網站首頁 編程語言 正文
opencv中也提供了一種類似于Keras的神經網絡,即為ann,這種神經網絡的使用方法與Keras的很接近。
關于mnist數據的解析,讀者可以自己從網上下載相應壓縮文件,用python自己編寫解析代碼,由于這里主要研究knn算法,為了圖簡單,直接使用Keras的mnist手寫數字解析模塊。
本次代碼運行環境為:
python 3.6.8
opencv-python 4.4.0.46
opencv-contrib-python 4.4.0.46
下面的代碼為使用ann進行模型的訓練:
from keras.datasets import mnist
from keras import utils
import cv2
import numpy as np
#opencv中ANN定義神經網絡層
def create_ANN():
ann=cv2.ml.ANN_MLP_create()
#設置神經網絡層的結構 輸入層為784 隱藏層為80 輸出層為10
ann.setLayerSizes(np.array([784,64,10]))
#設置網絡參數為誤差反向傳播法
ann.setTrainMethod(cv2.ml.ANN_MLP_BACKPROP)
#設置激活函數為sigmoid
ann.setActivationFunction(cv2.ml.ANN_MLP_SIGMOID_SYM)
#設置訓練迭代條件
#結束條件為訓練30次或者誤差小于0.00001
ann.setTermCriteria((cv2.TermCriteria_EPS|cv2.TermCriteria_COUNT,100,0.0001))
return ann
#計算測試數據上的識別率
def evaluate_acc(ann,test_images,test_labels):
#采用的sigmoid激活函數,需要對結果進行置信度處理
#對于大于0.99的可以確定為1 對于小于0.01的可以確信為0
test_ret=ann.predict(test_images)
#預測結果是一個元組
test_pre=test_ret[1]
#可以直接最大值的下標 (10000,)
test_pre=test_pre.argmax(axis=1)
true_sum=(test_pre==test_labels)
return true_sum.mean()
if __name__=='__main__':
#直接使用Keras載入的訓練數據(60000, 28, 28) (60000,)
(train_images,train_labels),(test_images,test_labels)=mnist.load_data()
#變換數據的形狀并歸一化
train_images=train_images.reshape(train_images.shape[0],-1)#(60000, 784)
train_images=train_images.astype('float32')/255
test_images=test_images.reshape(test_images.shape[0],-1)
test_images=test_images.astype('float32')/255
#將標簽變為one-hot形狀 (60000, 10) float32
train_labels=utils.to_categorical(train_labels)
#測試數據標簽不用變為one-hot (10000,)
test_labels=test_labels.astype(np.int)
#定義神經網絡模型結構
ann=create_ANN()
#開始訓練
ann.train(train_images,cv2.ml.ROW_SAMPLE,train_labels)
#在測試數據上測試準確率
print(evaluate_acc(ann,test_images,test_labels))
#保存模型
ann.save('mnist_ann.xml')
#加載模型
myann=cv2.ml.ANN_MLP_load('mnist_ann.xml')
訓練100次得到的準確率為0.9376,可以接著增加訓練次數或者提高神經網絡的層次結構深度來提高準確率。
使用ann神經網絡的模型結構非常小,因為只是保存了權重參數。
可以看到整個模型文件的大小才1M,而svm的大小為十多兆,knn的為幾百兆,因此使用ann神經網絡更加適合部署在客戶端上。
接下來使用ann進行圖片的測試識別:
import cv2
import numpy as np
if __name__=='__main__':
#讀取圖片
img=cv2.imread('shuzi.jpg',0)
img_sw=img.copy()
#將數據類型由uint8轉為float32
img=img.astype(np.float32)
#圖片形狀由(28,28)轉為(784,)
img=img.reshape(-1,)
#增加一個維度變為(1,784)
img=img.reshape(1,-1)
#圖片數據歸一化
img=img/255
#載入ann模型
ann=cv2.ml.ANN_MLP_load('minist_ann.xml')
#進行預測
img_pre=ann.predict(img)
#因為激活函數sigmoid,因此要進行置信度處理
ret=img_pre[1]
ret[ret>0.9]=1
ret[ret<0.1]=0
print(ret)
cv2.imshow('test',img_sw)
cv2.waitKey(0)
運行程序,結果如下,可見該模型正確識別了數字0.
原文鏈接:https://keras-lx.blog.csdn.net/article/details/111694841
相關推薦
- 2023-05-23 numpy數組之讀寫文件的實現_python
- 2023-02-26 C++?ROS與boost:bind()使用詳解_C 語言
- 2022-04-20 Flutter如何輕松實現動態更新ListView淺析_Android
- 2022-10-10 python讀取和保存為excel、csv、txt文件及對DataFrame文件的基本操作指南_py
- 2022-07-21 React生命周期
- 2022-11-07 ASP.NET?MVC通過勾選checkbox更改select的內容_實用技巧
- 2022-06-02 python文件與路徑操作神器?pathlib_python
- 2022-07-26 討論nginx?location?順序問題_nginx
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支