網站首頁 編程語言 正文
什么是LSTM
1、LSTM的結構
我們可以看出,在n時刻,LSTM的輸入有三個:
- 當前時刻網絡的輸入值Xt;
- 上一時刻LSTM的輸出值ht-1;
- 上一時刻的單元狀態Ct-1。
LSTM的輸出有兩個:
- 當前時刻LSTM輸出值ht;
- 當前時刻的單元狀態Ct。
2、LSTM獨特的門結構
LSTM用兩個門來控制單元狀態cn的內容:
- 遺忘門(forget gate),它決定了上一時刻的單元狀態cn-1有多少保留到當前時刻;
- 輸入門(input gate),它決定了當前時刻網絡的輸入c’n有多少保存到新的單元狀態cn中。
LSTM用一個門來控制當前輸出值hn的內容:
輸出門(output gate),它利用當前時刻單元狀態cn對hn的輸出進行控制。
3、LSTM參數量計算
a、遺忘門
遺忘門這里需要結合ht-1和Xt來決定上一時刻的單元狀態cn-1有多少保留到當前時刻;
由圖我們可以得到,我們在這一環節需要計一個參數ft。
b、輸入門
輸入門這里需要結合ht-1和Xt來決定當前時刻網絡的輸入c’n有多少保存到單元狀態cn中。
由圖我們可以得到,我們在這一環節需要計算兩個參數,分別是it。
和C’t
里面需要訓練的參數分別是Wi、bi、WC和bC。
在定義LSTM的時候我們會使用到一個參數叫做units,其實就是神經元的個數,也就是LSTM的輸出——ht的維度。
所以:
c、輸出門
輸出門利用當前時刻單元狀態cn對hn的輸出進行控制;
由圖我們可以得到,我們在這一環節需要計一個參數ot。
里面需要訓練的參數分別是Wo和bo。在定義LSTM的時候我們會使用到一個參數叫做units,其實就是神經元的個數,也就是LSTM的輸出——ht的維度。所以:
d、全部參數量
所以所有的門總參數量為:
在Keras中實現LSTM
LSTM一般需要輸入兩個參數。
一個是unit、一個是input_shape。
LSTM(CELL_SIZE, input_shape = (TIME_STEPS,INPUT_SIZE))
unit用于指定神經元的數量。
input_shape用于指定輸入的shape,分別指定TIME_STEPS和INPUT_SIZE。
實現代碼
import numpy as np
from keras.models import Sequential
from keras.layers import Input,Activation,Dense
from keras.models import Model
from keras.datasets import mnist
from keras.layers.recurrent import LSTM
from keras.utils import np_utils
from keras.optimizers import Adam
TIME_STEPS = 28
INPUT_SIZE = 28
BATCH_SIZE = 50
index_start = 0
OUTPUT_SIZE = 10
CELL_SIZE = 75
LR = 1e-3
(X_train,Y_train),(X_test,Y_test) = mnist.load_data()
X_train = X_train.reshape(-1,28,28)/255
X_test = X_test.reshape(-1,28,28)/255
Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
Y_test = np_utils.to_categorical(Y_test,num_classes= 10)
inputs = Input(shape=[TIME_STEPS,INPUT_SIZE])
x = LSTM(CELL_SIZE, input_shape = (TIME_STEPS,INPUT_SIZE))(inputs)
x = Dense(OUTPUT_SIZE)(x)
x = Activation("softmax")(x)
model = Model(inputs,x)
adam = Adam(LR)
model.summary()
model.compile(loss = 'categorical_crossentropy',optimizer = adam,metrics = ['accuracy'])
for i in range(50000):
X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
index_start += BATCH_SIZE
cost = model.train_on_batch(X_batch,Y_batch)
if index_start >= X_train.shape[0]:
index_start = 0
if i%100 == 0:
cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
print("accuracy:",accuracy)
實現效果:
10000/10000 [==============================] - 3s 340us/step
accuracy: 0.14040000014007092
10000/10000 [==============================] - 3s 310us/step
accuracy: 0.6507000041007995
10000/10000 [==============================] - 3s 320us/step
accuracy: 0.7740999992191792
10000/10000 [==============================] - 3s 305us/step
accuracy: 0.8516999959945679
10000/10000 [==============================] - 3s 322us/step
accuracy: 0.8669999945163727
10000/10000 [==============================] - 3s 324us/step
accuracy: 0.889699995815754
10000/10000 [==============================] - 3s 307us/step
原文鏈接:https://blog.csdn.net/weixin_44791964/article/details/103896884
相關推薦
- 2022-05-20 C++實現公司人事管理系統_C 語言
- 2022-10-07 Unity游戲開發實現場景切換示例_C#教程
- 2022-03-26 淺談C語言數組元素下標為何從0開始_C 語言
- 2022-06-02 Apache?Hudi集成Spark?SQL操作hide表_數據庫其它
- 2022-11-12 C語言數據結構之單鏈表的查找和建立_C 語言
- 2022-09-10 Python實現自定義異常堆棧信息的示例代碼_python
- 2022-05-04 Django點贊的實現示例_python
- 2023-01-07 Python源碼加密與Pytorch模型加密分別介紹_python
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支