網站首頁 編程語言 正文
自定義loss
的方法有很多,但是在博主查資料的時候發(fā)現有挺多寫法會有問題,靠譜一點的方法是把loss作為一個pytorch的模塊,
比如:
class CustomLoss(nn.Module): # 注意繼承 nn.Module ? ? def __init__(self): ? ? ? ? super(CustomLoss, self).__init__() ? ? def forward(self, x, y): ? ? ? ? # .....這里寫x與y的處理邏輯,即loss的計算方法 ? ? ? ? return loss # 注意最后只能返回Tensor值,且?guī)荻龋?loss.requires_grad == True
示例代碼:
以一個pytorch求解線性回歸的代碼為例:
import torch import torch.nn as nn import numpy as np import os os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" def get_x_y(): ? ? np.random.seed(0) ? ? x = np.random.randint(0, 50, 300) ? ? y_values = 2 * x + 21 ? ? x = np.array(x, dtype=np.float32) ? ? y = np.array(y_values, dtype=np.float32) ? ? x = x.reshape(-1, 1) ? ? y = y.reshape(-1, 1) ? ? return x, y class LinearRegressionModel(nn.Module): ? ? def __init__(self, input_dim, output_dim): ? ? ? ? super(LinearRegressionModel, self).__init__() ? ? ? ? self.linear = nn.Linear(input_dim, output_dim) ?# 輸入的個數,輸出的個數 ? ? def forward(self, x): ? ? ? ? out = self.linear(x) ? ? ? ? return out if __name__ == '__main__': ? ? input_dim = 1 ? ? output_dim = 1 ? ? x_train, y_train = get_x_y() ? ? model = LinearRegressionModel(input_dim, output_dim) ? ? epochs = 1000 ?# 迭代次數 ? ? optimizer = torch.optim.SGD(model.parameters(), lr=0.001) ? ? model_loss = nn.MSELoss() # 使用MSE作為loss ? ? # 開始訓練模型 ? ? for epoch in range(epochs): ? ? ? ? epoch += 1 ? ? ? ? # 注意轉行成tensor ? ? ? ? inputs = torch.from_numpy(x_train) ? ? ? ? labels = torch.from_numpy(y_train) ? ? ? ? # 梯度要清零每一次迭代 ? ? ? ? optimizer.zero_grad() ? ? ? ? # 前向傳播 ? ? ? ? outputs: torch.Tensor = model(inputs) ? ? ? ? # 計算損失 ? ? ? ? loss = model_loss(outputs, labels) ? ? ? ? # 返向傳播 ? ? ? ? loss.backward() ? ? ? ? # 更新權重參數 ? ? ? ? optimizer.step() ? ? ? ? if epoch % 50 == 0: ? ? ? ? ? ? print('epoch {}, loss {}'.format(epoch, loss.item()))
步驟1:添加自定義的類
我們就用自定義的寫法來寫與MSE相同的效果,MSE計算公式如下:
添加一個類:
class CustomLoss(nn.Module): ? ? def __init__(self): ? ? ? ? super(CustomLoss, self).__init__() ? ? ? ? self.mse_loss = nn.MSELoss() ? ? def forward(self, x, y): ? ? ? ? mse_loss = torch.mean(torch.pow((x - y), 2)) # x與y相減后平方,求均值即為MSE ? ? ? ? return mse_loss
步驟2:修改使用的loss函數
只需要把原始代碼中的:
model_loss = nn.MSELoss() # 使用MSE作為loss
改為:
model_loss = CustomLoss() ?# 自定義loss
即可
完整代碼:
import torch import torch.nn as nn import numpy as np import os os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" def get_x_y(): ? ? np.random.seed(0) ? ? x = np.random.randint(0, 50, 300) ? ? y_values = 2 * x + 21 ? ? x = np.array(x, dtype=np.float32) ? ? y = np.array(y_values, dtype=np.float32) ? ? x = x.reshape(-1, 1) ? ? y = y.reshape(-1, 1) ? ? return x, y class LinearRegressionModel(nn.Module): ? ? def __init__(self, input_dim, output_dim): ? ? ? ? super(LinearRegressionModel, self).__init__() ? ? ? ? self.linear = nn.Linear(input_dim, output_dim) ?# 輸入的個數,輸出的個數 ? ? def forward(self, x): ? ? ? ? out = self.linear(x) ? ? ? ? return out class CustomLoss(nn.Module): ? ? def __init__(self): ? ? ? ? super(CustomLoss, self).__init__() ? ? ? ? self.mse_loss = nn.MSELoss() ? ? def forward(self, x, y): ? ? ? ? mse_loss = torch.mean(torch.pow((x - y), 2)) ? ? ? ? return mse_loss if __name__ == '__main__': ? ? input_dim = 1 ? ? output_dim = 1 ? ? x_train, y_train = get_x_y() ? ? model = LinearRegressionModel(input_dim, output_dim) ? ? epochs = 1000 ?# 迭代次數 ? ? optimizer = torch.optim.SGD(model.parameters(), lr=0.001) ? ? # model_loss = nn.MSELoss() # 使用MSE作為loss ? ? model_loss = CustomLoss() ?# 自定義loss ? ? # 開始訓練模型 ? ? for epoch in range(epochs): ? ? ? ? epoch += 1 ? ? ? ? # 注意轉行成tensor ? ? ? ? inputs = torch.from_numpy(x_train) ? ? ? ? labels = torch.from_numpy(y_train) ? ? ? ? # 梯度要清零每一次迭代 ? ? ? ? optimizer.zero_grad() ? ? ? ? # 前向傳播 ? ? ? ? outputs: torch.Tensor = model(inputs) ? ? ? ? # 計算損失 ? ? ? ? loss = model_loss(outputs, labels) ? ? ? ? # 返向傳播 ? ? ? ? loss.backward() ? ? ? ? # 更新權重參數 ? ? ? ? optimizer.step() ? ? ? ? if epoch % 50 == 0: ? ? ? ? ? ? print('epoch {}, loss {}'.format(epoch, loss.item()))
原文鏈接:https://blog.csdn.net/weixin_35757704/article/details/122865272
相關推薦
- 2022-05-02 Python中的變量和數據類型詳情_python
- 2022-08-30 C語言例題講解指針與數組_C 語言
- 2022-09-05 Spring是如何解決循環(huán)依賴的?
- 2022-07-10 同時啟動兩個項目,產生的跨域問題
- 2022-12-14 正則表達式匹配0-10的正整數以及使用細節(jié)_正則表達式
- 2023-12-02 富文本組件中圖片間空白處理小技巧
- 2022-05-12 Kotlin 集合也可以進行+= -= 還可以根據條件進行刪除(removeIf)
- 2022-07-02 在React中使用axios發(fā)送請求
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細win安裝深度學習環(huán)境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發(fā)現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態(tài)字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支