news 2026/5/25 13:46:33

别再死记硬背Self-Attention公式了!用Python手搓一个Transformer核心模块(附完整代码)

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
别再死记硬背Self-Attention公式了!用Python手搓一个Transformer核心模块(附完整代码)

从零实现Self-Attention:用NumPy拆解Transformer核心逻辑

当第一次看到Transformer论文中那个著名的Self-Attention公式时,相信不少开发者都有过这样的困惑:这些矩阵乘法究竟在做什么?为什么需要Q、K、V三个矩阵?多头机制又该如何理解?本文将通过纯Python实现,带你亲手构建一个可运行的Self-Attention模块,用代码而非公式回答这些问题。

1. 环境准备与基础概念

在开始编码前,我们需要明确几个关键概念。Self-Attention的本质是一种动态特征加权机制——它让模型能够根据输入序列中不同位置的重要性自动调整关注度。想象你在阅读一段文字时,大脑会不自觉地对关键词给予更多注意力,Self-Attention正是模拟这种认知过程。

必备工具安装

pip install numpy matplotlib

基础实现只需要NumPy库,但为了验证结果准确性,我们可以准备一个对照环境:

import numpy as np from numpy.random import randn # 设置随机种子保证可复现性 np.random.seed(42)

2. 单头注意力实现

让我们从最基础的Scaled Dot-Product Attention开始。这个名称包含三个关键信息:

  • Dot-Product:使用点积计算相似度
  • Scaled:对结果进行缩放防止梯度消失
  • Attention:最终形成注意力权重

2.1 输入投影层

首先实现将输入转换为Q、K、V的线性变换:

def input_projection(X, d_model=64, d_k=8, d_v=8): """ X: 输入序列 [batch_size, seq_len, d_model] 返回: Q, K, V 投影矩阵 """ batch_size, seq_len, _ = X.shape WQ = randn(d_model, d_k) * 0.1 WK = randn(d_model, d_k) * 0.1 WV = randn(d_model, d_v) * 0.1 Q = X @ WQ # [batch_size, seq_len, d_k] K = X @ WK # [batch_size, seq_len, d_k] V = X @ WV # [batch_size, seq_len, d_v] return Q, K, V

2.2 注意力计算核心

接下来实现注意力权重的计算过程:

def scaled_dot_product_attention(Q, K, V, mask=None): d_k = Q.shape[-1] scores = Q @ K.transpose(0,1,3,2) / np.sqrt(d_k) # [batch_size, seq_len, seq_len] if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) weights = softmax(scores, axis=-1) # 沿最后一个维度做softmax output = weights @ V # [batch_size, seq_len, d_v] return output, weights def softmax(x, axis=-1): e_x = np.exp(x - np.max(x, axis=axis, keepdims=True)) return e_x / e_x.sum(axis=axis, keepdims=True)

注意:实际应用中需要添加mask机制处理变长序列,这里简化实现

3. 多头注意力机制

多头注意力就像让模型拥有多组"眼睛",可以从不同角度观察数据。以下是关键实现步骤:

3.1 多头投影

class MultiHeadAttention: def __init__(self, d_model=64, h=8): self.d_model = d_model self.h = h assert d_model % h == 0, "d_model必须能被h整除" self.d_k = d_model // h self.WQ = randn(d_model, d_model) * 0.1 self.WK = randn(d_model, d_model) * 0.1 self.WV = randn(d_model, d_model) * 0.1 self.WO = randn(d_model, d_model) * 0.1 def split_heads(self, x): batch_size = x.shape[0] return x.reshape(batch_size, -1, self.h, self.d_k).transpose(0,2,1,3) def forward(self, X, mask=None): Q = X @ self.WQ K = X @ self.WK V = X @ self.WV Q = self.split_heads(Q) # [batch_size, h, seq_len, d_k] K = self.split_heads(K) V = self.split_heads(V) # 计算缩放点积注意力 attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask) # 合并多头结果 batch_size = attn_output.shape[0] attn_output = attn_output.transpose(0,2,1,3).reshape(batch_size, -1, self.d_model) return attn_output @ self.WO, attn_weights

3.2 效果验证

让我们用实际数据测试这个实现:

# 模拟输入数据 batch_size = 2 seq_len = 10 d_model = 64 X = randn(batch_size, seq_len, d_model) # 初始化多头注意力 mha = MultiHeadAttention(d_model=d_model, h=8) # 前向计算 output, weights = mha.forward(X) print(f"输入形状: {X.shape}") print(f"输出形状: {output.shape}") print(f"注意力权重形状: {weights.shape}")

