日本免费高清视频-国产福利视频导航-黄色在线播放国产-天天操天天操天天操天天操|www.shdianci.com

學無先后,達者為師

網站首頁 編程語言 正文

PyTorch小功能之TensorDataset解讀_python

作者:菜鳥向前沖fighting ? 更新時間: 2023-05-22 編程語言

PyTorch之TensorDataset

TensorDataset 可以用來對?tensor?進行打包,就好像 python 中的 zip 功能。

該類通過每一個 tensor 的第一個維度進行索引。

因此,該類中的 tensor 第一維度必須相等。

from torch.utils.data import TensorDataset
import torch
from torch.utils.data import DataLoader

a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = torch.tensor([44, 55, 66, 44, 55, 66, 44, 55, 66, 44, 55, 66])
train_ids = TensorDataset(a, b) 
# 切片輸出
print(train_ids[0:2])
print('=' * 80)
# 循環取數據
for x_train, y_label in train_ids:
    print(x_train, y_label)
# DataLoader進行數據封裝
print('=' * 80)
train_loader = DataLoader(dataset=train_ids, batch_size=4, shuffle=True)
for i, data in enumerate(train_loader, 1):  # 注意enumerate返回值有兩個,一個是序號,一個是數據(包含訓練數據和標簽)
    x_data, label = data
    print(' batch:{0} x_data:{1}  label: {2}'.format(i, x_data, label))

運行結果:

(tensor([[1, 2, 3],
? ? ? ? [4, 5, 6]]), tensor([44, 55]))
================================================================================
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
================================================================================
?batch:1 x_data:tensor([[1, 2, 3],
? ? ? ? [1, 2, 3],
? ? ? ? [4, 5, 6],
? ? ? ? [4, 5, 6]]) ?label: tensor([44, 44, 55, 55])
?batch:2 x_data:tensor([[4, 5, 6],
? ? ? ? [7, 8, 9],
? ? ? ? [7, 8, 9],
? ? ? ? [7, 8, 9]]) ?label: tensor([55, 66, 66, 66])
?batch:3 x_data:tensor([[1, 2, 3],
? ? ? ? [1, 2, 3],
? ? ? ? [7, 8, 9],
? ? ? ? [4, 5, 6]]) ?label: tensor([44, 44, 66, 55])

注意:TensorDataset 中的參數必須是 tensor

Pytorch中TensorDataset的快速使用

Pytorch中,TensorDataset()可以快速構建訓練所用的數據,不用使用自建的Mydataset(),如果沒有熟悉適用的dataset可以使用TensorDataset()作為暫時替代。

只需要把data和label作為參數輸入,就可以快速構建,之后便可以用Dataloader處理。

import numpy as np
from torch.utils.data import DataLoader, TensorDataset
data = np.loadtxt('x.txt')
label = np.loadtxt('y.txt')
data = torch.tensor(data)
label = torch.tensor(label)
train_data = TensorDataset(data, label)
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)?

總結

原文鏈接:https://blog.csdn.net/qq_40211493/article/details/107529148

  • 上一篇:沒有了
  • 下一篇:沒有了
欄目分類
最近更新