網站首頁 編程語言 正文
背景
數據增強作為前處理的關鍵步驟,在整個計算機視覺中有著具足輕重的地位;
數據增強往往是決定數據集質量的關鍵,主要用于數據增廣,在基于深度學習的任務中,數據的多樣性和數量往往能夠決定模型的上限;
本次記錄主要是對數據增強中一些方法的源碼實現;
常用數據增強方法
首先如果是使用Pytorch框架,其內部的torchvision已經包裝好了數據增強的很多方法;
from torchvision import transforms
data_aug = transforms.Compose[
transforms.Resize(size=240),
transforms.RandomHorizontalFlip(0.5),
transforms.ToTensor()
]
接下來自己實現一些主要的方法;
常見的數據增強方法有:Compose、RandomHflip、RandomVflip、Reszie、RandomCrop、Normalize、Rotate、RandomRotate
1、Compose
作用:對多個方法的排序整合,并且依次調用;
# 排序(compose)
class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img):
for t in self.transforms:
img = t(img) # 通過循環不斷調用列表中的方法
return img
2、RandomHflip
作用:隨機水平翻轉;
# 隨機水平翻轉(random h flip)
class RandomHflip(object):
def __call__(self, image):
if random.randint(2):
return cv2.flip(image, 1)
else:
return image
通過隨機數0或1,實現對圖像可能反轉或不翻轉;
3、RandomVflip
作用:隨機垂直翻轉
class RandomVflip(object):
def __call__(self, image):
if random.randint(2):
return cv2.flip(image, 0)
else:
return image
4、RandomCrop
作用:隨機裁剪;
# 縮放(scale)
def scale_down(src_size, size):
w, h = size
sw, sh = src_size
if sh < h:
w, h = float(w * sh) / h, sh
if sw < w:
w, h = sw, float(h * sw) / w
return int(w), int(h)
# 固定裁剪(fixed crop)
def fixed_crop(src, x0, y0, w, h, size=None):
out = src[y0:y0 + h, x0:x0 + w]
if size is not None and (w, h) != size:
out = cv2.resize(out, (size[0], size[1]), interpolation=cv2.INTER_CUBIC)
return out
# 隨機裁剪(random crop)
class RandomCrop(object):
def __init__(self, size):
self.size = size
def __call__(self, image):
h, w, _ = image.shape
new_w, new_h = scale_down((w, h), self.size)
if w == new_w:
x0 = 0
else:
x0 = random.randint(0, w - new_w)
if h == new_h:
y0 = 0
else:
y0 = random.randint(0, h - new_h)
??????? out = fixed_crop(image, x0, y0, new_w, new_h, self.size)
return out
5、Normalize
作用:對圖像數據進行正則化,也就是減均值除方差的作用;
# 正則化(normalize)
class Normalize(object):
def __init__(self,mean, std):
'''
:param mean: RGB order
:param std: RGB order
'''
self.mean = np.array(mean).reshape(3,1,1)
self.std = np.array(std).reshape(3,1,1)
def __call__(self, image):
'''
:param image: (H,W,3) RGB
:return:
'''
return (image.transpose((2, 0, 1)) / 255. - self.mean) / self.std
6、Rotate
作用:對圖像進行旋轉;
# 旋轉(rotate)
def rotate_nobound(image, angle, center=None, scale=1.):
(h, w) = image.shape[:2]
# if the center is None, initialize it as the center of the image
if center is None:
center = (w // 2, h // 2) # perform the rotation
M = cv2.getRotationMatrix2D(center, angle, scale) # 這里是實現得到旋轉矩陣
rotated = cv2.warpAffine(image, M, (w, h)) # 通過矩陣進行仿射變換
return rotated
7、RandomRotate
作用:隨機旋轉,廣泛適用于圖像增強;
# 隨機旋轉(random rotate)
class FixRandomRotate(object):
# 這里的隨機旋轉是指在0、90、180、270四個角度下的
def __init__(self, angles=[0,90,180,270], bound=False):
self.angles = angles
self.bound = bound
def __call__(self,img):
do_rotate = random.randint(0, 4)
angle=self.angles[do_rotate]
if self.bound:
img = rotate_bound(img, angle)
else:
img = rotate_nobound(img, angle)
return img
8、Resize
作用:實現縮放;
# 大小重置(resize)
class Resize(object):
def __init__(self, size, inter=cv2.INTER_CUBIC):
self.size = size
self.inter = inter
def __call__(self, image):
return cv2.resize(image, (self.size[0], self.size[0]), interpolation=self.inter)
其他數據增強方法
其他一些數據增強的方法大部分是特殊的裁剪;
1、中心裁剪
# 中心裁剪(center crop)
def center_crop(src, size):
h, w = src.shape[0:2]
new_w, new_h = scale_down((w, h), size)
x0 = int((w - new_w) / 2)
y0 = int((h - new_h) / 2)
out = fixed_crop(src, x0, y0, new_w, new_h, size)
return out
2、隨機亮度增強
# 隨機亮度增強(random brightness)
class RandomBrightness(object):
def __init__(self, delta=10):
assert delta >= 0
assert delta <= 255
self.delta = delta
def __call__(self, image):
if random.randint(2):
delta = random.uniform(-self.delta, self.delta)
image = (image + delta).clip(0.0, 255.0)
# print('RandomBrightness,delta ',delta)
return image
3、隨機對比度增強
# 隨機對比度增強(random contrast)
class RandomContrast(object):
def __init__(self, lower=0.9, upper=1.05):
self.lower = lower
self.upper = upper
assert self.upper >= self.lower, "contrast upper must be >= lower."
assert self.lower >= 0, "contrast lower must be non-negative."
# expects float image
def __call__(self, image):
if random.randint(2):
alpha = random.uniform(self.lower, self.upper)
# print('contrast:', alpha)
image = (image * alpha).clip(0.0,255.0)
return image
4、隨機飽和度增強
# 隨機飽和度增強(random saturation)
class RandomSaturation(object):
def __init__(self, lower=0.8, upper=1.2):
self.lower = lower
self.upper = upper
assert self.upper >= self.lower, "contrast upper must be >= lower."
assert self.lower >= 0, "contrast lower must be non-negative."
def __call__(self, image):
if random.randint(2):
alpha = random.uniform(self.lower, self.upper)
image[:, :, 1] *= alpha
# print('RandomSaturation,alpha',alpha)
return image
5、邊界擴充
# 邊界擴充(expand border)
class ExpandBorder(object):
def __init__(self, mode='constant', value=255, size=(336,336), resize=False):
self.mode = mode
self.value = value
self.resize = resize
self.size = size
def __call__(self, image):
h, w, _ = image.shape
if h > w:
pad1 = (h-w)//2
pad2 = h - w - pad1
if self.mode == 'constant':
image = np.pad(image, ((0, 0), (pad1, pad2), (0, 0)),
self.mode, constant_values=self.value)
else:
image = np.pad(image,((0,0), (pad1, pad2),(0,0)), self.mode)
elif h < w:
pad1 = (w-h)//2
pad2 = w-h - pad1
if self.mode == 'constant':
image = np.pad(image, ((pad1, pad2),(0, 0), (0, 0)),
self.mode,constant_values=self.value)
else:
image = np.pad(image, ((pad1, pad2), (0, 0), (0, 0)),self.mode)
if self.resize:
image = cv2.resize(image, (self.size[0], self.size[0]),interpolation=cv2.INTER_LINEAR)
return image
當然還有很多其他數據增強的方式,在這里就不繼續做說明了;
拓展
除了可以使用Pytorch中自帶的數據增強包之外,也可以使用imgaug這個包(一個基于數據處理的包、包含大量的數據處理方法,并且代碼完全開源)
代碼地址:https://github.com/aleju/imgaug
說明文檔:https://imgaug.readthedocs.io/en/latest/index.html
強烈建議大家看看這個說明文檔,其中的很多數據處理方法可以快速的應用到實際項目中,也可以加深對圖像處理的理解;
原文鏈接:https://blog.csdn.net/weixin_40620310/article/details/126875826
相關推薦
- 2022-05-11 Qt編寫地圖之實現經緯度坐標糾偏_C 語言
- 2023-03-29 C語言楊氏矩陣實例教你編寫_C 語言
- 2022-04-20 為WPF框架Prism注冊Nlog日志服務_實用技巧
- 2022-05-31 關于k8s?使用?Service?控制器對外暴露服務的問題_云其它
- 2022-03-19 Linux系統下安裝Redis數據庫過程_Redis
- 2023-04-24 Numpy創建NumPy矩陣的簡單實現_python
- 2022-08-27 Qt實現一個簡單的word文檔編輯器_C 語言
- 2022-12-05 Flutter控制組件顯示和隱藏三種方式詳解_Android
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支