網站首頁 編程語言 正文
1.引例
給定如圖所示的某個函數,如何計算函數零點x0
在數學上我們如何處理這個問題?
最簡單的辦法是解方程f(x)=0,在代數學上還有著名的零點判定定理
如果函數y=f(x)在區間[a,b]上的圖象是連續不斷的一條曲線,并且有f(a)?f(b)<0,那么函數y=f(x)在區間(a,b)內有零點,即至少存在一個c∈(a,b),使得f(c)=0,這個c也就是方程f(x)=0的根。
然而,數學上的方法并不一定適合工程應用,當函數形式復雜,例如出現超越函數形式;非解析形式,例如遞推關系時,精確的方程解析一般難以進行,因為代數上還沒發展出任意形式的求根公式。而零點判定定理求解效率也較低,需要不停試錯。
因此,引入今天的主題——牛頓迭代法,服務于工程數值計算。
2.牛頓迭代算法求根
記第k輪迭代后,自變量更新為xk,令目標函數f(x)在x=xk泰勒展開:
f(x)=f(xk?)+f′(xk?)(x?xk?)+o(x)
我們希望下一次迭代到根點,忽略泰勒余項,令f(xk+1)=0,則
xk+1?=xk??f(xk?)/f'(xk?)?
不斷重復運算即可逼近根點。
在幾何上,上面過程實際上是在做f(x)在x=xk處的切線,并求切線的零點,在工程上稱為局部線性化。如圖所示,若xk在x0的左側,那么下一次迭代方向向右。
若xk在x0的右側,那么下一次迭代方向向左。
3.牛頓迭代優化
將優化問題轉化為求目標函數一階導數零點的問題,即可運用上面說的牛頓迭代法。
具體地,記第k輪迭代后,自變量更新為xk?,令目標函數f(x)在x=xk泰勒展開:
f(x)=f(xk?)+f′(xk?)(x?xk?)+1/2?f′′(xk?)(x?xk?)2+o(x)
兩邊求導得
f′(x)=f′(xk?)+f′′(xk?)(x?xk?)
令f′(xk+1?)=f′(xk?)+f′′(xk?)(xk+1??xk?)=0,從而得到
xk+1?=xk??f′(xk?)/f'′(xk?)?
對于向量x=[x1?? x2????xd??]T,將上述迭代公式推廣為
xk+1?=xk??[?2f(xk?)]?1?f(xk?)
?
其中?2f(xk?)是Hessian矩陣,當其正定時可以保證牛頓優化算法往 減小的方向迭代
牛頓法的特點如下:
① 以二階速率向最優點收斂,迭代次數遠小于梯度下降法,優化速度快;
梯度下降法的解析參考圖文詳解梯度下降算法的原理及Python實現
②學習率為[?2f(xk?)]?1?,包含更多函數本身的信息,迭代步長可實現自動調整,可視為自適應梯度下降算法;
③ 耗費CPU計算資源多,每次迭代需要計算一次Hessian矩陣,且無法保證Hessian矩陣可逆且正定,因而無法保證一定向最優點收斂。
在實際應用中,牛頓迭代法一般不能直接使用,會引入改進來規避其缺陷,稱為擬牛頓算法簇,其中包含大量不同的算法變種,例如共軛梯度法、DFP算法等等,今后都會介紹到。
4 代碼實戰:Logistic回歸
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
import matplotlib as mpl
from Logit import Logit
'''
* @breif: 從CSV中加載指定數據
* @param[in]: file -> 文件名
* @param[in]: colName -> 要加載的列名
* @param[in]: mode -> 加載模式, set: 列名與該列數據組成的字典, df: df類型
* @retval: mode模式下的返回值
'''
def loadCsvData(file, colName, mode='df'):
assert mode in ('set', 'df')
df = pd.read_csv(file, encoding='utf-8-sig', usecols=colName)
if mode == 'df':
return df
if mode == 'set':
res = {}
for col in colName:
res[col] = df[col].values
return res
if __name__ == '__main__':
# ============================
# 讀取CSV數據
# ============================
csvPath = os.path.abspath(os.path.join(__file__, "../../data/dataset3.0alpha.csv"))
dataX = loadCsvData(csvPath, ["含糖率", "密度"], 'df')
dataY = loadCsvData(csvPath, ["好瓜"], 'df')
label = np.array([
1 if i == "是" else 0
for i in list(map(lambda s: s.strip(), list(dataY['好瓜'])))
])
# ============================
# 繪制樣本點
# ============================
line_x = np.array([np.min(dataX['密度']), np.max(dataX['密度'])])
mpl.rcParams['font.sans-serif'] = [u'SimHei']
plt.title('對數幾率回歸模擬\nLogistic Regression Simulation')
plt.xlabel('density')
plt.ylabel('sugarRate')
plt.scatter(dataX['密度'][label==0],
dataX['含糖率'][label==0],
marker='^',
color='k',
s=100,
label='壞瓜')
plt.scatter(dataX['密度'][label==1],
dataX['含糖率'][label==1],
marker='^',
color='r',
s=100,
label='好瓜')
# ============================
# 實例化對數幾率回歸模型
# ============================
logit = Logit(dataX, label)
# 采用牛頓迭代法
logit.logitRegression(logit.newtomMethod)
line_y = -logit.w[0, 0] / logit.w[1, 0] * line_x - logit.w[2, 0] / logit.w[1, 0]
plt.plot(line_x, line_y, 'g-', label="牛頓迭代法")
# 繪圖
plt.legend(loc='upper left')
plt.show()
其中更新權重代碼為
'''
* @breif: 牛頓迭代法更新權重
* @param[in]: None
* @retval: 優化參數的增量dw
'''
def newtomMethod(self):
wTx = np.dot(self.w.T, self.X).reshape(-1, 1)
p = Logit.sigmod(wTx)
dw_1 = -self.X.dot(self.y - p)
dw_2 = self.X.dot(np.diag((p * (1 - p)).reshape(self.N))).dot(self.X.T)
dw = np.linalg.inv(dw_2).dot(dw_1)
return dw
原文鏈接:https://blog.csdn.net/FRIGIDWINTER/article/details/122832980
相關推薦
- 2022-01-11 Cookie、sessionStorage和localStorage的區別
- 2022-04-09 Eclipse 中Deployment Assembly 無法正常顯示
- 2022-03-31 聊聊Python?String型列表求最值的問題_python
- 2023-06-19 CentOS7使用yum安裝Golang的超詳細步驟_Golang
- 2023-01-18 React手寫redux過程分步講解_React
- 2022-04-12 如何解決:git push -u origin msster時出現error: failed to
- 2022-03-29 Qt超時鎖屏的實現示例_C 語言
- 2022-11-22 Python可視化繪制圖表的教程詳解_python
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支