網站首頁 編程語言 正文
PyTorch之TensorDataset
TensorDataset 可以用來對?tensor?進行打包,就好像 python 中的 zip 功能。
該類通過每一個 tensor 的第一個維度進行索引。
因此,該類中的 tensor 第一維度必須相等。
from torch.utils.data import TensorDataset
import torch
from torch.utils.data import DataLoader
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = torch.tensor([44, 55, 66, 44, 55, 66, 44, 55, 66, 44, 55, 66])
train_ids = TensorDataset(a, b)
# 切片輸出
print(train_ids[0:2])
print('=' * 80)
# 循環取數據
for x_train, y_label in train_ids:
print(x_train, y_label)
# DataLoader進行數據封裝
print('=' * 80)
train_loader = DataLoader(dataset=train_ids, batch_size=4, shuffle=True)
for i, data in enumerate(train_loader, 1): # 注意enumerate返回值有兩個,一個是序號,一個是數據(包含訓練數據和標簽)
x_data, label = data
print(' batch:{0} x_data:{1} label: {2}'.format(i, x_data, label))
運行結果:
(tensor([[1, 2, 3],
? ? ? ? [4, 5, 6]]), tensor([44, 55]))
================================================================================
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
================================================================================
?batch:1 x_data:tensor([[1, 2, 3],
? ? ? ? [1, 2, 3],
? ? ? ? [4, 5, 6],
? ? ? ? [4, 5, 6]]) ?label: tensor([44, 44, 55, 55])
?batch:2 x_data:tensor([[4, 5, 6],
? ? ? ? [7, 8, 9],
? ? ? ? [7, 8, 9],
? ? ? ? [7, 8, 9]]) ?label: tensor([55, 66, 66, 66])
?batch:3 x_data:tensor([[1, 2, 3],
? ? ? ? [1, 2, 3],
? ? ? ? [7, 8, 9],
? ? ? ? [4, 5, 6]]) ?label: tensor([44, 44, 66, 55])
注意:TensorDataset 中的參數必須是 tensor
Pytorch中TensorDataset的快速使用
Pytorch中,TensorDataset()可以快速構建訓練所用的數據,不用使用自建的Mydataset(),如果沒有熟悉適用的dataset可以使用TensorDataset()作為暫時替代。
只需要把data和label作為參數輸入,就可以快速構建,之后便可以用Dataloader處理。
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
data = np.loadtxt('x.txt')
label = np.loadtxt('y.txt')
data = torch.tensor(data)
label = torch.tensor(label)
train_data = TensorDataset(data, label)
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)?
總結
原文鏈接:https://blog.csdn.net/qq_40211493/article/details/107529148
- 上一篇:沒有了
- 下一篇:沒有了
相關推薦
- 2022-10-11 golang游戲等資源壓縮包創建和操作方法_Golang
- 2022-05-05 C++繼承中的對象構造與析構和賦值重載詳解_C 語言
- 2022-11-09 Flutter?異步編程之單線程下異步模型圖文示例詳解_Android
- 2023-05-08 Docker中的compose簡介_docker
- 2022-10-17 Kotlin編程循環控制示例詳解_Android
- 2024-07-15 Redis 底層數據結構-簡單動態字符串(SDS)
- 2022-04-24 基于Python制作一個文件去重小工具_python
- 2023-07-03 什么是懶加載,如何實現圖片或列表懶加載?
- 欄目分類
-
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支