news 2026/6/15 15:57:21

第P2周:CIFAR10彩色图片识别

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
第P2周:CIFAR10彩色图片识别
  • 🍨本文为🔗365天深度学习训练营中的学习记录博客
  • 🍖原作者:K同学啊

目录

一、 前期准备

1. 设置GPU

2. 导入数据

3. 数据可视化

二、构建简单的CNN网络

三、 训练模型

1. 设置超参数

2. 编写训练函数

3. 编写测试函数

4. 正式训练

四、 结果可视化

五、个人总结

一、 前期准备

1. 设置GPU

import torch import torch.nn as nn import matplotlib.pyplot as plt import torchvision device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device

2. 导入数据

train_ds = torchvision.datasets.CIFAR10('data', train=True, transform=torchvision.transforms.ToTensor(), download=True) test_ds = torchvision.datasets.CIFAR10('data', train=False, transform=torchvision.transforms.ToTensor(), download=True) batch_size = 32 train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True) test_dl = torch.utils.data.DataLoader(test_ds, batch_size=batch_size) imgs, labels = next(iter(train_dl)) imgs.shape

3. 数据可视化

import numpy as np plt.figure(figsize=(20, 5)) for i, imgs in enumerate(imgs[:20]): npimg = imgs.numpy().transpose((1, 2, 0)) plt.subplot(2, 10, i+1) plt.imshow(npimg, cmap=plt.cm.binary) plt.axis('off')

二、构建简单的CNN网络

import torch.nn.functional as F num_classes = 10 class Model(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=3) self.pool1 = nn.MaxPool2d(kernel_size=2) self.conv2 = nn.Conv2d(64, 64, kernel_size=3) self.pool2 = nn.MaxPool2d(kernel_size=2) self.conv3 = nn.Conv2d(64, 128, kernel_size=3) self.pool3 = nn.MaxPool2d(kernel_size=2) self.fc1 = nn.Linear(512, 256) self.fc2 = nn.Linear(256, num_classes) def forward(self, x): x = self.pool1(F.relu(self.conv1(x))) x = self.pool2(F.relu(self.conv2(x))) x = self.pool3(F.relu(self.conv3(x))) x = torch.flatten(x, start_dim=1) x = F.relu(self.fc1(x)) x = self.fc2(x) return x from torchinfo import summary # 将模型转移到GPU中(我们模型运行均在GPU中进行) model = Model().to(device) summary(model)

三、 训练模型

1. 设置超参数

loss_fn = nn.CrossEntropyLoss() # 创建损失函数 learn_rate = 1e-2 # 学习率 opt = torch.optim.SGD(model.parameters(),lr=learn_rate)

2. 编写训练函数

# 训练循环 def train(dataloader, model, loss_fn, optimizer): size = len(dataloader.dataset) # 训练集的大小,一共60000张图片 num_batches = len(dataloader) # 批次数目,1875(60000/32) train_loss, train_acc = 0, 0 # 初始化训练损失和正确率 for X, y in dataloader: # 获取图片及其标签 X, y = X.to(device), y.to(device) # 计算预测误差 pred = model(X) # 网络输出 loss = loss_fn(pred, y) # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失 # 反向传播 optimizer.zero_grad() # grad属性归零 loss.backward() # 反向传播 optimizer.step() # 每一步自动更新 # 记录acc与loss train_acc += (pred.argmax(1) == y).type(torch.float).sum().item() train_loss += loss.item() train_acc /= size train_loss /= num_batches return train_acc, train_loss

3. 编写测试函数

def test (dataloader, model, loss_fn): size = len(dataloader.dataset) # 测试集的大小,一共10000张图片 num_batches = len(dataloader) # 批次数目,313(10000/32=312.5,向上取整) test_loss, test_acc = 0, 0 # 当不进行训练时,停止梯度更新,节省计算内存消耗 with torch.no_grad(): for imgs, target in dataloader: imgs, target = imgs.to(device), target.to(device) # 计算loss target_pred = model(imgs) loss = loss_fn(target_pred, target) test_loss += loss.item() test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item() test_acc /= size test_loss /= num_batches return test_acc, test_loss

4. 正式训练

epochs = 10 train_loss = [] train_acc = [] test_loss = [] test_acc = [] for epoch in range(epochs): model.train() epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt) model.eval() epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn) train_acc.append(epoch_train_acc) train_loss.append(epoch_train_loss) test_acc.append(epoch_test_acc) test_loss.append(epoch_test_loss) template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}') print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss)) print('Done')

四、 结果可视化

import matplotlib.pyplot as plt #隐藏警告 import warnings warnings.filterwarnings("ignore") #忽略警告信息 plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号 plt.rcParams['figure.dpi'] = 100 #分辨率 from datetime import datetime current_time = datetime.now() # 获取当前时间 epochs_range = range(epochs) plt.figure(figsize=(12, 3)) plt.subplot(1, 2, 1) plt.plot(epochs_range, train_acc, label='Training Accuracy') plt.plot(epochs_range, test_acc, label='Test Accuracy') plt.legend(loc='lower right') plt.title('Training and Validation Accuracy') plt.xlabel(current_time) # 打卡请带上时间戳,否则代码截图无效 plt.subplot(1, 2, 2) plt.plot(epochs_range, train_loss, label='Training Loss') plt.plot(epochs_range, test_loss, label='Test Loss') plt.legend(loc='upper right') plt.title('Training and Validation Loss') plt.show()

五、个人总结

逐渐熟悉CNN模型构建过程,并逐步理解其原理。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/6/14 16:08:51

2026毕设ssm+vue基于框架的华建汽车出租系统设计与实现论文+程序

本系统(程序源码)带文档lw万字以上 文末可获取一份本项目的java源码和数据库参考。系统程序文件列表开题报告内容一、选题背景 关于动漫内容管理与推荐系统的研究,现有研究主要以传统CMS(内容管理系统)或通用媒体平台为…

作者头像 李华
网站建设 2026/6/15 12:39:06

AI应用架构师干货:GNN在医疗病历分析中的架构设计

AI应用架构师干货:GNN在医疗病历分析中的架构设计 一、引言 (Introduction) 钩子:医疗病历里的“隐藏关系”陷阱 凌晨3点,急诊室的张医生盯着电脑屏幕上的电子病历(EHR)眉头紧锁:52岁的糖尿病患者李阿姨&am…

作者头像 李华
网站建设 2026/6/15 13:32:34

1分钟原型开发:Vue3组件通信的即时验证方案

快速体验 打开 InsCode(快马)平台 https://www.inscode.net输入框内输入如下内容: 请设置一个即开即用的Vue3组件通信沙盒环境:1.预置父子组件基础结构;2.内置3种常用通信方法的代码片段(emit示例、provide示例、ref示例&#xf…

作者头像 李华
网站建设 2026/6/15 12:12:19

AI助力Windows Server 2016部署:自动化下载与配置指南

快速体验 打开 InsCode(快马)平台 https://www.inscode.net输入框内输入如下内容: 创建一个AI辅助工具,能够自动从微软官方或可信源下载Windows Server 2016 ISO镜像,并生成对应的校验信息。工具应包含以下功能:1)自动检测用户网…

作者头像 李华