網站首頁 編程語言 正文
Pytorch中TensorDataset,DataLoader的聯合使用
首先從字面意義上來理解TensorDataset和DataLoader,TensorDataset是個只用來存放tensor(張量)的數據集,而DataLoader是一個數據加載器,一般用到DataLoader的時候就說明需要遍歷和操作數據了。
TensorDataset(tensor1,tensor2)的功能就是形成數據tensor1和標簽tensor2的對應,也就是說tensor1中是數據,而tensor2是tensor1所對應的標簽。
來個小例子
from torch.utils.data import TensorDataset,DataLoader
import torch
?
a = torch.tensor([[1, 2, 3],
? ? ? ? ? ? ? ? ? [4, 5, 6],
? ? ? ? ? ? ? ? ? [7, 8, 9],
? ? ? ? ? ? ? ? ? [1, 2, 3],
? ? ? ? ? ? ? ? ? [4, 5, 6],
? ? ? ? ? ? ? ? ? [7, 8, 9],
? ? ? ? ? ? ? ? ? [1, 2, 3],
? ? ? ? ? ? ? ? ? [4, 5, 6],
? ? ? ? ? ? ? ? ? [7, 8, 9],
? ? ? ? ? ? ? ? ? [1, 2, 3],
? ? ? ? ? ? ? ? ? [4, 5, 6],
? ? ? ? ? ? ? ? ? [7, 8, 9]])
?
b = torch.tensor([44, 55, 66, 44, 55, 66, 44, 55, 66, 44, 55, 66])
train_ids = TensorDataset(a,b)
# 切片輸出
print(train_ids[0:4]) # 第0,1,2,3行
# 循環(huán)取數據
for x_train,y_label in train_ids:
? ? print(x_train,y_label)
下面是對應的輸出:
(tensor([[1, 2, 3],
? ? ? ? [4, 5, 6],
? ? ? ? [7, 8, 9],
? ? ? ? [1, 2, 3]]), tensor([44, 55, 66, 44]))
===============================================
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
從輸出結果我們就可以很好的理解,tensor型數據和tensor型標簽的對應了,這就是TensorDataset的基本應用。
接下來我們把構造好的TensorDataset封裝到DataLoader來操作里面的數據:
# 參數說明,dataset=train_ids表示需要封裝的數據集,batch_size表示一次取幾個
# shuffle表示亂序取數據,設為False表示順序取數據,True表示亂序取數據
train_loader = DataLoader(dataset=train_ids,batch_size=4,shuffle=False)
# 注意enumerate返回值有兩個,一個是序號,一個是數據(包含訓練數據和標簽)
for i,data in enumerate(train_loader,1):
? ? train_data, label = data
? ? print(' batch:{0} train_data:{1} ?label: {2}'.format(i+1, train_data, label))
下面是對應的輸出:
?batch:1 x_data:tensor([[1, 2, 3],
? ? ? ? [4, 5, 6],
? ? ? ? [7, 8, 9],
? ? ? ? [1, 2, 3]]) ?label: tensor([44, 55, 66, 44])
?batch:2 x_data:tensor([[4, 5, 6],
? ? ? ? [7, 8, 9],
? ? ? ? [1, 2, 3],
? ? ? ? [4, 5, 6]]) ?label: tensor([55, 66, 44, 55])
?batch:3 x_data:tensor([[7, 8, 9],
? ? ? ? [1, 2, 3],
? ? ? ? [4, 5, 6],
? ? ? ? [7, 8, 9]]) ?label: tensor([66, 44, 55, 66])
至此,TensorDataset和DataLoader的聯合使用就介紹完了。
我們再看一下這兩種方法的源碼:
class TensorDataset(Dataset[Tuple[Tensor, ...]]):
? ? r"""Dataset wrapping tensors.
? ? Each sample will be retrieved by indexing tensors along the first dimension.
? ? Arguments:
? ? ? ? *tensors (Tensor): tensors that have the same size of the first dimension.
? ? """
? ? tensors: Tuple[Tensor, ...]
?
? ? def __init__(self, *tensors: Tensor) -> None:
? ? ? ? assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
? ? ? ? self.tensors = tensors
?
? ? def __getitem__(self, index):
? ? ? ? return tuple(tensor[index] for tensor in self.tensors)
?
? ? def __len__(self):
? ? ? ? return self.tensors[0].size(0)
?
# 由于此類內容過多,故僅列舉了與本文相關的參數,其余參數可以自行去查看源碼
class DataLoader(Generic[T_co]):
? ? r"""
? ? Data loader. Combines a dataset and a sampler, and provides an iterable over
? ? the given dataset.
? ? The :class:`~torch.utils.data.DataLoader` supports both map-style and
? ? iterable-style datasets with single- or multi-process loading, customizing
? ? loading order and optional automatic batching (collation) and memory pinning.
? ? See :py:mod:`torch.utils.data` documentation page for more details.
? ? Arguments:
? ? ? ? dataset (Dataset): dataset from which to load the data.
? ? ? ? batch_size (int, optional): how many samples per batch to load
? ? ? ? ? ? (default: ``1``).
? ? ? ? shuffle (bool, optional): set to ``True`` to have the data reshuffled
? ? ? ? ? ? at every epoch (default: ``False``).
? ? """
? ? dataset: Dataset[T_co]
? ? batch_size: Optional[int]
?
? ? def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
? ? ? ? ? ? ? ? ?shuffle: bool = False):
?
? ? ? ? self.dataset = dataset
? ? ? ? self.batch_size = batch_size
Pytorch的DataLoader和Dataset以及TensorDataset的源碼分析
1.為什么要用DataLoader和Dataset
要對大量數據進行加載和處理時因為可能會出現內存不夠用的情況,這時候就需要用到數據集類Dataset或TensorDataset和數據集加載類DataLoader了。
使用這些類后可以將原本的數據分成小塊,在需要使用的時候再一部分一本分讀進內存中,而不是一開始就將所有數據讀進內存中。
2.Dateset的使用
pytorch中的torch.utils.data.Dataset是表示數據集的抽象類,但它一般不直接使用,而是通過自定義一個數據集來使用。
來自定義數據集應該繼承Dataset并應該有實現返回數據集尺寸的__len__方法和用來獲取索引數據的__getitem__方法。
Dataset類的源碼如下:
class Dataset(object):
? ? r"""An abstract class representing a :class:`Dataset`.
? ? All datasets that represent a map from keys to data samples should subclass
? ? it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
? ? data sample for a given key. Subclasses could also optionally overwrite
? ? :meth:`__len__`, which is expected to return the size of the dataset by many
? ? :class:`~torch.utils.data.Sampler` implementations and the default options
? ? of :class:`~torch.utils.data.DataLoader`.
? ? .. note::
? ? ? :class:`~torch.utils.data.DataLoader` by default constructs a index
? ? ? sampler that yields integral indices. ?To make it work with a map-style
? ? ? dataset with non-integral indices/keys, a custom sampler must be provided.
? ? """
? ? def __getitem__(self, index):
? ? ? ? raise NotImplementedError
? ? def __add__(self, other):
? ? ? ? return ConcatDataset([self, other])
? ? # No `def __len__(self)` default?
? ? # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
? ? # in pytorch/torch/utils/data/sampler.py
可以看到Dataset類中沒有__len__方法,雖然有__getitem__方法,但是并沒有實現啥有用的功能。
所以要寫一個Dataset類的子類來實現其應有的功能。
自定義類的實現舉例:
import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch.autograd import Variable
import numpy as np
import pandas as pd
value_df = pd.read_csv('data1.csv')
value_array = np.array(value_df)
print("value_array.shape =", value_array.shape) ?# (73700, 300)
value_size = value_array.shape[0] ?# 73700
train_size = int(0.7*value_size)
train_array = val_array[:train_size] ?
train_label_array = val_array[60:train_size+60]
class DealDataset(Dataset):
? ? """
? ? ? ? 下載數據、初始化數據,都可以在這里完成
? ? """
? ? def __init__(self, *arrays):
? ? ? ? assert all(arrays[0].shape[0] == array.shape[0] for array in arrays)
? ? ? ? self.arrays = arrays
? ? def __getitem__(self, index):
? ? ? ? return tuple(array[index] for array in self.arrays)
? ? def __len__(self):
? ? ? ? return self.arrays[0].shape[0]
# 實例化這個類,然后我們就得到了Dataset類型的數據,記下來就將這個類傳給DataLoader,就可以了。
train_dataset = DealDataset(train_array, train_label_array)
train_loader2 = DataLoader(dataset=train_dataset,
? ? ? ? ? ? ? ? ? ? ? ? ? ?batch_size=32,
? ? ? ? ? ? ? ? ? ? ? ? ? ?shuffle=True)
for epoch in range(2):
? ? for i, data in enumerate(train_loader2):
? ? ? ? # 將數據從 train_loader 中讀出來,一次讀取的樣本數是32個
? ? ? ? inputs, labels = data
? ? ? ? # 將這些數據轉換成Variable類型
? ? ? ? inputs, labels = Variable(inputs), Variable(labels)
? ? ? ? # 接下來就是跑模型的環(huán)節(jié)了,我們這里使用print來代替
? ? ? ? print("epoch:", epoch, "的第", i, "個inputs", inputs.data.size(), "labels", labels.data.size())
結果:
epoch: 0 的第 0 個inputs torch.Size([32, 300]) labels torch.Size([32, 300])
epoch: 0 的第 1 個inputs torch.Size([32, 300]) labels torch.Size([32, 300])
epoch: 0 的第 2 個inputs torch.Size([32, 300]) labels torch.Size([32, 300])
epoch: 0 的第 3 個inputs torch.Size([32, 300]) labels torch.Size([32, 300])
epoch: 0 的第 4 個inputs torch.Size([32, 300]) labels torch.Size([32, 300])
epoch: 0 的第 5 個inputs torch.Size([32, 300]) labels torch.Size([32, 300])
...
3.TensorDataset的使用
TensorDataset是可以直接使用的數據集類,它的源碼如下:
class TensorDataset(Dataset):
? ? r"""Dataset wrapping tensors.
? ? Each sample will be retrieved by indexing tensors along the first dimension.
? ? Arguments:
? ? ? ? *tensors (Tensor): tensors that have the same size of the first dimension.
? ? """
? ? def __init__(self, *tensors):
? ? ? ? assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
? ? ? ? self.tensors = tensors
? ? def __getitem__(self, index):
? ? ? ? return tuple(tensor[index] for tensor in self.tensors)
? ? def __len__(self):
? ? ? ? return self.tensors[0].size(0)
可以看到TensorDataset類是Dataset類的子類,且擁有返回數據集尺寸的__len__方法和用來獲取索引數據的__getitem__方法,所以可以直接使用。
它的結構跟上面自定義的子類的結構是一樣的,惟一的不同是TensorDataset已經規(guī)定了傳入的數據必須是torch.Tensor類型的,而自定義子類可以自由設定。
使用舉例:
import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch.autograd import Variable
import numpy as np
import pandas as pd
value_df = pd.read_csv('data1.csv')
value_array = np.array(value_df)
print("value_array.shape =", value_array.shape) ?# (73700, 300)
value_size = value_array.shape[0] ?# 73700
train_size = int(0.7*value_size)
train_array = val_array[:train_size] ?
train_tensor = torch.tensor(train_array, dtype=torch.float32).to(device)
train_label_array = val_array[60:train_size+60]
train_labels_tensor = torch.tensor(train_label_array,dtype=torch.float32).to(device)
train_dataset = TensorDataset(train_tensor, train_labels_tensor)
train_loader = DataLoader(dataset=train_dataset,
? ? ? ? ? ? ? ? ? ? ? ? ? batch_size=100,
? ? ? ? ? ? ? ? ? ? ? ? ? shuffle=False,
? ? ? ? ? ? ? ? ? ? ? ? ? num_workers=0)
for epoch in range(2):
? ? for i, data in enumerate(train_loader):
? ? ? ? inputs, labels = data
? ? ? ? inputs, labels = Variable(inputs), Variable(labels)
? ? ? ? print(epoch, i, "inputs", inputs.data.size(), "labels", labels.data.size())
結果:
0 0 inputs torch.Size([100, 300]) labels torch.Size([100, 300])
0 1 inputs torch.Size([100, 300]) labels torch.Size([100, 300])
0 2 inputs torch.Size([100, 300]) labels torch.Size([100, 300])
0 3 inputs torch.Size([100, 300]) labels torch.Size([100, 300])
0 4 inputs torch.Size([100, 300]) labels torch.Size([100, 300])
0 5 inputs torch.Size([100, 300]) labels torch.Size([100, 300])
0 6 inputs torch.Size([100, 300]) labels torch.Size([100, 300])
0 7 inputs torch.Size([100, 300]) labels torch.Size([100, 300])
0 8 inputs torch.Size([100, 300]) labels torch.Size([100, 300])
0 9 inputs torch.Size([100, 300]) labels torch.Size([100, 300])
0 10 inputs torch.Size([100, 300]) labels torch.Size([100, 300])
...
總結
原文鏈接:https://blog.csdn.net/F845992311/article/details/123478399
- 上一篇:沒有了
- 下一篇:沒有了
相關推薦
- 2022-10-22 BroadcastReceiver靜態(tài)注冊案例詳解_Android
- 2022-10-11 Spring Boot 使用@Scheduled注解實現定時任務
- 2022-05-08 Pandas修改DataFrame列名的兩種方法實例_python
- 2022-03-31 Docker使用鏡像倉庫的方法_docker
- 2022-06-08 FreeRTOS編碼標準及風格指南_操作系統(tǒng)
- 2023-04-03 React中使用Mobx的方法_React
- 2022-09-14 jQuery實現簡單計算器功能_jquery
- 2022-07-10 JDK13版本的環(huán)境變量的配置
- 欄目分類
-
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細win安裝深度學習環(huán)境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發(fā)現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態(tài)字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支