news 2026/5/6 13:52:32

别再死记硬背Transformer结构了!用PyTorch手写一个,从Self-Attention到完整模型

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
别再死记硬背Transformer结构了!用PyTorch手写一个,从Self-Attention到完整模型

用PyTorch手写Transformer:从Self-Attention到完整模型的实战指南

当你第一次接触Transformer时,是否曾被那些复杂的矩阵运算和维度变换搞得晕头转向?别担心,今天我们将用PyTorch从零开始构建一个完整的Transformer模型,通过代码实践带你真正理解这个革命性架构的精髓。不同于单纯的理论讲解,我们将采用"实现-解释-调试"的循环方式,确保每个组件都清晰可理解。

1. 环境准备与基础概念

在开始编码之前,我们需要明确几个核心概念。Transformer本质上是一个基于自注意力机制的序列到序列模型,它彻底改变了传统RNN处理序列数据的方式。让我们先准备好开发环境:

import torch import torch.nn as nn import math import numpy as np from torch.nn import functional as F print(f"PyTorch版本: {torch.__version__}") print(f"CUDA可用: {torch.cuda.is_available()}")

Transformer的核心创新在于完全摒弃了循环结构,转而使用自注意力机制来捕捉序列中元素间的长距离依赖关系。这种设计带来了几个显著优势:

  • 并行计算能力:不再受限于序列的时序依赖
  • 全局视野:每个位置都能直接访问序列中所有其他位置的信息
  • 灵活的特征组合:通过多头机制学习不同子空间的特征表示

2. 实现Scaled Dot-Product Attention

自注意力机制是Transformer的基础构建块,我们先从最核心的缩放点积注意力开始实现。这个模块计算query和key的相似度,然后对value进行加权求和。

class ScaledDotProductAttention(nn.Module): def __init__(self, dropout=0.1): super().__init__() self.dropout = nn.Dropout(dropout) def forward(self, Q, K, V, mask=None): """ 参数: Q: Query矩阵 [batch_size, n_head, seq_len, d_k] K: Key矩阵 [batch_size, n_head, seq_len, d_k] V: Value矩阵 [batch_size, n_head, seq_len, d_v] mask: 可选掩码 [batch_size, 1, seq_len, seq_len] 返回: output: 注意力输出 [batch_size, n_head, seq_len, d_v] attn_weights: 注意力权重 [batch_size, n_head, seq_len, seq_len] """ d_k = Q.size(-1) # 计算QK^T/sqrt(d_k) scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k) # 应用掩码(如需要) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) # softmax归一化得到注意力权重 attn_weights = F.softmax(scores, dim=-1) attn_weights = self.dropout(attn_weights) # 对Value加权求和 output = torch.matmul(attn_weights, V) return output, attn_weights

关键点解析

  1. 缩放因子1/sqrt(d_k)的作用是防止点积结果过大导致softmax梯度消失
  2. 掩码机制在decoder中用于防止当前位置关注到未来信息
  3. dropout应用于注意力权重,是一种有效的正则化手段

为了验证我们的实现,让我们构造一些测试数据:

batch_size, n_head, seq_len, d_k, d_v = 2, 3, 5, 64, 64 Q = torch.randn(batch_size, n_head, seq_len, d_k) K = torch.randn(batch_size, n_head, seq_len, d_k) V = torch.randn(batch_size, n_head, seq_len, d_v) attention = ScaledDotProductAttention() output, attn_weights = attention(Q, K, V) print(f"输出形状: {output.shape}") # 应为 [2, 3, 5, 64] print(f"注意力权重形状: {attn_weights.shape}") # 应为 [2, 3, 5, 5]

3. 构建Multi-Head Attention

多头注意力允许模型在不同的表示子空间中学习信息,这是Transformer强大表征能力的关键。下面是具体实现:

class MultiHeadAttention(nn.Module): def __init__(self, d_model, n_head, dropout=0.1): super().__init__() assert d_model % n_head == 0, "d_model必须能被n_head整除" self.d_model = d_model self.n_head = n_head self.d_k = d_model // n_head self.d_v = d_model // n_head # 线性变换层 self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model) self.attention = ScaledDotProductAttention(dropout) self.dropout = nn.Dropout(dropout) self.layer_norm = nn.LayerNorm(d_model) def forward(self, Q, K, V, mask=None): """ 参数: Q: Query [batch_size, seq_len, d_model] K: Key [batch_size, seq_len, d_model] V: Value [batch_size, seq_len, d_model] mask: 可选掩码 [batch_size, seq_len, seq_len] 返回: output: [batch_size, seq_len, d_model] attn_weights: [batch_size, n_head, seq_len, seq_len] """ batch_size = Q.size(0) # 1. 线性投影 + 分头 Q = self.W_q(Q).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2) K = self.W_k(K).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2) V = self.W_v(V).view(batch_size, -1, self.n_head, self.d_v).transpose(1, 2) # 2. 计算缩放点积注意力 if mask is not None: mask = mask.unsqueeze(1) # 为多头扩展维度 x, attn_weights = self.attention(Q, K, V, mask=mask) # 3. 合并多头 + 最终线性变换 x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model) output = self.W_o(x) return output, attn_weights

