news 2026/5/6 15:27:38

别再只用MNIST了!Permuted/Split MNIST数据集实战:用PyTorch搭建你的第一个连续学习评估环境

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
别再只用MNIST了!Permuted/Split MNIST数据集实战:用PyTorch搭建你的第一个连续学习评估环境

突破传统MNIST:用PyTorch构建连续学习实战环境的完整指南

当大多数机器学习教程还在使用标准MNIST数据集时,前沿研究早已转向更具挑战性的变体。Permuted MNIST和Split MNIST不仅是学术论文中的常客,更是检验连续学习算法性能的黄金标准。本文将带你从零开始,用PyTorch搭建完整的连续学习评估环境,通过代码实践理解这两种经典数据集的精髓。

1. 为什么需要超越标准MNIST?

传统MNIST数据集作为机器学习入门教材已有二十余年历史,其简单性既是优点也是局限。784个像素点的固定排列方式让模型很快就能达到接近人类水平的准确率,但这恰恰掩盖了现实世界中的核心挑战——数据分布的动态变化。

连续学习研究需要能模拟以下场景的数据集:

  • 任务序列化:模型需要按顺序学习多个相关但不完全相同的任务
  • 灾难性遗忘:新任务学习时旧任务性能的下降程度需要量化
  • 知识迁移:先前学到的知识如何帮助后续任务学习

提示:Permuted MNIST通过像素重排改变输入分布,Split MNIST通过类别划分创建任务序列,两者分别对应Domain-IL和Class-IL场景

下表对比了三种MNIST变体的核心差异:

数据集类型变化维度适用场景挑战重点
标准MNIST无变化基础分类单一任务性能
Permuted MNIST像素排列Domain-IL分布变化适应
Split MNIST类别划分Class-IL类别增量学习

2. 环境准备与基础配置

开始前确保已安装最新版PyTorch和标准科学计算库:

import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms import numpy as np import matplotlib.pyplot as plt

定义基础多层感知机(MLP)模型,这是连续学习研究的常用架构:

class BaseMLP(nn.Module): def __init__(self, input_size=784, hidden_size=400, output_size=10): super(BaseMLP, self).__init__() self.fc1 = nn.Linear(input_size, hidden_size) self.fc2 = nn.Linear(hidden_size, hidden_size) self.fc3 = nn.Linear(hidden_size, output_size) self.relu = nn.ReLU() def forward(self, x): x = x.view(x.size(0), -1) # 展平输入 x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) return self.fc3(x)

3. Permuted MNIST实战实现

Permuted MNIST的核心是为每个任务生成唯一的像素排列顺序。以下是关键实现步骤:

  1. 生成排列矩阵:为每个任务创建随机但固定的像素排列
  2. 数据转换器:应用排列并标准化图像
  3. 任务序列构建:创建多个不同排列的数据加载器
def generate_permutation(): """生成随机像素排列顺序""" return torch.randperm(784) class PermuteTransform: """应用像素排列的数据转换器""" def __init__(self, permutation): self.permutation = permutation def __call__(self, x): return x.view(-1)[self.permutation].view(1, 28, 28) # 创建5个不同排列的任务 num_tasks = 5 permutations = [generate_permutation() for _ in range(num_tasks)] task_loaders = [] for perm in permutations: transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)), PermuteTransform(perm) ]) train_loader = torch.utils.data.DataLoader( datasets.MNIST('../data', train=True, download=True, transform=transform), batch_size=64, shuffle=True) task_loaders.append(train_loader)

可视化不同排列下的图像样本:

def show_permuted_samples(permutations, num_samples=5): """展示不同排列下的MNIST样本""" fig, axes = plt.subplots(len(permutations), num_samples, figsize=(15, 10)) for i, perm in enumerate(permutations): transform = PermuteTransform(perm) for j in range(num_samples): img, _ = datasets.MNIST('../data', train=True)[j] axes[i,j].imshow(transform(img).squeeze(), cmap='gray') axes[i,j].axis('off') plt.show() show_permuted_samples(permutations[:3]) # 展示前3种排列

4. Split MNIST的精细实现

Split MNIST将10个数字类别划分为多个二元分类任务,典型划分方式如下:

  • Task 1: 识别0和1
  • Task 2: 识别2和3
  • Task 3: 识别4和5
  • Task 4: 识别6和7
  • Task 5: 识别8和9

实现关键点在于数据过滤和标签重映射:

def create_split_mnist_loaders(): """创建Split MNIST任务序列""" base_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) # 定义任务划分 (数字对列表) task_pairs = [(0,1), (2,3), (4,5), (6,7), (8,9)] task_loaders = [] for pair in task_pairs: # 过滤只包含当前数字对的样本 def filter_func(data, target): mask = (target == pair[0]) | (target == pair[1]) return data[mask], target[mask] # 重映射标签为0/1 def remap_labels(target): return (target == pair[1]).long() # 自定义数据集 class SplitMNIST(torch.utils.data.Dataset): def __init__(self, train=True): self.mnist = datasets.MNIST('../data', train=train, download=True) self.data, self.targets = filter_func(self.mnist.data, self.mnist.targets) self.targets = remap_labels(self.targets) self.transform = base_transform def __len__(self): return len(self.data) def __getitem__(self, idx): img, target = self.data[idx], self.targets[idx] img = self.transform(img.numpy()) return img, target train_loader = torch.utils.data.DataLoader( SplitMNIST(train=True), batch_size=64, shuffle=True) task_loaders.append(train_loader) return task_loaders split_loaders = create_split_mnist_loaders()

