news 2026/5/24 1:27:04

别再死记硬背CRF公式了!用Python手写一个BIO命名实体识别Demo,带你直观理解发射与转移矩阵

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
别再死记硬背CRF公式了!用Python手写一个BIO命名实体识别Demo,带你直观理解发射与转移矩阵

用Python从零实现CRF:BIO标注中的发射与转移矩阵实战解析

在自然语言处理领域,命名实体识别(NER)是信息抽取的基础任务之一。当我们第一次接触条件随机场(CRF)时,那些复杂的公式和抽象的概率图模型常常让人望而生畏。本文将通过一个完整的Python实现案例,带您直观理解CRF中最核心的两个概念:发射矩阵(emission matrix)和转移矩阵(transition matrix)。

1. 环境准备与数据构建

首先确保已安装必要的Python库:

import numpy as np import torch from torch import nn import matplotlib.pyplot as plt

我们构建一个极简的中文NER标注示例,采用BIO标注体系:

# 样本数据:句子和对应的BIO标签 sentences = [["吃", "米饭"], ["喝", "汤"]] labels = [["O", "B"], ["O", "B"]]

BIO标注规则简单明了:

  • B:实体开头(Begin)
  • I:实体内部(Inside)
  • O:非实体(Outside)

2. CRF核心组件实现

2.1 标签与特征映射

首先建立标签与索引的双向映射:

tag2idx = {'B':0, 'I':1, 'O':2, '<START>':3, '<END>':4} idx2tag = {v:k for k,v in tag2idx.items()} num_tags = len(tag2idx)

2.2 初始化转移矩阵

转移矩阵定义了标签之间的转换概率:

# 随机初始化转移矩阵 transitions = torch.randn(num_tags, num_tags, requires_grad=True) # 添加约束:B不能直接转B constraint_matrix = torch.ones_like(transitions) constraint_matrix[tag2idx['B'], tag2idx['B']] = 0 # B→B禁止 constrained_transitions = transitions * constraint_matrix

2.3 构建发射矩阵

发射矩阵表示从输入特征到标签的映射概率:

# 简单示例:基于字符的one-hot编码 def char_to_vec(char): return torch.tensor([1 if c == char else 0 for c in ['吃','米','饭','喝','汤']], dtype=torch.float) # 随机初始化发射参数 emission_params = torch.randn(5, num_tags, requires_grad=True)

3. 前向计算与损失函数

3.1 序列得分计算

定义计算序列得分的函数:

def sequence_score(emissions, tags, transitions): score = torch.zeros(1) tags = [tag2idx['<START>']] + [tag2idx[t] for t in tags] + [tag2idx['<END>']] for i in range(len(emissions)): # 发射得分 score += emissions[i, tags[i+1]] # 转移得分 score += transitions[tags[i], tags[i+1]] return score

3.2 计算所有可能路径得分

def total_score(emissions, transitions): # 使用动态规划高效计算 alpha = torch.zeros(num_tags) alpha = transitions[tag2idx['<START>']] + emissions[0] for emission in emissions[1:]: alpha = torch.logsumexp(alpha.unsqueeze(1) + transitions + emission, dim=0) return torch.logsumexp(alpha + transitions[:, tag2idx['<END>']], dim=0)

3.3 定义CRF损失

def crf_loss(emissions, tags, transitions): gold_score = sequence_score(emissions, tags, transitions) total = total_score(emissions, transitions) return total - gold_score

4. 训练与可视化

4.1 训练过程

optimizer = torch.optim.SGD([transitions, emission_params], lr=0.01) for epoch in range(100): total_loss = 0 for sentence, tag_seq in zip(sentences, labels): # 准备发射分数 emissions = torch.stack([emission_params @ char_to_vec(c) for c in sentence]) # 计算损失 loss = crf_loss(emissions, tag_seq, constrained_transitions) total_loss += loss.item() # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step() print(f"Epoch {epoch}, Loss: {total_loss/len(sentences)}")

4.2 矩阵可视化

训练完成后,我们可以可视化学习到的转移矩阵:

def plot_matrix(matrix, title): fig, ax = plt.subplots() cax = ax.matshow(matrix.detach().numpy()) fig.colorbar(cax) ax.set_xticks(range(num_tags)) ax.set_yticks(range(num_tags)) ax.set_xticklabels([idx2tag[i] for i in range(num_tags)]) ax.set_yticklabels([idx2tag[i] for i in range(num_tags)]) plt.title(title) plt.show() plot_matrix(constrained_transitions, "Learned Transition Matrix")

5. 解码与预测

5.1 维特比解码实现

def viterbi_decode(emissions, transitions): backpointers = [] # 初始化 viterbi = transitions[tag2idx['<START>']] + emissions[0] backpointers.append(torch.argmax(viterbi, dim=1)) # 递推 for emission in emissions[1:]: viterbi, backpointer = torch.max(viterbi.unsqueeze(1) + transitions + emission, dim=0) backpointers.append(backpointer) # 终止 best_score, best_tag = torch.max(viterbi + transitions[:, tag2idx['<END>']], dim=0) # 回溯 best_path = [best_tag.item()] for backpointer in reversed(backpointers): best_tag = backpointer[best_tag] best_path.append(best_tag.item()) return list(reversed(best_path))[1:]