实现细节

  • 通过线性变换将输入投影到多个子空间
  • 每个头独立计算注意力,最后合并结果
  • 保留了残差连接和层归一化的接口(将在完整模型中实现)

测试多头注意力的行为:

d_model, n_head = 512, 8 mha = MultiHeadAttention(d_model, n_head) Q = torch.randn(3, 10, d_model) # [batch_size, seq_len, d_model] K = V = Q # 自注意力情况下三者相同 output, attn = mha(Q, K, V) print(f"多头注意力输出形状: {output.shape}") # [3, 10, 512]

4. 位置编码与嵌入层

Transformer需要显式地注入位置信息,因为自注意力机制本身不具备感知序列顺序的能力。我们实现正弦位置编码:

class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=5000): super().__init__() position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe = torch.zeros(max_len, d_model) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) # [1, max_len, d_model] self.register_buffer('pe', pe) def forward(self, x): """ 参数: x: [batch_size, seq_len, d_model] 返回: 添加位置编码后的张量 """ return x + self.pe[:, :x.size(1)]

结合词嵌入和位置编码的完整输入处理:

class TransformerEmbedding(nn.Module): def __init__(self, vocab_size, d_model, max_len, dropout=0.1): super().__init__() self.token_embedding = nn.Embedding(vocab_size, d_model) self.position_encoding = PositionalEncoding(d_model, max_len) self.dropout = nn.Dropout(dropout) self.layer_norm = nn.LayerNorm(d_model) def forward(self, x): # x: [batch_size, seq_len] token_emb = self.token_embedding(x) # [batch_size, seq_len, d_model] pos_emb = self.position_encoding(token_emb) output = self.layer_norm(pos_emb) output = self.dropout(output) return output

位置编码的直观理解

  • 每个位置对应一个独特的"波形"编码
  • 不同维度对应不同频率的正弦/余弦函数
  • 模型可以学习利用这些位置信息

5. 前馈神经网络与残差连接

Transformer中的前馈网络实际上是一个两层的MLP,应用于每个位置独立:

class PositionwiseFeedForward(nn.Module): def __init__(self, d_model, d_ff, dropout=0.1): super().__init__() self.linear1 = nn.Linear(d_model, d_ff) self.linear2 = nn.Linear(d_ff, d_model) self.dropout = nn.Dropout(dropout) self.layer_norm = nn.LayerNorm(d_model) def forward(self, x): # x: [batch_size, seq_len, d_model] residual = x x = F.relu(self.linear1(x)) x = self.dropout(x) x = self.linear2(x) x = self.dropout(x) x = self.layer_norm(x + residual) return x

设计要点

  • 中间层的维度d_ff通常比d_model大(如4倍)
  • 使用ReLU激活函数
  • 应用了残差连接和层归一化

6. 构建Encoder层与Decoder层

现在我们可以组装完整的Encoder层和Decoder层了:

class EncoderLayer(nn.Module): def __init__(self, d_model, n_head, d_ff, dropout=0.1): super().__init__() self.self_attn = MultiHeadAttention(d_model, n_head, dropout) self.ffn = PositionwiseFeedForward(d_model, d_ff, dropout) self.dropout = nn.Dropout(dropout) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) def forward(self, x, mask=None): # 自注意力子层 residual = x x, _ = self.self_attn(x, x, x, mask) x = self.dropout(x) x = self.norm1(x + residual) # 前馈网络子层 residual = x x = self.ffn(x) x = self.norm2(x + residual) return x class DecoderLayer(nn.Module): def __init__(self, d_model, n_head, d_ff, dropout=0.1): super().__init__() self.self_attn = MultiHeadAttention(d_model, n_head, dropout) self.cross_attn = MultiHeadAttention(d_model, n_head, dropout) self.ffn = PositionwiseFeedForward(d_model, d_ff, dropout) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) def forward(self, x, enc_output, src_mask=None, tgt_mask=None): # 自注意力子层(带目标序列掩码) residual = x x, _ = self.self_attn(x, x, x, tgt_mask) x = self.dropout(x) x = self.norm1(x + residual) # 交叉注意力子层(Query来自解码器,Key/Value来自编码器) if enc_output is not None: residual = x x, _ = self.cross_attn(x, enc_output, enc_output, src_mask) x = self.dropout(x) x = self.norm2(x + residual) # 前馈网络子层 residual = x x = self.ffn(x) x = self.norm3(x + residual) return x

