news 2026/5/1 6:12:54

Day 41 Dataset 与 DataLoader

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
Day 41 Dataset 与 DataLoader

文章目录

  • Day 41 · Dataset 与 DataLoader
      • torchvision 模块速览
      • Step 1 · 定义 `transforms` 管道
    • 一、Dataset:定义“单份数据”
      • 1. 图片观察
      • 2. 两个必须的魔术方法
        • `__getitem__`:让对象支持索引
        • `__len__`:让对象支持 `len()`
      • 3. 自定义 `Dataset` 的伪代码
    • 二、DataLoader:批量调度器
    • 三、总结

Day 41 · Dataset 与 DataLoader

在训练大规模数据集时,显存通常无法一次性装下所有样本,因此必须按批次把数据送入模型。PyTorch 为此提供了两个密不可分的组件:

  1. Dataset:描述每一条数据长什么样、如何读取、是否需要预处理。
  2. DataLoader:负责把一个Dataset切成批次、决定是否乱序、是否并行加载。

下面以经典的MNIST 手写数字数据集为例(训练集 60k、测试集 10k、每张 28×28 灰度图),逐步梳理两者分工。

importtorchimporttorch.nnasnnimporttorch.optimasoptimfromtorch.utils.dataimportDataLoader,Datasetfromtorchvisionimportdatasets,transformsimportmatplotlib.pyplotasplt# 为可复现性固定随机种子torch.manual_seed(42)
<torch._C.Generator at 0x77eb27fe76f0>

torchvision 模块速览

torchvision ├── datasets # 视觉数据集(如 MNIST、CIFAR) ├── transforms # 视觉数据预处理(裁剪、翻转、归一化等) ├── models # 各类预训练模型 ├── utils # 目标检测等常用工具函数 └── io # 图像 / 视频 IO

Step 1 · 定义transforms管道

transforms.Compose可以像数据管道一样串联多步操作,这里先把 PIL 图转成张量,再用 MNIST 的均值、方差做标准化。

# 1. 数据预处理,该写法非常类似于管道pipeline# transforms 模块提供了一系列常用的图像预处理操作# 先归一化,再标准化transform=transforms.Compose([transforms.ToTensor(),# 转换为张量并归一化到[0,1]transforms.Normalize((0.1307,),(0.3081,))# 标准化。MNIST数据集的均值和标准差,这个值很出名,所以直接使用])
# Step 2 · 加载数据集。如果本地没有,会自动下载到 ./datatrain_dataset=datasets.MNIST(root='./data',train=True,download=True,transform=transform)test_dataset=datasets.MNIST(root='./data',train=False,transform=transform)
100%|██████████| 9.91M/9.91M [00:07<00:00, 1.37MB/s] 100%|██████████| 28.9k/28.9k [00:00<00:00, 152kB/s] 100%|██████████| 1.65M/1.65M [00:00<00:00, 1.70MB/s] 100%|██████████| 4.54k/4.54k [00:00<00:00, 15.8MB/s]

PyTorch 的思路是:在“读取数据”这一环就完成预处理,因此transform直接写进datasets.MNIST的构造函数里。

一、Dataset:定义“单份数据”

  • 负责描述数据的来源、存储方式以及取出单个样本所需的所有步骤。
  • 必须能够在索引访问时返回(features, target),并能报告自身的长度。

1. 图片观察

Dataset实例支持下标操作,因此可以像访问列表一样通过索引获取单张图像及其标签。

# 随机选择一张图片,可以重复运行,每次都会随机选择sample_idx=torch.randint(0,len(train_dataset),size=(1,)).item()# 随机选择一张图片的索引# len(train_dataset) 表示训练集的图片数量;size=(1,)表示返回一个索引;torch.randint() 函数用于生成一个指定范围内的随机数,item() 方法将张量转换为 Python 数字image,label=train_dataset[sample_idx]# 获取图片和标签

2. 两个必须的魔术方法

torch.utils.data.Dataset是一个抽象基类,自定义数据集需要重写:

  • __len__():返回样本总数,供len(dataset)或 DataLoader 计算迭代次数。
  • __getitem__(idx):根据索引返回单个样本(通常是(data, label))。

因此train_dataset[sample_idx]才会得到(image, label)

