日本免费高清视频-国产福利视频导航-黄色在线播放国产-天天操天天操天天操天天操|www.shdianci.com

學(xué)無(wú)先后,達(dá)者為師

網(wǎng)站首頁(yè) 編程語(yǔ)言 正文

pytorch加載自己的數(shù)據(jù)集源碼分享_python

作者:徽先生 ? 更新時(shí)間: 2022-10-11 編程語(yǔ)言

一、標(biāo)準(zhǔn)的數(shù)據(jù)集流程梳理

分為幾個(gè)步驟
數(shù)據(jù)準(zhǔn)備以及加載數(shù)據(jù)庫(kù)–>數(shù)據(jù)加載器的調(diào)用或者設(shè)計(jì)–>批量調(diào)用進(jìn)行訓(xùn)練或者其他作用

數(shù)據(jù)來(lái)源

直接讀取了x和y的數(shù)據(jù)變量,對(duì)比后面的就從把對(duì)應(yīng)的路徑寫(xiě)進(jìn)了文本文件中,通過(guò)加載器進(jìn)行讀取

x = torch.linspace(1, 10, 10)   # 訓(xùn)練數(shù)據(jù) linspace返回一個(gè)一維的張量,(最小值,最大值,多少個(gè)數(shù))
print(x)
y = torch.linspace(10, 1, 10)   # 標(biāo)簽
print(y)

將數(shù)據(jù)加載進(jìn)數(shù)據(jù)庫(kù)

輸出的結(jié)果是<torch.utils.data.dataset.TensorDataset object at 0x00000145BD93F1C0>,需要使用加載器進(jìn)行加載,才能迭代遍歷

import torch.utils.data as Data
torch_dataset = Data.TensorDataset(x, y)  # 對(duì)給定的 tensor 數(shù)據(jù),將他們包裝成 dataset
#輸出的結(jié)果是<torch.utils.data.dataset.TensorDataset object at 0x00000145BD93F1C0>,需要使用加載器進(jìn)行加載,才能迭代遍歷
print(torch_dataset)

所以要想看里面的內(nèi)容,就需要用迭代進(jìn)行操作或者查看。

BATCH_SIZE=5
loader = Data.DataLoader(#使用支持的默認(rèn)的數(shù)據(jù)集加載的方式
    # 從數(shù)據(jù)庫(kù)中每次抽出batch size個(gè)樣本
    dataset=torch_dataset,       # torch TensorDataset format   加載數(shù)據(jù)集
    batch_size=BATCH_SIZE,       # mini batch size 5
    shuffle=False,                # 要不要打亂數(shù)據(jù) (打亂比較好)
    num_workers=2,               # 多線程來(lái)讀數(shù)據(jù)
)
 
def show_batch():
    for epoch in range(3):
        for step, (batch_x, batch_y) in enumerate(loader): #加載數(shù)據(jù)集的時(shí)候起的作用很奇怪
            # training
            print("steop:{}, batch_x:{}, batch_y:{}".format(step, batch_x, batch_y))
            print("*"*100)
if __name__ == '__main__':
    show_batch()

二、實(shí)現(xiàn)加載自己的數(shù)據(jù)集

實(shí)現(xiàn)自己的數(shù)據(jù)集就需要完成對(duì)dataset類的重載。這個(gè)類的重載完成幾個(gè)函數(shù)的作用

  • 初始化數(shù)據(jù)集中的數(shù)據(jù)以及標(biāo)簽__init__()
  • 返回?cái)?shù)據(jù)和對(duì)應(yīng)標(biāo)簽__getitem__
  • 返回?cái)?shù)據(jù)集的大小__len__

