網(wǎng)站首頁 編程語言 正文
1.原理
首先粗略的講一下原理,如果這部分你看起來有難度,建議去看看西瓜書或者其他作者的博客。
我們知道線性回歸的損失函數(shù)如下:
L
o
s
s
=
1
N
∑
i
=
1
N
(
y
i
?
(
w
x
i
+
b
)
)
2
Loss=\frac{1}{N}\sum_{i=1}^N(y_i-(wx_i+b))^2
Loss=N1i=1∑N(yi?(wxi+b))2
我們所要做的是:
通過尋找最優(yōu)的參數(shù)
w
,
b
w,b
w,b,使得上述損失函數(shù)最小。
你可以通過求偏導來實現(xiàn),但是這種方法的適應范圍在機器學習領(lǐng)域十分有限,所以我更推薦你使用更為通用的方法:梯度下降法。
所謂梯度下降法,是通過讓參數(shù)通過梯度去更新自身。顯然,損失函數(shù)對兩個參數(shù)的梯度分別為:
?
L
o
s
s
?
w
=
?
2
N
∑
i
=
1
N
x
i
(
y
i
?
(
w
x
i
+
b
)
)
\frac{\partial Loss}{\partial w}=-\frac{2}{N}\sum_{i=1}^Nx_i(y_i-(wx_i+b))
?w?Loss=?N2i=1∑Nxi(yi?(wxi+b))
?
L
o
s
s
?
b
=
?
2
N
∑
i
=
1
N
(
y
i
?
(
w
x
i
+
b
)
)
\frac{\partial Loss}{\partial b}=-\frac{2}{N}\sum_{i=1}^N(y_i-(wx_i+b))
?b?Loss=?N2i=1∑N(yi?(wxi+b))
以參數(shù)w為例,其更新的策略為:
w
=
w
?
?
L
o
s
s
?
w
w=w-\frac{\partial Loss}{\partial w}
w=w??w?Loss
其余參數(shù)類似。
2.各個模塊的實現(xiàn)
2.1 初始化模塊
首先,作為一個類,該模型需要傳入一些必要的參數(shù),比如特征與目標:
class LinearRegression:
def __init__(self, data, target, ):
self.data = data
self.target = target
n_feature = data.shape[1] # 特征個數(shù)
self.w = np.zeros(n_feature)
self.b = np.zeros(1)
2.2 預測模塊
作為一個機器學習模型,預測模塊是必不可少的,也是最基礎(chǔ)的模塊,這在神經(jīng)網(wǎng)絡中被稱為前向傳播模塊:
def predict(self):
out = self.data.dot(self.w) + self.b
return out
2.3 梯度計算模塊
在這個模塊中,我們計算當前損失函數(shù)對當前參數(shù)的梯度,為此,我們需要使用第一節(jié)中的兩個梯度公式:
?
L
o
s
s
?
w
=
?
2
N
∑
i
=
1
N
x
i
(
y
i
?
(
w
x
i
+
b
)
)
\frac{\partial Loss}{\partial w}=-\frac{2}{N}\sum_{i=1}^Nx_i(y_i-(wx_i+b))
?w?Loss=?N2i=1∑Nxi(yi?(wxi+b))
?
L
o
s
s
?
b
=
?
2
N
∑
i
=
1
N
(
y
i
?
(
w
x
i
+
b
)
)
\frac{\partial Loss}{\partial b}=-\frac{2}{N}\sum_{i=1}^N(y_i-(wx_i+b))
?b?Loss=?N2i=1∑N(yi?(wxi+b))
代碼:
def gradient(self):
"""計算損失函數(shù)對w,b的梯度"""
sample_num = self.data.shape[0] # 發(fā)樣本個數(shù)
dw = (-2 / sample_num) * np.sum(self.data.T.dot(self.target - self.predict()))
db = (-2 / sample_num) * np.sum(self.target - self.predict())
return dw, db
2.4 訓練模塊
在這個模塊中,通過梯度下降法去更新參數(shù),實際上這也是訓練(學習)的過程,既然是訓練(學習),則必須要有訓練的次數(shù),即max_iter和學習率alpha:
def train(self, alpha=0.01, max_iter=200):
"""訓練"""
loss_history = []
for i in range(max_iter):
dw, db = self.gradient()
self.w -= alpha * dw
self.b -= alpha * db # 更新權(quán)重
now_loss = self.loss()
loss_history.append(now_loss)
return loss_history
2.5 其他模塊
如果你足夠細心,你可能已經(jīng)發(fā)現(xiàn),在train模塊中,有一個self.loss()方法,這個方法是用來計算當前預測值與真實值之間的誤差(損失)的:
def loss(self):
"""誤差,損失"""
sample_num = self.data.shape[0]
pre = self.predict()
loss = (1 / sample_num) * np.sum((self.target - pre) ** 2)
return loss
通過這個方法,將每次迭代得到的誤差記錄下來,便于我們觀察訓練情況,不僅如此,如果你對神經(jīng)網(wǎng)絡有所了解,你會發(fā)現(xiàn)這是一種非常常見的好方法。
3.實例化并訓練
選取你的訓練集,將其傳入實例:
my_lr = LinearRegression(data=x_train_s, target=y_train)
loss_list = my_lr.train()
plt.plot(np.arange(len(loss_list)),loss_list,'r--')
plt.xlabel('iter num')
plt.ylabel('$Loss$')
plt.show()
結(jié)果:
這和我們的期望是一致的,說明訓練成功。
也許你的程序會拋出這樣的警告:
RuntimeWarning: overflow encountered in multiply
return 2 * self.X.T.dot(self.X.dot(self.w_hat) - self.y)
那是因為你可能忘記將你的數(shù)據(jù)進行標準化處理了,因為線性回歸是要求樣本近似服從正態(tài)分布的。
除此之外,還有其他歸一化的好處:(摘自https://blog.csdn.net/weixin_43772533/article/details/100826616)
理論層面上,神經(jīng)網(wǎng)絡是以樣本在事件中的統(tǒng)計分布概率為基礎(chǔ)進行訓練和預測的,所以它對樣本數(shù)據(jù)的要求比較苛刻。具體說明如下:
1.樣本的各個特征的取值要符合概率分布,即[0,1]
2.樣本的度量單位要相同。我們并沒有辦法去比較1米和1公斤的區(qū)別,但是,如果我們知道了1米在整個樣本中的大小比例,以及1公斤在整個樣本中的大小比例,比如一個處于0.2的比例位置,另一個處于0.3的比例位置,就可以說這個樣本的1米比1公斤要小!
3.神經(jīng)網(wǎng)絡假設所有的輸入輸出數(shù)據(jù)都是標準差為1,均值為0,包括權(quán)重值的初始化,激活函數(shù)的選擇,以及優(yōu)化算法的的設計。
4.數(shù)值問題
歸一化可以避免一些不必要的數(shù)值問題。因為激活函數(shù)sigmoid/tanh的非線性區(qū)間大約在[-1.7,1.7]。意味著要使神經(jīng)元有效,線性計算輸出的值的數(shù)量級應該在1(1.7所在的數(shù)量級)左右。這時如果輸入較大,就意味著權(quán)值必須較小,一個較大,一個較小,兩者相乘,就引起數(shù)值問題了。
5.梯度更新
若果輸出層的數(shù)量級很大,會引起損失函數(shù)的數(shù)量級很大,這樣做反向傳播時的梯度也就很大,這時會給梯度的更新帶來數(shù)值問題。
6.學習率
知道梯度非常大,學習率就必須非常小,因此,學習率(學習率初始值)的選擇需要參考輸入的范圍,不如直接將數(shù)據(jù)歸一化,這樣學習率就不必再根據(jù)數(shù)據(jù)范圍作調(diào)整。 對w1適合的學習率,可能相對于w2來說會太小,若果使用適合w1的學習率,會導致在w2方向上步進非常慢,會消耗非常多的時間,而使用適合w2的學習率,對w1來說又太大,搜索不到適合w1的解。
原文鏈接:https://blog.csdn.net/weixin_57005504/article/details/126769487
相關(guān)推薦
- 2022-06-19 python繪制分組對比柱狀圖_python
- 2022-07-13 VMware Workstation Pro界面設置為中文界面
- 2022-09-04 docker部署可執(zhí)行jar包的思路與完整步驟_docker
- 2022-10-09 C#實現(xiàn)折半查找算法_C#教程
- 2023-07-24 前端常見狀態(tài)碼
- 2022-06-29 python版單鏈表反轉(zhuǎn)_python
- 2022-09-17 Python高效處理大文件的方法詳解_python
- 2023-07-13 遍歷對象并改變對象某個屬性的值
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細win安裝深度學習環(huán)境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡單動態(tài)字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支