網站首頁 編程語言 正文
1、雙向RNN
雙向RNN(Bidirectional RNN)的結構如下圖所示。
雙向的 RNN 是同時考慮“過去”和“未來”的信息。上圖是一個序列長度為 4 的雙向RNN 結構。
雙向RNN就像是我們做閱讀理解的時候從頭向后讀一遍文章,然后又從后往前讀一遍文章,然后再做題。有可能從后往前再讀一遍文章的時候會有新的不一樣的理解,最后模型可能會得到更好的結果。
2、堆疊的雙向RNN
堆疊的雙向RNN(Stacked Bidirectional RNN)的結構如上圖所示。上圖是一個堆疊了3個隱藏層的RNN網絡。
注意,這里的堆疊的雙向RNN并不是只有雙向的RNN才可以堆疊,其實任意的RNN都可以堆疊,如SimpleRNN、LSTM和GRU這些循環神經網絡也可以進行堆疊。
堆疊指的是在RNN的結構中疊加多層,類似于BP神經網絡中可以疊加多層,增加網絡的非線性。
3、雙向LSTM實現MNIST數據集分類
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import LSTM,Dropout,Bidirectional
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
# 載入數據集
mnist = tf.keras.datasets.mnist
# 載入數據,數據載入的時候就已經劃分好訓練集和測試集
# 訓練集數據x_train的數據形狀為(60000,28,28)
# 訓練集標簽y_train的數據形狀為(60000)
# 測試集數據x_test的數據形狀為(10000,28,28)
# 測試集標簽y_test的數據形狀為(10000)
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 對訓練集和測試集的數據進行歸一化處理,有助于提升模型訓練速度
x_train, x_test = x_train / 255.0, x_test / 255.0
# 把訓練集和測試集的標簽轉為獨熱編碼
y_train = tf.keras.utils.to_categorical(y_train,num_classes=10)
y_test = tf.keras.utils.to_categorical(y_test,num_classes=10)
# 數據大小-一行有28個像素
input_size = 28
# 序列長度-一共有28行
time_steps = 28
# 隱藏層memory block個數
cell_size = 50
# 創建模型
# 循環神經網絡的數據輸入必須是3維數據
# 數據格式為(數據數量,序列長度,數據大小)
# 載入的mnist數據的格式剛好符合要求
# 注意這里的input_shape設置模型數據輸入時不需要設置數據的數量
model = Sequential([
Bidirectional(LSTM(units=cell_size,input_shape=(time_steps,input_size),return_sequences=True)),
Dropout(0.2),
Bidirectional(LSTM(cell_size)),
Dropout(0.2),
# 50個memory block輸出的50個值跟輸出層10個神經元全連接
Dense(10,activation=tf.keras.activations.softmax)
])
# 循環神經網絡的數據輸入必須是3維數據
# 數據格式為(數據數量,序列長度,數據大小)
# 載入的mnist數據的格式剛好符合要求
# 注意這里的input_shape設置模型數據輸入時不需要設置數據的數量
# model.add(LSTM(
# units = cell_size,
# input_shape = (time_steps,input_size),
# ))
# 50個memory block輸出的50個值跟輸出層10個神經元全連接
# model.add(Dense(10,activation='softmax'))
# 定義優化器
adam = Adam(lr=1e-3)
# 定義優化器,loss function,訓練過程中計算準確率 使用交叉熵損失函數
model.compile(optimizer=adam,loss='categorical_crossentropy',metrics=['accuracy'])
# 訓練模型
history=model.fit(x_train,y_train,batch_size=64,epochs=10,validation_data=(x_test,y_test))
#打印模型摘要
model.summary()
loss=history.history['loss']
val_loss=history.history['val_loss']
accuracy=history.history['accuracy']
val_accuracy=history.history['val_accuracy']
# 繪制loss曲線
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()
# 繪制acc曲線
plt.plot(accuracy, label='Training accuracy')
plt.plot(val_accuracy, label='Validation accuracy')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()
這個可能對文本數據比較容易處理,這里用這個模型有點勉強,只是簡單測試下。
模型摘要:
acc曲線:
loss曲線:
原文鏈接:https://blog.csdn.net/qq_43753724/article/details/125591449
相關推薦
- 2022-04-08 pytorch?plt.savefig()的用法及保存路徑_python
- 2022-05-23 ELK與Grafana聯合打造可視化監控來分析nginx日志_nginx
- 2022-12-21 Android?ChipGroup收起折疊效果實現詳解_Android
- 2022-11-15 關于if?exists的用法及說明_MsSql
- 2022-07-10 Executor 線程池技術詳解
- 2022-08-02 python3線程池ThreadPoolExecutor處理csv文件數據_python
- 2023-01-12 python可迭代類型遍歷過程中數據改變會不會報錯_python
- 2022-12-29 解決React報錯Expected?`onClick`?listener?to?be?a?funct
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支