網站首頁 編程語言 正文
什么是Densenet
據說Densenet比Resnet還要厲害,我決定好好學一下。
ResNet模型的出現使得深度學習神經網絡可以變得更深,進而實現了更高的準確度。
ResNet模型的核心是通過建立前面層與后面層之間的短路連接(shortcuts),這有助于訓練過程中梯度的反向傳播,從而能訓練出更深的CNN網絡。
DenseNet模型,它的基本思路與ResNet一致,也是建立前面層與后面層的短路連接,不同的是,但是它建立的是前面所有層與后面層的密集連接。
DenseNet還有一個特點是實現了特征重用。
這些特點讓DenseNet在參數和計算成本更少的情形下實現比ResNet更優的性能。
DenseNet示意圖如下:
代碼下載
Densenet
1、Densenet的整體結構
如圖所示Densenet由DenseBlock和中間的間隔模塊Transition Layer組成。
1、DenseBlock:DenseBlock指的就是DenseNet特有的模塊,如下圖所示,前面所有層與后面層的具有密集連接,在同一個DenseBlock當中,特征層的高寬不會發生改變,但是通道數會發生改變。
2、Transition Layer:Transition Layer是將不同DenseBlock之間進行連接的模塊,主要功能是整合上一個DenseBlock獲得的特征,并且縮小上一個DenseBlock的寬高,在Transition Layer中,一般會使用一個步長為2的AveragePooling2D縮小特征層的寬高。
2、DenseBlock
DenseBlock的實現示意圖如圖所示:
以前獲得的特征會在保留后不斷的堆疊起來。
以一個簡單例子來表現一下具體的DenseBlock的流程:
假設輸入特征層為X0。
1、對x0進行一次1x1卷積調整通道數到4*32后,再利用3x3卷積獲得一個32通道的特征層,此時會獲得一個shape為(h,w,32)的特征層x1。
2、將獲得的x1和初始的x0堆疊,獲得一個新的特征層,這個特征層會同時保留初始x0的特征也會保留經過卷積處理后的特征。
3、反復經過步驟1、2的處理,原始的特征會一直得到保留,經過卷積處理后的特征也會得到保留。當網絡程度不斷加深,就可以實現前面所有層與后面層的具有密集連接。
實現代碼為:
def dense_block(x, blocks, name):
for i in range(blocks):
x = conv_block(x, 32, name=name + '_block' + str(i + 1))
return x
def conv_block(x, growth_rate, name):
bn_axis = 3
x1 = layers.BatchNormalization(axis=bn_axis,
epsilon=1.001e-5,
name=name + '_0_bn')(x)
x1 = layers.Activation('relu', name=name + '_0_relu')(x1)
x1 = layers.Conv2D(4 * growth_rate, 1,
use_bias=False,
name=name + '_1_conv')(x1)
x1 = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
name=name + '_1_bn')(x1)
x1 = layers.Activation('relu', name=name + '_1_relu')(x1)
x1 = layers.Conv2D(growth_rate, 3,
padding='same',
use_bias=False,
name=name + '_2_conv')(x1)
x = layers.Concatenate(axis=bn_axis, name=name + '_concat')([x, x1])
return x
3、Transition Layer
Transition Layer將不同DenseBlock之間進行連接的模塊,主要功能是整合上一個DenseBlock獲得的特征,并且縮小上一個DenseBlock的寬高,在Transition Layer中,一般會使用一個步長為2的AveragePooling2D縮小特征層的寬高。
實現代碼為:
def transition_block(x, reduction, name):
bn_axis = 3
x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
name=name + '_bn')(x)
x = layers.Activation('relu', name=name + '_relu')(x)
x = layers.Conv2D(int(backend.int_shape(x)[bn_axis] * reduction), 1,
use_bias=False,
name=name + '_conv')(x)
x = layers.AveragePooling2D(2, strides=2, name=name + '_pool')(x)
return x
網絡實現代碼
from keras.preprocessing import image
from keras.models import Model
from keras import layers
from keras.applications import imagenet_utils
from keras.applications.imagenet_utils import decode_predictions
from keras.utils.data_utils import get_file
from keras import backend
import numpy as np
BASE_WEIGTHS_PATH = (
'https://github.com/keras-team/keras-applications/'
'releases/download/densenet/')
DENSENET121_WEIGHT_PATH = (
BASE_WEIGTHS_PATH +
'densenet121_weights_tf_dim_ordering_tf_kernels.h5')
DENSENET169_WEIGHT_PATH = (
BASE_WEIGTHS_PATH +
'densenet169_weights_tf_dim_ordering_tf_kernels.h5')
DENSENET201_WEIGHT_PATH = (
BASE_WEIGTHS_PATH +
'densenet201_weights_tf_dim_ordering_tf_kernels.h5')
def dense_block(x, blocks, name):
for i in range(blocks):
x = conv_block(x, 32, name=name + '_block' + str(i + 1))
return x
def conv_block(x, growth_rate, name):
bn_axis = 3
x1 = layers.BatchNormalization(axis=bn_axis,
epsilon=1.001e-5,
name=name + '_0_bn')(x)
x1 = layers.Activation('relu', name=name + '_0_relu')(x1)
x1 = layers.Conv2D(4 * growth_rate, 1,
use_bias=False,
name=name + '_1_conv')(x1)
x1 = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
name=name + '_1_bn')(x1)
x1 = layers.Activation('relu', name=name + '_1_relu')(x1)
x1 = layers.Conv2D(growth_rate, 3,
padding='same',
use_bias=False,
name=name + '_2_conv')(x1)
x = layers.Concatenate(axis=bn_axis, name=name + '_concat')([x, x1])
return x
def transition_block(x, reduction, name):
bn_axis = 3
x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
name=name + '_bn')(x)
x = layers.Activation('relu', name=name + '_relu')(x)
x = layers.Conv2D(int(backend.int_shape(x)[bn_axis] * reduction), 1,
use_bias=False,
name=name + '_conv')(x)
x = layers.AveragePooling2D(2, strides=2, name=name + '_pool')(x)
return x
def DenseNet(blocks,
input_shape=None,
classes=1000,
**kwargs):
img_input = layers.Input(shape=input_shape)
bn_axis = 3
# 224,224,3 -> 112,112,64
x = layers.ZeroPadding2D(padding=((3, 3), (3, 3)))(img_input)
x = layers.Conv2D(64, 7, strides=2, use_bias=False, name='conv1/conv')(x)
x = layers.BatchNormalization(
axis=bn_axis, epsilon=1.001e-5, name='conv1/bn')(x)
x = layers.Activation('relu', name='conv1/relu')(x)
# 112,112,64 -> 56,56,64
x = layers.ZeroPadding2D(padding=((1, 1), (1, 1)))(x)
x = layers.MaxPooling2D(3, strides=2, name='pool1')(x)
# 56,56,64 -> 56,56,64+32*block[0]
# Densenet121 56,56,64 -> 56,56,64+32*6 == 56,56,256
x = dense_block(x, blocks[0], name='conv2')
# 56,56,64+32*block[0] -> 28,28,32+16*block[0]
# Densenet121 56,56,256 -> 28,28,32+16*6 == 28,28,128
x = transition_block(x, 0.5, name='pool2')
# 28,28,32+16*block[0] -> 28,28,32+16*block[0]+32*block[1]
# Densenet121 28,28,128 -> 28,28,128+32*12 == 28,28,512
x = dense_block(x, blocks[1], name='conv3')
# Densenet121 28,28,512 -> 14,14,256
x = transition_block(x, 0.5, name='pool3')
# Densenet121 14,14,256 -> 14,14,256+32*block[2] == 14,14,1024
x = dense_block(x, blocks[2], name='conv4')
# Densenet121 14,14,1024 -> 7,7,512
x = transition_block(x, 0.5, name='pool4')
# Densenet121 7,7,512 -> 7,7,256+32*block[3] == 7,7,1024
x = dense_block(x, blocks[3], name='conv5')
x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name='bn')(x)
x = layers.Activation('relu', name='relu')(x)
x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
x = layers.Dense(classes, activation='softmax', name='fc1000')(x)
inputs = img_input
if blocks == [6, 12, 24, 16]:
model = Model(inputs, x, name='densenet121')
elif blocks == [6, 12, 32, 32]:
model = Model(inputs, x, name='densenet169')
elif blocks == [6, 12, 48, 32]:
model = Model(inputs, x, name='densenet201')
else:
model = Model(inputs, x, name='densenet')
return model
def DenseNet121(input_shape=[224,224,3],
classes=1000,
**kwargs):
return DenseNet([6, 12, 24, 16],
input_shape, classes,
**kwargs)
def DenseNet169(input_shape=[224,224,3],
classes=1000,
**kwargs):
return DenseNet([6, 12, 32, 32],
input_shape, classes,
**kwargs)
def DenseNet201(input_shape=[224,224,3],
classes=1000,
**kwargs):
return DenseNet([6, 12, 48, 32],
input_shape, classes,
**kwargs)
def preprocess_input(x):
x /= 255.
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
x[..., 0] -= mean[0]
x[..., 1] -= mean[1]
x[..., 2] -= mean[2]
if std is not None:
x[..., 0] /= std[0]
x[..., 1] /= std[1]
x[..., 2] /= std[2]
return x
if __name__ == '__main__':
# model = DenseNet121()
# weights_path = get_file(
# 'densenet121_weights_tf_dim_ordering_tf_kernels.h5',
# DENSENET121_WEIGHT_PATH,
# cache_subdir='models',
# file_hash='9d60b8095a5708f2dcce2bca79d332c7')
model = DenseNet169()
weights_path = get_file(
'densenet169_weights_tf_dim_ordering_tf_kernels.h5',
DENSENET169_WEIGHT_PATH,
cache_subdir='models',
file_hash='d699b8f76981ab1b30698df4c175e90b')
# model = DenseNet201()
# weights_path = get_file(
# 'densenet201_weights_tf_dim_ordering_tf_kernels.h5',
# DENSENET201_WEIGHT_PATH,
# cache_subdir='models',
# file_hash='1ceb130c1ea1b78c3bf6114dbdfd8807')
model.load_weights(weights_path)
model.summary()
img_path = 'elephant.jpg'
img = image.load_img(img_path, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
print('Input image shape:', x.shape)
preds = model.predict(x)
print(np.argmax(preds))
print('Predicted:', decode_predictions(preds))
原文鏈接:https://blog.csdn.net/weixin_44791964/article/details/105472196
相關推薦
- 2024-03-05 git的使用
- 2022-08-25 C語言詳細分析宏定義與預處理命令的應用_C 語言
- 2022-08-23 C++深入講解函數重載_C 語言
- 2022-10-08 Python使用xlrd和xlwt實現自動化操作Excel_python
- 2022-03-29 C#算法之兩數之和_C#教程
- 2023-02-17 C++可執行文件絕對路徑值與VS安全檢查詳解_C 語言
- 2022-01-13 封裝axios以及接口管理
- 2022-04-23 uni-app之條件注釋實現跨端兼容
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支