網站首頁 編程語言 正文
pytorch Backward過程用時太長
問題描述
使用pytorch對網絡進行訓練的時候遇到一個問題,forward階段很快(只需要幾毫秒),backward階段卻用時很長(需要十多秒)。
導致這個問題的原因很容易被大家忽視,而且網上基本上沒有直接的解決方案,經過一天的折騰,總算把導致這個問題的原因搞清楚了。
解決方案
導致這個問題的原因在于訓練數據的淺拷貝,由于backward過程中的梯度是和模型推理過程中的張量相關的,如果這些張量在被模型使用之前沒有被深拷貝,意味著backward過程的會重復從這些張量的原始內存地址中取值,這個過程非常耗時。所以為了避免這個問題,需要養成一個好習慣,就是將張量數據輸入模型之前進行深拷貝
pytorch的深拷貝方式如下:
tensor_a = tensor_b.clone().detach()
Pytorch backward()簡單理解
backward()是反向傳播求梯度,具體實現過程如下
import torch
x=torch.tensor([1,2,3],requires_grad=True,dtype=torch.double)
y=x**2
z=y.mean()
z.backward()
print(x.grad)
結果
tensor([0.6667, 1.3333, 2.0000], dtype=torch.float64)
有幾個重要的點
1.必須要加上requires_grad=True才能求
2. 一般來說,需要標量才能求梯度。
3.具體過程如下:
z是一個標量(1*1矩陣)分別對x1,x2,x3求偏導, 再代入x1,x2,x3的數值,就是如上程序輸出的結果
總結
原文鏈接:https://blog.csdn.net/ahhhhhh520/article/details/124864850
- 上一篇:沒有了
- 下一篇:沒有了
相關推薦
- 2022-11-23 Python多線程使用方法詳細講解_python
- 2023-03-23 詳解python?ThreadPoolExecutor異常捕獲_python
- 2023-04-26 Numpy對于NaN值的判斷方法_python
- 2023-04-18 Python之split函數的深入理解_python
- 2022-04-08 c++11中std::move函數的使用_C 語言
- 2023-10-15 webrtc用clang編譯支持h264,支持msvc調用庫
- 2022-03-16 Android線程池源碼閱讀記錄介紹_Android
- 2024-03-02 前端directus對接單點登錄
- 欄目分類
-
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支