網站首頁 編程語言 正文
pytorch retain_graph==True的作用說明
總的來說進行一次backward之后,各個節點的值會清除,這樣進行第二次backward會報錯,如果加上retain_graph==True后,可以再來一次backward。?
retain_graph參數的作用
官方定義:
retain_graph (bool, optional) – If False, the graph used to compute the grad will be freed. Note that in nearly all cases setting this option to True is not needed and often can be worked around in a much more efficient way. Defaults to the value of create_graph.
大意是如果設置為False,計算圖中的中間變量在計算完后就會被釋放。
但是在平時的使用中這個參數默認都為False從而提高效率,和creat_graph的值一樣。
具體看一個例子理解
假設一個我們有一個輸入x,y = x **2, z = y*4,然后我們有兩個輸出,一個output_1 = z.mean(),另一個output_2 = z.sum()。
然后我們對兩個output執行backward。
import torch
x = torch.randn((1,4),dtype=torch.float32,requires_grad=True)
y = x ** 2
z = y * 4
print(x)
print(y)
print(z)
loss1 = z.mean()
loss2 = z.sum()
print(loss1,loss2)
loss1.backward() ? ?# 這個代碼執行正常,但是執行完中間變量都free了,所以下一個出現了問題
print(loss1,loss2)
loss2.backward() ? ?# 這時會引發錯誤
程序正常執行到第12行,所有的變量正常保存。
但是在第13行報錯:
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
分析:計算節點數值保存了,但是計算圖x-y-z結構被釋放了,而計算loss2的backward仍然試圖利用x-y-z的結構,因此會報錯。
因此需要retain_graph參數為True去保留中間參數從而兩個loss的backward()不會相互影響。
正確的代碼應當把第11行以及之后改成
- 1 # 假如你需要執行兩次backward,先執行第一個的backward,再執行第二個backward
- 2 loss1.backward(retain_graph=True)# 這里參數表明保留backward后的中間參數。
- 3 loss2.backward() # 執行完這個后,所有中間變量都會被釋放,以便下一次的循環
- 4 ?#如果是在訓練網絡optimizer.step() # 更新參數
create_graph參數比較簡單,參考官方定義:
create_graph (bool, optional) – If True, graph of the derivative will be constructed, allowing to compute higher order derivative products. Defaults to False.
Pytorch retain_graph=True錯誤信息
(Pytorch:RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time)
具有多個loss值
retain_graph設置True,一般多用于兩次backward
# 假如有兩個Loss,先執行第一個的backward,再執行第二個backward
loss1.backward(retain_graph=True) # 這樣計算圖就不會立即釋放
loss2.backward() # 執行完這個后,所有中間變量都會被釋放,以便下一次的循環
optimizer.step() # 更新參數
retain_graph設置True后一定要知道釋放,否則顯卡會占用越來越多,代碼速度也會跑的越來越慢。
有的時候我明明僅有一個模型的也會出現這種錯誤
第一種是輸入的原因。
// Example
x = torch.randn((100,1), requires_grad = True)
y = 1 + 2 * x + 0.3 * torch.randn(100,1)
x_train, y_train = x[:70], y[:70]
x_val, y_val = x[70:], y[70:]
for epoch in range(n_epochs):
?? ?...
?? ?prediction = model(x_train)
?? ?loss.backward()
?? ?...
在多次循環的過程中,input的梯度沒有清除,而且我們也不需要計算輸入的梯度,因此將x的require_grad設置為False就可以解決問題。
第二種是我在訓練LSTM時候發現的。
class LSTMpred(nn.Module):
? ? def __init__(self, input_size, hidden_dim):
? ? ?? ?self.hidden = self.init_hidden()
? ? ? ?...
? ? def init_hidden(self):?? ?#這里我們是需要個隱層參數的
? ? ? ? return (torch.zeros(1, 1, self.hidden_dim, requires_grad=True),
? ? ? ? ? ? ? ? torch.zeros(1, 1, self.hidden_dim, requires_grad=True))
? ? def forward(self, seq):
? ? ? ? ...
這里面的self.hidden我們在每一次訓練的時候都要重新初始化隱層參數:
for epoch in range(Epoch):
?? ?...
?? ?model.hidden = model.init_hidden()
?? ?modout = model(seq)
? ? ...
3. 我的看法
其實,想想這幾種情況都是一回事,都是網絡在反向傳播中不允許多個backward(),也就是梯度下降反饋的時候,有多個循環過程中共用了同一個需要計算梯度的變量,在前一個循環清除梯度后,后面一個循環過程就會在這個變量上栽跟頭(個人想法)。
總結
原文鏈接:https://blog.csdn.net/qq_39861441/article/details/104129368
- 上一篇:沒有了
- 下一篇:沒有了
相關推薦
- 2022-10-10 GO必知必會的常見面試題匯總_Golang
- 2022-05-06 SQL語句獲取表結構
- 2022-03-08 用C語言實現鏈式棧介紹_C 語言
- 2022-04-28 如何利用?Python?繪制動態可視化圖表_python
- 2022-09-13 C語言實現統計一行字符串的單詞個數_C 語言
- 2022-09-04 關于python?DataFrame的合并方法總結_python
- 2023-10-15 自定義帶下箭頭彈出框
- 2022-10-15 python-yml文件讀寫與xml文件讀寫_python
- 欄目分類
-
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支