網站首頁 編程語言 正文
Batch Normalization和Dropout是深度學習模型中常用的結構。
但BN和dropout在訓練和測試時使用卻不相同。
Batch Normalization
BN在訓練時是在每個batch上計算均值和方差來進行歸一化,每個batch的樣本量都不大,所以每次計算出來的均值和方差就存在差異。預測時一般傳入一個樣本,所以不存在歸一化,其次哪怕是預測一個batch,但batch計算出來的均值和方差是偏離總體樣本的,所以通常是通過滑動平均結合訓練時所有batch的均值和方差來得到一個總體均值和方差。
以tensorflow代碼實現為例:
def bn_layer(self, inputs, training, name='bn', moving_decay=0.9, eps=1e-5):
# 獲取輸入維度并判斷是否匹配卷積層(4)或者全連接層(2)
shape = inputs.shape
param_shape = shape[-1]
with tf.variable_scope(name):
# 聲明BN中唯一需要學習的兩個參數,y=gamma*x+beta
gamma = tf.get_variable('gamma', param_shape, initializer=tf.constant_initializer(1))
beta = tf.get_variable('beat', param_shape, initializer=tf.constant_initializer(0))
# 計算當前整個batch的均值與方差
axes = list(range(len(shape)-1))
batch_mean, batch_var = tf.nn.moments(inputs , axes, name='moments')
# 采用滑動平均更新均值與方差
ema = tf.train.ExponentialMovingAverage(moving_decay, name="ema")
def mean_var_with_update():
ema_apply_op = ema.apply([batch_mean, batch_var])
with tf.control_dependencies([ema_apply_op]):
return tf.identity(batch_mean), tf.identity(batch_var)
# 訓練時,更新均值與方差,測試時使用之前最后一次保存的均值與方差
mean, var = tf.cond(tf.equal(training,True), mean_var_with_update,
lambda:(ema.average(batch_mean), ema.average(batch_var)))
# 最后執行batch normalization
return tf.nn.batch_normalization(inputs ,mean, var, beta, gamma, eps)
training參數可以通過tf.placeholder傳入,這樣就可以控制訓練和預測時training的值。
self.training = tf.placeholder(tf.bool, name="training")
Dropout
Dropout在訓練時會隨機丟棄一些神經元,這樣會導致輸出的結果變小。而預測時往往關閉dropout,保證預測結果的一致性(不關閉dropout可能同一個輸入會得到不同的輸出,不過輸出會服從某一分布。另外有些情況下可以不關閉dropout,比如文本生成下,不關閉會增大輸出的多樣性)。
為了對齊Dropout訓練和預測的結果,通常有兩種做法,假設dropout rate = 0.2。一種是訓練時不做處理,預測時輸出乘以(1 - dropout rate)。另一種是訓練時留下的神經元除以(1 - dropout rate),預測時不做處理。以tensorflow為例。
x = tf.nn.dropout(x, self.keep_prob)
self.keep_prob = tf.placeholder(tf.float32, name="keep_prob")
tf.nn.dropout就是采用了第二種做法,訓練時除以(1 - dropout rate),源碼如下:
binary_tensor = math_ops.floor(random_tensor)
ret = math_ops.div(x, keep_prob) * binary_tensor
if not context.executing_eagerly():
ret.set_shape(x.get_shape())
return ret
binary_tensor就是一個mask tensor,即里面的值由0或1組成。keep_prob = 1 - dropout rate。
原文鏈接:https://www.cnblogs.com/jiangxinyang/p/14333903.html
相關推薦
- 2022-04-25 C#使用NPOI實現Excel和DataTable的互轉_C#教程
- 2022-04-04 微信小程序:返回上一頁,刷新頁面內容
- 2022-12-08 C語言實現計算圓周長以及面積_C 語言
- 2022-06-14 Docker?配置容器固定IP的方法_docker
- 2024-01-28 spring ioc容器
- 2023-01-30 C++指針和數組:字符和字符串、字符數組的關聯和區別_C 語言
- 2022-04-12 git項目初次push提示error: failed to push some refs to ht
- 2023-05-21 golang代碼中調用Linux命令_Golang
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支