網站首頁 編程語言 正文
pytorch和tensorflow計算Flops和params
1.只計算params
net = model() # 定義好的網絡模型
total = sum([param.nelement() for param in net.parameters()])
print("Number of parameter: %.2fM" % total)
這是網上很常見的直接用自帶方法計算params,基本不會出錯。勝在簡潔。
2.計算flops和params
要計算flops,目前沒見到用自帶方法計算的,基本都是要安裝別的庫。
這邊我們安裝thop庫。
pip install thop # 安裝thop庫
import torch
from thop import profile
net = model() # 定義好的網絡模型
img1 = torch.randn(1, 3, 512, 512)
img2 = torch.randn(1, 3, 512, 512)
img3 = torch.randn(1, 3, 512, 512)
macs, params = profile(net, (img1,img2,img3))
print('flops: ', 2*macs, 'params: ', params)
這邊和其他網上教程的區別便是,他們macs和flops不分。因為macs表示乘加累積操作數,一個乘法加上一個加法才算一個macs。而flops表示浮點運算次數,每一個加、減、乘、除操作都算1FLOPs操作。所以很明顯,在數值上,1flops=2macs。此外,(img1,img2,img3)就表示你如果有三個輸入要輸入模型,就這樣寫。
另外,要注意,params只和模型參數量相關,而和輸入tensor大小無關。但flops和輸入圖片大小是相關的.
3.tensorflow計算params和flops
此處是我找到的一些用于tensorflow計算params和flops的方法,僅供參考,不保證效果。
def get_flops_params():
sess = tf.compat.v1.Session()
graph = sess.graph
flops = tf.compat.v1.profiler.profile(graph, options=tf.compat.v1.profiler.ProfileOptionBuilder.float_operation())
params = tf.compat.v1.profiler.profile(graph,
options=tf.compat.v1.profiler.ProfileOptionBuilder.trainable_variables_parameter())
print('FLOPs: {}; Trainable params: {}'.format(flops.total_float_ops, params.total_parameters))
def count2():
print(np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]))
def get_nb_params_shape(shape):
'''
Computes the total number of params for a given shap.
Works for any number of shapes etc [D,F] or [W,H,C] computes D*F and W*H*C.
'''
nb_params = 1
for dim in shape:
nb_params = nb_params * int(dim)
return nb_params
def count3():
tot_nb_params = 0
for trainable_variable in tf.trainable_variables():
shape = trainable_variable.get_shape() # e.g [D,F] or [W,H,C]
current_nb_params = get_nb_params_shape(shape)
tot_nb_params = tot_nb_params + current_nb_params
print(tot_nb_params)
import tensorflow.compat.v1 as tf
tf.compat.v1.disable_eager_execution()
from model import Model
import keras.backend as K
def get_flops(model):
run_meta = tf.RunMetadata()
opts = tf.profiler.ProfileOptionBuilder.float_operation()
# We use the Keras session graph in the call to the profiler.
flops = tf.profiler.profile(graph=K.get_session().graph,
run_meta=run_meta, cmd='op', options=opts)
return flops.total_float_ops # Prints the "flops" of the model.
# .... Define your model here ....
M = Model(BATCH_SIZE=1, INPUT_H=268, INPUT_W=360, is_training=False)
print(get_flops(M))
原文鏈接:https://blog.csdn.net/qq_40840829/article/details/126334037
相關推薦
- 2022-11-03 python中for循環的多種使用實例_python
- 2023-01-12 Python讀取mat(matlab數據文件)并實現畫圖_python
- 2022-03-25 修改?asp.net?core?5?程序的默認端口號_ASP.NET
- 2023-01-10 Oracle如何獲取數據庫系統的當前時間_oracle
- 2022-09-23 python?pandas創建多層索引MultiIndex的6種方式_python
- 2022-06-21 Android實現登錄界面的注冊功能_Android
- 2022-11-14 關于C++解決內存泄漏問題的心得
- 2022-04-11 Python - logging.Formatter 的常用格式字符串
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支