網(wǎng)站首頁(yè) 編程語(yǔ)言 正文
正則化DropPath/drop_path用法示例(Python實(shí)現(xiàn))_python
作者:風(fēng)巽·劍染春水 ? 更新時(shí)間: 2022-06-13 編程語(yǔ)言DropPath/drop_path 是一種正則化手段,其效果是將深度學(xué)習(xí)模型中的多分支結(jié)構(gòu)隨機(jī)”刪除“,python中實(shí)現(xiàn)如下所示:
def drop_path(x, drop_prob: float = 0., training: bool = False):
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
調(diào)用如下:
self.drop_path = DropPath(drop_prob) if drop_prob > 0. else nn.Identity()
x = x + self.drop_path(self.token_mixer(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
看起來(lái)似乎有點(diǎn)迷茫,這怎么就隨機(jī)刪除了分支呢
實(shí)驗(yàn)如下:
import torch
drop_prob = 0.2
keep_prob = 1 - drop_prob
x = torch.randn(4, 3, 2, 2)
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_()
output = x.div(keep_prob) * random_tensor
輸出:
x.size():[4,3,2,2]
x:
tensor([[[[ 1.3833, -0.3703],
? ? ? ? ? [-0.4608, ?0.6955]],
? ? ? ? ?[[ 0.8306, ?0.6882],
? ? ? ? ? [ 2.2375, ?1.6158]],
? ? ? ? ?[[-0.7108, ?1.0498],
? ? ? ? ? [ 0.6783, ?1.5673]]],? ? ? ? [[[-0.0258, -1.7539],
? ? ? ? ? [-2.0789, -0.9648]],
? ? ? ? ?[[ 0.8598, ?0.9351],
? ? ? ? ? [-0.3405, ?0.0070]],
? ? ? ? ?[[ 0.3069, -1.5878],
? ? ? ? ? [-1.1333, -0.5932]]],? ? ? ? [[[ 1.0379, ?0.6277],
? ? ? ? ? [ 0.0153, -0.4764]],
? ? ? ? ?[[ 1.0115, -0.0271],
? ? ? ? ? [ 1.6610, -0.2410]],
? ? ? ? ?[[ 0.0681, -2.0821],
? ? ? ? ? [ 0.6137, ?0.1157]]],? ? ? ? [[[ 0.5350, -2.8424],
? ? ? ? ? [ 0.6648, -1.6652]],
? ? ? ? ?[[ 0.0122, ?0.3389],
? ? ? ? ? [-1.1071, -0.6179]],
? ? ? ? ?[[-0.1843, -1.3026],
? ? ? ? ? [-0.3247, ?0.3710]]]])
random_tensor.size():[4, 1, 1, 1]
random_tensor:
tensor([[[[0.]]],
[[[1.]]],
[[[1.]]],
[[[1.]]]])
output.size():[4,3,2,2]
output:
tensor([[[[ 0.0000, -0.0000],
[-0.0000, 0.0000]],
[[ 0.0000, 0.0000],
[ 0.0000, 0.0000]],
[[-0.0000, 0.0000],
[ 0.0000, 0.0000]]],
[[[-0.0322, -2.1924],
[-2.5986, -1.2060]],
[[ 1.0748, 1.1689],
[-0.4256, 0.0088]],
[[ 0.3836, -1.9848],
[-1.4166, -0.7415]]],
[[[ 1.2974, 0.7846],
[ 0.0192, -0.5955]],
[[ 1.2644, -0.0339],
[ 2.0762, -0.3012]],
[[ 0.0851, -2.6027],
[ 0.7671, 0.1446]]],
[[[ 0.6687, -3.5530],
[ 0.8310, -2.0815]],
[[ 0.0152, 0.4236],
[-1.3839, -0.7723]],
[[-0.2303, -1.6282],
[-0.4059, 0.4638]]]])
random_tensor作為是否保留分支的直接置0項(xiàng),若drop_path的概率設(shè)為0.2,random_tensor中的數(shù)有0.2的概率為0,而output中被保留概率為0.8。
結(jié)合drop_path的調(diào)用,若x為輸入的張量,其通道為[B,C,H,W],那么drop_path的含義為在一個(gè)Batch_size中,隨機(jī)有drop_prob的樣本,不經(jīng)過(guò)主干,而直接由分支進(jìn)行恒等映射。
總結(jié)
原文鏈接:https://blog.csdn.net/qq_43426908/article/details/121662843
相關(guān)推薦
- 2022-03-29 redis的list數(shù)據(jù)類(lèi)型相關(guān)命令介紹及使用_Redis
- 2022-10-22 SQLMAP插件tamper模塊簡(jiǎn)介_(kāi)MsSql
- 2023-03-18 C#調(diào)用dll報(bào)錯(cuò):無(wú)法加載dll,找不到指定模塊的解決_C#教程
- 2023-07-06 css flex實(shí)現(xiàn)div固定在瀏覽器右下角
- 2024-07-15 在物理及和虛擬主機(jī)上配置ftp,實(shí)現(xiàn)上傳和下載的功能(five day)
- 2022-07-13 Docker的數(shù)據(jù)管理
- 2022-07-27 python中format的用法實(shí)例詳解_python
- 2022-10-06 SQL語(yǔ)句中的ON?DUPLICATE?KEY?UPDATE使用_MsSql
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細(xì)win安裝深度學(xué)習(xí)環(huán)境2025年最新版(
- Linux 中運(yùn)行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲(chǔ)小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運(yùn)算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認(rèn)證信息的處理
- Spring Security之認(rèn)證過(guò)濾器
- Spring Security概述快速入門(mén)
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設(shè)
- maven:解決release錯(cuò)誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實(shí)現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務(wù)發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡(jiǎn)單動(dòng)態(tài)字符串(SD
- arthas操作spring被代理目標(biāo)對(duì)象命令
- Spring中的單例模式應(yīng)用詳解
- 聊聊消息隊(duì)列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠(yuǎn)程分支