news 2026/6/9 1:13:54

【动手学深度学习】笔记1:简单的线性回归

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
【动手学深度学习】笔记1:简单的线性回归

根据我们之前的对话,我为你整理了一份线性回归从零实现的学习笔记。这份笔记涵盖了数据生成、小批量迭代器、模型定义、损失函数、SGD优化器以及完整训练流程。你可以把它保存下来,经常复习。


线性回归从零实现 · 学习笔记

一、生成合成数据

defsynthetic_data(w,b,num_examples):"""生成 y = X w + b + 噪声"""X=torch.normal(0,1,(num_examples,len(w)))# 特征矩阵y=torch.matmul(X,w)+b# 线性部分y+=torch.normal(0,0.01,y.shape)# 加噪声returnX,y.reshape((-1,1))# y 转为列向量
  • 形状理解X形状(样本数, 特征数)w形状(特征数,)y形状(样本数,)。通过reshape((-1,1))变成列向量(样本数, 1)
  • 为什么要加噪声:模拟真实数据的不完美,让模型学会忽略小扰动。
  • 为什么要 reshape:便于后续矩阵运算,避免广播歧义。

二、手动实现小批量数据迭代器

defdata_iter(batch_size,features,labels):num_examples=len(features)indices=list(range(num_examples))random.shuffle(indices)# 每个 epoch 打乱顺序foriinrange(0,num_examples,batch_size):batch_indices=torch.tensor(indices[i:min(i+batch_size,num_examples)])yieldfeatures[batch_indices],labels[batch_indices]
  • indices:原始样本的索引列表[0,1,2,...],打乱后实现随机抽取。
  • batch_indices:当前批次对应的索引(张量形式)。
  • features[batch_indices]:按索引抽取子集,形状(batch_size, 特征数)
  • yield:生成器,每次返回一个批次,并保存状态,下次继续。节省内存。

三、初始化模型参数

w=torch.normal(0,0.01,size=(2,1),requires_grad=True)b=torch.zeros(1,requires_grad=True)
  • w形状(特征数, 1),便于矩阵乘法X @ w得到(batch_size, 1)
  • requires_grad=True:告诉 PyTorch 需要计算梯度。

四、定义线性回归模型

deflinreg(X,w,b):returntorch.matmul(X,w)+b
  • 等价于X @ w + b
  • 广播机制:b会自动加到每个样本的预测值上。

五、定义损失函数(均方误差的一半)

defsquared_loss(y_hat,y):return(y_hat-y.reshape(y_hat.shape))**2/2
  • 返回向量(batch_size, 1),每个元素是一个样本的半平方误差
  • 除以 2 是为了求导后系数为 1(导数 =y_hat - y)。
  • 训练时会对这个向量.sum()再反向传播。

六、定义优化算法(小批量 SGD)

defsgd(params,lr,batch_size):withtorch.no_grad():forparaminparams:param-=lr*param.grad/batch_size param.grad.zero_()
  • with torch.no_grad():禁用梯度追踪,手动更新参数时不要构建计算图。
  • param.grad是累积的梯度(因为之前l.sum().backward()求和后反向,梯度是总损失的导数)。
  • 除以batch_size:将总梯度转换为平均梯度,使更新步长与批量大小无关。
  • param.grad.zero_():梯度清零,否则下一批会累加。

七、训练循环

lr=0.03num_epochs=3net=linreg loss=squared_lossforepochinrange(num_epochs):forX,yindata_iter(batch_size,features,labels):l=loss(net(X,w,b),y)# 小批量损失向量l.sum().backward()# 反向传播计算梯度sgd([w,b],lr,batch_size)# 更新参数withtorch.no_grad():train_l=loss(net(features,w,b),labels)print(f'epoch{epoch+1}, loss:{float(train_l.mean()):f}')
  • 内层循环:每个 epoch 内遍历所有小批量。
  • l.sum().backward():因为l是向量,需先求和成标量再反向传播。梯度会累加到w.gradb.grad
  • 外层循环结束时:打印整个数据集上的平均损失,观察训练进展。

八、关键概念总结

