用PyTorch手写Transformer:从Self-Attention到完整模型的实战指南
当你第一次接触Transformer时,是否曾被那些复杂的矩阵运算和维度变换搞得晕头转向?别担心,今天我们将用PyTorch从零开始构建一个完整的Transformer模型,通过代码实践带你真正理解这个革命性架构的精髓。不同于单纯的理论讲解,我们将采用"实现-解释-调试"的循环方式,确保每个组件都清晰可理解。
1. 环境准备与基础概念
在开始编码之前,我们需要明确几个核心概念。Transformer本质上是一个基于自注意力机制的序列到序列模型,它彻底改变了传统RNN处理序列数据的方式。让我们先准备好开发环境:
import torch import torch.nn as nn import math import numpy as np from torch.nn import functional as F print(f"PyTorch版本: {torch.__version__}") print(f"CUDA可用: {torch.cuda.is_available()}")Transformer的核心创新在于完全摒弃了循环结构,转而使用自注意力机制来捕捉序列中元素间的长距离依赖关系。这种设计带来了几个显著优势:
- 并行计算能力:不再受限于序列的时序依赖
- 全局视野:每个位置都能直接访问序列中所有其他位置的信息
- 灵活的特征组合:通过多头机制学习不同子空间的特征表示
2. 实现Scaled Dot-Product Attention
自注意力机制是Transformer的基础构建块,我们先从最核心的缩放点积注意力开始实现。这个模块计算query和key的相似度,然后对value进行加权求和。
class ScaledDotProductAttention(nn.Module): def __init__(self, dropout=0.1): super().__init__() self.dropout = nn.Dropout(dropout) def forward(self, Q, K, V, mask=None): """ 参数: Q: Query矩阵 [batch_size, n_head, seq_len, d_k] K: Key矩阵 [batch_size, n_head, seq_len, d_k] V: Value矩阵 [batch_size, n_head, seq_len, d_v] mask: 可选掩码 [batch_size, 1, seq_len, seq_len] 返回: output: 注意力输出 [batch_size, n_head, seq_len, d_v] attn_weights: 注意力权重 [batch_size, n_head, seq_len, seq_len] """ d_k = Q.size(-1) # 计算QK^T/sqrt(d_k) scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k) # 应用掩码(如需要) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) # softmax归一化得到注意力权重 attn_weights = F.softmax(scores, dim=-1) attn_weights = self.dropout(attn_weights) # 对Value加权求和 output = torch.matmul(attn_weights, V) return output, attn_weights关键点解析:
- 缩放因子
1/sqrt(d_k)的作用是防止点积结果过大导致softmax梯度消失 - 掩码机制在decoder中用于防止当前位置关注到未来信息
- dropout应用于注意力权重,是一种有效的正则化手段
为了验证我们的实现,让我们构造一些测试数据:
batch_size, n_head, seq_len, d_k, d_v = 2, 3, 5, 64, 64 Q = torch.randn(batch_size, n_head, seq_len, d_k) K = torch.randn(batch_size, n_head, seq_len, d_k) V = torch.randn(batch_size, n_head, seq_len, d_v) attention = ScaledDotProductAttention() output, attn_weights = attention(Q, K, V) print(f"输出形状: {output.shape}") # 应为 [2, 3, 5, 64] print(f"注意力权重形状: {attn_weights.shape}") # 应为 [2, 3, 5, 5]3. 构建Multi-Head Attention
多头注意力允许模型在不同的表示子空间中学习信息,这是Transformer强大表征能力的关键。下面是具体实现:
class MultiHeadAttention(nn.Module): def __init__(self, d_model, n_head, dropout=0.1): super().__init__() assert d_model % n_head == 0, "d_model必须能被n_head整除" self.d_model = d_model self.n_head = n_head self.d_k = d_model // n_head self.d_v = d_model // n_head # 线性变换层 self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model) self.attention = ScaledDotProductAttention(dropout) self.dropout = nn.Dropout(dropout) self.layer_norm = nn.LayerNorm(d_model) def forward(self, Q, K, V, mask=None): """ 参数: Q: Query [batch_size, seq_len, d_model] K: Key [batch_size, seq_len, d_model] V: Value [batch_size, seq_len, d_model] mask: 可选掩码 [batch_size, seq_len, seq_len] 返回: output: [batch_size, seq_len, d_model] attn_weights: [batch_size, n_head, seq_len, seq_len] """ batch_size = Q.size(0) # 1. 线性投影 + 分头 Q = self.W_q(Q).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2) K = self.W_k(K).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2) V = self.W_v(V).view(batch_size, -1, self.n_head, self.d_v).transpose(1, 2) # 2. 计算缩放点积注意力 if mask is not None: mask = mask.unsqueeze(1) # 为多头扩展维度 x, attn_weights = self.attention(Q, K, V, mask=mask) # 3. 合并多头 + 最终线性变换 x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model) output = self.W_o(x) return output, attn_weights实现细节:
- 通过线性变换将输入投影到多个子空间
- 每个头独立计算注意力,最后合并结果
- 保留了残差连接和层归一化的接口(将在完整模型中实现)
测试多头注意力的行为:
d_model, n_head = 512, 8 mha = MultiHeadAttention(d_model, n_head) Q = torch.randn(3, 10, d_model) # [batch_size, seq_len, d_model] K = V = Q # 自注意力情况下三者相同 output, attn = mha(Q, K, V) print(f"多头注意力输出形状: {output.shape}") # [3, 10, 512]4. 位置编码与嵌入层
Transformer需要显式地注入位置信息,因为自注意力机制本身不具备感知序列顺序的能力。我们实现正弦位置编码:
class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=5000): super().__init__() position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe = torch.zeros(max_len, d_model) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) # [1, max_len, d_model] self.register_buffer('pe', pe) def forward(self, x): """ 参数: x: [batch_size, seq_len, d_model] 返回: 添加位置编码后的张量 """ return x + self.pe[:, :x.size(1)]结合词嵌入和位置编码的完整输入处理:
class TransformerEmbedding(nn.Module): def __init__(self, vocab_size, d_model, max_len, dropout=0.1): super().__init__() self.token_embedding = nn.Embedding(vocab_size, d_model) self.position_encoding = PositionalEncoding(d_model, max_len) self.dropout = nn.Dropout(dropout) self.layer_norm = nn.LayerNorm(d_model) def forward(self, x): # x: [batch_size, seq_len] token_emb = self.token_embedding(x) # [batch_size, seq_len, d_model] pos_emb = self.position_encoding(token_emb) output = self.layer_norm(pos_emb) output = self.dropout(output) return output位置编码的直观理解:
- 每个位置对应一个独特的"波形"编码
- 不同维度对应不同频率的正弦/余弦函数
- 模型可以学习利用这些位置信息
5. 前馈神经网络与残差连接
Transformer中的前馈网络实际上是一个两层的MLP,应用于每个位置独立:
class PositionwiseFeedForward(nn.Module): def __init__(self, d_model, d_ff, dropout=0.1): super().__init__() self.linear1 = nn.Linear(d_model, d_ff) self.linear2 = nn.Linear(d_ff, d_model) self.dropout = nn.Dropout(dropout) self.layer_norm = nn.LayerNorm(d_model) def forward(self, x): # x: [batch_size, seq_len, d_model] residual = x x = F.relu(self.linear1(x)) x = self.dropout(x) x = self.linear2(x) x = self.dropout(x) x = self.layer_norm(x + residual) return x设计要点:
- 中间层的维度d_ff通常比d_model大(如4倍)
- 使用ReLU激活函数
- 应用了残差连接和层归一化
6. 构建Encoder层与Decoder层
现在我们可以组装完整的Encoder层和Decoder层了:
class EncoderLayer(nn.Module): def __init__(self, d_model, n_head, d_ff, dropout=0.1): super().__init__() self.self_attn = MultiHeadAttention(d_model, n_head, dropout) self.ffn = PositionwiseFeedForward(d_model, d_ff, dropout) self.dropout = nn.Dropout(dropout) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) def forward(self, x, mask=None): # 自注意力子层 residual = x x, _ = self.self_attn(x, x, x, mask) x = self.dropout(x) x = self.norm1(x + residual) # 前馈网络子层 residual = x x = self.ffn(x) x = self.norm2(x + residual) return x class DecoderLayer(nn.Module): def __init__(self, d_model, n_head, d_ff, dropout=0.1): super().__init__() self.self_attn = MultiHeadAttention(d_model, n_head, dropout) self.cross_attn = MultiHeadAttention(d_model, n_head, dropout) self.ffn = PositionwiseFeedForward(d_model, d_ff, dropout) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) def forward(self, x, enc_output, src_mask=None, tgt_mask=None): # 自注意力子层(带目标序列掩码) residual = x x, _ = self.self_attn(x, x, x, tgt_mask) x = self.dropout(x) x = self.norm1(x + residual) # 交叉注意力子层(Query来自解码器,Key/Value来自编码器) if enc_output is not None: residual = x x, _ = self.cross_attn(x, enc_output, enc_output, src_mask) x = self.dropout(x) x = self.norm2(x + residual) # 前馈网络子层 residual = x x = self.ffn(x) x = self.norm3(x + residual) return xDecoder的特殊处理:
- 第一层自注意力使用掩码防止信息泄露
- 第二层交叉注意力连接编码器和解码器
- 每子层都有残差连接和层归一化
7. 组装完整Transformer模型
现在我们可以将所有组件组合成完整的Transformer模型:
class Transformer(nn.Module): def __init__(self, src_vocab_size, tgt_vocab_size, d_model, n_head, num_layers, d_ff, max_seq_len, dropout=0.1): super().__init__() # 编码器部分 self.encoder_embedding = TransformerEmbedding(src_vocab_size, d_model, max_seq_len, dropout) self.encoder_layers = nn.ModuleList([ EncoderLayer(d_model, n_head, d_ff, dropout) for _ in range(num_layers) ]) # 解码器部分 self.decoder_embedding = TransformerEmbedding(tgt_vocab_size, d_model, max_seq_len, dropout) self.decoder_layers = nn.ModuleList([ DecoderLayer(d_model, n_head, d_ff, dropout) for _ in range(num_layers) ]) # 输出层 self.output_linear = nn.Linear(d_model, tgt_vocab_size) def forward(self, src, tgt, src_mask=None, tgt_mask=None): # 编码器处理 enc_output = self.encoder_embedding(src) for layer in self.encoder_layers: enc_output = layer(enc_output, src_mask) # 解码器处理 dec_output = self.decoder_embedding(tgt) for layer in self.decoder_layers: dec_output = layer(dec_output, enc_output, src_mask, tgt_mask) # 输出预测 output = self.output_linear(dec_output) return output模型配置示例:
config = { 'src_vocab_size': 10000, # 源语言词汇表大小 'tgt_vocab_size': 10000, # 目标语言词汇表大小 'd_model': 512, # 模型维度 'n_head': 8, # 注意力头数 'num_layers': 6, # 编码器/解码器层数 'd_ff': 2048, # 前馈网络中间层维度 'max_seq_len': 100, # 最大序列长度 'dropout': 0.1 # dropout率 } model = Transformer(**config) print(f"模型参数量: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")8. 训练技巧与调试建议
实现Transformer后,训练过程也需要特别注意以下几点:
学习率调度:使用warmup策略
class WarmupScheduler: def __init__(self, d_model, warmup_steps=4000): self.d_model = d_model self.warmup_steps = warmup_steps self.current_step = 0 def step(self): self.current_step += 1 return (self.d_model ** -0.5) * min( self.current_step ** -0.5, self.current_step * (self.warmup_steps ** -1.5) )标签平滑:缓解过拟合
criterion = nn.KLDivLoss(reduction='batchmean') smoothed_labels = (1 - epsilon) * one_hot_labels + epsilon / num_classes梯度裁剪:防止梯度爆炸
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)掩码生成:处理变长序列
def create_padding_mask(seq, pad_idx): return (seq != pad_idx).unsqueeze(1).unsqueeze(2) # [batch_size, 1, 1, seq_len] def create_look_ahead_mask(seq_len): return torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
常见问题排查:
- 注意力权重全为均匀分布:检查缩放因子和初始化
- 训练损失不下降:验证掩码是否正确应用
- 验证集性能波动大:调整学习率和batch size
9. 模型应用示例:机器翻译
让我们看一个简化的机器翻译流程示例:
# 假设我们已经有了预处理好的数据 src_vocab_size = 10000 # 源语言(如德语)词汇表大小 tgt_vocab_size = 10000 # 目标语言(如英语)词汇表大小 # 初始化模型 model = Transformer( src_vocab_size=src_vocab_size, tgt_vocab_size=tgt_vocab_size, d_model=512, n_head=8, num_layers=6, d_ff=2048, max_seq_len=100, dropout=0.1 ) # 训练循环示例 optimizer = torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-9) scheduler = WarmupScheduler(d_model=512) for epoch in range(10): for batch in train_loader: src, tgt_in, tgt_out = batch # 生成掩码 src_mask = create_padding_mask(src, pad_idx=0) tgt_mask = create_padding_mask(tgt_in, pad_idx=0) & \ create_look_ahead_mask(tgt_in.size(1)) # 前向传播 logits = model(src, tgt_in, src_mask, tgt_mask) # 计算损失 loss = F.cross_entropy( logits.view(-1, tgt_vocab_size), tgt_out.view(-1), ignore_index=0 ) # 反向传播 optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step()推理时的解码策略:
- 贪心解码:每一步选择概率最高的词
- Beam Search:保留多个候选序列
- 采样策略:基于温度的参数化采样
10. 模型变体与扩展
原始Transformer有许多改进版本,以下是几种值得关注的变体:
- Transformer-XL:引入循环机制处理长序列
- Reformer:使用局部敏感哈希(LSH)降低注意力复杂度
- Sparse Transformer:稀疏注意力模式
- Performer:线性注意力近似
例如,我们可以实现一个简单的线性注意力:
class LinearAttention(nn.Module): def __init__(self, d_model, n_head, dropout=0.1): super().__init__() self.d_model = d_model self.n_head = n_head self.d_k = d_model // n_head # 使用随机特征映射近似softmax self.proj = nn.Linear(self.d_k, self.d_k) self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model) def forward(self, Q, K, V, mask=None): batch_size = Q.size(0) # 线性投影 Q = self.W_q(Q).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2) K = self.W_k(K).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2) V = self.W_v(V).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2) # 随机特征映射 Q = torch.relu(self.proj(Q)) K = torch.relu(self.proj(K)) # 线性注意力计算 KV = torch.einsum("bhnd,bhnk->bhdk", K, V) Z = 1 / (torch.einsum("bhnd,bhd->bhn", Q, K.sum(dim=2)) + 1e-6) output = torch.einsum("bhnd,bhdk,bhn->bhnk", Q, KV, Z) # 合并多头 output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model) return self.W_o(output), None通过这种手把手的实现方式,相信你已经对Transformer的内部机制有了更深入的理解。记住,真正掌握一个模型的最佳方式就是亲自实现它,并在实践中不断调试和优化。