典型输出结果:

输入形状: (2, 10, 64) 输出形状: (2, 10, 64) 注意力权重形状: (2, 8, 10, 10)

4. 与框架实现对比

为了验证我们的实现是否正确,可以与PyTorch官方实现进行对比:

import torch import torch.nn as nn # 使用相同输入数据 X_torch = torch.from_numpy(X) # PyTorch多头注意力层 mha_torch = nn.MultiheadAttention(embed_dim=d_model, num_heads=8, batch_first=True) output_torch, _ = mha_torch(X_torch, X_torch, X_torch) # 比较结果差异 diff = np.mean(np.abs(output.detach().numpy() - output_torch.detach().numpy())) print(f"与PyTorch实现的平均差异: {diff:.6f}")

提示:实际差异主要来自初始化方式不同,核心计算逻辑应该保持一致

5. 常见问题与调试技巧

在实现过程中,开发者常遇到以下几个典型问题:

5.1 梯度消失问题

当维度较大时,点积结果可能变得极大,导致softmax后某些位置接近0或1。解决方法:

  • 使用缩放因子(1/√d_k)
  • 添加微小epsilon值防止数值不稳定

5.2 内存占用优化

注意力矩阵的大小为O(n²),对于长序列:

  • 实现分块计算
  • 使用稀疏注意力模式
  • 考虑线性注意力变体

5.3 训练不稳定

多头注意力可能出现的训练问题:

# 解决方案示例:添加LayerNorm class TransformerLayer: def __init__(self, d_model, h): self.self_attn = MultiHeadAttention(d_model, h) self.norm1 = LayerNorm(d_model) self.norm2 = LayerNorm(d_model) def forward(self, x): attn_out, _ = self.self_attn(x) x = self.norm1(x + attn_out) return x

6. 完整实现与扩展

将上述模块组合成完整实现:

class TransformerEncoderLayer: def __init__(self, d_model=512, h=8, d_ff=2048): self.self_attn = MultiHeadAttention(d_model, h) self.ffn = PositionwiseFFN(d_model, d_ff) self.norm1 = LayerNorm(d_model) self.norm2 = LayerNorm(d_model) def forward(self, x, mask=None): # 自注意力子层 attn_out, _ = self.self_attn(x, mask) x = self.norm1(x + attn_out) # 前馈网络子层 ffn_out = self.ffn(x) x = self.norm2(x + ffn_out) return x class PositionwiseFFN: def __init__(self, d_model, d_ff): self.w1 = randn(d_model, d_ff) * 0.1 self.w2 = randn(d_ff, d_model) * 0.1 def forward(self, x): return x @ self.w1 @ self.w2

在实际项目中,这种从零实现的方式虽然性能不如优化后的框架代码,但它带来的理解深度是无可替代的。当我在处理一个序列标注任务时,正是通过这种手写实现,才真正理解了为什么某些位置的注意力权重会异常偏高——原来是输入数据中存在特殊标记导致的注意力聚焦。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/25 13:44:47

免费抖音批量下载神器:一键保存无水印视频完整指南

免费抖音批量下载神器:一键保存无水印视频完整指南 【免费下载链接】douyin-downloader A practical Douyin downloader for both single-item and profile batch downloads, with progress display, retries, SQLite deduplication, and browser fallback support.…

作者头像 李华
网站建设 2026/5/25 13:44:43

如何彻底解决AutoCAD字体缺失问题:FontCenter免费插件终极指南

如何彻底解决AutoCAD字体缺失问题:FontCenter免费插件终极指南 【免费下载链接】FontCenter AutoCAD自动管理字体插件 项目地址: https://gitcode.com/gh_mirrors/fo/FontCenter 还在为AutoCAD图纸打开时出现"字体缺失"警告而烦恼吗?Fo…

作者头像 李华
网站建设 2026/5/25 13:42:59

DDrawCompat:3分钟让Windows老游戏重获新生的终极解决方案

DDrawCompat:3分钟让Windows老游戏重获新生的终极解决方案 【免费下载链接】DDrawCompat DirectDraw and Direct3D 1-7 compatibility, performance and visual enhancements for Windows Vista, 7, 8, 10 and 11 项目地址: https://gitcode.com/gh_mirrors/dd/DD…

作者头像 李华