網(wǎng)站首頁 編程語言 正文
【1】方法一:獲取nn.Sequential的中間層輸出
import torch import torch.nn as nn model = nn.Sequential( ? ? ? ? ? ? nn.Conv2d(3, 9, 1, 1, 0, bias=False), ? ? ? ? ? ? nn.BatchNorm2d(9), ? ? ? ? ? ? nn.ReLU(inplace=True), ? ? ? ? ? ? nn.AdaptiveAvgPool2d((1, 1)), ? ? ? ? ) # 假如想要獲得ReLu的輸出 x = torch.rand([2, 3, 224, 224]) for i in range(len(model)): ? ? x = model[i](x) ? ? if i == 2: ? ? ? ? ReLu_out = x print('ReLu_out.shape:\n\t',ReLu_out.shape) print('x.shape:\n\t',x.shape)
結(jié)果:
ReLu_out.shape:
? torch.Size([2, 9, 224, 224])
x.shape:
? torch.Size([2, 9, 1, 1])
【2】方法二:IntermediateLayerGetter
from collections import OrderedDict ? import torch from torch import nn ? ? class IntermediateLayerGetter(nn.ModuleDict): ? ? """ ? ? Module wrapper that returns intermediate layers from a model ? ? It has a strong assumption that the modules have been registered ? ? into the model in the same order as they are used. ? ? This means that one should **not** reuse the same nn.Module ? ? twice in the forward if you want this to work. ? ? Additionally, it is only able to query submodules that are directly ? ? assigned to the model. So if `model` is passed, `model.feature1` can ? ? be returned, but not `model.feature1.layer2`. ? ? Arguments: ? ? ? ? model (nn.Module): model on which we will extract the features ? ? ? ? return_layers (Dict[name, new_name]): a dict containing the names ? ? ? ? ? ? of the modules for which the activations will be returned as ? ? ? ? ? ? the key of the dict, and the value of the dict is the name ? ? ? ? ? ? of the returned activation (which the user can specify). ? ? """ ? ?? ? ? def __init__(self, model, return_layers): ? ? ? ? if not set(return_layers).issubset([name for name, _ in model.named_children()]): ? ? ? ? ? ? raise ValueError("return_layers are not present in model") ? ? ? ? ? orig_return_layers = return_layers ? ? ? ? return_layers = {k: v for k, v in return_layers.items()} ? ? ? ? layers = OrderedDict() ? ? ? ? for name, module in model.named_children(): ? ? ? ? ? ? layers[name] = module ? ? ? ? ? ? if name in return_layers: ? ? ? ? ? ? ? ? del return_layers[name] ? ? ? ? ? ? if not return_layers: ? ? ? ? ? ? ? ? break ? ? ? ? ? super(IntermediateLayerGetter, self).__init__(layers) ? ? ? ? self.return_layers = orig_return_layers ? ? ? def forward(self, x): ? ? ? ? out = OrderedDict() ? ? ? ? for name, module in self.named_children(): ? ? ? ? ? ? x = module(x) ? ? ? ? ? ? if name in self.return_layers: ? ? ? ? ? ? ? ? out_name = self.return_layers[name] ? ? ? ? ? ? ? ? out[out_name] = x ? ? ? ? return out
# example m = torchvision.models.resnet18(pretrained=True) # extract layer1 and layer3, giving as names `feat1` and feat2` new_m = torchvision.models._utils.IntermediateLayerGetter(m,{'layer1': 'feat1', 'layer3': 'feat2'}) out = new_m(torch.rand(1, 3, 224, 224)) print([(k, v.shape) for k, v in out.items()]) # [('feat1', torch.Size([1, 64, 56, 56])), ('feat2', torch.Size([1, 256, 14, 14]))]
作用:
在定義它的時(shí)候注明作用的模型(如下例中的m)和要返回的layer(如下例中的layer1,layer3),得到new_m。
使用時(shí)喂輸入變量,返回的就是對(duì)應(yīng)的layer
。
舉例:
m = torchvision.models.resnet18(pretrained=True) ?# extract layer1 and layer3, giving as names `feat1` and feat2` new_m = torchvision.models._utils.IntermediateLayerGetter(m,{'layer1': 'feat1', 'layer3': 'feat2'}) out = new_m(torch.rand(1, 3, 224, 224)) print([(k, v.shape) for k, v in out.items()])
輸出結(jié)果:
[('feat1', torch.Size([1, 64, 56, 56])), ('feat2', torch.Size([1, 256, 14, 14]))]
【3】方法三:鉤子
class TestForHook(nn.Module): ? ? def __init__(self): ? ? ? ? super().__init__() ? ? ? ? self.linear_1 = nn.Linear(in_features=2, out_features=2) ? ? ? ? self.linear_2 = nn.Linear(in_features=2, out_features=1) ? ? ? ? self.relu = nn.ReLU() ? ? ? ? self.relu6 = nn.ReLU6() ? ? ? ? self.initialize() ? ? def forward(self, x): ? ? ? ? linear_1 = self.linear_1(x) ? ? ? ? linear_2 = self.linear_2(linear_1) ? ? ? ? relu = self.relu(linear_2) ? ? ? ? relu_6 = self.relu6(relu) ? ? ? ? layers_in = (x, linear_1, linear_2) ? ? ? ? layers_out = (linear_1, linear_2, relu) ? ? ? ? return relu_6, layers_in, layers_out features_in_hook = [] features_out_hook = [] def hook(module, fea_in, fea_out): ? ? features_in_hook.append(fea_in) ? ? features_out_hook.append(fea_out) ? ? return None net = TestForHook()
第一種寫法,按照類型勾,但如果有重復(fù)類型的layer比較復(fù)雜
net_chilren = net.children() for child in net_chilren: ? ? if not isinstance(child, nn.ReLU6): ? ? ? ? child.register_forward_hook(hook=hook)
推薦下面我改的這種寫法,因?yàn)槲易约旱木W(wǎng)絡(luò)中,在Sequential
中有很多層,
這種方式可以直接先print(net)
一下,找出自己所需要那個(gè)layer
的名稱,按名稱勾出來
layer_name = 'relu_6' for (name, module) in net.named_modules(): ? ? if name == layer_name: ? ? ? ? module.register_forward_hook(hook=hook) print(features_in_hook) ?# 勾的是指定層的輸入 print(features_out_hook) ?# 勾的是指定層的輸出
原文鏈接:https://zhuanlan.zhihu.com/p/362985275
相關(guān)推薦
- 2022-02-28 el-dialog 的關(guān)閉事件執(zhí)行兩次
- 2022-11-17 Android文本與視圖基本操作梳理介紹_Android
- 2022-03-31 C語言類的基本語法詳解_C 語言
- 2023-02-04 C語言設(shè)計(jì)實(shí)現(xiàn)掃描器的自動(dòng)機(jī)的示例詳解_C 語言
- 2022-05-12 Kotlin 初始化陷阱。初始化注意事項(xiàng)
- 2022-07-13 CentOS上Autofs自動(dòng)掛載iso光盤鏡像-Linux
- 2023-04-04 python中class(object)的含義是什么以及用法_python
- 2022-04-15 windows+vscode穿越跳板機(jī)調(diào)試遠(yuǎn)程代碼的圖文教程_python
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細(xì)win安裝深度學(xué)習(xí)環(huán)境2025年最新版(
- Linux 中運(yùn)行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲(chǔ)小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運(yùn)算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認(rèn)證信息的處理
- Spring Security之認(rèn)證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設(shè)
- maven:解決release錯(cuò)誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實(shí)現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務(wù)發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡(jiǎn)單動(dòng)態(tài)字符串(SD
- arthas操作spring被代理目標(biāo)對(duì)象命令
- Spring中的單例模式應(yīng)用詳解
- 聊聊消息隊(duì)列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠(yuǎn)程分支