news 2026/5/3 15:34:56

PyTorch实战:从ResNet-18代码实现到过拟合解决,我的CIFAR-10训练踩坑全记录

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
PyTorch实战:从ResNet-18代码实现到过拟合解决,我的CIFAR-10训练踩坑全记录

PyTorch实战:从ResNet-18代码实现到过拟合解决,我的CIFAR-10训练踩坑全记录

当第一次在CIFAR-10数据集上训练ResNet-18模型时,我本以为按照教科书上的代码实现就能轻松获得不错的结果。然而现实给了我一记响亮的耳光——模型在训练集上表现优异,测试集上却惨不忍睹。这个典型的过拟合问题让我意识到,深度学习实战远不止于代码的简单堆砌。本文将完整记录我从模型构建到解决过拟合的全过程,希望能为同样在PyTorch实践中遇到类似问题的朋友提供参考。

1. 环境准备与数据加载

在开始任何深度学习项目前,确保环境配置正确是避免后续问题的关键一步。我使用的是Python 3.8和PyTorch 1.9.0,搭配CUDA 11.1以利用GPU加速。建议使用conda创建虚拟环境:

conda create -n pytorch_resnet python=3.8 conda activate pytorch_resnet pip install torch torchvision torchaudio cudatoolkit=11.1 -f https://download.pytorch.org/whl/torch_stable.html

CIFAR-10数据集包含60,000张32x32彩色图像,分为10个类别,每个类别6,000张。PyTorch的torchvision模块提供了便捷的加载方式:

import torchvision import torchvision.transforms as transforms # 定义数据增强和归一化 transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) # 加载数据集 train_set = torchvision.datasets.CIFAR10( root='./data', train=True, download=True, transform=transform_train) test_set = torchvision.datasets.CIFAR10( root='./data', train=False, download=True, transform=transform_test) train_loader = torch.utils.data.DataLoader( train_set, batch_size=128, shuffle=True, num_workers=2) test_loader = torch.utils.data.DataLoader( test_set, batch_size=100, shuffle=False, num_workers=2)

这里有几个关键点需要注意:

  • 数据增强:训练时使用随机裁剪和水平翻转来增加数据多样性
  • 数据归一化:使用CIFAR-10的均值和标准差进行归一化
  • 批量大小:根据GPU内存选择适当的batch_size,太大可能导致内存不足

2. ResNet-18模型实现

ResNet的核心思想是通过残差连接解决深层网络训练中的梯度消失问题。标准的ResNet-18由以下部分组成:

  1. 初始卷积层(7x7卷积,步长2)
  2. 最大池化层
  3. 4个残差块组,每组包含2个残差块
  4. 全局平均池化
  5. 全连接分类层

以下是PyTorch实现的关键代码:

import torch import torch.nn as nn import torch.nn.functional as F class BasicBlock(nn.Module): expansion = 1 def __init__(self, in_planes, planes, stride=1): super(BasicBlock, self).__init__() self.conv1 = nn.Conv2d( in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.shortcut = nn.Sequential() if stride != 1 or in_planes != self.expansion*planes: self.shortcut = nn.Sequential( nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(self.expansion*planes) ) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += self.shortcut(x) out = F.relu(out) return out class ResNet(nn.Module): def __init__(self, block, num_blocks, num_classes=10): super(ResNet, self).__init__() self.in_planes = 64 self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(64) self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) self.linear = nn.Linear(512*block.expansion, num_classes) def _make_layer(self, block, planes, num_blocks, stride): strides = [stride] + [1]*(num_blocks-1) layers = [] for stride in strides: layers.append(block(self.in_planes, planes, stride)) self.in_planes = planes * block.expansion return nn.Sequential(*layers) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) out = self.layer4(out) out = F.avg_pool2d(out, 4) out = out.view(out.size(0), -1) out = self.linear(out) return out def ResNet18(): return ResNet(BasicBlock, [2, 2, 2, 2])

实现时需要注意的几个细节:

  • 残差连接:当输入输出维度不匹配时(stride≠1或通道数变化),需要通过1x1卷积调整维度
  • 批归一化:每个卷积层后都接BatchNorm层加速训练
  • 激活函数:ReLU激活放在残差相加之后,这是ResNet的标准做法

3. 训练过程与初始结果

有了模型和数据,接下来就是训练过程。我使用交叉熵损失和SGD优化器,初始学习率设为0.1,并加入学习率衰减:

