日本免费高清视频-国产福利视频导航-黄色在线播放国产-天天操天天操天天操天天操|www.shdianci.com

學(xué)無先后,達(dá)者為師

網(wǎng)站首頁(yè) 編程語(yǔ)言 正文

PyTorch小功能之TensorDataset解讀_python

作者:菜鳥向前沖fighting ? 更新時(shí)間: 2023-05-22 編程語(yǔ)言

PyTorch之TensorDataset

TensorDataset 可以用來對(duì)?tensor?進(jìn)行打包,就好像 python 中的 zip 功能。

該類通過每一個(gè) tensor 的第一個(gè)維度進(jìn)行索引。

因此,該類中的 tensor 第一維度必須相等。

from torch.utils.data import TensorDataset
import torch
from torch.utils.data import DataLoader

a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = torch.tensor([44, 55, 66, 44, 55, 66, 44, 55, 66, 44, 55, 66])
train_ids = TensorDataset(a, b) 
# 切片輸出
print(train_ids[0:2])
print('=' * 80)
# 循環(huán)取數(shù)據(jù)
for x_train, y_label in train_ids:
    print(x_train, y_label)
# DataLoader進(jìn)行數(shù)據(jù)封裝
print('=' * 80)
train_loader = DataLoader(dataset=train_ids, batch_size=4, shuffle=True)
for i, data in enumerate(train_loader, 1):  # 注意enumerate返回值有兩個(gè),一個(gè)是序號(hào),一個(gè)是數(shù)據(jù)(包含訓(xùn)練數(shù)據(jù)和標(biāo)簽)
    x_data, label = data
    print(' batch:{0} x_data:{1}  label: {2}'.format(i, x_data, label))

運(yùn)行結(jié)果:

(tensor([[1, 2, 3],
? ? ? ? [4, 5, 6]]), tensor([44, 55]))
================================================================================
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
================================================================================
?batch:1 x_data:tensor([[1, 2, 3],
? ? ? ? [1, 2, 3],
? ? ? ? [4, 5, 6],
? ? ? ? [4, 5, 6]]) ?label: tensor([44, 44, 55, 55])
?batch:2 x_data:tensor([[4, 5, 6],
? ? ? ? [7, 8, 9],
? ? ? ? [7, 8, 9],
? ? ? ? [7, 8, 9]]) ?label: tensor([55, 66, 66, 66])
?batch:3 x_data:tensor([[1, 2, 3],
? ? ? ? [1, 2, 3],
? ? ? ? [7, 8, 9],
? ? ? ? [4, 5, 6]]) ?label: tensor([44, 44, 66, 55])

注意:TensorDataset 中的參數(shù)必須是 tensor

Pytorch中TensorDataset的快速使用

Pytorch中,TensorDataset()可以快速構(gòu)建訓(xùn)練所用的數(shù)據(jù),不用使用自建的Mydataset(),如果沒有熟悉適用的dataset可以使用TensorDataset()作為暫時(shí)替代。

只需要把data和label作為參數(shù)輸入,就可以快速構(gòu)建,之后便可以用Dataloader處理。

import numpy as np
from torch.utils.data import DataLoader, TensorDataset
data = np.loadtxt('x.txt')
label = np.loadtxt('y.txt')
data = torch.tensor(data)
label = torch.tensor(label)
train_data = TensorDataset(data, label)
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)?

總結(jié)

原文鏈接:https://blog.csdn.net/qq_40211493/article/details/107529148

  • 上一篇:沒有了
  • 下一篇:沒有了
欄目分類
最近更新