網站首頁 編程語言 正文
AutoGrad 是一個老少皆宜的 Python 梯度計算模塊。
對于初高中生而言,它可以用來輕易計算一條曲線在任意一個點上的斜率。
對于大學生、機器學習愛好者而言,你只需要傳遞給它Numpy這樣的標準數據庫下編寫的損失函數,它就可以自動計算損失函數的導數(梯度)。
我們將從普通斜率計算開始,介紹到如何只使用它來實現一個邏輯回歸模型。
1.準備
開始之前,你要確保Python和pip已經成功安裝在電腦上,如果沒有,可以訪問這篇文章:超詳細Python安裝指南?進行安裝。
(可選1)?如果你用Python的目的是數據分析,可以直接安裝Anaconda,它內置了Python和pip.
(可選2)?此外,推薦大家用VSCode編輯器,它有許多的優點
請選擇以下任一種方式輸入命令安裝依賴:
1. Windows 環境 打開 Cmd (開始-運行-CMD)。
2. MacOS 環境 打開 Terminal (command+空格輸入Terminal)。
3. 如果你用的是 VSCode編輯器 或 Pycharm,可以直接使用界面下方的Terminal.
pip?install?autograd
2.計算斜率
對于初高中生同學而言,它可以用來輕松計算斜率,比如我編寫一個斜率為0.5的直線函數:
# 公眾號 Python實用寶典
import?autograd.numpy?as?np
from?autograd?import?grad
def?oneline(x):
????y = x/2
????return?y
grad_oneline = grad(oneline)
print(grad_oneline(3.0))
運行代碼,傳入任意X值,你就能得到在該X值下的斜率:
(base) G:\push\20220724>python?1.py
0.5
由于這是一條直線,因此無論你傳什么值,都只會得到0.5的結果。
那么讓我們再試試一個tanh函數:
# 公眾號 Python實用寶典
import?autograd.numpy?as?np
from?autograd?import?grad
def?tanh(x):
????y = np.exp(-2.0?* x)
????return?(1.0?- y) / (1.0?+ y)
grad_tanh = grad(tanh)
print(grad_tanh(1.0))
此時你會獲得 1.0 這個 x 在tanh上的曲線的斜率:
(base) G:\push\20220724>python?1.py
0.419974341614026
我們還可以繪制出tanh的斜率的變化的曲線:
# 公眾號 Python實用寶典
import?autograd.numpy?as?np
from?autograd?import?grad
def?tanh(x):
????y = np.exp(-2.0?* x)
????return?(1.0?- y) / (1.0?+ y)
grad_tanh = grad(tanh)
print(grad_tanh(1.0))
import?matplotlib.pyplot?as?plt
from?autograd?import?elementwise_grad?as?egrad
x = np.linspace(-7,?7,?200)
plt.plot(x, tanh(x), x, egrad(tanh)(x))
plt.show()
圖中藍色的線是tanh,橙色的線是tanh的斜率,你可以非常清晰明了地看到tanh的斜率的變化。非常便于學習和理解斜率概念。
3.實現一個邏輯回歸模型
有了Autograd,我們甚至不需要借用scikit-learn就能實現一個回歸模型:
邏輯回歸的底層分類就是基于一個sigmoid函數:
import?autograd.numpy?as?np
from?autograd?import?grad
# Build a toy dataset.
inputs = np.array([[0.52,?1.12,?0.77],
???????????????????[0.88,?-1.08,?0.15],
???????????????????[0.52,?0.06,?-1.30],
???????????????????[0.74,?-2.49,?1.39]])
targets = np.array([True,?True,?False,?True])
def?sigmoid(x):
????return?0.5?* (np.tanh(x /?2.) +?1)
def?logistic_predictions(weights, inputs):
????# Outputs probability of a label being true according to logistic model.
????return?sigmoid(np.dot(inputs, weights))
從下面的損失函數可以看到,預測結果的好壞取決于weights的好壞,因此我們的問題轉化為怎么優化這個 weights 變量:
def?training_loss(weights):
????# Training loss is the negative log-likelihood of the training labels.
????preds = logistic_predictions(weights, inputs)
????label_probabilities = preds * targets + (1?- preds) * (1?- targets)
????return?-np.sum(np.log(label_probabilities))
知道了優化目標后,又有Autograd這個工具,我們的問題便迎刃而解了,我們只需要讓weights往損失函數不斷下降的方向移動即可:
# Define a function that returns gradients of training loss using Autograd.
training_gradient_fun = grad(training_loss)
# Optimize weights using gradient descent.
weights = np.array([0.0,?0.0,?0.0])
print("Initial loss:", training_loss(weights))
for?i?in?range(100):
????weights -= training_gradient_fun(weights) *?0.01
print("Trained loss:", training_loss(weights))
運行結果如下:
(base) G:\push\20220724>python?regress.py
Initial loss:?2.772588722239781
Trained loss:?1.067270675787016
由此可見損失函數以及下降方式的重要性,損失函數不正確,你可能無法優化模型。損失下降幅度太單一或者太快,你可能會錯過損失的最低點。
總而言之,AutoGrad是一個你用來優化模型的一個好工具,它可以給你提供更加直觀的損失走勢,進而讓你有更多優化想象力。
有興趣的朋友還可以看官方的更多示例代碼:https://github.com/HIPS/autograd/blob/master/examples/
原文鏈接:https://mp.weixin.qq.com/s/rH2_onXJ3Xvf3eFlIQbvdw
相關推薦
- 2022-11-02 Python封裝解構以及丟棄變量_python
- 2023-03-02 C++版本基于ros將文件夾中的圖像轉換為bag包_C 語言
- 2022-03-25 Mybatis聯合查詢的實現方法(多表聯合查詢)
- 2022-06-10 Linux環境下部署Consul集群_Linux
- 2022-11-23 python?Multiprocessing.Pool進程池模塊詳解_python
- 2022-10-02 react中的useImperativeHandle()和forwardRef()用法_React
- 2022-09-29 shell函數內調用另一個函數(不帶返回值和帶返回值)_linux shell
- 2022-07-08 C#四種計時器Timer的區別和用法_C#教程
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支