網(wǎng)站首頁(yè) 編程語(yǔ)言 正文
PyTorch搭建LSTM實(shí)現(xiàn)時(shí)間序列負(fù)荷預(yù)測(cè)_python
作者:Cyril_KI ? 更新時(shí)間: 2022-07-04 編程語(yǔ)言I. 前言
在上一篇文章深入理解PyTorch中LSTM的輸入和輸出(從input輸入到Linear輸出)中,我詳細(xì)地解釋了如何利用PyTorch來(lái)搭建一個(gè)LSTM模型,本篇文章的主要目的是搭建一個(gè)LSTM模型用于時(shí)間序列預(yù)測(cè)。
系列文章:
PyTorch搭建LSTM實(shí)現(xiàn)多變量多步長(zhǎng)時(shí)序負(fù)荷預(yù)測(cè)
PyTorch搭建LSTM實(shí)現(xiàn)多變量時(shí)序負(fù)荷預(yù)測(cè)
PyTorch深度學(xué)習(xí)LSTM從input輸入到Linear輸出
PyTorch搭建雙向LSTM實(shí)現(xiàn)時(shí)間序列負(fù)荷預(yù)測(cè)
II. 數(shù)據(jù)處理
數(shù)據(jù)集為某個(gè)地區(qū)某段時(shí)間內(nèi)的電力負(fù)荷數(shù)據(jù),除了負(fù)荷以外,還包括溫度、濕度等信息。
本篇文章暫時(shí)不考慮其它變量,只考慮用歷史負(fù)荷來(lái)預(yù)測(cè)未來(lái)負(fù)荷。
本文中,我們根據(jù)前24個(gè)時(shí)刻的負(fù)荷下一時(shí)刻的負(fù)荷。有關(guān)多變量預(yù)測(cè)請(qǐng)參考:PyTorch搭建LSTM實(shí)現(xiàn)多變量時(shí)間序列預(yù)測(cè)(負(fù)荷預(yù)測(cè))。
def load_data(file_name):
global MAX, MIN
df = pd.read_csv('data/new_data/' + file_name, encoding='gbk')
columns = df.columns
df.fillna(df.mean(), inplace=True)
MAX = np.max(df[columns[1]])
MIN = np.min(df[columns[1]])
df[columns[1]] = (df[columns[1]] - MIN) / (MAX - MIN)
return df
class MyDataset(Dataset):
def __init__(self, data):
self.data = data
def __getitem__(self, item):
return self.data[item]
def __len__(self):
return len(self.data)
def nn_seq(file_name, B):
print('處理數(shù)據(jù):')
data = load_data(file_name)
load = data[data.columns[1]]
load = load.tolist()
load = torch.FloatTensor(load).view(-1)
data = data.values.tolist()
seq = []
for i in range(len(data) - 24):
train_seq = []
train_label = []
for j in range(i, i + 24):
train_seq.append(load[j])
train_label.append(load[i + 24])
train_seq = torch.FloatTensor(train_seq).view(-1)
train_label = torch.FloatTensor(train_label).view(-1)
seq.append((train_seq, train_label))
# print(seq[:5])
Dtr = seq[0:int(len(seq) * 0.7)]
Dte = seq[int(len(seq) * 0.7):len(seq)]
train_len = int(len(Dtr) / B) * B
test_len = int(len(Dte) / B) * B
Dtr, Dte = Dtr[:train_len], Dte[:test_len]
train = MyDataset(Dtr)
test = MyDataset(Dte)
Dtr = DataLoader(dataset=train, batch_size=B, shuffle=False, num_workers=0)
Dte = DataLoader(dataset=test, batch_size=B, shuffle=False, num_workers=0)
return Dtr, Dte
上面代碼用了DataLoader來(lái)對(duì)原始數(shù)據(jù)進(jìn)行處理,最終得到了batch_size=B的數(shù)據(jù)集Dtr和Dte,Dtr為訓(xùn)練集,Dte為測(cè)試集。
III. LSTM模型
這里采用了深入理解PyTorch中LSTM的輸入和輸出(從input輸入到Linear輸出)中的模型:
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, output_size, batch_size):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.output_size = output_size
self.num_directions = 1 # 單向LSTM
self.batch_size = batch_size
self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True)
self.linear = nn.Linear(self.hidden_size, self.output_size)
def forward(self, input_seq):
h_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device)
c_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device)
seq_len = input_seq.shape[1] # (5, 24)
# input(batch_size, seq_len, input_size)
input_seq = input_seq.view(self.batch_size, seq_len, 1) # (5, 24, 1)
# output(batch_size, seq_len, num_directions * hidden_size)
output, _ = self.lstm(input_seq, (h_0, c_0)) # output(5, 24, 64)
output = output.contiguous().view(self.batch_size * seq_len, self.hidden_size) # (5 * 24, 64)
pred = self.linear(output) # pred(150, 1)
pred = pred.view(self.batch_size, seq_len, -1) # (5, 24, 1)
pred = pred[:, -1, :] # (5, 1)
return pred
IV. 訓(xùn)練
def LSTM_train(name, b):
Dtr, Dte = nn_seq(file_name=name, B=b)
input_size, hidden_size, num_layers, output_size = 1, 64, 5, 1
model = LSTM(input_size, hidden_size, num_layers, output_size, batch_size=b).to(device)
loss_function = nn.MSELoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# 訓(xùn)練
epochs = 15
cnt = 0
for i in range(epochs):
cnt = 0
print('當(dāng)前', i)
for (seq, label) in Dtr:
cnt += 1
seq = seq.to(device)
label = label.to(device)
y_pred = model(seq)
loss = loss_function(y_pred, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if cnt % 100 == 0:
print('epoch', i, ':', cnt - 100, '~', cnt, loss.item())
state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict()}
torch.save(state, LSTM_PATH)
一共訓(xùn)練了15輪:
V. 測(cè)試
def test(name, b):
global MAX, MIN
Dtr, Dte = nn_seq(file_name=name, B=b)
pred = []
y = []
print('loading model...')
input_size, hidden_size, num_layers, output_size = 1, 64, 5, 1
model = LSTM(input_size, hidden_size, num_layers, output_size, batch_size=b).to(device)
model.load_state_dict(torch.load(LSTM_PATH)['model'])
model.eval()
print('predicting...')
for (seq, target) in Dte:
target = list(chain.from_iterable(target.data.tolist()))
y.extend(target)
seq = seq.to(device)
seq_len = seq.shape[1]
seq = seq.view(model.batch_size, seq_len, 1) # (5, 24, 1)
with torch.no_grad():
y_pred = model(seq)
y_pred = list(chain.from_iterable(y_pred.data.tolist()))
pred.extend(y_pred)
y, pred = np.array(y), np.array(pred)
y = (MAX - MIN) * y + MIN
pred = (MAX - MIN) * pred + MIN
print('accuracy:', get_mape(y, pred))
# plot
x = [i for i in range(1, 151)]
x_smooth = np.linspace(np.min(x), np.max(x), 600)
y_smooth = make_interp_spline(x, y[0:150])(x_smooth)
plt.plot(x_smooth, y_smooth, c='green', marker='*', ms=1, alpha=0.75, label='true')
y_smooth = make_interp_spline(x, pred[0:150])(x_smooth)
plt.plot(x_smooth, y_smooth, c='red', marker='o', ms=1, alpha=0.75, label='pred')
plt.grid(axis='y')
plt.legend()
plt.show()
MAPE為6.07%:
VI. 源碼及數(shù)據(jù)
源碼及數(shù)據(jù)我放在了GitHub上,LSTM-Load-Forecasting
原文鏈接:https://blog.csdn.net/Cyril_KI/article/details/122569775
相關(guān)推薦
- 2023-01-02 GO比較兩個(gè)對(duì)象是否相同實(shí)戰(zhàn)案例_Golang
- 2022-11-28 go?mod文件內(nèi)容版本號(hào)簡(jiǎn)單用法詳解_Golang
- 2022-12-02 C語(yǔ)言實(shí)現(xiàn)動(dòng)態(tài)順序表的示例代碼_C 語(yǔ)言
- 2022-05-29 C#調(diào)用USB攝像頭的方法_C#教程
- 2022-08-24 .net新興日志框架Serilog簡(jiǎn)介_(kāi)實(shí)用技巧
- 2022-10-16 python?sys模塊使用方法介紹_python
- 2022-09-12 nginx訪(fǎng)問(wèn)報(bào)403錯(cuò)誤的幾種情況詳解_nginx
- 2023-03-29 C++中字符串全排列算法及next_permutation原理詳解_C 語(yǔ)言
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細(xì)win安裝深度學(xué)習(xí)環(huán)境2025年最新版(
- Linux 中運(yùn)行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲(chǔ)小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運(yùn)算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認(rèn)證信息的處理
- Spring Security之認(rèn)證過(guò)濾器
- Spring Security概述快速入門(mén)
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設(shè)
- maven:解決release錯(cuò)誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實(shí)現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務(wù)發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡(jiǎn)單動(dòng)態(tài)字符串(SD
- arthas操作spring被代理目標(biāo)對(duì)象命令
- Spring中的單例模式應(yīng)用詳解
- 聊聊消息隊(duì)列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠(yuǎn)程分支