網站首頁 編程語言 正文
一:手寫數字模型構建與保存
1 加載數據集
# 1加載數據
digits_data = load_digits()
可以先簡單查看下 手寫數字集,如下可以隱約看出數字為8
plt.imshow(digits_data.images[8])
plt.show()
2 特征數據 標簽數據
# 數據劃分
x_data = digits_data.data
y_data = digits_data.target
3 訓練集 測試集
# 訓練集 + 測試集
x_test = x_data[:40]
y_test = y_data[:40]
x_train = x_data[40:]
y_train = y_data[40:]
# 概率問題
y_train_2 = np.zeros(shape=(len(y_train), 10))
4 數據流圖 輸入層
input_size = digits_data.data.shape[1] # 輸入的列數
# 數據流圖的構建
# x:輸入64個特征值--像素
x = tf.placeholder(np.float32, shape=[None, input_size])
# y:識別的數字 有幾個類別[0-9]
y = tf.placeholder(np.float32, shape=[None, 10])
5 隱藏層
5.1 第一層
# 第一層隱藏層
# 參數1 輸入維度 參數2:輸出維度(神經元個數) 標準差是0.1的正態(tài)分布
w1 = tf.Variable(tf.random_normal([input_size, 80], stddev=0.1))
# b的個數就是隱藏層神經元的個數
b1 = tf.Variable(tf.constant(0.01), [80])
# 第一層計算
one = tf.matmul(x, w1) + b1
# 激活函數 和0比 大于0則激活
op1 = tf.nn.relu(one)
5.2 第二層
# 第二層隱藏層 上一層輸出為下一層輸入
# 參數1 輸入維度 參數2:輸出維度(神經元個數) 標準差是0.1的正態(tài)分布
w2 = tf.Variable(tf.random_normal([80, 10], stddev=0.1))
# b的個數就是隱藏層神經元的個數
b2 = tf.Variable(tf.constant(0.01), [10])
# 第一層計算
two = tf.matmul(op1, w2) + b2
# 激活函數 和0比 大于0則激活
op2 = tf.nn.relu(two)
6 損失函數
# 構建損失函數
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=op2))
7 梯度下降算法
# 梯度下降算法
Optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.005).minimize(loss)
8 輸出損失值?
# 變量初始化
init = tf.global_variables_initializer()
data_size = digits_data.data.shape[0]
# 開啟會話
with tf.Session() as sess:
sess.run(init)
# 訓練次數
for i in range(500):
# 數據分組
start = (i * 100) % data_size
end = min(start + 100, data_size)
batch_x = x_train[start:end]
batch_y = y_train_2[start:end]
sess.run(Optimizer, feed_dict={x: batch_x, y: batch_y})
# 輸出損失值
train_loss = sess.run(loss, feed_dict={x: batch_x, y: batch_y})
print(train_loss)
9 模型 保存與使用
obj = tf.train.Saver()
# 模型保存
obj.save(sess, 'model-digits.ckpt')
10 完整源碼分享
import tensorflow as tf
from sklearn.datasets import load_digits
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
# 1加載數據
digits_data = load_digits()
# 查看數據
# print(digits_data)
# 查看數據基本特征 (1797, 64) 64:8*8像素點
# print(digits_data.data.shape)
# plt.imshow(digits_data.images[8])
# plt.show()
# 數據劃分
x_data = digits_data.data
y_data = digits_data.target
# 訓練集 + 測試集
x_test = x_data[:40]
y_test = y_data[:40]
x_train = x_data[40:]
y_train = y_data[40:]
# 概率問題
y_train_2 = np.zeros(shape=(len(y_train), 10))
# 對應的分類 當前行對應列變成1
for index, row in enumerate(y_train_2):
# 當前行 對應的數字對應列
row[int(y_train[index])] = 1
# print(y_train_2[0])
input_size = digits_data.data.shape[1] # 輸入的列數
# 數據流圖的構建
# x:輸入64個特征值--像素
x = tf.placeholder(np.float32, shape=[None, input_size])
# y:識別的數字 有幾個類別[0-9]
y = tf.placeholder(np.float32, shape=[None, 10])
# 第一層隱藏層
# 參數1 輸入維度 參數2:輸出維度(神經元個數) 標準差是0.1的正態(tài)分布
w1 = tf.Variable(tf.random_normal([input_size, 80], stddev=0.1))
# b的個數就是隱藏層神經元的個數
b1 = tf.Variable(tf.constant(0.01), [80])
# 第一層計算
one = tf.matmul(x, w1) + b1
# 激活函數 和0比 大于0則激活
op1 = tf.nn.relu(one)
# 第二層隱藏層 上一層輸出為下一層輸入
# 參數1 輸入維度 參數2:輸出維度(神經元個數) 標準差是0.1的正態(tài)分布
w2 = tf.Variable(tf.random_normal([80, 10], stddev=0.1))
# b的個數就是隱藏層神經元的個數
b2 = tf.Variable(tf.constant(0.01), [10])
# 第一層計算
two = tf.matmul(op1, w2) + b2
# 激活函數 和0比 大于0則激活
op2 = tf.nn.relu(two)
# 構建損失函數
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=op2))
# 梯度下降算法 優(yōu)化器 learning_rate學習率(步長)
Optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.005).minimize(loss)
# 變量初始化
init = tf.global_variables_initializer()
data_size = digits_data.data.shape[0]
# 開啟會話
with tf.Session() as sess:
sess.run(init)
# 訓練次數
for i in range(500):
# 數據分組
start = (i * 100) % data_size
end = min(start + 100, data_size)
batch_x = x_train[start:end]
batch_y = y_train_2[start:end]
sess.run(Optimizer, feed_dict={x: batch_x, y: batch_y})
# 輸出損失值
train_loss = sess.run(loss, feed_dict={x: batch_x, y: batch_y})
print(train_loss)
obj = tf.train.Saver()
# 模型保存
obj.save(sess, 'modelSave/model-digits.ckpt')
?損失值在0.303左右,如下圖所示
二:手寫數字模型使用與測試
對上一步創(chuàng)建的模型,使用測試
import tensorflow as tf
from sklearn.datasets import load_digits
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
# 1加載數據
digits_data = load_digits()
# 數據劃分
x_data = digits_data.data
y_data = digits_data.target
# 訓練集 + 測試集
x_test = x_data[:40]
y_test = y_data[:40]
x_train = x_data[40:]
y_train = y_data[40:]
# 概率問題
y_train_2 = np.zeros(shape=(len(y_train), 10))
# 對應的分類 當前行對應列變成1
for index, row in enumerate(y_train_2):
# 當前行 對應的數字對應列
row[int(y_train[index])] = 1
# 網絡搭建
num_class = 10 # 數字0-9
hidden_num = 80 # 神經元個數
input_size = digits_data.data.shape[1] # 輸入的列數
# 數據流圖的構建
# x:輸入64個特征值--像素
x = tf.placeholder(np.float32, shape=[None, 64])
# y:識別的數字 有幾個類別[0-9]
y = tf.placeholder(np.float32, shape=[None, 10])
# 第一層隱藏層
# 參數1 輸入維度 參數2:輸出維度(神經元個數) 標準差是0.1的正態(tài)分布
w1 = tf.Variable(tf.random_normal([input_size, 80], stddev=0.1))
# b的個數就是隱藏層神經元的個數
b1 = tf.Variable(tf.constant(0.01), [80])
# 第一層計算
one = tf.matmul(x, w1) + b1
# 激活函數 和0比 大于0則激活
op1 = tf.nn.relu(one)
# 第二層隱藏層 上一層輸出為下一層輸入
# 參數1 輸入維度 參數2:輸出維度(神經元個數) 標準差是0.1的正態(tài)分布
w2 = tf.Variable(tf.random_normal([80, 10], stddev=0.1))
# b的個數就是隱藏層神經元的個數
b2 = tf.Variable(tf.constant(0.01), [10])
# 第一層計算
two = tf.matmul(op1, w2) + b2
# 激活函數 和0比 大于0則激活
op2 = tf.nn.relu(two)
# 變量初始化
init = tf.global_variables_initializer()
train_count = 500
batch_size = 100
data_size = x_train.shape[0]
pre_max_index = tf.argmax(op2, 1)
plt.imshow(digits_data.images[13]) # 3
plt.show()
with tf.Session() as sess:
sess.run(init)
# 使用網絡
obj = tf.train.Saver()
obj.restore(sess, 'modelSave/model-digits.ckpt')
print(sess.run(op2, feed_dict={x: [x_test[13], x_test[14]]}))
print(sess.run(pre_max_index, feed_dict={x: [x_test[13], x_test[14]]}))
想要測試的數據,如下圖所示
使用模型測試出來的結果,如下圖所示,模型基本能夠使用
原文鏈接:https://blog.csdn.net/m0_56051805/article/details/128398291
相關推薦
- 2022-05-23 NetCat工具命令介紹及遠程文件傳輸實現(xiàn)_linux shell
- 2022-04-07 Redis數據庫分布式設計方案介紹_Redis
- 2022-03-31 C語言類的雙向鏈表詳解_C 語言
- 2021-10-24 Linux多線程中fork與互斥鎖過程示例_Linux
- 2023-08-15 vite打包報錯 Rollup failed to resolve
- 2022-03-23 如何解決Mac中的Docker宿主機與容器無法通信(MacOS下解決宿主機和docker容器網絡互通
- 2023-10-27 np.zeros()函數的使用方法_python
- 2022-05-11 剖析數據庫中重要而又常被曲解的概念
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細win安裝深度學習環(huán)境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態(tài)字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支