日本免费高清视频-国产福利视频导航-黄色在线播放国产-天天操天天操天天操天天操|www.shdianci.com

學(xué)無先后,達(dá)者為師

網(wǎng)站首頁 編程語言 正文

Pytorch建模過程中的DataLoader與Dataset示例詳解_python

作者:奧辰 ? 更新時(shí)間: 2023-02-12 編程語言

處理數(shù)據(jù)樣本的代碼會(huì)因?yàn)樘幚磉^程繁雜而變得混亂且難以維護(hù),在理想情況下,我們希望數(shù)據(jù)預(yù)處理過程代碼與我們的模型訓(xùn)練代碼分離,以獲得更好的可讀性和模塊化,為此,PyTorch提供了torch.utils.data.DataLoader?和?torch.utils.data.Dataset兩個(gè)類用于數(shù)據(jù)處理。其中torch.utils.data.DataLoader用于將數(shù)據(jù)集進(jìn)行打包封裝成一個(gè)可迭代對(duì)象,torch.utils.data.Dataset存儲(chǔ)有一些常用的數(shù)據(jù)集示例以及相關(guān)標(biāo)簽。

同時(shí)PyTorch針對(duì)不同的專業(yè)領(lǐng)域,也提供有不同的模塊,例如?TorchText(自然語言處理),?TorchVision(計(jì)算機(jī)視覺),?TorchAudio(音頻),這些模塊中也都包含一些真實(shí)數(shù)據(jù)集示例。例如TorchVision模塊中提供了CIFAR, COCO, FashionMNIST 數(shù)據(jù)集。

1 定義數(shù)據(jù)集

pytorch中提供兩種風(fēng)格的數(shù)據(jù)集定義方式:

  • 字典映射風(fēng)格。之所以稱為映射風(fēng)格,是因?yàn)樵诤罄m(xù)加載數(shù)據(jù)迭代時(shí),pytorch將自動(dòng)使用迭代索引作為key,通過字典索引的方式獲取value,本質(zhì)就是將數(shù)據(jù)集定義為一個(gè)字典,使用這種風(fēng)格時(shí),需要繼承Dataset類。

In?[54]:

from torch.utils.data import Dataset
from torch.utils.data import DataLoader

In?[56]:

dataset = {0: '張三', 1:'李四', 2:'王五', 3:'趙六', 4:'陳七'}
dataloader = DataLoader(dataset, batch_size=2)
for i, value in enumerate(dataloader):
    print(i, value)
0 ['張三', '李四']
1 ['王五', '趙六']
2 ['陳七']
  • 迭代器風(fēng)格。在自定義數(shù)據(jù)集類中,實(shí)現(xiàn)__iter____next__方法,即定義為迭代器,在后續(xù)加載數(shù)據(jù)迭代時(shí),pytorch將依次獲取value,使用這種風(fēng)格時(shí),需要繼承IterableDataset類。這種方法在數(shù)據(jù)量巨大,無法一下全部加載到內(nèi)存時(shí)非常實(shí)用。

In?[57]:

from torch.utils.data import DataLoader
from torch.utils.data import IterableDataset

In?[58]:

dataset = [i for i in range(10)]
dataloader = DataLoader(dataset=dataset, batch_size=3, shuffle=True) 
for i, item in enumerate(dataloader): # 迭代輸出
    print(i, item)
0 tensor([3, 1, 2])
1 tensor([9, 7, 5])
2 tensor([0, 8, 4])
3 tensor([6])

如下所示,我們有一個(gè)螞蟻蜜蜂圖像分類數(shù)據(jù)集,目錄結(jié)構(gòu)如下所示,下面我們結(jié)合這個(gè)數(shù)據(jù)集,分別介紹如何使用這兩個(gè)類定義真實(shí)數(shù)據(jù)集。

data
└── hymenoptera_data
    ├── train
    │?? ├── ants
    │?? │?? ├── 0013035.jpg
    │   │   ……
    │?? └── bees
    │??     ├── 1092977343_cb42b38d62.jpg
    │       ……
    └── val
        ├── ants
        │?? ├── 10308379_1b6c72e180.jpg
        │?? ……
        └── bees
            ├── 1032546534_06907fe3b3.jpg
            ……

