網站首頁 編程語言 正文
Xception是繼Inception后提出的對Inception v3的另一種改進,學一學總是好的
什么是Xception模型
Xception是谷歌公司繼Inception后,提出的InceptionV3的一種改進模型,其改進的主要內容為采用depthwise separable convolution來替換原來Inception v3中的多尺寸卷積核特征響應操作。
在講Xception模型之前,首先要講一下什么是depthwise separable convolution(深度可分離卷積塊)。
深度可分離卷積塊由兩個部分組成,分別是深度可分離卷積和1x1普通卷積,深度可分離卷積的卷積核大小一般是3x3的,便于理解的話我們可以把它當作是特征提取,1x1的普通卷積可以完成通道數的調整。
下圖為深度可分離卷積塊的結構示意圖:
深度可分離卷積塊的目的是使用更少的參數來代替普通的3x3卷積。
我們可以進行一下普通卷積和深度可分離卷積塊的對比:
假設有一個3×3大小的卷積層,其輸入通道為16、輸出通道為32。具體為,32個3×3大小的卷積核會遍歷16個通道中的每個數據,最后可得到所需的32個輸出通道,所需參數為16×32×3×3=4608個。
應用深度可分離卷積,用16個3×3大小的卷積核分別遍歷16通道的數據,得到了16個特征圖譜。在融合操作之前,接著用32個1×1大小的卷積核遍歷這16個特征圖譜,所需參數為16×3×3+16×32×1×1=656個。
可以看出來depthwise separable convolution可以減少模型的參數。
通俗地理解深度可分離卷積結構塊,就是3x3的卷積核厚度只有一層,然后在輸入張量上一層一層地滑動,每一次卷積完生成一個輸出通道,當卷積完成后,再利用1x1的卷積調整厚度。
(視頻中有些許錯誤,感謝zl960929的提醒,Xception使用的深度可分離卷積塊SeparableConv2D也就是先深度可分離卷積再進行1x1卷積。)
對于Xception模型而言,其一共可以分為3個flow,分別是Entry flow、Middle flow、Exit flow;
分為14個block,其中Entry flow中有4個、Middle flow中有8個、Exit flow中有2個。
具體結構如下:
其內部主要結構就是殘差卷積網絡搭配SeparableConv2D層實現一個個block,在Xception模型中,常見的兩個block的結構如下。這個主要在Entry flow和Exit flow中:
這個主要在Middle flow中:
Xception網絡部分實現代碼
#-------------------------------------------------------------#
# Xception的網絡部分
#-------------------------------------------------------------#
from keras.preprocessing import image
from keras.models import Model
from keras import layers
from keras.layers import Dense,Input,BatchNormalization,Activation,Conv2D,SeparableConv2D,MaxPooling2D
from keras.layers import GlobalAveragePooling2D,GlobalMaxPooling2D
from keras import backend as K
from keras.applications.imagenet_utils import decode_predictions
def Xception(input_shape = [299,299,3],classes=1000):
img_input = Input(shape=input_shape)
#--------------------------#
# Entry flow
#--------------------------#
#--------------------#
# block1
#--------------------#
# 299,299,3 -> 149,149,64
x = Conv2D(32, (3, 3), strides=(2, 2), use_bias=False, name='block1_conv1')(img_input)
x = BatchNormalization(name='block1_conv1_bn')(x)
x = Activation('relu', name='block1_conv1_act')(x)
x = Conv2D(64, (3, 3), use_bias=False, name='block1_conv2')(x)
x = BatchNormalization(name='block1_conv2_bn')(x)
x = Activation('relu', name='block1_conv2_act')(x)
#--------------------#
# block2
#--------------------#
# 149,149,64 -> 75,75,128
residual = Conv2D(128, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual)
x = SeparableConv2D(128, (3, 3), padding='same', use_bias=False, name='block2_sepconv1')(x)
x = BatchNormalization(name='block2_sepconv1_bn')(x)
x = Activation('relu', name='block2_sepconv2_act')(x)
x = SeparableConv2D(128, (3, 3), padding='same', use_bias=False, name='block2_sepconv2')(x)
x = BatchNormalization(name='block2_sepconv2_bn')(x)
x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block2_pool')(x)
x = layers.add([x, residual])
#--------------------#
# block3
#--------------------#
# 75,75,128 -> 38,38,256
residual = Conv2D(256, (1, 1), strides=(2, 2),padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual)
x = Activation('relu', name='block3_sepconv1_act')(x)
x = SeparableConv2D(256, (3, 3), padding='same', use_bias=False, name='block3_sepconv1')(x)
x = BatchNormalization(name='block3_sepconv1_bn')(x)
x = Activation('relu', name='block3_sepconv2_act')(x)
x = SeparableConv2D(256, (3, 3), padding='same', use_bias=False, name='block3_sepconv2')(x)
x = BatchNormalization(name='block3_sepconv2_bn')(x)
x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block3_pool')(x)
x = layers.add([x, residual])
#--------------------#
# block4
#--------------------#
# 38,38,256 -> 19,19,728
residual = Conv2D(728, (1, 1), strides=(2, 2),padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual)
x = Activation('relu', name='block4_sepconv1_act')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name='block4_sepconv1')(x)
x = BatchNormalization(name='block4_sepconv1_bn')(x)
x = Activation('relu', name='block4_sepconv2_act')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name='block4_sepconv2')(x)
x = BatchNormalization(name='block4_sepconv2_bn')(x)
x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block4_pool')(x)
x = layers.add([x, residual])
#--------------------------#
# Middle flow
#--------------------------#
#--------------------#
# block5--block12
#--------------------#
# 19,19,728 -> 19,19,728
for i in range(8):
residual = x
prefix = 'block' + str(i + 5)
x = Activation('relu', name=prefix + '_sepconv1_act')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv1')(x)
x = BatchNormalization(name=prefix + '_sepconv1_bn')(x)
x = Activation('relu', name=prefix + '_sepconv2_act')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv2')(x)
x = BatchNormalization(name=prefix + '_sepconv2_bn')(x)
x = Activation('relu', name=prefix + '_sepconv3_act')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv3')(x)
x = BatchNormalization(name=prefix + '_sepconv3_bn')(x)
x = layers.add([x, residual])
#--------------------------#
# Exit flow
#--------------------------#
#--------------------#
# block13
#--------------------#
# 19,19,728 -> 10,10,1024
residual = Conv2D(1024, (1, 1), strides=(2, 2),
padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual)
x = Activation('relu', name='block13_sepconv1_act')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name='block13_sepconv1')(x)
x = BatchNormalization(name='block13_sepconv1_bn')(x)
x = Activation('relu', name='block13_sepconv2_act')(x)
x = SeparableConv2D(1024, (3, 3), padding='same', use_bias=False, name='block13_sepconv2')(x)
x = BatchNormalization(name='block13_sepconv2_bn')(x)
x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block13_pool')(x)
x = layers.add([x, residual])
#--------------------#
# block14
#--------------------#
# 10,10,1024 -> 10,10,2048
x = SeparableConv2D(1536, (3, 3), padding='same', use_bias=False, name='block14_sepconv1')(x)
x = BatchNormalization(name='block14_sepconv1_bn')(x)
x = Activation('relu', name='block14_sepconv1_act')(x)
x = SeparableConv2D(2048, (3, 3), padding='same', use_bias=False, name='block14_sepconv2')(x)
x = BatchNormalization(name='block14_sepconv2_bn')(x)
x = Activation('relu', name='block14_sepconv2_act')(x)
x = GlobalAveragePooling2D(name='avg_pool')(x)
x = Dense(classes, activation='softmax', name='predictions')(x)
inputs = img_input
model = Model(inputs, x, name='xception')
model.load_weights("xception_weights_tf_dim_ordering_tf_kernels.h5")
return model
圖片預測
建立網絡后,可以用以下的代碼進行預測。
def preprocess_input(x):
x /= 255.
x -= 0.5
x *= 2.
return x
if __name__ == '__main__':
model = Xception()
img_path = 'elephant.jpg'
img = image.load_img(img_path, target_size=(299, 299))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
print('Input image shape:', x.shape)
preds = model.predict(x)
print(np.argmax(preds))
print('Predicted:', decode_predictions(preds))
預測所需的已經訓練好的Xception模型可以在https://github.com/fchollet/deep-learning-models/releases下載。非常方便。
預測結果為:
Predicted: [[('n02504458', 'African_elephant', 0.47570863), ('n01871265', 'tusker', 0.3173351), ('n02504013', 'Indian_elephant', 0.030323735), ('n02963159', 'cardigan', 0.0007877756), ('n02410509', 'bison', 0.00075616257)]]
原文鏈接:https://blog.csdn.net/weixin_44791964/article/details/102813517
相關推薦
- 2022-10-23 C#中的yield關鍵字詳解_C#教程
- 2022-09-17 C++詳解PIMPL指向實現的指針_C 語言
- 2022-02-20 Android?WebView實現全屏播放視頻_Android
- 2024-03-08 SpringBoot項目多模塊開發詳解
- 2022-08-15 數據結構之鏈式棧的實現與簡單運用
- 2024-07-13 IDEA無法使用@WebServlet()注解
- 2022-03-21 C語言楊輝三角兩種實現方法_C 語言
- 2022-09-01 React新文檔切記不要濫用effect_React
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支