網(wǎng)站首頁(yè) 編程語(yǔ)言 正文
深度學(xué)習(xí)TextRNN的tensorflow1.14實(shí)現(xiàn)示例_python
作者:我是王大你是誰(shuí) ? 更新時(shí)間: 2023-02-14 編程語(yǔ)言實(shí)現(xiàn)對(duì)下一個(gè)單詞的預(yù)測(cè)
RNN 原理自己找,這里只給出簡(jiǎn)單例子的實(shí)現(xiàn)代碼
import tensorflow as tf
import numpy as np
tf.reset_default_graph()
sentences = ['i love damao','i like mengjun','we love all']
words = list(set(" ".join(sentences).split()))
word2idx = {v:k for k,v in enumerate(words)}
idx2word = {k:v for k,v in enumerate(words)}
V = len(words) # 詞典大小
step = 2 # 時(shí)間序列長(zhǎng)度
hidden = 5 # 隱層大小
dim = 50 # 詞向量維度
# 制作輸入和標(biāo)簽
def make_batch(sentences):
input_batch = []
target_batch = []
for sentence in sentences:
words = sentence.split()
input = [word2idx[word] for word in words[:-1]]
target = word2idx[words[-1]]
input_batch.append(input)
target_batch.append(np.eye(V)[target]) # 這里將標(biāo)簽改為 one-hot 編碼,之后計(jì)算交叉熵的時(shí)候會(huì)用到
return input_batch, target_batch
# 初始化詞向量
embedding = tf.get_variable(shape=[V, dim], initializer=tf.random_normal_initializer(), name="embedding")
X = tf.placeholder(tf.int32, [None, step])
XX = tf.nn.embedding_lookup(embedding, X)
Y = tf.placeholder(tf.int32, [None, V])
# 定義 cell
cell = tf.nn.rnn_cell.BasicRNNCell(hidden)
# 計(jì)算各個(gè)時(shí)間點(diǎn)的輸出和隱層輸出的結(jié)果
outputs, hiddens = tf.nn.dynamic_rnn(cell, XX, dtype=tf.float32) # outputs: [batch_size, step, hidden] hiddens: [batch_size, hidden]
# 這里將所有時(shí)間點(diǎn)的狀態(tài)向量都作為了后續(xù)分類器的輸入(也可以只將最后時(shí)間節(jié)點(diǎn)的狀態(tài)向量作為后續(xù)分類器的輸入)
W = tf.Variable(tf.random_normal([step*hidden, V]))
b = tf.Variable(tf.random_normal([V]))
L = tf.matmul(tf.reshape(outputs,[-1, step*hidden]), W) + b
# 計(jì)算損失并進(jìn)行優(yōu)化
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=Y, logits=L))
optimizer = tf.train.AdamOptimizer(0.001).minimize(cost)
# 預(yù)測(cè)
prediction = tf.argmax(L, 1)
# 初始化 tf
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
# 喂訓(xùn)練數(shù)據(jù)
input_batch, target_batch = make_batch(sentences)
for epoch in range(5000):
_, loss = sess.run([optimizer, cost], feed_dict={X:input_batch, Y:target_batch})
if (epoch+1)%1000 == 0:
print("epoch: ", '%04d'%(epoch+1), 'cost= ', '%04f'%(loss))
# 預(yù)測(cè)數(shù)據(jù)
predict = sess.run([prediction], feed_dict={X: input_batch})
print([sentence.split()[:2] for sentence in sentences], '->', [idx2word[n] for n in predict[0]])
結(jié)果打印
epoch: ?1000 cost= ?0.008979
epoch: ?2000 cost= ?0.002754
epoch: ?3000 cost= ?0.001283
epoch: ?4000 cost= ?0.000697
epoch: ?5000 cost= ?0.000406
[['i', 'love'], ['i', 'like'], ['we', 'love']] -> ['damao', 'mengjun', 'all']?
原文鏈接:https://juejin.cn/post/6949412624215834638
相關(guān)推薦
- 2022-08-15 SpringMVC異常處理流程總結(jié)
- 2022-08-12 Python使用Opencv打開筆記本電腦攝像頭報(bào)錯(cuò)解問(wèn)題及解決_python
- 2022-04-02 詳析C++中的auto_C 語(yǔ)言
- 2022-05-01 C#程序調(diào)用cmd.exe執(zhí)行命令_C#教程
- 2022-05-26 python中的getter與setter你了解嗎_python
- 2022-07-07 通過(guò)Golang編寫一個(gè)AES加密解密工具_(dá)Golang
- 2022-06-04 Android自定義ScrollView實(shí)現(xiàn)阻尼回彈_Android
- 2022-08-23 Python可視化模塊altair的使用詳解_python
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細(xì)win安裝深度學(xué)習(xí)環(huán)境2025年最新版(
- Linux 中運(yùn)行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲(chǔ)小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運(yùn)算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認(rèn)證信息的處理
- Spring Security之認(rèn)證過(guò)濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設(shè)
- maven:解決release錯(cuò)誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實(shí)現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務(wù)發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡(jiǎn)單動(dòng)態(tài)字符串(SD
- arthas操作spring被代理目標(biāo)對(duì)象命令
- Spring中的單例模式應(yīng)用詳解
- 聊聊消息隊(duì)列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠(yuǎn)程分支