日本免费高清视频-国产福利视频导航-黄色在线播放国产-天天操天天操天天操天天操|www.shdianci.com

學無先后,達者為師

網站首頁 編程語言 正文

pytorch?collate_fn的基礎與應用教程_python

作者:音程 ? 更新時間: 2022-04-16 編程語言

作用

collate_fn:即用于collate的function,用于整理數據的函數。

說到整理數據,你當然要會用數據,即會用數據制作工具torch.utils.data.Dataset,雖然我們今天談的是torch.utils.data.DataLoader,但是,其實:

  1. 這兩個你如何定義;
  2. 從裝載器dataloader中取數據后做什么處理;
  3. 模型的forward()中如何處理。

這三部分都是有機統一而不可分割的,一個地方改了,其他地方就要改。

emmm…,小小總結,collate_fn籠統的說就是用于整理數據,通常我們不需要使用,其應用的情形是:各個數據長度不一樣的情況,比如第一張圖片大小是28*28,第二張是50*50,這樣的話就如果不自己寫collate_fn,而使用默認的,就會報錯。

原則

其實說起來,我們也沒有什么原則,但是如今大多數做深度學習都是使用GPU,所以這個時候我們需要記住一個總則:只有tensor數據類型才能運行在GPU上,list和numpy都不可以。

從而,我們什么時候將我們的數據轉化為tensor是一個問題,我的答案是前一節中的三個部分都可以來轉化,只是我們大多數的人都習慣在部分一轉化。

基礎

dataset

我們必須先看看torch.utils.data.Dataset如何使用,以一個例子為例:

import torch.utils.data as Data
class mydataset(Data.Dataset):
    def __init__(self,train_inputs,train_targets):#必須有
        super(mydataset,self).__init__()
        self.inputs=train_inputs
        self.targets=train_targets
        
    def __getitem__(self, index):#必須重寫
        return self.inputs[index],self.targets[index]
        
    def __len__(self):#必須重寫
        return len(self.targets)
#構造訓練數據
datax=torch.randn(4,3)#構造4個輸入
datay=torch.empty(4).random_(2)#構造4個標簽
#制作dataset
dataset=mydataset(datax,datay)

下面,可以對dataset進行一系列操作,這些操作返回的結果和你之前那個class的三個函數定義都息息相關。我想說,那三個函數非常自由,你想怎么定義就怎么定義,上述只是一種常見的而已,你可以定制一個特色的。

len(dataset)#調用了你上面定義的def __len__()那個函數
#4
dataset[0]#調用了你上面定義的def __getitem__()那個函數
#(tensor([-1.1426, -1.3239,  1.8372]), tensor(0.))

所以我再三強調的是上面的輸出結果和你的定義有關,比如你完全可以把def __getitem__()改成:

    def __getitem__(self, index):
        return self.inputs[index]#不輸出標簽

那么,

dataset[0]#此時當然變化。
#tensor([-1.1426, -1.3239,  1.8372])

可以看到,是非常隨便的,你隨便定制就好。

dataloader

torch.utils.data.DataLoader

dataloader=Data.DataLoader(dataset,batch_size=2)

4個數據,batch_size=2,所以一共有2個batch。

collate_fn如果你不指定,會調用pytorch內部的,也就是說這個函數是一定會調用的,而且調用這個函數時pytorch會往這個函數里面傳入一個參數batch。

def my_collate(batch):
	return xxx

這個batch是什么?這個東西和你定義的dataset, batch_size息息相關。batch是一個列表[x,...,xx],長度就是batch_size,里面每一個元素是dataset的某一個元素,即dataset[i](我在上一節展示過dataset[0])。

在我們的例子中,由于我們沒有對dataloader設置需要打亂數據,即shuffle=True,那么第1個batch就是前兩個數據,如下:

print(datax)
print(datay)
batch=[dataset[0],dataset[1]]#所以才說和你dataset中get_item的定義有關。
print(batch)

對,你沒有看錯,上述代碼展示的batch就會傳入到pytorch默認的collate_fn中,然后經過默認的處理,輸出如下:

it=iter(dataloader)
nex=next(it)#我們展示第一個batch經過collate_fn之后的輸出結果
print(nex)

其實,上面就是我們常用的,經典的輸出結果,即輸入和標簽是分開的,第一項是輸入tensor,第二項是標簽tensor,輸入的維度變成了(batch_size,input_size)。

但是我們乍一看,將第一個batch變成上述輸出結果很容易呀,我們也會!我們下面就來自己寫一個collate_fn實現這個功能。

# a simple custom collate function, just to show the idea
# `batch` is a list of tuple where first element is input tensor and the second element is corresponding label
def my_collate(batch):
    inputs=[data[0].tolist() for data in batch]
    target = torch.tensor([data[1] for data in batch])
    return [data, target]

?

dataloader=Data.DataLoader(dataset,batch_size=2,collate_fn=my_collate)
print(datax)
print(datay)

it=iter(dataloader)
nex=next(it)
print(nex)

這不就和默認的collate_fn的輸出結果一樣了嘛!無非就是默認的還把輸入變成了tensor,標簽變成了tensor,我上面是列表,我改就是了嘛!如下:

def my_collate(batch):
    inputs=[data[0].tolist() for data in batch]
    inputs=torch.tensor(inputs)
    target =[data[1].tolist() for data in batch]
    target=torch.tensor(target)
    return [inputs, target]
    
dataloader=Data.DataLoader(dataset,batch_size=2,collate_fn=my_collate)
it=iter(dataloader)
nex=next(it)
print(nex)

這下好了吧!

對了,作為彩蛋,告訴大家一個秘密:默認的collate_fn函數中有一些語句是轉tensor以及tensor合并的操作,所以你的dataset如果沒有設計成經典模式的話,使用默認的就容易報錯,而我們自己會寫collate_fn,當然就不存在這個問題啦。同時,給大家的一個經驗就是,一般dataset是不會報錯的,而是根據dataset制作dataloader的時候容易報錯,因為默認collate_fn把dataset的類型限制得比較死。

應用情形

假設我們還是4個輸入,但是維度不固定的。

a=[[1,2],[3,4,5],[1],[3,4,9]]
b=[1,0,0,1]
dataset=mydataset(a,b)
dataloader=Data.DataLoader(dataset,batch_size=2)
it=iter(dataloader)
nex=next(it)
nex

使用默認的collate_fn,直接報錯,要求相同維度。

這個時候,我們可以使用自己的collate_fn,避免報錯。

不過話說回來,我個人感受是:

在這里避免報錯好像也沒有什么用,因為大多數的神經網絡都是定長輸入的,而且很多的操作也要求相同維度才能相加或相乘,所以:這里不報錯,后面還是報錯。如果后面解決這個問題的方法是:在不足維度上進行補0操作,那么我們為什么不在建立dataset之前先補好呢?所以,collate_fn這個東西的應用場景還是有限的。

總結

原文鏈接:https://blog.csdn.net/qq_43391414/article/details/120462055

欄目分類
最近更新