網站首頁 編程語言 正文
DropPath/drop_path 是一種正則化手段,其效果是將深度學習模型中的多分支結構隨機”刪除“,python中實現如下所示:
def drop_path(x, drop_prob: float = 0., training: bool = False):
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
調用如下:
self.drop_path = DropPath(drop_prob) if drop_prob > 0. else nn.Identity()
x = x + self.drop_path(self.token_mixer(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
看起來似乎有點迷茫,這怎么就隨機刪除了分支呢
實驗如下:
import torch
drop_prob = 0.2
keep_prob = 1 - drop_prob
x = torch.randn(4, 3, 2, 2)
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_()
output = x.div(keep_prob) * random_tensor
輸出:
x.size():[4,3,2,2]
x:
tensor([[[[ 1.3833, -0.3703],
? ? ? ? ? [-0.4608, ?0.6955]],
? ? ? ? ?[[ 0.8306, ?0.6882],
? ? ? ? ? [ 2.2375, ?1.6158]],
? ? ? ? ?[[-0.7108, ?1.0498],
? ? ? ? ? [ 0.6783, ?1.5673]]],? ? ? ? [[[-0.0258, -1.7539],
? ? ? ? ? [-2.0789, -0.9648]],
? ? ? ? ?[[ 0.8598, ?0.9351],
? ? ? ? ? [-0.3405, ?0.0070]],
? ? ? ? ?[[ 0.3069, -1.5878],
? ? ? ? ? [-1.1333, -0.5932]]],? ? ? ? [[[ 1.0379, ?0.6277],
? ? ? ? ? [ 0.0153, -0.4764]],
? ? ? ? ?[[ 1.0115, -0.0271],
? ? ? ? ? [ 1.6610, -0.2410]],
? ? ? ? ?[[ 0.0681, -2.0821],
? ? ? ? ? [ 0.6137, ?0.1157]]],? ? ? ? [[[ 0.5350, -2.8424],
? ? ? ? ? [ 0.6648, -1.6652]],
? ? ? ? ?[[ 0.0122, ?0.3389],
? ? ? ? ? [-1.1071, -0.6179]],
? ? ? ? ?[[-0.1843, -1.3026],
? ? ? ? ? [-0.3247, ?0.3710]]]])
random_tensor.size():[4, 1, 1, 1]
random_tensor:
tensor([[[[0.]]],
[[[1.]]],
[[[1.]]],
[[[1.]]]])
output.size():[4,3,2,2]
output:
tensor([[[[ 0.0000, -0.0000],
[-0.0000, 0.0000]],
[[ 0.0000, 0.0000],
[ 0.0000, 0.0000]],
[[-0.0000, 0.0000],
[ 0.0000, 0.0000]]],
[[[-0.0322, -2.1924],
[-2.5986, -1.2060]],
[[ 1.0748, 1.1689],
[-0.4256, 0.0088]],
[[ 0.3836, -1.9848],
[-1.4166, -0.7415]]],
[[[ 1.2974, 0.7846],
[ 0.0192, -0.5955]],
[[ 1.2644, -0.0339],
[ 2.0762, -0.3012]],
[[ 0.0851, -2.6027],
[ 0.7671, 0.1446]]],
[[[ 0.6687, -3.5530],
[ 0.8310, -2.0815]],
[[ 0.0152, 0.4236],
[-1.3839, -0.7723]],
[[-0.2303, -1.6282],
[-0.4059, 0.4638]]]])
random_tensor作為是否保留分支的直接置0項,若drop_path的概率設為0.2,random_tensor中的數有0.2的概率為0,而output中被保留概率為0.8。
結合drop_path的調用,若x為輸入的張量,其通道為[B,C,H,W],那么drop_path的含義為在一個Batch_size中,隨機有drop_prob的樣本,不經過主干,而直接由分支進行恒等映射。
總結
原文鏈接:https://blog.csdn.net/qq_43426908/article/details/121662843
相關推薦
- 2022-12-01 django第一個項目127.0.0.1:8000不能訪問的解決方案詳析_python
- 2022-12-28 golang?gin?監聽rabbitmq隊列無限消費的案例代碼_Golang
- 2022-07-01 Keras實現Vision?Transformer?VIT模型示例詳解_python
- 2022-03-07 android?studio?項目?:UI設計高精度實現簡單計算器_Android
- 2022-06-06 canvas保存圖片時,谷歌瀏覽器Chrome報錯解決方案Not allowed to naviga
- 2022-09-07 Python實現不寫硬盤上傳文件_python
- 2022-02-01 微信小程序批量獲取input的輸入值,監聽輸入框,數據同步
- 2022-05-31 使用python把Excel中的數據在頁面中可視化_python
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支