網站首頁 編程語言 正文
什么是MobileNet模型
MobileNet是一種輕量級網絡,相比于其它結構網絡,它不一定是最準的,但是它真的很輕
MobileNet模型是Google針對手機等嵌入式設備提出的一種輕量級的深層神經網絡,其使用的核心思想便是depthwise separable convolution。
對于一個卷積點而言:
假設有一個3×3大小的卷積層,其輸入通道為16、輸出通道為32。具體為,32個3×3大小的卷積核會遍歷16個通道中的每個數據,最后可得到所需的32個輸出通道,所需參數為16×32×3×3=4608個。
應用深度可分離卷積,用16個3×3大小的卷積核分別遍歷16通道的數據,得到了16個特征圖譜。在融合操作之前,接著用32個1×1大小的卷積核遍歷這16個特征圖譜,所需參數為16×3×3+16×32×1×1=656個。
可以看出來depthwise separable convolution可以減少模型的參數。
如下這張圖就是depthwise separable convolution的結構
在建立模型的時候,可以使用Keras中的DepthwiseConv2D層實現深度可分離卷積,然后再利用1x1卷積調整channels數。
通俗地理解就是3x3的卷積核厚度只有一層,然后在輸入張量上一層一層地滑動,每一次卷積完生成一個輸出通道,當卷積完成后,在利用1x1的卷積調整厚度。
如下就是MobileNet的結構,其中Conv dw就是分層卷積,在其之后都會接一個1x1的卷積進行通道處理,
MobileNet網絡部分實現代碼
#-------------------------------------------------------------#
# MobileNet的網絡部分
#-------------------------------------------------------------#
import warnings
import numpy as np
from keras.preprocessing import image
from keras.models import Model
from keras.layers import DepthwiseConv2D,Input,Activation,Dropout,Reshape,BatchNormalization,GlobalAveragePooling2D,GlobalMaxPooling2D,Conv2D
from keras.applications.imagenet_utils import decode_predictions
from keras import backend as K
def MobileNet(input_shape=[224,224,3],
depth_multiplier=1,
dropout=1e-3,
classes=1000):
img_input = Input(shape=input_shape)
# 224,224,3 -> 112,112,32
x = _conv_block(img_input, 32, strides=(2, 2))
# 112,112,32 -> 112,112,64
x = _depthwise_conv_block(x, 64, depth_multiplier, block_id=1)
# 112,112,64 -> 56,56,128
x = _depthwise_conv_block(x, 128, depth_multiplier,
strides=(2, 2), block_id=2)
# 56,56,128 -> 56,56,128
x = _depthwise_conv_block(x, 128, depth_multiplier, block_id=3)
# 56,56,128 -> 28,28,256
x = _depthwise_conv_block(x, 256, depth_multiplier,
strides=(2, 2), block_id=4)
# 28,28,256 -> 28,28,256
x = _depthwise_conv_block(x, 256, depth_multiplier, block_id=5)
# 28,28,256 -> 14,14,512
x = _depthwise_conv_block(x, 512, depth_multiplier,
strides=(2, 2), block_id=6)
# 14,14,512 -> 14,14,512
x = _depthwise_conv_block(x, 512, depth_multiplier, block_id=7)
x = _depthwise_conv_block(x, 512, depth_multiplier, block_id=8)
x = _depthwise_conv_block(x, 512, depth_multiplier, block_id=9)
x = _depthwise_conv_block(x, 512, depth_multiplier, block_id=10)
x = _depthwise_conv_block(x, 512, depth_multiplier, block_id=11)
# 14,14,512 -> 7,7,1024
x = _depthwise_conv_block(x, 1024, depth_multiplier,
strides=(2, 2), block_id=12)
x = _depthwise_conv_block(x, 1024, depth_multiplier, block_id=13)
# 7,7,1024 -> 1,1,1024
x = GlobalAveragePooling2D()(x)
x = Reshape((1, 1, 1024), name='reshape_1')(x)
x = Dropout(dropout, name='dropout')(x)
x = Conv2D(classes, (1, 1),padding='same', name='conv_preds')(x)
x = Activation('softmax', name='act_softmax')(x)
x = Reshape((classes,), name='reshape_2')(x)
inputs = img_input
model = Model(inputs, x, name='mobilenet_1_0_224_tf')
model_name = 'mobilenet_1_0_224_tf.h5'
model.load_weights(model_name)
return model
def _conv_block(inputs, filters, kernel=(3, 3), strides=(1, 1)):
x = Conv2D(filters, kernel,
padding='same',
use_bias=False,
strides=strides,
name='conv1')(inputs)
x = BatchNormalization(name='conv1_bn')(x)
return Activation(relu6, name='conv1_relu')(x)
def _depthwise_conv_block(inputs, pointwise_conv_filters,
depth_multiplier=1, strides=(1, 1), block_id=1):
x = DepthwiseConv2D((3, 3),
padding='same',
depth_multiplier=depth_multiplier,
strides=strides,
use_bias=False,
name='conv_dw_%d' % block_id)(inputs)
x = BatchNormalization(name='conv_dw_%d_bn' % block_id)(x)
x = Activation(relu6, name='conv_dw_%d_relu' % block_id)(x)
x = Conv2D(pointwise_conv_filters, (1, 1),
padding='same',
use_bias=False,
strides=(1, 1),
name='conv_pw_%d' % block_id)(x)
x = BatchNormalization(name='conv_pw_%d_bn' % block_id)(x)
return Activation(relu6, name='conv_pw_%d_relu' % block_id)(x)
def relu6(x):
return K.relu(x, max_value=6)
圖片預測
建立網絡后,可以用以下的代碼進行預測。
def preprocess_input(x):
x /= 255.
x -= 0.5
x *= 2.
return x
if __name__ == '__main__':
model = MobileNet(input_shape=(224, 224, 3))
img_path = 'elephant.jpg'
img = image.load_img(img_path, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
print('Input image shape:', x.shape)
preds = model.predict(x)
print(np.argmax(preds))
print('Predicted:', decode_predictions(preds, 1))
預測所需的已經訓練好的Xception模型可以在https://github.com/fchollet/deep-learning-models/releases下載。非常方便。預測結果為:
Predicted: [[('n02504458', 'African_elephant', 0.7590296)]]
原文鏈接:https://blog.csdn.net/weixin_44791964/article/details/102819915
相關推薦
- 2023-11-11 Flask 表單form.validate_on_submit()什么情況下會是false——解決辦
- 2022-06-17 C語言深入講解棧與堆和靜態存儲區的使用_C 語言
- 2022-07-07 Python?pluggy模塊的用法示例演示_python
- 2022-08-22 Python連接數據庫使用matplotlib畫柱形圖_python
- 2023-04-01 PyTorch基礎之torch.nn.CrossEntropyLoss交叉熵損失_python
- 2023-02-18 Go?gin權限驗證實現過程詳解_Golang
- 2024-03-19 Rust 中Self 關鍵字的兩種不同用法
- 2022-09-03 redis?主從哨兵模式實現一主二從_Redis
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支