網站首頁 編程語言 正文
文章目錄
- 自動微分
- 一個簡單的例子
- 非標量變量的反向傳播
- 分離計算
- Python控制流的梯度計算
- 小結
自動微分
正如微積分中所說,求導是幾乎所有深度學習優化算法的關鍵步驟。 雖然求導的計算很簡單,只需要一些基本的微積分。 但對于復雜的模型,手工進行更新是一件很痛苦的事情(而且經常容易出錯)。
度學習框架通過自動計算導數,即 自動微分(automatic differentiation) 來加快求導。 實際中,根據我們設計的模型,系統會構建一個 計算圖(computational graph), 來跟蹤計算是哪些數據通過哪些操作組合起來產生輸出。 自動微分使系統能夠隨后反向傳播梯度。 這里,反向傳播(backpropagate) 意味著跟蹤整個計算圖,填充關于每個參數的偏導數。
重要理解 requires_grad 梯度參數、backward() 反向傳播函數、detach() 分離函數。
一個簡單的例子
作為一個演示例子,假設我們想對函數 y = 2 x T x y = 2x^{T}x y=2xTx 關于列向量 x ? \vec{x} x 求導。 首先,我們創建變量x并為其分配一個初始值。
import torch
x = torch.arange(4.0) #生成一個一維張量
x
tensor([0., 1., 2., 3.])
在我們計算 y y y 關于的梯度 x x x 之前,我們需要一個地方(grad)來存儲梯度。 重要的是,我們不會在每次對一個參數求導時都分配新的內存。 因為我們經常會成千上萬次地更新相同的參數,每次都分配新的內存可能很快就會將內存耗盡。 注意,一個標量函數關于向量 x x x 的梯度是向量,并且與 x x x 具有相同的形狀。
x.requires_grad_(True) #為張量x設置梯度,用來存儲計算得出的x梯度向量,等同于torch.arange(4.0, requires_grad=True)
print(x.grad) #輸出x默認的梯度值, 為None
None
現在讓我們計算 y y y 。
y = 2 * torch.dot(x.T, x) #計算y = 2 * (x.T與x的內積),結果為一個標量
y
tensor(28., grad_fn=<MulBackward0>)
x是一個長度為4的向量,計算x和x的點積,得到了我們賦值給y的標量輸出。 接下來,我們通過調用**反向傳播函數(backward())**來自動計算y關于x每個分量的梯度,并打印這些梯度。
這里我覺得大家需要明白一個概念。梯度,也就是標量對向量內的每個元素的偏導數組成的向量。
y.backward() #調用反向傳播函數,計算之前組成y結果的x向量梯度
x.grad #輸出x向量梯度
tensor([ 0., 4., 8., 12.])
它的計算過程如下所示:
其中,函數 y = 2 x T x y = 2x^{T}x y=2xTx 關于 x x x 的梯度應為 4 x 4x 4x (計算原理為二次型對向量的求導)。 如下圖(詳情請閱讀矩陣求導知識):
其中
y = 2 x T x = 2 x T E x = 2 ( E T x + E x ) = 4 x y = 2x^{T}x = 2x^{T}Ex = 2(E^{T}x + Ex) = 4x y=2xTx=2xTEx=2(ETx+Ex)=4x
讓我們快速驗證這個梯度是否計算正確。
x.grad == 4*x #驗證反向傳播函數的偏導數計算是否正確
tensor([True, True, True, True])
現在讓我們計算x的另一個函數。
#默認情況下,pytorch會累計梯度,我們需要清楚之前的值
x.grad.zero_()
y = x.sum() #此時的y = x_0 + x_1 + x_2 + x_3
y.backward() #反向傳播求x向量內分量的偏導數,即x關于y的梯度
x.grad #輸出x的梯度
tensor([1., 1., 1., 1.])
非標量變量的反向傳播
當y不是標量時,向量y關于向量x的導數的最自然解釋是一個矩陣。 對于高階和高維的y和x,求導的結果可以是一個高階張量。
然而,雖然這些更奇特的對象確實出現在高級機器學習中(包括深度學習中), 但當我們調用向量的反向計算時,我們通常會試圖計算一批訓練樣本中每個組成部分的損失函數的導數。 這里,我們的目的不是計算微分矩陣,而是單獨計算批量中每個樣本的偏導數之和。
# 對非標量調用backward需要傳入一個gradient參數,該參數指定微分函數關于self的梯度。
# 在我們的例子中,我們只想求偏導數的和,所以傳遞一個1的梯度是合適的
x.grad.zero_()
y = x * x
# 等價于y.backward(torch.ones(len(x)))
y.sum().backward() #y.sum() = x_o的平方 + x_1的平方 + x_2的平方 + x_3的平方
x.grad #其對x的偏導數分別為2x_0,2x_1,2x_2,2x_3, 梯度為(2x_0,2x_1,2x_2,2x_3)
tensor([0., 2., 4., 6.])
分離計算
有時,我們希望將某些計算移動到記錄的計算圖之外。 例如,假設y是作為x的函數計算的,而z則是作為y和x的函數計算的。 想象一下,我們想計算z關于x的梯度,但由于某種原因,我們希望將y視為一個常數, 并且只考慮到x在y被計算后發揮的作用。
在這里,我們可以分離y來返回一個新變量u,該變量與y具有相同的值, 但丟棄計算圖中如何計算y的任何信息。 換句話說,梯度不會向后流經u到x。 因此,下面的反向傳播函數計算z=ux關于x的偏導數,同時將u作為常數處理, 而不是z=xx*x關于x的偏導數。
#將x之前的梯度進行清空
x.grad.zero_()
y = x * x #y是關于自變量x的函數
u = y.detach() #detach()函數,即分離。只是將y的值賦值給u,即u僅僅作為普通常數,而不會把y是如何得來的計算圖賦值給u
u.requires_grad_(True)
z = u * x #表面上還是z = y * x,但此時的u和x未有任何聯系,可以無顧慮地進行偏導數求解
#z.sum() = u * x = u0x0 + u1x1 + u2x2 + u3x3, 對x的偏導數為(u0,u1,u2,u3), 對u的偏導數為(x0,x1,x2,x3)
#故x關于z的梯度為(0,1,4,9), u關于z的梯度為(0,1,2,3)
z.sum().backward() #進行反向傳播求出偏導數
x.grad, u.grad, x.grad == u #分別輸出x與u的梯度,并判斷x的梯度是否等于y的值
(tensor([0., 1., 4., 9.]),
tensor([0., 1., 2., 3.]),
tensor([True, True, True, True]))
由于記錄了 y y y 的計算結果,我們可以隨后在 y y y 上調用反向傳播, 得到 y = x ? x y=x*x y=x?x 關于的 x x x 的導數,即 2 x 2x 2x 。
x.grad.zero_() #將x的梯度清0
y.sum().backward() #其中, y.sum() = x0的平方 + x1的平方 + x2的平方 + x3的平方。計算出x的梯度為2(x0,x1,x2,x3)
x.grad == 2 * x #判斷梯度是否正確
tensor([True, True, True, True])
Python控制流的梯度計算
使用自動微分的一個好處是: 即使構建函數的計算圖需要通過Python控制流(例如,條件、循環或任意函數調用),我們仍然可以計算得到的變量的梯度。 在下面的代碼中,while循環的迭代次數和if語句的結果都取決于輸入a的值。
def f(a):
b = a * 2
while b.norm() < 1000:
b = b * 2
if b.sum() > 0:
c = b
else:
c = b * 100
return c
讓我們計算梯度。
a = torch.randn(size=(),requires_grad=True) #隨機生成一個符合正態分布的小數
d = f(a) #調用函數f,計算關于a的函數
d.backward() #調用反向傳播函數
我們現在可以分析上面定義的 f f f 函數。 請注意,它在其輸入 a a a 中是分段線性的。 換言之,對于任何 a a a ,存在某個常量標量k,使得 f ( a ) = k a f(a)=ka f(a)=ka ,其中 k k k 的值取決于輸入 a a a 。 因此,我們可以用 d / a d/a d/a 驗證梯度是否正確。
a.grad, a.grad == d / a #計算a的梯度
(tensor(1024.), tensor(True))
小結
深度學習框架可以自動計算導數:
我們首先將梯度附加到想要對其計算偏導數的變量上,然后我們記錄目標值的計算,執行它的反向傳播函數,并訪問得到的梯度。
原文鏈接:https://blog.csdn.net/weixin_43479947/article/details/126989990
- 上一篇:線性回歸的從零開始實現(線性神經網絡)
- 下一篇:瀏覽網站時發生的過程
相關推薦
- 2022-04-27 Python線程之線程安全的隊列Queue_python
- 2022-10-11 C語言超詳細分析多進程的概念與使用_C 語言
- 2024-03-18 sql篇-輸入數據提示[HY000][1366] Incorrect string value: ‘
- 2022-11-23 TypeScript前端上傳文件到MinIO示例詳解_其它
- 2023-01-07 Android?RecyclerLineChart實現圖表繪制教程_Android
- 2022-07-18 Nio中Buffer的Scattering和Gathering
- 2022-07-21 數據庫表數據操作-新增、刪除、修改
- 2022-08-17 yolov5中anchors設置實例詳解_python
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支