網(wǎng)站首頁 編程語言 正文
(數(shù)據(jù))圖像預(yù)處理——image augmentation圖像增廣之cutout、Mixup、CutMix方法及其實現(xiàn)
作者:甘霖佳佳 更新時間: 2022-01-31 編程語言圖片增廣(增強) image-augmentation
圖像增強即通過一系列的隨機變化生成大量“新的樣本”,從而減低過擬合的可能。現(xiàn)在在深度卷積神經(jīng)網(wǎng)絡(luò)訓(xùn)練中,圖像增強是必不可少的一部分。
常用增廣方法
圖像增廣方法一般分為兩類:一是對圖片做變形,二是對圖片做顏色變化
圖像增廣的一般方法的代碼和實現(xiàn)見以下鏈接,我們不再闡述。
深度學(xué)習(xí)圖像數(shù)據(jù)增廣方法總結(jié)
下面我們實現(xiàn)兩種圖像增強的高級方法:Cutout、Mixup和CutMix。
Mixup方法
Mixup is 是一個普遍通用的數(shù)據(jù)增強原則。本質(zhì)上,mixup訓(xùn)練神經(jīng)網(wǎng)絡(luò)的凸組合的例子和他們的標簽。通過這樣做,mixup正則化了神經(jīng)網(wǎng)絡(luò),以支持訓(xùn)練示例之間的簡單線性行為。
如圖所示,Mixup將兩個圖像根據(jù)透明度混淆在一起,使得機器更好的學(xué)習(xí)。
代碼實現(xiàn)
# mixup function
def mixup_data(x, y, alpha=1.0, use_cuda=True):
'''Returns mixed inputs, pairs of targets, and lambda'''
if alpha > 0:
lam = np.random.beta(alpha, alpha) # bata分布隨機數(shù)
else:
lam = 1
batch_size = x.size()[0]
if use_cuda:
index = torch.randperm(batch_size).cuda() # 返回一個[0, batch_size-1]的隨機數(shù)組
else:
index = torch.randperm(batch_size)
mixed_x = lam * x + (1 - lam) * x[index, :]
y_a, y_b = y, y[index]
return mixed_x, y_a, y_b, lam
Cutout方法
Cutout是一種簡單的卷積神經(jīng)網(wǎng)絡(luò)正則化方法,它包括在訓(xùn)練過程中屏蔽輸入圖像的隨機部分。這種技術(shù)模擬閉塞的例子,鼓勵模型在做決策時考慮更多次要的特性,而不是依賴于幾個主要特性的存在。
如圖所示,Cutout方法是隨機選取圖像上一個或者多個正方形區(qū)域?qū)⑵鋼赋?/p>
代碼實現(xiàn)
import torch
import numpy as np
class Cutout(object):
"""Randomly mask out one or more patches from an image.
#
Args:
n_holes (int): Number of patches to cut out of each image.
length (int): The length (in pixels) of each square patch.
"""
def __init__(self, n_holes, length):
self.n_holes = n_holes
self.length = length
def __call__(self, img):
"""
Args:
img (Tensor): Tensor image of size (C, H, W).
Returns:
Tensor: Image with n_holes of dimension length x length cut out of it.
"""
h = img.size(1) #32圖片的高
w = img.size(2) #32圖片的寬
mask = np.ones((h, w), np.float32) #32*32w*h的全1矩陣
for n in range(self.n_holes): #n_holes=2,length=4 選擇2個區(qū)域;每個區(qū)域的邊長為4
y = np.random.randint(h) #0~31隨機選擇一個數(shù) y=4
x = np.random.randint(w) #0~31隨機選擇一個數(shù) x=24
y1 = np.clip(y - self.length // 2, 0, h) #2,0,32 ->2
y2 = np.clip(y + self.length // 2, 0, h) #6,0,32 ->6
x1 = np.clip(x - self.length // 2, 0, w) #24-2,0,32 ->22
x2 = np.clip(x + self.length // 2, 0, w) #24+2,0,32 ->26
mask[y1: y2, x1: x2] = 0. #將這一小塊區(qū)域去除
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
# expand_as()函數(shù)與expand()函數(shù)類似,功能都是用來擴展張量中某維數(shù)據(jù)的尺寸,區(qū)別是它括號內(nèi)的輸入?yún)?shù)是另一個張量,作用是將輸入tensor的維度擴展為與指定tensor相同的size。
img = img * mask
return img
幫助理解代碼的鏈接:
python中numpy模塊下的np.clip()的用法
pytorch中的expand()和expand_as()函數(shù)
CutMix
CutMix的所選取的正方形區(qū)域在訓(xùn)練圖像之間剪切和粘貼,真實標簽值也按patches的面積比例混合。通過有效利用訓(xùn)練像素,并保留區(qū)域dropout的正則化效果,CutMix在CIFAR分類任務(wù)上的表現(xiàn)始終優(yōu)于最先進的增強策略。
代碼實現(xiàn)
def rand_bbox(size, lam):
W = size[2]
H = size[3]
cut_rat = np.sqrt(1. - lam)
cut_w = np.int(W * cut_rat)
cut_h = np.int(H * cut_rat)
# uniform
cx = np.random.randint(W)
cy = np.random.randint(H)
bbx1 = np.clip(cx - cut_w // 2, 0, W)
bby1 = np.clip(cy - cut_h // 2, 0, H)
bbx2 = np.clip(cx + cut_w // 2, 0, W)
bby2 = np.clip(cy + cut_h // 2, 0, H)
return bbx1, bby1, bbx2, bby2
# generate mixed sample
lam = np.random.beta(args.beta, args.beta)
rand_index = torch.randperm(images.size()[0]).cuda()
labels_a = labels
labels_b = labels[rand_index]
bbx1, bby1, bbx2, bby2 = rand_bbox(images.size(), lam)
images[:, :, bbx1:bbx2, bby1:bby2] = images[rand_index, :, bbx1:bbx2, bby1:bby2]
# adjust lambda to exactly match pixel ratio
lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (images.size()[-1] * images.size()[-2]))
原文鏈接:https://blog.csdn.net/weixin_45928096/article/details/122406271
相關(guān)推薦
- 2022-06-29 Python解決非線性規(guī)劃中經(jīng)濟調(diào)度問題_python
- 2022-09-08 Python支持異步的列表解析式_python
- 2023-04-20 elementUI無線滾動+監(jiān)聽滾動條觸底
- 2022-03-26 C++函數(shù)指針的用法詳解_C 語言
- 2022-04-11 解決git push 錯誤error: src refspec master does not ma
- 2021-12-15 Android?studio導(dǎo)出APP測試包和構(gòu)建正式簽名包_Android
- 2022-07-07 ASP.NET對Cookie的操作_ASP.NET
- 2022-11-13 Redis中HyperLogLog的使用詳情_Redis
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細win安裝深度學(xué)習(xí)環(huán)境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設(shè)
- maven:解決release錯誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務(wù)發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡單動態(tài)字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應(yīng)用詳解
- 聊聊消息隊列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支