網(wǎng)站首頁 編程語言 正文
為什么要保存和加載模型
用數(shù)據(jù)對模型進行訓練后得到了比較理想的模型,但在實際應(yīng)用的時候不可能每次都先進行訓練然后再使用,所以就得先將之前訓練好的模型保存下來,然后在需要用到的時候加載一下直接使用。
模型的本質(zhì)是一堆用某種結(jié)構(gòu)存儲起來的參數(shù),所以在保存的時候有兩種方式
- 一種方式是直接將整個模型保存下來,之后直接加載整個模型,但這樣會比較耗內(nèi)存;
- 另一種是只保存模型的參數(shù),之后用到的時候再創(chuàng)建一個同樣結(jié)構(gòu)的新模型,然后把所保存的參數(shù)導入新模型。
兩種情況的實現(xiàn)方法
(1)只保存模型參數(shù)字典(推薦)
#保存
torch.save(the_model.state_dict(), PATH)
#讀取
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))
(2)保存整個模型
#保存
torch.save(the_model, PATH)
#讀取
the_model = torch.load(PATH)
只保存模型參數(shù)的情況(例子)
pytorch會把模型的參數(shù)放在一個字典里面,而我們所要做的就是將這個字典保存,然后再調(diào)用。
比如說設(shè)計一個單層LSTM的網(wǎng)絡(luò),然后進行訓練,訓練完之后將模型的參數(shù)字典進行保存,保存為同文件夾下面的rnn.pt文件:
class LSTM(nn.Module):
? ? def __init__(self, input_size, hidden_size, num_layers):
? ? ? ? super(LSTM, self).__init__()
? ? ? ? self.hidden_size = hidden_size
? ? ? ? self.num_layers = num_layers
? ? ? ? self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
? ? ? ? self.fc = nn.Linear(hidden_size, 1)
? ? def forward(self, x):
? ? ? ? # Set initial states
? ? ? ? h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)?
? ? ? ? ?# 2 for bidirection
? ? ? ? c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
? ? ? ? # Forward propagate LSTM
? ? ? ? out, _ = self.lstm(x, (h0, c0)) ?
? ? ? ? # out: tensor of shape (batch_size, seq_length, hidden_size*2)
? ? ? ? out = self.fc(out)
? ? ? ? return out
rnn = LSTM(input_size=1, hidden_size=10, num_layers=2).to(device)
# optimize all cnn parameters
optimizer = torch.optim.Adam(rnn.parameters(), lr=0.001) ?
# the target label is not one-hotted
loss_func = nn.MSELoss() ?
for epoch in range(1000):
? ? output = rnn(train_tensor) ?# cnn output`
? ? loss = loss_func(output, train_labels_tensor) ?# cross entropy loss
? ? optimizer.zero_grad() ?# clear gradients for this training step
? ? loss.backward() ?# backpropagation, compute gradients
? ? optimizer.step() ?# apply gradients
? ? output_sum = output
# 保存模型
torch.save(rnn.state_dict(), 'rnn.pt')
保存完之后利用這個訓練完的模型對數(shù)據(jù)進行處理:
# 測試所保存的模型
m_state_dict = torch.load('rnn.pt')
new_m = LSTM(input_size=1, hidden_size=10, num_layers=2).to(device)
new_m.load_state_dict(m_state_dict)
predict = new_m(test_tensor)
這里做一下說明,在保存模型的時候rnn.state_dict()表示rnn這個模型的參數(shù)字典,在測試所保存的模型時要先將這個參數(shù)字典加載一下
m_state_dict = torch.load('rnn.pt');
然后再實例化一個LSTM對像,這里要保證傳入的參數(shù)跟實例化rnn是傳入的對象時一樣的,即結(jié)構(gòu)相同
new_m = LSTM(input_size=1, hidden_size=10, num_layers=2).to(device);
下面是給這個新的模型傳入之前加載的參數(shù)
new_m.load_state_dict(m_state_dict);
最后就可以利用這個模型處理數(shù)據(jù)了
predict = new_m(test_tensor)
保存整個模型的情況(例子)
class LSTM(nn.Module):
? ? def __init__(self, input_size, hidden_size, num_layers):
? ? ? ? super(LSTM, self).__init__()
? ? ? ? self.hidden_size = hidden_size
? ? ? ? self.num_layers = num_layers
? ? ? ? self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
? ? ? ? self.fc = nn.Linear(hidden_size, 1)
? ? def forward(self, x):
? ? ? ? # Set initial states
? ? ? ? h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) ?# 2 for bidirection
? ? ? ? c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
? ? ? ? # Forward propagate LSTM
? ? ? ? out, _ = self.lstm(x, (h0, c0)) ?# out: tensor of shape (batch_size, seq_length, hidden_size*2)
? ? ? ? # print("output_in=", out.shape)
? ? ? ? # print("fc_in_shape=", out[:, -1, :].shape)
? ? ? ? # Decode the hidden state of the last time step
? ? ? ? # out = torch.cat((out[:, 0, :], out[-1, :, :]), axis=0)
? ? ? ? # out = self.fc(out[:, -1, :]) ?# 取最后一列為out
? ? ? ? out = self.fc(out)
? ? ? ? return out
rnn = LSTM(input_size=1, hidden_size=10, num_layers=2).to(device)
print(rnn)
optimizer = torch.optim.Adam(rnn.parameters(), lr=0.001) ?# optimize all cnn parameters
loss_func = nn.MSELoss() ?# the target label is not one-hotted
for epoch in range(1000):
? ? output = rnn(train_tensor) ?# cnn output`
? ? loss = loss_func(output, train_labels_tensor) ?# cross entropy loss
? ? optimizer.zero_grad() ?# clear gradients for this training step
? ? loss.backward() ?# backpropagation, compute gradients
? ? optimizer.step() ?# apply gradients
? ? output_sum = output
# 保存模型
torch.save(rnn, 'rnn1.pt')
保存完之后利用這個訓練完的模型對數(shù)據(jù)進行處理:
new_m = torch.load('rnn1.pt')
predict = new_m(test_tensor)
參考pytorch的官方文檔
總結(jié)
原文鏈接:https://blog.csdn.net/comli_cn/article/details/107516740
- 上一篇:沒有了
- 下一篇:沒有了
相關(guān)推薦
- 2022-12-05 Android不同版本兼容性適配方法教程_Android
- 2023-10-31 IP地址、網(wǎng)關(guān)、網(wǎng)絡(luò)/主機號、子網(wǎng)掩碼關(guān)系
- 2022-06-14 golang并發(fā)安全及讀寫互斥鎖的示例分析_Golang
- 2022-05-05 RabbitMQ的Web管理與監(jiān)控簡介_web2.0
- 2022-05-11 C++類繼承時的構(gòu)造函數(shù)_C 語言
- 2022-07-30 jQuery?UI旋轉(zhuǎn)器部件Spinner?Widget_jquery
- 2022-05-06 react-router-domV6版本的路由和嵌套路由寫法詳解_React
- 2022-10-26 如何查看git分支從哪個源分支拉的_相關(guān)技巧
- 欄目分類
-
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細win安裝深度學習環(huán)境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設(shè)
- maven:解決release錯誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務(wù)發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡單動態(tài)字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應(yīng)用詳解
- 聊聊消息隊列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支