網(wǎng)站首頁 編程語言 正文
1.原理
首先粗略的講一下原理,如果這部分你看起來有難度,建議去看看西瓜書或者其他作者的博客。
我們知道線性回歸的損失函數(shù)如下:
L
o
s
s
=
1
N
∑
i
=
1
N
(
y
i
?
(
w
x
i
+
b
)
)
2
Loss=\frac{1}{N}\sum_{i=1}^N(y_i-(wx_i+b))^2
Loss=N1i=1∑N(yi?(wxi+b))2
我們所要做的是:
通過尋找最優(yōu)的參數(shù)
w
,
b
w,b
w,b,使得上述損失函數(shù)最小。
你可以通過求偏導(dǎo)來實(shí)現(xiàn),但是這種方法的適應(yīng)范圍在機(jī)器學(xué)習(xí)領(lǐng)域十分有限,所以我更推薦你使用更為通用的方法:梯度下降法。
所謂梯度下降法,是通過讓參數(shù)通過梯度去更新自身。顯然,損失函數(shù)對(duì)兩個(gè)參數(shù)的梯度分別為:
?
L
o
s
s
?
w
=
?
2
N
∑
i
=
1
N
x
i
(
y
i
?
(
w
x
i
+
b
)
)
\frac{\partial Loss}{\partial w}=-\frac{2}{N}\sum_{i=1}^Nx_i(y_i-(wx_i+b))
?w?Loss=?N2i=1∑Nxi(yi?(wxi+b))
?
L
o
s
s
?
b
=
?
2
N
∑
i
=
1
N
(
y
i
?
(
w
x
i
+
b
)
)
\frac{\partial Loss}{\partial b}=-\frac{2}{N}\sum_{i=1}^N(y_i-(wx_i+b))
?b?Loss=?N2i=1∑N(yi?(wxi+b))
以參數(shù)w為例,其更新的策略為:
w
=
w
?
?
L
o
s
s
?
w
w=w-\frac{\partial Loss}{\partial w}
w=w??w?Loss
其余參數(shù)類似。
2.各個(gè)模塊的實(shí)現(xiàn)
2.1 初始化模塊
首先,作為一個(gè)類,該模型需要傳入一些必要的參數(shù),比如特征與目標(biāo):
class LinearRegression:
def __init__(self, data, target, ):
self.data = data
self.target = target
n_feature = data.shape[1] # 特征個(gè)數(shù)
self.w = np.zeros(n_feature)
self.b = np.zeros(1)
2.2 預(yù)測模塊
作為一個(gè)機(jī)器學(xué)習(xí)模型,預(yù)測模塊是必不可少的,也是最基礎(chǔ)的模塊,這在神經(jīng)網(wǎng)絡(luò)中被稱為前向傳播模塊:
def predict(self):
out = self.data.dot(self.w) + self.b
return out
2.3 梯度計(jì)算模塊
在這個(gè)模塊中,我們計(jì)算當(dāng)前損失函數(shù)對(duì)當(dāng)前參數(shù)的梯度,為此,我們需要使用第一節(jié)中的兩個(gè)梯度公式:
?
L
o
s
s
?
w
=
?
2
N
∑
i
=
1
N
x
i
(
y
i
?
(
w
x
i
+
b
)
)
\frac{\partial Loss}{\partial w}=-\frac{2}{N}\sum_{i=1}^Nx_i(y_i-(wx_i+b))
?w?Loss=?N2i=1∑Nxi(yi?(wxi+b))
?
L
o
s
s
?
b
=
?
2
N
∑
i
=
1
N
(
y
i
?
(
w
x
i
+
b
)
)
\frac{\partial Loss}{\partial b}=-\frac{2}{N}\sum_{i=1}^N(y_i-(wx_i+b))
?b?Loss=?N2i=1∑N(yi?(wxi+b))
代碼:
def gradient(self):
"""計(jì)算損失函數(shù)對(duì)w,b的梯度"""
sample_num = self.data.shape[0] # 發(fā)樣本個(gè)數(shù)
dw = (-2 / sample_num) * np.sum(self.data.T.dot(self.target - self.predict()))
db = (-2 / sample_num) * np.sum(self.target - self.predict())
return dw, db
2.4 訓(xùn)練模塊
在這個(gè)模塊中,通過梯度下降法去更新參數(shù),實(shí)際上這也是訓(xùn)練(學(xué)習(xí))的過程,既然是訓(xùn)練(學(xué)習(xí)),則必須要有訓(xùn)練的次數(shù),即max_iter和學(xué)習(xí)率alpha:
def train(self, alpha=0.01, max_iter=200):
"""訓(xùn)練"""
loss_history = []
for i in range(max_iter):
dw, db = self.gradient()
self.w -= alpha * dw
self.b -= alpha * db # 更新權(quán)重
now_loss = self.loss()
loss_history.append(now_loss)
return loss_history
2.5 其他模塊
如果你足夠細(xì)心,你可能已經(jīng)發(fā)現(xiàn),在train模塊中,有一個(gè)self.loss()方法,這個(gè)方法是用來計(jì)算當(dāng)前預(yù)測值與真實(shí)值之間的誤差(損失)的:
def loss(self):
"""誤差,損失"""
sample_num = self.data.shape[0]
pre = self.predict()
loss = (1 / sample_num) * np.sum((self.target - pre) ** 2)
return loss
通過這個(gè)方法,將每次迭代得到的誤差記錄下來,便于我們觀察訓(xùn)練情況,不僅如此,如果你對(duì)神經(jīng)網(wǎng)絡(luò)有所了解,你會(huì)發(fā)現(xiàn)這是一種非常常見的好方法。
3.實(shí)例化并訓(xùn)練
選取你的訓(xùn)練集,將其傳入實(shí)例:
my_lr = LinearRegression(data=x_train_s, target=y_train)
loss_list = my_lr.train()
plt.plot(np.arange(len(loss_list)),loss_list,'r--')
plt.xlabel('iter num')
plt.ylabel('$Loss$')
plt.show()
結(jié)果:
這和我們的期望是一致的,說明訓(xùn)練成功。
也許你的程序會(huì)拋出這樣的警告:
RuntimeWarning: overflow encountered in multiply
return 2 * self.X.T.dot(self.X.dot(self.w_hat) - self.y)
那是因?yàn)槟憧赡芡泴⒛愕臄?shù)據(jù)進(jìn)行標(biāo)準(zhǔn)化處理了,因?yàn)榫€性回歸是要求樣本近似服從正態(tài)分布的。
除此之外,還有其他歸一化的好處:(摘自https://blog.csdn.net/weixin_43772533/article/details/100826616)
理論層面上,神經(jīng)網(wǎng)絡(luò)是以樣本在事件中的統(tǒng)計(jì)分布概率為基礎(chǔ)進(jìn)行訓(xùn)練和預(yù)測的,所以它對(duì)樣本數(shù)據(jù)的要求比較苛刻。具體說明如下:
1.樣本的各個(gè)特征的取值要符合概率分布,即[0,1]
2.樣本的度量單位要相同。我們并沒有辦法去比較1米和1公斤的區(qū)別,但是,如果我們知道了1米在整個(gè)樣本中的大小比例,以及1公斤在整個(gè)樣本中的大小比例,比如一個(gè)處于0.2的比例位置,另一個(gè)處于0.3的比例位置,就可以說這個(gè)樣本的1米比1公斤要小!
3.神經(jīng)網(wǎng)絡(luò)假設(shè)所有的輸入輸出數(shù)據(jù)都是標(biāo)準(zhǔn)差為1,均值為0,包括權(quán)重值的初始化,激活函數(shù)的選擇,以及優(yōu)化算法的的設(shè)計(jì)。
4.數(shù)值問題
歸一化可以避免一些不必要的數(shù)值問題。因?yàn)榧せ詈瘮?shù)sigmoid/tanh的非線性區(qū)間大約在[-1.7,1.7]。意味著要使神經(jīng)元有效,線性計(jì)算輸出的值的數(shù)量級(jí)應(yīng)該在1(1.7所在的數(shù)量級(jí))左右。這時(shí)如果輸入較大,就意味著權(quán)值必須較小,一個(gè)較大,一個(gè)較小,兩者相乘,就引起數(shù)值問題了。
5.梯度更新
若果輸出層的數(shù)量級(jí)很大,會(huì)引起損失函數(shù)的數(shù)量級(jí)很大,這樣做反向傳播時(shí)的梯度也就很大,這時(shí)會(huì)給梯度的更新帶來數(shù)值問題。
6.學(xué)習(xí)率
知道梯度非常大,學(xué)習(xí)率就必須非常小,因此,學(xué)習(xí)率(學(xué)習(xí)率初始值)的選擇需要參考輸入的范圍,不如直接將數(shù)據(jù)歸一化,這樣學(xué)習(xí)率就不必再根據(jù)數(shù)據(jù)范圍作調(diào)整。 對(duì)w1適合的學(xué)習(xí)率,可能相對(duì)于w2來說會(huì)太小,若果使用適合w1的學(xué)習(xí)率,會(huì)導(dǎo)致在w2方向上步進(jìn)非常慢,會(huì)消耗非常多的時(shí)間,而使用適合w2的學(xué)習(xí)率,對(duì)w1來說又太大,搜索不到適合w1的解。
原文鏈接:https://blog.csdn.net/weixin_57005504/article/details/126769487
相關(guān)推薦
- 2023-01-17 關(guān)于最大池化層和平均池化層圖解_python
- 2022-09-21 android實(shí)現(xiàn)簡單底部導(dǎo)航欄_Android
- 2022-08-31 python中ndarray數(shù)組的索引和切片的使用_python
- 2022-08-04 yolov5中head修改為decouple?head詳解_python
- 2022-12-26 C++內(nèi)存分區(qū)模型超詳細(xì)講解_C 語言
- 2022-07-20 C/C++詳解實(shí)現(xiàn)二層轉(zhuǎn)發(fā)_C 語言
- 2022-05-27 C++超詳細(xì)分析單鏈表的實(shí)現(xiàn)與常見接口_C 語言
- 2022-09-20 Flink實(shí)踐Savepoint使用示例詳解_服務(wù)器其它
- 欄目分類
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細(xì)win安裝深度學(xué)習(xí)環(huán)境2025年最新版(
- Linux 中運(yùn)行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲(chǔ)小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運(yùn)算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認(rèn)證信息的處理
- Spring Security之認(rèn)證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設(shè)
- maven:解決release錯(cuò)誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實(shí)現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務(wù)發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡單動(dòng)態(tài)字符串(SD
- arthas操作spring被代理目標(biāo)對(duì)象命令
- Spring中的單例模式應(yīng)用詳解
- 聊聊消息隊(duì)列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠(yuǎn)程分支