5. 连续学习评估框架

完整的连续学习评估需要跟踪以下指标:

  • 当前任务准确率
  • 旧任务遗忘程度
  • 整体平均准确率

实现评估流程的核心代码:

def evaluate(model, task_id, test_loaders, device='cpu'): """评估模型在所有已学习任务上的表现""" model.eval() accuracies = [] with torch.no_grad(): for t in range(task_id + 1): correct = 0 total = 0 for images, labels in test_loaders[t]: images, labels = images.to(device), labels.to(device) outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() accuracies.append(100 * correct / total) return accuracies def continual_learning_training(model, task_loaders, num_epochs=5): """连续学习训练流程""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) optimizer = optim.Adam(model.parameters(), lr=0.001) criterion = nn.CrossEntropyLoss() all_accuracies = [] for task_id, train_loader in enumerate(task_loaders): print(f"\n=== Training on Task {task_id + 1} ===") for epoch in range(num_epochs): model.train() running_loss = 0.0 for images, labels in train_loader: images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader):.4f}") # 评估当前模型在所有已学习任务上的表现 current_acc = evaluate(model, task_id, task_loaders, device) all_accuracies.append(current_acc) print(f"Task Accuracies after Task {task_id+1}: {current_acc}") return all_accuracies

6. 高级技巧与优化方向

基础实现后,可以考虑以下增强方案:

对抗遗忘技术

  • 弹性权重固化(EWC):添加正则项保护重要参数
  • 记忆回放:保留少量旧任务样本进行联合训练
  • 动态架构:为每个任务分配专用模型组件
# EWC实现示例 def ewc_loss(model, fisher_matrix, previous_params, lambda_ewc): loss = 0 for name, param in model.named_parameters(): if name in fisher_matrix: loss += (fisher_matrix[name] * (param - previous_params[name])**2).sum() return lambda_ewc * loss # 在训练循环中添加EWC损失 total_loss = criterion(outputs, labels) + ewc_loss(model, fisher, prev_params, lambda_ewc=1000)

评估指标可视化

def plot_learning_curve(accuracies): """绘制连续学习曲线""" plt.figure(figsize=(10, 6)) for task in range(len(accuracies)): x = range(task + 1, len(accuracies) + 1) y = [acc[task] for acc in accuracies[task:]] plt.plot(x, y, marker='o', label=f'Task {task+1}') plt.xlabel('Task Number') plt.ylabel('Accuracy (%)') plt.title('Continual Learning Performance') plt.legend() plt.grid() plt.show()

实际项目中,我发现Permuted MNIST对初始化种子非常敏感,不同排列顺序可能导致性能波动达5-8%。解决方法是固定随机种子或进行多次实验取平均。Split MNIST则面临类别不平衡问题,某些数字对的样本量可能相差20%,需要在损失函数中添加类别权重。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/6 15:26:18

3个场景告诉你:PowerToys文本提取器如何成为你的数字助手

3个场景告诉你:PowerToys文本提取器如何成为你的数字助手 【免费下载链接】PowerToys Microsoft PowerToys is a collection of utilities that supercharge productivity and customization on Windows 项目地址: https://gitcode.com/GitHub_Trending/po/PowerT…

作者头像 李华
网站建设 2026/5/6 15:26:16

Grasscutter Tools:5分钟学会原神私服终极管理指南

Grasscutter Tools:5分钟学会原神私服终极管理指南 【免费下载链接】grasscutter-tools A cross-platform client that combines launcher, command generation, and mod management to easily play Grasscutter; 一个结合了启动器、命令生成、MOD管理等功能的跨平台…

作者头像 李华
网站建设 2026/5/6 15:25:43

Gitee CodePecker SCA:开源治理的终极解决方案

在数字化转型浪潮中,开源组件已成为软件开发的基石,但随之而来的安全风险也日益凸显。最新行业数据显示,超过90%的企业IT系统依赖开源组件,而其中70%以上的安全漏洞源于开源或第三方组件。从震惊业界的Log4j漏洞到日益猖獗的供应链…

作者头像 李华
网站建设 2026/5/6 15:23:51

LightOnOCR-2-1B高算力适配:CUDA Graph优化OCR推理延迟降低40%

LightOnOCR-2-1B高算力适配:CUDA Graph优化OCR推理延迟降低40% 在OCR应用场景中,推理速度直接影响用户体验。本文将详细介绍如何通过CUDA Graph技术优化LightOnOCR-2-1B模型,实现40%的延迟降低。 1. LightOnOCR-2-1B模型概述 LightOnOCR-2-1…

作者头像 李华
网站建设 2026/5/6 15:23:34

本地AI智能体PocketPaw:开源框架实现数据私有化与自动化

1. 项目概述:一个真正属于你的本地AI智能体 如果你和我一样,对把个人数据、对话历史和任务委托给云端AI服务商这件事,始终心存疑虑,但又眼馋那些能帮你写代码、查资料、管理日程的智能助手,那么PocketPaw的出现&#…

作者头像 李华
网站建设 2026/5/6 15:20:06

【绝密泄露】某省级政务云MCP 2026单节点吞吐量从1.2万TPS飙升至8.7万TPS的3项内核级优化(含sysctl.conf定制模板及验证脚本)

更多请点击: https://intelliparadigm.com 第一章:MCP 2026国产化部署优化方法总览 MCP 2026(Mission-Critical Platform 2026)是面向高可靠政务与能源场景的国产化中间件平台,其部署优化需兼顾信创生态兼容性、资源轻…

作者头像 李华