網站首頁 編程語言 正文
torch.nn.CrossEntropyLoss交叉熵損失
本文只考慮基本情況,未考慮加權。
torch.nnCrossEntropyLosss使用的公式
目標類別采用one-hot編碼
其中,class表示當前樣本類別在one-hot編碼中對應的索引(從0開始),
x[j]表示預測函數的第j個輸出
公式(1)表示先對預測函數使用softmax計算每個類別的概率,再使用log(以e為底)計算后的相反數表示當前類別的損失,只表示其中一個樣本的損失計算方式,非全部樣本。
每個樣本使用one-hot編碼表示所屬類別時,只有一項為1,因此與基本的交叉熵損失函數相比,省略了其它值為0的項,只剩(1)所表示的項。
sample
torch.nn.CrossEntropyLoss使用流程
torch.nn.CrossEntropyLoss為一個類,并非單獨一個函數,使用到的相關簡單參數會在使用中說明,并非對所有參數進行說明。
首先創建類對象
In [1]: import torch
In [2]: import torch.nn as nn
In [3]: loss_function = nn.CrossEntropyLoss(reduction="none")
參數reduction默認為"mean",表示對所有樣本的loss取均值,最終返回只有一個值
參數reduction取"none",表示保留每一個樣本的loss
計算損失
In [4]: pred = torch.tensor([[0.0541,0.1762,0.9489],[-0.0288,-0.8072,0.4909]], dtype=torch.float32)
In [5]: class_index = torch.tensor([0, 2], dtype=torch.int64)
In [6]: loss_value = loss_function(pred, class_index)
In [7]: loss_value
Out[7]: tensor([1.5210, 0.6247]) # 與上述【sample】計算一致
實際計算損失值調用函數時,傳入pred預測值與class_index類別索引
在傳入每個類別時,class_index應為一維,長度為樣本個數,每個元素表示對應樣本的類別索引,非one-hot編碼方式傳入
測試torch.nn.CrossEntropyLoss的reduction參數為默認值"mean"
In [1]: import torch
In [2]: import torch.nn as nn
In [3]: loss_function = nn.CrossEntropyLoss(reduction="mean")
In [4]: pred = torch.tensor([[0.0541,0.1762,0.9489],[-0.0288,-0.8072,0.4909]], dtype=torch.float32)
In [5]: class_index = torch.tensor([0, 2], dtype=torch.int64)
In [6]: loss_value = loss_function(pred, class_index)
In [7]: loss_value
Out[7]: 1.073 # 與上述【sample】計算一致
交叉熵損失nn.CrossEntropyLoss()的真正計算過程
對于多分類損失函數Cross Entropy Loss,就不過多的解釋,網上的博客不計其數。在這里,講講對于CE Loss的一些真正的理解。
首先大部分博客給出的公式如下:
其中p為真實標簽值,q為預測值。
在低維復現此公式,結果如下。在此強調一點,pytorch中CE Loss并不會將輸入的target映射為one-hot編碼格式,而是直接取下標進行計算。
import torch
import torch.nn as nn
import math
import numpy as np
#官方的實現
entroy=nn.CrossEntropyLoss()
input=torch.Tensor([[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],])
target = torch.tensor([0,1,2])
output = entroy(input, target)
print(output)
#輸出 tensor(1.1142)
#自己實現
input=np.array(input)
target = np.array(target)
def cross_entorpy(input, target):
output = 0
length = len(target)
for i in range(length):
hou = 0
for j in input[i]:
hou += np.log(input[i][target[i]])
output += -hou
return np.around(output / length, 4)
print(cross_entorpy(input, target))
#輸出 3.8162
我們按照官方給的CE Loss和根據公式得到的答案并不相同,說明公式是有問題的。
正確公式
實現代碼如下
import torch
import torch.nn as nn
import math
import numpy as np
entroy=nn.CrossEntropyLoss()
input=torch.Tensor([[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],])
target = torch.tensor([0,1,2])
output = entroy(input, target)
print(output)
#輸出 tensor(1.1142)
#%%
input=np.array(input)
target = np.array(target)
def cross_entorpy(input, target):
output = 0
length = len(target)
for i in range(length):
hou = 0
for j in input[i]:
hou += np.exp(j)
output += -input[i][target[i]] + np.log(hou)
return np.around(output / length, 4)
print(cross_entorpy(input, target))
#輸出 1.1142
對比自己實現的公式和官方給出的結果,可以驗證公式的正確性。
觀察公式可以發現其實nn.CrossEntropyLoss()是nn.logSoftmax()和nn.NLLLoss()的整合版本。
nn.logSoftmax(),公式如下
nn.NLLLoss(),公式如下
將nn.logSoftmax()作為變量帶入nn.NLLLoss()可得
因為
可看做一個常量,故上式可化簡為:
對比nn.Cross Entropy Loss公式,結果顯而易見。
驗證代碼如下。
import torch
import torch.nn as nn
import math
import numpy as np
entroy=nn.CrossEntropyLoss()
input=torch.Tensor([[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],])
target = torch.tensor([0,1,2])
output = entroy(input, target)
print(output)
# 輸出為tensor(1.1142)
m = nn.LogSoftmax()
loss = nn.NLLLoss()
input=m(input)
output = loss(input, target)
print(output)
# 輸出為tensor(1.1142)
綜上,可得兩個結論
1.nn.Cross Entropy Loss的公式。
2.nn.Cross Entropy Loss為nn.logSoftmax()和nn.NLLLoss()的整合版本。
總結
原文鏈接:https://blog.csdn.net/u012633319/article/details/111093144
相關推薦
- 2022-09-24 如何將一個CSV格式的文件分割成兩個CSV文件_python
- 2024-02-01 webstorm中Line comment at first column,Block commen
- 2022-11-22 Golang分布式鎖詳細介紹_Golang
- 2022-03-01 箭頭函數的this 構造函數的this 全局環境下的this各是什么
- 2022-11-27 Docker?容器互聯互通的實現方法_docker
- 2022-05-14 C++使用new和delete進行動態內存分配與數組封裝_C 語言
- 2022-08-13 Android自定義ProgressBar實現漂亮的進度提示框_Android
- 2023-01-20 C++利用模板實現消息訂閱和分發功能_C 語言
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支