網站首頁 編程語言 正文
1. dataloader() 初始化函數
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, multiprocessing_context=None):
其中幾個常用的參數:
- dataset 數據集,map-style and iterable-style 可以用index取值的對象、
- batch_size 大小
- shuffle 取batch是否隨機取, 默認為False
- sampler 定義取batch的方法,是一個迭代器, 每次生成一個key 用于讀取dataset中的值
- batch_sampler 也是一個迭代器, 每次生次一個batch_size的key
- num_workers 參與工作的線程數collate_fn 對取出的batch進行處理
- drop_last 對最后不足batchsize的數據的處理方法
下面看兩段取自DataLoader中的__init__代碼, 幫助我們理解幾個常用參數之間的關系
2. shuffle 與sample 之間的關系
當我們sampler有輸入時,shuffle的值就沒有意義,
if sampler is None: # give default samplers
if self._dataset_kind == _DatasetKind.Iterable:
# See NOTE [ Custom Samplers and IterableDataset ]
sampler = _InfiniteConstantSampler()
else: # map-style
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
當dataset類型是map style時, shuffle其實就是改變sampler的取值
- shuffle為默認值 False時,sampler是SequentialSampler,就是按順序取樣,
- shuffle為True時,sampler是RandomSampler, 就是按隨機取樣
3. sample 的定義方法
3.1 sampler 參數的使用
sampler 是用來定義取batch方法的一個函數或者類,返回的是一個迭代器。
我們可以看下自帶的RandomSampler類中最重要的iter函數
def __iter__(self):
n = len(self.data_source)
# dataset的長度, 按順序索引
if self.replacement:# 對應的replace參數
return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
return iter(torch.randperm(n).tolist())
可以看出,其實就是生成索引,然后隨機的取值, 然后再迭代。
其實還有一些細節需要注意理解:
比如__len__函數,包括DataLoader的len和sample的len, 兩者區別, 這部分代碼比較簡單,可以自行閱讀,其實參考著RandomSampler寫也不會出現問題。
比如,迭代器和生成器的使用, 以及區別
if batch_size is not None and batch_sampler is None:
# auto_collation without custom batch_sampler
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
self.sampler = sampler
self.batch_sampler = batch_sampler
BatchSampler的生成過程:
# 略去類的初始化
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
就是按batch_size從sampler中讀取索引, 并形成生成器返回。
以上可以看出, batch_sampler和sampler, batch_size, drop_last之間的關系
- 如果batch_sampler沒有定義的話且batch_size有定義, 會根據sampler, batch_size, drop_last生成一個batch_sampler
- 自帶的注釋中對batch_sampler有一句話: Mutually exclusive with :attr:batch_size :attr:shuffle, :attr:sampler, and :attr:drop_last.
- 意思就是b
- atch_sampler 與這些參數沖突 ,即 如果你定義了batch_sampler, 其他參數都不需要有
4. batch 生成過程
每個batch都是由迭代器產生的:
# DataLoader中iter的部分
def __iter__(self):
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
return _MultiProcessingDataLoaderIter(self)
# 再看調用的另一個類
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
assert self._timeout == 0
assert self._num_workers == 0
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
def __next__(self):
index = self._next_index()
data = self._dataset_fetcher.fetch(index)
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data
原文鏈接:https://blog.csdn.net/chumingqian/article/details/126625724
相關推薦
- 2022-05-01 Python字符串和其常用函數合集_python
- 2023-05-22 解決Python報錯:ValueError:operands?could?not?be?broadc
- 2023-02-04 golang?Gorm框架講解_Golang
- 2022-10-05 帶你深度走入C語言取整以及4種函數_C 語言
- 2022-07-11 Linux刪除某個字母開頭的所有文件
- 2022-05-16 輕松讀懂Golang中的數組和切片_Golang
- 2023-01-20 Python中數組切片的用法實例詳解_python
- 2023-03-21 通俗易懂的C語言快速排序和歸并排序的時間復雜度分析_C 語言
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支