網(wǎng)站首頁 編程語言 正文
python神經(jīng)網(wǎng)絡(luò)使用Keras構(gòu)建RNN訓(xùn)練_python
作者:Bubbliiiing ? 更新時(shí)間: 2022-06-28 編程語言Keras中構(gòu)建RNN的重要函數(shù)
1、SimpleRNN
SimpleRNN用于在Keras中構(gòu)建普通的簡單RNN層,在使用前需要import。
from keras.layers import SimpleRNN
在實(shí)際使用時(shí),需要用到幾個(gè)參數(shù)。
model.add(
SimpleRNN(
batch_input_shape = (BATCH_SIZE,TIME_STEPS,INPUT_SIZE),
output_dim = CELL_SIZE,
)
)
其中,batch_input_shape代表RNN輸入數(shù)據(jù)的shape,shape的內(nèi)容分別是每一次訓(xùn)練使用的BATCH,TIME_STEPS表示這個(gè)RNN按順序輸入的時(shí)間點(diǎn)的數(shù)量,INPUT_SIZE表示每一個(gè)時(shí)間點(diǎn)的輸入數(shù)據(jù)大小。
CELL_SIZE代表訓(xùn)練每一個(gè)時(shí)間點(diǎn)的神經(jīng)元數(shù)量。
2、model.train_on_batch
與之前的訓(xùn)練CNN網(wǎng)絡(luò)和普通分類網(wǎng)絡(luò)不同,RNN網(wǎng)絡(luò)在建立時(shí)就規(guī)定了batch_input_shape,所以訓(xùn)練的時(shí)候也需要一定量一定量的傳入訓(xùn)練數(shù)據(jù)。
model.train_on_batch在使用前需要對數(shù)據(jù)進(jìn)行處理。獲取指定BATCH大小的訓(xùn)練集。
X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
index_start += BATCH_SIZE
具體訓(xùn)練過程如下:
for i in range(500):
X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
index_start += BATCH_SIZE
cost = model.train_on_batch(X_batch,Y_batch)
if index_start >= X_train.shape[0]:
index_start = 0
if i%100 == 0:
## acc
cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
## W,b = model.layers[0].get_weights()
print("accuracy:",accuracy)
x = X_test[1].reshape(1,28,28)
全部代碼
這是一個(gè)RNN神經(jīng)網(wǎng)絡(luò)的例子,用于識別手寫體。
import numpy as np
from keras.models import Sequential
from keras.layers import SimpleRNN,Activation,Dense ## 全連接層
from keras.datasets import mnist
from keras.utils import np_utils
from keras.optimizers import Adam
TIME_STEPS = 28
INPUT_SIZE = 28
BATCH_SIZE = 50
index_start = 0
OUTPUT_SIZE = 10
CELL_SIZE = 75
LR = 1e-3
(X_train,Y_train),(X_test,Y_test) = mnist.load_data()
X_train = X_train.reshape(-1,28,28)/255
X_test = X_test.reshape(-1,28,28)/255
Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
Y_test = np_utils.to_categorical(Y_test,num_classes= 10)
model = Sequential()
# conv1
model.add(
SimpleRNN(
batch_input_shape = (BATCH_SIZE,TIME_STEPS,INPUT_SIZE),
output_dim = CELL_SIZE,
)
)
model.add(Dense(OUTPUT_SIZE))
model.add(Activation("softmax"))
adam = Adam(LR)
## compile
model.compile(loss = 'categorical_crossentropy',optimizer = adam,metrics = ['accuracy'])
## tarin
for i in range(500):
X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
index_start += BATCH_SIZE
cost = model.train_on_batch(X_batch,Y_batch)
if index_start >= X_train.shape[0]:
index_start = 0
if i%100 == 0:
## acc
cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
## W,b = model.layers[0].get_weights()
print("accuracy:",accuracy)
實(shí)驗(yàn)結(jié)果為:
10000/10000 [==============================] - 1s 147us/step
accuracy: 0.09329999938607215
…………………………
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9395000022649765
10000/10000 [==============================] - 1s 109us/step
accuracy: 0.9422999995946885
10000/10000 [==============================] - 1s 114us/step
accuracy: 0.9534000000357628
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9566000008583069
10000/10000 [==============================] - 1s 113us/step
accuracy: 0.950799999833107
10000/10000 [==============================] - 1s 116us/step
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9474999988079071
10000/10000 [==============================] - 1s 111us/step
accuracy: 0.9515000003576278
10000/10000 [==============================] - 1s 114us/step
accuracy: 0.9288999977707862
10000/10000 [==============================] - 1s 115us/step
accuracy: 0.9487999993562698
原文鏈接:https://blog.csdn.net/weixin_44791964/article/details/101609556
相關(guān)推薦
- 2022-09-16 C++中的位運(yùn)算和位圖bitmap解析_C 語言
- 2022-10-11 Tomcat 9.x啟動時(shí)控制臺亂碼
- 2022-04-20 C語言特殊符號的補(bǔ)充理解_C 語言
- 2022-11-25 Go實(shí)現(xiàn)快速生成固定長度的隨機(jī)字符串_Golang
- 2022-03-12 用C語言實(shí)現(xiàn)圣誕樹(簡易版+進(jìn)階版)_C 語言
- 2022-06-14 C#實(shí)現(xiàn)密碼驗(yàn)證與輸錯(cuò)密碼賬戶鎖定_C#教程
- 2022-01-26 阿里云服務(wù)器端口請求失敗(在控制臺把端口添加到服務(wù)器的安全組)
- 2022-08-10 對WPF中的TreeView實(shí)現(xiàn)右鍵選定_C#教程
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細(xì)win安裝深度學(xué)習(xí)環(huán)境2025年最新版(
- Linux 中運(yùn)行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運(yùn)算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認(rèn)證信息的處理
- Spring Security之認(rèn)證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設(shè)
- maven:解決release錯(cuò)誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實(shí)現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務(wù)發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡單動態(tài)字符串(SD
- arthas操作spring被代理目標(biāo)對象命令
- Spring中的單例模式應(yīng)用詳解
- 聊聊消息隊(duì)列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠(yuǎn)程分支