日本免费高清视频-国产福利视频导航-黄色在线播放国产-天天操天天操天天操天天操|www.shdianci.com

學無先后,達者為師

網站首頁 編程語言 正文

Pytorch從0實現Transformer的實踐_python

作者:原來如此- ? 更新時間: 2022-07-09 編程語言

摘要

With the continuous development of time series prediction, Transformer-like models have gradually replaced traditional models in the fields of CV and NLP by virtue of their powerful advantages. Among them, the Informer is far superior to the traditional RNN model in long-term prediction, and the Swin Transformer is significantly stronger than the traditional CNN model in image recognition. A deep grasp of Transformer has become an inevitable requirement in the field of artificial intelligence. This article will use the Pytorch framework to implement the position encoding, multi-head attention mechanism, self-mask, causal mask and other functions in Transformer, and build a Transformer network from 0.

隨著時序預測的不斷發展,Transformer類模型憑借強大的優勢,在CV、NLP領域逐漸取代傳統模型。其中Informer在長時序預測上遠超傳統的RNN模型,Swin Transformer在圖像識別上明顯強于傳統的CNN模型。深層次掌握Transformer已經成為從事人工智能領域的必然要求。本文將用Pytorch框架,實現Transformer中的位置編碼、多頭注意力機制、自掩碼、因果掩碼等功能,從0搭建一個Transformer網絡。

一、構造數據

1.1 句子長度

# 關于word embedding,以序列建模為例
# 輸入句子有兩個,第一個長度為2,第二個長度為4
src_len = torch.tensor([2, 4]).to(torch.int32)
# 目標句子有兩個。第一個長度為4, 第二個長度為3
tgt_len = torch.tensor([4, 3]).to(torch.int32)
print(src_len)
print(tgt_len)

輸入句子(src_len)有兩個,第一個長度為2,第二個長度為4
目標句子(tgt_len)有兩個。第一個長度為4, 第二個長度為3

1.2 生成句子

用隨機數生成句子,用0填充空白位置,保持所有句子長度一致

src_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_src_words, (L, )), (0, max(src_len)-L)), 0) for L in src_len])
tgt_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_tgt_words, (L, )), (0, max(tgt_len)-L)), 0) for L in tgt_len])
print(src_seq)
print(tgt_seq)

src_seq為輸入的兩個句子,tgt_seq為輸出的兩個句子。
為什么句子是數字?在做中英文翻譯時,每個中文或英文對應的也是一個數字,只有這樣才便于處理。

1.3 生成字典

在該字典中,總共有8個字(行),每個字對應8維向量(做了簡化了的)。注意在實際應用中,應當有幾十萬個字,每個字可能有512個維度。

# 構造word embedding
src_embedding_table = nn.Embedding(9, model_dim)
tgt_embedding_table = nn.Embedding(9, model_dim)
# 輸入單詞的字典
print(src_embedding_table)
# 目標單詞的字典
print(tgt_embedding_table)

字典中,需要留一個維度給class token,故是9行。

1.4 得到向量化的句子

通過字典取出1.2中得到的句子

# 得到向量化的句子
src_embedding = src_embedding_table(src_seq)
tgt_embedding = tgt_embedding_table(tgt_seq)
print(src_embedding)
print(tgt_embedding)

該階段總程序

import torch
# 句子長度
src_len = torch.tensor([2, 4]).to(torch.int32)
tgt_len = torch.tensor([4, 3]).to(torch.int32)
# 構造句子,用0填充空白處
src_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, 8, (L, )), (0, max(src_len)-L)), 0) for L in src_len])
tgt_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, 8, (L, )), (0, max(tgt_len)-L)), 0) for L in tgt_len])
# 構造字典
src_embedding_table = nn.Embedding(9, 8)
tgt_embedding_table = nn.Embedding(9, 8)
# 得到向量化的句子
src_embedding = src_embedding_table(src_seq)
tgt_embedding = tgt_embedding_table(tgt_seq)
print(src_embedding)
print(tgt_embedding)

二、位置編碼

位置編碼是transformer的一個重點,通過加入transformer位置編碼,代替了傳統RNN的時序信息,增強了模型的并發度。位置編碼的公式如下:(其中pos代表行,i代表列)

2.1 計算括號內的值

# 得到分子pos的值
pos_mat = torch.arange(4).reshape((-1, 1))
# 得到分母值
i_mat = torch.pow(10000, torch.arange(0, 8, 2).reshape((1, -1))/8)
print(pos_mat)
print(i_mat)

2.2 得到位置編碼

# 初始化位置編碼矩陣
pe_embedding_table = torch.zeros(4, 8)
# 得到偶數行位置編碼
pe_embedding_table[:, 0::2] =torch.sin(pos_mat / i_mat)
# 得到奇數行位置編碼
pe_embedding_table[:, 1::2] =torch.cos(pos_mat / i_mat)
pe_embedding = nn.Embedding(4, 8)
# 設置位置編碼不可更新參數
pe_embedding.weight = nn.Parameter(pe_embedding_table, requires_grad=False)
print(pe_embedding.weight)

三、多頭注意力

3.1 self mask

有些位置是空白用0填充的,訓練時不希望被這些位置所影響,那么就需要用到self mask。self mask的原理是令這些位置的值為無窮小,經過softmax后,這些值會變為0,不會再影響結果。

3.1.1 得到有效位置矩陣

# 得到有效位置矩陣
vaild_encoder_pos = torch.unsqueeze(torch.cat([torch.unsqueeze(F.pad(torch.ones(L), (0, max(src_len) - L)), 0)for L in src_len]), 2)
valid_encoder_pos_matrix = torch.bmm(vaild_encoder_pos, vaild_encoder_pos.transpose(1, 2))
print(valid_encoder_pos_matrix)

3.1.2 得到無效位置矩陣

invalid_encoder_pos_matrix = 1-valid_encoder_pos_matrix
mask_encoder_self_attention = invalid_encoder_pos_matrix.to(torch.bool)
print(mask_encoder_self_attention)

True代表需要對該位置mask

在這里插入圖片描述

3.1.3 得到mask矩陣
用極小數填充需要被mask的位置

# 初始化mask矩陣
score = torch.randn(2, max(src_len), max(src_len))
# 用極小數填充
mask_score = score.masked_fill(mask_encoder_self_attention, -1e9)
print(mask_score)

算其softmat

mask_score_softmax = F.softmax(mask_score)
print(mask_score_softmax)

可以看到,已經達到預期效果

原文鏈接:https://blog.csdn.net/sunningzhzh/article/details/124786568

欄目分類
最近更新