網站首頁 編程語言 正文
前言
在實現Per-FedAvg的代碼時,遇到如下問題:
可以發現,我們需要求損失函數對模型參數的Hessian矩陣。
模型定義
我們定義一個比較簡單的模型:
class ANN(nn.Module):
def __init__(self):
super(ANN, self).__init__()
self.sigmoid = nn.Sigmoid()
self.fc1 = nn.Linear(3, 4)
self.fc2 = nn.Linear(4, 5)
def forward(self, data):
x = self.fc1(data)
x = self.fc2(x)
return x
輸出一下模型的參數:
model = ANN()
for param in model.parameters():
print(param.size())
輸出如下:
torch.Size([4, 3])
torch.Size([4])
torch.Size([5, 4])
torch.Size([5])
求解Hessian矩陣
我們首先定義數據:
data = torch.tensor([1, 2, 3], dtype=torch.float)
label = torch.tensor([1, 1, 5, 7, 8], dtype=torch.float)
pred = model(data)
loss_fn = nn.MSELoss()
loss = loss_fn(pred, label)
然后求解一階梯度:
grads = torch.autograd.grad(loss, model.parameters(), retain_graph=True, create_graph=True)
輸出一下grads:
(tensor([[-1.0530, -2.1059, -3.1589],
[ 2.3615, 4.7229, 7.0844],
[-1.5046, -3.0093, -4.5139],
[-2.0272, -4.0543, -6.0815]], grad_fn=<TBackward0>), tensor([-1.0530, 2.3615, -1.5046, -2.0272], grad_fn=<SqueezeBackward1>), tensor([[ 0.2945, -0.2725, -0.8159, -0.6720],
[ 0.1936, -0.1791, -0.5362, -0.4416],
[ 1.0800, -0.9993, -2.9918, -2.4641],
[ 1.3448, -1.2444, -3.7255, -3.0683],
[ 1.2436, -1.1507, -3.4450, -2.8373]], grad_fn=<TBackward0>), tensor([-0.6045, -0.3972, -2.2165, -2.7600, -2.5522],
grad_fn=<MseLossBackwardBackward0>))
可以發現一共4個Tensor,分別為損失函數對四個參數Tensor(兩層,每層都有權重和偏置)的梯度。
然后針對每一個Tensor求解二階梯度:
hessian_params = []
for k in range(len(grads)):
hess_params = torch.zeros_like(grads[k])
for i in range(grads[k].size(0)):
# 判斷是w還是b
if len(grads[k].size()) == 2:
# w
for j in range(grads[k].size(1)):
hess_params[i, j] = torch.autograd.grad(grads[k][i][j], model.parameters(), retain_graph=True)[k][i, j]
else:
# b
hess_params[i] = torch.autograd.grad(grads[k][i], model.parameters(), retain_graph=True)[k][i]
hessian_params.append(hess_params)
這里需要注意:由于模型一共兩層,每一層都有權重和偏置,其中權重參數為二維,偏置參數為一維,在進行具體的二階梯度求導時,需要進行判斷。
最終得到的hessian_params是一個列表,列表中包含四個Tensor,對應損失函數對兩層網絡權重和偏置的二階梯度。
原文鏈接:https://blog.csdn.net/Cyril_KI/article/details/124562109
相關推薦
- 2022-11-25 CentOS?7.9?升級內核?kernel-ml-5.6.14版本的方法_云其它
- 2022-11-22 python正則表達式中匹配次數與貪心問題詳解(+??*)_python
- 2023-03-13 Android?Hilt依賴注入的使用講解_Android
- 2023-11-12 jetson nano報錯Cannot allocate memory的問題——解決辦法
- 2022-01-13 Vite2+TS+el3獲取DOM元素設置類型并進行表單校驗
- 2023-04-21 python查找指定依賴包簡介信息實現_python
- 2023-02-09 Flask如何獲取用戶的ip,查詢用戶的登錄次數,并且封ip_python
- 2022-07-06 Flutter?點擊兩次退出app的實現示例_Android
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支