網站首頁 編程語言 正文
我們有時會看到GRU中輸入的參數有時是一個,但是有時又有兩個。這難免會讓人們感到疑惑,那么這些參數到底是什么呢。
一、輸入到GRU的參數
輸入的參數有兩個,分別是input和h_0。
Inputs: input, h_0
①input的shape
The shape of input:(seq_len, batch, input_size) : tensor containing the feature of the input sequence. The input can also be a packed variable length sequence。
See functorch.nn.utils.rnn.pack_padded_sequencefor details.
②h_0的shape
從下面的解釋中也可以看出,這個參數可以不提供,那么就默認為0.
The shape of h_0:(num_layers * num_directions, batch, hidden_size): tensor containing the initial hidden state for each element in the batch.
Defaults to zero if not provided. If the RNN is bidirectional num_directions should be 2, else it should be 1.
綜上,可以只輸入一個參數。當輸入兩個參數的時候,那么第二個參數相當于是一個隱含層的輸出。
為了便于理解,下面是一幅圖:
二、GRU返回的數據
輸出有兩個,分別是output和h_n
①output
output 的shape是:(seq_len, batch, num_directions * hidden_size): tensor containing the output features h_t from the last layer of the GRU, for each t.
If a class:torch.nn.utils.rnn.PackedSequence has been given as the input, the output will also be a packed sequence.
For the unpacked case, the directions can be separated using output.view(seq_len, batch, num_directions, hidden_size), with forward and backward being direction 0 and 1 respectively.
Similarly, the directions can be separated in the packed case.
②h_n
h_n的shape是:(num_layers * num_directions, batch, hidden_size): tensor containing the hidden state for t = seq_len
Like output, the layers can be separated using
h_n.view(num_layers, num_directions, batch, hidden_size).
三、代碼示例
數據的shape是[batch,seq_len,emb_dim]
RNN接收輸入的數據的shape是[seq_len,batch,emb_dim]
即前兩個維度調換就行了。
可以知道,加入批處理的時候一次處理128個句子,每個句子中有5個單詞,那么上圖中展示的input_data的shape是:[128,5,emb_dim]。
結合代碼分析,本例子將演示有1個句子和5個句子的情況。假設每個句子中有9個單詞,所以seq_len=9,并且每個單詞對應的emb_dim=3,所以對應數據的shape是: [batch,9,3],由于輸入到RNN中數據格式的格式,所以為[9,batch,3]
import torch
import torch.nn as nn
emb_dim = 3
hidden_dim = 2
rnn = nn.GRU(emb_dim,hidden_dim)
#rnn = nn.GRU(9,1,3)
print(type(rnn))
tensor1 = torch.tensor([[-0.5502, -0.1920, 1.1845],
[-0.8003, 2.0783, 0.0175],
[ 0.6761, 0.7183, -1.0084],
[ 0.9514, 1.4772, -0.2271],
[-1.0146, 0.7912, 0.2003],
[-0.5502, -0.1920, 1.1845],
[-0.8003, 2.0783, 0.0175],
[ 0.1718, 0.1070, 0.4255],
[-2.6727, -1.5680, -0.8369]])
tensor2 = torch.tensor([[-0.5502, -0.1920]])
# 假設input只有一個句子,那么batch為1
print('--------------batch=1時------------')
data = tensor1.unsqueeze(0)
h_0 = tensor2[0].unsqueeze(0).unsqueeze(0)
print('data.shape: [batch,seq_len,emb_dim]',data.shape)
print('')
input = data.transpose(0,1)
print('input.shape: [seq_len,batch,emb_dim]',input.shape)
print('h_0.shape: [1,batch,hidden_dim]',h_0.shape)
print('')
# 輸入到rnn中
output,h_n = rnn(input,h_0)
print('output.shape: [seq_len,batch,hidden_dim]',output.shape)
print('h_n.shape: [1,batch,hidden_dim]',h_n.shape)
# 假設input中有5個句子,所以,batch = 5
print('\n--------------batch=5時------------')
data = tensor1.unsqueeze(0).repeat(5,1,1) # 由于batch為5
h_0 = tensor2[0].unsqueeze(0).repeat(1,5,1) # 由于batch為5
print('data.shape: [batch,seq_len,emb_dim]',data.shape)
print('')
input = data.transpose(0,1)
print('input.shape: [seq_len,batch,emb_dim]',input.shape)
print('h_0.shape: [1,batch,hidden_dim]',h_0.shape)
print('')
# 輸入到rnn中
output,h_n = rnn(input,h_0)
print('output.shape: [seq_len,batch,hidden_dim]',output.shape)
print('h_n.shape: [1,batch,hidden_dim]',h_n.shape)
四、輸出
<class ‘torch.nn.modules.rnn.GRU’>
--------------batch=1時------------
data.shape: [batch,seq_len,emb_dim] torch.Size([1, 9, 3])input.shape: [seq_len,batch,emb_dim] torch.Size([9, 1, 3])
h_0.shape: [1,batch,hidden_dim] torch.Size([1, 1, 2])output.shape: [seq_len,batch,hidden_dim] torch.Size([9, 1, 2])
h_n.shape: [1,batch,hidden_dim] torch.Size([1, 1, 2])--------------batch=5時------------
data.shape: [batch,seq_len,emb_dim] torch.Size([5, 9, 3])input.shape: [seq_len,batch,emb_dim] torch.Size([9, 5, 3])
h_0.shape: [1,batch,hidden_dim] torch.Size([1, 5, 2])output.shape: [seq_len,batch,hidden_dim] torch.Size([9, 5, 2])
h_n.shape: [1,batch,hidden_dim] torch.Size([1, 5, 2])
總結
原文鏈接:https://blog.csdn.net/jiuweideqixu/article/details/109492863
相關推薦
- 2022-01-16 對象的綁定、滾輪滾動事件及鍵盤事件
- 2022-06-02 React中的Props類型校驗和默認值詳解_React
- 2023-03-22 nginx.conf配置兩個前端路徑_nginx
- 2023-03-18 詳解Flutter中key的正確使用方式_Android
- 2022-08-18 python上下文管理器使用場景及異常處理_python
- 2022-06-14 colab中修改python版本的全過程_python
- 2022-07-24 C#導入和導出CSV文件_C#教程
- 2023-06-18 C#最小二乘法擬合曲線成直線的實例_C#教程
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支