網(wǎng)站首頁(yè) 編程語(yǔ)言 正文
python神經(jīng)網(wǎng)絡(luò)Batch?Normalization底層原理詳解_python
作者:Bubbliiiing ? 更新時(shí)間: 2022-07-01 編程語(yǔ)言什么是Batch Normalization
Batch Normalization是神經(jīng)網(wǎng)絡(luò)中常用的層,解決了很多深度學(xué)習(xí)中遇到的問(wèn)題,我們一起來(lái)學(xué)習(xí)一哈。
Batch Normalization是由google提出的一種訓(xùn)練優(yōu)化方法。參考論文:Batch Normalization Accelerating Deep Network Training by Reducing Internal Covariate Shift。
Batch Normalization的名稱為批標(biāo)準(zhǔn)化,它的功能是使得輸入的X數(shù)據(jù)符合同一分布,從而使得訓(xùn)練更加簡(jiǎn)單、快速。
一般來(lái)講,Batch Normalization會(huì)放在卷積層后面,即卷積 + 標(biāo)準(zhǔn)化 + 激活函數(shù)。
其計(jì)算過(guò)程可以簡(jiǎn)單歸納為以下3點(diǎn):
1、求數(shù)據(jù)均值。
2、求數(shù)據(jù)方差。
3、數(shù)據(jù)進(jìn)行標(biāo)準(zhǔn)化。
Batch Normalization的計(jì)算公式
Batch Normalization的計(jì)算公式主要看如下這幅圖:
這個(gè)公式一定要靜下心來(lái)看,整個(gè)公式可以分為四行:
1、對(duì)輸入進(jìn)來(lái)的數(shù)據(jù)X進(jìn)行均值求取。
2、利用輸入進(jìn)來(lái)的數(shù)據(jù)X減去第一步得到的均值,然后求平方和,獲得輸入X的方差。
3、利用輸入X、第一步獲得的均值和第二步獲得的方差對(duì)數(shù)據(jù)進(jìn)行歸一化,即利用X減去均值,然后除上方差開(kāi)根號(hào)。方差開(kāi)根號(hào)前需要添加上一個(gè)極小值。
4、引入γ和β變量,對(duì)輸入進(jìn)來(lái)的數(shù)據(jù)進(jìn)行縮放和平移。利用γ和β兩個(gè)參數(shù),讓我們的網(wǎng)絡(luò)可以學(xué)習(xí)恢復(fù)出原始網(wǎng)絡(luò)所要學(xué)習(xí)的特征分布。
前三步是標(biāo)準(zhǔn)化工序,最后一步是反標(biāo)準(zhǔn)化工序。
Bn層的好處
1、加速網(wǎng)絡(luò)的收斂速度。在神經(jīng)網(wǎng)絡(luò)中,存在內(nèi)部協(xié)變量偏移的現(xiàn)象,如果每層的數(shù)據(jù)分布不同的話,會(huì)導(dǎo)致非常難收斂,如果把每層的數(shù)據(jù)都在轉(zhuǎn)換在均值為零,方差為1的狀態(tài)下,這樣每層數(shù)據(jù)的分布都是一樣的,訓(xùn)練會(huì)比較容易收斂。
2、防止梯度爆炸和梯度消失。對(duì)于梯度消失而言,以Sigmoid函數(shù)為例,它會(huì)使得輸出在[0,1]之間,實(shí)際上當(dāng)x到了一定的大小,sigmoid激活函數(shù)的梯度值就變得非常小,不易訓(xùn)練。歸一化數(shù)據(jù)的話,就能讓梯度維持在比較大的值和變化率;
對(duì)于梯度爆炸而言,在方向傳播的過(guò)程中,每一層的梯度都是由上一層的梯度乘以本層的數(shù)據(jù)得到。如果歸一化的話,數(shù)據(jù)均值都在0附近,很顯然,每一層的梯度不會(huì)產(chǎn)生爆炸的情況。
3、防止過(guò)擬合。在網(wǎng)絡(luò)的訓(xùn)練中,Bn使得一個(gè)minibatch中所有樣本都被關(guān)聯(lián)在了一起,因此網(wǎng)絡(luò)不會(huì)從某一個(gè)訓(xùn)練樣本中生成確定的結(jié)果,這樣就會(huì)使得整個(gè)網(wǎng)絡(luò)不會(huì)朝這一個(gè)方向使勁學(xué)習(xí)。一定程度上避免了過(guò)擬合。
為什么要引入γ和β變量
Bn層在進(jìn)行前三步后,會(huì)引入γ和β變量,對(duì)輸入進(jìn)來(lái)的數(shù)據(jù)進(jìn)行縮放和平移。
γ和β變量是網(wǎng)絡(luò)參數(shù),是可學(xué)習(xí)的。
引入γ和β變量進(jìn)行縮放平移可以使得神經(jīng)網(wǎng)絡(luò)有自適應(yīng)的能力,在標(biāo)準(zhǔn)化效果好時(shí),盡量不抵消標(biāo)準(zhǔn)化的作用,而在標(biāo)準(zhǔn)化效果不好時(shí),盡量去抵消一部分標(biāo)準(zhǔn)化的效果,相當(dāng)于讓神經(jīng)網(wǎng)絡(luò)學(xué)會(huì)要不要標(biāo)準(zhǔn)化,如何折中選擇。
Bn層的代碼實(shí)現(xiàn)
Pytorch代碼看起來(lái)比較簡(jiǎn)單,而且和上面的公式非常符合,可以學(xué)習(xí)一下,參考自
https://www.jb51.net/article/247197.htm
def batch_norm(is_training, x, gamma, beta, moving_mean, moving_var, eps=1e-5, momentum=0.9):
if not is_training:
x_hat = (x - moving_mean) / torch.sqrt(moving_var + eps)
else:
mean = x.mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
var = ((x - mean) ** 2).mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
x_hat = (x - mean) / torch.sqrt(var + eps)
moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
moving_var = momentum * moving_var + (1.0 - momentum) * var
Y = gamma * x_hat + beta
return Y, moving_mean, moving_var
class BatchNorm2d(nn.Module):
def __init__(self, num_features):
super(BatchNorm2d, self).__init__()
shape = (1, num_features, 1, 1)
self.gamma = nn.Parameter(torch.ones(shape))
self.beta = nn.Parameter(torch.zeros(shape))
self.register_buffer('moving_mean', torch.zeros(shape))
self.register_buffer('moving_var', torch.ones(shape))
def forward(self, x):
if self.moving_mean.device != x.device:
self.moving_mean = self.moving_mean.to(x.device)
self.moving_var = self.moving_var.to(x.device)
y, self.moving_mean, self.moving_var = batch_norm(self.training,
x, self.gamma, self.beta, self.moving_mean,
self.moving_var, eps=1e-5, momentum=0.9)
return y
原文鏈接:https://blog.csdn.net/weixin_44791964/article/details/114998793
相關(guān)推薦
- 2022-06-01 Python學(xué)習(xí)之虛擬環(huán)境原理詳解_python
- 2024-03-15 docker安裝RabbitMq插件
- 2022-04-23 uni-app項(xiàng)目之商品列表的下拉刷新與上拉加載更多
- 2022-07-13 Andorid 自定義 View - 自定義屬性 - 屬性重復(fù)導(dǎo)致沖突
- 2022-06-12 C語(yǔ)言詳解float類型在內(nèi)存中的存儲(chǔ)方式_C 語(yǔ)言
- 2023-04-24 Python?argparse中的action=store_true用法小結(jié)_python
- 2022-07-06 Python3?DataFrame缺失值的處理方法_python
- 2022-12-13 C++?Boost實(shí)現(xiàn)數(shù)字與字符串轉(zhuǎn)化詳解_C 語(yǔ)言
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細(xì)win安裝深度學(xué)習(xí)環(huán)境2025年最新版(
- Linux 中運(yùn)行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲(chǔ)小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運(yùn)算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認(rèn)證信息的處理
- Spring Security之認(rèn)證過(guò)濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設(shè)
- maven:解決release錯(cuò)誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實(shí)現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務(wù)發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡(jiǎn)單動(dòng)態(tài)字符串(SD
- arthas操作spring被代理目標(biāo)對(duì)象命令
- Spring中的單例模式應(yīng)用詳解
- 聊聊消息隊(duì)列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠(yuǎn)程分支