網(wǎng)站首頁 編程語言 正文
LSTM介紹
關(guān)于LSTM的具體原理,可以參考:
https://www.jb51.net/article/178582.htm
https://www.jb51.net/article/178423.htm
系列文章:
PyTorch搭建雙向LSTM實(shí)現(xiàn)時(shí)間序列負(fù)荷預(yù)測
PyTorch搭建LSTM實(shí)現(xiàn)多變量多步長時(shí)序負(fù)荷預(yù)測
PyTorch搭建LSTM實(shí)現(xiàn)多變量時(shí)序負(fù)荷預(yù)測
PyTorch搭建LSTM實(shí)現(xiàn)時(shí)間序列負(fù)荷預(yù)測
LSTM參數(shù)
關(guān)于nn.LSTM的參數(shù),官方文檔給出的解釋為:
總共有七個(gè)參數(shù),其中只有前三個(gè)是必須的。由于大家普遍使用PyTorch的DataLoader來形成批量數(shù)據(jù),因此batch_first也比較重要。LSTM的兩個(gè)常見的應(yīng)用場景為文本處理和時(shí)序預(yù)測,因此下面對每個(gè)參數(shù)我都會從這兩個(gè)方面來進(jìn)行具體解釋。
- input_size:在文本處理中,由于一個(gè)單詞沒法參與運(yùn)算,因此我們得通過Word2Vec來對單詞進(jìn)行嵌入表示,將每一個(gè)單詞表示成一個(gè)向量,此時(shí)input_size=embedding_size。
- 比如每個(gè)句子中有五個(gè)單詞,每個(gè)單詞用一個(gè)100維向量來表示,那么這里input_size=100;
- 在時(shí)間序列預(yù)測中,比如需要預(yù)測負(fù)荷,每一個(gè)負(fù)荷都是一個(gè)單獨(dú)的值,都可以直接參與運(yùn)算,因此并不需要將每一個(gè)負(fù)荷表示成一個(gè)向量,此時(shí)input_size=1。
- 但如果我們使用多變量進(jìn)行預(yù)測,比如我們利用前24小時(shí)每一時(shí)刻的[負(fù)荷、風(fēng)速、溫度、壓強(qiáng)、濕度、天氣、節(jié)假日信息]來預(yù)測下一時(shí)刻的負(fù)荷,那么此時(shí)input_size=7。
- hidden_size:隱藏層節(jié)點(diǎn)個(gè)數(shù)。可以隨意設(shè)置。
- num_layers:層數(shù)。nn.LSTMCell與nn.LSTM相比,num_layers默認(rèn)為1。
- batch_first:默認(rèn)為False,意義見后文。
Inputs
關(guān)于LSTM的輸入,官方文檔給出的定義為:
可以看到,輸入由兩部分組成:input、(初始的隱狀態(tài)h_0,初始的單元狀態(tài)c_0)
其中input:
input(seq_len, batch_size, input_size)
- seq_len:在文本處理中,如果一個(gè)句子有7個(gè)單詞,則seq_len=7;在時(shí)間序列預(yù)測中,假設(shè)我們用前24個(gè)小時(shí)的負(fù)荷來預(yù)測下一時(shí)刻負(fù)荷,則seq_len=24。
- batch_size:一次性輸入LSTM中的樣本個(gè)數(shù)。在文本處理中,可以一次性輸入很多個(gè)句子;在時(shí)間序列預(yù)測中,也可以一次性輸入很多條數(shù)據(jù)。
- input_size:見前文。
(h_0, c_0):
h_0(num_directions * num_layers, batch_size, hidden_size)
c_0(num_directions * num_layers, batch_size, hidden_size)
h_0和c_0的shape一致。
- num_directions:如果是雙向LSTM,則num_directions=2;否則num_directions=1。
- num_layers:見前文。
- batch_size:見前文。
- hidden_size:見前文。
Outputs
關(guān)于LSTM的輸出,官方文檔給出的定義為:
可以看到,輸出也由兩部分組成:otput、(隱狀態(tài)h_n,單元狀態(tài)c_n)
其中output的shape為:
output(seq_len, batch_size, num_directions * hidden_size)
h_n和c_n的shape保持不變,參數(shù)解釋見前文。
batch_first
如果在初始化LSTM時(shí)令batch_first=True,那么input和output的shape將由:
input(seq_len, batch_size, input_size)
output(seq_len, batch_size, num_directions * hidden_size)
變?yōu)椋?/p>
input(batch_size, seq_len, input_size)
output(batch_size, seq_len, num_directions * hidden_size)
即batch_size提前。
案例
簡單搭建一個(gè)LSTM如下所示:
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, output_size, batch_size):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.output_size = output_size
self.num_directions = 1 # 單向LSTM
self.batch_size = batch_size
self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True)
self.linear = nn.Linear(self.hidden_size, self.output_size)
def forward(self, input_seq):
h_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device)
c_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device)
seq_len = input_seq.shape[1] # (5, 30)
# input(batch_size, seq_len, input_size)
input_seq = input_seq.view(self.batch_size, seq_len, 1) # (5, 30, 1)
# output(batch_size, seq_len, num_directions * hidden_size)
output, _ = self.lstm(input_seq, (h_0, c_0)) # output(5, 30, 64)
output = output.contiguous().view(self.batch_size * seq_len, self.hidden_size) # (5 * 30, 64)
pred = self.linear(output) # pred(150, 1)
pred = pred.view(self.batch_size, seq_len, -1) # (5, 30, 1)
pred = pred[:, -1, :] # (5, 1)
return pred
其中定義模型的代碼為:
self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True)
self.linear = nn.Linear(self.hidden_size, self.output_size)
我們加上具體的數(shù)字:
self.lstm = nn.LSTM(self.input_size=1, self.hidden_size=64, self.num_layers=5, batch_first=True)
self.linear = nn.Linear(self.hidden_size=64, self.output_size=1)
再看前向傳播:
def forward(self, input_seq):
h_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device)
c_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device)
seq_len = input_seq.shape[1] # (5, 30)
# input(batch_size, seq_len, input_size)
input_seq = input_seq.view(self.batch_size, seq_len, 1) # (5, 30, 1)
# output(batch_size, seq_len, num_directions * hidden_size)
output, _ = self.lstm(input_seq, (h_0, c_0)) # output(5, 30, 64)
output = output.contiguous().view(self.batch_size * seq_len, self.hidden_size) # (5 * 30, 64)
pred = self.linear(output) # (150, 1)
pred = pred.view(self.batch_size, seq_len, -1) # (5, 30, 1)
pred = pred[:, -1, :] # (5, 1)
return pred
假設(shè)用前30個(gè)預(yù)測下一個(gè),則seq_len=30,batch_size=5,由于設(shè)置了batch_first=True,因此,輸入到LSTM中的input的shape應(yīng)該為:
input(batch_size, seq_len, input_size) = input(5, 30, 1)
但實(shí)際上,經(jīng)過DataLoader處理后的input_seq為:
input_seq(batch_size, seq_len) = input_seq(5, 30)
(5, 30)表示一共5條數(shù)據(jù),每條數(shù)據(jù)的維度都為30。為了匹配LSTM的輸入,我們需要對input_seq的shape進(jìn)行變換:
input_seq = input_seq.view(self.batch_size, seq_len, 1) # (5, 30, 1)
然后將input_seq送入LSTM:
output, _ = self.lstm(input_seq, (h_0, c_0)) # output(5, 30, 64)
根據(jù)前文,output的shape為:
output(batch_size, seq_len, num_directions * hidden_size) = output(5, 30, 64)
全連接層的定義為:
self.linear = nn.Linear(self.hidden_size=64, self.output_size=1)
因此,我們需要將output的第二維度變換為64(150, 64):
output = output.contiguous().view(self.batch_size * seq_len, self.hidden_size) # (5 * 30, 64)
然后將output送入全連接層:
pred = self.linear(output) # pred(150, 1)
得到的預(yù)測值shape為(150, 1)。我們需要將其進(jìn)行還原,變成(5, 30, 1):
pred = pred.view(self.batch_size, seq_len, -1) # (5, 30, 1)
在用DataLoader處理了數(shù)據(jù)后,得到的input_seq和label的shape分別為:
input_seq(batch_size, seq_len) = input_seq(5, 30)label(batch_size, output_size) = label(5, 1)
由于輸出是輸入右移,我們只需要取pred第二維度(time)中的最后一個(gè)數(shù)據(jù):
pred = pred[:, -1, :] # (5, 1)
這樣,我們就得到了預(yù)測值,然后與label求loss,然后再反向更新參數(shù)即可。
時(shí)間序列預(yù)測的一個(gè)真實(shí)案例請見:PyTorch搭建LSTM實(shí)現(xiàn)時(shí)間序列預(yù)測(負(fù)荷預(yù)測)
原文鏈接:https://blog.csdn.net/Cyril_KI/article/details/122557880
相關(guān)推薦
- 2023-11-15 Latex解決表格過寬問題,自適應(yīng)調(diào)整寬度;自動調(diào)整適合的表格大小
- 2023-11-18 Python將字符串String轉(zhuǎn)換成要使用的變量
- 2022-03-29 python中的classmethod與staticmethod_python
- 2022-11-06 Matplotlib學(xué)習(xí)筆記之plt.xticks()用法_python
- 2022-07-02 react用axios的 get/post請求/獲取數(shù)據(jù)
- 2022-12-09 Android入門之ProgressBar的使用教程_Android
- 2022-12-09 ReactQuery系列React?Query?實(shí)踐示例詳解_React
- 2023-01-19 Yolov5更換BiFPN的詳細(xì)步驟總結(jié)_python
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細(xì)win安裝深度學(xué)習(xí)環(huán)境2025年最新版(
- Linux 中運(yùn)行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運(yùn)算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認(rèn)證信息的處理
- Spring Security之認(rèn)證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設(shè)
- maven:解決release錯誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實(shí)現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務(wù)發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡單動態(tài)字符串(SD
- arthas操作spring被代理目標(biāo)對象命令
- Spring中的單例模式應(yīng)用詳解
- 聊聊消息隊(duì)列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠(yuǎn)程分支