網(wǎng)站首頁(yè) 編程語(yǔ)言 正文
nn.RNN(input_size, hidden_size, num_layers=1, nonlinearity=tanh, bias=True, batch_first=False, dropout=0, bidirectional=False)
參數(shù)說(shuō)明
- input_size輸入特征的維度, 一般rnn中輸入的是詞向量,那么 input_size 就等于一個(gè)詞向量的維度
- hidden_size隱藏層神經(jīng)元個(gè)數(shù),或者也叫輸出的維度(因?yàn)閞nn輸出為各個(gè)時(shí)間步上的隱藏狀態(tài))
- num_layers網(wǎng)絡(luò)的層數(shù)
- nonlinearity激活函數(shù)
- bias是否使用偏置
- batch_first輸入數(shù)據(jù)的形式,默認(rèn)是 False,就是這樣形式,(seq(num_step), batch, input_dim),也就是將序列長(zhǎng)度放在第一位,batch 放在第二位
- dropout是否應(yīng)用dropout, 默認(rèn)不使用,如若使用將其設(shè)置成一個(gè)0-1的數(shù)字即可
- birdirectional是否使用雙向的 rnn,默認(rèn)是 False
- 注意某些參數(shù)的默認(rèn)值在標(biāo)題中已注明
輸入輸出shape
- input_shape = [時(shí)間步數(shù), 批量大小, 特征維度] = [num_steps(seq_length), batch_size, input_dim]
- 在前向計(jì)算后會(huì)分別返回輸出和隱藏狀態(tài)h,其中輸出指的是隱藏層在各個(gè)時(shí)間步上計(jì)算并輸出的隱藏狀態(tài),它們通常作為后續(xù)輸出層的輸?。需要強(qiáng)調(diào)的是,該“輸出”本身并不涉及輸出層計(jì)算,形狀為(時(shí)間步數(shù), 批量大小, 隱藏單元個(gè)數(shù));隱藏狀態(tài)指的是隱藏層在最后時(shí)間步的隱藏狀態(tài):當(dāng)隱藏層有多層時(shí),每?層的隱藏狀態(tài)都會(huì)記錄在該變量中;對(duì)于像?短期記憶(LSTM),隱藏狀態(tài)是?個(gè)元組(h, c),即hidden state和cell state(此處普通rnn只有一個(gè)值)隱藏狀態(tài)h的形狀為(層數(shù), 批量大小,隱藏單元個(gè)數(shù))
代碼
rnn_layer = nn.RNN(input_size=vocab_size, hidden_size=num_hiddens, ) # 定義模型, 其中vocab_size = 1027, hidden_size = 256
num_steps = 35 batch_size = 2 state = None # 初始隱藏層狀態(tài)可以不定義 X = torch.rand(num_steps, batch_size, vocab_size) Y, state_new = rnn_layer(X, state) print(Y.shape, len(state_new), state_new.shape)
輸出
torch.Size([35, 2, 256]) ? ? 1 ? ? ? torch.Size([1, 2, 256])
具體計(jì)算過(guò)程
H t = i n p u t ? W x h + H t ? 1 ? W h h + b i a s H_t = input * W_{xh} + H_{t-1} * W_{hh} + bias Ht?=input?Wxh?+Ht?1??Whh?+bias[batch_size, input_dim] * [input_dim, num_hiddens] + [batch_size, num_hiddens] *[num_hiddens, num_hiddens] +bias
可以發(fā)現(xiàn)每個(gè)隱藏狀態(tài)形狀都是[batch_size, num_hiddens], 起始輸出也是一樣的
注意:上面為了方便假設(shè)num_step=1
GRU/LSTM等參數(shù)同上面RNN
原文鏈接:https://blog.csdn.net/orangerfun/article/details/103934290
相關(guān)推薦
- 2022-08-18 python上下文管理器使用場(chǎng)景及異常處理_python
- 2022-04-01 SQL?Server?Transact-SQL編程詳解_MsSql
- 2022-10-30 Matlab利用遺傳算法GA求解非連續(xù)函數(shù)問(wèn)題詳解_C 語(yǔ)言
- 2022-04-01 Docker一直starting如何解決
- 2023-08-28 react:使用 moment 來(lái)獲取日期
- 2022-12-12 Python實(shí)現(xiàn)打印九九乘法表的不同方法總結(jié)_python
- 2022-03-23 詳細(xì)聊聊Redis的過(guò)期策略_Redis
- 2022-09-14 Android多渠道打包神器ProductFlavor詳解_Android
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細(xì)win安裝深度學(xué)習(xí)環(huán)境2025年最新版(
- Linux 中運(yùn)行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲(chǔ)小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運(yùn)算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認(rèn)證信息的處理
- Spring Security之認(rèn)證過(guò)濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設(shè)
- maven:解決release錯(cuò)誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實(shí)現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務(wù)發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡(jiǎn)單動(dòng)態(tài)字符串(SD
- arthas操作spring被代理目標(biāo)對(duì)象命令
- Spring中的單例模式應(yīng)用詳解
- 聊聊消息隊(duì)列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠(yuǎn)程分支