網站首頁 編程語言 正文
鉤子編程(hooking
),也稱作“掛鉤”,是計算機程序設計術語,指通過攔截軟件模塊間的函數調用、消息傳遞、事件傳遞來修改或擴展操作系統、應用程序或其他軟件組件的行為的各種技術。處理被攔截的函數調用、事件、消息的代碼,被稱為鉤子(hook)。
Hook 是 PyTorch
中一個十分有用的特性。利用它,我們可以不必改變網絡輸入輸出的結構,方便地獲取、改變網絡中間層變量的值和梯度。這個功能被廣泛用于可視化神經網絡中間層的 feature
、gradient
,從而診斷神經網絡中可能出現的問題,分析網絡有效性。
本文主要用 hook 函數輸出網絡執行過程中 forward 和 backward 的執行順序,以此找到了bug所在。
用法如下:
# 設置hook func def hook_func(name, module): ? ? def hook_function(module, inputs, outputs): ? ? ? ? # 請依據使用場景自定義函數 ? ? ? ? print(name+' inputs', inputs) ? ? ? ? print(name+' outputs', outputs) ? ? return hook_function # 注冊正反向hook for name, module in model.named_modules(): ? ? module.register_forward_hook(hook_func('[forward]: '+name, module)) ? ? module.register_backward_hook(hook_func('[backward]: '+name, module))
如一個簡單的 MNIST 手寫數字識別的模型結構如下:
class Net(nn.Module): ? ? def __init__(self): ? ? ? ? super(Net, self).__init__() ? ? ? ? self.conv1 = nn.Conv2d(1, 32, 3, 1) ? ? ? ? self.conv2 = nn.Conv2d(32, 64, 3, 1) ? ? ? ? self.dropout1 = nn.Dropout(0.25) ? ? ? ? self.dropout2 = nn.Dropout(0.5) ? ? ? ? self.fc1 = nn.Linear(9216, 128) ? ? ? ? self.fc2 = nn.Linear(128, 10) ? ? def forward(self, x): ? ? ? ? x = self.conv1(x) ? ? ? ? x = F.relu(x) ? ? ? ? x = self.conv2(x) ? ? ? ? x = F.relu(x) ? ? ? ? x = F.max_pool2d(x, 2) ? ? ? ? x = self.dropout1(x) ? ? ? ? x = torch.flatten(x, 1) ? ? ? ? x = self.fc1(x) ? ? ? ? x = F.relu(x) ? ? ? ? x = self.dropout2(x) ? ? ? ? x = self.fc2(x) ? ? ? ? output = F.log_softmax(x, dim=1) ? ? ? ? return output
打印模型:
Net( ? (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1)) ? (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1)) ? (dropout1): Dropout(p=0.25, inplace=False) ? (dropout2): Dropout(p=0.5, inplace=False) ? (fc1): Linear(in_features=9216, out_features=128, bias=True) ? (fc2): Linear(in_features=128, out_features=10, bias=True) )
構建hook函數:
# 設置hook func def hook_func(name, module): ? ? def hook_function(module, inputs, outputs): ? ? ? ? with open("log_model.txt", 'a+') as f: ? ? ? ? ? ? # 請依據使用場景自定義函數 ? ? ? ? ? ? f.write(name + ' ? len(inputs): ' + str(len(inputs)) + '\n') ? ? ? ? ? ? f.write(name + ' ? len(outputs): ?' + str(len(outputs)) + '\n') ? ? return hook_function # 注冊正反向hook for name, module in model.named_modules(): ? ? module.register_forward_hook(hook_func('[forward]: '+name, module)) ? ? module.register_backward_hook(hook_func('[backward]: '+name, module))
輸出的前向和反向傳播過程:
[forward]: conv1 ? len(inputs): 1
[forward]: conv1 ? len(outputs): ?8
[forward]: conv2 ? len(inputs): 1
[forward]: conv2 ? len(outputs): ?8
[forward]: dropout1 ? len(inputs): 1
[forward]: dropout1 ? len(outputs): ?8
[forward]: fc1 ? len(inputs): 1
[forward]: fc1 ? len(outputs): ?8
[forward]: dropout2 ? len(inputs): 1
[forward]: dropout2 ? len(outputs): ?8
[forward]: fc2 ? len(inputs): 1
[forward]: fc2 ? len(outputs): ?8
[forward]: ? ?len(inputs): 1
[forward]: ? ?len(outputs): ?8
[backward]: ? ?len(inputs): 2
[backward]: ? ?len(outputs): ?1
[backward]: fc2 ? len(inputs): 3
[backward]: fc2 ? len(outputs): ?1
[backward]: dropout2 ? len(inputs): 1
[backward]: dropout2 ? len(outputs): ?1
[backward]: fc1 ? len(inputs): 3
[backward]: fc1 ? len(outputs): ?1
[backward]: dropout1 ? len(inputs): 1
[backward]: dropout1 ? len(outputs): ?1
[backward]: conv2 ? len(inputs): 2
[backward]: conv2 ? len(outputs): ?1
[backward]: conv1 ? len(inputs): 2
[backward]: conv1 ? len(outputs): ?1
因為只要模型處于train狀態,hook_func
就會執行,導致不斷輸出 [forward] 和 [backward],所以將輸出內容建議寫到文件中,而不是 print
原文鏈接:https://blog.csdn.net/qq_43799400/article/details/119348675
相關推薦
- 2022-03-11 Linux fatal error: iostream: No such file or direc
- 2022-04-23 Android如何使用ViewPager2實現頁面滑動切換效果_Android
- 2022-10-16 Python?numpy中np.random.seed()的詳細用法實例_python
- 2024-01-10 給idea添加右鍵打開功能
- 2022-03-09 C語言通過gets和gets_s分別實現讀取含空格的字符串_C 語言
- 2022-06-18 Android?ProgressBar實現進度條效果_Android
- 2023-07-14 css :如何讓背景平鋪整個頁面
- 2023-01-26 Android?源碼淺析RecyclerView?ItemAnimator_Android
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支