網(wǎng)站首頁 編程語言 正文
前言
Open Neural Network Exchange (ONNX,開放神經(jīng)網(wǎng)絡(luò)交換) 格式,是一個用于表示深度學(xué)習(xí)模型的標(biāo)準(zhǔn),可使模型在不同框架之間進(jìn)行轉(zhuǎn)移
PyTorch 所定義的模型為動態(tài)圖,其前向傳播是由類方法定義和實現(xiàn)的
但是 Python 代碼的效率是比較底下的,試想把動態(tài)圖轉(zhuǎn)化為靜態(tài)圖,模型的推理速度應(yīng)當(dāng)有所提升
PyTorch 框架中,torch.onnx.export 可以將父類為 nn.Module 的模型導(dǎo)出到 onnx 文件中,
最重要的有三個參數(shù):
- model:父類為 nn.Module 的模型
- args:傳入 model 的 forward 方法的變量列表,類型應(yīng)為
- tuplef:onnx 文件名稱的字符串
import torch
from torchvision.models import resnet50
file = 'resnet.onnx'
# 聲明模型
resnet = resnet50(pretrained=False).eval()
image = torch.rand([1, 3, 224, 224])
# 導(dǎo)出為 onnx 文件
torch.onnx.export(resnet, (image,), file)
onnx 文件可被 Netron 打開,以查看模型結(jié)構(gòu)
基本用法
要在 Python 中運行 onnx 模型,需要下載 onnxruntime
# 選其一即可
pip install onnxruntime # CPU 版本
pip install onnxruntime-gpu # GPU 版本
推理時需要借助其中的 InferenceSession,其中較為重要的實例方法有:
- get_inputs():得到輸入變量的列表 (變量屬性:name、shape、type)
- get_outputs():得到輸入變量的列表 (變量屬性:name、shape、type)run(output_names, input_feed):輸入變量為 numpy.ndarray (注意 dtype 應(yīng)為 float32),使用模型推理并返回輸出
可得出 onnx 模型的基本用法:
import onnxruntime as ort
import numpy as np
file = 'resnet.onnx'
# 找到 GPU / CPU
provider = ort.get_available_providers()[
1 if ort.get_device() == 'GPU' else 0]
print('設(shè)備:', provider)
# 聲明 onnx 模型
model = ort.InferenceSession(file, providers=[provider])
# 參考: ort.NodeArg
for node_list in model.get_inputs(), model.get_outputs():
for node in node_list:
attr = {'name': node.name,
'shape': node.shape,
'type': node.type}
print(attr)
print('-' * 60)
# 得到輸入、輸出結(jié)點的名稱
input_node_name = model.get_inputs()[0].name
ouput_node_name = [node.name for node in model.get_outputs()]
image = np.random.random([1, 3, 224, 224]).astype(np.float32)
print(model.run(output_names=ouput_node_name,
input_feed={input_node_name: image}))
高級 API
為了簡化使用步驟,使用類進(jìn)行封裝:
class Onnx_Module(ort.InferenceSession):
''' onnx 推理模型
provider: 優(yōu)先使用 GPU'''
provider = ort.get_available_providers()[
1 if ort.get_device() == 'GPU' else 0]
def __init__(self, file):
super(Onnx_Module, self).__init__(file, providers=[self.provider])
# 參考: ort.NodeArg
self.inputs = [node_arg.name for node_arg in self.get_inputs()]
self.outputs = [node_arg.name for node_arg in self.get_outputs()]
def __call__(self, *arrays):
input_feed = {name: x for name, x in zip(self.inputs, arrays)}
return self.run(self.outputs, input_feed)
在 PyTorch 中,對于卷積神經(jīng)網(wǎng)絡(luò) model 與圖像 image,推理的代碼為 "model(image)",而使用這個封裝的類也是類似:
import numpy as np
file = 'resnet.onnx'
model = Onnx_Module(file)
image = np.random.random([1, 3, 224, 224]).astype(np.float32)
print(model(image))
為了方便觀察 Torch 模型與 onnx 模型的速度差異,同時檢查兩個模型的輸出是否一致,又編寫了 test 函數(shù)
test 方法的參數(shù)與 torch.onnx.export 一致,其基本流程為:
- 得到 Torch 模型的輸出,并 print 推斷耗時
- 將 Torch 模型導(dǎo)出為 onnx 文件,將輸入變量中的 torch.tensor 轉(zhuǎn)化為 numpy.ndarray
- 初始化 onnx 模型,得到 onnx?模型的輸出,并 print 推斷耗時
- 計算 Torch 模型與 onnx 模型輸出的絕對誤差的均值
- 將 onnx 模型 return
class Timer:
repeat = 3
def __new__(cls, fun, *args, **kwargs):
import time
start = time.time()
for _ in range(cls.repeat): fun(*args, **kwargs)
cost = (time.time() - start) / cls.repeat
return cost * 1e3 # ms
class Onnx_Module(ort.InferenceSession):
''' onnx 推理模型
provider: 優(yōu)先使用 GPU'''
provider = ort.get_available_providers()[
1 if ort.get_device() == 'GPU' else 0]
def __init__(self, file):
super(Onnx_Module, self).__init__(file, providers=[self.provider])
# 參考: ort.NodeArg
self.inputs = [node_arg.name for node_arg in self.get_inputs()]
self.outputs = [node_arg.name for node_arg in self.get_outputs()]
def __call__(self, *arrays):
input_feed = {name: x for name, x in zip(self.inputs, arrays)}
return self.run(self.outputs, input_feed)
@classmethod
def test(cls, model, args, file, **export_kwargs):
# 測試 Torch 的運行時間
torch_output = model(*args).data.numpy()
print(f'Torch: {Timer(model, *args):.2f} ms')
# model: Torch -> onnx
torch.onnx.export(model, args, file, **export_kwargs)
# data: tensor -> array
args = tuple(map(lambda tensor: tensor.data.numpy(), args))
onnx_model = cls(file)
# 測試 onnx 的運行時間
onnx_output = onnx_model(*args)
print(f'Onnx: {Timer(onnx_model, *args):.2f} ms')
# 計算 Torch 模型與 onnx 模型輸出的絕對誤差
abs_error = np.abs(torch_output - onnx_output).mean()
print(f'Mean Error: {abs_error:.2f}')
return onnx_model
對于 ResNet50 而言,Torch 模型的推斷耗時為 172.67 ms,onnx 模型的推斷耗時為 36.56 ms,onnx 模型的推斷耗時僅為 Torch 模型的 21.17%
原文鏈接:https://blog.csdn.net/qq_55745968/article/details/125965503
相關(guān)推薦
- 2022-09-25 面向?qū)ο蠛兔嫦蜻^程:兩種程序設(shè)計思想的基礎(chǔ)介紹和對比
- 2022-06-17 基于pgrouting的路徑規(guī)劃處理方法_PostgreSQL
- 2022-04-25 ASP.NET?Core中Cookie驗證身份用法詳解_實用技巧
- 2022-04-17 Mac使用pandoc 將docx文件轉(zhuǎn)換成html文件 快速實現(xiàn)協(xié)議文件的轉(zhuǎn)換
- 2022-08-30 MongoDB集合的增刪改查管理_MongoDB
- 2023-06-05 python中xlwt模塊的具體用法_python
- 2022-04-12 Python實現(xiàn)批量向PDF文件添加中文水印_python
- 2022-06-19 Go語言列表List獲取元素的4種方式_Golang
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細(xì)win安裝深度學(xué)習(xí)環(huán)境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認(rèn)證信息的處理
- Spring Security之認(rèn)證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設(shè)
- maven:解決release錯誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務(wù)發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡單動態(tài)字符串(SD
- arthas操作spring被代理目標(biāo)對象命令
- Spring中的單例模式應(yīng)用詳解
- 聊聊消息隊列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠(yuǎn)程分支