網站首頁 編程語言 正文
利用Keras構建完普通BP神經網絡后,還要會構建CNN
Keras中構建CNN的重要函數
1、Conv2D
Conv2D用于在CNN中構建卷積層,在使用它之前需要在庫函數處import它。
from keras.layers import Conv2D
在實際使用時,需要用到幾個參數。
Conv2D(
nb_filter = 32,
nb_row = 5,
nb_col = 5,
border_mode = 'same',
input_shape = (28,28,1)
)
其中,nb_filter代表卷積層的輸出有多少個channel,卷積之后圖像會越來越厚,這就是卷積后圖像的厚度。nb_row和nb_col的組合就是卷積器的大小,這里卷積器是(5,5)的大小。border_mode代表著padding的方式,same表示卷積前后圖像的shape不變。input_shape代表輸入的shape。
2、MaxPooling2D
MaxPooling2D指的是池化層,在使用它之前需要在庫函數處import它。
from keras.layers import MaxPooling2D
在實際使用時,需要用到幾個參數。
MaxPooling2D(
pool_size = (2,2),
strides = (2,2),
border_mode = 'same'
)
其中,pool_size表示池化器的大小,在這里,池化器的shape是(2,2)。strides是池化器的步長,這里在X和Y方向上都是2,池化后,輸出比輸入的shape小了1/2。border_mode代表著padding的方式。
3、Flatten
Flatten用于將卷積池化后最后的輸出變為一維向量,這樣才可以和全連接層連接,用于計算。在使用前需要用import導入。
from keras.layers import Flatten
在實際使用時,在最后一個池化層后直接添加層即可
model.add(Flatten())
全部代碼
這是一個卷積神經網絡的例子,用于識別手寫體,其神經網絡結構如下:
卷積層1->池化層1->卷積層2->池化層2->flatten->全連接層1->全連接層2->全連接層3。
單個樣本的shape如下:
(28,28,1)->(28,28,32)->(14,14,32)->(14,14,64)->(7,7,64)->(3136)->(1024)->(256)
import numpy as np
from keras.models import Sequential
from keras.layers import Dense,Activation,Conv2D,MaxPooling2D,Flatten ## 全連接層
from keras.datasets import mnist
from keras.utils import np_utils
from keras.optimizers import Adam
(X_train,Y_train),(X_test,Y_test) = mnist.load_data()
X_train = X_train.reshape(-1,28,28,1)
X_test = X_test.reshape(-1,28,28,1)
Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
Y_test = np_utils.to_categorical(Y_test,num_classes= 10)
model = Sequential()
# conv1
model.add(
Conv2D(
nb_filter = 32,
nb_row = 5,
nb_col = 5,
border_mode = 'same',
input_shape = (28,28,1)
)
)
model.add(Activation("relu"))
# pool1
model.add(
MaxPooling2D(
pool_size = (2,2),
strides = (2,2),
border_mode = 'same'
)
)
# conv2
model.add(
Conv2D(
nb_filter = 64,
nb_row = 5,
nb_col = 5,
border_mode = 'same'
)
)
model.add(Activation("relu"))
# pool2
model.add(
MaxPooling2D(
pool_size = (2,2),
strides = (2,2),
border_mode = 'same'
)
)
# 全連接層
model.add(Flatten())
model.add(Dense(1024))
model.add(Activation("relu"))
model.add(Dense(256))
model.add(Activation("relu"))
model.add(Dense(10))
model.add(Activation("softmax"))
adam = Adam(lr = 1e-4)
## compile
model.compile(loss = 'categorical_crossentropy',optimizer = adam,metrics = ['accuracy'])
## tarin
print("\ntraining")
cost = model.fit(X_train,Y_train,nb_epoch = 2,batch_size = 32)
print("\nTest")
## acc
cost,accuracy = model.evaluate(X_test,Y_test)
## W,b = model.layers[0].get_weights()
print("accuracy:",accuracy)
實驗結果為:
Epoch 1/2
60000/60000 [==============================] - 64s 1ms/step - loss: 0.7664 - acc: 0.9224
Epoch 2/2
60000/60000 [==============================] - 62s 1ms/step - loss: 0.0473 - acc: 0.9858
Test
10000/10000 [==============================] - 2s 169us/step
accuracy: 0.9856
原文鏈接:https://blog.csdn.net/weixin_44791964/article/details/101171576
相關推薦
- 2021-12-09 Jenkins+GitLab+Docker持續集成LNMP
- 2022-04-15 python實現AES算法及AES-CFB8加解密源碼_python
- 2022-04-19 C#多線程系列之線程的創建和生命周期_C#教程
- 2022-06-15 ASP.NET?MVC使用區域(Area)功能_基礎應用
- 2024-02-27 credentials to a set of origins, list them explici
- 2022-12-07 C++成員函數后面加override問題_C 語言
- 2022-05-15 Python?matplotlib?seaborn繪圖教程詳解_python
- 2022-07-21 解決win10系統網絡連接正常,但是網頁打不開的問題
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支