網站首頁 編程語言 正文
摘要
With the continuous development of time series prediction, Transformer-like models have gradually replaced traditional models in the fields of CV and NLP by virtue of their powerful advantages. Among them, the Informer is far superior to the traditional RNN model in long-term prediction, and the Swin Transformer is significantly stronger than the traditional CNN model in image recognition. A deep grasp of Transformer has become an inevitable requirement in the field of artificial intelligence. This article will use the Pytorch framework to implement the position encoding, multi-head attention mechanism, self-mask, causal mask and other functions in Transformer, and build a Transformer network from 0.
隨著時序預測的不斷發展,Transformer類模型憑借強大的優勢,在CV、NLP領域逐漸取代傳統模型。其中Informer在長時序預測上遠超傳統的RNN模型,Swin Transformer在圖像識別上明顯強于傳統的CNN模型。深層次掌握Transformer已經成為從事人工智能領域的必然要求。本文將用Pytorch框架,實現Transformer中的位置編碼、多頭注意力機制、自掩碼、因果掩碼等功能,從0搭建一個Transformer網絡。
一、構造數據
1.1 句子長度
# 關于word embedding,以序列建模為例 # 輸入句子有兩個,第一個長度為2,第二個長度為4 src_len = torch.tensor([2, 4]).to(torch.int32) # 目標句子有兩個。第一個長度為4, 第二個長度為3 tgt_len = torch.tensor([4, 3]).to(torch.int32) print(src_len) print(tgt_len)
輸入句子(src_len)有兩個,第一個長度為2,第二個長度為4
目標句子(tgt_len)有兩個。第一個長度為4, 第二個長度為3
1.2 生成句子
用隨機數生成句子,用0填充空白位置,保持所有句子長度一致
src_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_src_words, (L, )), (0, max(src_len)-L)), 0) for L in src_len]) tgt_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_tgt_words, (L, )), (0, max(tgt_len)-L)), 0) for L in tgt_len]) print(src_seq) print(tgt_seq)
src_seq為輸入的兩個句子,tgt_seq為輸出的兩個句子。
為什么句子是數字?在做中英文翻譯時,每個中文或英文對應的也是一個數字,只有這樣才便于處理。
1.3 生成字典
在該字典中,總共有8個字(行),每個字對應8維向量(做了簡化了的)。注意在實際應用中,應當有幾十萬個字,每個字可能有512個維度。
# 構造word embedding src_embedding_table = nn.Embedding(9, model_dim) tgt_embedding_table = nn.Embedding(9, model_dim) # 輸入單詞的字典 print(src_embedding_table) # 目標單詞的字典 print(tgt_embedding_table)
字典中,需要留一個維度給class token,故是9行。
1.4 得到向量化的句子
通過字典取出1.2
中得到的句子
# 得到向量化的句子 src_embedding = src_embedding_table(src_seq) tgt_embedding = tgt_embedding_table(tgt_seq) print(src_embedding) print(tgt_embedding)
該階段總程序
import torch # 句子長度 src_len = torch.tensor([2, 4]).to(torch.int32) tgt_len = torch.tensor([4, 3]).to(torch.int32) # 構造句子,用0填充空白處 src_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, 8, (L, )), (0, max(src_len)-L)), 0) for L in src_len]) tgt_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, 8, (L, )), (0, max(tgt_len)-L)), 0) for L in tgt_len]) # 構造字典 src_embedding_table = nn.Embedding(9, 8) tgt_embedding_table = nn.Embedding(9, 8) # 得到向量化的句子 src_embedding = src_embedding_table(src_seq) tgt_embedding = tgt_embedding_table(tgt_seq) print(src_embedding) print(tgt_embedding)
二、位置編碼
位置編碼是transformer的一個重點,通過加入transformer位置編碼,代替了傳統RNN的時序信息,增強了模型的并發度。位置編碼的公式如下:(其中pos代表行,i代表列)
2.1 計算括號內的值
# 得到分子pos的值 pos_mat = torch.arange(4).reshape((-1, 1)) # 得到分母值 i_mat = torch.pow(10000, torch.arange(0, 8, 2).reshape((1, -1))/8) print(pos_mat) print(i_mat)
2.2 得到位置編碼
# 初始化位置編碼矩陣 pe_embedding_table = torch.zeros(4, 8) # 得到偶數行位置編碼 pe_embedding_table[:, 0::2] =torch.sin(pos_mat / i_mat) # 得到奇數行位置編碼 pe_embedding_table[:, 1::2] =torch.cos(pos_mat / i_mat) pe_embedding = nn.Embedding(4, 8) # 設置位置編碼不可更新參數 pe_embedding.weight = nn.Parameter(pe_embedding_table, requires_grad=False) print(pe_embedding.weight)
三、多頭注意力
3.1 self mask
有些位置是空白用0填充的,訓練時不希望被這些位置所影響,那么就需要用到self mask。self mask的原理是令這些位置的值為無窮小,經過softmax后,這些值會變為0,不會再影響結果。
3.1.1 得到有效位置矩陣
# 得到有效位置矩陣 vaild_encoder_pos = torch.unsqueeze(torch.cat([torch.unsqueeze(F.pad(torch.ones(L), (0, max(src_len) - L)), 0)for L in src_len]), 2) valid_encoder_pos_matrix = torch.bmm(vaild_encoder_pos, vaild_encoder_pos.transpose(1, 2)) print(valid_encoder_pos_matrix)
3.1.2 得到無效位置矩陣
invalid_encoder_pos_matrix = 1-valid_encoder_pos_matrix mask_encoder_self_attention = invalid_encoder_pos_matrix.to(torch.bool) print(mask_encoder_self_attention)
True
代表需要對該位置mask
3.1.3 得到mask矩陣
用極小數填充需要被mask的位置
# 初始化mask矩陣 score = torch.randn(2, max(src_len), max(src_len)) # 用極小數填充 mask_score = score.masked_fill(mask_encoder_self_attention, -1e9) print(mask_score)
算其softmat
mask_score_softmax = F.softmax(mask_score) print(mask_score_softmax)
可以看到,已經達到預期效果
原文鏈接:https://blog.csdn.net/sunningzhzh/article/details/124786568
相關推薦
- 2022-06-15 C#數據類型實現背包、隊列和棧_C#教程
- 2022-06-12 Python語法學習之線程的創建與常用方法詳解_python
- 2022-07-19 Tomcat升級版本出現400問題
- 2022-08-22 .Net彈性和瞬態故障處理庫Polly實現執行策略_實用技巧
- 2022-03-20 Entity?Framework?Core關聯刪除_實用技巧
- 2022-11-29 如果服務器出現內存泄漏,堆內存緩慢上漲,一段時間后觸發了fullGc,如何快速定位?
- 2022-06-07 victoriaMetrics庫布隆過濾器初始化及使用詳解_Golang
- 2022-11-18 React生命周期函數深入全面介紹_React
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支