import torch.optim as optim device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = ResNet18().to(device) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200) def train(epoch): model.train() train_loss = 0 correct = 0 total = 0 for batch_idx, (inputs, targets) in enumerate(train_loader): inputs, targets = inputs.to(device), targets.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() train_loss += loss.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() print(f'Epoch: {epoch} | Loss: {train_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f}%') def test(epoch): model.eval() test_loss = 0 correct = 0 total = 0 with torch.no_grad(): for batch_idx, (inputs, targets) in enumerate(test_loader): inputs, targets = inputs.to(device), targets.to(device) outputs = model(inputs) loss = criterion(outputs, targets) test_loss += loss.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() print(f'Test Loss: {test_loss/(batch_idx+1):.3f} | Acc: {100.*correct/total:.3f}%') return 100.*correct/total for epoch in range(200): train(epoch) test_acc = test(epoch) scheduler.step()

训练200个epoch后,我观察到了一个典型的过拟合现象:

指标训练集测试集
准确率99.8%85.3%
损失值0.0020.891

训练集和测试集之间的巨大性能差距表明模型已经过拟合。更直观地,通过绘制训练曲线可以看到:

  1. 训练准确率很快接近100%
  2. 测试准确率在约50个epoch后停止提升
  3. 测试损失在后期反而开始上升

4. 过拟合诊断与解决方案

过拟合意味着模型过度记忆了训练数据的特定特征,而未能学习到泛化的模式。针对ResNet-18在CIFAR-10上的过拟合,我尝试了以下几种解决方案:

4.1 数据增强增强

初始的数据增强只包含随机裁剪和水平翻转。我扩展了增强策略:

transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ])

新增的增强包括:

  • 随机旋转:±15度范围内旋转图像
  • 颜色扰动:调整亮度、对比度和饱和度

4.2 加入Dropout层

在ResNet的全连接层前加入Dropout:

class ResNet(nn.Module): # ... 其他部分保持不变 ... def __init__(self, block, num_blocks, num_classes=10): # ... self.dropout = nn.Dropout(0.5) self.linear = nn.Linear(512*block.expansion, num_classes) def forward(self, x): # ... out = F.avg_pool2d(out, 4) out = out.view(out.size(0), -1) out = self.dropout(out) out = self.linear(out) return out

Dropout率为0.5,意味着在前向传播中随机丢弃50%的神经元。

4.3 权重衰减调整

增加L2正则化的强度,将weight_decay从5e-4提高到1e-3:

optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-3)

4.4 标签平滑

使用标签平滑技术减轻模型对标签的过度自信:

class LabelSmoothingCrossEntropy(nn.Module): def __init__(self, epsilon=0.1): super().__init__() self.epsilon = epsilon def forward(self, preds, target): log_probs = F.log_softmax(preds, dim=-1) nll_loss = -log_probs.gather(dim=-1, index=target.unsqueeze(1)) nll_loss = nll_loss.squeeze(1) smooth_loss = -log_probs.mean(dim=-1) loss = (1 - self.epsilon) * nll_loss + self.epsilon * smooth_loss return loss.mean() criterion = LabelSmoothingCrossEntropy(epsilon=0.1)

标签平滑通过将一部分概率质量分配给非目标类别,防止模型对训练样本过度自信。

4.5 模型架构调整

原始的ResNet-18设计用于ImageNet(224x224),而CIFAR-10只有32x32。我做了以下调整:

  1. 将初始卷积层从7x7改为3x3,步长从2改为1
  2. 移除第一个最大池化层
  3. 减小最终的全连接层尺寸
class ResNet(nn.Module): def __init__(self, block, num_blocks, num_classes=10): super(ResNet, self).__init__() self.in_planes = 64 self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(64) # 移除maxpool层 self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) # ... 其余部分保持不变 ...

5. 优化后结果对比

实施上述改进后,重新训练模型并对比结果:

改进措施测试准确率训练-测试差距
原始模型85.3%14.5%
+增强数据87.1%10.2%
+Dropout88.5%8.7%
+权重衰减调整89.2%7.9%
+标签平滑90.1%6.3%
+模型调整92.4%4.8%

最终的训练曲线显示:

  • 训练准确率收敛于约97.2%
  • 测试准确率稳定在92.4%左右
  • 两者差距从最初的14.5%缩小到4.8%

6. 其他实用技巧

在调试过程中,我还发现以下几个实用技巧:

学习率预热

对于深度网络,初始阶段使用过大的学习率可能导致不稳定。可以采用线性预热:

def warmup_lr(epoch): if epoch < 5: return 0.01 + (0.1-0.01) * epoch / 5 else: return 0.1 scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, warmup_lr)

