網(wǎng)站首頁 編程語言 正文
在做實驗時,我們常常會使用用開源的數(shù)據(jù)集進行測試。而Pytorch中內(nèi)置了許多數(shù)據(jù)集,這些數(shù)據(jù)集我們常常使用DataLoader
類進行加載。
如下面這個我們使用DataLoader
類加載torch.vision
中的FashionMNIST
數(shù)據(jù)集。
from torch.utils.data import DataLoader from torchvision import datasets from torchvision.transforms import ToTensor import matplotlib.pyplot as plt training_data = datasets.FashionMNIST( ? ? root="data", ? ? train=True, ? ? download=True, ? ? transform=ToTensor() ) test_data = datasets.FashionMNIST( ? ? root="data", ? ? train=False, ? ? download=True, ? ? transform=ToTensor() )
我們接下來定義Dataloader對象用于加載這兩個數(shù)據(jù)集:
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True) test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
那么這個train_dataloader
究竟是什么類型呢?
print(type(train_dataloader)) ?# <class 'torch.utils.data.dataloader.DataLoader'>
我們可以將先其轉(zhuǎn)換為迭代器類型。
print(type(iter(train_dataloader)))# <class 'torch.utils.data.dataloader._SingleProcessDataLoaderIter'>
然后再使用next(iter(train_dataloader))
從迭代器里取數(shù)據(jù),如下所示:
train_features, train_labels = next(iter(train_dataloader)) print(f"Feature batch shape: {train_features.size()}") print(f"Labels batch shape: {train_labels.size()}") img = train_features[0].squeeze() label = train_labels[0] plt.imshow(img, cmap="gray") plt.show() print(f"Label: {label}")
可以看到我們成功獲取了數(shù)據(jù)集中第一張圖片的信息,控制臺打印:
Feature batch shape: torch.Size([64, 1, 28, 28]) Labels batch shape: torch.Size([64]) Label: 2
圖片可視化顯示如下:
不過有讀者可能就會產(chǎn)生疑問,很多時候我們并沒有將DataLoader類型強制轉(zhuǎn)換成迭代器類型呀,大多數(shù)時候我們會寫如下代碼:
for train_features, train_labels in train_dataloader:? ? ? print(train_features.shape) # torch.Size([64, 1, 28, 28]) ? ? print(train_features[0].shape) # torch.Size([1, 28, 28]) ? ? print(train_features[0].squeeze().shape) # torch.Size([28, 28]) ? ?? ? ? img = train_features[0].squeeze() ? ? label = train_labels[0] ? ? plt.imshow(img, cmap="gray") ? ? plt.show() ? ? print(f"Label: {label}")
可以看到,該代碼也能夠正常迭代訓(xùn)練數(shù)據(jù),前三個樣本的控制臺打印輸出為:
torch.Size([64, 1, 28, 28]) torch.Size([1, 28, 28]) torch.Size([28, 28]) Label: 7 torch.Size([64, 1, 28, 28]) torch.Size([1, 28, 28]) torch.Size([28, 28]) Label: 4 torch.Size([64, 1, 28, 28]) torch.Size([1, 28, 28]) torch.Size([28, 28]) Label: 1
那么為什么我們這里沒有顯式將Dataloader
轉(zhuǎn)換為迭代器類型呢,其實是Python語言for循環(huán)的一種機制,一旦我們用for ... in ...句式來迭代一個對象,那么Python
解釋器就會偷偷地自動幫我們創(chuàng)建好迭代器,也就是說
for train_features, train_labels in train_dataloader:
實際上等同于
for train_features, train_labels in iter(train_dataloader):
更進一步,這實際上等同于
train_iterator = iter(train_dataloader) try: ? ? while True: ? ? ? ? train_features, train_labels = next(train_iterator) except StopIteration: ? ? pass
推而廣之,我們在用Python迭代直接迭代列表時:
for x in [1, 2, 3, 4]:
其實Python解釋器已經(jīng)為我們隱式轉(zhuǎn)換為迭代器了:
list_iterator = iter([1, 2, 3, 4]) try: ? ? while True: ? ? ? ? x = next(list_iterator) except StopIteration: ? ? pass
原文鏈接:https://www.cnblogs.com/orion-orion/p/15651037.html
相關(guān)推薦
- 2023-01-23 Python列表對象中元素的刪除操作方法_python
- 2023-04-06 C++中的memset用法詳解_C 語言
- 2022-11-17 Android自定義一個view?ViewRootImpl繪制流程示例_Android
- 2023-11-13 matplotlib按照論文要求繪圖并保存pdf格式
- 2022-07-06 C#線程開發(fā)之System.Thread類詳解_C#教程
- 2022-03-24 Android?TextView文本控件介紹_Android
- 2023-02-23 Redis的setNX分布式鎖超時時間失效?-1問題及解決_Redis
- 2022-05-09 Python的數(shù)據(jù)結(jié)構(gòu)與算法的隊列詳解(3)_python
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細win安裝深度學(xué)習(xí)環(huán)境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設(shè)
- maven:解決release錯誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務(wù)發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡單動態(tài)字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應(yīng)用詳解
- 聊聊消息隊列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支