網站首頁 編程語言 正文
Dataset類是TensorFlow非常流行的存儲數據的格式。常用來作為輸入輸出。data模塊主要的用途就是通過這種方法創建Dataset。
Dataset使用過程中的一些心得:
經常將自變量X數據以及target數據以元組的形式包裹,如db_train=tf.data.Dataset.from_tensor_slices((x_train,y_train)),創建Dataset。模型的fit()方法可以自動的解包。
Dataset能夠包括比較靈活的類型,比如db_train=tf.data.Dataset.from_tensor_slices(({"features":features_train,"biomass_start":biomass_start_trarin},y_train))。因為數據最外部依然是最外部包裹,所以model的fit()依然可以自動的對x以及target解包。但由于dataset保存component是以原始數據的形式保存的。所以,fit()里的inputs一般是這個樣子:
{'features': <tf.Tensor 'my_rnn/Cast_1:0' shape=(None, 5, 4) dtype=float32>, 'biomass_start': <tf.Tensor 'my_rnn/Cast:0' shape=(None, 1) dtype=float32>}
對于字典內部部分,需要手動的自己解包。這樣的好處是,給我們自定義模型的結構提供的很大的遍歷,輸入一部分導入A網絡,一部分導入不同的B網絡。
Dataset作為模型的輸入,需要設定batch()。而不在模型內設定batch。更加方便。然而Dataset作為迭代器,迭代完成后再次迭代數據,生成數據的前后數據是不一樣的。需要注意。
batch的drop_remainder=True參數比較重要,只有設定為True,input接下來的層還能正確的識別shape
Dataset的常用屬性
Dataset.element_spec
這個屬性可以檢測每一個元素中的component的類型。返回的是一個tf.TypeSpec對象。這個對象的結構跟元素的結構是一致的。
dataset1 = tf.data.Dataset.from_tensor_slices(tf.random.uniform([4, 10]))
dataset1.element_spec
#TensorSpec(shape=(10,), dtype=tf.float32, name=None)
dataset2 = tf.data.Dataset.from_tensor_slices(
? ?(tf.random.uniform([4]),
? ? tf.random.uniform([4, 100], maxval=100, dtype=tf.int32)))
dataset2.element_spec
# 標量和向量
# (TensorSpec(shape=(), dtype=tf.float32, name=None),
#TensorSpec(shape=(100,), dtype=tf.int32, name=None))
dataset = tf.data.Dataset.from_tensor_slices(([1, 2], [3, 4], [5, 6]))
dataset.element_spec?
#(TensorSpec(shape=(), dtype=tf.int32, name=None),
# TensorSpec(shape=(), dtype=tf.int32, name=None),
# TensorSpec(shape=(), dtype=tf.int32, name=None))
# 注意這里是字典類型
dataset = tf.data.Dataset.from_tensor_slices({"a": [1, 2], "b": [3, 4]})
dataset.element_spec
#{'a': TensorSpec(shape=(), dtype=tf.int32, name=None),
# 'b': TensorSpec(shape=(), dtype=tf.int32, name=None)}
Dataset的常用方法
apply方法
對dataset進行轉換。
dataset = tf.data.Dataset.range(100)
def dataset_fn(ds):
return ds.filter(lambda x: x < 5)
dataset = dataset.apply(dataset_fn)
list(dataset.as_numpy_iterator())
as_numpy_iterator
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
for element in dataset.as_numpy_iterator():
print(element)
這個在dataset比較常用。就是將dataset變成迭代器,將所有元素都變成numy對象輸出
shuffle
shuffle(
buffer_size, seed=None, reshuffle_each_iteration=None, name=None
)
參數:
- buffer_size:緩沖區大小
- seed:隨機種子
- reshuffle_each_iteration:bool. 如果為真,表示每次迭代時數據集完成后都應該是進行偽隨機重新洗牌的。控制每個epoch的洗牌順序是否不同。
這個方法用來隨機打亂數據集的元素順序。數據集用buffer_size元素填充一個緩沖區,然后從這個緩沖區隨機取樣元素,用新元素替換選中的元素。例如,如果您的數據集包含10,000個元素,但是buffer_size被設置為1,000,那么shuffle將首先從緩沖區中的前1,000個元素中選擇一個隨機元素。一旦一個元素被選中,它在緩沖區中的空間就會被下一個(比如第1001個)元素替換,從而保持這個1,000元素緩沖區。為了實現完美的洗牌,需要一個大于或等于數據集完整大小的緩沖區。
dataset = tf.data.Dataset.range(3)
# 每個每個epoch重新洗牌
dataset = dataset.shuffle(3, reshuffle_each_iteration=True)
list(dataset.as_numpy_iterator())
# [1, 0, 2]
list(dataset.as_numpy_iterator())
# [1, 2, 0]
dataset = tf.data.Dataset.range(3)
# 每個每個epoch不重新洗牌
dataset = dataset.shuffle(3, reshuffle_each_iteration=False)
list(dataset.as_numpy_iterator())
# [1, 0, 2]
list(dataset.as_numpy_iterator())
# [1, 0, 2]
batch
batch(
batch_size,
drop_remainder=False,
num_parallel_calls=None,
deterministic=None,
name=None
)
參數:
- batch_size: 批處理大小
- drop_remainder:是否刪除最后一個短batch。==這個比較重要,只有設定為Ture,model才能正確的判斷其輸入的shape。==這也比較合理,指定為Falsel,因為誰也不知道后面是不是有一個比較短的batch,只有第一維是None,才能提高程序的穩定性。
- num_parallel_calls:并行計算的數量。不指定會順序執行。如果有 tf.data.AUTOTUNE,會自動動態的制定這個值。
- deterministic:bool. 指定了num_parallel_calls,才有效。如果設置為False,則允許轉換產生無序元素,以犧牲確定性來換取性能。如果不指定,tf.data.Options.deterministic控制這個行為(默認為True)
- name: 標識符
這個方法經常使用,將dataset進行批處理化。因為數據集比較大的時候,一下子完全進行訓練占用大量的內存。所以用分批處理。輸出的元素增加了一個額外的維度,就是batch維,shape是batch的size.
batch支持一個drop_remainder=True關鍵字,為真意味著,最后一個batch的size如果小于我們指定值,就會被舍棄。
之所以要刪掉最后一個短的batch,是因為如果我們的項目依賴這個batch的size,那最后一個batch不等長,可能會出錯。
import tensorflow as tf
from tensorflow.python.data import Dataset
dataset = tf.data.Dataset.range(8)
dataset = dataset.batch(3)
print(list(dataset.as_numpy_iterator()))
# 通過這個看到這個elem也已經是分批了
for elem in dataset:
? ? print(elem)
# tf.Tensor([0 1 2], shape=(3,), dtype=int64)
# tf.Tensor([3 4 5], shape=(3,), dtype=int64)
# tf.Tensor([6 7], shape=(2,), dtype=int64)
for elem in dataset.as_numpy_iterator():
? ? print(elem)
# [0 1 2]
# [3 4 5]
# [6 7]
dataset = tf.data.Dataset.range(8)
# drop_remainder舍掉最后一個長度不夠的batch
dataset = dataset.batch(3, drop_remainder=True)
list(dataset.as_numpy_iterator())
一般情況下,shuffle跟batch是連續使用的,實現隨機讀取并批量處理數據:dataset.shuffle(buffer_size).batch(batchsize)
不能對已經batch的dataset進行連續的batch操作,其batchsize不會改變,而是生成了新的異常數據
unbatch
unbatch(
name=None
)
這里是將Batchdataset這樣的dataset分割為一個個元素,元素的格式跟定義時的格式是一樣的。而且,這里固定的是對第1個維度進行split操作,且生成shape[0]個元素。
reduce方法
reduce(
initial_state, reduce_func, name=None
)
將輸入數據集簡化為一個元素。 reduce_func作用于dataset中每一個元素,輸出其dataset的聚合信息。
參數initial_state代表進行reduce之前的初始狀態。reduce_func要接收old_state, input_element兩個參數,然后生成新的狀態newstate。old_state和new_state的結構要一致。
dataset = tf.data.Dataset.from_tensor_slices([8, 3, 0, 8, 2, 1])
print(dataset.reduce(0, lambda state, value: state + value).numpy())
# 22
dataset不支持tf.split屬性,也不能直接把dataset給切分為訓練集和測試集。
原文鏈接:https://blog.csdn.net/yue81560/article/details/128691866
相關推薦
- 2022-05-23 iOS實現無限滑動效果_IOS
- 2022-09-22 模擬實現vector
- 2022-03-22 C語言圍圈報數題目代碼實現_C 語言
- 2023-01-28 Android?之Preference控件基本使用示例詳解_Android
- 2023-01-05 Presenting?Streams?in?Flutter小技巧_Android
- 2022-11-03 Python列表推導式,元組推導式,字典推導式,集合推導式_python
- 2022-10-06 zabbix如何添加監控主機和自定義監控項_zabbix
- 2022-05-23 Python的代理類實現,控制訪問和修改屬性的權限你都了解嗎_python
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支