網站首頁 編程語言 正文
在Pytorch中,torch.utils.data中的Dataset與DataLoader是處理數據集的兩個函數,用來處理加載數據集。通常情況下,使用的關鍵在于構建dataset類。
一:dataset類構建。
在構建數據集類時,除了__init__(self),還要有__len__(self)與__getitem__(self,item)兩個方法,這三個是必不可少的,至于其它用于數據處理的函數,可以任意定義。
class dataset:
def __init__(self,...):
...
def __len__(self,...):
return n
def __getitem__(self,item):
return data[item]
正常情況下,該數據集是要繼承Pytorch中Dataset類的,但實際操作中,即使不繼承,數據集類構建后仍可以用Dataloader()加載的。
在dataset類中,__len__(self)返回數據集中數據個數,__getitem__(self,item)表示每次返回第item條數據。
二:DataLoader使用
在構建dataset類后,即可使用DataLoader加載。DataLoader中常用參數如下:
1.dataset:需要載入的數據集,如前面構造的dataset類。
2.batch_size:批大小,在神經網絡訓練時我們很少逐條數據訓練,而是幾條數據作為一個batch進行訓練。
3.shuffle:是否在打亂數據集樣本順序。True為打亂,False反之。
4.drop_last:是否舍去最后一個batch的數據(很多情況下數據總數N與batch size不整除,導致最后一個batch不為batch size)。True為舍去,False反之。
三:舉例
兔兔以指標為1,數據個數為100的數據為例。
import torch
from torch.utils.data import DataLoader
class dataset:
def __init__(self):
self.x=torch.randint(0,20,size=(100,1),dtype=torch.float32)
self.y=(torch.sin(self.x)+1)/2
def __len__(self):
return 100
def __getitem__(self, item):
return self.x[item],self.y[item]
data=DataLoader(dataset(),batch_size=10,shuffle=True)
for batch in data:
print(batch)
當然,利用這個數據集可以進行簡單的神經網絡訓練。
from torch import nn
data=DataLoader(dataset(),batch_size=10,shuffle=True)
bp=nn.Sequential(nn.Linear(1,5),
nn.Sigmoid(),
nn.Linear(5,1),
nn.Sigmoid())
optim=torch.optim.Adam(params=bp.parameters())
Loss=nn.MSELoss()
for epoch in range(10):
print('the {} epoch'.format(epoch))
for batch in data:
yp=bp(batch[0])
loss=Loss(yp,batch[1])
optim.zero_grad()
loss.backward()
optim.step()
ps:下面再給大家補充介紹下Pytorch中DataLoader的使用。
前言
最近開始接觸pytorch,從跑別人寫好的代碼開始,今天需要把輸入數據根據每個batch的最長輸入數據,填充到一樣的長度(之前是將所有的數據直接填充到一樣的長度再輸入)。
剛開始是想偷懶,沒有去認真了解輸入的機制,結果一直報錯…還是要認真學習呀!
加載數據
pytorch中加載數據的順序是:
①創建一個dataset對象
②創建一個dataloader對象
③循環dataloader對象,將data,label拿到模型中去訓練
dataset
你需要自己定義一個class,里面至少包含3個函數:
①__init__:傳入數據,或者像下面一樣直接在函數里加載數據
②__len__:返回這個數據集一共有多少個item
③__getitem__:返回一條訓練數據,并將其轉換成tensor
import torch
from torch.utils.data import Dataset
class Mydata(Dataset):
def __init__(self):
a = np.load("D:/Python/nlp/NRE/a.npy",allow_pickle=True)
b = np.load("D:/Python/nlp/NRE/b.npy",allow_pickle=True)
d = np.load("D:/Python/nlp/NRE/d.npy",allow_pickle=True)
c = np.load("D:/Python/nlp/NRE/c.npy")
self.x = list(zip(a,b,d,c))
def __getitem__(self, idx):
assert idx < len(self.x)
return self.x[idx]
def __len__(self):
return len(self.x)
dataloader
參數:
dataset:傳入的數據
shuffle = True:是否打亂數據
collate_fn:使用這個參數可以自己操作每個batch的數據
dataset = Mydata()
dataloader = DataLoader(dataset, batch_size = 2, shuffle=True,collate_fn = mycollate)
下面是將每個batch的數據填充到該batch的最大長度
def mycollate(data):
a = []
b = []
c = []
d = []
max_len = len(data[0][0])
for i in data:
if len(i[0])>max_len:
max_len = len(i[0])
if len(i[1])>max_len:
max_len = len(i[1])
if len(i[2])>max_len:
max_len = len(i[2])
print(max_len)
# 填充
for i in data:
if len(i[0])<max_len:
i[0].extend([27] * (max_len-len(i[0])))
if len(i[1])<max_len:
i[1].extend([27] * (max_len-len(i[1])))
if len(i[2])<max_len:
i[2].extend([27] * (max_len-len(i[2])))
a.append(i[0])
b.append(i[1])
d.append(i[2])
c.extend(i[3])
# 這里要自己轉成tensor
a = torch.Tensor(a)
b = torch.Tensor(b)
c = torch.Tensor(c)
d = torch.Tensor(d)
data1 = [a,b,d,c]
print("data1",data1)
return data1
結果:
最后循環該dataloader ,拿到數據放入模型進行訓練:
for ii, data in enumerate(test_data_loader):
if opt.use_gpu:
data = list(map(lambda x: torch.LongTensor(x.long()).cuda(), data))
else:
data = list(map(lambda x: torch.LongTensor(x.long()), data))
out = model(data[:-1]) #數據data[:-1]
loss = F.cross_entropy(out, data[-1])# 最后一列是標簽
寫在最后:建議像我一樣剛開始不太熟練的小伙伴,在處理數據輸入的時候可以打印出來仔細查看。
原文鏈接:https://blog.csdn.net/weixin_60737527/article/details/126754254
相關推薦
- 2022-08-11 boost.asio框架系列之調度器io_service_C 語言
- 2023-02-15 C#?9使用foreach擴展的示例詳解_C#教程
- 2022-02-02 去掉chorme瀏覽器自動補全時input框的背景樣式
- 2023-01-12 一文帶你入木三分地理解字符串KMP算法以及C++實現_C 語言
- 2022-07-07 python?如何求N的階乘_python
- 2023-01-28 C#實現自定義單選和復選按鈕樣式_C#教程
- 2022-08-04 python連接FTP服務器的實現方法_python
- 2023-08-16 el-input輸入框去除邊框,且實現自動換行功能
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支