網站首頁 編程語言 正文
函數原型
參數介紹
mode (torch.nn.Module, torch.jit.ScriptModule or torch.jit.ScriptFunction)
需要轉換的模型,支持的模型類型有:torch.nn.Module, torch.jit.ScriptModule or torch.jit.ScriptFunction
args (tuple or torch.Tensor)
args可以被設置成三種形式
1.一個tuple
args = (x, y, z)
這個tuple應該與模型的輸入相對應,任何非Tensor的輸入都會被硬編碼入onnx模型,所有Tensor類型的參數會被當做onnx模型的輸入。
2.一個Tensor
args = torch.Tensor([1, 2, 3])
一般這種情況下模型只有一個輸入
3.一個帶有字典的tuple
args = (x,
{'y': input_y,
'z': input_z})
這種情況下,所有字典之前的參數會被當做“非關鍵字”參數傳入網絡,字典種的鍵值對會被當做關鍵字參數傳入網絡。如果網絡中的關鍵字參數未出現在此字典中,將會使用默認值,如果沒有設定默認值,則會被指定為None。
NOTE:
一個特殊情況,當網絡本身最后一個參數為字典時,直接在tuple最后寫一個字典則會被誤認為關鍵字傳參。所以,可以通過在tuple最后添加一個空字典來解決。
#錯誤寫法:
torch.onnx.export(
model,
(x,
# WRONG: will be interpreted as named arguments
{y: z}),
"test.onnx.pb")
# 糾正
torch.onnx.export(
model,
(x,
{y: z},
{}),
"test.onnx.pb")
f
一個文件類對象或一個路徑字符串,二進制的protocol buffer將被寫入此文件
export_params (bool, default True)
如果為True則導出模型的參數。如果想導出一個未訓練的模型,則設為False
verbose (bool, default False)
如果為True,則打印一些轉換日志,并且onnx模型中會包含doc_string信息。
training (enum, default TrainingMode.EVAL)
枚舉類型包括:
TrainingMode.EVAL - 以推理模式導出模型。
TrainingMode.PRESERVE - 如果model.training為False,則以推理模式導出;否則以訓練模式導出。
TrainingMode.TRAINING - 以訓練模式導出,此模式將禁止一些影響訓練的優化操作。
input_names (list of str, default empty list)
按順序分配給onnx圖的輸入節點的名稱列表。
output_names (list of str, default empty list)
按順序分配給onnx圖的輸出節點的名稱列表。
operator_export_type (enum, default None)
默認為OperatorExportTypes.ONNX, 如果Pytorch built with DPYTORCH_ONNX_CAFFE2_BUNDLE,則默認為OperatorExportTypes.ONNX_ATEN_FALLBACK。
枚舉類型包括:
OperatorExportTypes.ONNX - 將所有操作導出為ONNX操作。
OperatorExportTypes.ONNX_FALLTHROUGH - 試圖將所有操作導出為ONNX操作,但碰到無法轉換的操作(如onnx未實現的操作),則將操作導出為“自定義操作”,為了使導出的模型可用,運行時必須支持這些自定義操作。支持自定義操作方法見鏈接。
OperatorExportTypes.ONNX_ATEN - 所有ATen操作導出為ATen操作,ATen是Pytorch的內建tensor庫,所以這將使得模型直接使用Pytorch實現。(此方法轉換的模型只能被Caffe2直接使用)
OperatorExportTypes.ONNX_ATEN_FALLBACK - 試圖將所有的ATen操作也轉換為ONNX操作,如果無法轉換則轉換為ATen操作(此方法轉換的模型只能被Caffe2直接使用)。例如:
# 轉換前:
graph(%0 : Float):
%3 : int = prim::Constant[value=0]()
# conversion unsupported
%4 : Float = aten::triu(%0, %3)
# conversion supported
%5 : Float = aten::mul(%4, %0)
return (%5)
# 轉換后:
graph(%0 : Float):
%1 : Long() = onnx::Constant[value={0}]()
# not converted
%2 : Float = aten::ATen[operator="triu"](%0, %1)
# converted
%3 : Float = onnx::Mul(%2, %0)
return (%3)
opset_version (int, default 9)
默認是9。值必須等于_onnx_main_opset或在_onnx_stable_opsets之內。具體可在torch/onnx/symbolic_helper.py中找到。例如:
_default_onnx_opset_version = 9
_onnx_main_opset = 13
_onnx_stable_opsets = [7, 8, 9, 10, 11, 12]
_export_onnx_opset_version = _default_onnx_opset_version
do_constant_folding (bool, default False)
是否使用“常量折疊”優化。常量折疊將使用一些算好的常量來優化一些輸入全為常量的節點。
example_outputs (T or a tuple of T, where T is Tensor or convertible to Tensor, default None)
當需輸入模型為ScriptModule 或 ScriptFunction時必須提供。此參數用于確定輸出的類型和形狀,而不跟蹤(tracing )模型的執行。
dynamic_axes (dict<string, dict<python:int, string>> or dict<string, list(int)>, default empty dict)
通過以下規則設置動態的維度:
KEY(str) - 必須是input_names或output_names指定的名稱,用來指定哪個變量需要使用到動態尺寸。
VALUE(dict or list) - 如果是一個dict,dict中的key是變量的某個維度,dict中的value是我們給這個維度取的名稱。如果是一個list,則list中的元素都表示此變量的某個維度。
具體可參考如下示例:
class SumModule(torch.nn.Module):
def forward(self, x):
return torch.sum(x, dim=1)
# 以動態尺寸模式導出模型
torch.onnx.export(SumModule(), (torch.ones(2, 2),), "onnx.pb",
input_names=["x"], output_names=["sum"],
dynamic_axes={
# dict value: manually named axes
"x": {0: "my_custom_axis_name"},
# list value: automatic names
"sum": [0],
})
### 導出后的節點信息
##input
input {
name: "x"
...
shape {
dim {
dim_param: "my_custom_axis_name" # axis 0
}
dim {
dim_value: 2 # axis 1
...
##output
output {
name: "sum"
...
shape {
dim {
dim_param: "sum_dynamic_axes_1" # axis 0
...
keep_initializers_as_inputs (bool, default None)
NONE
custom_opsets (dict<str, int>, default empty dict)
NONE
Torch.onnx.export執行流程:
1、如果輸入到torch.onnx.export的模型是nn.Module類型,則默認會將模型使用torch.jit.trace轉換為ScriptModule
2、使用args參數和torch.jit.trace將模型轉換為ScriptModule,torch.jit.trace不能處理模型中的循環和if語句
3、如果模型中存在循環或者if語句,在執行torch.onnx.export之前先使用torch.jit.script將nn.Module轉換為ScriptModule
4、模型轉換成onnx之后,預測結果與之前會有稍微的差別,這些差別往往不會改變模型的預測結果,比如預測的概率在小數點之后五六位有差別。
總結
原文鏈接:https://blog.csdn.net/Dteam_f/article/details/122487634
相關推薦
- 2023-03-23 Android進階CoordinatorLayout協調者布局實現吸頂效果_Android
- 2021-12-12 c++虛函數與虛函數表原理_C 語言
- 2022-03-19 Android使用DocumentFile讀寫外置存儲的問題_Android
- 2022-05-06 Python判斷字符串中是否是中英文文小技巧
- 2022-06-22 如何利用Android仿微博正文鏈接交互效果_Android
- 2023-02-17 Python導入其他文件夾中函數的實現方法_python
- 2023-10-31 IP地址、網關、網絡/主機號、子網掩碼關系
- 2022-10-10 Go?代碼規范錯誤處理示例經驗總結_Golang
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支