網(wǎng)站首頁 編程語言 正文
大多數(shù)卷積神經(jīng)網(wǎng)絡(luò)都是直接通過寫一個(gè)Model類來定義的,這樣寫的代碼其實(shí)是比較好懂的,特別是在魔改網(wǎng)絡(luò)的時(shí)候也很方便。然后也有一些會(huì)通過cfg配置文件進(jìn)行模型的定義。在yolov5中可以看到是通過yaml文件進(jìn)行網(wǎng)絡(luò)的定義【個(gè)人感覺通過配置文件魔改網(wǎng)絡(luò)有些不方便,當(dāng)然每個(gè)人習(xí)慣不同】,可能很多人也用過,如果自己去寫一個(gè)yaml文件,自己能不能定義出來呢?很多人不知道是如何具體通過yaml文件將里面的參數(shù)傳入自己定義的網(wǎng)絡(luò)中,這也就給自己修改網(wǎng)絡(luò)帶來了不便。這篇文章將仿照yolov5的方式,利用yaml定義一個(gè)自己的網(wǎng)絡(luò)。
定義卷積塊
我們可以先定義一個(gè)卷積塊CBL,C指卷積Conv,B指BN層,L為激活函數(shù),這里我用ReLu.
class BaseConv(nn.Module):
def __init__(self, in_channels, out_channels, k=1, s=1, p=None):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.conv = nn.Conv2d(in_channels, out_channels, k, s, autopad(k, p))
self.bn = nn.BatchNorm2d(out_channels)
self.act_fn = nn.ReLU(inplace=True)
def forward(self, x):
return self.act_fn(self.bn(self.conv(x)))
卷積中的autopad是自動(dòng)補(bǔ)充pad,代碼如下:
def autopad(k, p=None):
if p is None:
p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
return p
定義一個(gè)Bottleneck?
可以仿照yolov5定義一個(gè)Bottleneck,參考了殘差塊的思想。
class Bottleneck(nn.Module):
def __init__(self, in_channels, out_channels, shortcut=True):
super(Bottleneck, self).__init__()
self.conv1 = BaseConv(in_channels, out_channels, k=1, s=1)
self.conv2 = BaseConv(out_channels, out_channels, k=3, s=1)
self.add = shortcut and in_channels == out_channels
def forward(self, x):
"""
x-->conv1-->conv2-->add
|_________________|
"""
return x + self.conv2(self.conv1(x)) if self.add else self.conv2(self.conv1(x))
攥寫yaml配置文件
然后我們來寫一下yaml配置文件,網(wǎng)絡(luò)不要很復(fù)雜,就由兩個(gè)卷積和兩個(gè)Bottleneck組成就行。同理,仿v5的方法,我們的網(wǎng)絡(luò)中的backone也是個(gè)列表,每行為一個(gè)卷積層,每列有4個(gè)參數(shù),分別代表from(指該層的輸入通道數(shù)為上一層的輸出通道數(shù),所以是-1),number【yaml中的1,1,2指該層的深度,或者說是重復(fù)幾次】,Module_nams【該層的名字】,args【網(wǎng)絡(luò)參數(shù),包含輸出通道數(shù),k,s,p等設(shè)置】
# define own model
backbone:
[[-1, 1, BaseConv, [32, 3, 1]], # out_channles=32, k=3, s=1
[-1, 1, BaseConv, [64, 1, 1]],
[-1, 2, Bottleneck, [64]]
]
我們現(xiàn)在用yaml工具來打開我們的配置文件,看看都有什么內(nèi)容
import yaml
# 獲得yaml文件名字
yaml_file = Path('Model.yaml').name
with open(yaml_file,errors='ignore') as f:
yaml_ = yaml.safe_load(f)
print(yaml_)
輸出:?
?{'backbone': [[-1, 1, 'BaseConv', [32, 3, 1]], [-1, 1, 'BaseConv', [64, 1, 1]], [-1, 2, 'Bottleneck', [64]]]}
然后我們可以定義下自己Model類,也就是定義自己的網(wǎng)絡(luò)。可以看到與前面讀取yaml文件相比,多了一行 ? ?ch = self.yaml["ch"] = self.yaml["ch"] = 3 ? 這個(gè)是在原yaml內(nèi)容中加入一個(gè)key和valuse,3指的3通道,因?yàn)槲覀兊膱D像是3通道。parse_model是下面要說的傳參過程。
class Model(nn.Module):
def __init__(self, cfg='./Model.yaml', ch=3, ):
super().__init__()
self.yaml = cfg
import yaml
yaml_file = Path(cfg).name
with open(yaml_file, errors='ignore')as f:
self.yaml = yaml.safe_load(f)
ch = self.yaml["ch"] = self.yaml["ch"] = 3
self.backbone = parse_model(deepcopy(self.yaml), ch=[ch])
def forward(self, x):
output = self.backbone(x)
return output
傳入?yún)?shù)
這一步也是最關(guān)鍵的一步,我們需要定義傳參的函數(shù),將yaml中的卷積參數(shù)傳入我們定義的網(wǎng)絡(luò)中,這里會(huì)用的一個(gè)非常非常重要的函數(shù)eval(),后面也會(huì)介紹到這個(gè)函數(shù)的用法。
這里先附上完整代碼:
def parse_model(yaml_cfg, ch):
"""
:param yaml_cfg: yaml file
:param ch: init in_channels default is 3
:return: model
"""
layer, out_channels = [], ch[-1]
for i, (f, number, Module_name, args) in enumerate(yaml_cfg['backbone']):
"""
f:上一層輸出通道
number:該模塊有幾層,就是該模塊要重復(fù)幾次
Mdule_name:卷積層名字
args:參數(shù),包含輸出通道數(shù),k,s,p等
"""
# 通過eval,將str類型轉(zhuǎn)自己定義的BaseConv
m = eval(Module_name) if isinstance(Module_name, str) else Module_name
for j, a in enumerate(args):
# 通過eval,將str轉(zhuǎn)int,獲得輸出通道數(shù)
args[j] = eval(a) if isinstance(a, str) else a
# 更新通道
# args[0]是輸出通道
if m in [BaseConv, Bottleneck]:
in_channels, out_channels = ch[f], args[0]
args = [in_channels, out_channels, *args[1:]] # args=[in_channels, out_channels, k, s, p]
# 將參數(shù)傳入模型
model_ = nn.Sequential(*[m(*args) for _ in range(number)]) if number > 1 else m(*args)
# 更新通道列表,每次獲取輸出通道
ch.append(out_channels)
layer.append(model_)
return nn.Sequential(*layer)
下面開始分析代碼 。
這行代碼是通過列表用來存放每層內(nèi)容以及輸出通道數(shù)。
# 這行代碼是通過列表用來存放每層內(nèi)容以及輸出通道數(shù)
layer, out_channels = [], ch[-1]
然后進(jìn)入我們的for循環(huán),在每一次循環(huán)中可以獲得我們yaml文件中的每一層網(wǎng)絡(luò):f是上一層網(wǎng)絡(luò)的輸出通道【用來作為本層的輸入通道】,number【網(wǎng)絡(luò)深度,也就是該層重復(fù)幾次而已】,Module_name是該層的名字,args是該層的一些參數(shù)。
for i, (f, number, Module_name, args) in enumerate(yaml_cfg['backbone']):
接下來會(huì)碰到一個(gè)很重要的函數(shù)eval()。下行的代碼首先需要判斷一下我們的Module_name類型是不是字符串類型,也就是判斷一下yaml中“BaseConv”是不是字符串類型,如果是,則用eval進(jìn)行對應(yīng)類型的轉(zhuǎn)化,轉(zhuǎn)成我們的BaseConv類型。?
m = eval(Module_name) if isinstance(Module_name, str) else Module_name
這里我將對eval函數(shù)在深入點(diǎn),如果知道這個(gè)函數(shù)用法的,就可以略去這部分。
我們先舉個(gè)例子,比如我現(xiàn)在有個(gè)變量a="123",這個(gè)a的類型是什么呢?他是一個(gè)str類型,不是int類型。 現(xiàn)在我們用eval函數(shù)轉(zhuǎn)一下,看看會(huì)變成什么樣子。
>>> b = eval(a) if isinstance(a,str) else a
>>> b
123
>>> type(b)
<class 'int'>
我們可以看到,經(jīng)過eval函數(shù)以后,會(huì)自動(dòng)識(shí)別并轉(zhuǎn)為int類型。那么我繼續(xù)舉例子,如果現(xiàn)在a="BaseConv",經(jīng)過eval以后會(huì)變成什么?可以看到,這里報(bào)錯(cuò)了!這是為什么?這是因?yàn)槲覀儧]有導(dǎo)入BaseConv這個(gè)類,所以eval函數(shù)并不知道我們希望轉(zhuǎn)為什么類型。所以我們需要用import導(dǎo)入BaseConv這個(gè)類才可以。
>>> a="BaseConv"
>>> b = eval(a) if isinstance(a,str) else a
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "<string>", line 1, in <module>
NameError: name 'BaseConv' is not defined
當(dāng)我們導(dǎo)入BaseConv以后,在經(jīng)過eval就可以獲得:
<class 'models.BaseConv'>
接下來是獲得args中的網(wǎng)絡(luò)參數(shù),也是通過eval進(jìn)行轉(zhuǎn)化
for j, a in enumerate(args):
# 通過eval,將str轉(zhuǎn)int,獲得輸出通道數(shù)
args[j] = eval(a) if isinstance(a, str) else a
獲取通道數(shù),并在每次循環(huán)中對通道進(jìn)行更新:可以仔細(xì)看一下ch[f]指的上一層輸出通道,剛開始默認(rèn)為[3],那么ch[-1]=3,我們yaml中第一層的BaseConv args[0]為32,表示輸出32通道。因此在第一次循環(huán)中有in_channels = 3,out_channels=32。args也要更新,*args前面的"*"并不是指針的意思,也不是乘的意思,而是解壓操作,因此我們第一次循環(huán)中得到的args=[3,32,3,1]。
# 更新通道
# args[0]是輸出通道
if m in [BaseConv, Bottleneck]:
in_channels, out_channels = ch[f], args[0]
args = [in_channels, out_channels, *args[1:]] # args=[in_channels, out_channels, k, s, p]
將參數(shù)傳入模型
這里用for _ in range(number)來判斷網(wǎng)絡(luò)的深度【或者說該模塊重復(fù)幾次】,這里的m就是前面經(jīng)過eval轉(zhuǎn)化的 <class 'models.BaseConv'>。通過*args解壓操作將args列表中的內(nèi)容放入m中,再通過*解壓操作放入nn.Sequential。
model_ = nn.Sequential(*[m(*args) for _ in range(number)]) if number > 1 else m(*args)
這樣就可以獲得我們第一次循環(huán)BaseConv了。后面的循環(huán)也是同樣的反復(fù)操作而已。
BaseConv(
(conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_fn): ReLU(inplace=True)
)
然后是更新通道列表和layer列表,為的是獲取每次循環(huán)的輸出通道,沒有這一步,再下一次循環(huán)的時(shí)候?qū)⒉荒苷_得到通道數(shù)。
# 更新通道列表,每次獲取輸出通道
ch.append(out_channels)
layer.append(model_)
然后我們就可以對模型調(diào)用進(jìn)行實(shí)例化了,可以打印下模型:
Model(
(backbone): Sequential(
(0): BaseConv(
(conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_fn): ReLU(inplace=True)
)
(1): BaseConv(
(conv): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_fn): ReLU(inplace=True)
)
(2): Sequential(
(0): Bottleneck(
(conv1): BaseConv(
(conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_fn): ReLU(inplace=True)
)
(conv2): BaseConv(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_fn): ReLU(inplace=True)
)
)
(1): Bottleneck(
(conv1): BaseConv(
(conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_fn): ReLU(inplace=True)
)
(conv2): BaseConv(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(act_fn): ReLU(inplace=True)
)
)
)
)
)
同時(shí)我們也可以對模型每層可視化看一下。可以看到和我們定義的模型是一樣的。
完整的代碼
from copy import deepcopy
from models import BaseConv, Bottleneck
import torch.nn as nn
import os
path = os.getcwd()
from pathlib import Path
import torch
def parse_model(yaml_cfg, ch):
"""
:param yaml_cfg: yaml file
:param ch: init in_channels default is 3
:return: model
"""
layer, out_channels = [], ch[-1]
for i, (f, number, Module_name, args) in enumerate(yaml_cfg['backbone']):
"""
f:上一層輸出通道
number:該模塊有幾層,就是該模塊要重復(fù)幾次
Mdule_name:卷積層名字
args:參數(shù),包含輸出通道數(shù),k,s,p等
"""
# 通過eval,將str類型轉(zhuǎn)自己定義的BaseConv
m = eval(Module_name) if isinstance(Module_name, str) else Module_name
for j, a in enumerate(args):
# 通過eval,將str轉(zhuǎn)int,獲得輸出通道數(shù)
args[j] = eval(a) if isinstance(a, str) else a
# 更新通道
# args[0]是輸出通道
if m in [BaseConv, Bottleneck]:
in_channels, out_channels = ch[f], args[0]
args = [in_channels, out_channels, *args[1:]] # args=[in_channels, out_channels, k, s, p]
# 將參數(shù)傳入模型
model_ = nn.Sequential(*[m(*args) for _ in range(number)]) if number > 1 else m(*args)
# 更新通道列表,每次獲取輸出通道
ch.append(out_channels)
layer.append(model_)
return nn.Sequential(*layer)
class Model(nn.Module):
def __init__(self, cfg='./Model.yaml', ch=3, ):
super().__init__()
self.yaml = cfg
import yaml
yaml_file = Path(cfg).name
with open(yaml_file, errors='ignore')as f:
self.yaml = yaml.safe_load(f)
ch = self.yaml["ch"] = self.yaml["ch"] = 3
self.backbone = parse_model(deepcopy(self.yaml), ch=[ch])
def forward(self, x):
output = self.backbone(x)
return output
if __name__ == "__main__":
cfg = path + '/Model.yaml'
model = Model()
model.eval()
print(model)
x = torch.ones(1, 3, 512, 512)
output = model(x)
torch.save(model, "model.pth")
# model = torch.load('model.pth')
# model.eval()
# x = torch.ones(1,3,512,512)
# input_name = ['input']
# output_name = ['output']
# torch.onnx.export(model, x, 'myonnx.onnx', verbose=True)
原文鏈接:https://blog.csdn.net/Vertira/article/details/127417327
相關(guān)推薦
- 2022-11-17 解讀Python中字典的key都可以是什么_python
- 2022-12-06 C++類成員函數(shù)后面加const問題_C 語言
- 2022-05-02 python?異常捕獲詳解流程_python
- 2022-10-06 C++?pimpl機(jī)制詳細(xì)講解_C 語言
- 2022-03-16 .net6環(huán)境下使用RestSharp請求GBK編碼網(wǎng)頁亂碼的解決方案_實(shí)用技巧
- 2022-11-29 Mybatis中如何傳入map參數(shù)呢?
- 2022-04-28 C#操作進(jìn)程的方法介紹_C#教程
- 2023-07-13 遍歷對象并改變對象某個(gè)屬性的值
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細(xì)win安裝深度學(xué)習(xí)環(huán)境2025年最新版(
- Linux 中運(yùn)行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲(chǔ)小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運(yùn)算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認(rèn)證信息的處理
- Spring Security之認(rèn)證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設(shè)
- maven:解決release錯(cuò)誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實(shí)現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務(wù)發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡單動(dòng)態(tài)字符串(SD
- arthas操作spring被代理目標(biāo)對象命令
- Spring中的單例模式應(yīng)用詳解
- 聊聊消息隊(duì)列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠(yuǎn)程分支