網站首頁 編程語言 正文
隨機梯度下降法
為什么使用隨機梯度下降法?
如果當我們數據量和樣本量非常大時,每一項都要參與到梯度下降,那么它的計算量時非常大的,所以我們可以采用隨機梯度下降法。
隨機梯度下降法中的學習率必須是隨著循環的次數增加而遞減的。如果eta取一樣的話有可能在非常接近我們的最優值時會跳過,所以隨著迭代次數的增加,學習率eta要隨之減小,我們可以用模擬退火的思想實現(如下圖所示),t0和t1是一個常數,定值,其通常是根據經驗取得一些值。
隨機梯度下降法的實現
隨機梯度下降法的公式如下圖所示,其中挑出一個樣本出來計算。
先創建x,y,以下取10000個樣本
import numpy as np
m = 10000
x = np.random.random(size=m)
y = x*3 + 4 + np.random.normal(size=m)
寫入函數
def dj_sgd(theta, x_i, y_i): # 傳入一個樣本,獲取對應的梯度
return x_i.T.dot(x_i.dot(theta)-y_i)*2 # MSE
def sgd(X_b, y, initial_theta, n_iters): # 求出整個theta的函數
def learning_rate(i_iter):
t0 = 5
t1 = 50
return t0/(i_iter+t1)
theta = initial_theta
i_iter = 1
while i_iter <= n_iters:
index = np.random.randint(0, len(X_b))
x_i = X_b[index]
y_i = y[index]
gradient = dj_sgd(theta, x_i, y_i) # 求導數
theta = theta - gradient*learning_rate(i_iter) # 求步長
i_iter += 1
return theta
調用函數,求出截距和系數
以上隨機梯度的缺點是不能照顧到每一點,因此需要進行改進。
以下對其中的函數進行修改。
def dj_sgd(theta, x_i, y_i): # 傳入一個樣本,獲取對應的梯度
return x_i.T.dot(x_i.dot(theta)-y_i)*2 # MSE
def sgd(X_b, y, initial_theta, n_iters): # 求出整個theta的函數
def learning_rate(i_iter):
t0 = 5
t1 = 50
return t0/(i_iter+t1)
theta = initial_theta
m = len(X_b)
for cur_iter in range(n_iters): # 每一次循環都把樣本打亂,n_iters的代表整個樣本看幾輪
random_indexs = np.random.permutation(m)
X_random = X_b[random_indexs]
y_random = y[random_indexs]
for i in range(m):
theta = theta - learning_rate(cur_iter*m+i) * (dj_sgd(theta, X_random[i], y_random[i]))
return theta
與前邊運算結果進行對比,其耗時更長。
原文鏈接:https://blog.csdn.net/Oh_Python/article/details/129233709
- 上一篇:沒有了
- 下一篇:沒有了
相關推薦
- 2023-05-29 python怎樣判斷一個數值(字符串)為整數_python
- 2022-09-14 Python安裝xarray庫讀取.nc文件的詳細步驟_python
- 2023-01-30 基于redis樂觀鎖實現并發排隊_Redis
- 2022-05-06 利用python實現蝴蝶曲線_python
- 2022-09-19 React?Hook?四種組件優化總結_React
- 2023-01-11 Python?基于xml.etree.ElementTree實現XML對比示例詳解_python
- 2022-06-02 使用kubeadm命令行工具創建kubernetes集群_云和虛擬化
- 2022-06-29 C語言超詳細講解遞歸算法漢諾塔_C 語言
- 欄目分類
-
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支