網站首頁 編程語言 正文
1、dataset:(數據類型 dataset)
輸入的數據類型,這里是原始數據的輸入。PyTorch內也有這種數據結構。
2、batch_size:(數據類型 int)
批訓練數據量的大小,根據具體情況設置即可(默認:1)。PyTorch訓練模型時調用數據不是一行一行進行的(這樣太沒效率),而是一捆一捆來的。這里就是定義每次喂給神經網絡多少行數據,如果設置成1,那就是一行一行進行(個人偏好,PyTorch默認設置是1)。每次是隨機讀取大小為batch_size。如果dataset中的數據個數不是batch_size的整數倍,這最后一次把剩余的數據全部輸出。若想把剩下的不足batch size個的數據丟棄,則將drop_last設置為True,會將多出來不足一個batch的數據丟棄。
3、shuffle:(數據類型 bool)
洗牌。默認設置為False。在每次迭代訓練時是否將數據洗牌,默認設置是False。將輸入數據的順序打亂,是為了使數據更有獨立性,但如果數據是有序列特征的,就不要設置成True了。
4、collate_fn:(數據類型 callable,沒見過的類型)
將一小段數據合并成數據列表,默認設置是False。如果設置成True,系統會在返回前會將張量數據(Tensors)復制到CUDA內存中。
5、batch_sampler:(數據類型 Sampler)
批量采樣,默認設置為None。但每次返回的是一批數據的索引(注意:不是數據)。其和batch_size、shuffle 、sampler and drop_last參數是不兼容的。我想,應該是每次輸入網絡的數據是隨機采樣模式,這樣能使數據更具有獨立性質。所以,它和一捆一捆按順序輸入,數據洗牌,數據采樣,等模式是不兼容的。
6、sampler:(數據類型 Sampler)
采樣,默認設置為None。根據定義的策略從數據集中采樣輸入。如果定義采樣規則,則洗牌(shuffle)設置必須為False。
7、num_workers:(數據類型 Int)
工作者數量,默認是0。使用多少個子進程來導入數據。設置為0,就是使用主進程來導入數據。注意:這個數字必須是大于等于0的,負數估計會出錯。
8、pin_memory:(數據類型 bool)
內存寄存,默認為False。在數據返回前,是否將數據復制到CUDA內存中。
9、drop_last:(數據類型 bool)
丟棄最后數據,默認為False。設置了 batch_size 的數目后,最后一批數據未必是設置的數目,有可能會小些。這時你是否需要丟棄這批數據。
10、timeout:(數據類型 numeric)
超時,默認為0。是用來設置數據讀取的超時時間的,但超過這個時間還沒讀取到數據的話就會報錯。 所以,數值必須大于等于0。
11、worker_init_fn(數據類型 callable,沒見過的類型)
子進程導入模式,默認為Noun。在數據導入前和步長結束后,根據工作子進程的ID逐個按順序導入數據。
對batch_size舉例分析:
"""
批訓練,把數據變成一小批一小批數據進行訓練。
DataLoader就是用來包裝所使用的數據,每次拋出一批數據
"""
import torch
import torch.utils.data as Data
BATCH_SIZE = 5
x = torch.linspace(1, 11, 11)
y = torch.linspace(11, 1, 11)
print(x)
print(y)
# 把數據放在數據庫中
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
# 從數據庫中每次抽出batch size個樣本
dataset=torch_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
# num_workers=2,
)
def show_batch():
for epoch in range(3):
for step, (batch_x, batch_y) in enumerate(loader):
# training
print("steop:{}, batch_x:{}, batch_y:{}".format(step, batch_x, batch_y))
if __name__ == '__main__':
show_batch()
輸出為:
tensor([ 1., ?2., ?3., ?4., ?5., ?6., ?7., ?8., ?9., 10., 11.])
tensor([11., 10., ?9., ?8., ?7., ?6., ?5., ?4., ?3., ?2., ?1.])
steop:0, batch_x:tensor([ 3., ?2., ?8., 11., ?1.]), batch_y:tensor([ 9., 10., ?4., ?1., 11.])
steop:1, batch_x:tensor([ 5., ?6., ?7., ?4., 10.]), batch_y:tensor([7., 6., 5., 8., 2.])
steop:2, batch_x:tensor([9.]), batch_y:tensor([3.])
steop:0, batch_x:tensor([ 9., ?7., 10., ?2., ?4.]), batch_y:tensor([ 3., ?5., ?2., 10., ?8.])
steop:1, batch_x:tensor([ 5., 11., ?3., ?6., ?8.]), batch_y:tensor([7., 1., 9., 6., 4.])
steop:2, batch_x:tensor([1.]), batch_y:tensor([11.])
steop:0, batch_x:tensor([10., ?5., ?7., ?4., ?2.]), batch_y:tensor([ 2., ?7., ?5., ?8., 10.])
steop:1, batch_x:tensor([3., 9., 1., 8., 6.]), batch_y:tensor([ 9., ?3., 11., ?4., ?6.])
steop:2, batch_x:tensor([11.]), batch_y:tensor([1.])
?
Process finished with exit code 0
若drop_last=True
"""
批訓練,把數據變成一小批一小批數據進行訓練。
DataLoader就是用來包裝所使用的數據,每次拋出一批數據
"""
import torch
import torch.utils.data as Data
BATCH_SIZE = 5
x = torch.linspace(1, 11, 11)
y = torch.linspace(11, 1, 11)
print(x)
print(y)
# 把數據放在數據庫中
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
# 從數據庫中每次抽出batch size個樣本
dataset=torch_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
# num_workers=2,
drop_last=True,
)
def show_batch():
for epoch in range(3):
for step, (batch_x, batch_y) in enumerate(loader):
# training
print("steop:{}, batch_x:{}, batch_y:{}".format(step, batch_x, batch_y))
if __name__ == '__main__':
show_batch()
對應的輸出為:
tensor([ 1., ?2., ?3., ?4., ?5., ?6., ?7., ?8., ?9., 10., 11.])
tensor([11., 10., ?9., ?8., ?7., ?6., ?5., ?4., ?3., ?2., ?1.])
steop:0, batch_x:tensor([ 9., ?2., ?7., ?4., 11.]), batch_y:tensor([ 3., 10., ?5., ?8., ?1.])
steop:1, batch_x:tensor([ 3., ?5., 10., ?1., ?8.]), batch_y:tensor([ 9., ?7., ?2., 11., ?4.])
steop:0, batch_x:tensor([ 5., 11., ?6., ?1., ?2.]), batch_y:tensor([ 7., ?1., ?6., 11., 10.])
steop:1, batch_x:tensor([ 3., ?4., 10., ?8., ?9.]), batch_y:tensor([9., 8., 2., 4., 3.])
steop:0, batch_x:tensor([10., ?4., ?9., ?8., ?7.]), batch_y:tensor([2., 8., 3., 4., 5.])
steop:1, batch_x:tensor([ 6., ?1., 11., ?2., ?5.]), batch_y:tensor([ 6., 11., ?1., 10., ?7.])
?
Process finished with exit code 0
總結
原文鏈接:https://blog.csdn.net/qq_36044523/article/details/118914223
相關推薦
- 2023-03-23 Rust應用調用C語言動態庫的操作方法_Rust語言
- 2022-02-25 C++實現單例模式的方法_C 語言
- 2022-09-20 C#單線程和多線程端口掃描器詳解_C#教程
- 2022-11-19 Django項目中表的查詢的操作_python
- 2022-05-17 Git分支管理策略_其它綜合
- 2023-01-28 Flutter框架解決盒約束widget和assets里加載資產技術_Android
- 2022-06-18 C語言?詳解字符串基礎_C 語言
- 2023-06-20 React?DOM-diff?節點源碼解析_React
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支