網站首頁 編程語言 正文
解讀model.named_parameters()與model.parameters()
model.named_parameters()
迭代打印model.named_parameters()將會打印每一次迭代元素的名字和param。
model = DarkNet([1, 2, 8, 8, 4])
for name, param in model.named_parameters():
? ? print(name,param.requires_grad)
? ? param.requires_grad = False
輸出結果為
conv1.weight True
bn1.weight True
bn1.bias True
layer1.ds_conv.weight True
layer1.ds_bn.weight True
layer1.ds_bn.bias True
layer1.residual_0.conv1.weight True
layer1.residual_0.bn1.weight True
layer1.residual_0.bn1.bias True
layer1.residual_0.conv2.weight True
layer1.residual_0.bn2.weight True
layer1.residual_0.bn2.bias True
layer2.ds_conv.weight True
layer2.ds_bn.weight True
layer2.ds_bn.bias True
layer2.residual_0.conv1.weight True
layer2.residual_0.bn1.weight True
layer2.residual_0.bn1.bias True
....
并且可以更改參數的可訓練屬性,第一次打印是True,這是第二次,就是False了
model.parameters()
迭代打印model.parameters()將會打印每一次迭代元素的param而不會打印名字,這是它和named_parameters的區別,兩者都可以用來改變requires_grad的屬性。
for index, param in enumerate(model.parameters()):
? ? print(param.shape)
輸出結果為
torch.Size([32, 3, 3, 3])
torch.Size([32])
torch.Size([32])
torch.Size([64, 32, 3, 3])
torch.Size([64])
torch.Size([64])
torch.Size([32, 64, 1, 1])
torch.Size([32])
torch.Size([32])
torch.Size([64, 32, 3, 3])
torch.Size([64])
torch.Size([64])
torch.Size([128, 64, 3, 3])
torch.Size([128])
torch.Size([128])
torch.Size([64, 128, 1, 1])
torch.Size([64])
torch.Size([64])
torch.Size([128, 64, 3, 3])
torch.Size([128])
torch.Size([128])
torch.Size([64, 128, 1, 1])
torch.Size([64])
torch.Size([64])
torch.Size([128, 64, 3, 3])
torch.Size([128])
torch.Size([128])
torch.Size([256, 128, 3, 3])
torch.Size([256])
torch.Size([256])
torch.Size([128, 256, 1, 1])
....
將兩者結合進行迭代,同時具有索引,網絡層名字及param
?? ?for index, (name, param) in zip(enumerate(model.parameters()), model.named_parameters()):
?? ??? ?print(index[0])
?? ??? ?print(name, param.shape)
輸出結果為
0
conv1.weight torch.Size([32, 3, 3, 3])
1
bn1.weight torch.Size([32])
2
bn1.bias torch.Size([32])
3
layer1.ds_conv.weight torch.Size([64, 32, 3, 3])
4
layer1.ds_bn.weight torch.Size([64])
5
layer1.ds_bn.bias torch.Size([64])
6
layer1.residual_0.conv1.weight torch.Size([32, 64, 1, 1])
7
layer1.residual_0.bn1.weight torch.Size([32])
8
layer1.residual_0.bn1.bias torch.Size([32])
9
layer1.residual_0.conv2.weight torch.Size([64, 32, 3, 3])
state_dict()、named_parameters()和parameters()的區別
Pytorch中有3個功能極其類似的方法,分別是model.parameters()、model.named_parameters()和model.state_dict(),下面就來探究一下這三種方法的區別。
它們的差異主要體現在3方面:
- 返回值類型不同
- 存儲的模型參數的種類不同
- 返回的值的require_grad屬性不同
測試代碼準備工作
import torch
import torch.nn as nn
import torch.optim as optim
import random
import os
import numpy as np
def seed_torch(seed=1029):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed) # 為了禁止hash隨機化,使得實驗可復現
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
seed_torch() # 固定隨機數
# 定義一個網絡
class net(nn.Module):
def __init__(self, num_class=10):
super(net, self).__init__()
self.pool1 = nn.AvgPool1d(2)
self.bn1 = nn.BatchNorm1d(3)
self.fc1 = nn.Linear(12, 4)
def forward(self, x):
x = self.pool1(x)
x = self.bn1(x)
x = x.reshape(x.size(0), -1)
x = self.fc1(x)
return x
# 定義網絡
model = net()
# 定義loss
loss_fn = nn.CrossEntropyLoss()
# 定義優化器
optimizer = optim.SGD(model.parameters(), lr=1e-2)
# 定義訓練數據
x = torch.randn((3, 3, 8))
兩個概念
可學習參數
可學習參數也可叫做模型參數,其就是要參與學習和更新的,特別注意這里的參數更新是指在優化器的optim.step步驟里更新參數,即需要反向傳播更新的參數
使用nn.parameter.Parameter()創建的變量是可學習參數(模型參數)
模型中的可學習參數的數據類型都是nn.parameter.Parameter
optim.step只能更新nn.parameter.Parameter類型的參數
nn.parameter.Parameter類型的參數的特點是默認requires_grad=True,也就是說訓練過程中需要反向傳播的,就需要使用這個
示例:
在上述定義的網絡中,self.fc1層中的參數(weight和bias)是可學習參數,要在訓練過程中進行學習與更新
print(type(model.fc1.weight))
(bbn) jyzhang@admin2-X10DAi:~/test$ python net.py
<class 'torch.nn.parameter.Parameter'>
不可學習參數
不可學習參數不參與學習和在優化器中的更新,即不需要參與反向傳播
不可學習參數將會通過Module.register_parameter()注冊在self._buffers中,self._buffers是一個OrderedDict
舉例:上述定義的模型中,self.bn1層中的參數running_mean、running_var和num_batches_tracked均是不可學習參數
self.register_parameter('running_mean', None)
存儲在self._buffers中的不可學習參數不能通過optim.step()更新參數,但例如上述的self.bn1層中的不可學習參數也會更新,其更新是發生在forward的過程中
示例:
在上述定義的網絡中,self.bn1層中的參數(running_mean)是不可學習參數
print(type(model.bn1.running_mean))
(bbn) jyzhang@admin2-X10DAi:~/test$ python net.py
<class 'torch.Tensor'>
named_parameters()
總述
model.named_parameters()返回的是一個生成器(generator),該生成器中只保存了可學習、可被優化器更新的參數的參數名和具體的參數,可通過循環迭代打印參數名和參數(參見代碼示例一)
該方法可以用來改變可學習、可被優化器更新參數的requires_grad屬性,因此可用于鎖住某些層的參數,讓其在訓練的時候不更新參數(參見代碼示例二)
代碼示例一
# model.named_parameters()的用法
print(type(model.named_parameters()))
for name, param in model.named_parameters():
? ? print(name)
? ? print(param)
結果
(bbn) jyzhang@admin2-X10DAi:~/test$ python net.py
<class 'generator'>
bn1.weight
Parameter containing:
tensor([1., 1., 1.], requires_grad=True)
bn1.bias
Parameter containing:
tensor([0., 0., 0.], requires_grad=True)
fc1.weight
Parameter containing:
tensor([[ 0.0036, ?0.1960, ?0.2315, -0.2408, ?0.1217, ?0.2579, -0.0676, -0.1880,
? ? ? ? ?-0.2855, -0.1587, ?0.0409, ?0.0312],
? ? ? ? [ 0.1057, ?0.1348, -0.0590, -0.1538, ?0.2505, ?0.0651, -0.2461, -0.1856,
? ? ? ? ? 0.2498, -0.1969, ?0.0013, ?0.1979],
? ? ? ? [-0.1812, ?0.1153, ?0.2723, -0.2190, ?0.0371, -0.0341, ?0.2282, ?0.1461,
? ? ? ? ? 0.1890, ?0.1762, ?0.2657, -0.0827],
? ? ? ? [-0.0188, ?0.0081, -0.2674, -0.1858, ?0.1296, ?0.1728, -0.0770, ?0.1444,
? ? ? ? ?-0.2360, -0.1793, ?0.1921, -0.2791]], requires_grad=True)
fc1.bias
Parameter containing:
tensor([-0.0020, ?0.0985, ?0.1859, -0.0175], requires_grad=True)
代碼示例二
print(model.fc1.weight.requires_grad) ?# 可學習參數fc1.weight的requires_grad屬性
for name, param in model.named_parameters():
? ? if ("fc1" in name):
? ? ? ? param.requires_grad = False
print(model.fc1.weight.requires_grad) ?# 修改后可學習參數fc1.weight的requires_grad屬性
結果
(bbn) jyzhang@admin2-X10DAi:~/test$ python net.py
True
False
parameters()
總述
model.parameters()返回的是一個生成器,該生成器中只保存了可學習、可被優化器更新的參數的具體的參數,可通過循環迭代打印參數。(參見代碼示例一)
與model.named_parameters()相比,model.parameters()不會保存參數的名字。
該方法可以用來改變可學習、可被優化器更新參數的requires_grad屬性,但由于其只有參數,沒有對應的參數名,所以當要修改指定的某些層的requires_grad屬性時,沒有model.named_parameters()方便。(參見
代碼示例二)
代碼示例一
# model.parameters()的用法
print(type(model.parameters()))
for param in model.parameters():
? ? print(param)
結果
(bbn) jyzhang@admin2-X10DAi:~/test$ python net.py
<class 'generator'>
Parameter containing:
tensor([1., 1., 1.], requires_grad=True)
Parameter containing:
tensor([0., 0., 0.], requires_grad=True)
Parameter containing:
tensor([[ 0.0036, ?0.1960, ?0.2315, -0.2408, ?0.1217, ?0.2579, -0.0676, -0.1880,
? ? ? ? ?-0.2855, -0.1587, ?0.0409, ?0.0312],
? ? ? ? [ 0.1057, ?0.1348, -0.0590, -0.1538, ?0.2505, ?0.0651, -0.2461, -0.1856,
? ? ? ? ? 0.2498, -0.1969, ?0.0013, ?0.1979],
? ? ? ? [-0.1812, ?0.1153, ?0.2723, -0.2190, ?0.0371, -0.0341, ?0.2282, ?0.1461,
? ? ? ? ? 0.1890, ?0.1762, ?0.2657, -0.0827],
? ? ? ? [-0.0188, ?0.0081, -0.2674, -0.1858, ?0.1296, ?0.1728, -0.0770, ?0.1444,
? ? ? ? ?-0.2360, -0.1793, ?0.1921, -0.2791]], requires_grad=True)
Parameter containing:
tensor([-0.0020, ?0.0985, ?0.1859, -0.0175], requires_grad=True)
代碼示例二
print(model.fc1.weight.requires_grad)
for param in model.parameters():
? ? param.requires_grad = False
print(model.fc1.weight.requires_grad)
結果
(bbn) jyzhang@admin2-X10DAi:~/test$ python net.py
True
False
state_dict()
總述
model.state_dict()返回的是一個有序字典OrderedDict,該有序字典中保存了模型所有參數的參數名和具體的參數值,所有參數包括可學習參數和不可學習參數,可通過循環迭代打印參數,因此,該方法可用于保存模型,當保存模型時,會將不可學習參數也存下,當加載模型時,也會將不可學習參數進行賦值。(參見代碼示例一)
一般在使用model.state_dict()時會使用該函數的默認參數,model.state_dict()源碼如下:
# torch.nn.modules.module.py
class Module(object):
? ? def state_dict(self, destination=None, prefix='', keep_vars=False):
? ? ? ? if destination is None:
? ? ? ? ? ? destination = OrderedDict()
? ? ? ? ? ? destination._metadata = OrderedDict()
? ? ? ? destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
? ? ? ? for name, param in self._parameters.items():
? ? ? ? ? ? if param is not None:
? ? ? ? ? ? ? ? destination[prefix + name] = param if keep_vars else param.data
? ? ? ? for name, buf in self._buffers.items():
? ? ? ? ? ? if buf is not None:
? ? ? ? ? ? ? ? destination[prefix + name] = buf if keep_vars else buf.data
? ? ? ? for name, module in self._modules.items():
? ? ? ? ? ? if module is not None:
? ? ? ? ? ? ? ? module.state_dict(destination, prefix + name + '.', keep_vars=keep_vars)
? ? ? ? for hook in self._state_dict_hooks.values():
? ? ? ? ? ? hook_result = hook(self, destination, prefix, local_metadata)
? ? ? ? ? ? if hook_result is not None:
? ? ? ? ? ? ? ? destination = hook_result
? ? ? ? return destination
在默認參數下,model.state_dict()保存參數時只會保存參數(Tensor對象)的data屬性,不會保存參數的requires_grad屬性,因此,其保存的參數的requires_grad的屬性變為False,沒有辦法改變requires_grad的屬性,所以改變requires_grad的屬性只能通過上面的兩種方式。(參見代碼示例二)
model.state_dict()本質上是淺拷貝,即返回的OrderedDict對象本身是新創建的對象,但其中的param參數的引用仍是模型參數的data屬性的地址,又因為Tensor是可變對象,因此,若對param參數進行修改(在原地址變更數據內容),會導致對應的模型參數的改變。(參見代碼示例三)
代碼示例一
# model.state_dict()的用法
print(model.state_dict())
for name, param in model.state_dict().items():
? ? print(name)
? ? print(param)
? ? print(param.requires_grad)
結果
(bbn) jyzhang@admin2-X10DAi:~/test$ python net.py
OrderedDict([('bn1.weight', tensor([1., 1., 1.])), ('bn1.bias', tensor([0., 0., 0.])), ('bn1.running_mean', tensor([0., 0., 0.])), ('bn1.running_var', tensor([1., 1., 1.])), ('bn1.num_batches_tracked', tensor(0)), ('fc1.weight', tensor([[ 0.0036, ?0.1960, ?0.2315, -0.2408, ?0.1217, ?0.2579, -0.0676, -0.1880,
? ? ? ? ?-0.2855, -0.1587, ?0.0409, ?0.0312],
? ? ? ? [ 0.1057, ?0.1348, -0.0590, -0.1538, ?0.2505, ?0.0651, -0.2461, -0.1856,
? ? ? ? ? 0.2498, -0.1969, ?0.0013, ?0.1979],
? ? ? ? [-0.1812, ?0.1153, ?0.2723, -0.2190, ?0.0371, -0.0341, ?0.2282, ?0.1461,
? ? ? ? ? 0.1890, ?0.1762, ?0.2657, -0.0827],
? ? ? ? [-0.0188, ?0.0081, -0.2674, -0.1858, ?0.1296, ?0.1728, -0.0770, ?0.1444,
? ? ? ? ?-0.2360, -0.1793, ?0.1921, -0.2791]])), ('fc1.bias', tensor([-0.0020, ?0.0985, ?0.1859, -0.0175]))])
bn1.weight
tensor([1., 1., 1.])
False
bn1.bias
tensor([0., 0., 0.])
False
bn1.running_mean
tensor([0., 0., 0.])
False
bn1.running_var
tensor([1., 1., 1.])
False
bn1.num_batches_tracked
tensor(0)
False
fc1.weight
tensor([[ 0.0036, ?0.1960, ?0.2315, -0.2408, ?0.1217, ?0.2579, -0.0676, -0.1880,
? ? ? ? ?-0.2855, -0.1587, ?0.0409, ?0.0312],
? ? ? ? [ 0.1057, ?0.1348, -0.0590, -0.1538, ?0.2505, ?0.0651, -0.2461, -0.1856,
? ? ? ? ? 0.2498, -0.1969, ?0.0013, ?0.1979],
? ? ? ? [-0.1812, ?0.1153, ?0.2723, -0.2190, ?0.0371, -0.0341, ?0.2282, ?0.1461,
? ? ? ? ? 0.1890, ?0.1762, ?0.2657, -0.0827],
? ? ? ? [-0.0188, ?0.0081, -0.2674, -0.1858, ?0.1296, ?0.1728, -0.0770, ?0.1444,
? ? ? ? ?-0.2360, -0.1793, ?0.1921, -0.2791]])
False
fc1.bias
tensor([-0.0020, ?0.0985, ?0.1859, -0.0175])
False
代碼示例二
# model.state_dict()的用法
print(model.bn1.weight.requires_grad)
model.bn1.weight.requires_grad = False
print(model.bn1.weight.requires_grad)
for name, param in model.state_dict().items():
? ? if (name == "bn1.weight"):
? ? ? ? param.requires_grad = True
print(model.bn1.weight.requires_grad)
結果
(bbn) jyzhang@admin2-X10DAi:~/test$ python net.py
True
False
False
代碼示例三
# model.state_dict()的用法
print(model.bn1.weight)
for name, param in model.state_dict().items():
? ? if (name == "bn1.weight"):
? ? ? ? param[0] = 1000
print(model.bn1.weight)
結果
(bbn) jyzhang@admin2-X10DAi:~/test$ python net.py
Parameter containing:
tensor([1., 1., 1.], requires_grad=True)
Parameter containing:
tensor([1000., ? ?1., ? ?1.], requires_grad=True)
原文鏈接:https://blog.csdn.net/weixin_42149550/article/details/117128228
相關推薦
- 2022-07-24 .Net結構型設計模式之裝飾模式(Decorator)_基礎應用
- 2022-12-08 一文帶你搞懂Go如何讀寫Excel文件_Golang
- 2023-03-16 numpy如何獲取array中數組元素的索引位置_python
- 2022-06-22 C語言關于include順序不同導致編譯結果不同的問題_C 語言
- 2023-01-05 python使用jpype導入多個Jar的異常問題及解決_python
- 2024-07-15 GIT同步修改后的遠程分支
- 2022-05-26 Flutter實現抽屜動畫_Android
- 2023-01-07 Python個人博客程序開發實例后臺編寫_python
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支