網站首頁 編程語言 正文
學習前言
開始做項目的話,有些時候會用到別人訓練好的模型,這個時候要學會load噢。
Keras中保存與讀取的重要函數
1、model.save
model.save用于保存模型,在保存模型前,首先要利用pip install安裝h5py的模塊,這個模塊在Keras的模型保存與讀取中常常被使用,用于定義保存格式。
pip install h5py
完成安裝后,可以通過如下函數保存模型。
model.save("./model.hdf5")
其中,model是已經訓練完成的模型,save函數傳入的參數就是保存后的位置+名字。
2、load_model
load_model用于載入模型。
具體使用方式如下:
model = load_model("./model.hdf5")
其中,load_model函數傳入的參數就是已經完成保存的模型的位置+名字。./表示保存在當前目錄。
全部代碼
這是一個簡單的手寫體識別例子,在之前也講解過如何構建
python神經網絡學習使用Keras進行簡單分類,在最后我添加上了模型的保存與讀取函數。
import numpy as np
from keras.models import Sequential,load_model,save_model
from keras.layers import Dense,Activation ## 全連接層
from keras.datasets import mnist
from keras.utils import np_utils
from keras.optimizers import RMSprop
# 獲取訓練集
(X_train,Y_train),(X_test,Y_test) = mnist.load_data()
# 首先進行標準化
X_train = X_train.reshape(X_train.shape[0],-1)/255
X_test = X_test.reshape(X_test.shape[0],-1)/255
# 計算categorical_crossentropy需要對分類結果進行categorical
# 即需要將標簽轉化為形如(nb_samples, nb_classes)的二值序列
Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
Y_test = np_utils.to_categorical(Y_test,num_classes= 10)
# 構建模型
model = Sequential([
Dense(32,input_dim = 784),
Activation("relu"),
Dense(10),
Activation("softmax")
]
)
rmsprop = RMSprop(lr = 0.001,rho = 0.9,epsilon = 1e-08,decay = 0)
## compile
model.compile(loss = 'categorical_crossentropy',optimizer = rmsprop,metrics=['accuracy'])
print("\ntraining")
cost = model.fit(X_train,Y_train,nb_epoch = 2,batch_size = 100)
print("\nTest")
# 測試
cost,accuracy = model.evaluate(X_test,Y_test)
print("accuracy:",accuracy)
# 保存模型
model.save("./model.hdf5")
# 刪除現有模型
del model
print("model had been del")
# 再次載入模型
model = load_model("./model.hdf5")
# 預測
cost,accuracy = model.evaluate(X_test,Y_test)
print("accuracy:",accuracy)
實驗結果為:
Epoch 1/2
60000/60000 [==============================] - 6s 104us/step - loss: 0.4217 - acc: 0.8888
Epoch 2/2
60000/60000 [==============================] - 6s 99us/step - loss: 0.2240 - acc: 0.9366
Test
10000/10000 [==============================] - 1s 149us/step
accuracy: 0.9419
model had been del
10000/10000 [==============================] - 1s 117us/step
accuracy: 0.9419
原文鏈接:https://blog.csdn.net/weixin_44791964/article/details/101613118
相關推薦
- 2021-12-31 element 級聯下拉菜單 獲取value 同時 獲取label
- 2022-09-05 C語言之數組名與數組起始地址的關系解析_C 語言
- 2023-06-03 C++一個函數如何調用其他.cpp文件中的函數_C 語言
- 2023-07-28 獲取當前日期以及前6天的日期集合
- 2023-01-01 Python交換字典鍵值對的四種方法實例_python
- 2023-02-06 Golang泛型實現類型轉換的方法實例_Golang
- 2022-10-25 基于Pytorch使用GPU運行模型方法及可能出現的問題解決方法
- 2022-07-23 python雙向鏈表實例詳解_python
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支