網站首頁 編程語言 正文
前言
本文主要使用 cpu 版本的 tensorflow-2.1 來完成深度學習權重參數/模型的保存和加載操作。
在我們進行項目期間,很多時候都要在模型訓練期間、訓練結束之后對模型或者模型權重進行保存,然后我們可以從之前停止的地方恢復原模型效果繼續進行訓練或者直接投入實際使用,另外為了節省存儲空間我們還可以自定義保存內容和保存頻率。
實現方法
1. 讀取數據
(1)本文重點介紹模型或者模型權重的保存和讀取的相關操作,使用到的是 MNIST 數據集僅是為了演示效果,我們無需關心模型訓練的質量好壞。
(2)這里是常規的讀取數據操作,我們為了能較快介紹本文重點內容,只使用了 MNIST 前 1000 條數據,然后對數據進行歸一化操作,加快模型訓練收斂速度,并且將每張圖片的數據從二維壓縮成一維。
import os
import tensorflow as tf
from tensorflow import keras
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_labels = train_labels[:1000]
test_labels = test_labels[:1000]
train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0
2. 搭建深度學習模型
(1)這里主要是搭建一個最簡單的深度學習模型。
(2)第一層將圖片的長度為 784 的一維向量轉換成 256 維向量的全連接操作,并且用到了 relu 激活函數。
(3)第二層緊接著使用了防止過擬合的 Dropout 操作,神經元丟棄率為 50% 。
(4)第三層為輸出層,也就是輸出每張圖片屬于對應 10 種類別的分布概率。
(5)優化器我們選擇了最常見的 Adam 。
(6)損失函數選擇了 SparseCategoricalCrossentropy 。
(7)評估指標選用了 SparseCategoricalAccuracy 。
def create_model():
model = tf.keras.Sequential([keras.layers.Dense(256, activation='relu', input_shape=(784,)),
keras.layers.Dropout(0.5),
keras.layers.Dense(10) ])
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
return model
3. 使用回調函數在每個 epoch 后自動保存模型權重
(1)這里介紹一種在模型訓練期間保存權重參數的方法,我們定義一個回調函數 callback ,它可以在訓練過程中將權重保存在自定義目錄中 weights_path ,在訓練過程中一共執行 5 次 epoch ,每次 epoch 結束之后就會保存一次模型的權重到指定的目錄。
(2)可以看到最后使用測試集進行評估的 loss 為 0.4952 ,分類準確率為 0.8500 。
weights_path = "training_weights/cp.ckpt"
weights_dir = os.path.dirname(weights_path)
callback = tf.keras.callbacks.ModelCheckpoint(filepath=weights_path, save_weights_only=True, verbose=1)
model = create_model()
model.fit(train_images,
train_labels,
epochs=5,
validation_data=(test_images, test_labels),
callbacks=[callback])
輸出結果為:
?val_loss: 0.4952 - val_sparse_categorical_accuracy: 0.8500?? ? ? ? ? ??
(3)我們瀏覽目標文件夾里,只有三個文件,每個 epoch 后自動都會保存三個文件,在下一次 epoch 之后會自動更新這三個文件的內容。
os.listdir(weights_dir)
結果為:
['checkpoint', 'cp.ckpt.data-00000-of-00001', 'cp.ckpt.index']
(4) 我們通過 create_model 定義了一個新的模型實例,然后讓其在沒有訓練的情況下使用測試數據進行評估,結果可想而知,準確率差的離譜。
NewModel = create_model()
loss, acc = NewModel.evaluate(test_images, test_labels, verbose=2)
結果為:
loss: 2.3694 - sparse_categorical_accuracy: 0.1330
(5) tensorflow 中只要兩個模型有相同的模型結構,就可以在它們之間共享權重,所以我們使用 NewModel 讀取了之前訓練好的模型權重,再使用測試集對其進行評估發現,損失值和準確率和舊模型的結果完全一樣,說明權重被相同結構的新模型成功加載并使用。
NewModel.load_weights(checkpoint_path)
loss, acc = NewModel.evaluate(test_images, test_labels, verbose=2)
輸出結果:
loss: 0.4952 - sparse_categorical_accuracy: 0.8500
4. 使用回調函數每經過 5 個 epoch 對模型權重保存一次
(1)如果我們想保留多個中間 epoch 的模型訓練的權重,或者我們想每隔幾個 epoch 保存一次模型訓練的權重,這時候我們可以通過設置保存頻率 period 來完成,我這里讓新建的模型訓練 30 個 epoch ,在每經過 10 epoch 后保存一次模型訓練好的權重。
(2)使用測試集對此次模型進行評估,損失值為 0.4047 ,準確率為 0.8680 。
weights_path = "training_weights2/cp-{epoch:04d}.ckpt"
weights_dir = os.path.dirname(weights_path)
batch_size = 64
cp_callback = tf.keras.callbacks.ModelCheckpoint( filepath=weights_path,
verbose=1,
save_weights_only=True,
period=10)
model = create_model()
model.save_weights(weights_path.format(epoch=1))
model.fit(train_images,
train_labels,
epochs=30,
batch_size=batch_size,
callbacks=[cp_callback],
validation_data=(test_images, test_labels),
verbose=1)
結果輸出為:
val_loss: 0.4047 - val_sparse_categorical_accuracy: 0.8680 ??
(3)這里我們能看到指定目錄中的文件組成,這里的 0001 是因為訓練時指定了要保存的 epoch 的權重,其他都是每 10 個 epoch 保存的權重參數文件。目錄中有一個 checkpoint ,它是一個檢查點文本文件,文件保存了一個目錄下所有的模型文件列表,首行記錄的是最后(最近)一次保存的模型名稱。
(4)每個 epoch 保存下來的文件都包含:
- 一個索引文件,指示哪些權重存儲在哪個分片中
- 一個或多個包含模型權重的分片
瀏覽文件夾內容
os.listdir(weights_dir)
結果如下:
['checkpoint', 'cp-0001.ckpt.data-00000-of-00001', 'cp-0001.ckpt.index', 'cp-0010.ckpt.data-00000-of-00001', 'cp-0010.ckpt.index', 'cp-0020.ckpt.data-00000-of-00001', 'cp-0020.ckpt.index', 'cp-0030.ckpt.data-00000-of-00001', 'cp-0030.ckpt.index']
(5)我們將最后一次保存的權重讀取出來,然后創建一個新的模型去讀取剛剛保存的最新的之前訓練好的模型權重,然后通過測試集對新模型進行評估,發現損失值準確率和之前完全一樣,說明權重被成功讀取并使用。
latest = tf.train.latest_checkpoint(weights_dir)
newModel = create_model()
newModel.load_weights(latest)
loss, acc = newModel.evaluate(test_images, test_labels, verbose=2)
結果如下:
loss: 0.4047 - sparse_categorical_accuracy: 0.8680
5. 手動保存模型權重到指定目錄
(1)有時候我們還想手動將模型訓練好的權重保存到指定的目錄下,我們可以使用 save_weights 函數,通過我們新建了一個同樣的新模型,然后使用 load_weights 函數去讀取權重并使用測試集對其進行評估,發現損失值和準確率仍然和之前的兩種結果完全一樣。
model.save_weights('./training_weights3/my_cp')
newModel = create_model()
newModel.load_weights('./training_weights3/my_cp')
loss, acc = newModel.evaluate(test_images, test_labels, verbose=2)
結果如下:
loss: 0.4047 - sparse_categorical_accuracy: 0.8680
6. 手動保存整個模型結構和權重
(1)有時候我們還需要保存整個模型的結構和權重,這時候我們直接使用 save 函數即可將這些內容保存到指定目錄,使用該方法要保證目錄是存在的否則會報錯,所以這里我們要創建文件夾。我們能看到損失值為 0.4821,準確率為 0.8460 。
model = create_model()
model.fit(train_images, train_labels, epochs=5, validation_data=(test_images, test_labels), verbose=1)
!mkdir my_model
modelPath = './my_model'
model.save(modelPath)
輸出結果:
val_loss: 0.4821 - val_sparse_categorical_accuracy: 0.8460
(2)然后我們通過函數 load_model 即可生成出一個新的完全一樣結構和權重的模型,我們使用測試集對其進行評估,發現準確率和損失值和之前完全一樣,說明模型結構和權重被完全讀取恢復。
new_model = tf.keras.models.load_model(modelPath)
loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
輸出結果:
?loss: 0.4821 - sparse_categorical_accuracy: 0.8460
原文鏈接:https://juejin.cn/post/7166486878714068999
相關推薦
- 2022-08-14 C++學習之算術運算符使用詳解_C 語言
- 2022-01-21 Flink中window 窗口和時間以及watermark水印
- 2022-10-05 linux查看服務器開放的端口和啟用的端口多種方式_Linux
- 2022-03-29 C++中的拷貝構造詳解_C 語言
- 2022-04-08 iOS實現簡單計算器功能_IOS
- 2022-09-16 Pandas統計計數value_counts()的使用_python
- 2023-01-29 Python操作lxml庫之基礎使用篇_python
- 2022-09-26 你了解Redis事務嗎_Redis
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支