網站首頁 編程語言 正文
本文介紹一下 Pytorch 中常用乘法的 TensorRT 實現。
pytorch 用于訓練,TensorRT 用于推理是很多 AI 應用開發的標配。大家往往更加熟悉 pytorch 的算子,而不太熟悉 TensorRT 的算子,這里拿比較常用的乘法運算在兩種框架下的實現做一個對比,可能會有更加直觀一些的認識。
1.乘法運算總覽
先把 pytorch 中的一些常用的乘法運算進行一個總覽:
- torch.mm:用于兩個矩陣 (不包括向量) 的乘法,如維度 (m, n) 的矩陣乘以維度 (n, p) 的矩陣;
- torch.bmm:用于帶 batch 的三維向量的乘法,如維度 (b, m, n) 的矩陣乘以維度 (b, n, p) 的矩陣;
- torch.mul:用于同維度矩陣的逐像素點相乘,也即點乘,如維度 (m, n) 的矩陣點乘維度 (m, n) 的矩陣。該方法支持廣播,也即支持矩陣和元素點乘;
- torch.mv:用于矩陣和向量的乘法,矩陣在前,向量在后,如維度 (m, n) 的矩陣乘以維度為 (n) 的向量,輸出維度為 (m);
- torch.matmul:用于兩個張量相乘,或矩陣與向量乘法,作用包含 torch.mm、torch.bmm、torch.mv;
- @:作用相當于 torch.matmul;
- *:作用相當于 torch.mul;
如上進行了一些具體羅列,可以歸納出,常用的乘法無非兩種:矩陣乘 和 點乘,所以下面分這兩類進行介紹。
2.乘法算子實現
2.1矩陣乘算子實現
先來看看矩陣乘法的 pytorch 的實現 (以下實現在終端):
>>> import torch >>> # torch.mm >>> a = torch.randn(66, 99) >>> b = torch.randn(99, 88) >>> c = torch.mm(a, b) >>> c.shape torch.size([66, 88]) >>> >>> # torch.bmm >>> a = torch.randn(3, 66, 99) >>> b = torch.randn(3, 99, 77) >>> c = torch.bmm(a, b) >>> c.shape torch.size([3, 66, 77]) >>> >>> # torch.mv >>> a = torch.randn(66, 99) >>> b = torch.randn(99) >>> c = torch.mv(a, b) >>> c.shape torch.size([66]) >>> >>> # torch.matmul >>> a = torch.randn(32, 3, 66, 99) >>> b = torch.randn(32, 3, 99, 55) >>> c = torch.matmul(a, b) >>> c.shape torch.size([32, 3, 66, 55]) >>> >>> # @ >>> d = a @ b >>> d.shape torch.size([32, 3, 66, 55])
來看 TensorRT 的實現,以上乘法都可使用?addMatrixMultiply
?方法覆蓋,對應 torch.matmul,先來看該方法的定義:
//! //! \brief Add a MatrixMultiply layer to the network. //! //! \param input0 The first input tensor (commonly A). //! \param op0 The operation to apply to input0. //! \param input1 The second input tensor (commonly B). //! \param op1 The operation to apply to input1. //! //! \see IMatrixMultiplyLayer //! //! \warning Int32 tensors are not valid input tensors. //! //! \return The new matrix multiply layer, or nullptr if it could not be created. //! IMatrixMultiplyLayer* addMatrixMultiply( ITensor& input0, MatrixOperation op0, ITensor& input1, MatrixOperation op1) noexcept { return mImpl->addMatrixMultiply(input0, op0, input1, op1); }
可以看到這個方法有四個傳參,對應兩個張量和其?operation
。來看這個算子在 TensorRT 中怎么添加:
// 構造張量 Tensor0 nvinfer1::IConstantLayer *Constant_layer0 = m_network->addConstant(tensorShape0, value0); // 構造張量 Tensor1 nvinfer1::IConstantLayer *Constant_layer1 = m_network->addConstant(tensorShape1, value1); // 添加矩陣乘法 nvinfer1::IMatrixMultiplyLayer *Matmul_layer = m_network->addMatrixMultiply(Constant_layer0->getOutput(0), matrix0Type, Constant_layer1->getOutput(0), matrix2Type); // 獲取輸出 matmulOutput = Matmul_layer->getOputput(0);
2.2點乘算子實現
再來看看點乘的 pytorch 的實現 (以下實現在終端):
>>> import torch >>> # torch.mul >>> a = torch.randn(66, 99) >>> b = torch.randn(66, 99) >>> c = torch.mul(a, b) >>> c.shape torch.size([66, 99]) >>> d = 0.125 >>> e = torch.mul(a, d) >>> e.shape torch.size([66, 99]) >>> # * >>> f = a * b >>> f.shape torch.size([66, 99])
來看 TensorRT 的實現,以上乘法都可使用?addScale
?方法覆蓋,這在圖像預處理中十分常用,先來看該方法的定義:
//! //! \brief Add a Scale layer to the network. //! //! \param input The input tensor to the layer. //! This tensor is required to have a minimum of 3 dimensions in implicit batch mode //! and a minimum of 4 dimensions in explicit batch mode. //! \param mode The scaling mode. //! \param shift The shift value. //! \param scale The scale value. //! \param power The power value. //! //! If the weights are available, then the size of weights are dependent on the ScaleMode. //! For ::kUNIFORM, the number of weights equals 1. //! For ::kCHANNEL, the number of weights equals the channel dimension. //! For ::kELEMENTWISE, the number of weights equals the product of the last three dimensions of the input. //! //! \see addScaleNd //! \see IScaleLayer //! \warning Int32 tensors are not valid input tensors. //! //! \return The new Scale layer, or nullptr if it could not be created. //! IScaleLayer* addScale(ITensor& input, ScaleMode mode, Weights shift, Weights scale, Weights power) noexcept { return mImpl->addScale(input, mode, shift, scale, power); }
?可以看到有三個模式:
- kUNIFORM:weights 為一個值,對應張量乘一個元素;
- kCHANNEL:weights 維度和輸入張量通道的 c 維度對應,可以做一些以通道為基準的預處理;
- kELEMENTWISE:weights 維度和輸入張量的 c、h、w 對應,不考慮 batch,所以是輸入的后三維;
再來看這個算子在 TensorRT 中怎么添加:
// 構造張量 input nvinfer1::IConstantLayer *Constant_layer = m_network->addConstant(tensorShape, value); // scalemode選擇,kUNIFORM、kCHANNEL、kELEMENTWISE scalemode = kUNIFORM; // 構建 Weights 類型的 shift、scale、power,其中 volume 為元素數量 nvinfer1::Weights scaleShift{nvinfer1::DataType::kFLOAT, nullptr, volume }; nvinfer1::Weights scaleScale{nvinfer1::DataType::kFLOAT, nullptr, volume }; nvinfer1::Weights scalePower{nvinfer1::DataType::kFLOAT, nullptr, volume }; // !! 注意這里還需要對 shift、scale、power 的 values 進行賦值,若只是乘法只需要對 scale 進行賦值就行 // 添加張量乘法 nvinfer1::IScaleLayer *Scale_layer = m_network->addScale(Constant_layer->getOutput(0), scalemode, scaleShift, scaleScale, scalePower); // 獲取輸出 scaleOutput = Scale_layer->getOputput(0);
有一點你可能會比較疑惑,既然是點乘,那么輸入只需要兩個張量就可以了,為啥這里有 input、shift、scale、power 四個張量這么多呢。解釋一下,input 不用說,就是輸入張量,而 shift 表示加法參數、scale 表示乘法參數、power 表示指數參數,說到這里,你應該能發現,這個函數除了我們上面講的點乘外還有其他更加豐富的運算功能。
原文鏈接:https://blog.csdn.net/weixin_42405819/article/details/125070931
相關推薦
- 2022-04-21 catalina.out 和 catalina.log 的區別和用途
- 2022-11-14 值類型和引用類型的區別 I 數據結構中的堆和棧和內存中的堆和棧的區別
- 2022-04-09 Maven 編譯提示:spring-boot-maven-plugin:2.1.9.RELEASE
- 2022-04-23 .NET?Core使用APB?vNext框架入門教程_實用技巧
- 2022-02-11 安裝element UI (全局引入與按需引入)
- 2022-04-22 Jmeter之控制線程執行到某個結果時退出執行(第二種解決方案)
- 2022-03-26 C語言猜兇手及類似題目的實現示例_C 語言
- 2022-10-19 C++模板編程特性之移動語義_C 語言
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支