突破传统MNIST:用PyTorch构建连续学习实战环境的完整指南
当大多数机器学习教程还在使用标准MNIST数据集时,前沿研究早已转向更具挑战性的变体。Permuted MNIST和Split MNIST不仅是学术论文中的常客,更是检验连续学习算法性能的黄金标准。本文将带你从零开始,用PyTorch搭建完整的连续学习评估环境,通过代码实践理解这两种经典数据集的精髓。
1. 为什么需要超越标准MNIST?
传统MNIST数据集作为机器学习入门教材已有二十余年历史,其简单性既是优点也是局限。784个像素点的固定排列方式让模型很快就能达到接近人类水平的准确率,但这恰恰掩盖了现实世界中的核心挑战——数据分布的动态变化。
连续学习研究需要能模拟以下场景的数据集:
- 任务序列化:模型需要按顺序学习多个相关但不完全相同的任务
- 灾难性遗忘:新任务学习时旧任务性能的下降程度需要量化
- 知识迁移:先前学到的知识如何帮助后续任务学习
提示:Permuted MNIST通过像素重排改变输入分布,Split MNIST通过类别划分创建任务序列,两者分别对应Domain-IL和Class-IL场景
下表对比了三种MNIST变体的核心差异:
| 数据集类型 | 变化维度 | 适用场景 | 挑战重点 |
|---|---|---|---|
| 标准MNIST | 无变化 | 基础分类 | 单一任务性能 |
| Permuted MNIST | 像素排列 | Domain-IL | 分布变化适应 |
| Split MNIST | 类别划分 | Class-IL | 类别增量学习 |
2. 环境准备与基础配置
开始前确保已安装最新版PyTorch和标准科学计算库:
import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms import numpy as np import matplotlib.pyplot as plt定义基础多层感知机(MLP)模型,这是连续学习研究的常用架构:
class BaseMLP(nn.Module): def __init__(self, input_size=784, hidden_size=400, output_size=10): super(BaseMLP, self).__init__() self.fc1 = nn.Linear(input_size, hidden_size) self.fc2 = nn.Linear(hidden_size, hidden_size) self.fc3 = nn.Linear(hidden_size, output_size) self.relu = nn.ReLU() def forward(self, x): x = x.view(x.size(0), -1) # 展平输入 x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) return self.fc3(x)3. Permuted MNIST实战实现
Permuted MNIST的核心是为每个任务生成唯一的像素排列顺序。以下是关键实现步骤:
- 生成排列矩阵:为每个任务创建随机但固定的像素排列
- 数据转换器:应用排列并标准化图像
- 任务序列构建:创建多个不同排列的数据加载器
def generate_permutation(): """生成随机像素排列顺序""" return torch.randperm(784) class PermuteTransform: """应用像素排列的数据转换器""" def __init__(self, permutation): self.permutation = permutation def __call__(self, x): return x.view(-1)[self.permutation].view(1, 28, 28) # 创建5个不同排列的任务 num_tasks = 5 permutations = [generate_permutation() for _ in range(num_tasks)] task_loaders = [] for perm in permutations: transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)), PermuteTransform(perm) ]) train_loader = torch.utils.data.DataLoader( datasets.MNIST('../data', train=True, download=True, transform=transform), batch_size=64, shuffle=True) task_loaders.append(train_loader)可视化不同排列下的图像样本:
def show_permuted_samples(permutations, num_samples=5): """展示不同排列下的MNIST样本""" fig, axes = plt.subplots(len(permutations), num_samples, figsize=(15, 10)) for i, perm in enumerate(permutations): transform = PermuteTransform(perm) for j in range(num_samples): img, _ = datasets.MNIST('../data', train=True)[j] axes[i,j].imshow(transform(img).squeeze(), cmap='gray') axes[i,j].axis('off') plt.show() show_permuted_samples(permutations[:3]) # 展示前3种排列4. Split MNIST的精细实现
Split MNIST将10个数字类别划分为多个二元分类任务,典型划分方式如下:
- Task 1: 识别0和1
- Task 2: 识别2和3
- Task 3: 识别4和5
- Task 4: 识别6和7
- Task 5: 识别8和9
实现关键点在于数据过滤和标签重映射:
def create_split_mnist_loaders(): """创建Split MNIST任务序列""" base_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) # 定义任务划分 (数字对列表) task_pairs = [(0,1), (2,3), (4,5), (6,7), (8,9)] task_loaders = [] for pair in task_pairs: # 过滤只包含当前数字对的样本 def filter_func(data, target): mask = (target == pair[0]) | (target == pair[1]) return data[mask], target[mask] # 重映射标签为0/1 def remap_labels(target): return (target == pair[1]).long() # 自定义数据集 class SplitMNIST(torch.utils.data.Dataset): def __init__(self, train=True): self.mnist = datasets.MNIST('../data', train=train, download=True) self.data, self.targets = filter_func(self.mnist.data, self.mnist.targets) self.targets = remap_labels(self.targets) self.transform = base_transform def __len__(self): return len(self.data) def __getitem__(self, idx): img, target = self.data[idx], self.targets[idx] img = self.transform(img.numpy()) return img, target train_loader = torch.utils.data.DataLoader( SplitMNIST(train=True), batch_size=64, shuffle=True) task_loaders.append(train_loader) return task_loaders split_loaders = create_split_mnist_loaders()5. 连续学习评估框架
完整的连续学习评估需要跟踪以下指标:
- 当前任务准确率
- 旧任务遗忘程度
- 整体平均准确率
实现评估流程的核心代码:
def evaluate(model, task_id, test_loaders, device='cpu'): """评估模型在所有已学习任务上的表现""" model.eval() accuracies = [] with torch.no_grad(): for t in range(task_id + 1): correct = 0 total = 0 for images, labels in test_loaders[t]: images, labels = images.to(device), labels.to(device) outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() accuracies.append(100 * correct / total) return accuracies def continual_learning_training(model, task_loaders, num_epochs=5): """连续学习训练流程""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) optimizer = optim.Adam(model.parameters(), lr=0.001) criterion = nn.CrossEntropyLoss() all_accuracies = [] for task_id, train_loader in enumerate(task_loaders): print(f"\n=== Training on Task {task_id + 1} ===") for epoch in range(num_epochs): model.train() running_loss = 0.0 for images, labels in train_loader: images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader):.4f}") # 评估当前模型在所有已学习任务上的表现 current_acc = evaluate(model, task_id, task_loaders, device) all_accuracies.append(current_acc) print(f"Task Accuracies after Task {task_id+1}: {current_acc}") return all_accuracies6. 高级技巧与优化方向
基础实现后,可以考虑以下增强方案:
对抗遗忘技术:
- 弹性权重固化(EWC):添加正则项保护重要参数
- 记忆回放:保留少量旧任务样本进行联合训练
- 动态架构:为每个任务分配专用模型组件
# EWC实现示例 def ewc_loss(model, fisher_matrix, previous_params, lambda_ewc): loss = 0 for name, param in model.named_parameters(): if name in fisher_matrix: loss += (fisher_matrix[name] * (param - previous_params[name])**2).sum() return lambda_ewc * loss # 在训练循环中添加EWC损失 total_loss = criterion(outputs, labels) + ewc_loss(model, fisher, prev_params, lambda_ewc=1000)评估指标可视化:
def plot_learning_curve(accuracies): """绘制连续学习曲线""" plt.figure(figsize=(10, 6)) for task in range(len(accuracies)): x = range(task + 1, len(accuracies) + 1) y = [acc[task] for acc in accuracies[task:]] plt.plot(x, y, marker='o', label=f'Task {task+1}') plt.xlabel('Task Number') plt.ylabel('Accuracy (%)') plt.title('Continual Learning Performance') plt.legend() plt.grid() plt.show()实际项目中,我发现Permuted MNIST对初始化种子非常敏感,不同排列顺序可能导致性能波动达5-8%。解决方法是固定随机种子或进行多次实验取平均。Split MNIST则面临类别不平衡问题,某些数字对的样本量可能相差20%,需要在损失函数中添加类别权重。