網站首頁 編程語言 正文
一、torch.utils.data.DataLoader 簡介
作用:torch.utils.data.DataLoader 主要是對數據進行 batch 的劃分。
數據加載器,結合了數據集和取樣器,并且可以提供多個線程處理數據集。
在訓練模型時使用到此函數,用來 把訓練數據分成多個小組 ,此函數 每次拋出一組數據 。直至把所有的數據都拋出。就是做一個數據的初始化。
好處:
使用DataLoader的好處是,可以快速的迭代數據。
用于生成迭代數據非常方便。
注意:
除此之外,特別要注意的是輸入進函數的數據一定得是可迭代的。如果是自定的數據集的話可以在定義類中用def__len__、def__getitem__定義。
二、實例
BATCH_SIZE 剛好整除數據量
"""
批訓練,把數據變成一小批一小批數據進行訓練。
DataLoader就是用來包裝所使用的數據,每次拋出一批數據
"""
import torch
import torch.utils.data as Data
BATCH_SIZE = 5 # 批訓練的數據個數
x = torch.linspace(1, 10, 10) # 訓練數據
print(x)
y = torch.linspace(10, 1, 10) # 標簽
print(y)
# 把數據放在數據庫中
torch_dataset = Data.TensorDataset(x, y) # 對給定的 tensor 數據,將他們包裝成 dataset
loader = Data.DataLoader(
# 從數據庫中每次抽出batch size個樣本
dataset=torch_dataset, # torch TensorDataset format
batch_size=BATCH_SIZE, # mini batch size
shuffle=True, # 要不要打亂數據 (打亂比較好)
num_workers=2, # 多線程來讀數據
)
def show_batch():
for epoch in range(3):
for step, (batch_x, batch_y) in enumerate(loader):
# training
print("steop:{}, batch_x:{}, batch_y:{}".format(step, batch_x, batch_y))
show_batch()
輸出結果:
tensor([ 1., ?2., ?3., ?4., ?5., ?6., ?7., ?8., ?9., 10.])
tensor([10., ?9., ?8., ?7., ?6., ?5., ?4., ?3., ?2., ?1.])
steop:0, batch_x:tensor([10., ?1., ?3., ?7., ?6.]), batch_y:tensor([ 1., 10., ?8., ?4., ?5.])
steop:1, batch_x:tensor([8., 5., 4., 9., 2.]), batch_y:tensor([3., 6., 7., 2., 9.])
steop:0, batch_x:tensor([ 9., ?3., 10., ?1., ?5.]), batch_y:tensor([ 2., ?8., ?1., 10., ?6.])
steop:1, batch_x:tensor([2., 6., 8., 4., 7.]), batch_y:tensor([9., 5., 3., 7., 4.])
steop:0, batch_x:tensor([ 2., 10., ?9., ?6., ?1.]), batch_y:tensor([ 9., ?1., ?2., ?5., 10.])
steop:1, batch_x:tensor([8., 3., 4., 7., 5.]), batch_y:tensor([3., 8., 7., 4., 6.])
說明:共有 10 條數據,設置 BATCH_SIZE 為 5 來進行劃分,能劃分為 2 組(steop 為 0 和 1)。這兩組數據互斥。
BATCH_SIZE 不整除數據量:會輸出余下所有數據
將上述代碼中的 BATCH_SIZE 改為 4 :
"""
批訓練,把數據變成一小批一小批數據進行訓練。
DataLoader就是用來包裝所使用的數據,每次拋出一批數據
"""
import torch
import torch.utils.data as Data
BATCH_SIZE = 4 # 批訓練的數據個數
x = torch.linspace(1, 10, 10) # 訓練數據
print(x)
y = torch.linspace(10, 1, 10) # 標簽
print(y)
# 把數據放在數據庫中
torch_dataset = Data.TensorDataset(x, y) # 對給定的 tensor 數據,將他們包裝成 dataset
loader = Data.DataLoader(
# 從數據庫中每次抽出batch size個樣本
dataset=torch_dataset, # torch TensorDataset format
batch_size=BATCH_SIZE, # mini batch size
shuffle=True, # 要不要打亂數據 (打亂比較好)
num_workers=2, # 多線程來讀數據
)
def show_batch():
for epoch in range(3):
for step, (batch_x, batch_y) in enumerate(loader):
# training
print("steop:{}, batch_x:{}, batch_y:{}".format(step, batch_x, batch_y))
show_batch()
輸出結果:
tensor([ 1., ?2., ?3., ?4., ?5., ?6., ?7., ?8., ?9., 10.])
tensor([10., ?9., ?8., ?7., ?6., ?5., ?4., ?3., ?2., ?1.])
steop:0, batch_x:tensor([1., 5., 3., 2.]), batch_y:tensor([10., ?6., ?8., ?9.])
steop:1, batch_x:tensor([7., 8., 4., 6.]), batch_y:tensor([4., 3., 7., 5.])
steop:2, batch_x:tensor([10., ?9.]), batch_y:tensor([1., 2.])
steop:0, batch_x:tensor([ 7., 10., ?5., ?2.]), batch_y:tensor([4., 1., 6., 9.])
steop:1, batch_x:tensor([9., 1., 6., 4.]), batch_y:tensor([ 2., 10., ?5., ?7.])
steop:2, batch_x:tensor([8., 3.]), batch_y:tensor([3., 8.])
steop:0, batch_x:tensor([10., ?3., ?2., ?8.]), batch_y:tensor([1., 8., 9., 3.])
steop:1, batch_x:tensor([1., 7., 5., 9.]), batch_y:tensor([10., ?4., ?6., ?2.])
steop:2, batch_x:tensor([4., 6.]), batch_y:tensor([7., 5.])
說明:共有 10 條數據,設置 BATCH_SIZE 為 4 來進行劃分,能劃分為 3 組(steop 為 0 、1、2)。分別有 4、4、2 條數據。
參考鏈接
- torch.utils.data.DataLoader使用方法
- 【Pytorch基礎】torch.utils.data.DataLoader方法的使用
總結
原文鏈接:https://blog.csdn.net/weixin_44211968/article/details/123744513
相關推薦
- 2022-04-04 Error running : No valid Maven installation found.
- 2024-04-06 linux中redis重啟,啟動,停止的sh腳本
- 2022-11-05 pytest官方文檔解讀之安裝和使用插件的方法_python
- 2022-12-26 詳解Python中四種關系圖數據可視化的效果對比_python
- 2023-11-26 XMLHttpRequest的readyState狀態值
- 2022-09-30 Go語言編譯原理之源碼調試_Golang
- 2023-05-15 Go語言實現AES加密并編寫一個命令行應用程序_Golang
- 2022-05-08 react實現原生下拉刷新_React
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支