日本免费高清视频-国产福利视频导航-黄色在线播放国产-天天操天天操天天操天天操|www.shdianci.com

學(xué)無(wú)先后,達(dá)者為師

網(wǎng)站首頁(yè) 編程語(yǔ)言 正文

pytorch中dataloader?的sampler?參數(shù)詳解_python

作者:mingqian_chu ? 更新時(shí)間: 2022-10-27 編程語(yǔ)言

1. dataloader() 初始化函數(shù)

 def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
 batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
                 worker_init_fn=None, multiprocessing_context=None):

其中幾個(gè)常用的參數(shù):

  • dataset 數(shù)據(jù)集,map-style and iterable-style 可以用index取值的對(duì)象、
  • batch_size 大小
  • shuffle 取batch是否隨機(jī)取, 默認(rèn)為False
  • sampler 定義取batch的方法,是一個(gè)迭代器, 每次生成一個(gè)key 用于讀取dataset中的值
  • batch_sampler 也是一個(gè)迭代器, 每次生次一個(gè)batch_size的key
  • num_workers 參與工作的線(xiàn)程數(shù)collate_fn 對(duì)取出的batch進(jìn)行處理
  • drop_last 對(duì)最后不足batchsize的數(shù)據(jù)的處理方法

下面看兩段取自DataLoader中的__init__代碼, 幫助我們理解幾個(gè)常用參數(shù)之間的關(guān)系

2. shuffle 與sample 之間的關(guān)系

當(dāng)我們sampler有輸入時(shí),shuffle的值就沒(méi)有意義,

	if sampler is None:  # give default samplers
	    if self._dataset_kind == _DatasetKind.Iterable:
	        # See NOTE [ Custom Samplers and IterableDataset ]
	        sampler = _InfiniteConstantSampler()
	    else:  # map-style
	        if shuffle:
	            sampler = RandomSampler(dataset)
	        else:
	            sampler = SequentialSampler(dataset)

當(dāng)dataset類(lèi)型是map style時(shí), shuffle其實(shí)就是改變sampler的取值

  • shuffle為默認(rèn)值 False時(shí),sampler是SequentialSampler,就是按順序取樣,
  • shuffle為T(mén)rue時(shí),sampler是RandomSampler, 就是按隨機(jī)取樣

3. sample 的定義方法

3.1 sampler 參數(shù)的使用

sampler 是用來(lái)定義取batch方法的一個(gè)函數(shù)或者類(lèi),返回的是一個(gè)迭代器。

我們可以看下自帶的RandomSampler類(lèi)中最重要的iter函數(shù)

    def __iter__(self):
        n = len(self.data_source)
        # dataset的長(zhǎng)度, 按順序索引
        if self.replacement:# 對(duì)應(yīng)的replace參數(shù)
            return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
        return iter(torch.randperm(n).tolist())        

可以看出,其實(shí)就是生成索引,然后隨機(jī)的取值, 然后再迭代。

其實(shí)還有一些細(xì)節(jié)需要注意理解:

比如__len__函數(shù),包括DataLoader的len和sample的len, 兩者區(qū)別, 這部分代碼比較簡(jiǎn)單,可以自行閱讀,其實(shí)參考著RandomSampler寫(xiě)也不會(huì)出現(xiàn)問(wèn)題。
比如,迭代器和生成器的使用, 以及區(qū)別

    if batch_size is not None and batch_sampler is None:
        # auto_collation without custom batch_sampler
        batch_sampler = BatchSampler(sampler, batch_size, drop_last)
        
    self.sampler = sampler
    self.batch_sampler = batch_sampler

BatchSampler的生成過(guò)程:

# 略去類(lèi)的初始化
    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch

就是按batch_size從sampler中讀取索引, 并形成生成器返回。

以上可以看出, batch_sampler和sampler, batch_size, drop_last之間的關(guān)系

  • 如果batch_sampler沒(méi)有定義的話(huà)且batch_size有定義, 會(huì)根據(jù)sampler, batch_size, drop_last生成一個(gè)batch_sampler
  • 自帶的注釋中對(duì)batch_sampler有一句話(huà): Mutually exclusive with :attr:batch_size :attr:shuffle, :attr:sampler, and :attr:drop_last.
  • 意思就是b
  • atch_sampler 與這些參數(shù)沖突 ,即 如果你定義了batch_sampler, 其他參數(shù)都不需要有

4. batch 生成過(guò)程

每個(gè)batch都是由迭代器產(chǎn)生的:

# DataLoader中iter的部分
    def __iter__(self):
        if self.num_workers == 0:
            return _SingleProcessDataLoaderIter(self)
        else:
            return _MultiProcessingDataLoaderIter(self)

# 再看調(diào)用的另一個(gè)類(lèi)
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
    def __init__(self, loader):
        super(_SingleProcessDataLoaderIter, self).__init__(loader)
        assert self._timeout == 0
        assert self._num_workers == 0

        self._dataset_fetcher = _DatasetKind.create_fetcher(
            self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)

    def __next__(self):
        index = self._next_index()  
        data = self._dataset_fetcher.fetch(index)  
        if self._pin_memory:
            data = _utils.pin_memory.pin_memory(data)
        return data

原文鏈接:https://blog.csdn.net/chumingqian/article/details/126625724

欄目分類(lèi)
最近更新