news 2026/6/8 4:06:15

PyTorch学习率调度实战:用CosineAnnealingLR和WarmRestarts搞定图像分类任务(以ResNet18为例)

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
PyTorch学习率调度实战:用CosineAnnealingLR和WarmRestarts搞定图像分类任务(以ResNet18为例)

PyTorch学习率调度实战:用CosineAnnealingLR和WarmRestarts搞定图像分类任务(以ResNet18为例)

在深度学习模型训练中,学习率调度策略的选择往往决定了模型能否快速收敛到最优解。想象一下,你正在训练一个ResNet18模型进行图像分类,前几轮训练进展顺利,但到了中后期验证集准确率却停滞不前——这可能正是学习率调度策略需要优化的信号。本文将带你深入实战,通过PyTorch的CosineAnnealingLRCosineAnnealingWarmRestarts两种调度器,为ResNet18模型打造一个高效的学习率调节方案。

1. 环境准备与数据加载

在开始之前,我们需要准备好实验环境。这里推荐使用Python 3.8+和PyTorch 1.10+版本,确保能够支持最新的调度器功能。首先安装必要的依赖:

pip install torch torchvision matplotlib tensorboard

对于图像分类任务,我们选择CIFAR-10数据集作为示例。这个数据集包含10个类别的6万张32x32彩色图像,非常适合验证ResNet18这样的轻量级模型:

import torch from torchvision import datasets, transforms from torch.utils.data import DataLoader # 数据增强和归一化 train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) ]) test_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) ]) # 加载数据集 train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform) test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform) # 创建数据加载器 batch_size = 128 train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4) test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=4)

2. 模型构建与基础训练流程

我们使用PyTorch内置的ResNet18模型,并针对CIFAR-10的32x32输入尺寸进行适当调整:

import torch.nn as nn from torchvision.models import resnet18 def create_model(num_classes=10): model = resnet18(pretrained=False) # 调整第一层卷积,适应CIFAR-10的32x32输入 model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) # 移除最后的全连接层,替换为适合类别数的层 model.fc = nn.Linear(model.fc.in_features, num_classes) return model model = create_model().cuda()

基础训练流程包含损失函数和优化器的设置。这里我们使用交叉熵损失和SGD优化器:

criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)

3. CosineAnnealingLR调度器实战

CosineAnnealingLR实现了一个简单的余弦退火策略,学习率在给定的周期内按照余弦曲线从初始值下降到最小值。

3.1 参数配置与初始化

关键参数说明:

  • T_max: 半周期长度(epoch数)
  • eta_min: 学习率最小值

对于CIFAR-10训练,通常设置150-200个epoch。我们选择T_max=50,意味着每50个epoch完成半个余弦周期:

from torch.optim.lr_scheduler import CosineAnnealingLR scheduler_cosine = CosineAnnealingLR(optimizer, T_max=50, eta_min=1e-5)

3.2 训练循环集成

将调度器集成到训练循环中,每个epoch后调用scheduler.step()

def train_with_scheduler(model, train_loader, test_loader, scheduler, epochs=150): train_losses = [] test_accuracies = [] learning_rates = [] for epoch in range(epochs): model.train() running_loss = 0.0 for inputs, labels in train_loader: inputs, labels = inputs.cuda(), labels.cuda() optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() # 记录当前学习率 current_lr = optimizer.param_groups[0]['lr'] learning_rates.append(current_lr) # 更新学习率 scheduler.step() # 评估模型 test_acc = evaluate(model, test_loader) test_accuracies.append(test_acc) train_loss = running_loss / len(train_loader) train_losses.append(train_loss) print(f'Epoch {epoch+1}/{epochs} - Loss: {train_loss:.4f}, Acc: {test_acc:.4f}, LR: {current_lr:.6f}') return train_losses, test_accuracies, learning_rates

3.3 学习率变化可视化

训练完成后,我们可以绘制学习率变化曲线:

import matplotlib.pyplot as plt def plot_learning_rates(learning_rates, title): plt.figure(figsize=(10, 5)) plt.plot(learning_rates) plt.xlabel('Epoch') plt.ylabel('Learning Rate') plt.title(title) plt.grid(True) plt.show() # 假设已经完成了训练 # plot_learning_rates(cosine_lrs, "CosineAnnealingLR Learning Rate Schedule")

典型的余弦退火学习率曲线会呈现周期性波动,在50个epoch内从初始值平滑下降到最小值,然后重新开始。

4. CosineAnnealingWarmRestarts调度器实战

CosineAnnealingWarmRestartsCosineAnnealingLR的改进版,在每次重启时保留部分之前的学习率变化"记忆",通常能带来更好的性能。

4.1 参数配置与初始化