Decoder的特殊处理

  • 第一层自注意力使用掩码防止信息泄露
  • 第二层交叉注意力连接编码器和解码器
  • 每子层都有残差连接和层归一化

7. 组装完整Transformer模型

现在我们可以将所有组件组合成完整的Transformer模型:

class Transformer(nn.Module): def __init__(self, src_vocab_size, tgt_vocab_size, d_model, n_head, num_layers, d_ff, max_seq_len, dropout=0.1): super().__init__() # 编码器部分 self.encoder_embedding = TransformerEmbedding(src_vocab_size, d_model, max_seq_len, dropout) self.encoder_layers = nn.ModuleList([ EncoderLayer(d_model, n_head, d_ff, dropout) for _ in range(num_layers) ]) # 解码器部分 self.decoder_embedding = TransformerEmbedding(tgt_vocab_size, d_model, max_seq_len, dropout) self.decoder_layers = nn.ModuleList([ DecoderLayer(d_model, n_head, d_ff, dropout) for _ in range(num_layers) ]) # 输出层 self.output_linear = nn.Linear(d_model, tgt_vocab_size) def forward(self, src, tgt, src_mask=None, tgt_mask=None): # 编码器处理 enc_output = self.encoder_embedding(src) for layer in self.encoder_layers: enc_output = layer(enc_output, src_mask) # 解码器处理 dec_output = self.decoder_embedding(tgt) for layer in self.decoder_layers: dec_output = layer(dec_output, enc_output, src_mask, tgt_mask) # 输出预测 output = self.output_linear(dec_output) return output

模型配置示例