1.2 Dataset類

自定義一個(gè)Dataset類,繼承torch.utils.data.Dataset,且必須實(shí)現(xiàn)下面三個(gè)方法:

  • Dataset類里面的__init__函數(shù)初始化一些參數(shù),如讀取外部數(shù)據(jù)源文件。

  • Dataset類里面的__getitem__函數(shù),映射取值是調(diào)用的方法,獲取單個(gè)的數(shù)據(jù),訓(xùn)練迭代時(shí)將會(huì)調(diào)用這個(gè)方法。

  • Dataset類里面的__len__函數(shù)獲取數(shù)據(jù)的總量。

In?[211]:

import os
import pandas as pd
from PIL import Image
from torchvision.transforms import ToTensor, Lambda
from torchvision import transforms
import torchvision
class AntBeeDataset(Dataset):
    # 把圖片所在的文件夾路徑分成兩個(gè)部分,一部分是根目錄,一部分是標(biāo)簽?zāi)夸洠@是因?yàn)闃?biāo)簽?zāi)夸浀拿Q我們需要用到
    def __init__(self, root_dir, transform=None, target_transform=None):
        """
        root_dir:存放數(shù)據(jù)的根目錄,即:data/hymenoptera_data
        transform: 對(duì)圖像數(shù)據(jù)進(jìn)行處理,例如,將圖片轉(zhuǎn)換為Tensor、圖片的維度可能不一致需要進(jìn)行resize
        target_transform:對(duì)標(biāo)簽數(shù)據(jù)進(jìn)行處理,例如,將文本標(biāo)簽轉(zhuǎn)換為數(shù)值
        """
        self.root_dir = root_dir
        self.transform = transform
        self.target_transform = target_transform
        
        # 獲取文件夾下所有圖片的名稱和對(duì)應(yīng)的標(biāo)簽
        self.img_lst = []
        for label in ['ants', 'bees']:
            path = os.path.join(root_dir, label)
            for img_name in os.listdir(path):
                self.img_lst.append((os.path.join(root_dir, label, img_name), label))
        
    def __getitem__(self, idx):
        img_path, label = self.img_lst[idx]
        img = Image.open(img_path).convert('RGB')
        
        if self.transform:
            img = self.transform(img)
        if self.target_transform:
            label = self.target_transform(label)
        # 這個(gè)地方要注意,我們?cè)谟?jì)算loss的時(shí)候用交叉熵nn.CrossEntropyLoss()
        # 交叉熵的輸入有兩個(gè),一個(gè)是模型的輸出outputs,一個(gè)是標(biāo)簽targets,注意targets是一維tensor
        # 例如batchsize如果是2,ants的targets的應(yīng)該[0,0],而不是[[0][0]]
        # 因此label要返回0,而不是[0]
        return img, label

    def __len__(self):
        return len(self.img_lst)

In?[310]:

