news 2026/6/19 17:31:30

Day50 PythonStudy

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
Day50 PythonStudy
import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data import DataLoader import matplotlib.pyplot as plt import numpy as np # 设置中文字体支持 plt.rcParams["font.family"] = ["SimHei"] plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题 # 检查GPU是否可用 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"使用设备: {device}") # 1. 数据预处理 # 训练集:使用多种数据增强方法提高模型泛化能力 train_transform = transforms.Compose([ # 随机裁剪图像,从原图中随机截取32x32大小的区域 transforms.RandomCrop(32, padding=4), # 随机水平翻转图像(概率0.5) transforms.RandomHorizontalFlip(), # 随机颜色抖动:亮度、对比度、饱和度和色调随机变化 transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), # 随机旋转图像(最大角度15度) transforms.RandomRotation(15), # 将PIL图像或numpy数组转换为张量 transforms.ToTensor(), # 标准化处理:每个通道的均值和标准差,使数据分布更合理 transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) ]) # 测试集:仅进行必要的标准化,保持数据原始特性,标准化不损失数据信息,可还原 test_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) ]) # 2. 加载CIFAR-10数据集 train_dataset = datasets.CIFAR10( root='./data', train=True, download=True, transform=train_transform # 使用增强后的预处理 ) test_dataset = datasets.CIFAR10( root='./data', train=False, transform=test_transform # 测试集不使用增强 ) # 3. 创建数据加载器 batch_size = 64 train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F # 设置设备 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"使用设备: {device}") # 1. 定义CNN模型 class CNN(nn.Module): def __init__(self): super(CNN, self).__init__() # ---------------------- 卷积特征提取部分 ---------------------- # 第一个卷积块 self.conv_block1 = nn.Sequential( nn.Conv2d(3, 32, kernel_size=3, padding=1), # [batch, 3, 32, 32] -> [batch, 32, 32, 32] nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.MaxPool2d(2) # [batch, 32, 32, 32] -> [batch, 32, 16, 16] ) # 第二个卷积块 self.conv_block2 = nn.Sequential( nn.Conv2d(32, 64, kernel_size=3, padding=1), # [batch, 32, 16, 16] -> [batch, 64, 16, 16] nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.MaxPool2d(2) # [batch, 64, 16, 16] -> [batch, 64, 8, 8] ) # 第三个卷积块 self.conv_block3 = nn.Sequential( nn.Conv2d(64, 128, kernel_size=3, padding=1), # [batch, 64, 8, 8] -> [batch, 128, 8, 8] nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.MaxPool2d(2) # [batch, 128, 8, 8] -> [batch, 128, 4, 4] ) # ---------------------- 全连接分类部分 ---------------------- self.classifier = nn.Sequential( nn.Linear(128 * 4 * 4, 512), nn.ReLU(inplace=True), nn.Dropout(0.5), nn.Linear(512, 256), nn.ReLU(inplace=True), nn.Dropout(0.3), nn.Linear(256, 10) ) def forward(self, x): # 卷积特征提取 x = self.conv_block1(x) x = self.conv_block2(x) x = self.conv_block3(x) # 展平 x = x.view(x.size(0), -1) # [batch, 128, 4, 4] -> [batch, 2048] # 分类 x = self.classifier(x) return x # 2. 初始化模型 model = CNN().to(device) print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}") print(f"可训练参数量: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}") # 3. 定义损失函数和优化器 criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4) # 添加L2正则化 scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, # 第一个参数是optimizer,不要用关键字参数 mode='min', factor=0.5, patience=5, threshold=0.01, min_lr=1e-5 )
# 5. 训练模型(记录每个 iteration 的损失) def train(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs): model.train() # 设置为训练模式 # 记录每个 iteration 的损失 all_iter_losses = [] # 存储所有 batch 的损失 iter_indices = [] # 存储 iteration 序号 # 记录每个 epoch 的准确率和损失 train_acc_history = [] test_acc_history = [] train_loss_history = [] test_loss_history = [] for epoch in range(epochs): running_loss = 0.0 correct = 0 total = 0 for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) # 移至GPU optimizer.zero_grad() # 梯度清零 output = model(data) # 前向传播 loss = criterion(output, target) # 计算损失 loss.backward() # 反向传播 optimizer.step() # 更新参数 # 记录当前 iteration 的损失 iter_loss = loss.item() all_iter_losses.append(iter_loss) iter_indices.append(epoch * len(train_loader) + batch_idx + 1) # 统计准确率和损失 running_loss += iter_loss _, predicted = output.max(1) total += target.size(0) correct += predicted.eq(target).sum().item() # 每100个批次打印一次训练信息 if (batch_idx + 1) % 100 == 0: print(f'Epoch: {epoch+1}/{epochs} | Batch: {batch_idx+1}/{len(train_loader)} ' f'| 单Batch损失: {iter_loss:.4f} | 累计平均损失: {running_loss/(batch_idx+1):.4f}') # 计算当前epoch的平均训练损失和准确率 epoch_train_loss = running_loss / len(train_loader) epoch_train_acc = 100. * correct / total train_acc_history.append(epoch_train_acc) train_loss_history.append(epoch_train_loss) # 测试阶段 model.eval() # 设置为评估模式 test_loss = 0 correct_test = 0 total_test = 0 with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) output = model(data) test_loss += criterion(output, target).item() _, predicted = output.max(1) total_test += target.size(0) correct_test += predicted.eq(target).sum().item() epoch_test_loss = test_loss / len(test_loader) epoch_test_acc = 100. * correct_test / total_test test_acc_history.append(epoch_test_acc) test_loss_history.append(epoch_test_loss) # 更新学习率调度器 scheduler.step(epoch_test_loss) print(f'Epoch {epoch+1}/{epochs} 完成 | 训练准确率: {epoch_train_acc:.2f}% | 测试准确率: {epoch_test_acc:.2f}%') # 绘制所有 iteration 的损失曲线 plot_iter_losses(all_iter_losses, iter_indices) # 绘制每个 epoch 的准确率和损失曲线 plot_epoch_metrics(train_acc_history, test_acc_history, train_loss_history, test_loss_history) return epoch_test_acc # 返回最终测试准确率 # 6. 绘制每个 iteration 的损失曲线 def plot_iter_losses(losses, indices): plt.figure(figsize=(10, 4)) plt.plot(indices, losses, 'b-', alpha=0.7, label='Iteration Loss') plt.xlabel('Iteration(Batch序号)') plt.ylabel('损失值') plt.title('每个 Iteration 的训练损失') plt.legend() plt.grid(True) plt.tight_layout() plt.show() # 7. 绘制每个 epoch 的准确率和损失曲线 def plot_epoch_metrics(train_acc, test_acc, train_loss, test_loss): epochs = range(1, len(train_acc) + 1) plt.figure(figsize=(12, 4)) # 绘制准确率曲线 plt.subplot(1, 2, 1) plt.plot(epochs, train_acc, 'b-', label='训练准确率') plt.plot(epochs, test_acc, 'r-', label='测试准确率') plt.xlabel('Epoch') plt.ylabel('准确率 (%)') plt.title('训练和测试准确率') plt.legend() plt.grid(True) # 绘制损失曲线 plt.subplot(1, 2, 2) plt.plot(epochs, train_loss, 'b-', label='训练损失') plt.plot(epochs, test_loss, 'r-', label='测试损失') plt.xlabel('Epoch') plt.ylabel('损失值') plt.title('训练和测试损失') plt.legend() plt.grid(True) plt.tight_layout() plt.show() # 8. 执行训练和测试 epochs = 20 # 增加训练轮次以获得更好效果 print("开始使用CNN训练模型...") final_accuracy = train(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs) print(f"训练完成!最终测试准确率: {final_accuracy:.2f}%") # # 保存模型 # torch.save(model.state_dict(), 'cifar10_cnn_model.pth') # print("模型已保存为: cifar10_cnn_model.pth")

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/6/15 12:34:17

