網(wǎng)站首頁(yè) 編程語(yǔ)言 正文
一、標(biāo)準(zhǔn)的數(shù)據(jù)集流程梳理
分為幾個(gè)步驟
數(shù)據(jù)準(zhǔn)備以及加載數(shù)據(jù)庫(kù)–>數(shù)據(jù)加載器的調(diào)用或者設(shè)計(jì)–>批量調(diào)用進(jìn)行訓(xùn)練或者其他作用
數(shù)據(jù)來(lái)源
直接讀取了x和y的數(shù)據(jù)變量,對(duì)比后面的就從把對(duì)應(yīng)的路徑寫(xiě)進(jìn)了文本文件中,通過(guò)加載器進(jìn)行讀取
x = torch.linspace(1, 10, 10) # 訓(xùn)練數(shù)據(jù) linspace返回一個(gè)一維的張量,(最小值,最大值,多少個(gè)數(shù)) print(x) y = torch.linspace(10, 1, 10) # 標(biāo)簽 print(y)
將數(shù)據(jù)加載進(jìn)數(shù)據(jù)庫(kù)
輸出的結(jié)果是<torch.utils.data.dataset.TensorDataset object at 0x00000145BD93F1C0>
,需要使用加載器進(jìn)行加載,才能迭代遍歷
import torch.utils.data as Data torch_dataset = Data.TensorDataset(x, y) # 對(duì)給定的 tensor 數(shù)據(jù),將他們包裝成 dataset #輸出的結(jié)果是<torch.utils.data.dataset.TensorDataset object at 0x00000145BD93F1C0>,需要使用加載器進(jìn)行加載,才能迭代遍歷 print(torch_dataset)
所以要想看里面的內(nèi)容,就需要用迭代進(jìn)行操作或者查看。
BATCH_SIZE=5 loader = Data.DataLoader(#使用支持的默認(rèn)的數(shù)據(jù)集加載的方式 # 從數(shù)據(jù)庫(kù)中每次抽出batch size個(gè)樣本 dataset=torch_dataset, # torch TensorDataset format 加載數(shù)據(jù)集 batch_size=BATCH_SIZE, # mini batch size 5 shuffle=False, # 要不要打亂數(shù)據(jù) (打亂比較好) num_workers=2, # 多線程來(lái)讀數(shù)據(jù) ) def show_batch(): for epoch in range(3): for step, (batch_x, batch_y) in enumerate(loader): #加載數(shù)據(jù)集的時(shí)候起的作用很奇怪 # training print("steop:{}, batch_x:{}, batch_y:{}".format(step, batch_x, batch_y)) print("*"*100) if __name__ == '__main__': show_batch()
二、實(shí)現(xiàn)加載自己的數(shù)據(jù)集
實(shí)現(xiàn)自己的數(shù)據(jù)集就需要完成對(duì)dataset類的重載。這個(gè)類的重載完成幾個(gè)函數(shù)的作用
- 初始化數(shù)據(jù)集中的數(shù)據(jù)以及標(biāo)簽
__init__()
- 返回?cái)?shù)據(jù)和對(duì)應(yīng)標(biāo)簽
__getitem__
- 返回?cái)?shù)據(jù)集的大小
__len__
基本的數(shù)據(jù)集的方法就是完成以上步驟,但是可以想想數(shù)據(jù)集通常是一些圖片和標(biāo)簽組成,而這些數(shù)據(jù)集以及標(biāo)簽是保存在計(jì)算機(jī)上,具有相對(duì)應(yīng)的位置,那么直接訪問(wèn)對(duì)應(yīng)的位置因?yàn)槭窃谖募A下需要進(jìn)行遍歷等一系列操作,而且這就顯得和dataset類沒(méi)有解耦,因?yàn)橛袝r(shí)候在這些位置的操作可能會(huì)有一些特殊操作,所以如果能夠?qū)⑵湮恢帽4嬖谖谋疚募锌赡芫蜁?huì)方便很多,所以就采取保存文本文件的方式。
# 自定義數(shù)據(jù)集類 class MyDataset(torch.utils.data.Dataset): def __init__(self, *args): super().__init__() # 初始化數(shù)據(jù)集包含的數(shù)據(jù)和標(biāo)簽 pass def __getitem__(self, index): # 根據(jù)索引index從文件中讀取一個(gè)數(shù)據(jù) # 對(duì)數(shù)據(jù)預(yù)處理 # 返回?cái)?shù)據(jù)和對(duì)應(yīng)標(biāo)簽 pass def __len__(self): # 返回?cái)?shù)據(jù)集的大小 return len()
1. 保存在txt文件中(生成訓(xùn)練集和測(cè)試集,其實(shí)這里的訓(xùn)練集以及測(cè)試集也都是用文本文件的形式保存下來(lái)的)
所以這里新建一個(gè)數(shù)據(jù)庫(kù)就是新建了兩個(gè)文本文件,然后加載器通過(guò)文本文件就將圖片以及l(fā)abel加載進(jìn)去了。而標(biāo)準(zhǔn)的數(shù)據(jù)集操作是使用了自帶的數(shù)據(jù)集接口,在加載的時(shí)候也不用再去實(shí)現(xiàn)相關(guān)的__getitem__方法
- 數(shù)組定義
- 將絕對(duì)路徑加載進(jìn)數(shù)組中
- 數(shù)組定義
- 將絕對(duì)路徑加載進(jìn)數(shù)組中
- 通過(guò)os.walk操作
- os.walk可以獲得根路徑、文件夾以及文件,并會(huì)一直進(jìn)行迭代遍歷下去,直至只有文件才會(huì)結(jié)束
- 將數(shù)組的內(nèi)容打亂順序
- 分別將絕對(duì)路徑對(duì)應(yīng)的數(shù)組內(nèi)容寫(xiě)進(jìn)文本文件里,那么這里的文本文件就是保存的數(shù)據(jù)庫(kù),其實(shí)數(shù)據(jù)就是一個(gè)保存相關(guān)信息或者其內(nèi)容的文件,而標(biāo)準(zhǔn)也是將將其數(shù)據(jù)保存在了一個(gè)地方,然后對(duì)應(yīng)到標(biāo)準(zhǔn)接口就可以加載了(Data.TensorDataset以及Data.DataLoader)
以下代碼用于生成對(duì)應(yīng)的train.txt val.txt
''' 生成訓(xùn)練集和測(cè)試集,保存在txt文件中 ''' import os import random train_ratio = 0.6 test_ratio = 1-train_ratio rootdata = r"dataset" #數(shù)組定義 train_list, test_list = [],[] data_list = [] class_flag = -1 # 將絕對(duì)路徑加載進(jìn)數(shù)組中 for a,b,c in os.walk(rootdata):#os.walk可以獲得根路徑、文件夾以及文件,并會(huì)一直進(jìn)行迭代遍歷下去,直至只有文件才會(huì)結(jié)束 print(a) for i in range(len(c)): data_list.append(os.path.join(a,c[i])) for i in range(0,int(len(c)*train_ratio)): train_data = os.path.join(a, c[i])+'\t'+str(class_flag)+'\n' #class_flag表示分類的類別 train_list.append(train_data) for i in range(int(len(c) * train_ratio),len(c)): test_data = os.path.join(a, c[i]) + '\t' + str(class_flag)+'\n' test_list.append(test_data) class_flag += 1 print(train_list) # 將數(shù)組的內(nèi)容打亂順序 random.shuffle(train_list) random.shuffle(test_list) #分別將絕對(duì)路徑對(duì)應(yīng)的數(shù)組內(nèi)容寫(xiě)進(jìn)文本文件里 with open('train.txt','w',encoding='UTF-8') as f: for train_img in train_list: f.write(str(train_img)) with open('test.txt','w',encoding='UTF-8') as f: for test_img in test_list: f.write(test_img)
2. 在繼承dataset類LoadData的三個(gè)函數(shù)里調(diào)用train.txt以及test.txt實(shí)現(xiàn)相關(guān)功能
初始化數(shù)據(jù)集中的數(shù)據(jù)以及標(biāo)簽、相關(guān)變量__init__()
def __init__(self, txt_path, train_flag=True): #初始化圖片對(duì)應(yīng)的變量imgs_info以及一些相關(guān)變量 self.imgs_info = self.get_images(txt_path) #imgs_info保存了圖片以及標(biāo)簽 self.train_flag = train_flag self.train_tf = transforms.Compose([#對(duì)訓(xùn)練集的圖片進(jìn)行預(yù)處理 transforms.Resize(224), transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.ToTensor(), transform_BZ ]) self.val_tf = transforms.Compose([#對(duì)測(cè)試集的圖片進(jìn)行預(yù)處理 transforms.Resize(224), transforms.ToTensor(), transform_BZ ])
返回數(shù)據(jù)和對(duì)應(yīng)標(biāo)簽__getitem__
def __getitem__(self, index): img_path, label = self.imgs_info[index] #打開(kāi)圖片,并將RGBA轉(zhuǎn)換為RGB,這里是通過(guò)PIL庫(kù)打開(kāi)圖片的 img = Image.open(img_path) img = img.convert('RGB') img = self.padding_black(img) #將圖片添加上黑邊的 if self.train_flag: #選擇是訓(xùn)練集還是測(cè)試集 img = self.train_tf(img) else: img = self.val_tf(img) label = int(label) return img, label
返回?cái)?shù)據(jù)集的大小__len__
def __len__(self): return len(self.imgs_info)
由于前面已經(jīng)對(duì)集成dataset的類進(jìn)行了實(shí)現(xiàn)三種方法,那么就可以在加載器中進(jìn)行加載,將加載后的數(shù)據(jù)傳入到train函數(shù)或者test函數(shù)都可以
-
train_dataloader = DataLoader(dataset=train_data, num_workers=4, pin_memory=True, batch_size=batch_size, shuffle=True)
:使用加載器加載數(shù)據(jù) -
train(train_dataloader, model, loss_fn, optimizer) test(test_dataloader, model)
:將數(shù)據(jù)傳入train或者test中進(jìn)行訓(xùn)練或者測(cè)試 - 注意:LoadData是繼承了dataset的類
if __name__=='__main__': batch_size = 16 # # 給訓(xùn)練集和測(cè)試集分別創(chuàng)建一個(gè)數(shù)據(jù)集加載器 train_data = LoadData("train.txt", True) valid_data = LoadData("test.txt", False) train_dataloader = DataLoader(dataset=train_data, num_workers=4, pin_memory=True, batch_size=batch_size, shuffle=True) test_dataloader = DataLoader(dataset=valid_data, num_workers=4, pin_memory=True, batch_size=batch_size) for X, y in test_dataloader: print("Shape of X [N, C, H, W]: ", X.shape) print("Shape of y: ", y.shape, y.dtype) break
原文鏈接:https://blog.csdn.net/weixin_42295969/article/details/126333679
相關(guān)推薦
- 2022-08-06 Qt編寫(xiě)顯示密碼強(qiáng)度的控件_C 語(yǔ)言
- 2022-03-20 6ull加載linux驅(qū)動(dòng)模塊失敗解決方法_Linux
- 2022-08-23 C++詳解使用floor&ceil&round實(shí)現(xiàn)保留小數(shù)點(diǎn)后兩位_C 語(yǔ)言
- 2023-02-12 基于Redis驗(yàn)證碼發(fā)送及校驗(yàn)方案實(shí)現(xiàn)_Redis
- 2022-04-24 python使用技巧-查找文件?_python
- 2022-03-06 C#中List用法介紹詳解_C#教程
- 2022-08-28 keil5仿真相關(guān)配置,解決相關(guān)bug
- 2022-06-18 C#如何在窗體程序中操作數(shù)據(jù)庫(kù)數(shù)據(jù)_C#教程
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細(xì)win安裝深度學(xué)習(xí)環(huán)境2025年最新版(
- Linux 中運(yùn)行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲(chǔ)小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運(yùn)算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認(rèn)證信息的處理
- Spring Security之認(rèn)證過(guò)濾器
- Spring Security概述快速入門(mén)
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設(shè)
- maven:解決release錯(cuò)誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實(shí)現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務(wù)發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡(jiǎn)單動(dòng)態(tài)字符串(SD
- arthas操作spring被代理目標(biāo)對(duì)象命令
- Spring中的單例模式應(yīng)用詳解
- 聊聊消息隊(duì)列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠(yuǎn)程分支