概念说明
特征矩阵 X形状(num_examples, num_features),每行一个样本
权重 w形状(num_features, 1),每个特征对应一个权重
偏置 b标量或(1,),所有样本共享
预测值X @ w + b,形状(num_examples, 1)
损失函数向量输出(每个样本一个损失),求和后反向传播
梯度下降param -= lr * (param.grad / batch_size),然后清零梯度
生成器 (yield)按需产生数据,节省内存

九、常见问题自查

  1. 为什么y要 reshape 成列向量?
    为了与y_hat形状一致,便于矩阵运算,避免广播错误。

  2. 为什么要除以batch_size
    因为l.sum().backward()产生的梯度是总损失的梯度,而我们需要平均梯度,所以除以批量大小。

  3. 为什么损失函数要除以 2?
    为了让求导后的梯度表达式为(y_hat - y),没有系数 2,纯粹是为了数学简洁。

  4. with torch.no_grad()有什么用?
    告诉 PyTorch 不要跟踪这部分操作,避免构建不需要的计算图,节省内存和计算。

  5. 为什么要手动param.grad.zero_()
    梯度默认会累加,不清零的话下一次反向传播会将新旧梯度相加,导致错误。


十、扩展学习路径

  • 使用 PyTorch 高层 API:nn.Linear,MSELoss,optim.SGD
  • 增加验证集,绘制训练曲线
  • 尝试不同的学习率和批量大小,观察收敛情况
  • 扩展到多项式回归、逻辑回归

笔记到此结束。你可以根据自己的理解,在空白处补充例子或疑问。需要我进一步解释某个部分,随时可以继续提问。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/6/9 1:13:43

【环形缓冲区】1-概念与编程

【环形缓冲区】1-概念与编程 文章目录【环形缓冲区】1-概念与编程一、环形缓冲区的引入二、环形缓冲区 编程这一点很重要,后续还会做补充 一、环形缓冲区的引入 //下面是一个按键控制LCD的程序 main() {while(1){key read_key(); // 读取按键值LCD(key); // 耗…

作者头像 李华
网站建设 2026/6/9 1:11:58

UABEA:终极Unity游戏资源编辑完全指南

UABEA:终极Unity游戏资源编辑完全指南 【免费下载链接】UABEA c# uabe for newer versions of unity 项目地址: https://gitcode.com/gh_mirrors/ua/UABEA 你是否曾经想过深入探索Unity游戏内部,查看那些精美的纹理、音效和模型是如何工作的&…

作者头像 李华
网站建设 2026/6/9 1:09:02

C语言中的递归

C语言中的递归 递归是一种编程技巧,它允许函数直接或间接地调用自身。在C语言中,递归是一种强大的编程工具,它可以帮助我们解决许多问题,特别是那些可以分解为相似子问题的算法。本文将详细介绍C语言中的递归,包括递归的基本概念、递归函数的编写、递归的优缺点以及递归在…

作者头像 李华
网站建设 2026/6/9 1:08:02

Claude Code-Dynamic Workflows:1.为什么用工作流?

Claude Code-Dynamic Workflows:1.为什么用工作流? 为什么用工作流如果你经常让 Claude 做长任务,应该见过这种情况:它一开始很认真,越往后越像在“凭感觉收尾”。不是模型突然变差了,而是我们把太多事情塞…

作者头像 李华
网站建设 2026/6/9 1:00:01

SQLite数据操作避坑指南:从‘insert失败’到‘select显示乱’的常见问题排查(附字段名修改方法)

SQLite数据操作避坑指南:从‘insert失败’到‘select显示乱’的常见问题排查当你第一次尝试在SQLite中插入或查询数据时,可能会遇到各种意料之外的问题。这些问题看似简单,却足以让新手开发者陷入长时间的调试困境。本文将带你深入剖析SQLite…

作者头像 李华
网站建设 2026/6/9 0:56:02

小程序毕业设计-基于微信小程序的扶贫助农系统及其小程序的实现基于springboot+微信小程序的扶贫助农系统及其小程序的实现(源码+LW+部署文档+全bao+远程调试+代码讲解等)

博主介绍:✌️码农一枚 ,专注于大学生项目实战开发、讲解和毕业🚢文撰写修改等。全栈领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技术领域和毕业项目实战 ✌️技术范围:&am…

作者头像 李华