基本的數(shù)據(jù)集的方法就是完成以上步驟,但是可以想想數(shù)據(jù)集通常是一些圖片和標(biāo)簽組成,而這些數(shù)據(jù)集以及標(biāo)簽是保存在計(jì)算機(jī)上,具有相對(duì)應(yīng)的位置,那么直接訪問(wèn)對(duì)應(yīng)的位置因?yàn)槭窃谖募A下需要進(jìn)行遍歷等一系列操作,而且這就顯得和dataset類沒(méi)有解耦,因?yàn)橛袝r(shí)候在這些位置的操作可能會(huì)有一些特殊操作,所以如果能夠?qū)⑵湮恢帽4嬖谖谋疚募锌赡芫蜁?huì)方便很多,所以就采取保存文本文件的方式。

# 自定義數(shù)據(jù)集類
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, *args):
        super().__init__()
        # 初始化數(shù)據(jù)集包含的數(shù)據(jù)和標(biāo)簽
        pass
        
    def __getitem__(self, index):
        # 根據(jù)索引index從文件中讀取一個(gè)數(shù)據(jù)
        # 對(duì)數(shù)據(jù)預(yù)處理
        # 返回?cái)?shù)據(jù)和對(duì)應(yīng)標(biāo)簽
        pass
    
    def __len__(self):
        # 返回?cái)?shù)據(jù)集的大小
        return len()

1. 保存在txt文件中(生成訓(xùn)練集和測(cè)試集,其實(shí)這里的訓(xùn)練集以及測(cè)試集也都是用文本文件的形式保存下來(lái)的)

所以這里新建一個(gè)數(shù)據(jù)庫(kù)就是新建了兩個(gè)文本文件,然后加載器通過(guò)文本文件就將圖片以及l(fā)abel加載進(jìn)去了。而標(biāo)準(zhǔn)的數(shù)據(jù)集操作是使用了自帶的數(shù)據(jù)集接口,在加載的時(shí)候也不用再去實(shí)現(xiàn)相關(guān)的__getitem__方法

  • 數(shù)組定義
  • 將絕對(duì)路徑加載進(jìn)數(shù)組中
  • 數(shù)組定義
  • 將絕對(duì)路徑加載進(jìn)數(shù)組中
  • 通過(guò)os.walk操作
  • os.walk可以獲得根路徑、文件夾以及文件,并會(huì)一直進(jìn)行迭代遍歷下去,直至只有文件才會(huì)結(jié)束
  • 將數(shù)組的內(nèi)容打亂順序
  • 分別將絕對(duì)路徑對(duì)應(yīng)的數(shù)組內(nèi)容寫(xiě)進(jìn)文本文件里,那么這里的文本文件就是保存的數(shù)據(jù)庫(kù),其實(shí)數(shù)據(jù)就是一個(gè)保存相關(guān)信息或者其內(nèi)容的文件,而標(biāo)準(zhǔn)也是將將其數(shù)據(jù)保存在了一個(gè)地方,然后對(duì)應(yīng)到標(biāo)準(zhǔn)接口就可以加載了(Data.TensorDataset以及Data.DataLoader)

以下代碼用于生成對(duì)應(yīng)的train.txt val.txt

'''
生成訓(xùn)練集和測(cè)試集,保存在txt文件中
'''
import os
import random


train_ratio = 0.6


test_ratio = 1-train_ratio

rootdata = r"dataset"

#數(shù)組定義
train_list, test_list = [],[]
data_list = []

class_flag = -1
# 將絕對(duì)路徑加載進(jìn)數(shù)組中
for a,b,c in os.walk(rootdata):#os.walk可以獲得根路徑、文件夾以及文件,并會(huì)一直進(jìn)行迭代遍歷下去,直至只有文件才會(huì)結(jié)束
    print(a)
    for i in range(len(c)):
        data_list.append(os.path.join(a,c[i]))

    for i in range(0,int(len(c)*train_ratio)):
        train_data = os.path.join(a, c[i])+'\t'+str(class_flag)+'\n' #class_flag表示分類的類別
        train_list.append(train_data)

    for i in range(int(len(c) * train_ratio),len(c)):
        test_data = os.path.join(a, c[i]) + '\t' + str(class_flag)+'\n'
        test_list.append(test_data)

    class_flag += 1 