config = { 'src_vocab_size': 10000, # 源语言词汇表大小 'tgt_vocab_size': 10000, # 目标语言词汇表大小 'd_model': 512, # 模型维度 'n_head': 8, # 注意力头数 'num_layers': 6, # 编码器/解码器层数 'd_ff': 2048, # 前馈网络中间层维度 'max_seq_len': 100, # 最大序列长度 'dropout': 0.1 # dropout率 } model = Transformer(**config) print(f"模型参数量: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

8. 训练技巧与调试建议

实现Transformer后,训练过程也需要特别注意以下几点:

  1. 学习率调度:使用warmup策略

    class WarmupScheduler: def __init__(self, d_model, warmup_steps=4000): self.d_model = d_model self.warmup_steps = warmup_steps self.current_step = 0 def step(self): self.current_step += 1 return (self.d_model ** -0.5) * min( self.current_step ** -0.5, self.current_step * (self.warmup_steps ** -1.5) )
  2. 标签平滑:缓解过拟合

    criterion = nn.KLDivLoss(reduction='batchmean') smoothed_labels = (1 - epsilon) * one_hot_labels + epsilon / num_classes
  3. 梯度裁剪:防止梯度爆炸

    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
  4. 掩码生成:处理变长序列

    def create_padding_mask(seq, pad_idx): return (seq != pad_idx).unsqueeze(1).unsqueeze(2) # [batch_size, 1, 1, seq_len] def create_look_ahead_mask(seq_len): return torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()

常见问题排查

  • 注意力权重全为均匀分布:检查缩放因子和初始化
  • 训练损失不下降:验证掩码是否正确应用
  • 验证集性能波动大:调整学习率和batch size

9. 模型应用示例:机器翻译

让我们看一个简化的机器翻译流程示例:

# 假设我们已经有了预处理好的数据 src_vocab_size = 10000 # 源语言(如德语)词汇表大小 tgt_vocab_size = 10000 # 目标语言(如英语)词汇表大小 # 初始化模型 model = Transformer( src_vocab_size=src_vocab_size, tgt_vocab_size=tgt_vocab_size, d_model=512, n_head=8, num_layers=6, d_ff=2048, max_seq_len=100, dropout=0.1 ) # 训练循环示例 optimizer = torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-9) scheduler = WarmupScheduler(d_model=512) for epoch in range(10): for batch in train_loader: src, tgt_in, tgt_out = batch # 生成掩码 src_mask = create_padding_mask(src, pad_idx=0) tgt_mask = create_padding_mask(tgt_in, pad_idx=0) & \ create_look_ahead_mask(tgt_in.size(1)) # 前向传播 logits = model(src, tgt_in, src_mask, tgt_mask) # 计算损失 loss = F.cross_entropy( logits.view(-1, tgt_vocab_size), tgt_out.view(-1), ignore_index=0 ) # 反向传播 optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step()

推理时的解码策略

  • 贪心解码:每一步选择概率最高的词
  • Beam Search:保留多个候选序列
  • 采样策略:基于温度的参数化采样

10. 模型变体与扩展

原始Transformer有许多改进版本,以下是几种值得关注的变体:

  1. Transformer-XL:引入循环机制处理长序列
  2. Reformer:使用局部敏感哈希(LSH)降低注意力复杂度
  3. Sparse Transformer:稀疏注意力模式
  4. Performer:线性注意力近似

例如,我们可以实现一个简单的线性注意力:

class LinearAttention(nn.Module): def __init__(self, d_model, n_head, dropout=0.1): super().__init__() self.d_model = d_model self.n_head = n_head self.d_k = d_model // n_head # 使用随机特征映射近似softmax self.proj = nn.Linear(self.d_k, self.d_k) self.W_q = nn.Linear(d_model, d_model) self.W_k = nn.Linear(d_model, d_model) self.W_v = nn.Linear(d_model, d_model) self.W_o = nn.Linear(d_model, d_model) def forward(self, Q, K, V, mask=None): batch_size = Q.size(0) # 线性投影 Q = self.W_q(Q).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2) K = self.W_k(K).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2) V = self.W_v(V).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2) # 随机特征映射 Q = torch.relu(self.proj(Q)) K = torch.relu(self.proj(K)) # 线性注意力计算 KV = torch.einsum("bhnd,bhnk->bhdk", K, V) Z = 1 / (torch.einsum("bhnd,bhd->bhn", Q, K.sum(dim=2)) + 1e-6) output = torch.einsum("bhnd,bhdk,bhn->bhnk", Q, KV, Z) # 合并多头 output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model) return self.W_o(output), None

通过这种手把手的实现方式,相信你已经对Transformer的内部机制有了更深入的理解。记住,真正掌握一个模型的最佳方式就是亲自实现它,并在实践中不断调试和优化。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/6 13:50:46

实战应用:基于快马平台构建带回收站功能的c盘管家软件

实战应用:基于快马平台构建带回收站功能的C盘管家软件 最近我的C盘又亮起了红色警告,这已经是今年第三次了。每次手动清理都特别麻烦,要小心翼翼地避开系统文件,还得担心误删重要文档。于是我想,为什么不自己开发一个…

作者头像 李华
网站建设 2026/5/6 13:50:16

如何快速绕过MTK设备保护?这个Python工具3步搞定

如何快速绕过MTK设备保护?这个Python工具3步搞定 【免费下载链接】bypass_utility 项目地址: https://gitcode.com/gh_mirrors/by/bypass_utility 你是否曾遇到过MTK设备刷机失败,提示"保护机制已启用"的困扰?当你想为联发…

作者头像 李华
网站建设 2026/5/6 13:50:10

告别OV2640颜色错乱:深入STM32 DCMI的RGB565数据格式与LSB/MSB配置详解

告别OV2640颜色错乱:深入STM32 DCMI的RGB565数据格式与LSB/MSB配置详解 当你在STM32平台上成功驱动OV2640摄像头后,最令人沮丧的莫过于屏幕上出现的红蓝颜色错位——本该湛蓝的天空呈现诡异的紫红色,而红色物体却变成了深蓝色。这种颜色错乱问…

作者头像 李华
网站建设 2026/5/6 13:50:07

手机号快速查询QQ号:终极简单解决方案完整指南

手机号快速查询QQ号:终极简单解决方案完整指南 【免费下载链接】phone2qq 项目地址: https://gitcode.com/gh_mirrors/ph/phone2qq 你是否曾经因为忘记QQ号而无法登录账号?手机号查QQ号工具为你提供了一种快速、免费的解决方案。这款基于Python开…

作者头像 李华
网站建设 2026/5/6 13:45:48

微信小程序定位开发全流程:从wx.getLocation申请到app.json配置避坑指南

微信小程序定位功能开发实战:从权限申请到高精度定位优化 校园导航、外卖配送、共享单车…这些我们每天使用的小程序服务,都离不开一个核心技术——地理位置定位。作为开发者,当你兴致勃勃地写完了wx.getLocation的调用代码,却在真…

作者头像 李华