網(wǎng)站首頁 編程語言 正文
作用
collate_fn:即用于collate的function,用于整理數(shù)據(jù)的函數(shù)。
說到整理數(shù)據(jù),你當(dāng)然要會用數(shù)據(jù),即會用數(shù)據(jù)制作工具torch.utils.data.Dataset
,雖然我們今天談的是torch.utils.data.DataLoader
,但是,其實:
- 這兩個你如何定義;
- 從裝載器dataloader中取數(shù)據(jù)后做什么處理;
- 模型的forward()中如何處理。
這三部分都是有機統(tǒng)一而不可分割的,一個地方改了,其他地方就要改。
emmm…,小小總結(jié),collate_fn籠統(tǒng)的說就是用于整理數(shù)據(jù),通常我們不需要使用,其應(yīng)用的情形是:各個數(shù)據(jù)長度不一樣的情況,比如第一張圖片大小是28*28,第二張是50*50,這樣的話就如果不自己寫collate_fn,而使用默認(rèn)的,就會報錯。
原則
其實說起來,我們也沒有什么原則,但是如今大多數(shù)做深度學(xué)習(xí)都是使用GPU,所以這個時候我們需要記住一個總則:只有tensor數(shù)據(jù)類型才能運行在GPU上,list和numpy都不可以。
從而,我們什么時候?qū)⑽覀兊臄?shù)據(jù)轉(zhuǎn)化為tensor是一個問題,我的答案是前一節(jié)中的三個部分都可以來轉(zhuǎn)化,只是我們大多數(shù)的人都習(xí)慣在部分一轉(zhuǎn)化。
基礎(chǔ)
dataset
我們必須先看看torch.utils.data.Dataset如何使用,以一個例子為例:
import torch.utils.data as Data class mydataset(Data.Dataset): def __init__(self,train_inputs,train_targets):#必須有 super(mydataset,self).__init__() self.inputs=train_inputs self.targets=train_targets def __getitem__(self, index):#必須重寫 return self.inputs[index],self.targets[index] def __len__(self):#必須重寫 return len(self.targets)
#構(gòu)造訓(xùn)練數(shù)據(jù) datax=torch.randn(4,3)#構(gòu)造4個輸入 datay=torch.empty(4).random_(2)#構(gòu)造4個標(biāo)簽
#制作dataset dataset=mydataset(datax,datay)
下面,可以對dataset進(jìn)行一系列操作,這些操作返回的結(jié)果和你之前那個class的三個函數(shù)定義都息息相關(guān)。我想說,那三個函數(shù)非常自由,你想怎么定義就怎么定義,上述只是一種常見的而已,你可以定制一個特色的。
len(dataset)#調(diào)用了你上面定義的def __len__()那個函數(shù) #4
dataset[0]#調(diào)用了你上面定義的def __getitem__()那個函數(shù) #(tensor([-1.1426, -1.3239, 1.8372]), tensor(0.))
所以我再三強調(diào)的是上面的輸出結(jié)果和你的定義有關(guān),比如你完全可以把def __getitem__()改成:
def __getitem__(self, index): return self.inputs[index]#不輸出標(biāo)簽
那么,
dataset[0]#此時當(dāng)然變化。 #tensor([-1.1426, -1.3239, 1.8372])
可以看到,是非常隨便的,你隨便定制就好。
dataloader
torch.utils.data.DataLoader
dataloader=Data.DataLoader(dataset,batch_size=2)
4個數(shù)據(jù),batch_size=2,所以一共有2個batch。
collate_fn如果你不指定,會調(diào)用pytorch內(nèi)部的,也就是說這個函數(shù)是一定會調(diào)用的,而且調(diào)用這個函數(shù)時pytorch會往這個函數(shù)里面?zhèn)魅胍粋€參數(shù)batch。
def my_collate(batch): return xxx
這個batch是什么?這個東西和你定義的dataset, batch_size息息相關(guān)。batch是一個列表[x,...,xx],長度就是batch_size,里面每一個元素是dataset的某一個元素,即dataset[i](我在上一節(jié)展示過dataset[0])。
在我們的例子中,由于我們沒有對dataloader設(shè)置需要打亂數(shù)據(jù),即shuffle=True,那么第1個batch就是前兩個數(shù)據(jù),如下:
print(datax) print(datay) batch=[dataset[0],dataset[1]]#所以才說和你dataset中g(shù)et_item的定義有關(guān)。 print(batch)
對,你沒有看錯,上述代碼展示的batch就會傳入到pytorch默認(rèn)的collate_fn中,然后經(jīng)過默認(rèn)的處理,輸出如下:
it=iter(dataloader) nex=next(it)#我們展示第一個batch經(jīng)過collate_fn之后的輸出結(jié)果 print(nex)
其實,上面就是我們常用的,經(jīng)典的輸出結(jié)果,即輸入和標(biāo)簽是分開的,第一項是輸入tensor,第二項是標(biāo)簽tensor,輸入的維度變成了(batch_size,input_size)。
但是我們乍一看,將第一個batch變成上述輸出結(jié)果很容易呀,我們也會!我們下面就來自己寫一個collate_fn實現(xiàn)這個功能。
# a simple custom collate function, just to show the idea # `batch` is a list of tuple where first element is input tensor and the second element is corresponding label def my_collate(batch): inputs=[data[0].tolist() for data in batch] target = torch.tensor([data[1] for data in batch]) return [data, target]
?
dataloader=Data.DataLoader(dataset,batch_size=2,collate_fn=my_collate)
print(datax) print(datay)
it=iter(dataloader) nex=next(it) print(nex)
這不就和默認(rèn)的collate_fn的輸出結(jié)果一樣了嘛!無非就是默認(rèn)的還把輸入變成了tensor,標(biāo)簽變成了tensor,我上面是列表,我改就是了嘛!如下:
def my_collate(batch): inputs=[data[0].tolist() for data in batch] inputs=torch.tensor(inputs) target =[data[1].tolist() for data in batch] target=torch.tensor(target) return [inputs, target]
dataloader=Data.DataLoader(dataset,batch_size=2,collate_fn=my_collate)
it=iter(dataloader) nex=next(it) print(nex)
這下好了吧!
對了,作為彩蛋,告訴大家一個秘密:默認(rèn)的collate_fn函數(shù)中有一些語句是轉(zhuǎn)tensor以及tensor合并的操作,所以你的dataset如果沒有設(shè)計成經(jīng)典模式的話,使用默認(rèn)的就容易報錯,而我們自己會寫collate_fn,當(dāng)然就不存在這個問題啦。同時,給大家的一個經(jīng)驗就是,一般dataset是不會報錯的,而是根據(jù)dataset制作dataloader的時候容易報錯,因為默認(rèn)collate_fn把dataset的類型限制得比較死。
應(yīng)用情形
假設(shè)我們還是4個輸入,但是維度不固定的。
a=[[1,2],[3,4,5],[1],[3,4,9]] b=[1,0,0,1] dataset=mydataset(a,b) dataloader=Data.DataLoader(dataset,batch_size=2) it=iter(dataloader) nex=next(it) nex
使用默認(rèn)的collate_fn,直接報錯,要求相同維度。
這個時候,我們可以使用自己的collate_fn,避免報錯。
不過話說回來,我個人感受是:
在這里避免報錯好像也沒有什么用,因為大多數(shù)的神經(jīng)網(wǎng)絡(luò)都是定長輸入的,而且很多的操作也要求相同維度才能相加或相乘,所以:這里不報錯,后面還是報錯。如果后面解決這個問題的方法是:在不足維度上進(jìn)行補0操作,那么我們?yōu)槭裁床辉诮ataset之前先補好呢?所以,collate_fn這個東西的應(yīng)用場景還是有限的。
總結(jié)
原文鏈接:https://blog.csdn.net/qq_43391414/article/details/120462055
相關(guān)推薦
- 2022-08-23 python文件讀取read及readlines兩種方法使用詳解_python
- 2022-12-23 Kotlin?try?catch異常處理i詳解_Android
- 2022-12-25 React?redux?原理及使用詳解_React
- 2022-05-02 在kali上安裝AWVS的圖文教程_相關(guān)技巧
- 2022-12-04 .NET?Core利用BsonDocumentProjectionDefinition和Lookup
- 2022-12-10 Qt界面中滑動條的實現(xiàn)方式_C 語言
- 2022-09-23 Redux中異步action與同步action的使用_React
- 2022-02-27 Postgres -- 報錯:right sibling‘s left-link doesn‘t m
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細(xì)win安裝深度學(xué)習(xí)環(huán)境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認(rèn)證信息的處理
- Spring Security之認(rèn)證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設(shè)
- maven:解決release錯誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務(wù)發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡單動態(tài)字符串(SD
- arthas操作spring被代理目標(biāo)對象命令
- Spring中的單例模式應(yīng)用詳解
- 聊聊消息隊列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠(yuǎn)程分支