print(train_list)
# 將數(shù)組的內(nèi)容打亂順序
random.shuffle(train_list)
random.shuffle(test_list)

#分別將絕對(duì)路徑對(duì)應(yīng)的數(shù)組內(nèi)容寫(xiě)進(jìn)文本文件里
with open('train.txt','w',encoding='UTF-8') as f:
    for train_img in train_list:
        f.write(str(train_img))

with open('test.txt','w',encoding='UTF-8') as f:
    for test_img in test_list:
        f.write(test_img)

2. 在繼承dataset類LoadData的三個(gè)函數(shù)里調(diào)用train.txt以及test.txt實(shí)現(xiàn)相關(guān)功能

初始化數(shù)據(jù)集中的數(shù)據(jù)以及標(biāo)簽、相關(guān)變量__init__()

def __init__(self, txt_path, train_flag=True):
     #初始化圖片對(duì)應(yīng)的變量imgs_info以及一些相關(guān)變量
     self.imgs_info = self.get_images(txt_path) #imgs_info保存了圖片以及標(biāo)簽
     self.train_flag = train_flag

     self.train_tf = transforms.Compose([#對(duì)訓(xùn)練集的圖片進(jìn)行預(yù)處理
             transforms.Resize(224),
             transforms.RandomHorizontalFlip(),
             transforms.RandomVerticalFlip(),
             transforms.ToTensor(),
             transform_BZ
         ])
     self.val_tf = transforms.Compose([#對(duì)測(cè)試集的圖片進(jìn)行預(yù)處理
             transforms.Resize(224),
             transforms.ToTensor(),
             transform_BZ
         ])

返回數(shù)據(jù)對(duì)應(yīng)標(biāo)簽__getitem__

def __getitem__(self, index):
     img_path, label = self.imgs_info[index]
     #打開(kāi)圖片,并將RGBA轉(zhuǎn)換為RGB,這里是通過(guò)PIL庫(kù)打開(kāi)圖片的
     img = Image.open(img_path)
     img = img.convert('RGB')
     img = self.padding_black(img) #將圖片添加上黑邊的
     if self.train_flag: #選擇是訓(xùn)練集還是測(cè)試集
         img = self.train_tf(img)
     else:
         img = self.val_tf(img)
     label = int(label)

     return img, label

返回?cái)?shù)據(jù)集的大小__len__

def __len__(self):
     return len(self.imgs_info)

由于前面已經(jīng)對(duì)集成dataset的類進(jìn)行了實(shí)現(xiàn)三種方法,那么就可以在加載器中進(jìn)行加載,將加載后的數(shù)據(jù)傳入到train函數(shù)或者test函數(shù)都可以

  • train_dataloader = DataLoader(dataset=train_data, num_workers=4, pin_memory=True, batch_size=batch_size, shuffle=True):使用加載器加載數(shù)據(jù)
  • train(train_dataloader, model, loss_fn, optimizer) test(test_dataloader, model):將數(shù)據(jù)傳入train或者test中進(jìn)行訓(xùn)練或者測(cè)試
  • 注意:LoadData是繼承了dataset的類
if __name__=='__main__':
    batch_size = 16

    # # 給訓(xùn)練集和測(cè)試集分別創(chuàng)建一個(gè)數(shù)據(jù)集加載器
    train_data = LoadData("train.txt", True)
    valid_data = LoadData("test.txt", False)


    train_dataloader = DataLoader(dataset=train_data, num_workers=4, pin_memory=True, batch_size=batch_size, shuffle=True)
    test_dataloader = DataLoader(dataset=valid_data, num_workers=4, pin_memory=True, batch_size=batch_size)

    for X, y in test_dataloader:
        print("Shape of X [N, C, H, W]: ", X.shape)
        print("Shape of y: ", y.shape, y.dtype)
        break




原文鏈接:https://blog.csdn.net/weixin_42295969/article/details/126333679

欄目分類
最近更新