網(wǎng)站首頁(yè) 編程語(yǔ)言 正文
1. dataloader() 初始化函數(shù)
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, multiprocessing_context=None):
其中幾個(gè)常用的參數(shù):
- dataset 數(shù)據(jù)集,map-style and iterable-style 可以用index取值的對(duì)象、
- batch_size 大小
- shuffle 取batch是否隨機(jī)取, 默認(rèn)為False
- sampler 定義取batch的方法,是一個(gè)迭代器, 每次生成一個(gè)key 用于讀取dataset中的值
- batch_sampler 也是一個(gè)迭代器, 每次生次一個(gè)batch_size的key
- num_workers 參與工作的線(xiàn)程數(shù)collate_fn 對(duì)取出的batch進(jìn)行處理
- drop_last 對(duì)最后不足batchsize的數(shù)據(jù)的處理方法
下面看兩段取自DataLoader中的__init__代碼, 幫助我們理解幾個(gè)常用參數(shù)之間的關(guān)系
2. shuffle 與sample 之間的關(guān)系
當(dāng)我們sampler有輸入時(shí),shuffle的值就沒(méi)有意義,
if sampler is None: # give default samplers
if self._dataset_kind == _DatasetKind.Iterable:
# See NOTE [ Custom Samplers and IterableDataset ]
sampler = _InfiniteConstantSampler()
else: # map-style
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
當(dāng)dataset類(lèi)型是map style時(shí), shuffle其實(shí)就是改變sampler的取值
- shuffle為默認(rèn)值 False時(shí),sampler是SequentialSampler,就是按順序取樣,
- shuffle為T(mén)rue時(shí),sampler是RandomSampler, 就是按隨機(jī)取樣
3. sample 的定義方法
3.1 sampler 參數(shù)的使用
sampler 是用來(lái)定義取batch方法的一個(gè)函數(shù)或者類(lèi),返回的是一個(gè)迭代器。
我們可以看下自帶的RandomSampler類(lèi)中最重要的iter函數(shù)
def __iter__(self):
n = len(self.data_source)
# dataset的長(zhǎng)度, 按順序索引
if self.replacement:# 對(duì)應(yīng)的replace參數(shù)
return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
return iter(torch.randperm(n).tolist())
可以看出,其實(shí)就是生成索引,然后隨機(jī)的取值, 然后再迭代。
其實(shí)還有一些細(xì)節(jié)需要注意理解:
比如__len__函數(shù),包括DataLoader的len和sample的len, 兩者區(qū)別, 這部分代碼比較簡(jiǎn)單,可以自行閱讀,其實(shí)參考著RandomSampler寫(xiě)也不會(huì)出現(xiàn)問(wèn)題。
比如,迭代器和生成器的使用, 以及區(qū)別
if batch_size is not None and batch_sampler is None:
# auto_collation without custom batch_sampler
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
self.sampler = sampler
self.batch_sampler = batch_sampler
BatchSampler的生成過(guò)程:
# 略去類(lèi)的初始化
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
就是按batch_size從sampler中讀取索引, 并形成生成器返回。
以上可以看出, batch_sampler和sampler, batch_size, drop_last之間的關(guān)系
- 如果batch_sampler沒(méi)有定義的話(huà)且batch_size有定義, 會(huì)根據(jù)sampler, batch_size, drop_last生成一個(gè)batch_sampler
- 自帶的注釋中對(duì)batch_sampler有一句話(huà): Mutually exclusive with :attr:batch_size :attr:shuffle, :attr:sampler, and :attr:drop_last.
- 意思就是b
- atch_sampler 與這些參數(shù)沖突 ,即 如果你定義了batch_sampler, 其他參數(shù)都不需要有
4. batch 生成過(guò)程
每個(gè)batch都是由迭代器產(chǎn)生的:
# DataLoader中iter的部分
def __iter__(self):
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
return _MultiProcessingDataLoaderIter(self)
# 再看調(diào)用的另一個(gè)類(lèi)
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
assert self._timeout == 0
assert self._num_workers == 0
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
def __next__(self):
index = self._next_index()
data = self._dataset_fetcher.fetch(index)
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data
原文鏈接:https://blog.csdn.net/chumingqian/article/details/126625724
相關(guān)推薦
- 2022-04-28 Python可視化學(xué)習(xí)之matplotlib內(nèi)置單顏色_python
- 2022-04-15 Python?yield?的使用淺析_python
- 2022-08-05 C語(yǔ)言示例講解switch分支語(yǔ)句的用法_C 語(yǔ)言
- 2022-07-03 el-upload上傳組件的動(dòng)態(tài)添加;el-upload動(dòng)態(tài)上傳文件;el-upload區(qū)分文件是哪
- 2022-09-25 Clion配置STM32開(kāi)發(fā)環(huán)境printf函數(shù)打印浮點(diǎn)數(shù)快速設(shè)置方法
- 2022-09-16 Go1.16新特性embed打包靜態(tài)資源文件實(shí)現(xiàn)_Golang
- 2023-01-17 Android繪制文本與圖片及動(dòng)效詳解_Android
- 2022-05-05 基于PyQt5制作數(shù)據(jù)處理小工具_(dá)python
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細(xì)win安裝深度學(xué)習(xí)環(huán)境2025年最新版(
- Linux 中運(yùn)行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲(chǔ)小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運(yùn)算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認(rèn)證信息的處理
- Spring Security之認(rèn)證過(guò)濾器
- Spring Security概述快速入門(mén)
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設(shè)
- maven:解決release錯(cuò)誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實(shí)現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務(wù)發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡(jiǎn)單動(dòng)態(tài)字符串(SD
- arthas操作spring被代理目標(biāo)對(duì)象命令
- Spring中的單例模式應(yīng)用詳解
- 聊聊消息隊(duì)列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠(yuǎn)程分支