網站首頁 編程語言 正文
1. 引言
FLOPs 是 floating point operations 的縮寫,指浮點運算數,可以用來衡量模型/算法的計算復雜度。本文主要討論如何在 tensorflow 1.x, tensorflow 2.x 以及 pytorch 中利用相關工具計算對應模型的 FLOPs。
2. 模型結構
為了說明方便,先搭建一個簡單的神經網絡模型,其模型結構以及主要參數如表1 所示。
表 1 模型結構及主要參數
Layers | channels | Kernels | Strides | Units | Activation |
---|---|---|---|---|---|
Conv2D | 32 | (4,4) | (1,2) | \ | relu |
GRU | \ | \ | \ | 96 | \ |
Dense | \ | \ | \ | 256 | sigmoid |
用 tensorflow(實際使用 tensorflow 中的 keras 模塊)實現該模型的代碼為:
from tensorflow.keras.layers import *
from tensorflow.keras.models import load_model, Model
def test_model_tf(Input_shape):
# shape: [B, C, T, F]
main_input = Input(batch_shape=Input_shape, name='main_inputs')
conv = Conv2D(32, kernel_size=(4, 4), strides=(1, 2), activation='relu', data_format='channels_first', name='conv')(main_input)
# shape: [B, T, FC]
gru = Reshape((conv.shape[2], conv.shape[1] * conv.shape[3]))(conv)
gru = GRU(units=96, reset_after=True, return_sequences=True, name='gru')(gru)
output = Dense(256, activation='sigmoid', name='output')(gru)
model = Model(inputs=[main_input], outputs=[output])
return model
用 pytorch 實現該模型的代碼為:
import torch
import torch.nn as nn
class test_model_torch(nn.Module):
def __init__(self):
super(test_model_torch, self).__init__()
self.conv2d = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(4,4), stride=(1,2))
self.relu = nn.ReLU()
self.gru = nn.GRU(input_size=4064, hidden_size=96)
self.fc = nn.Linear(96, 256)
self.sigmoid = nn.Sigmoid()
def forward(self, inputs):
# shape: [B, C, T, F]
out = self.conv2d(inputs)
out = self.relu(out)
# shape: [B, T, FC]
batch, channel, frame, freq = out.size()
out = torch.reshape(out, (batch, frame, freq*channel))
out, _ = self.gru(out)
out = self.fc(out)
out = self.sigmoid(out)
return out
3. 計算模型的 FLOPs
本節討論的版本具體為:tensorflow 1.12.0, tensorflow 2.3.1 以及 pytorch 1.10.1+cu102。
3.1. tensorflow 1.12.0
在 tensorflow 1.12.0 環境中,可以使用以下代碼計算模型的 FLOPs:
import tensorflow as tf
import tensorflow.keras.backend as K
def get_flops(model):
run_meta = tf.RunMetadata()
opts = tf.profiler.ProfileOptionBuilder.float_operation()
flops = tf.profiler.profile(graph=K.get_session().graph,
run_meta=run_meta, cmd='op', options=opts)
return flops.total_float_ops
if __name__ == "__main__":
x = K.random_normal(shape=(1, 1, 100, 256))
model = test_model_tf(x.shape)
print('FLOPs of tensorflow 1.12.0:', get_flops(model))
3.2. tensorflow 2.3.1
在 tensorflow 2.3.1 環境中,可以使用以下代碼計算模型的 FLOPs :
import tensorflow.compat.v1 as tf
import tensorflow.compat.v1.keras.backend as K
tf.disable_eager_execution()
def get_flops(model):
run_meta = tf.RunMetadata()
opts = tf.profiler.ProfileOptionBuilder.float_operation()
flops = tf.profiler.profile(graph=K.get_session().graph,
run_meta=run_meta, cmd='op', options=opts)
return flops.total_float_ops
if __name__ == "__main__":
x = K.random_normal(shape=(1, 1, 100, 256))
model = test_model_tf(x.shape)
print('FLOPs of tensorflow 2.3.1:', get_flops(model))
3.3. pytorch 1.10.1+cu102
在 pytorch 1.10.1+cu102 環境中,可以使用以下代碼計算模型的 FLOPs(需要安裝 thop):
import thop
x = torch.randn(1, 1, 100, 256)
model = test_model_torch()
flops, _ = thop.profile(model, inputs=(x,))
print('FLOPs of pytorch 1.10.1:', flops * 2)
需要注意的是,thop 返回的是 MACs (Multiply–Accumulate Operations),其等于 2 2 2 倍的 FLOPs,所以上述代碼有乘 2 2 2 操作。
3.4. 結果對比
三者計算出的 FLOPs 分別為:
tensorflow 1.12.0:
tensorflow 2.3.1:
pytorch 1.10.1:
可以看到 tensorflow 1.12.0 和 tensorflow 2.3.1 的結果基本在同一個量級,而與 pytorch 1.10.1 計算出來的相差甚遠。但如果將上述模型結構改為只包含第一層 Conv2D,三者計算出來的 FLOPs 卻又是一致的。所以推斷差異主要來自于 GRU 的 FLOPs。如讀者知道其中詳情,還請不吝賜教。
4. 總結
本文給出了在 tensorflow 1.x, tensorflow 2.x 以及 pytorch 中利用相關工具計算模型 FLOPs 的方法,但從本文所使用的測試模型來看, tensorflow 與 pytorch 統計出的結果相差甚遠。當然,也可以根據網絡層的類型及其對應的參數,推導計算出每個網絡層所需的 FLOPs。
原文鏈接:https://blog.csdn.net/wjrenxinlei/article/details/127973081
相關推薦
- 2022-01-09 el-checkbox 狀態切換,將boolean轉換成1遇到的問題
- 2022-10-18 使用shell腳本快速登錄容器的實現步驟_linux shell
- 2022-06-02 CKAD認證中部署k8s并配置Calico插件_云和虛擬化
- 2022-07-04 如何用python實現結構體數組_python
- 2022-08-10 Python多任務版靜態Web服務器實現示例_python
- 2022-10-23 Go語言數據結構之希爾排序示例詳解_Golang
- 2022-07-10 Linux安裝及管理程序
- 2022-12-04 C#目錄和文件管理操作詳解_C#教程
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支