網(wǎng)站首頁(yè) 編程語(yǔ)言 正文
python神經(jīng)網(wǎng)絡(luò)Keras實(shí)現(xiàn)GRU及其參數(shù)量_python
作者:Bubbliiiing ? 更新時(shí)間: 2022-07-01 編程語(yǔ)言什么是GRU
GRU是LSTM的一個(gè)變種。
傳承了LSTM的門結(jié)構(gòu),但是將LSTM的三個(gè)門轉(zhuǎn)化成兩個(gè)門,分別是更新門和重置門。
1、GRU單元的輸入與輸出
下圖是每個(gè)GRU單元的結(jié)構(gòu)。
在n時(shí)刻,每個(gè)GRU單元的輸入有兩個(gè):
- 當(dāng)前時(shí)刻網(wǎng)絡(luò)的輸入值Xt;
- 上一時(shí)刻GRU的輸出值ht-1;
輸出有一個(gè):
當(dāng)前時(shí)刻GRU輸出值ht;
2、GRU的門結(jié)構(gòu)
GRU含有兩個(gè)門結(jié)構(gòu),分別是:
更新門zt和重置門rt:
更新門用于控制前一時(shí)刻的狀態(tài)信息被代入到當(dāng)前狀態(tài)的程度,更新門的值越大說明前一時(shí)刻的狀態(tài)信息帶入越少,這一時(shí)刻的狀態(tài)信息帶入越多。
重置門用于控制忽略前一時(shí)刻的狀態(tài)信息的程度,重置門的值越小說明忽略得越多。
3、GRU的參數(shù)量計(jì)算
a、更新門
更新門在圖中的標(biāo)號(hào)為zt,需要結(jié)合ht-1和Xt來決定上一時(shí)刻的輸出ht-1有多少得到保留,更新門的值越大說明前一時(shí)刻的狀態(tài)信息保留越少,這一時(shí)刻的狀態(tài)信息保留越多。
結(jié)合公式我們可以知道:
zt由ht-1和Xt來決定。
當(dāng)更新門zt的值較大的時(shí)候,上一時(shí)刻的輸出ht-1保留較少,而這一時(shí)刻的狀態(tài)信息保留較多。
b、重置門
重置門在圖中的標(biāo)號(hào)為rt,需要結(jié)合ht-1和Xt來控制忽略前一時(shí)刻的狀態(tài)信息的程度,重置門的值越小說明忽略得越多。
結(jié)合公式我們可以知道:
rt由ht-1和Xt來決定。
當(dāng)重置門rt的值較小的時(shí)候,上一時(shí)刻的輸出ht-1保留較少,說明忽略得越多。
c、全部參數(shù)量
所以所有的門總參數(shù)量為:
在Keras中實(shí)現(xiàn)GRU
GRU一般需要輸入兩個(gè)參數(shù)。
一個(gè)是unit、一個(gè)是input_shape。
LSTM(CELL_SIZE, input_shape = (TIME_STEPS,INPUT_SIZE))
unit用于指定神經(jīng)元的數(shù)量。
input_shape用于指定輸入的shape,分別指定TIME_STEPS和INPUT_SIZE。
實(shí)現(xiàn)代碼
import numpy as np
from keras.models import Sequential
from keras.layers import Input,Activation,Dense
from keras.models import Model
from keras.datasets import mnist
from keras.layers.recurrent import GRU
from keras.utils import np_utils
from keras.optimizers import Adam
TIME_STEPS = 28
INPUT_SIZE = 28
BATCH_SIZE = 50
index_start = 0
OUTPUT_SIZE = 10
CELL_SIZE = 75
LR = 1e-3
(X_train,Y_train),(X_test,Y_test) = mnist.load_data()
X_train = X_train.reshape(-1,28,28)/255
X_test = X_test.reshape(-1,28,28)/255
Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
Y_test = np_utils.to_categorical(Y_test,num_classes= 10)
inputs = Input(shape=[TIME_STEPS,INPUT_SIZE])
x = GRU(CELL_SIZE, input_shape = (TIME_STEPS,INPUT_SIZE))(inputs)
x = Dense(OUTPUT_SIZE)(x)
x = Activation("softmax")(x)
model = Model(inputs,x)
adam = Adam(LR)
model.summary()
model.compile(loss = 'categorical_crossentropy',optimizer = adam,metrics = ['accuracy'])
for i in range(50000):
X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
index_start += BATCH_SIZE
cost = model.train_on_batch(X_batch,Y_batch)
if index_start >= X_train.shape[0]:
index_start = 0
if i%100 == 0:
cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
print("accuracy:",accuracy)
實(shí)現(xiàn)效果:
10000/10000 [==============================] - 2s 231us/step
accuracy: 0.16749999986961484
10000/10000 [==============================] - 2s 206us/step
accuracy: 0.6134000015258789
10000/10000 [==============================] - 2s 214us/step
accuracy: 0.7058000019192696
10000/10000 [==============================] - 2s 209us/step
accuracy: 0.797899999320507
原文鏈接:https://blog.csdn.net/weixin_44791964/article/details/104011262
相關(guān)推薦
- 2023-05-23 numpy中tensordot的用法_python
- 2022-07-21 element ui 中el-row的gutter失效
- 2024-02-17 pytorch花式索引提取topk的張量
- 2022-11-28 Flutter?Widgets之標(biāo)簽類控件Chip詳解_IOS
- 2023-04-19 Android.bp語(yǔ)法和使用方法講解_Android
- 2022-05-13 c++中文字符匹配,但不匹配中文標(biāo)點(diǎn)的完美解決方案。
- 2022-07-26 react如何添加less環(huán)境配置_React
- 2022-05-27 Linux?創(chuàng)建oracle數(shù)據(jù)庫(kù)的詳細(xì)過程_oracle
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細(xì)win安裝深度學(xué)習(xí)環(huán)境2025年最新版(
- Linux 中運(yùn)行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲(chǔ)小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運(yùn)算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認(rèn)證信息的處理
- Spring Security之認(rèn)證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設(shè)
- maven:解決release錯(cuò)誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實(shí)現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務(wù)發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡(jiǎn)單動(dòng)態(tài)字符串(SD
- arthas操作spring被代理目標(biāo)對(duì)象命令
- Spring中的單例模式應(yīng)用詳解
- 聊聊消息隊(duì)列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠(yuǎn)程分支