網站首頁 編程語言 正文
一、標準的數據集流程梳理
分為幾個步驟
數據準備以及加載數據庫–>數據加載器的調用或者設計–>批量調用進行訓練或者其他作用
數據來源
直接讀取了x和y的數據變量,對比后面的就從把對應的路徑寫進了文本文件中,通過加載器進行讀取
x = torch.linspace(1, 10, 10) # 訓練數據 linspace返回一個一維的張量,(最小值,最大值,多少個數) print(x) y = torch.linspace(10, 1, 10) # 標簽 print(y)
將數據加載進數據庫
輸出的結果是<torch.utils.data.dataset.TensorDataset object at 0x00000145BD93F1C0>
,需要使用加載器進行加載,才能迭代遍歷
import torch.utils.data as Data torch_dataset = Data.TensorDataset(x, y) # 對給定的 tensor 數據,將他們包裝成 dataset #輸出的結果是<torch.utils.data.dataset.TensorDataset object at 0x00000145BD93F1C0>,需要使用加載器進行加載,才能迭代遍歷 print(torch_dataset)
所以要想看里面的內容,就需要用迭代進行操作或者查看。
BATCH_SIZE=5 loader = Data.DataLoader(#使用支持的默認的數據集加載的方式 # 從數據庫中每次抽出batch size個樣本 dataset=torch_dataset, # torch TensorDataset format 加載數據集 batch_size=BATCH_SIZE, # mini batch size 5 shuffle=False, # 要不要打亂數據 (打亂比較好) num_workers=2, # 多線程來讀數據 ) def show_batch(): for epoch in range(3): for step, (batch_x, batch_y) in enumerate(loader): #加載數據集的時候起的作用很奇怪 # training print("steop:{}, batch_x:{}, batch_y:{}".format(step, batch_x, batch_y)) print("*"*100) if __name__ == '__main__': show_batch()
二、實現加載自己的數據集
實現自己的數據集就需要完成對dataset類的重載。這個類的重載完成幾個函數的作用
- 初始化數據集中的數據以及標簽
__init__()
- 返回數據和對應標簽
__getitem__
- 返回數據集的大小
__len__
基本的數據集的方法就是完成以上步驟,但是可以想想數據集通常是一些圖片和標簽組成,而這些數據集以及標簽是保存在計算機上,具有相對應的位置,那么直接訪問對應的位置因為是在文件夾下需要進行遍歷等一系列操作,而且這就顯得和dataset類沒有解耦,因為有時候在這些位置的操作可能會有一些特殊操作,所以如果能夠將其位置保存在文本文件中可能就會方便很多,所以就采取保存文本文件的方式。
# 自定義數據集類 class MyDataset(torch.utils.data.Dataset): def __init__(self, *args): super().__init__() # 初始化數據集包含的數據和標簽 pass def __getitem__(self, index): # 根據索引index從文件中讀取一個數據 # 對數據預處理 # 返回數據和對應標簽 pass def __len__(self): # 返回數據集的大小 return len()
1. 保存在txt文件中(生成訓練集和測試集,其實這里的訓練集以及測試集也都是用文本文件的形式保存下來的)
所以這里新建一個數據庫就是新建了兩個文本文件,然后加載器通過文本文件就將圖片以及label加載進去了。而標準的數據集操作是使用了自帶的數據集接口,在加載的時候也不用再去實現相關的__getitem__方法
- 數組定義
- 將絕對路徑加載進數組中
- 數組定義
- 將絕對路徑加載進數組中
- 通過os.walk操作
- os.walk可以獲得根路徑、文件夾以及文件,并會一直進行迭代遍歷下去,直至只有文件才會結束
- 將數組的內容打亂順序
- 分別將絕對路徑對應的數組內容寫進文本文件里,那么這里的文本文件就是保存的數據庫,其實數據就是一個保存相關信息或者其內容的文件,而標準也是將將其數據保存在了一個地方,然后對應到標準接口就可以加載了(Data.TensorDataset以及Data.DataLoader)
以下代碼用于生成對應的train.txt val.txt
''' 生成訓練集和測試集,保存在txt文件中 ''' import os import random train_ratio = 0.6 test_ratio = 1-train_ratio rootdata = r"dataset" #數組定義 train_list, test_list = [],[] data_list = [] class_flag = -1 # 將絕對路徑加載進數組中 for a,b,c in os.walk(rootdata):#os.walk可以獲得根路徑、文件夾以及文件,并會一直進行迭代遍歷下去,直至只有文件才會結束 print(a) for i in range(len(c)): data_list.append(os.path.join(a,c[i])) for i in range(0,int(len(c)*train_ratio)): train_data = os.path.join(a, c[i])+'\t'+str(class_flag)+'\n' #class_flag表示分類的類別 train_list.append(train_data) for i in range(int(len(c) * train_ratio),len(c)): test_data = os.path.join(a, c[i]) + '\t' + str(class_flag)+'\n' test_list.append(test_data) class_flag += 1 print(train_list) # 將數組的內容打亂順序 random.shuffle(train_list) random.shuffle(test_list) #分別將絕對路徑對應的數組內容寫進文本文件里 with open('train.txt','w',encoding='UTF-8') as f: for train_img in train_list: f.write(str(train_img)) with open('test.txt','w',encoding='UTF-8') as f: for test_img in test_list: f.write(test_img)
2. 在繼承dataset類LoadData的三個函數里調用train.txt以及test.txt實現相關功能
初始化數據集中的數據以及標簽、相關變量__init__()
def __init__(self, txt_path, train_flag=True): #初始化圖片對應的變量imgs_info以及一些相關變量 self.imgs_info = self.get_images(txt_path) #imgs_info保存了圖片以及標簽 self.train_flag = train_flag self.train_tf = transforms.Compose([#對訓練集的圖片進行預處理 transforms.Resize(224), transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.ToTensor(), transform_BZ ]) self.val_tf = transforms.Compose([#對測試集的圖片進行預處理 transforms.Resize(224), transforms.ToTensor(), transform_BZ ])
返回數據和對應標簽__getitem__
def __getitem__(self, index): img_path, label = self.imgs_info[index] #打開圖片,并將RGBA轉換為RGB,這里是通過PIL庫打開圖片的 img = Image.open(img_path) img = img.convert('RGB') img = self.padding_black(img) #將圖片添加上黑邊的 if self.train_flag: #選擇是訓練集還是測試集 img = self.train_tf(img) else: img = self.val_tf(img) label = int(label) return img, label
返回數據集的大小__len__
def __len__(self): return len(self.imgs_info)
由于前面已經對集成dataset的類進行了實現三種方法,那么就可以在加載器中進行加載,將加載后的數據傳入到train函數或者test函數都可以
-
train_dataloader = DataLoader(dataset=train_data, num_workers=4, pin_memory=True, batch_size=batch_size, shuffle=True)
:使用加載器加載數據 -
train(train_dataloader, model, loss_fn, optimizer) test(test_dataloader, model)
:將數據傳入train或者test中進行訓練或者測試 - 注意:LoadData是繼承了dataset的類
if __name__=='__main__': batch_size = 16 # # 給訓練集和測試集分別創建一個數據集加載器 train_data = LoadData("train.txt", True) valid_data = LoadData("test.txt", False) train_dataloader = DataLoader(dataset=train_data, num_workers=4, pin_memory=True, batch_size=batch_size, shuffle=True) test_dataloader = DataLoader(dataset=valid_data, num_workers=4, pin_memory=True, batch_size=batch_size) for X, y in test_dataloader: print("Shape of X [N, C, H, W]: ", X.shape) print("Shape of y: ", y.shape, y.dtype) break
原文鏈接:https://blog.csdn.net/weixin_42295969/article/details/126333679
相關推薦
- 2022-10-13 Python中?whl包、tar.gz包的區別詳解_python
- 2022-04-21 python數據類型bytes?和?bytearray的使用與區別_python
- 2022-06-17 mongodb?數據塊的遷移流程分析_MongoDB
- 2022-05-23 C++的繼承特性你了解嗎_C 語言
- 2022-06-26 利用ASP.Net?Core中的Razor實現動態菜單_實用技巧
- 2024-01-10 右鍵添加 idea 打開功能
- 2022-06-18 C語言圖文并茂詳解鏈接過程_C 語言
- 2022-12-13 C++?POSIX?API超詳細分析_C 語言
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支