【手撕Transformer】Transformer输入输出细节以及代码实现(pytorch)
举例讲解transformer的输入输出细节
数据从输入到encoder到decoder输出这个过程中的流程(以机器翻译为例子):
encoder
对于机器翻译来说 ,一个样本是由原始句子和翻译后的句子组成的 。比如原始句子是: “我爱机器学习 ” ,那么翻译后是 ’i love machine learning‘ 。 则该一个样本就是由“我爱机器学习 ”和 “i love machine learning” 组成。
这个样本的原始句子的单词长度是length=4,即‘我’ ‘爱’ ‘机器’ ‘学习’ 。经过embedding后每个词的embedding向量是512 。那么“我爱机器学习 ”这个句子的embedding后的维度是[4,512 ] (若是批量输入 ,则embedding后的维度是[batch, 4, 512]) 。
padding
假设样本中句子的最大长度是10 ,那么对于长度不足10的句子 ,需要补足到10个长度 ,shape就变为[10, 512], 补全的位置上的embedding数值自然就是0了
Padding Mask
对于输入序列一般要进行padding补齐 ,也就是说设定一个统一长度N ,在较短的序列后面填充0到长度为N 。对于那些补零的数据来说 ,attention机制不应该把注意力放在这些位置上 ,所以需要进行一些处理 。具体的做法是 ,把这些位置的值加上一个非常大的负数(负无穷),这样经过softmax后 ,这些位置的权重就会接近0 。Transformer的padding mask实际上是一个张量 ,每个值都是一个Boolean,值为false的地方就是要进行处理的地方 。
Positional Embedding
得到补全后的句子embedding向量后 ,直接输入encoder的话 ,那么是没有考虑到句子中的位置顺序关系的 。此时需要再加一个位置向量,位置向量在模型训练中有特定的方式 ,可以表示每个词的位置或者不同词之间的距离;总之 ,核心思想是在attention计算时提供有效的距离信息 。
初步理解参考我的博客【初理解】Transformer中的Positional Encodingattention
参考我的博文(2021李宏毅)机器学习-Self-attention
FeedForward
略 ,很简单
add/Norm
经过add/norm后的隐藏输出的shape也是[10,512]。
encoder输入输出
从输入开始 ,再从头理一遍单个encoder这个过程:
输入x x 做一个层归一化: x1 = norm(x) 进入多头self-attention: x2 = self_attention(x1) 残差加成:x3 = x + x2 再做个层归一化:x4 = norm(x3) 经过前馈网络: x5 = feed_forward(x4) 残差加成: x6 = x3 + x5 输出x6
这就是Encoder所做的工作decoder
注意encoder的输出并没直接作为decoder的直接输入 。
训练的时候 ,1.初始decoder的time step为1时(也就是第一次接收输入) ,其输入为一个特殊的token ,可能是目标序列开始的token(如) ,也可能是源序列结尾的token(如) ,也可能是其它视任务而定的输入等等,不同源码中可能有微小的差异 ,其目标则是预测翻译后的第1个单词(token)是什么;2.然后和预测出来的第1个单词一起 ,再次作为decoder的输入,得到第2个预测单词;3后续依此类推;
具体的例子如下:
样本:“我/爱/机器/学习 ”和 “i/ love /machine/ learning ”
训练:把“我/爱/机器/学习 ”embedding后输入到encoder里去 ,最后一层的encoder最终输出的outputs [10, 512](假设我们采用的embedding长度为512 ,而且batch size = 1),此outputs 乘以新的参数矩阵,可以作为decoder里每一层用到的K和V;
将<bos>作为decoder的初始输入 ,将decoder的最大概率输出词 A1和‘i’做cross entropy计算error 。
将<bos> ,“i ” 作为decoder的输入 ,将decoder的最大概率输出词 A2 和‘love’做cross entropy计算error。
将<bos> ,“i ” ,“love ” 作为decoder的输入 ,将decoder的最大概率输出词A3和’machine’ 做cross entropy计算error 。
将<bos> ,“i ” ,"love " ,“machine ” 作为decoder的输入,将decoder最大概率输出词A4和‘learning’做cross entropy计算error 。
将<bos> ,“i” ,"love ",“machine ” ,“learning ” 作为decoder的输入 ,将decoder最大概率输出词A5和终止符做cross entropy计算error。
Sequence Mask
上述训练过程是挨个单词串行进行的,那么能不能并行进行呢 ,当然可以 。可以看到上述单个句子训练时候 ,输入到 decoder的分别是
<bos>
<bos> ,“i”
<bos> ,“i ” ,“love ”
<bos> ,“i” ,"love " ,“machine ”
<bos> ,“i ”,"love " ,“machine ” ,“learning ”
那么为何不将这些输入组成矩阵,进行输入呢?这些输入组成矩阵形式如下:
【<bos>
<bos> ,“i ”
<bos> ,“i ”,“love ”
<bos> ,“i ” ,"love " ,“machine ”
<bos> ,“i” ,"love " ,“machine ” ,“learning ” 】
怎么操作得到这个矩阵呢?
将decoder在上述2-6步次的输入补全为一个完整的句子
【
然后将上述矩阵矩阵乘以一个 mask矩阵
【1 0 0 0 0
1 1 0 0 0
1 1 1 0 0
1 1 1 1 0
1 1 1 1 1 】
这样是不是就得到了
【<bos>
<bos> ,“i ”
<bos>,“i” ,“love ”
<bos> ,“i ”,"love " ,“machine”
<bos> ,“i ” ,"love " ,“machine ” ,“learning ” 】
这样将这个矩阵输入到decoder(其实你可以想一下 ,此时这个矩阵是不是类似于批处理 ,矩阵的每行是一个样本 ,只是每行的样本长度不一样 ,每行输入后最终得到一个输出概率分布,作为矩阵输入的话一下可以得到5个输出概率分布) 。
这样就可以进行并行计算进行训练了 。测试
训练好模型 , 测试的时候 ,比如用 机器学习很有趣’当作测试样本,得到其英语翻译 。
这一句经过encoder后得到输出tensor ,送入到decoder(并不是当作decoder的直接输入):
然后用起始符<bos>当作decoder的 输入 ,得到输出 machine
用<bos> + machine 当作输入得到输出 learning
用 <bos> + machine + learning 当作输入得到is
用<bos> + machine + learning + is 当作输入得到interesting
用<bos> + machine + learning + is + interesting 当作输入得到 结束符号<eos>
得到了完整的翻译 ‘machine learning is interesting’
可以看到,在测试过程中 ,只能一个单词一个单词的进行输出 ,是串行进行的 。
Transformer pytorch代码实现
数据准备
import math import torch import numpy as np import torch.nn as nn import torch.optim as optim import torch.utils.data as Data #自制数据集 # Encoder_input Decoder_input Decoder_output sentences = [[我 是 学 生 P , S I am a student , I am a student E], # S: 开始符号 [我 喜 欢 学 习, S I like learning P, I like learning P E], # E: 结束符号 [我 是 男 生 P , S I am a boy , I am a boy E]] # P: 占位符号 ,如果当前句子不足固定长度用P占位 pad补0 src_vocab = {P:0, 我:1, 是:2, 学:3, 生:4, 喜:5, 欢:6,习:7,男:8} # 词源字典 字:索引 src_idx2word = {src_vocab[key]: key for key in src_vocab} src_vocab_size = len(src_vocab) # 字典字的个数 tgt_vocab = {S:0, E:1, P:2, I:3, am:4, a:5, student:6, like:7, learning:8, boy:9} idx2word = {tgt_vocab[key]: key for key in tgt_vocab} # 把目标字典转换成 索引:字的形式 tgt_vocab_size = len(tgt_vocab) # 目标字典尺寸 src_len = len(sentences[0][0].split(" ")) # Encoder输入的最大长度 5 tgt_len = len(sentences[0][1].split(" ")) # Decoder输入输出最大长度 5 src_len,tgt_len (5, 5) # 把sentences 转换成字典索引 def make_data(sentences): enc_inputs, dec_inputs, dec_outputs = [], [], [] for i in range(len(sentences)): enc_input = [[src_vocab[n] for n in sentences[i][0].split()]] dec_input = [[tgt_vocab[n] for n in sentences[i][1].split()]] dec_output = [[tgt_vocab[n] for n in sentences[i][2].split()]] enc_inputs.extend(enc_input) dec_inputs.extend(dec_input) dec_outputs.extend(dec_output) return torch.LongTensor(enc_inputs), torch.LongTensor(dec_inputs), torch.LongTensor(dec_outputs) enc_inputs, dec_inputs, dec_outputs = make_data(sentences) print(enc_inputs) print(dec_inputs) print(dec_outputs) tensor([[1, 2, 3, 4, 0], [1, 5, 6, 3, 7], [1, 2, 8, 4, 0]]) tensor([[0, 3, 4, 5, 6], [0, 3, 7, 8, 2], [0, 3, 4, 5, 9]]) tensor([[3, 4, 5, 6, 1], [3, 7, 8, 2, 1], [3, 4, 5, 9, 1]])sentences 里一共有三个训练数据 ,中文->英文 。把Encoder_input 、Decoder_input 、Decoder_output转换成字典索引 ,例如"学"->3 、“student ”->6 。再把数据转换成batch大小为2的分组数据 ,3句话一共可以分成两组 ,一组2句话 、一组1句话 。src_len表示中文句子固定最大长度 ,tgt_len 表示英文句子固定最大长度 。
#自定义数据集函数 class MyDataSet(Data.Dataset): def __init__(self, enc_inputs, dec_inputs, dec_outputs): super(MyDataSet, self).__init__() self.enc_inputs = enc_inputs self.dec_inputs = dec_inputs self.dec_outputs = dec_outputs def __len__(self): return self.enc_inputs.shape[0] def __getitem__(self, idx): return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx] loader = Data.DataLoader(MyDataSet(enc_inputs, dec_inputs, dec_outputs), 2, True)参数设置
d_model = 512 # 字 Embedding 的维度 d_ff = 2048 # 前向传播隐藏层维度 d_k = d_v = 64 # K(=Q), V的维度 n_layers = 6 # 有多少个encoder和decoder n_heads = 8 # Multi-Head Attention设置为8定义位置信息
首先 ,给出文章中的公式解读:
{
p
k
,
2
i
=
sin
(
k
/
1000
2
i
/
d
)
p
k
,
2
i
+
1
=
cos
(
k
/
1000
2
i
/
d
)
\left\{\begin{array}{l} \boldsymbol{p}_{k, 2 i}=\sin \left(k / 10000^{2 i / d}\right) \\ \boldsymbol{p}_{k, 2 i+1}=\cos \left(k / 10000^{2 i / d}\right) \end{array}\right.
{pk,2i=sin(k/100002i/d)pk,2i+1=cos(k/100002i/d)其中
p
k
,
2
i
,
p
k
,
2
i
+
1
分别是位置
k
的编码向量的第
2
i
,
2
i
+
1
个分量,
d
是向量维度
\text { 其中 } \boldsymbol{p}_{k, 2 i}, \boldsymbol{p}_{k, 2 i+1} \text { 分别是位置 } k \text { 的编码向量的第 } 2 i, 2 i+1 \text { 个分量, } d \text { 是向量维度 }
其中pk,2i,pk,2i+1分别是位置k的编码向量的第2i,2i+1个分量, d是向量维度 class PositionalEncoding(nn.Module): def __init__(self,d_model,dropout=0.1,max_len=5000): super(PositionalEncoding,self).__init__() self.dropout = nn.Dropout(p=dropout) pos_table = np.array([ [pos / np.power(10000, 2 * i / d_model) for i in range(d_model)] if pos != 0 else np.zeros(d_model) for pos in range(max_len)]) pos_table[1:, 0::2] = np.sin(pos_table[1:, 0::2]) # 字嵌入维度为偶数时 pos_table[1:, 1::2] = np.cos(pos_table[1:, 1::2]) # 字嵌入维度为奇数时 self.pos_table = torch.FloatTensor(pos_table).cuda() # enc_inputs: [seq_len, d_model] def forward(self,enc_inputs): # enc_inputs: [batch_size, seq_len, d_model] enc_inputs += self.pos_table[:enc_inputs.size(1),:] return self.dropout(enc_inputs.cuda())生成位置信息矩阵pos_table,直接加上输入的enc_inputs上 ,得到带有位置信息的字向量 ,pos_table是一个固定值的矩阵。这里矩阵加法利用到了广播机制
Mask掉停用词
Mask句子中没有实际意义的占位符,例如’我 是 学 生 P’ ,P对应句子没有实际意义 ,所以需要被Mask,Encoder_input 和Decoder_input占位符都需要被Mask 。
这就是为了处理 ,句子不一样长 ,但是输入有需要定长 ,不够长的pad填充 ,但是计算又不需要这个pad ,所以mask掉这个函数最核心的一句代码是 seq_k.data.eq(0) ,这句的作用是返回一个大小和 seq_k 一样的 tensor ,只不过里面的值只有 True 和 False 。如果 seq_k 某个位置的值等于 0 ,那么对应位置就是 True ,否则即为 False。举个例子,输入为 seq_data = [1, 2, 3, 4, 0] ,seq_data.data.eq(0) 就会返回 [False, False, False, False, True]
def get_attn_pad_mask(seq_q,seq_k): batch_size, len_q = seq_q.size()# seq_q 用于升维 ,为了做attention,mask score矩阵用的 batch_size, len_k = seq_k.size() pad_attn_mask = seq_k.data.eq(0).unsqueeze(1) # 判断 输入那些含有P(=0),用1标记 ,[batch_size, 1, len_k] return pad_attn_mask.expand(batch_size,len_q,len_k) # 扩展成多维度 [batch_size, len_q, len_k]Decoder 输入 Mask
用来Mask未来输入信息 ,返回的是一个上三角矩阵 。比如我们在中英文翻译时候 ,会先把"我是学生"整个句子输入到Encoder中,得到最后一层的输出后 ,才会在Decoder输入"S I am a student"(s表示开始),但是"S I am a student"这个句子我们不会一起输入 ,而是在T0时刻先输入"S"预测 ,预测第一个词"I";在下一个T1时刻 ,同时输入"S"和"I"到Decoder预测下一个单词"am";然后在T2时刻把"S,I,am"同时输入到Decoder预测下一个单词"a" ,依次把整个句子输入到Decoder,预测出"I am a student E" 。
下图是 np.triu() 用法
def get_attn_subsequence_mask(seq): # seq: [batch_size, tgt_len] attn_shape = [seq.size(0), seq.size(1), seq.size(1)] # 生成上三角矩阵,[batch_size, tgt_len, tgt_len] subsequence_mask = np.triu(np.ones(attn_shape), k=1) subsequence_mask = torch.from_numpy(subsequence_mask).byte() # [batch_size, tgt_len, tgt_len] return subsequence_mask计算注意力信息 、残差和归一化
class ScaledDotProductAttention(nn.Module): def __init__(self): super(ScaledDotProductAttention, self).__init__() def forward(self, Q, K, V, attn_mask): # Q: [batch_size, n_heads, len_q, d_k] # K: [batch_size, n_heads, len_k, d_k] # V: [batch_size, n_heads, len_v(=len_k), d_v] # attn_mask: [batch_size, n_heads, seq_len, seq_len] scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # scores : [batch_size, n_heads, len_q, len_k] scores.masked_fill_(attn_mask, -1e9) # 如果是停用词P就等于 0 attn = nn.Softmax(dim=-1)(scores) context = torch.matmul(attn, V) # [batch_size, n_heads, len_q, d_v] return context, attn计算注意力信息 ,
W
Q
,
W
K
,
W
V
W^{Q}, W^{K}, W^{V}
WQ,WK,WV 矩阵会拆分成 8 个小矩阵。注意传入的 input_Q, input_K, input_V, 在Encoder和Decoder的第一次调用传入的三个矩阵是相同的 ,但 Decoder的第二次调用传入的三个矩阵input_Q 等于 input_K 不等于 input_V,因为decoder中是计算的cross attention ,如下图所示. class MultiHeadAttention(nn.Module): def __init__(self): super(MultiHeadAttention, self).__init__() self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False) self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False) self.W_V = nn.Linear(d_model, d_v * n_heads, bias=False) self.fc = nn.Linear(n_heads * d_v, d_model, bias=False) def forward(self, input_Q, input_K, input_V, attn_mask): # input_Q: [batch_size, len_q, d_model] # input_K: [batch_size, len_k, d_model] # input_V: [batch_size, len_v(=len_k), d_model] # attn_mask: [batch_size, seq_len, seq_len] residual, batch_size = input_Q, input_Q.size(0) Q = self.W_Q(input_Q).view(batch_size, -1, n_heads, d_k).transpose(1,2) # Q: [batch_size, n_heads, len_q, d_k] K = self.W_K(input_K).view(batch_size, -1, n_heads, d_k).transpose(1,2) # K: [batch_size, n_heads, len_k, d_k] V = self.W_V(input_V).view(batch_size, -1, n_heads, d_v).transpose(1,2) # V: [batch_size, n_heads, len_v(=len_k), d_v] attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask : [batch_size, n_heads, seq_len, seq_len] context, attn = ScaledDotProductAttention()(Q, K, V, attn_mask) # context: [batch_size, n_heads, len_q, d_v] # attn: [batch_size, n_heads, len_q, len_k] context = context.transpose(1, 2).reshape(batch_size, -1, n_heads * d_v) # context: [batch_size, len_q, n_heads * d_v] output = self.fc(context) # [batch_size, len_q, d_model] return nn.LayerNorm(d_model).cuda()(output + residual), attn前馈神经网络
输入inputs ,经过两个全连接层,得到的结果再加上 inputs (残差) ,再做LayerNorm归一化 。LayerNorm归一化可以理解层是把Batch中每一句话进行归一化 。
class PoswiseFeedForwardNet(nn.Module): def __init__(self): super(PoswiseFeedForwardNet, self).__init__() self.fc = nn.Sequential( nn.Linear(d_model, d_ff, bias=False), nn.ReLU(), nn.Linear(d_ff, d_model, bias=False)) def forward(self, inputs): # inputs: [batch_size, seq_len, d_model] residual = inputs output = self.fc(inputs) return nn.LayerNorm(d_model).cuda()(output + residual) # [batch_size, seq_len, d_model]encoder layer(block)
class EncoderLayer(nn.Module): def __init__(self): super(EncoderLayer, self).__init__() self.enc_self_attn = MultiHeadAttention() # 多头注意力机制 self.pos_ffn = PoswiseFeedForwardNet() # 前馈神经网络 def forward(self, enc_inputs, enc_self_attn_mask): # enc_inputs: [batch_size, src_len, d_model] #输入3个enc_inputs分别与W_q 、W_k 、W_v相乘得到Q 、K 、V # enc_self_attn_mask: [batch_size, src_len, src_len] enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, # enc_outputs: [batch_size, src_len, d_model], enc_self_attn_mask) # attn: [batch_size, n_heads, src_len, src_len] enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size, src_len, d_model] return enc_outputs, attnEncoder
第一步 ,中文字索引进行Embedding,转换成512维度的字向量 。第二步 ,在子向量上面加上位置信息 。第三步 ,Mask掉句子中的占位符号 。第四步,通过6层的encoder(上一层的输出作为下一层的输入) 。
class Encoder(nn.Module): def __init__(self): super(Encoder, self).__init__() self.src_emb = nn.Embedding(src_vocab_size, d_model) self.pos_emb = PositionalEncoding(d_model) self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)]) def forward(self, enc_inputs): enc_inputs: [batch_size, src_len] enc_outputs = self.src_emb(enc_inputs) # [batch_size, src_len, d_model] enc_outputs = self.pos_emb(enc_outputs.transpose(0, 1)).transpose(0, 1) # [batch_size, src_len, d_model] enc_self_attn_mask = get_attn_pad_mask(enc_inputs, enc_inputs) # [batch_size, src_len, src_len] enc_self_attns = [] for layer in self.layers: # enc_outputs: [batch_size, src_len, d_model], enc_self_attn: [batch_size, n_heads, src_len, src_len] enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask) enc_self_attns.append(enc_self_attn) return enc_outputs, enc_self_attnsdecoder layer(block)
decoder两次调用MultiHeadAttention时 ,第一次调用传入的 Q ,K ,V 的值是相同的 ,都等于dec_inputs ,第二次调用 Q 矩阵是来自Decoder的输入 。K ,V 两个矩阵是来自Encoder的输出 ,等于enc_outputs 。
class DecoderLayer(nn.Module): def __init__(self): super(DecoderLayer, self).__init__() self.dec_self_attn = MultiHeadAttention() self.dec_enc_attn = MultiHeadAttention() self.pos_ffn = PoswiseFeedForwardNet() def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask): # dec_inputs: [batch_size, tgt_len, d_model] # enc_outputs: [batch_size, src_len, d_model] # dec_self_attn_mask: [batch_size, tgt_len, tgt_len] # dec_enc_attn_mask: [batch_size, tgt_len, src_len] dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask) # dec_outputs: [batch_size, tgt_len, d_model] # dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len] dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask) # dec_outputs: [batch_size, tgt_len, d_model] # dec_enc_attn: [batch_size, h_heads, tgt_len, src_len] dec_outputs = self.pos_ffn(dec_outputs) # dec_outputs: [batch_size, tgt_len, d_model] return dec_outputs, dec_self_attn, dec_enc_attnDecoder
第一步 ,英文字索引进行Embedding ,转换成512维度的字向量 。第二步,在子向量上面加上位置信息。第三步 ,Mask掉句子中的占位符号和输出顺序.第四步 ,通过6层的decoder(上一层的输出作为下一层的输入)
class Decoder(nn.Module): def __init__(self): super(Decoder, self).__init__() self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model) self.pos_emb = PositionalEncoding(d_model) self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)]) def forward(self, dec_inputs, enc_inputs, enc_outputs): dec_inputs: [batch_size, tgt_len] enc_intpus: [batch_size, src_len] enc_outputs: [batch_size, src_len, d_model] dec_outputs = self.tgt_emb(dec_inputs) # [batch_size, tgt_len, d_model] dec_outputs = self.pos_emb(dec_outputs.transpose(0, 1)).transpose(0, 1).cuda() # [batch_size, tgt_len, d_model] # Decoder输入序列的pad mask矩阵(这个例子中decoder是没有加pad的,实际应用中都是有pad填充的) dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs).cuda() # [batch_size, tgt_len, tgt_len] # Masked Self_Attention:当前时刻是看不到未来的信息的 dec_self_attn_subsequence_mask = get_attn_subsequence_mask(dec_inputs).cuda() # [batch_size, tgt_len, tgt_len] # Decoder中把两种mask矩阵相加(既屏蔽了pad的信息 ,也屏蔽了未来时刻的信息) dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequence_mask), 0).cuda() # [batch_size, tgt_len, tgt_len] # 这个mask主要用于encoder-decoder attention层 # get_attn_pad_mask主要是enc_inputs的pad mask矩阵(因为enc是处理K,V的 ,求Attention时是用v1,v2,..vm去加权的, # 要把pad对应的v_i的相关系数设为0 ,这样注意力就不会关注pad向量) # dec_inputs只是提供expand的size的 dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs) # [batc_size, tgt_len, src_len] dec_self_attns, dec_enc_attns = [], [] for layer in self.layers: # dec_outputs: [batch_size, tgt_len, d_model], dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len], dec_enc_attn: [batch_size, h_heads, tgt_len, src_len] dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask) dec_self_attns.append(dec_self_attn) dec_enc_attns.append(dec_enc_attn) return dec_outputs, dec_self_attns, dec_enc_attnsTransformer
Trasformer的整体结构 ,输入数据先通过Encoder ,再同个Decoder ,最后把输出进行多分类 ,分类数为英文字典长度 ,也就是判断每一个字的概率 。
class Transformer(nn.Module): def __init__(self): super(Transformer, self).__init__() self.Encoder = Encoder().cuda() self.Decoder = Decoder().cuda() self.projection = nn.Linear(d_model, tgt_vocab_size, bias=False).cuda() def forward(self, enc_inputs, dec_inputs): # enc_inputs: [batch_size, src_len] # dec_inputs: [batch_size, tgt_len] enc_outputs, enc_self_attns = self.Encoder(enc_inputs) # enc_outputs: [batch_size, src_len, d_model], # enc_self_attns: [n_layers, batch_size, n_heads, src_len, src_len] dec_outputs, dec_self_attns, dec_enc_attns = self.Decoder( dec_inputs, enc_inputs, enc_outputs) # dec_outpus : [batch_size, tgt_len, d_model], # dec_self_attns: [n_layers, batch_size, n_heads, tgt_len, tgt_len], # dec_enc_attn : [n_layers, batch_size, tgt_len, src_len] dec_logits = self.projection(dec_outputs) # dec_logits: [batch_size, tgt_len, tgt_vocab_size] return dec_logits.view(-1, dec_logits.size(-1)), enc_self_attns, dec_self_attns, dec_enc_attns定义网络
model = Transformer().cuda() criterion = nn.CrossEntropyLoss(ignore_index=0) #忽略 占位符 索引为0. optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.99)训练Transformer
因为batch=2 ,所以一个epoch有两个loss
for epoch in range(1000): for enc_inputs, dec_inputs, dec_outputs in loader: enc_inputs, dec_inputs, dec_outputs = enc_inputs.cuda(), dec_inputs.cuda(), dec_outputs.cuda() outputs, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs) loss = criterion(outputs,dec_outputs.view(-1)) print(Epoch:, %04d % (epoch+1), loss =, {:.6f}.format(loss)) optimizer.zero_grad() loss.backward() optimizer.step() Epoch: 0001 loss = 0.000002 Epoch: 0001 loss = 0.000002 Epoch: 0002 loss = 0.000002 Epoch: 0002 loss = 0.000002 Epoch: 0003 loss = 0.000002 Epoch: 0003 loss = 0.000002 Epoch: 0004 loss = 0.000004 Epoch: 0004 loss = 0.000002 Epoch: 0005 loss = 0.000003 Epoch: 0005 loss = 0.000004 Epoch: 0006 loss = 0.000003 Epoch: 0006 loss = 0.000002 Epoch: 0007 loss = 0.000003 Epoch: 0007 loss = 0.000002 Epoch: 0008 loss = 0.000003 Epoch: 0008 loss = 0.000003 Epoch: 0009 loss = 0.000003 Epoch: 0009 loss = 0.000002 Epoch: 0010 loss = 0.000004 Epoch: 0010 loss = 0.000002 Epoch: 0011 loss = 0.000002 Epoch: 0011 loss = 0.000002 Epoch: 0012 loss = 0.000004 Epoch: 0012 loss = 0.000003 Epoch: 0013 loss = 0.000003 Epoch: 0013 loss = 0.000003 Epoch: 0014 loss = 0.000002 Epoch: 0014 loss = 0.000002 Epoch: 0015 loss = 0.000003 Epoch: 0015 loss = 0.000003 Epoch: 0016 loss = 0.000003 Epoch: 0016 loss = 0.000002 Epoch: 0017 loss = 0.000001 Epoch: 0017 loss = 0.000002 Epoch: 0018 loss = 0.000002 Epoch: 0018 loss = 0.000003 Epoch: 0019 loss = 0.000003 Epoch: 0019 loss = 0.000002 Epoch: 0020 loss = 0.000003 Epoch: 0020 loss = 0.000002 Epoch: 0021 loss = 0.000002 Epoch: 0021 loss = 0.000004 Epoch: 0022 loss = 0.000003 Epoch: 0022 loss = 0.000002 Epoch: 0023 loss = 0.000003 Epoch: 0023 loss = 0.000002 Epoch: 0024 loss = 0.000003 Epoch: 0024 loss = 0.000002 Epoch: 0025 loss = 0.000003 Epoch: 0025 loss = 0.000002 Epoch: 0026 loss = 0.000003 Epoch: 0026 loss = 0.000002 Epoch: 0027 loss = 0.000002 Epoch: 0027 loss = 0.000002 Epoch: 0028 loss = 0.000002 Epoch: 0028 loss = 0.000002 Epoch: 0029 loss = 0.000003 Epoch: 0029 loss = 0.000002 Epoch: 0030 loss = 0.000003 Epoch: 0030 loss = 0.000003 Epoch: 0031 loss = 0.000002 Epoch: 0031 loss = 0.000002 Epoch: 0032 loss = 0.000002 Epoch: 0032 loss = 0.000003 Epoch: 0033 loss = 0.000002 Epoch: 0033 loss = 0.000002 Epoch: 0034 loss = 0.000001 Epoch: 0034 loss = 0.000002 Epoch: 0035 loss = 0.000003 Epoch: 0035 loss = 0.000002 Epoch: 0036 loss = 0.000003 Epoch: 0036 loss = 0.000002 Epoch: 0037 loss = 0.000003 Epoch: 0037 loss = 0.000003 Epoch: 0038 loss = 0.000002 Epoch: 0038 loss = 0.000002 Epoch: 0039 loss = 0.000002 Epoch: 0039 loss = 0.000002 Epoch: 0040 loss = 0.000002 Epoch: 0040 loss = 0.000002 Epoch: 0041 loss = 0.000003 Epoch: 0041 loss = 0.000002 Epoch: 0042 loss = 0.000003 Epoch: 0042 loss = 0.000003 Epoch: 0043 loss = 0.000003 Epoch: 0043 loss = 0.000002 Epoch: 0044 loss = 0.000003 Epoch: 0044 loss = 0.000002 Epoch: 0045 loss = 0.000002 Epoch: 0045 loss = 0.000003 Epoch: 0046 loss = 0.000002 Epoch: 0046 loss = 0.000002 Epoch: 0047 loss = 0.000003 Epoch: 0047 loss = 0.000002 Epoch: 0048 loss = 0.000003 Epoch: 0048 loss = 0.000002 Epoch: 0049 loss = 0.000002 Epoch: 0049 loss = 0.000004 Epoch: 0050 loss = 0.000003 Epoch: 0050 loss = 0.000002 Epoch: 0051 loss = 0.000002 Epoch: 0051 loss = 0.000002 Epoch: 0052 loss = 0.000003 Epoch: 0052 loss = 0.000003 Epoch: 0053 loss = 0.000002 Epoch: 0053 loss = 0.000002 Epoch: 0054 loss = 0.000002 Epoch: 0054 loss = 0.000001 Epoch: 0055 loss = 0.000002 Epoch: 0055 loss = 0.000003 Epoch: 0056 loss = 0.000002 Epoch: 0056 loss = 0.000003 Epoch: 0057 loss = 0.000002 Epoch: 0057 loss = 0.000003 Epoch: 0058 loss = 0.000002 Epoch: 0058 loss = 0.000002 Epoch: 0059 loss = 0.000003 Epoch: 0059 loss = 0.000004 Epoch: 0060 loss = 0.000002 Epoch: 0060 loss = 0.000003 Epoch: 0061 loss = 0.000002 Epoch: 0061 loss = 0.000002 Epoch: 0062 loss = 0.000002 Epoch: 0062 loss = 0.000003 Epoch: 0063 loss = 0.000003 Epoch: 0063 loss = 0.000002 Epoch: 0064 loss = 0.000002 Epoch: 0064 loss = 0.000003 Epoch: 0065 loss = 0.000003 Epoch: 0065 loss = 0.000002 Epoch: 0066 loss = 0.000002 Epoch: 0066 loss = 0.000004 Epoch: 0067 loss = 0.000001 Epoch: 0067 loss = 0.000003 Epoch: 0068 loss = 0.000003 Epoch: 0068 loss = 0.000004 Epoch: 0069 loss = 0.000002 Epoch: 0069 loss = 0.000002 Epoch: 0070 loss = 0.000001 Epoch: 0070 loss = 0.000003 Epoch: 0071 loss = 0.000004 Epoch: 0071 loss = 0.000002 Epoch: 0072 loss = 0.000003 Epoch: 0072 loss = 0.000002 Epoch: 0073 loss = 0.000002 Epoch: 0073 loss = 0.000003 Epoch: 0074 loss = 0.000003 Epoch: 0074 loss = 0.000002 Epoch: 0075 loss = 0.000003 Epoch: 0075 loss = 0.000002 Epoch: 0076 loss = 0.000002 Epoch: 0076 loss = 0.000003 Epoch: 0077 loss = 0.000001 Epoch: 0077 loss = 0.000002 Epoch: 0078 loss = 0.000001 Epoch: 0078 loss = 0.000002 Epoch: 0079 loss = 0.000003 Epoch: 0079 loss = 0.000002 Epoch: 0080 loss = 0.000002 Epoch: 0080 loss = 0.000002 Epoch: 0081 loss = 0.000002 Epoch: 0081 loss = 0.000005 Epoch: 0082 loss = 0.000003 Epoch: 0082 loss = 0.000002 Epoch: 0083 loss = 0.000003 Epoch: 0083 loss = 0.000003 Epoch: 0084 loss = 0.000002 Epoch: 0084 loss = 0.000003 Epoch: 0085 loss = 0.000002 Epoch: 0085 loss = 0.000002 Epoch: 0086 loss = 0.000003 Epoch: 0086 loss = 0.000001 Epoch: 0087 loss = 0.000002 Epoch: 0087 loss = 0.000002 Epoch: 0088 loss = 0.000001 Epoch: 0088 loss = 0.000002 Epoch: 0089 loss = 0.000002 Epoch: 0089 loss = 0.000003 Epoch: 0090 loss = 0.000002 Epoch: 0090 loss = 0.000002 Epoch: 0091 loss = 0.000004 Epoch: 0091 loss = 0.000002 Epoch: 0092 loss = 0.000002 Epoch: 0092 loss = 0.000002 Epoch: 0093 loss = 0.000003 Epoch: 0093 loss = 0.000002 Epoch: 0094 loss = 0.000002 Epoch: 0094 loss = 0.000003 Epoch: 0095 loss = 0.000001 Epoch: 0095 loss = 0.000002 Epoch: 0096 loss = 0.000003 Epoch: 0096 loss = 0.000002 Epoch: 0097 loss = 0.000002 Epoch: 0097 loss = 0.000002 Epoch: 0098 loss = 0.000001 Epoch: 0098 loss = 0.000003 Epoch: 0099 loss = 0.000003 Epoch: 0099 loss = 0.000003 Epoch: 0100 loss = 0.000002 Epoch: 0100 loss = 0.000003 Epoch: 0101 loss = 0.000003 Epoch: 0101 loss = 0.000002 Epoch: 0102 loss = 0.000004 Epoch: 0102 loss = 0.000002 Epoch: 0103 loss = 0.000003 Epoch: 0103 loss = 0.000002 Epoch: 0104 loss = 0.000003 Epoch: 0104 loss = 0.000003 Epoch: 0105 loss = 0.000003 Epoch: 0105 loss = 0.000002 Epoch: 0106 loss = 0.000002 Epoch: 0106 loss = 0.000001 Epoch: 0107 loss = 0.000003 Epoch: 0107 loss = 0.000003 Epoch: 0108 loss = 0.000002 Epoch: 0108 loss = 0.000002 Epoch: 0109 loss = 0.000003 Epoch: 0109 loss = 0.000002 Epoch: 0110 loss = 0.000002 Epoch: 0110 loss = 0.000002 Epoch: 0111 loss = 0.000002 Epoch: 0111 loss = 0.000003 Epoch: 0112 loss = 0.000003 Epoch: 0112 loss = 0.000002 Epoch: 0113 loss = 0.000003 Epoch: 0113 loss = 0.000003 Epoch: 0114 loss = 0.000003 Epoch: 0114 loss = 0.000002 Epoch: 0115 loss = 0.000001 Epoch: 0115 loss = 0.000003 Epoch: 0116 loss = 0.000002 Epoch: 0116 loss = 0.000002 Epoch: 0117 loss = 0.000003 Epoch: 0117 loss = 0.000002 Epoch: 0118 loss = 0.000002 Epoch: 0118 loss = 0.000001 Epoch: 0119 loss = 0.000003 Epoch: 0119 loss = 0.000002 Epoch: 0120 loss = 0.000002 Epoch: 0120 loss = 0.000002 Epoch: 0121 loss = 0.000002 Epoch: 0121 loss = 0.000003 Epoch: 0122 loss = 0.000003 Epoch: 0122 loss = 0.000002 Epoch: 0123 loss = 0.000003 Epoch: 0123 loss = 0.000002 Epoch: 0124 loss = 0.000002 Epoch: 0124 loss = 0.000002 Epoch: 0125 loss = 0.000002 Epoch: 0125 loss = 0.000003 Epoch: 0126 loss = 0.000002 Epoch: 0126 loss = 0.000002 Epoch: 0127 loss = 0.000002 Epoch: 0127 loss = 0.000002 Epoch: 0128 loss = 0.000002 Epoch: 0128 loss = 0.000002 Epoch: 0129 loss = 0.000002 Epoch: 0129 loss = 0.000003 Epoch: 0130 loss = 0.000002 Epoch: 0130 loss = 0.000002 Epoch: 0131 loss创心域SEO版权声明:以上内容作者已申请原创保护,未经允许不得转载,侵权必究!授权事宜、对本内容有异议或投诉,敬请联系网站管理员,我们将尽快回复您,谢谢合作!