網站首頁 編程語言 正文
學習前言
上一篇講了如何構建回歸算法,這一次將怎么進行簡單分類。
Keras中分類的重要函數
1、np_utils.to_categorical
np_utils.to_categorical用于將標簽轉化為形如(nb_samples, nb_classes)的二值序列。
假設num_classes = 10。
如將[1,2,3,……4]轉化成:
[[0,1,0,0,0,0,0,0]
[0,0,1,0,0,0,0,0]
[0,0,0,1,0,0,0,0]
……
[0,0,0,0,1,0,0,0]]
這樣的形態。
如將Y_train轉化為二值序列,可以用如下方式:
Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
2、Activation
Activation是激活函數,一般在每一層的輸出使用。
當我們使用Sequential模型構建函數的時候,只需要在每一層Dense后面添加Activation就可以了。
Sequential函數也支持直接在參數中完成所有層的構建,使用方法如下。
model = Sequential([
Dense(32,input_dim = 784),
Activation("relu"),
Dense(10),
Activation("softmax")
]
)
其中兩次Activation分別使用了relu函數和softmax函數。
3、metrics=[‘accuracy’]
在model.compile中添加metrics=[‘accuracy’]表示需要計算分類精確度,具體使用方式如下:
model.compile(
loss = 'categorical_crossentropy',
optimizer = rmsprop,
metrics=['accuracy']
)
全部代碼
這是一個簡單的僅含有一個隱含層的神經網絡,用于完成手寫體識別。在本例中,使用的優化器是RMSprop,具體可以使用的優化器可以參照Keras中文文檔。
import numpy as np
from keras.models import Sequential
from keras.layers import Dense,Activation ## 全連接層
from keras.datasets import mnist
from keras.utils import np_utils
from keras.optimizers import RMSprop
# 獲取訓練集
(X_train,Y_train),(X_test,Y_test) = mnist.load_data()
# 首先進行標準化
X_train = X_train.reshape(X_train.shape[0],-1)/255
X_test = X_test.reshape(X_test.shape[0],-1)/255
# 計算categorical_crossentropy需要對分類結果進行categorical
# 即需要將標簽轉化為形如(nb_samples, nb_classes)的二值序列
Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
Y_test = np_utils.to_categorical(Y_test,num_classes= 10)
# 構建模型
model = Sequential([
Dense(32,input_dim = 784),
Activation("relu"),
Dense(10),
Activation("softmax")
]
)
rmsprop = RMSprop(lr = 0.001,rho = 0.9,epsilon = 1e-08,decay = 0)
## compile
model.compile(loss = 'categorical_crossentropy',optimizer = rmsprop,metrics=['accuracy'])
print("\ntraining")
cost = model.fit(X_train,Y_train,nb_epoch = 2,batch_size = 32)
print("\nTest")
cost,accuracy = model.evaluate(X_test,Y_test)
## W,b = model.layers[0].get_weights()
print("accuracy:",accuracy)
實驗結果為:
Epoch 1/2
60000/60000 [==============================] - 12s 202us/step - loss: 0.3512 - acc: 0.9022
Epoch 2/2
60000/60000 [==============================] - 11s 183us/step - loss: 0.2037 - acc: 0.9419
Test
10000/10000 [==============================] - 1s 108us/step
accuracy: 0.9464
原文鏈接:https://blog.csdn.net/weixin_44791964/article/details/101170430
相關推薦
- 2022-11-01 R語言在散點圖中添加lm線性回歸公式的問題_R語言
- 2024-02-25 maven打包測試jar包沖突
- 2022-07-09 SAP Commerce Cloud 里的 Site API 調用方式講解
- 2022-07-22 Qt鍵盤事件和鼠標事件、定時器小實例詳解
- 2022-04-10 SQL?server中提示對象名無效的解決方法_MsSql
- 2022-07-02 Pandas?如何處理DataFrame中的inf值_python
- 2022-10-23 Go語言數據結構之插入排序示例詳解_Golang
- 2022-08-20 swift?framework使用OC?代碼兩種方式示例_Swift
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支