網站首頁 編程語言 正文
前言
Open Neural Network Exchange (ONNX,開放神經網絡交換) 格式,是一個用于表示深度學習模型的標準,可使模型在不同框架之間進行轉移
PyTorch 所定義的模型為動態圖,其前向傳播是由類方法定義和實現的
但是 Python 代碼的效率是比較底下的,試想把動態圖轉化為靜態圖,模型的推理速度應當有所提升
PyTorch 框架中,torch.onnx.export 可以將父類為 nn.Module 的模型導出到 onnx 文件中,
最重要的有三個參數:
- model:父類為 nn.Module 的模型
- args:傳入 model 的 forward 方法的變量列表,類型應為
- tuplef:onnx 文件名稱的字符串
import torch
from torchvision.models import resnet50
file = 'resnet.onnx'
# 聲明模型
resnet = resnet50(pretrained=False).eval()
image = torch.rand([1, 3, 224, 224])
# 導出為 onnx 文件
torch.onnx.export(resnet, (image,), file)
onnx 文件可被 Netron 打開,以查看模型結構
基本用法
要在 Python 中運行 onnx 模型,需要下載 onnxruntime
# 選其一即可
pip install onnxruntime # CPU 版本
pip install onnxruntime-gpu # GPU 版本
推理時需要借助其中的 InferenceSession,其中較為重要的實例方法有:
- get_inputs():得到輸入變量的列表 (變量屬性:name、shape、type)
- get_outputs():得到輸入變量的列表 (變量屬性:name、shape、type)run(output_names, input_feed):輸入變量為 numpy.ndarray (注意 dtype 應為 float32),使用模型推理并返回輸出
可得出 onnx 模型的基本用法:
import onnxruntime as ort
import numpy as np
file = 'resnet.onnx'
# 找到 GPU / CPU
provider = ort.get_available_providers()[
1 if ort.get_device() == 'GPU' else 0]
print('設備:', provider)
# 聲明 onnx 模型
model = ort.InferenceSession(file, providers=[provider])
# 參考: ort.NodeArg
for node_list in model.get_inputs(), model.get_outputs():
for node in node_list:
attr = {'name': node.name,
'shape': node.shape,
'type': node.type}
print(attr)
print('-' * 60)
# 得到輸入、輸出結點的名稱
input_node_name = model.get_inputs()[0].name
ouput_node_name = [node.name for node in model.get_outputs()]
image = np.random.random([1, 3, 224, 224]).astype(np.float32)
print(model.run(output_names=ouput_node_name,
input_feed={input_node_name: image}))
高級 API
為了簡化使用步驟,使用類進行封裝:
class Onnx_Module(ort.InferenceSession):
''' onnx 推理模型
provider: 優先使用 GPU'''
provider = ort.get_available_providers()[
1 if ort.get_device() == 'GPU' else 0]
def __init__(self, file):
super(Onnx_Module, self).__init__(file, providers=[self.provider])
# 參考: ort.NodeArg
self.inputs = [node_arg.name for node_arg in self.get_inputs()]
self.outputs = [node_arg.name for node_arg in self.get_outputs()]
def __call__(self, *arrays):
input_feed = {name: x for name, x in zip(self.inputs, arrays)}
return self.run(self.outputs, input_feed)
在 PyTorch 中,對于卷積神經網絡 model 與圖像 image,推理的代碼為 "model(image)",而使用這個封裝的類也是類似:
import numpy as np
file = 'resnet.onnx'
model = Onnx_Module(file)
image = np.random.random([1, 3, 224, 224]).astype(np.float32)
print(model(image))
為了方便觀察 Torch 模型與 onnx 模型的速度差異,同時檢查兩個模型的輸出是否一致,又編寫了 test 函數
test 方法的參數與 torch.onnx.export 一致,其基本流程為:
- 得到 Torch 模型的輸出,并 print 推斷耗時
- 將 Torch 模型導出為 onnx 文件,將輸入變量中的 torch.tensor 轉化為 numpy.ndarray
- 初始化 onnx 模型,得到 onnx?模型的輸出,并 print 推斷耗時
- 計算 Torch 模型與 onnx 模型輸出的絕對誤差的均值
- 將 onnx 模型 return
class Timer:
repeat = 3
def __new__(cls, fun, *args, **kwargs):
import time
start = time.time()
for _ in range(cls.repeat): fun(*args, **kwargs)
cost = (time.time() - start) / cls.repeat
return cost * 1e3 # ms
class Onnx_Module(ort.InferenceSession):
''' onnx 推理模型
provider: 優先使用 GPU'''
provider = ort.get_available_providers()[
1 if ort.get_device() == 'GPU' else 0]
def __init__(self, file):
super(Onnx_Module, self).__init__(file, providers=[self.provider])
# 參考: ort.NodeArg
self.inputs = [node_arg.name for node_arg in self.get_inputs()]
self.outputs = [node_arg.name for node_arg in self.get_outputs()]
def __call__(self, *arrays):
input_feed = {name: x for name, x in zip(self.inputs, arrays)}
return self.run(self.outputs, input_feed)
@classmethod
def test(cls, model, args, file, **export_kwargs):
# 測試 Torch 的運行時間
torch_output = model(*args).data.numpy()
print(f'Torch: {Timer(model, *args):.2f} ms')
# model: Torch -> onnx
torch.onnx.export(model, args, file, **export_kwargs)
# data: tensor -> array
args = tuple(map(lambda tensor: tensor.data.numpy(), args))
onnx_model = cls(file)
# 測試 onnx 的運行時間
onnx_output = onnx_model(*args)
print(f'Onnx: {Timer(onnx_model, *args):.2f} ms')
# 計算 Torch 模型與 onnx 模型輸出的絕對誤差
abs_error = np.abs(torch_output - onnx_output).mean()
print(f'Mean Error: {abs_error:.2f}')
return onnx_model
對于 ResNet50 而言,Torch 模型的推斷耗時為 172.67 ms,onnx 模型的推斷耗時為 36.56 ms,onnx 模型的推斷耗時僅為 Torch 模型的 21.17%
原文鏈接:https://blog.csdn.net/qq_55745968/article/details/125965503
相關推薦
- 2023-02-10 Golang中interface的基本用法詳解_Golang
- 2022-06-16 nginx限流及配置管理實戰記錄_nginx
- 2022-10-06 Android?Jetpack庫重要組件WorkManager的使用_Android
- 2024-03-08 Linux虛擬機輸入ifconfig不顯示IP地址解決方法
- 2022-05-24 Python?6種基本變量操作技巧總結_python
- 2022-03-31 C#判斷語句的表達式樹實現_C#教程
- 2022-06-24 C#中緩存System.Web.Caching用法總結_C#教程
- 2023-02-10 批處理從html格式(接收到的郵件)中讀取數據的操作方法_DOS/BAT
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支