網站首頁 編程語言 正文
學習前言
在前一段時間已經完成了卷積神經網絡的復習,現在要對循環神經網絡的結構進行更深層次的明確。
RNN簡介
RNN 是當前發展非常火熱的神經網絡中的一種,它擅長對序列數據進行處理。
什么是序列數據呢?舉個例子。
現在假設有四個字,“我” “去” “吃” “飯”。我們可以對它們進行任意的排列組合。
“我去吃飯”,表示的就是我要去吃飯了。
“飯去吃我”,表示的就是飯成精了。
“我吃去飯”,表示的我要去吃‘去飯’了。
不同的排列順序會導致不同的語意,序列數據表示的就是按照一定順序排列的序列,這種排列一般存在一定的意義。。
所以我們知道了RNN有順序存儲的這個抽象概念,但是RNN如何學習這個概念呢?
那么,讓我們來看一個傳統的神經網絡,也稱為前饋神經網絡。它有輸入層,隱藏層和輸出層。就像這樣
對于RNN來講,其結構示意圖是這樣的:
一句話可以分為N個part,比如“我去吃飯”可以分為四個字,“我” “去” “吃” “飯”,分別可以傳入四個隱含層,前一個隱含層會有一個輸出按照一定的比率傳給后一個隱含層,比如第一個“我”輸入隱含層后,有一個輸出按照w1的比率輸入給下一個隱含層,當第二個“去”進入隱含層時,隱含層同樣要接收“我”傳過來的信息。
以此類推,在到達最后一個“飯”時,最后的輸出便得到了前面全部的信息。
其偽代碼形式為:
rnn = RNN() ff = FeedForwardNN() hidden_state = [0,0,0] for word in input: output,hidden_state = rnn(word,hidden_state) prediction = ff(output)
tensorflow中RNN的相關函數
tf.nn.rnn_cell.BasicLSTMCell
tf.nn.rnn_cell.BasicRNNCell( num_units, activation=None, reuse=None, name=None, dtype=None, **kwargs)
- num_units:RNN單元中的神經元數量,即輸出神經元數量。
- activation:激活函數。
- reuse:描述是否在現有范圍中重用變量。如果不為True,并且現有范圍已經具有給定變量,則會引發錯誤。
- name:層的名稱。
- dtype:該層的數據類型。
- kwargs:常見層屬性的關鍵字命名屬性,如trainable,當從get_config()創建cell 。
在使用時,可以定義為:
RNN_cell = tf.nn.rnn_cell.BasicRNNCell(n_hidden_units,activation=tf.nn.tanh)
在定義完成后,可以進行狀態初始化:
_init_state = RNN_cell.zero_state(batch_size,tf.float32)
tf.nn.dynamic_rnn
tf.nn.dynamic_rnn( cell, inputs, sequence_length=None, initial_state=None, dtype=None, parallel_iterations=None, swap_memory=False, time_major=False, scope=None )
- cell:上文所定義的lstm_cell。
- inputs:RNN輸入。如果time_major==false(默認),則必須是如下shape的tensor:[batch_size,max_time,…]或此類元素的嵌套元組。如果time_major==true,則必須是如下形狀的tensor:[max_time,batch_size,…]或此類元素的嵌套元組。
- sequence_length:Int32/Int64矢量大小。用于在超過批處理元素的序列長度時復制通過狀態和零輸出。因此,它更多的是為了性能而不是正確性。
- initial_state:上文所定義的_init_state。
- dtype:數據類型。
- parallel_iterations:并行運行的迭代次數。那些不具有任何時間依賴性并且可以并行運行的操作將是。這個參數用時間來交換空間。值>>1使用更多的內存,但花費的時間更少,而較小的值使用更少的內存,但計算需要更長的時間。
- time_major:輸入和輸出tensor的形狀格式。如果為真,這些張量的形狀必須是[max_time,batch_size,depth]。如果為假,這些張量的形狀必須是[batch_size,max_time,depth]。使用time_major=true會更有效率,因為它可以避免在RNN計算的開始和結束時進行換位。但是,大多數TensorFlow數據都是批處理主數據,因此默認情況下,此函數為False。
- scope:創建的子圖的可變作用域;默認為“RNN”。
在RNN的最后,需要用該函數得出結果。
outputs,states = tf.nn.dynamic_rnn(RNN_cell,X_in,initial_state = _init_state,time_major = False)
返回的是一個元組 (outputs, state):
outputs
:RNN的最后一層的輸出,是一個tensor。如果為time_major== False,則它的shape為[batch_size,max_time,cell.output_size]。如果為time_major== True,則它的shape為[max_time,batch_size,cell.output_size]。
states
:states是一個tensor。state是最終的狀態,也就是序列中最后一個cell輸出的狀態。一般情況下states的形狀為 [batch_size, cell.output_size],但當輸入的cell為BasicLSTMCell時,states的形狀為[2,batch_size, cell.output_size ],其中2也對應著LSTM中的cell state和hidden state。
整個RNN的定義過程為:
def RNN(X,weights,biases): #X最開始的形狀為(128 batch,28 steps,28 inputs) #轉化為(128 batch*28 steps,128 hidden) X = tf.reshape(X,[-1,n_inputs]) #經過乘法后結果為(128 batch*28 steps,256 hidden) X_in = tf.matmul(X,weights['in'])+biases['in'] #再次轉化為(128 batch,28 steps,256 hidden) X_in = tf.reshape(X_in,[-1,n_steps,n_hidden_units]) RNN_cell = tf.nn.rnn_cell.BasicRNNCell(n_hidden_units,activation=tf.nn.tanh) _init_state = RNN_cell.zero_state(batch_size,tf.float32) outputs,states = tf.nn.dynamic_rnn(RNN_cell,X_in,initial_state = _init_state,time_major = False) results = tf.matmul(states,weights['out'])+biases['out'] return results
全部代碼
該例子為手寫體識別例子,將手寫體的28行分別作為每一個step的輸入,輸入維度均為28列。
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("MNIST_data",one_hot = "true") lr = 0.001 #學習率 training_iters = 1000000 #學習世代數 batch_size = 128 #每一輪進入訓練的訓練量 n_inputs = 28 #輸入每一個隱含層的inputs維度 n_steps = 28 #一共分為28次輸入 n_hidden_units = 128 #每一個隱含層的神經元個數 n_classes = 10 #輸出共有10個 x = tf.placeholder(tf.float32,[None,n_steps,n_inputs]) y = tf.placeholder(tf.float32,[None,n_classes]) weights = { 'in':tf.Variable(tf.random_normal([n_inputs,n_hidden_units])), 'out':tf.Variable(tf.random_normal([n_hidden_units,n_classes])) } biases = { 'in':tf.Variable(tf.constant(0.1,shape=[n_hidden_units])), 'out':tf.Variable(tf.constant(0.1,shape=[n_classes])) } def RNN(X,weights,biases): #X最開始的形狀為(128 batch,28 steps,28 inputs) #轉化為(128 batch*28 steps,128 hidden) X = tf.reshape(X,[-1,n_inputs]) #經過乘法后結果為(128 batch*28 steps,256 hidden) X_in = tf.matmul(X,weights['in'])+biases['in'] #再次轉化為(128 batch,28 steps,256 hidden) X_in = tf.reshape(X_in,[-1,n_steps,n_hidden_units]) RNN_cell = tf.nn.rnn_cell.BasicRNNCell(n_hidden_units,activation=tf.nn.tanh) _init_state = RNN_cell.zero_state(batch_size,tf.float32) outputs,states = tf.nn.dynamic_rnn(RNN_cell,X_in,initial_state = _init_state,time_major = False) results = tf.matmul(states,weights['out'])+biases['out'] return results pre = RNN(x,weights,biases) cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = pre,labels = y)) train_op = tf.train.AdamOptimizer(lr).minimize(cost) correct_pre = tf.equal(tf.argmax(y,1),tf.argmax(pre,1)) accuracy = tf.reduce_mean(tf.cast(correct_pre,tf.float32)) init = tf.initialize_all_variables() with tf.Session() as sess: sess.run(init) step = 0 while step*batch_size <training_iters: batch_xs,batch_ys = mnist.train.next_batch(batch_size) batch_xs = batch_xs.reshape([batch_size,n_steps,n_inputs]) sess.run(train_op,feed_dict = { x:batch_xs, y:batch_ys }) if step%20 == 0: print(sess.run(accuracy,feed_dict = { x:batch_xs, y:batch_ys })) step += 1
原文鏈接:https://blog.csdn.net/weixin_44791964/article/details/98476021
相關推薦
- 2022-07-29 pytest解讀一次請求多個fixtures及多次請求_python
- 2022-10-25 IDEA 安裝tomcat10創建servlet報404錯誤
- 2022-07-13 IO流分類以及分別使用字節流、字符流復制文本文件、復制圖片
- 2022-10-12 使用Docker搭建Vsftpd?的?FTP?服務的詳細過程_docker
- 2022-11-21 在react中使用windicss的問題_React
- 2022-04-11 ElasticSearch 8.x 默認密碼
- 2022-07-13 查看工具設置的編碼 sys.getdefaultencoding()
- 2022-10-11 python格式化字符串的實戰教程(使用占位符、format方法)_python
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支