網站首頁 編程語言 正文
參考資料:自己debug
首先,我報錯的問題的文本是:RuntimeError: CUDA error: device-side assert triggered以及
Assertion `input_val >= zero && input_val <= one` failed

把這兩個文本放在前面以便搜索引擎檢索。下面說一下我的解決方案,因為問題解決過程中我沒有逐步截圖,所以有些步驟只能文字描述。
RCAN是超分辨率恢復領域的一個深度殘差網絡,但是它的代碼卻是很舊的了,基于EDSR和Pytorch<1.2的框架。所以我就將它的model移植到了自己的框架下,結果在訓練過程中突然報了上面這個錯誤。
并且,這個報錯不是在網絡訓練一開始就發生的,而是訓著訓著突然報錯。這讓我百思不得解,思考歷程如下:
首先,根據報錯的文字描述,是在源代碼的 /src/ATen/native/cuda/Loss.cu:102,打開github上pytorch的源碼,找到這一部分:

原來如此,這個斷言原來是發生在binary_cross_entropy_out_cuda()內,也就是在使用nn.BCELoss的時候報了錯。
回顧一下我的源碼,我是讓output和target進行一個BCELoss。但是,在我的model(記為modelB)內,最后一層是sigmoid,也就是說,我的網絡保證了輸出值一定是在[0, 1]之間的!!那么,憑什么說我的值不在這個區間?是Pytorch犯病了嗎?
Well,沒有別的辦法,我在源碼中添加了print語句打印model的輸出,看看它是否真的按照預期輸出值均在[0, 1]。一看嚇一跳,一開始的時候,輸出值十分正常,并且都在[0, 1]之間。但是到達某一個時間節點,網絡的輸出突然就變成滿屏的nan也就是無窮值!這就奇了怪了,因為眾所周知,卷積神經網絡里面的操作都是“有限的”,當輸入有限的時候輸出一定是有限的。等等,當輸入有限的時候。我想當然地認為這個輸入是有限的,但是輸入值到底是不是這樣呢。(此處輸入是modelA的輸出)
有了這種想法,我又打印了modelA(其實就是RCAN)的輸出,結果發現和modelB一樣,在某個時間節點,輸出值突然變成了nan。
問題找到了,但好像又陷入了僵局。RCAN也是一個卷積神經網絡啊,里面也沒有log這種會產生無窮量的操作,那么問題到底出在了哪里?
我靈光一閃突然想到,RCAN的網絡結構里面,貌似有一些網絡層時不需要訓練和優化的。但是在我的源碼里,優化器的代碼是這樣的:
optimizerG = optim.Adam(netG.parameters())
也就是將RCAN網絡中的所有參數都交給了OptimizerG給hook住了。那么,會不會是優化器優化了一些固定的不該優化的網絡參數,導致網絡異常。翻看RCAN的源碼,找到蛛絲馬跡,源碼中有這樣一段:
def make_optimizer(args, my_model):
trainable = filter(lambda x: x.requires_grad, my_model.parameters())
if args.optimizer == 'SGD':
optimizer_function = optim.SGD
kwargs = {'momentum': args.momentum}
elif args.optimizer == 'ADAM':
optimizer_function = optim.Adam
kwargs = {
'betas': (args.beta1, args.beta2),
'eps': args.epsilon
}
elif args.optimizer == 'RMSprop':
optimizer_function = optim.RMSprop
kwargs = {'eps': args.epsilon}
kwargs['lr'] = args.lr
kwargs['weight_decay'] = args.weight_decay
return optimizer_function(trainable, **kwargs)
真該死,RCAN的源碼好像在這里的第一行過濾了需要訓練的參數和不需要訓練的參數。而我的代碼卻沒有干這件事。那么,我就直接用原作者的代碼把,試一試:

咳,訓練效果非常不理想。但是卻再也沒有報過同樣的錯誤,看來,問題已經被解決!!!
總結一下:
1. 在Pytorch進行BCELoss的時候,需要輸入值都在[0, 1]之間,如果你的網絡的最后一層不是sigmoid,你需要把BCELoss換成BCEWithLogitsLoss,這樣損失函數會替你做Sigmoid的操作。
2. 神經網絡的輸入和輸出一般都是有限量,如果你確認你的網絡是好的,不妨查看一下網絡的輸入是不是已經變成了nan
3. 神經網絡的一些網絡層是不需要訓練的,此時你需要告訴優化器這件事,不然optim會做出一些蠢事讓輸出值變成nan
原文鏈接:https://blog.csdn.net/weixin_43590796/article/details/115714248
相關推薦
- 2022-12-28 詳解Go語言strconv與其他基本數據類型轉換函數的使用_Golang
- 2022-10-13 Python中?whl包、tar.gz包的區別詳解_python
- 2023-07-22 linux查看進程的啟動路徑:ll /proc/PID
- 2023-01-03 一文帶你掌握Go語言中文件的寫入操作_Golang
- 2023-05-31 Pandas多個條件(AND,OR,NOT)中提取行_python
- 2022-03-19 .NET6使WebApi獲取訪問者IP地址_基礎應用
- 2022-12-30 Python利用tkinter和socket實現端口掃描_python
- 2022-04-12 Windows11右鍵菜單恢復Windows10樣式
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支