網站首頁 編程語言 正文
概述
在pytorch中有兩種方式可以保存推理模型,第一種是只保存模型的參數,比如parameters和buffers;另外一種是保存整個模型;
1.保存模型 - 權重參數
我們可以用torch.save()函數來保存model.state_dict();state_dict()里面包含模型的parameters&buffers;這種方法只保存模型中必要的訓練參數。
你可以用pytorch中的pickle來保存模型;使用這種方法可以生成最直觀的語法,并涉及最少的代碼;這種方法的缺點是,序列化的數據被綁定到特定的類和保存模型時使用的確切的目錄結構。
這樣做的原因是pickle并不保存模型類本身。相反,它保存包含類的文件的路徑,在加載期間使用;因此,當在其他項目中使用或重構后,您的代碼可能以各種方式中斷。
我們將探討如何保存和加載模型進行推斷的兩種方法。
步驟:
(1)導入所有必要的庫來加載我們的數據
(2)定義和初始化神經網絡
(3)初始化優化器
(4)保存并通過state_dict加載模型
(5)保存并加載整個模型
1.1代碼
# -*- coding: utf-8 -*- # @Project: zc # @Author: zc # @File name: Neural_Network_test # @Create time: 2022/3/19 15:33 # 1.導入相關數據庫 import torch import torch.nn as nn import torch.optim as optim from torch.nn import functional as F # 2.定義神經網絡模型 class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x # 3. 實例化神經網絡 net = Net() # 4. 實例化優化器 optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) # 5. 保存模型參數 # Specify a path PATH = "state_dict_model.pt" # 6. 保存模型的參數字典:parameters and buffers torch.save(net.state_dict(), PATH) # 7. 實例化新的模型 model = Net() # 8. 給新的實例加載之前的模型參數 model.load_state_dict(torch.load(PATH)) # 9. 設置模型為評估模式 model.eval()
注意(1):
pytorch中常用的慣例是將model.state_dict()保存為"state_dict_model.pt",即文件的格式一般是.pt或者.pth格式文件;注意load_state_dict加載的是一個字典,而不是路徑。
注意(2):
模型參數在推理階段一定要設置model.eval();這樣可以讓dropout和batchnorm失效,如果沒設置推理模式,會得到不一樣的結果。
2.保存模型 - 整個模型
將模型所有的內容都保存下來。?
# Specify a path PATH = "entire_model.pt" # Save torch.save(net, PATH) # Load model = torch.load(PATH) model.eval()
3.保存模型 - checkpoints
我們按照checkpoints模式來保存模型,本質上就是按照字典的模式進行分門別類的保存,我們可以通過鍵值進行加載。
-
epoch
:訓練周期 -
model_state_dict
:模型可訓練參數 -
optimizer_state_dict
:模型優化器參數 -
loss
:模型的損失函數
# Additional information EPOCH = 5 PATH = "model.pt" LOSS = 0.4 torch.save({ 'epoch': EPOCH, 'model_state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': LOSS, }, PATH)
保存和加載通用的檢查點模型以進行推斷或恢復訓練,這有助于您從上一個地方繼續進行。
當保存一個常規檢查點時,您必須保存模型的state_dict之外的更多信息。
保存優化器的state_dict也很重要,因為它包含緩沖區和參數,隨著模型的運行而更新。
您可能希望保存的其他項目是您離開的時期,最新記錄的訓練損失,外部torch.nn.嵌入層,以及更多,基于自己的算法
3.1代碼
# 1.導入相關數據庫 import torch import torch.nn as nn import torch.optim as optim from torch.nn import functional as F # 2. 定義神經網絡 class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x # 3. 實例化神經網絡 net = Net() # 4. 實例化優化器 optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) # Additional information # 5. 定義超參數 EPOCH = 5 PATH = "model.pt" LOSS = 0.4 # 6. 以checkpoints形式保存模型的相關數據 torch.save({ 'epoch': EPOCH, 'model_state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': LOSS, }, PATH) # 7. 重新實例化一個模型 model = Net() # 8. 實例化優化器 optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) # 9. 加載以前的checkpoint checkpoint = torch.load(PATH) # 10. 通過鍵值來加載相關參數 model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] loss = checkpoint['loss'] # 11.設置推理模式 model.eval() # - or - model.train()
4.保存雙模型
當保存有多個神經網絡模型組成的神經網絡時,比如GAN對抗模型,sequence-to-sequence序列到序列模型,或者一個組合模型,你必須為每一個模型保存狀態字典state_dict()和其對應的優化器參數optimizer.state_dict();您還可以保存任何其他項目,可能會幫助您恢復訓練,只需將它們添加到字典;為了加載模型,第一步是初始化神經網絡模型和優化器,然后用torch.load()去加載checkpoint對應的數據,因為checkpoints是字典,所以我們可以通過鍵值進行查詢導入;
4.1相關步驟
(1)導入所有相關的數據庫
(2)定義和實例化神經網絡模型
(3)初始化優化器
(4)保存多重模型
(5)加載多重模型
# 1.導入相關數據庫 import torch import torch.nn as nn import torch.optim as optim from torch.nn import functional as F # 2. 定義神經網絡 class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x # 3. 實例化神經網絡A,B netA = Net() netB = Net() # 4. 實例化優化器A,B optimizerA = optim.SGD(netA.parameters(), lr=0.001, momentum=0.9) optimizerB = optim.SGD(netB.parameters(), lr=0.001, momentum=0.9) # 5. 保存模型 # Specify a path to save to PATH = "model.pt" torch.save({ 'modelA_state_dict': netA.state_dict(), 'modelB_state_dict': netB.state_dict(), 'optimizerA_state_dict': optimizerA.state_dict(), 'optimizerB_state_dict': optimizerB.state_dict(), }, PATH) # 6.重新實例化新的網絡模型A,B modelA = Net() modelB = Net() # 7. 重新實例化新的網絡模型A,B optimModelA = optim.SGD(modelA.parameters(), lr=0.001, momentum=0.9) optimModelB = optim.SGD(modelB.parameters(), lr=0.001, momentum=0.9) # 8. 將以前模型的參數重新加載到新的模型A,B中 checkpoint = torch.load(PATH) modelA.load_state_dict(checkpoint['modelA_state_dict']) modelB.load_state_dict(checkpoint['modelB_state_dict']) optimizerA.load_state_dict(checkpoint['optimizerA_state_dict']) optimizerB.load_state_dict(checkpoint['optimizerB_state_dict']) # 9. 開啟預測模式 modelA.eval() modelB.eval() # - or - # 10.開啟訓練模式 modelA.train() modelB.train()
5.機器學習流程圖
6.機器學習常用庫
總結
原文鏈接:https://blog.csdn.net/scar2016/article/details/123618089
相關推薦
- 2023-11-21 NVIDIA jetson nano/ Linux/ Ubuntu18.0.4 配置固定IP靜態IP
- 2022-07-29 Linux中文件的基本屬性介紹_linux shell
- 2022-12-10 React實現控制減少useContext導致非必要的渲染詳解_React
- 2022-04-25 C++特殊成員函數以及其生成機制詳解_C 語言
- 2022-01-19 iview-admin 富文本編輯器(wangEditor)菜單無法選中解決方案
- 2022-12-06 Docker基礎和常用命令詳解_docker
- 2022-05-13 Centos error: cannot remove “core“: snap “core“ is
- 2022-05-29 ASP.NET?Core在WebApi項目中使用Cookie_實用技巧
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支