train_transform = transforms.Compose([
    
    transforms.RandomResizedCrop(224),  # 將給定圖像隨機(jī)裁剪為不同的大小和寬高比,然后縮放所裁剪得到的圖像為制定的大小
    transforms.RandomHorizontalFlip(),  # 以給定的概率隨機(jī)水平旋轉(zhuǎn)給定的PIL的圖像,默認(rèn)為0.5
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

# 驗(yàn)證集并不需要做與訓(xùn)練集相同的處理,所有,通常使用更加簡(jiǎn)單的transformer
val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

# 根據(jù)標(biāo)簽?zāi)夸浀拿Q來確定圖片是哪一類,如果是"ants",標(biāo)簽設(shè)置為0,如果是"bees",標(biāo)簽設(shè)置為1
target_transform = transforms.Lambda(lambda y: 0 if y == "ants" else 1)

In?[311]:

train_dataset = AntBeeDataset('data/hymenoptera_data/train', transform=train_transform, target_transform=target_transform)
val_dataset = AntBeeDataset('data/hymenoptera_data/val', transform=val_transform, target_transform=target_transform)

1.2 Dataset數(shù)據(jù)集常用操作

1. 查看數(shù)據(jù)集大小:

In?[221]:

len(train_dataset), len(val_dataset)

Out[221]:

(245, 153)

2. 合并數(shù)據(jù)集

In?[222]:

dataset = train_dataset + val_dataset

In?[223]:

len(dataset)

Out[223]:

398

3. 劃分訓(xùn)練集、測(cè)試集

In?[224]:

from torch.utils.data import random_split
# random_split 不能直接使用百分比劃分,必須指定具體數(shù)字
train_size = int( len(dataset) * 0.8)
test_size = len(dataset) - train_size

In?[225]:

train_dataset, val_dataset = random_split(dataset, [train_size, test_size])

In?[226]:

len(train_dataset), len(val_dataset)

Out[226]:

(318, 80)

1.3 IterableDataset類

使用迭代器風(fēng)格時(shí),必須繼承IterableDataset類,且實(shí)現(xiàn)下面兩個(gè)方法:

  • __init__,函數(shù)初始化一些參數(shù),如讀取外部數(shù)據(jù)源文件,在數(shù)據(jù)量過大時(shí),通常只是獲取操作句柄、數(shù)據(jù)庫連接。

  • __iter__,獲取迭代器。

雖然只需要實(shí)現(xiàn)這兩個(gè)方法,但是通常還需要在迭代過程中對(duì)數(shù)據(jù)進(jìn)行處理。IterableDataset類實(shí)現(xiàn)自定義數(shù)據(jù)集,本質(zhì)就是創(chuàng)建一個(gè)數(shù)據(jù)集類,且實(shí)現(xiàn)__iter__返回一個(gè)迭代器。一下提供兩種方法通過IterableDataset類自定義數(shù)據(jù)集:

方法一:

In?[289]:

class AntBeeIterableDataset(IterableDataset):
    # 把圖片所在的文件夾路徑分成兩個(gè)部分,一部分是根目錄,一部分是標(biāo)簽?zāi)夸洠@是因?yàn)闃?biāo)簽?zāi)夸浀拿Q我們需要用到
    def __init__(self, root_dir, transform=None, target_transform=None):
        """
        root_dir:存放數(shù)據(jù)的根目錄,即:data/hymenoptera_data
        transform: 對(duì)圖像數(shù)據(jù)進(jìn)行處理,例如,將圖片轉(zhuǎn)換為Tensor、圖片的維度可能不一致需要進(jìn)行resize
        target_transform:對(duì)標(biāo)簽數(shù)據(jù)進(jìn)行處理,例如,將文本標(biāo)簽轉(zhuǎn)換為數(shù)值
        """
        self.root_dir = root_dir
        self.transform = transform
        self.target_transform = target_transform
        
        # 獲取文件夾下所有圖片的名稱和對(duì)應(yīng)的標(biāo)簽
        self.img_lst = []
        for label in ['ants', 'bees']:
            path = os.path.join(root_dir, label)
            for img_name in os.listdir(path):
                self.img_lst.append((os.path.join(root_dir, label, img_name), label))
                
    def __iter__(self):
        for img_path, label in self.img_lst:
            img = Image.open(img_path).convert('RGB')
            if self.transform:
                img = self.transform(img)
            if self.target_transform:
                label = self.target_transform(label)
            yield img, label

方法二:

In?[285]:

class AntBeeIterableDataset(IterableDataset):
    # 把圖片所在的文件夾路徑分成兩個(gè)部分,一部分是根目錄,一部分是標(biāo)簽?zāi)夸洠@是因?yàn)闃?biāo)簽?zāi)夸浀拿Q我們需要用到
    def __init__(self, root_dir, transform=None, target_transform=None):
        """
        root_dir:存放數(shù)據(jù)的根目錄,即:data/hymenoptera_data
        transform: 對(duì)圖像數(shù)據(jù)進(jìn)行處理,例如,將圖片轉(zhuǎn)換為Tensor、圖片的維度可能不一致需要進(jìn)行resize
        target_transform:對(duì)標(biāo)簽數(shù)據(jù)進(jìn)行處理,例如,將文本標(biāo)簽轉(zhuǎn)換為數(shù)值
        """
        self.root_dir = root_dir
        self.transform = transform
        self.target_transform = target_transform
        
        # 獲取文件夾下所有圖片的名稱和對(duì)應(yīng)的標(biāo)簽
        self.img_lst = []
        for label in ['ants', 'bees']:
            path = os.path.join(root_dir, label)
            for img_name in os.listdir(path):
                self.img_lst.append((os.path.join(root_dir, label, img_name), label))
        self.index = 0
                
    def __iter__(self):
        return self
    
    def __next__(self):
        try:
            img_path, label = self.img_lst[self.index]
            self.index += 1
            img = Image.open(img_path).convert('RGB')
            if self.transform:
                img = self.transform(img)
            if self.target_transform:
                label = self.target_transform(label)
            return img, label
        except IndexError:
            raise StopIteration()

In?[290]:

train_dataset = AntBeeIterableDataset('data/hymenoptera_data/train', transform=train_transform, target_transform=target_transform)
val_dataset = AntBeeIterableDataset('data/hymenoptera_data/val', transform=val_transform, target_transform=target_transform)

在處理大數(shù)據(jù)集時(shí),IterableDataset會(huì)比Dataset更有優(yōu)勢(shì),例如數(shù)據(jù)存儲(chǔ)在文件或者數(shù)據(jù)庫中,只需要在自定義的IterableDataset之類中獲取文件操作句柄或者數(shù)據(jù)庫連接和游標(biāo)驚喜迭代,每次只返回一條數(shù)據(jù)即可。我們把上文中螞蟻蜜蜂數(shù)據(jù)集的所有圖片、標(biāo)簽這里后寫入hymenoptera_data.txt中,內(nèi)容如下所示,假設(shè)有數(shù)億行,那么,就不能直接將數(shù)據(jù)加載到內(nèi)存了:

data/hymenoptera_data/train/ants/2288481644_83ff7e4572.jpg, ants
data/hymenoptera_data/train/ants/2278278459_6b99605e50.jpg, ants
data/hymenoptera_data/train/ants/543417860_b14237f569.jpg, ants
...
...

可以參考一下方式定義IterableDataset子類:

In?[299]:

class AntBeeIterableDataset(IterableDataset):
    # 把圖片所在的文件夾路徑分成兩個(gè)部分,一部分是根目錄,一部分是標(biāo)簽?zāi)夸洠@是因?yàn)闃?biāo)簽?zāi)夸浀拿Q我們需要用到
    def __init__(self, filepath, transform=None, target_transform=None):
        """
        filepath:hymenoptera_data.txt完整路徑
        transform: 對(duì)圖像數(shù)據(jù)進(jìn)行處理,例如,將圖片轉(zhuǎn)換為Tensor、圖片的維度可能不一致需要進(jìn)行resize
        target_transform:對(duì)標(biāo)簽數(shù)據(jù)進(jìn)行處理,例如,將文本標(biāo)簽轉(zhuǎn)換為數(shù)值
        """
        self.filepath = filepath
        self.transform = transform
        self.target_transform = target_transform

                
    def __iter__(self):
        with open(self.filepath, 'r') as f:
            for line in f:
                img_path, label = line.replace('\n', '').split(', ')
                img = Image.open(img_path).convert('RGB')
                if self.transform:
                    img = self.transform(img)
                if self.target_transform:
                    label = self.target_transform(label)
                yield img, label

In?[307]:

train_dataset = AntBeeIterableDataset('hymenoptera_data.txt', transform=train_transform, target_transform=target_transform)

注意,IterableDataset方法在處理大數(shù)據(jù)集時(shí)確實(shí)比Dataset更有優(yōu)勢(shì),但是,IterableDataset在迭代過程中,樣本輸出順序是固定的,在使用DataLoader進(jìn)行加載時(shí),無法使用shuffle進(jìn)行打亂,同時(shí),因?yàn)樵贗terableDataset中并未強(qiáng)制限定必須實(shí)現(xiàn)__len__()方法(很多時(shí)候確實(shí)也沒法獲取數(shù)據(jù)總量),不能通過len()方法獲取數(shù)據(jù)總量。

2 DataLoad

DataLoader的功能是構(gòu)建可迭代的數(shù)據(jù)裝載器,在訓(xùn)練的時(shí)候,每一個(gè)for循環(huán),每一次Iteration,就是從DataLoader中獲取一個(gè)batch_size大小的數(shù)據(jù),節(jié)省內(nèi)存的同時(shí),它還可以實(shí)現(xiàn)多進(jìn)程、數(shù)據(jù)打亂等處理。我們通過一張圖來了解DataLoader數(shù)據(jù)讀取機(jī)制:

首先,在for循環(huán)中使用了DataLoader,進(jìn)入DataLoader后,首先根據(jù)是否使用多進(jìn)程DataLoaderIter,做出判斷之后單線程還是多線程,接著使用Sampler得索引Index,然后將索引給到DatasetFetcher,在這里面調(diào)用Dataset,根據(jù)索引,通過getitem得到實(shí)際的數(shù)據(jù)和標(biāo)簽,得到一個(gè)batch size大小的數(shù)據(jù)后,通過collate_fn函數(shù)整理成一個(gè)Batch Data的形式輸入到模型去訓(xùn)練。

在pytorch建模的數(shù)據(jù)處理、加載流程中,DataLoader應(yīng)該算是最核心的一步操作DataLoader有很多參數(shù),這里我們列出常用的幾個(gè):

  • dataset:表示Dataset類,它決定了數(shù)據(jù)從哪讀取以及如何讀取;
  • batch_size:表示批大小;
  • num_works:表示是否多進(jìn)程讀取數(shù)據(jù);
  • shuffle:表示每個(gè)epoch是否亂序;
  • drop_last:表示當(dāng)樣本數(shù)不能被batch_size整除時(shí),是否舍棄最后一批數(shù)據(jù);
  • num_workers:?jiǎn)?dòng)多少個(gè)進(jìn)程來加載數(shù)據(jù)。

