網站首頁 編程語言 正文
批標準化層 tf.keras.layers.Batchnormalization()
tf.keras.layers.Batchnormalization()
重要參數:
-
training
:布爾值,指示圖層應在訓練模式還是在推理模式下運行。 -
training=True
:該圖層將使用當前批輸入的均值和方差對其輸入進行標準化。 -
training=False
:該層將使用在訓練期間學習的移動統計數據的均值和方差來標準化其輸入。
BatchNormalization 廣泛用于 Keras 內置的許多高級卷積神經網絡架構,比如 ResNet50、Inception V3 和 Xception。
BatchNormalization 層通常在卷積層或密集連接層之后使用。
批標準化的實現過程
- 求每一個訓練批次數據的均值
- 求每一個訓練批次數據的方差
- 數據進行標準化
- 訓練參數γ,β
- 輸出y通過γ與β的線性變換得到原來的數值
在訓練的正向傳播中,不會改變當前輸出,只記錄下γ與β。在反向傳播的時候,根據求得的γ與β通過鏈式求導方式,求出學習速率以至改變權值。
對于預測階段時所使用的均值和方差,其實也是來源于訓練集。比如我們在模型訓練時我們就記錄下每個batch下的均值和方差,待訓練完畢后,我們求整個訓練樣本的均值和方差期望值,作為我們進行預測時進行BN的的均值和方差。
批標準化的使用位置
原始論文講在CNN中一般應作用與非線性激活函數之前,但是,實際上放在激活函數之后效果可能會更好。
# 放在非線性激活函數之前 model.add(tf.keras.layers.Conv2D(64, (3, 3))) model.add(tf.keras.layers.BatchNormalization()) model.add(tf.keras.layers.Activation('relu')) # 放在激活函數之后 model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu')) model.add(tf.keras.layers.BatchNormalization())
tf.keras.layers.BatchNormalization使用細節
關于keras中的BatchNormalization使用,官方文檔說的足夠詳細。本文的目的旨在說明在BatchNormalization的使用過程中容易被忽略的細節。
在BatchNormalization的Arguments參數中有trainable屬性;以及在Call arguments參數中有training。兩個都是bool類型。第一次看到有兩個參數的時候,我有點懵,為什么需要兩個?
后來在查閱資料后發現了兩者的不同作用。
1,trainable是Argument參數,類似于c++中構造函數的參數一樣,是構建一個BatchNormalization層時就需要傳入的,至于它的作用在下面會講到。
2,training參數時Call argument(調用參數),是運行過程中需要傳入的,用來控制模型在那個模式(train還是interfere)下運行。關于這個參數,如果使用模型調用fit()的話,是可以不給的(官方推薦是不給),因為在fit()的時候,模型會自己根據相應的階段(是train階段還是inference階段)決定training值,這是由learning——phase機制實現的。
重點
關于trainable=False:如果設置trainable=False,那么這一層的BatchNormalization層就會被凍結(freeze),它的trainable weights(可訓練參數)(就是gamma和beta)就不會被更新。
注意:freeze mode和inference mode是兩個概念。
但是,在BatchNormalization層中,如果把某一層BatchNormalization層設置為trainable=False,那么這一層BatchNormalization層將一inference mode運行,也就是說(meaning that it will use the moving mean and the moving variance to normalize the current batch, rather than using the mean and variance of the current batch).
總結
原文鏈接:https://blog.csdn.net/weixin_46072771/article/details/108591263
- 上一篇:沒有了
- 下一篇:沒有了
相關推薦
- 2022-04-17 sa-token快速添加多鑒權體系
- 2022-11-16 Python中Pygame模塊的詳細安裝過程_python
- 2022-05-26 Python?pass語句作用和Python?assert斷言函數的用法_python
- 2022-11-07 python學習pymongo模塊的使用方法_python
- 2022-09-03 C#中DataSet、DataTable、DataRow數據的復制方法_C#教程
- 2022-12-07 C++?基本數據類型中int、long等整數類型取值范圍及原理分析_C 語言
- 2022-08-21 C語言數據結構之單鏈表的實現_C 語言
- 2022-05-02 Python中的變量和數據類型詳情_python
- 欄目分類
-
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支