網(wǎng)站首頁 編程語言 正文
python神經(jīng)網(wǎng)絡使用tensorflow實現(xiàn)自編碼Autoencoder_python
作者:Bubbliiiing ? 更新時間: 2022-06-28 編程語言學習前言
當你發(fā)現(xiàn)數(shù)據(jù)的維度太多怎么辦!沒關系,我們給它降維!
當你發(fā)現(xiàn)不會降維怎么辦!沒關系,來這里看看怎么autoencode
antoencoder簡介
1、為什么要降維
隨著社會的發(fā)展,可以利用人工智能解決的越來越多,人工智能所需要處理的問題也越來越復雜,作為神經(jīng)網(wǎng)絡的輸入量,維度也越來越大,也就出現(xiàn)了當前所面臨的“維度災難”與“信息豐富、知識貧乏”的問題。
維度太多并不是一件優(yōu)秀的事情,太多的維度同樣會導致訓練效率低,特征難以提取等問題,如果可以通過優(yōu)秀的方法對特征進行提取,將會大大提高訓練效率。
常見的降維方法有PCA(主成分分析)和LDA(線性判別分析,F(xiàn)isher Linear Discriminant Analysis),二者的使用方法我會在今后的日子繼續(xù)寫B(tài)LOG進行闡明。
2、antoencoder的原理
如圖是一個降維的神經(jīng)網(wǎng)絡的示意圖,其可以將n維數(shù)據(jù)量降維2維數(shù)據(jù)量:
輸入量與輸出量都是數(shù)據(jù)原有的全部特征,我們利用tensorflow的optimizer對w1ij和w2ji進行優(yōu)化。在優(yōu)化的最后,w1ij就是我們將n維數(shù)據(jù)編碼到2維的編碼方式,w2ji就是我們將2維數(shù)據(jù)進行解碼到n維數(shù)據(jù)的解碼方式。
3、python中encode的實現(xiàn)
def encoder(x):
layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(x, weights['encoder_h1']),
biases['encoder_b1']))
layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1, weights['encoder_h2']),
biases['encoder_b2']))
layer_3 = tf.nn.sigmoid(tf.add(tf.matmul(layer_2, weights['encoder_h3']),
biases['encoder_b3']))
layer_4 = tf.add(tf.matmul(layer_3, weights['encoder_h4']),
biases['encoder_b4'])
return layer_4
def decoder(x):
layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(x, weights['decoder_h1']),
biases['decoder_b1']))
layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1, weights['decoder_h2']),
biases['decoder_b2']))
layer_3 = tf.nn.sigmoid(tf.add(tf.matmul(layer_2, weights['decoder_h3']),
biases['decoder_b3']))
layer_4 = tf.nn.sigmoid(tf.add(tf.matmul(layer_3, weights['decoder_h4']),
biases['decoder_b4']))
return layer_4
encoder_op = encoder(X)
decoder_op = decoder(encoder_op)
其中encode函數(shù)的輸出就是編碼后的結(jié)果。
全部代碼
該例子為手寫體識別例子,將784維縮小為2維,并且以圖像的方式顯示。
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data",one_hot = "true")
learning_rate = 0.01 #學習率
training_epochs = 10 #訓練十次
batch_size = 256
display_step = 1
examples_to_show = 10
n_input = 784
X = tf.placeholder(tf.float32,[None,n_input])
#encode的過程分為4次,分別是784->128、128->64、64->10、10->2
n_hidden_1 = 128
n_hidden_2 = 64
n_hidden_3 = 10
n_hidden_4 = 2
weights = {
#這四個是用于encode的
'encoder_h1': tf.Variable(tf.truncated_normal([n_input, n_hidden_1],)),
'encoder_h2': tf.Variable(tf.truncated_normal([n_hidden_1, n_hidden_2],)),
'encoder_h3': tf.Variable(tf.truncated_normal([n_hidden_2, n_hidden_3],)),
'encoder_h4': tf.Variable(tf.truncated_normal([n_hidden_3, n_hidden_4],)),
#這四個是用于decode的
'decoder_h1': tf.Variable(tf.truncated_normal([n_hidden_4, n_hidden_3],)),
'decoder_h2': tf.Variable(tf.truncated_normal([n_hidden_3, n_hidden_2],)),
'decoder_h3': tf.Variable(tf.truncated_normal([n_hidden_2, n_hidden_1],)),
'decoder_h4': tf.Variable(tf.truncated_normal([n_hidden_1, n_input],)),
}
biases = {
#這四個是用于encode的
'encoder_b1': tf.Variable(tf.random_normal([n_hidden_1])),
'encoder_b2': tf.Variable(tf.random_normal([n_hidden_2])),
'encoder_b3': tf.Variable(tf.random_normal([n_hidden_3])),
'encoder_b4': tf.Variable(tf.random_normal([n_hidden_4])),
#這四個是用于decode的
'decoder_b1': tf.Variable(tf.random_normal([n_hidden_3])),
'decoder_b2': tf.Variable(tf.random_normal([n_hidden_2])),
'decoder_b3': tf.Variable(tf.random_normal([n_hidden_1])),
'decoder_b4': tf.Variable(tf.random_normal([n_input])),
}
def encoder(x):
#encode函數(shù),分為四步,layer4為編碼后的結(jié)果
layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(x, weights['encoder_h1']),
biases['encoder_b1']))
layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1, weights['encoder_h2']),
biases['encoder_b2']))
layer_3 = tf.nn.sigmoid(tf.add(tf.matmul(layer_2, weights['encoder_h3']),
biases['encoder_b3']))
layer_4 = tf.add(tf.matmul(layer_3, weights['encoder_h4']),
biases['encoder_b4'])
return layer_4
def decoder(x):
#decode函數(shù),分為四步,layer4為解碼后的結(jié)果
layer_1 = tf.nn.sigmoid(tf.add(tf.matmul(x, weights['decoder_h1']),
biases['decoder_b1']))
layer_2 = tf.nn.sigmoid(tf.add(tf.matmul(layer_1, weights['decoder_h2']),
biases['decoder_b2']))
layer_3 = tf.nn.sigmoid(tf.add(tf.matmul(layer_2, weights['decoder_h3']),
biases['decoder_b3']))
layer_4 = tf.nn.sigmoid(tf.add(tf.matmul(layer_3, weights['decoder_h4']),
biases['decoder_b4']))
return layer_4
encoder_op = encoder(X)
decoder_op = decoder(encoder_op)
#將編碼再解碼的結(jié)果與原始碼對比,查看區(qū)別
y_pred = decoder_op
y_label = X
#比較特征損失情況
cost = tf.reduce_mean(tf.square(y_pred-y_label))
train = tf.train.AdamOptimizer(learning_rate).minimize(cost)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
#每個世代進行total_batch次訓練
total_batch = int(mnist.train.num_examples/batch_size)
for epoch in range(training_epochs):
for i in range(total_batch):
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
_,c = sess.run([train,cost],feed_dict={X:batch_xs})
if epoch % display_step == 0:
print("Epoch :","%02d"%epoch,"cost =","%.4f"%c)
#利用test測試機進行測試
encoder_result = sess.run(encoder_op,feed_dict={X:mnist.test.images})
plt.scatter(encoder_result[:,0],encoder_result[:,1],c=np.argmax(mnist.test.labels,1),s=1)
plt.show()
實現(xiàn)結(jié)果為:
可以看到實驗結(jié)果分為很多個區(qū)域塊,基本可以識別。
原文鏈接:https://blog.csdn.net/weixin_44791964/article/details/99547770
相關推薦
- 2022-05-03 C++STL函數(shù)和排序算法的快排以及歸并排序詳解_C 語言
- 2022-12-13 C語言MFC導出dll回調(diào)函數(shù)方法詳解_C 語言
- 2022-04-10 MyBatis 查詢的時候?qū)傩悦妥侄蚊灰恢碌膯栴}
- 2022-10-17 Python可視化程序調(diào)用流程解析_python
- 2022-04-16 一盤王者的時間用C語言實現(xiàn)三子棋_C 語言
- 2022-03-12 C++類和對象之多態(tài)詳解_C 語言
- 2021-12-14 RHCE橋接,免密登錄和修改端口號介紹_Linux
- 2022-04-28 C#委托用法詳解_C#教程
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細win安裝深度學習環(huán)境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構-簡單動態(tài)字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支