我們重點(diǎn)說說多進(jìn)程模式下使用DataLoader,在多進(jìn)程模式下,每次 DataLoader 創(chuàng)建 iterator 時(shí)(遍歷DataLoader時(shí),例如,當(dāng)調(diào)用時(shí)enumerate(dataloader)),都會(huì)創(chuàng)建 num_workers 工作進(jìn)程。dataset, collate_fn, worker_init_fn 都會(huì)被傳到每個(gè)worker中,每個(gè)worker都用獨(dú)立的進(jìn)程。

對(duì)于映射風(fēng)格的數(shù)據(jù)集,即Dataset子類,主線程會(huì)用Sampler(采樣器)產(chǎn)生indice,并將它們送到進(jìn)程里。因此,shuffle是在主線程做的

對(duì)于迭代器風(fēng)格的數(shù)據(jù)集,即IterableDataset子類,因?yàn)槊總€(gè)進(jìn)程都有相同的data復(fù)制樣本,并在各個(gè)進(jìn)程里進(jìn)行不同的操作,以防止每個(gè)進(jìn)程輸出的數(shù)據(jù)是重復(fù)的,所以一般用 torch.utils.data.get_worker_info() 來進(jìn)行輔助處理。

這里,torch.utils.data.get_worker_info() 返回worker進(jìn)程的一些信息(id, dataset, num_workers, seed),如果在主線程跑的話返回None

注意,通常不建議在多進(jìn)程加載中返回CUDA張量,因?yàn)樵谑褂肅UDA和在多處理中共享CUDA張量時(shí)存在許多微妙之處(文檔中提出:只要接收過程保留張量的副本,就需要發(fā)送過程來保留原始張量)。建議采用 pin_memory=True ,以將數(shù)據(jù)快速傳輸?shù)街С諧UDA的GPU。簡(jiǎn)而言之,不建議在使用多線程的情況下返回CUDA的tensor。

In?[313]:

dataload = DataLoader(train_dataset, batch_size=2)

In?[315]:

img, label = next(iter(dataload))

In?[316]:

img.shape, label

Out[316]:

(torch.Size([2, 3, 224, 224]), tensor([0, 0]))

原文鏈接:https://www.cnblogs.com/chenhuabin/p/17026018.html

欄目分類
最近更新