浔川社团关于福利发放方案再次调整的征求意见稿公告

浔川社团关于福利发放方案再次调整的征求意见稿公告各位社团成员:为保障社团核心项目推进,结合实际工作安排,现就福利发放方案再次调整事宜征求全体成员意见。因浔川代码编辑器v2.1.0正式版内测工作将于明年2月底启动,该项目占用存…

作者头像 李华
网站建设 2026/6/18 13:23:24

Windows NVMe技术革新与性能跃迁

在存储技术高速迭代的今天,NVMe(NVM Express)作为PCIe时代的存储协议标杆,早已成为高性能计算、数据中心乃至消费级设备的核心支撑。而微软作为操作系统生态的核心玩家,其在Windows系统中对NVMe技术的优化与革新,直接决定了硬件性能的释放上限。微软披露的Windows更新、原…

作者头像 李华
网站建设 2026/6/15 18:47:15

CloudWatch 使用技巧与方法大全

一、概述 Amazon CloudWatch 是 AWS 的核心监控服务,提供指标收集、日志管理、告警通知和可视化能力。 核心组件 组件 功能 典型场景 Metrics 指标收集与存储 CPU、内存、自定义业务指标 Logs 日志收集与分析 应用日志、系统日志 Alarms 告警与自动响应 阈值告警、自动伸缩触…

作者头像 李华
网站建设 2026/6/15 13:10:31

甲骨文文字检测数据集VOC+YOLO格式6079张1类别

数据集格式:Pascal VOC格式YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件)图片数量(jpg文件个数):6079标注数量(xml文件个数):6079标注数量(txt文件个数):6079标注类别…

作者头像 李华
网站建设 2026/6/17 16:45:29

基于微信小程序的民宿预订管理系统设计与实现(毕设源码+文档)

背景 本课题聚焦基于微信小程序的民宿预订管理系统的设计与实现,旨在解决游客预订民宿流程繁琐、房源信息不透明、订单管理混乱、房东与游客沟通低效等痛点,依托微信小程序的轻量化、高触达优势,构建集房源展示、在线预订、订单管理、房东管理…

作者头像 李华
网站建设 2026/6/14 22:10:03

基于微信小程序的生猪养殖信息化管理系统(毕设源码+文档)

背景 本课题聚焦基于微信小程序的生猪养殖信息化管理系统的设计与实现,旨在解决传统生猪养殖过程中数据记录繁琐、养殖状态监控不实时、疫病预警滞后、养殖流程不规范等痛点,依托微信小程序的轻量化、高触达优势,构建集养殖数据记录、生猪状态…

作者头像 李华