__getitem__:让对象支持索引
classMyList:# 仅为解释魔术方法的示例def__init__(self):self.data=[10,20,30,40,50]def__getitem__(self,idx):returnself.data[idx]my_list_obj=MyList()print(my_list_obj[2])# 输出 30
30
__len__:让对象支持len()
classMyList:def__init__(self):self.data=[10,20,30,40,50]def__len__(self):returnlen(self.data)my_list_obj=MyList()print(len(my_list_obj))# 输出 5
5

3. 自定义Dataset的伪代码

常见写法是在构造函数中读入路径或内存数据,并保存transform__getitem__返回预处理后的样本。

classMNIST(Dataset):def__init__(self,root,train=True,transform=None):self.data,self.targets=fetch_mnist_data(root,train)# 假设这里完成原始数据读取self.transform=transformdef__len__(self):returnlen(self.data)def__getitem__(self,idx):img,target=self.data[idx],self.targets[idx]ifself.transformisnotNone:img=self.transform(img)returnimg,target
组件职责关键方法
Dataset1. 存储数据和标签的映射关系
2. 定义单样本的获取方式
3. 应用样本级预处理(如缩放、裁剪)
__getitem__(idx)
__len__()
DataLoader1. 批量组织样本
2. 并行加载数据
3. 打乱数据顺序
4. 处理多进程问题
迭代器接口(iter()next()
  • 可以把Dataset 想成“厨师”:负责挑选食材、清洗、调味(预处理)。
  • DataLoader 则像“服务员”:按订单把菜(批次)端给模型。
defimshow(img):img=img*0.3081+0.1307# 反标准化回原始像素npimg=img.numpy()plt.imshow(npimg[0],cmap='gray')plt.axis('off')plt.show()print(f"Label:{label}")imshow(image)
Label: 6

二、DataLoader:批量调度器

DataLoader 根据我们提供的Dataset产出一个可迭代对象,它负责:

  • batch_size聚合样本;
  • 根据shuffle决定是否随机打乱顺序;
  • 通过num_workers控制并行加载进程数。
train_loader=DataLoader(train_dataset,batch_size=64,shuffle=True)test_loader=DataLoader(test_dataset,batch_size=1000# 测试集通常不需要 shuffle)

三、总结

维度DatasetDataLoader
核心职责定义“数据是什么”以及如何得到单个样本决定怎样批量、按顺序或乱序地取数据
核心接口__getitem____len__通过参数控制加载逻辑,无需继承
预处理__getitem__transform中完成不做预处理,直接消费 Dataset 的输出
并行能力单线程读取num_workers>0时可多进程读取
典型参数roottransformtarget_transformbatch_sizeshufflenum_workers

@浙大疏锦行

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/1 6:10:16

MySQL UPDATE 更新操作详解

MySQL UPDATE 更新操作详解 引言 MySQL 是一款广泛使用的开源关系型数据库管理系统,其灵活的查询语句和强大的数据管理能力使其在各类应用场景中扮演着重要角色。本文将详细介绍 MySQL 的 UPDATE 更新操作,包括其语法、使用场景以及注意事项。 一、UPDATE 语法 UPDATE 语…

作者头像 李华
网站建设 2026/5/1 6:10:31

LobeChat能否支持脑机接口?未来人机交互形态设想

LobeChat能否支持脑机接口&#xff1f;未来人机交互形态设想 在智能设备越来越“懂”人的今天&#xff0c;我们对交互方式的期待早已超越了键盘敲击和语音唤醒。想象这样一个场景&#xff1a;一位渐冻症患者躺在床上&#xff0c;仅靠凝视与思维&#xff0c;就能通过AI助手向家人…

作者头像 李华
网站建设 2026/4/27 1:12:26

大数据领域数据仓库的流处理框架选型

大数据领域数据仓库的流处理框架选型关键词&#xff1a;数据仓库、流处理、Apache Kafka、Apache Flink、Apache Spark、实时计算、批流一体摘要&#xff1a;本文深入探讨大数据领域中数据仓库的流处理框架选型问题。我们将从流处理的基本概念出发&#xff0c;分析主流流处理框…

作者头像 李华
网站建设 2026/4/13 8:56:06

20、Vim搜索功能全解析

Vim搜索功能全解析 1. 重复搜索选项 在Vim中,有多种方式可以重复搜索,以下是相关命令及其效果: | 效果 | 命令 | | — | — | | 保持方向和偏移,跳转到下一个匹配项 | n | | 保持方向和偏移,跳转到上一个匹配项 | N | | 向前跳转到相同模式的下一个匹配项 | / | |…

作者头像 李华