关键参数说明:

  • T_0: 第一次重启的epoch数
  • T_mult: 重启周期倍增因子
  • eta_min: 学习率最小值

我们设置初始周期T_0=30,倍增因子T_mult=2,意味着:

  • 第一个周期:30个epoch
  • 第二个周期:60个epoch
  • 第三个周期:120个epoch
  • 以此类推...
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts scheduler_warm = CosineAnnealingWarmRestarts(optimizer, T_0=30, T_mult=2, eta_min=1e-5)

4.2 训练循环集成

训练循环与之前类似,只是更换了调度器:

# 使用相同的train_with_scheduler函数,传入不同的scheduler train_losses_warm, test_accuracies_warm, learning_rates_warm = train_with_scheduler( model, train_loader, test_loader, scheduler_warm, epochs=150 )

4.3 学习率变化对比

将两种调度器的学习率曲线绘制在一起对比:

def compare_schedulers(lr_cosine, lr_warm, epochs): plt.figure(figsize=(12, 6)) plt.plot(range(epochs), lr_cosine, label='CosineAnnealingLR (T_max=50)') plt.plot(range(epochs), lr_warm, label='CosineAnnealingWarmRestarts (T_0=30, T_mult=2)') plt.xlabel('Epoch') plt.ylabel('Learning Rate') plt.title('Comparison of Learning Rate Schedules') plt.legend() plt.grid(True) plt.show() # 假设已经获取了两种调度器的学习率记录 # compare_schedulers(cosine_lrs, warm_lrs, epochs=150)

5. 性能分析与实践建议

在实际项目中,我们需要综合考虑训练时间、资源消耗和模型性能。以下是两种调度器的对比分析:

特性CosineAnnealingLRCosineAnnealingWarmRestarts
收敛速度中等较快
最终准确率良好优秀
超参数敏感性中等较低
计算开销略高
适合场景小中型数据集大型复杂数据集

实践中的几点建议:

  1. 初始学习率选择:对于ResNet18和CIFAR-10,0.1是一个不错的起点
  2. T_max/T_0设置:通常设为总epoch数的1/3到1/2
  3. eta_min设置:一般设为初始学习率的1/100到1/1000
  4. Warm Restarts优势:在训练后期能帮助模型跳出局部最优
# 一个综合了最佳实践的配置示例 optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=40, T_mult=2, eta_min=1e-4)

在CIFAR-10上的实验表明,使用Warm Restarts的调度器通常能比固定学习率或简单余弦退火获得1-2%的准确率提升。更重要的是,这种提升不需要额外的训练时间或计算资源,只需要正确配置调度器参数。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/6/8 4:05:15

Jenkinsapi高级技巧:提升CI/CD效率的10个实用方法

Jenkinsapi高级技巧:提升CI/CD效率的10个实用方法 【免费下载链接】jenkinsapi A Python API for accessing resources and configuring Hudson & Jenkins continuous-integration servers 项目地址: https://gitcode.com/gh_mirrors/je/jenkinsapi Jenk…

作者头像 李华
网站建设 2026/6/8 3:57:03

EdgeKey开源:全栈数字商品商城,丢Cloudflare Workers零成本开店

卖卡密、卖激活码这种小生意,租服务器一个月几十块,一年下来够买好多杯奶茶了。EdgeKey 的思路很野:基于 Vike 框架把前端后端数据库全打包,丢到 Cloudflare Workers 上跑,免费额度就够开店。核心流程:商品…

作者头像 李华
网站建设 2026/6/8 3:57:00

PawPal开源:屏幕角落的透明小狗,专治久坐走神不喝水

坐了三小时没动过、杯子空了俩小时、打开微博刷了半小时—— PawPal 是一只赖在屏幕角落里的小狗,专门用物理方式把你拽回正经事上。到休息时间了它从屏幕这头窜到那头,想忽略都难;杯子空了它到点蹦出来盯你去倒水;偷偷切到社交网…

作者头像 李华
网站建设 2026/6/8 3:56:37

Ludic框架性能优化:7个提升Web应用响应速度的关键技巧

Ludic框架性能优化:7个提升Web应用响应速度的关键技巧 【免费下载链接】ludic 🌳 A type-safe HTML template engine for Python. Build dynamic web pages using Python components with a React-like approach. 项目地址: https://gitcode.com/gh_mi…

作者头像 李华
网站建设 2026/6/8 3:51:56

Webpack Bundle Size Analyzer最佳实践:10个优化打包体积的技巧

Webpack Bundle Size Analyzer最佳实践:10个优化打包体积的技巧 【免费下载链接】webpack-bundle-size-analyzer A tool for finding out what contributes to the size of Webpack bundles 项目地址: https://gitcode.com/gh_mirrors/we/webpack-bundle-size-ana…

作者头像 李华