網站首頁 編程語言 正文
pytorch/transformers 最后一層不加激活函數原因
之前看bert及其各種變種模型,發現模型最后一層都是FC (full connect)的線性層Linear層,現在講解原因
實驗:筆者試著在最后一層后加上了softmax激活函數,用來做多分類,發現模型無法收斂。去掉激活函數后收斂很好。
說明加的不對,因此深入研究了一下。
前言
對于分類問題,pytorch最后一層為啥都是linear層,沒有激活函數?
一、原因在于損失方式CrossEntropy
CrossEntropy:該損失函數集成了log_softmax和nll_loss。因此,相當于FC層后接上CrossEntropy,實際上是有經過softmax處理的。只是內置到損失函數CrossEntropy中去了。
This criterion combines `log_softmax` and `nll_loss` in a single function.
二、為什么CrossEntropy要用log_softmax而不是softmax
1.查看CrossEntropy定義:
其中p為真實分布,q為預測分布。
根據CrossEntropyLoss公式,分類問題中,所以標簽中只有一個類別(設為z)分量為1,其他類別全為0,我們代入公式,即求和之后只剩下一項。
其中:
是模型FC層輸出后需要接上softmax后,得到的概率。因此,這個公式就可以表示為:-log(softmax(FC的輸出)),因此,這里就直接變成一個函數,叫log_softmax,便于計算CrossEntropy。
2.如果想要的到模型輸出的概率值,需要在FC層輸出后,人為的接上F.Softmax()就好了
代碼如下(示例):
import torch from torch.autograd import Variable import torch.nn.functional as F import matplotlib.pyplot as plt n_data = torch.ones(100,2) x0 = torch.normal(2*n_data, 1) y0 = torch.zeros(100) x1 = torch.normal(-2*n_data, 1) y1 = torch.ones(100) x = torch.cat((x0, x1), 0).type(torch.FloatTensor) # 組裝(連接) y = torch.cat((y0, y1), 0).type(torch.LongTensor) x, y = Variable(x), Variable(y) class Net(torch.nn.Module): def __init__(self, n_feature, n_hidden, n_output): super(Net, self).__init__() self.hidden = torch.nn.Linear(n_feature, n_hidden) self.out = torch.nn.Linear(n_hidden, n_output) def forward(self, x): x = F.relu(self.hidden(x)) x = self.out(x) return x net = Net(2, 10, 2) optimizer = torch.optim.SGD(net.parameters(), lr = 0.012) for t in range(100): out = net(x) loss = torch.nn.CrossEntropyLoss()(out, y) optimizer.zero_grad() loss.backward() optimizer.step() if (t+1) % 20 == 0: plt.cla() prediction = torch.max(F.softmax(out), 1)[1] # 在第1維度取最大值并返回索引值 pred_y = prediction.data.numpy().squeeze() target_y = y.data.numpy() plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:,1], c=pred_y, s=100, lw=0, cmap='RdYlGn') accuracy = sum(pred_y == target_y)/200 plt.text(1.5, -4, 'Accu=%.2f' % accuracy, fontdict={'size': 20, 'color': 'red'}) plt.pause(0.1)
上述代碼中,F.softmax(out)表示的就是模型輸出的概率。
torch.max(F.softmax(out), 1)[1] # 在第1維度取表示取概率最大的列最為預測標簽值,不是概率,而是標簽了。
3.bert模型的輸出端展示
代碼如下(示例):
class Model(nn.Module): def __init__(self, config): super(Model, self).__init__() self.bert = BertModel.from_pretrained(config.bert_path) for param in self.bert.parameters(): param.requires_grad = True self.fc = nn.Linear(config.hidden_size, config.num_classes) def forward(self, x): context = x[0] # 輸入的句子 mask = x[2] # 對padding部分進行mask,和句子一個size,padding部分用0表示,如:[1, 1, 1, 1, 0, 0] bert_out = self.bert(context, attention_mask=mask, output_hidden_states=False) out = self.fc(bert_out.pooler_output) return out
也可以看到,bert中的self.fc = nn.Linear(config.hidden_size, config.num_classes)僅僅為Linear層,沒有激活函數。
如果想得到bert的多分類概率,最后在模型的out輸出后,需要接上一個
F.softmax(out)
總結
這里給大家解釋一下為什么bert模型最后都不加激活函數。是因為損失函數選擇的原因。
原文鏈接:https://blog.csdn.net/weixin_43290383/article/details/128580386
相關推薦
- 2022-03-29 C#算法之各位相加_C#教程
- 2023-01-03 Redis實現優惠券限一單限制詳解_Redis
- 2022-12-12 Docker制作tomcat鏡像并部署項目_docker
- 2023-01-14 ubuntu開機后ROS程序自啟動問題_Linux
- 2022-03-23 shell腳本設置防止暴力破解ssh_Linux
- 2022-11-07 Python實現簡易凱撒密碼的示例代碼_python
- 2022-11-05 Android開發使用Databinding實現關注功能mvvp_Android
- 2022-08-13 DHCP服務簡介及Linux配置實例
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支