日本免费高清视频-国产福利视频导航-黄色在线播放国产-天天操天天操天天操天天操|www.shdianci.com

學無先后,達者為師

網站首頁 編程語言 正文

python神經網絡tensorflow利用訓練好的模型進行預測_python

作者:Bubbliiiing ? 更新時間: 2022-06-30 編程語言

學習前言

在神經網絡學習中slim常用函數與如何訓練、保存模型文章里已經講述了如何使用slim訓練出來一個模型,這篇文章將會講述如何預測。

載入模型思路

載入模型的過程主要分為以下四步:

1、建立會話Session;

2、將img_input的placeholder傳入網絡,建立網絡結構;

3、初始化所有變量;

4、利用saver對象restore載入所有參數。

這里要注意的重點是,在利用saver對象restore載入所有參數之前,必須要建立網絡結構,因為網絡結構對應著cpkt文件中的參數。

(網絡層具有對應的名稱scope。)

實現代碼

在運行實驗代碼前,可以直接下載代碼,因為存在許多依賴的文件

import tensorflow as tf
import numpy as np
from nets import Net
from tensorflow.examples.tutorials.mnist import input_data
def compute_accuracy(x_data,y_data):
    global prediction
    y_pre = sess.run(prediction,feed_dict={img_input:x_data})
    correct_prediction = tf.equal(tf.arg_max(y_data,1),tf.arg_max(y_pre,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
    result = sess.run(accuracy,feed_dict = {img_input:x_data})
    return result
mnist = input_data.read_data_sets("MNIST_data",one_hot = "true")
slim = tf.contrib.slim
# img_input的placeholder
img_input = tf.placeholder(tf.float32, shape = (None, 784))
img_reshape = tf.reshape(img_input,shape = (-1,28,28,1))
# 載入模型
sess = tf.Session()
Conv_Net = Net.Conv_Net()
# 將img_input的placeholder傳入網絡
prediction = Conv_Net.net(img_reshape)
# 載入模型
ckpt_filename = './logs/model.ckpt-20000'
# 初始化所有變量
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
# 恢復
saver.restore(sess, ckpt_filename)
print(compute_accuracy(mnist.test.images,mnist.test.labels))

運行結果為:

0.9921

原文鏈接:https://blog.csdn.net/weixin_44791964/article/details/102584474

欄目分類
最近更新