5.2 预测示例

test_sentence = ["喝", "可乐"] emissions = torch.stack([emission_params @ char_to_vec(c) for c in test_sentence]) best_path = viterbi_decode(emissions, constrained_transitions) print("预测标签序列:", [idx2tag[idx] for idx in best_path])

6. 工程实践中的优化技巧

在实际项目中,我们还需要考虑以下优化点:

  • 特征工程:除了字符本身,可以加入词性、上下文窗口等特征
  • 批量处理:实现批量化计算提升训练效率
  • 正则化:添加L2正则防止过拟合
  • 学习率调度:使用学习率衰减策略
  • 早停机制:基于验证集性能提前终止训练
# 示例:添加L2正则化 def regularized_loss(emissions, tags, transitions, l2_lambda=0.01): base_loss = crf_loss(emissions, tags, transitions) l2_reg = torch.norm(transitions, p=2) + torch.norm(emission_params, p=2) return base_loss + l2_lambda * l2_reg

7. 扩展与进阶

理解基础CRF实现后,可以进一步探索:

  • BiLSTM-CRF:结合神经网络自动学习特征表示
  • BERT-CRF:利用预训练语言模型提升性能
  • 半监督学习:利用未标注数据提升模型泛化能力
  • 领域适应:将通用NER模型迁移到特定领域
# BiLSTM-CRF架构示意 class BiLSTM_CRF(nn.Module): def __init__(self, vocab_size, tag2idx): super().__init__() self.embedding = nn.Embedding(vocab_size, 64) self.lstm = nn.LSTM(64, 64//2, bidirectional=True) self.hidden2tag = nn.Linear(64, len(tag2idx)) self.crf = CRF(len(tag2idx)) def forward(self, x): embeds = self.embedding(x) lstm_out, _ = self.lstm(embeds.view(len(x), 1, -1)) emissions = self.hidden2tag(lstm_out.view(len(x), -1)) return emissions
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/24 1:21:22

用AI解决电源最复杂PDN问题的实战设计案例

用AI解决电源最复杂PDN问题的实战设计案例在过去的几年里&#xff0c;我们见证了AI在图像识别、自然语言处理领域的统治。但在硬件物理设计领域&#xff0c;尤其是电源完整性PI和 信号完整性SI这种顶层物理战场上&#xff0c;AI 似乎一直像个门外汉。为什么&#xff1f;因为硬件…

作者头像 李华
网站建设 2026/5/24 1:20:09

Unity Device Simulator:深度解析UI适配调试核心机制

1. 这个“设备模拟器”不是让你在电脑上玩手游的很多人第一次看到Device Simulator&#xff0c;下意识觉得&#xff1a;“哦&#xff0c;Unity里又出了个能预览手机效果的窗口&#xff1f;”——这理解方向就偏了。它根本不是个“截图预览工具”&#xff0c;而是 Unity 编辑器原…

作者头像 李华
网站建设 2026/5/24 1:09:14

cmake和makefile

一、什么是cmake和makefile简单来说&#xff0c;CMake 是一个用来生成 Makefile 的工具。它们是构建 C/C 项目过程中不同层级的工具&#xff0c;通常配合使用&#xff0c;而不是相互替代。&#x1f9f1; Makefile&#xff1a;编译的“施工队”角色&#xff1a; 它是具体的“执行…

作者头像 李华
网站建设 2026/5/24 1:09:11

04-系统技术架构师必备——设计模式在系统架构中的应用

关键词:GoF设计模式、SOLID原则、工厂模式、观察者模式、策略模式、适配器模式、装饰器模式、架构师 设计模式 GoF SOLID原则 系统架构 架构师 面向对象 Java 代码重构 系统技术架构师必备——设计模式在系统架构中的应用 摘要 GoF 23种设计模式是系统技术架构师必须掌握的&…

作者头像 李华
网站建设 2026/5/24 1:07:06

仓储海量货物人车混跑,无感定位并发能力碾压UWB上限瓶颈技术白皮书方案

仓储海量货物人车混跑&#xff0c;无感定位并发能力碾压UWB上限瓶颈技术白皮书方案一、方案概述随着现代智能仓储向高密度、高周转、无人化、集约化模式快速迭代&#xff0c;立体仓储库区普遍形成海量货物堆叠、多叉车穿梭、人员高频作业、人车密集混跑的复杂动态工况。仓储作业…

作者头像 李华
网站建设 2026/5/24 1:03:04

数据科学概述与方法论

数据科学概述与方法论 1. 技术分析 1.1 数据科学概述 数据科学是从数据中提取知识的跨学科领域&#xff1a; 数据科学组成统计学: 数据分析方法机器学习: 预测模型数据工程: 数据处理领域知识: 业务理解数据科学流程:问题定义数据收集数据清洗数据分析模型构建结果部署1.2 CRIS…

作者头像 李华