網(wǎng)站首頁 編程語言 正文
本文實例為大家分享了TensorFlow實現(xiàn)簡單線性回歸的具體代碼,供大家參考,具體內容如下
簡單的一元線性回歸
一元線性回歸公式:
其中x是特征:[x1,x2,x3,…,xn,]T
w是權重,b是偏置值
代碼實現(xiàn)
導入必須的包
import tensorflow as tf import matplotlib.pyplot as plt import numpy as np import os # 屏蔽warning以下的日志信息 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
產生模擬數(shù)據(jù)
def generate_data(): ? ? x = tf.constant(np.array([i for i in range(0, 100, 5)]).reshape(-1, 1), tf.float32) ? ? y = tf.add(tf.matmul(x, [[1.3]]) + 1, tf.random_normal([20, 1], stddev=30)) ? ? return x, y
x是100行1列的數(shù)據(jù),tf.matmul是矩陣相乘,所以權值設置成二維的。
設置的w是1.3, b是1
實現(xiàn)回歸
def myregression(): ? ? """ ? ? 自實現(xiàn)線性回歸 ? ? :return: ? ? """ ? ? x, y = generate_data() ? ? # ? ? 建立模型 ?y = x * w + b ? ? # w 1x1的二維數(shù)據(jù) ? ? w = tf.Variable(tf.random_normal([1, 1], mean=0.0, stddev=1.0), name='weight_a') ? ? b = tf.Variable(0.0, name='bias_b') ? ? y_predict = tf.matmul(x, a) + b ? ? # 建立損失函數(shù) ? ? loss = tf.reduce_mean(tf.square(y_predict - y)) ? ?? ? ? # 訓練 ? ? train_op = tf.train.GradientDescentOptimizer(0.1).minimize(loss=loss) ? ? # 初始化全局變量 ? ? init_op = tf.global_variables_initializer() ?? ? ? with tf.Session() as sess: ? ? ? ? sess.run(init_op) ? ? ? ? print('初始的權重:%f偏置值:%f' % (a.eval(), b.eval())) ? ?? ? ? ? ? # 訓練優(yōu)化 ? ? ? ? for i in range(1, 100): ? ? ? ? ? ? sess.run(train_op) ? ? ? ? ? ? print('第%d次優(yōu)化的權重:%f偏置值:%f' % (i, a.eval(), b.eval())) ? ? ? ? # 顯示回歸效果 ? ? ? ? show_img(x.eval(), y.eval(), y_predict.eval())
使用matplotlib查看回歸效果
def show_img(x, y, y_pre): ? ? plt.scatter(x, y) ? ? plt.plot(x, y_pre) ? ? plt.show()
完整代碼
import tensorflow as tf import matplotlib.pyplot as plt import numpy as np import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' def generate_data(): ? ? x = tf.constant(np.array([i for i in range(0, 100, 5)]).reshape(-1, 1), tf.float32) ? ? y = tf.add(tf.matmul(x, [[1.3]]) + 1, tf.random_normal([20, 1], stddev=30)) ? ? return x, y def myregression(): ? ? """ ? ? 自實現(xiàn)線性回歸 ? ? :return: ? ? """ ? ? x, y = generate_data() ? ? # 建立模型 ?y = x * w + b ? ? w = tf.Variable(tf.random_normal([1, 1], mean=0.0, stddev=1.0), name='weight_a') ? ? b = tf.Variable(0.0, name='bias_b') ? ? y_predict = tf.matmul(x, w) + b ? ? # 建立損失函數(shù) ? ? loss = tf.reduce_mean(tf.square(y_predict - y)) ? ? # 訓練 ? ? train_op = tf.train.GradientDescentOptimizer(0.0001).minimize(loss=loss) ? ? init_op = tf.global_variables_initializer() ? ? with tf.Session() as sess: ? ? ? ? sess.run(init_op) ? ? ? ? print('初始的權重:%f偏置值:%f' % (w.eval(), b.eval())) ? ? ? ? # 訓練優(yōu)化 ? ? ? ? for i in range(1, 35000): ? ? ? ? ? ? sess.run(train_op) ? ? ? ? ? ? print('第%d次優(yōu)化的權重:%f偏置值:%f' % (i, w.eval(), b.eval())) ? ? ? ? show_img(x.eval(), y.eval(), y_predict.eval()) def show_img(x, y, y_pre): ? ? plt.scatter(x, y) ? ? plt.plot(x, y_pre) ? ? plt.show() if __name__ == '__main__': ? ? myregression()
看看訓練的結果(因為數(shù)據(jù)是隨機產生的,每次的訓練結果都會不同,可適當調節(jié)梯度下降的學習率和訓練步數(shù))
35000次的訓練結果
原文鏈接:https://blog.csdn.net/kylinxjd/article/details/105557304
相關推薦
- 2023-01-12 Kotlin?Option與Either及Result實現(xiàn)異常處理詳解_Android
- 2021-12-03 找不到或無法加載主類 CMD || 找不到\*\*\路徑|| 原因大全
- 2022-07-11 Oracle使用dblink同步數(shù)據(jù)
- 2022-05-09 React中的axios模塊及使用方法_React
- 2022-06-22 android使用intent傳遞參數(shù)實現(xiàn)乘法計算_Android
- 2022-10-16 Android基于方法池與回調實現(xiàn)登錄攔截的場景_Android
- 2022-04-04 iview在Table表格中渲染title文字提示,使用render實現(xiàn)
- 2022-06-04 解決Go語言time包數(shù)字與時間相乘的問題_Golang
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細win安裝深度學習環(huán)境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結構-簡單動態(tài)字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支