日本免费高清视频-国产福利视频导航-黄色在线播放国产-天天操天天操天天操天天操|www.shdianci.com

學(xué)無先后,達者為師

網(wǎng)站首頁 編程語言 正文

torch.utils.data.DataLoader與迭代器轉(zhuǎn)換操作_python

作者:Orion's?Blog ? 更新時間: 2022-04-24 編程語言

在做實驗時,我們常常會使用用開源的數(shù)據(jù)集進行測試。而Pytorch中內(nèi)置了許多數(shù)據(jù)集,這些數(shù)據(jù)集我們常常使用DataLoader類進行加載。
如下面這個我們使用DataLoader類加載torch.vision中的FashionMNIST數(shù)據(jù)集。

from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt

training_data = datasets.FashionMNIST(
? ? root="data",
? ? train=True,
? ? download=True,
? ? transform=ToTensor()
)

test_data = datasets.FashionMNIST(
? ? root="data",
? ? train=False,
? ? download=True,
? ? transform=ToTensor()
)

我們接下來定義Dataloader對象用于加載這兩個數(shù)據(jù)集:

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

那么這個train_dataloader究竟是什么類型呢?

print(type(train_dataloader)) ?# <class 'torch.utils.data.dataloader.DataLoader'>

我們可以將先其轉(zhuǎn)換為迭代器類型。

print(type(iter(train_dataloader)))# <class 'torch.utils.data.dataloader._SingleProcessDataLoaderIter'>

然后再使用next(iter(train_dataloader))從迭代器里取數(shù)據(jù),如下所示:

train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

可以看到我們成功獲取了數(shù)據(jù)集中第一張圖片的信息,控制臺打印:

Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])
Label: 2

圖片可視化顯示如下:

不過有讀者可能就會產(chǎn)生疑問,很多時候我們并沒有將DataLoader類型強制轉(zhuǎn)換成迭代器類型呀,大多數(shù)時候我們會寫如下代碼:

for train_features, train_labels in train_dataloader:?
? ? print(train_features.shape) # torch.Size([64, 1, 28, 28])
? ? print(train_features[0].shape) # torch.Size([1, 28, 28])
? ? print(train_features[0].squeeze().shape) # torch.Size([28, 28])
? ??
? ? img = train_features[0].squeeze()
? ? label = train_labels[0]
? ? plt.imshow(img, cmap="gray")
? ? plt.show()
? ? print(f"Label: {label}")

可以看到,該代碼也能夠正常迭代訓(xùn)練數(shù)據(jù),前三個樣本的控制臺打印輸出為:

torch.Size([64, 1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([28, 28])
Label: 7
torch.Size([64, 1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([28, 28])
Label: 4
torch.Size([64, 1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([28, 28])
Label: 1

那么為什么我們這里沒有顯式將Dataloader轉(zhuǎn)換為迭代器類型呢,其實是Python語言for循環(huán)的一種機制,一旦我們用for ... in ...句式來迭代一個對象,那么Python解釋器就會偷偷地自動幫我們創(chuàng)建好迭代器,也就是說

for train_features, train_labels in train_dataloader:

實際上等同于

for train_features, train_labels in iter(train_dataloader):

更進一步,這實際上等同于

train_iterator = iter(train_dataloader)
try:
? ? while True:
? ? ? ? train_features, train_labels = next(train_iterator)
except StopIteration:
? ? pass

推而廣之,我們在用Python迭代直接迭代列表時:

for x in [1, 2, 3, 4]:

其實Python解釋器已經(jīng)為我們隱式轉(zhuǎn)換為迭代器了:

list_iterator = iter([1, 2, 3, 4])
try:
? ? while True:
? ? ? ? x = next(list_iterator)
except StopIteration:
? ? pass

原文鏈接:https://www.cnblogs.com/orion-orion/p/15651037.html

欄目分類
最近更新