混合精度训练

使用AMP(Automatic Mixed Precision)加速训练并减少显存占用:

scaler = torch.cuda.amp.GradScaler() for epoch in range(200): model.train() for inputs, targets in train_loader: inputs, targets = inputs.to(device), targets.to(device) optimizer.zero_grad() with torch.cuda.amp.autocast(): outputs = model(inputs) loss = criterion(outputs, targets) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()

模型EMA

使用模型参数的指数移动平均(EMA)可以获得更稳定的测试性能:

class ModelEMA: def __init__(self, model, decay=0.999): self.model = model self.decay = decay self.shadow = {} self.backup = {} def register(self): for name, param in self.model.named_parameters(): if param.requires_grad: self.shadow[name] = param.data.clone() def update(self): for name, param in self.model.named_parameters(): if param.requires_grad: assert name in self.shadow new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name] self.shadow[name] = new_average.clone() def apply_shadow(self): for name, param in self.model.named_parameters(): if param.requires_grad: assert name in self.shadow self.backup[name] = param.data param.data = self.shadow[name] def restore(self): for name, param in self.model.named_parameters(): if param.requires_grad: assert name in self.backup param.data = self.backup[name] self.backup = {} ema = ModelEMA(model) ema.register() # 在训练循环中 for epoch in range(200): train(epoch) ema.update() ema.apply_shadow() test_acc = test(epoch) # 使用EMA模型测试 ema.restore()
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/3 15:32:50

如何在macOS上获得完美的桌面歌词体验:LyricsX完整指南

如何在macOS上获得完美的桌面歌词体验&#xff1a;LyricsX完整指南 【免费下载链接】LyricsX &#x1f3b6; Ultimate lyrics app for macOS. 项目地址: https://gitcode.com/gh_mirrors/ly/LyricsX 还在为macOS上找不到合适的歌词显示工具而烦恼吗&#xff1f;LyricsX这…

作者头像 李华
网站建设 2026/5/3 15:28:38

通过审计日志功能追踪APIKey使用情况加强安全管控

通过审计日志功能追踪APIKey使用情况加强安全管控 1. APIKey安全治理的核心挑战 在企业级大模型应用场景中&#xff0c;APIKey作为访问凭证直接关联计费与权限。传统管理方式往往面临三大痛点&#xff1a;密钥分发后难以追溯实际使用者、异常调用无法及时预警、多模型访问缺乏…

作者头像 李华
网站建设 2026/5/3 15:27:33

在macOS上实现完美歌词体验的终极解决方案:LyricsX深度指南

在macOS上实现完美歌词体验的终极解决方案&#xff1a;LyricsX深度指南 【免费下载链接】LyricsX &#x1f3b6; Ultimate lyrics app for macOS. 项目地址: https://gitcode.com/gh_mirrors/ly/LyricsX 还在为macOS上找不到合适的歌词显示工具而烦恼吗&#xff1f;Lyri…

作者头像 李华
网站建设 2026/5/3 15:24:31

2025届最火的降重复率助手横评

Ai论文网站排名&#xff08;开题报告、文献综述、降aigc率、降重综合对比&#xff09; TOP1. 千笔AI TOP2. aipasspaper TOP3. 清北论文 TOP4. 豆包 TOP5. kimi TOP6. deepseek 应当把AIGC在最终内容里所占的比例予以降低&#xff0c;这就需要从源头着手去把控生成内容的…

作者头像 李华
网站建设 2026/5/3 15:19:21

一夜爆火!这个4千星的开源项目让Agent重回文档

一个登上 GitHub 热榜的桌面端 GUI在 AI Agent 的开源战场上&#xff0c;一个名字正在被越来越多开发者反复提起&#xff1a;lukilabs/craft-agents-oss。4 月中旬&#xff0c;这个项目登上 GitHub 日热榜 AI 类榜单&#xff0c;短时间内积累四千余 Star。与一众「命令行型」智…

作者头像 李华
网站建设 2026/5/3 15:19:04

生态系统集成篇:DeepSeek V4 在生产级应用开发中的赋能与实践指南

生态系统集成篇&#xff1a;DeepSeek V4 在生产级应用开发中的赋能与实践指南 引言&#xff1a;从模型能力到可部署生态系统 拥有一个高性能的基础模型&#xff08;如 DeepSeek V4&#xff09;仅仅是起点。一个真正的商业化产品&#xff0c;其核心价值在于其能否被高效、可靠、…

作者头像 李华