日本免费高清视频-国产福利视频导航-黄色在线播放国产-天天操天天操天天操天天操|www.shdianci.com

學無先后,達者為師

網(wǎng)站首頁 編程語言 正文

Pytorch使用技巧之Dataloader中的collate_fn參數(shù)詳析_python

作者:政在學習 ? 更新時間: 2022-05-19 編程語言

以MNIST為例

from torchvision import datasets
mnist = datasets.MNIST(root='./data/', train=True, download=True)
print(mnist[0])

結果

(, 5)

MINIST數(shù)據(jù)集的dataset是由一張圖片和一個label組成的元組

dataloader = torch.utils.data.DataLoader(dataset=mnist, batch_size=2, shuffle=True,collate_fn=lambda x:x)
for each in dataloader:
    print(each)
    break

結果

[(, 0), (, 2)]

collate_fn為lamda x:x時表示對傳入進來的數(shù)據(jù)不做處理

下面自定義collate_fn看看什么效果

def collate(data):
    img = []
    label = []
    for each in data:
        img.append(each[0])
        label.append(each[1])
    return img,label
dataloader = torch.utils.data.DataLoader(dataset=mnist, batch_size=2, shuffle=True,collate_fn=lambda x:collate(x))
for each in dataloader:
    print(each)
    break

結果

([, ], [9, 3])

說明:若不設置collate_fn參數(shù)則會使用默認處理函數(shù)

但必須保證傳進來的數(shù)據(jù)都是tensor格式否則會報錯

附:DataLoader完整的參數(shù)表如下:

class torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    sampler=None,
    batch_sampler=None,
    num_workers=0,
    collate_fn=,
    pin_memory=False,
    drop_last=False,
    timeout=0,
    worker_init_fn=None)

DataLoader在數(shù)據(jù)集上提供單進程或多進程的迭代器

幾個關鍵的參數(shù)意思:

- shuffle:設置為True的時候,每個世代都會打亂數(shù)據(jù)集

- collate_fn:如何取樣本的,我們可以定義自己的函數(shù)來準確地實現(xiàn)想要的功能

- drop_last:告訴如何處理數(shù)據(jù)集長度除于batch_size余下的數(shù)據(jù)。True就拋棄,否則保留

總結

原文鏈接:https://blog.csdn.net/qq_47718334/article/details/122884898

欄目分類
最近更新