日本免费高清视频-国产福利视频导航-黄色在线播放国产-天天操天天操天天操天天操|www.shdianci.com

學無先后,達者為師

網站首頁 編程語言 正文

pytorch?hook?鉤子函數的用法_python

作者:ctrl?A_ctrl?C_ctrl?V ? 更新時間: 2022-05-25 編程語言

鉤子編程(hooking),也稱作“掛鉤”,是計算機程序設計術語,指通過攔截軟件模塊間的函數調用、消息傳遞、事件傳遞來修改或擴展操作系統、應用程序或其他軟件組件的行為的各種技術。處理被攔截的函數調用、事件、消息的代碼,被稱為鉤子(hook)。

Hook 是 PyTorch 中一個十分有用的特性。利用它,我們可以不必改變網絡輸入輸出的結構,方便地獲取、改變網絡中間層變量的值和梯度。這個功能被廣泛用于可視化神經網絡中間層的 featuregradient,從而診斷神經網絡中可能出現的問題,分析網絡有效性。

本文主要用 hook 函數輸出網絡執行過程中 forward 和 backward 的執行順序,以此找到了bug所在。

用法如下:

# 設置hook func
def hook_func(name, module):
? ? def hook_function(module, inputs, outputs):
? ? ? ? # 請依據使用場景自定義函數
? ? ? ? print(name+' inputs', inputs)
? ? ? ? print(name+' outputs', outputs)
? ? return hook_function

# 注冊正反向hook
for name, module in model.named_modules():
? ? module.register_forward_hook(hook_func('[forward]: '+name, module))
? ? module.register_backward_hook(hook_func('[backward]: '+name, module))

如一個簡單的 MNIST 手寫數字識別的模型結構如下:

class Net(nn.Module):
? ? def __init__(self):
? ? ? ? super(Net, self).__init__()
? ? ? ? self.conv1 = nn.Conv2d(1, 32, 3, 1)
? ? ? ? self.conv2 = nn.Conv2d(32, 64, 3, 1)
? ? ? ? self.dropout1 = nn.Dropout(0.25)
? ? ? ? self.dropout2 = nn.Dropout(0.5)
? ? ? ? self.fc1 = nn.Linear(9216, 128)
? ? ? ? self.fc2 = nn.Linear(128, 10)

? ? def forward(self, x):
? ? ? ? x = self.conv1(x)
? ? ? ? x = F.relu(x)
? ? ? ? x = self.conv2(x)
? ? ? ? x = F.relu(x)
? ? ? ? x = F.max_pool2d(x, 2)
? ? ? ? x = self.dropout1(x)
? ? ? ? x = torch.flatten(x, 1)
? ? ? ? x = self.fc1(x)
? ? ? ? x = F.relu(x)
? ? ? ? x = self.dropout2(x)
? ? ? ? x = self.fc2(x)
? ? ? ? output = F.log_softmax(x, dim=1)
? ? ? ? return output

打印模型:

Net(
? (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
? (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
? (dropout1): Dropout(p=0.25, inplace=False)
? (dropout2): Dropout(p=0.5, inplace=False)
? (fc1): Linear(in_features=9216, out_features=128, bias=True)
? (fc2): Linear(in_features=128, out_features=10, bias=True)
)

構建hook函數:

# 設置hook func
def hook_func(name, module):
? ? def hook_function(module, inputs, outputs):
? ? ? ? with open("log_model.txt", 'a+') as f:
? ? ? ? ? ? # 請依據使用場景自定義函數
? ? ? ? ? ? f.write(name + ' ? len(inputs): ' + str(len(inputs)) + '\n')
? ? ? ? ? ? f.write(name + ' ? len(outputs): ?' + str(len(outputs)) + '\n')
? ? return hook_function

# 注冊正反向hook
for name, module in model.named_modules():
? ? module.register_forward_hook(hook_func('[forward]: '+name, module))
? ? module.register_backward_hook(hook_func('[backward]: '+name, module))

輸出的前向和反向傳播過程:

[forward]: conv1 ? len(inputs): 1
[forward]: conv1 ? len(outputs): ?8
[forward]: conv2 ? len(inputs): 1
[forward]: conv2 ? len(outputs): ?8
[forward]: dropout1 ? len(inputs): 1
[forward]: dropout1 ? len(outputs): ?8
[forward]: fc1 ? len(inputs): 1
[forward]: fc1 ? len(outputs): ?8
[forward]: dropout2 ? len(inputs): 1
[forward]: dropout2 ? len(outputs): ?8
[forward]: fc2 ? len(inputs): 1
[forward]: fc2 ? len(outputs): ?8
[forward]: ? ?len(inputs): 1
[forward]: ? ?len(outputs): ?8
[backward]: ? ?len(inputs): 2
[backward]: ? ?len(outputs): ?1
[backward]: fc2 ? len(inputs): 3
[backward]: fc2 ? len(outputs): ?1
[backward]: dropout2 ? len(inputs): 1
[backward]: dropout2 ? len(outputs): ?1
[backward]: fc1 ? len(inputs): 3
[backward]: fc1 ? len(outputs): ?1
[backward]: dropout1 ? len(inputs): 1
[backward]: dropout1 ? len(outputs): ?1
[backward]: conv2 ? len(inputs): 2
[backward]: conv2 ? len(outputs): ?1
[backward]: conv1 ? len(inputs): 2
[backward]: conv1 ? len(outputs): ?1

因為只要模型處于train狀態,hook_func 就會執行,導致不斷輸出 [forward] 和 [backward],所以將輸出內容建議寫到文件中,而不是 print

原文鏈接:https://blog.csdn.net/qq_43799400/article/details/119348675

欄目分類
最近更新