網站首頁 編程語言 正文
Pytorch損失函數torch.nn.NLLLoss()
在各種深度學習框架中,我們最常用的損失函數就是交叉熵(torch.nn.CrossEntropyLoss),熵是用來描述一個系統的混亂程度,通過交叉熵我們就能夠確定預測數據與真是數據之間的相近程度。
交叉熵越小,表示數據越接近真實樣本。
交叉熵計算公式
就是我們預測的概率的對數與標簽的乘積,當qk->1的時候,它的損失接近零。
nn.NLLLoss
官方文檔中介紹稱:
nn.NLLLoss輸入是一個對數概率向量和一個目標標簽,它與nn.CrossEntropyLoss的關系可以描述為:softmax(x)+log(x)+nn.NLLLoss====>nn.CrossEntropyLoss
CrossEntropyLoss()=log_softmax() + NLLLoss()
其中softmax函數又稱為歸一化指數函數,它可以把一個多維向量壓縮在(0,1)之間,并且它們的和為1.
計算公式
?
示例代碼:
import math
z = [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0]
z_exp = [math.exp(i) for i in z]
print(z_exp) # Result: [2.72, 7.39, 20.09, 54.6, 2.72, 7.39, 20.09]
sum_z_exp = sum(z_exp)
print(sum_z_exp) # Result: 114.98
softmax = [round(i / sum_z_exp, 3) for i in z_exp]
print(softmax) # Result: [0.024, 0.064, 0.175, 0.475, 0.024, 0.064, 0.175]
log_softmax
log_softmax是指在softmax函數的基礎上,再進行一次log運算,此時結果有正有負,log函數的值域是負無窮到正無窮,當x在0—1之間的時候,log(x)值在負無窮到0之間。
nn.NLLLoss
此時,nn.NLLLoss的結果就是把上面的輸出與Label對應的那個值拿出來,再去掉負號,再求均值。
代碼示例:
import torch
input=torch.randn(3,3)
soft_input = torch.nn.Softmax(dim=0)
soft_input(input)
Out[20]:
tensor([[0.7284, 0.7364, 0.3343],
[0.1565, 0.0365, 0.0408],
[0.1150, 0.2270, 0.6250]])
#對softmax結果取log
torch.log(soft_input(input))
Out[21]:
tensor([[-0.3168, -0.3059, -1.0958],
[-1.8546, -3.3093, -3.1995],
[-2.1625, -1.4827, -0.4701]])
假設標簽是[0,1,2],第一行取第0個元素,第二行取第1個,第三行取第2個,去掉負號,即[0.3168,3.3093,0.4701],求平均值,就可以得到損失值。
(0.3168+3.3093+0.4701)/3
Out[22]: 1.3654000000000002
#驗證一下
loss=torch.nn.NLLLoss()
target=torch.tensor([0,1,2])
loss(input,target)
Out[26]: tensor(0.1365)
nn.CrossEntropyLoss
loss=torch.nn.NLLLoss()
target=torch.tensor([0,1,2])
loss(input,target)
Out[26]: tensor(-0.1399)
loss =torch.nn.CrossEntropyLoss()
input = torch.tensor([[ 1.1879, 1.0780, 0.5312],
[-0.3499, -1.9253, -1.5725],
[-0.6578, -0.0987, 1.1570]])
target = torch.tensor([0,1,2])
loss(input,target)
Out[30]: tensor(0.1365)
以上為全部實驗驗證兩個loss函數之間的關系!??!
總結
原文鏈接:https://blog.csdn.net/Jeremy_lf/article/details/102725285
相關推薦
- 2023-04-19 Android使用gradle讀取并保存數據到BuildConfg流程詳解_Android
- 2021-12-13 C++ 之常量const(常對象、常數據成員、常成員函數)
- 2022-08-16 Hive導入csv文件示例_數據庫其它
- 2022-08-03 python中time庫使用詳解_python
- 2022-09-17 Python?seaborn數據可視化繪圖(直方圖,密度圖,散點圖)_python
- 2023-06-16 Python中ArcPy柵格裁剪柵格(批量對齊柵格圖像范圍并統一行數與列數)_python
- 2022-07-22 HTML+CSS之背景圖片的設置
- 2022-03-07 Android顯示系統SurfaceFlinger分析_Android
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支