網站首頁 編程語言 正文
PyTorch nn.Module類的簡介
torch.nn.Module類是所有神經網絡模塊(modules)的基類,它的實現在torch/nn/modules/module.py中。你的模型也應該繼承這個類,主要重載__init__、forward和extra_repr函數。Modules還可以包含其它Modules,從而可以將它們嵌套在樹結構中。
只要在自己的類中定義了forward函數,backward函數就會利用Autograd被自動實現。只要實例化一個對象并傳入對應的參數就可以自動調用forward函數。因為此時會調用對象的__call__方法,而nn.Module類中的__call__方法會調用forward函數。
nn.Module類中函數介紹:
-
__init__
:初始化內部module狀態。 -
register_buffer
:向module添加buffer,不作為模型參數,可作為module狀態的一部分。默認情況下,buffer是持久(persistent)的,將與參數一起保存。buffer是否persistent的區別在于這個buffer是否被放入self.state_dict()中被保存下來。 -
register_parameter
:向module添加參數。 -
add_module
:添加一個submodule(children)到當前module中。 -
apply
:將fn遞歸應用于每個submodule(children),典型用途為初始化模型參數。 -
cuda
:將所有模型參數和buffers轉移到GPU上。 -
xpu
:將所有模型參數和buffers轉移到XPU上。 -
cpu
:將所有模型參數和buffers轉移到CPU上。 -
type
:將所有參數和buffers轉換為所需的類型。 -
float
:將所有浮點參數和buffers轉換為float32數據類型。 -
double
:將所有浮點參數和buffers轉換為double數據類型。 -
half
:將所有浮點參數和buffers轉換為float16數據類型。 -
bfloat16
:將所有浮點參數和buffers轉換為bfloat16數據類型。 -
to
:將參數和buffers轉換為指定的數據類型或轉換到指定的設備上。 -
register_backward_hook
:在module中注冊一個反向鉤子。不推薦使用。 -
register_full_backward_hook
:在module中注冊一個反向鉤子。每次計算梯度時都會調用此鉤子。使用此鉤子時不允許就地(in place)修改輸入或輸出,否則會觸發error。 -
register_forward_pre_hook
:在module中注冊前向pre-hook。每次調用forward之前都會調用此鉤子。 -
register_forward_hook
:在module中注冊一個前向鉤子。每次forward計算輸出后都會調用此鉤子。 -
state_dict
:返回包含了module的整個狀態的字典。其中keys是對應的參數和buffer名稱。 -
load_state_dict
:將參數和buffers從state_dict復制到module及其后代(descendants)中。 -
parameters
:返回module的參數的迭代器。 -
named_parameters
:返回module的參數的迭代器,產生(yield)參數的名稱以及參數本身。不會返回重復的parameter。 -
buffers
:返回module的buffers的迭代器。 -
named_buffers
:返回module的buffers的迭代器,產生(yield)buffer的名稱以及buffer本身。不會返回重復的buffer。 -
children
:返回直接子module的迭代器。 -
named_children
:返回直接子module的迭代器,產生(yield)子module的名稱以及子module本身。不會返回重復的children。 -
modules
:返回網絡中所有modules的迭代器。 -
named_modules
:返回網絡中所有modules的迭代器,產生(yield)module的名稱以及module本身。不會返回重復的module。 -
train
:將module設置為訓練模式。這僅對某些module起作用。module.py實現中會修改self.training并通過self.children()來調整所有submodule的狀態。 -
eval
:將module設置為評估模式。這僅對某些module起作用。module.py實現中直接調用train(False)。 -
requires_grad
_:更改autograd是否應記錄對此module中參數的操作。此方法就地(in place)設置參數的requires_grad屬性。 -
zero_grad
:將所有模型參數的梯度設置為零。 -
extra_repr
:設置module的額外表示。你應該在自己的modules中重新實現此方法。
測試代碼如下:
import torch import torch.nn as nn import torch.nn.functional as F # nn.functional.py中存放激活函數等的實現 ? @torch.no_grad() def init_weights(m): ? ? print("xxxx:", m) ? ? if type(m) == nn.Linear: ? ? ? ? ?m.weight.fill_(1.0) ? ? ? ? ?print("yyyy:", m.weight) ? class Model(nn.Module): ? ? def __init__(self): ? ? ? ? # 在實現自己的__init__函數時,為了正確初始化自定義的神經網絡模塊,一定要先調用super().__init__ ? ? ? ? super(Model, self).__init__() ? ? ? ? self.conv1 = nn.Conv2d(1, 20, 5) # submodule(child module) ? ? ? ? self.conv2 = nn.Conv2d(20, 20, 5) ? ? ? ? self.add_module("conv3", nn.Conv2d(10, 40, 5)) # 添加一個submodule到當前module,等價于self.conv3 = nn.Conv2d(10, 40, 5) ? ? ? ? self.register_buffer("buffer", torch.randn([2,3])) # 給module添加一個presistent(持久的) buffer ? ? ? ? self.param1 = nn.Parameter(torch.rand([1])) # module參數的tensor ? ? ? ? self.register_parameter("param2", nn.Parameter(torch.rand([1]))) # 向module添加參數 ? ? ? ? ? # nn.Sequential: 順序容器,module將按照它們在構造函數中傳遞的順序添加,它允許將整個容器視為單個module ? ? ? ? self.feature = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) ? ? ? ? self.feature.apply(init_weights) # 將fn遞歸應用于每個submodule,典型用途為初始化模型參數 ? ? ? ? self.feature.to(torch.double) # 將參數數據類型轉換為double ? ? ? ? cpu = torch.device("cpu") ? ? ? ? self.feature.to(cpu) # 將參數數據轉換到cpu設備上 ? ? ? def forward(self, x): ? ? ? ?x = F.relu(self.conv1(x)) ? ? ? ?return F.relu(self.conv2(x)) ? model = Model() print("## Model:", model) ? model.cpu() # 將所有模型參數和buffers移動到CPU上 model.float() # 將所有浮點參數和buffers轉換為float數據類型 model.zero_grad() # 將所有模型參數的梯度設置為零 ? # state_dict:返回一個字典,保存著module的所有狀態,參數和persistent buffers都會包含在字典中,字典的key就是參數和buffer的names print("## state_dict:", model.state_dict().keys()) ? for name, parameters in model.named_parameters(): # 返回module的參數(weight and bias)的迭代器,產生(yield)參數的名稱以及參數本身 ? ? print(f"## named_parameters: name: {name}; parameters size: {parameters.size()}") ? for name, buffers in model.named_buffers(): # 返回module的buffers的迭代器,產生(yield)buffer的名稱以及buffer本身 ? ? print(f"## named_buffers: name: {name}; buffers size: {buffers.size()}") ? # 注:children和modules中重復的module只被返回一次 for children in model.children(): # 返回當前module的child module(submodule)的迭代器 ? ? print("## children:", children) ? for name, children in model.named_children(): # 返回直接submodule的迭代器,產生(yield) submodule的名稱以及submodule本身 ? ? print(f"## named_children: name: {name}; children: {children}") ? for modules in model.modules(): # 返回當前模型所有module的迭代器,注意與children的區別 ? ? print("## modules:", modules) ? for name, modules in model.named_modules(): # 返回網絡中所有modules的迭代器,產生(yield)module的名稱以及module本身,注意與named_children的區別 ? ? print(f"## named_modules: name: {name}; module: {modules}") ? model.train() # 將module設置為訓練模式 model.eval() # 將module設置為評估模式 ? print("test finish")
GitHub:https://github.com/fengbingchun/PyTorch_Test
PyTorch中nn.Module理解
nn.Module是Pytorch封裝的一個類,是搭建神經網絡時需要繼承的父類:
import torch import torch.nn as nn # 括號中加入nn.Module(父類)。Test2變成子類,繼承父類(nn.Module)的所有特性。 class Test2(nn.Module): def __init__(self): # Test2類定義初始化方法 super(Test2, self).__init__() # 父類初始化 self.M = nn.Parameter(torch.ones(10)) def weightInit(self): print('Testing') def forward(self, n): # print(2 * n) print(self.M * n) self.weightInit() # 調用方法 network = Test2() network(2) # 2賦值給forward(self, n)中的n。 ……省略一部分代碼…… # 因為Test2是nn.Module的子類,所以也可以執行父類中的方法。如: model_dict = network.state_dict() # 調用父類中的方法state_dict(),將Test2中訓練參數賦值model_dict。 for k, v in model_dict.items(): # 查看自己網絡參數各層名稱、數值 print(k) # 輸出網絡參數名字 # print(v) # 輸出網絡參數數值
繼承nn.Module的子類程序是從forward()方法開始執行的,如果要想執行其他方法,必須把它放在forward()方法中。這一點與python中繼承有稍許的不同。
總結
原文鏈接:https://blog.csdn.net/fengbingchun/article/details/122023299
- 上一篇:沒有了
- 下一篇:沒有了
相關推薦
- 2022-06-19 selenium?IDE自動化測試腳本的實現_其它綜合
- 2022-12-08 React源碼state計算流程和優先級實例解析_React
- 2023-03-22 一文帶你學會Python?Flask框架設置響應頭_python
- 2022-04-25 ASP.NET?Core中Cookie驗證身份用法詳解_實用技巧
- 2024-07-15 GIT同步修改后的遠程分支
- 2022-10-16 Django完整增刪改查系統實例代碼_python
- 2021-11-25 C++實現截圖截屏的示例代碼_C 語言
- 2022-05-12 android okHttp網絡請求封裝
- 欄目分類
-
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支