news 2026/5/31 12:00:18

【李沐 | 动手实现深度学习】9-2 Pytorch神经网络基础

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
【李沐 | 动手实现深度学习】9-2 Pytorch神经网络基础

前面整理了第5章的前半部分可以移步m【李沐 | 动手实现深度学习】9-1 Pytorch神经网络基础

下面是后半部分。

3 自定义层

前面我们深刻感受到了深度学习神经网络的灵活性:我们可以创造性地组合不同的层/块,从而设计出适用于目标任务的架构。有些情况下,我们可能要自己“Create”一个高级API没有提供的层,下面我们来看如何构建自定义层。

3.1 自定义无参数的层

我们自定义的CenteredLayer 在__init__中仅完成了父类的初始化,没有通过sefl.xxx的方式定义任何参数;其前向传播→从输入中减去均值。

import torch import torch.nn.functional as F from torch import nn class CenteredLayer(nn.Module): def __init__(self): super().__init__() def forward(self, X): return X - X.mean()

我们看看它能否按预期工作。

layer = CenteredLayer() layer(torch.FloatTensor([1, 2, 3, 4, 5]))

//输出示例:可以看到CenteredLayer()如期完成了工作

现在将层作为组件合并到构建更复杂的模型中,我们向该网络传入随机数据,检查均值是否为0。

net = nn.Sequential(nn.Linear(8, 128), CenteredLayer()) Y = net(torch.rand(4, 8)) Y.mean()

//输出示例:可以看到均值为0,由于处理的是浮点数,由于精度的原因看到是一个非常小的非零数

3.2 自定义带参数的层

在MyLinear()的初始化函数中:使用nn.Parameter封装了两个torch.randn(随机正态分布)初始化的张量:self.weightself.bias。

前向传播实现了线性变换 Y = XW + B 随后接上 ReLU 激活函数

注意:

  • 当在forward函数中使用.data时,实质上是在告诉 PyTorch:“这次计算中,把self.weightself.bias看作普通的、不参与梯度追踪的常量。”
  • 尽管参数的requires_grad属性仍然是True,但在这次特定的前向计算中,它们被当作非可训练张量使用,导致模型无法学习和更新参数
class MyLinear(nn.Module): def __init__(self, in_units, units): super().__init__() # in_units-输入单元数,units-输出的特征数量 self.weight = nn.Parameter(torch.randn(in_units, units)) self.bias = nn.Parameter(torch.randn(units,)) def forward(self, X): linear = torch.matmul(X, self.weight.data) + self.bias.data return F.relu(linear)

下面我们实例化MyLinear类并访问其模型参数。

我们可以使用自定义层直接执行前向传播计算。

使用自定义层构建模型。

4 读写文件

有时候我们希望将训练好的参数保存下载复用到其他场景,怎么做呢?此外,当运行一个耗时较长的训练过程时, 最佳的做法是定期保存中间结果, 以确保在服务器电源被不小心断掉时,我们不会损失几天的计算结果。

4.1 加载和保存张量

import torch from torch import nn from torch.nn import functional as F x = torch.arange(4) torch.save(x, 'x-file') # 保存为名为'x-file'的文件 x2 = torch.load('x-file') # 下载名为'x-file'的文件 x2

//输出示例:

我们可以存储一个张量列表,再把它读回内存。

y = torch.zeros(4) torch.save([x, y], 'xy-file') x2, y2 = torch.load('xy-file') (x2,y2)

//输出示例:

我们可以读取或写入从字符串映射到张量的字典

mydict = {'x':x, 'y':y} torch.save(mydict, 'mydict') mydict2 = torch.load('mydict') mydict2

4.2 加载和保存模型参数

保存单个权重张量确实有用,不过往往在实际应用中可能会涉及成百上千个参数,难以逐个保存。因此,深度学习框架提供了内置函数来保存和加载整个网络

需要注意的是,机器做的是保存模型参数而非整个模型,因为模型本身本身可以包含任意代码,是难以序列化的。因此,为了恢复模型,我们需要用代码生成架构, 然后从磁盘加载参数。

以一个3层MLP为例:

class MLP(nn.Module): def __init__(self): super().__init__() self.hidden = nn.Linear(20, 256) self.output = nn.Linear(256, 10) def forward(self, x): return self.output(F.relu(self.hidden(x))) net = MLP() X = torch.randn(size=(2, 20)) Y = net(X)

我们将模型的参数存储为“xxx.params”的文件。(第二个数指定路径,若只有文件名默认与代码文件同一文件夹)

  • xxx.params格式:这是 PyTorch 中常用的文件扩展名,用于存储模型的权重和缓冲区张量
torch.save(net.state_dict(), 'mlp.params')

//输出示例:

  • torch.load('mlp.params'):从磁盘文件中加载一个序列化的 Python 对象
  • clone.load_state_dict(...): 这是核心的加载操作。它将磁盘文件中加载的状态字典(包含所有参数和缓冲区的张量值)精确地复制clone模型的对应参数位置上。
  • clone.eval(): 将模型clone切换到评估模式 (Evaluation Mode)

→ nn.Module的模式管理:PyTorch 中的某些层(例如nn.Dropoutnn.BatchNorm,即批量归一化层)在训练时和推理时需要有不同的行为:

  • 训练模式 (.train()):Dropout随机丢弃神经元;BatchNorm计算并更新运行平均值和方差 (running mean/var)。
  • 评估模式 (.eval()):Dropout失效(不丢弃任何神经元);BatchNorm冻结其统计数据,使用在训练过程中学到的固定的运行平均值和方差。

在加载模型进行推理时,必须调用clone.eval()来确保这些层以正确的方式运行,保证输出的确定性和准确性。

clone = MLP() # 先把模型声明出来,再把参数放回去 clone.load_state_dict(torch.load('mlp.params')) clone.eval()

//输出示例:可以看到 clone(X) 的结果与 net(X) 的结果完全一致,说明模型参数成功加载到了 clone的相应位置。

总结

Pytorch深度学习框架为我们自定义层提供了强大的灵活性,可以自定义带参/无参的层,并可以约其他层/块组合使用。此外,有时候训练模型是一件很“贵”的事情,我们可以通过保存模型参数的方式使得模型可以复用到其他环境,save-保存,load-加载。

愉快的一天鼬鼬鼬鼬过去了~~~我又学废了😴😴😴

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/30 16:00:12

使用RPCA算法对图像进行稀疏低秩分解

使用RPCA(鲁棒主成分分析)算法对图像进行稀疏低秩分解。 RPCA能够将图像分解为低秩部分(背景/主要成分)和稀疏部分(前景/噪声/异常)。 RPCA算法原理 RPCA旨在解决以下优化问题: min ‖L‖* λ‖…

作者头像 李华
网站建设 2026/5/30 14:22:26

嘉楠携手SynVista打造可再生能源驱动的自适应比特币矿机

嘉楠耘智与SynVista合作打造可再生能源矿机 比特币矿机及硬件制造商嘉楠耘智已达成一项合作协议,将共同开发一个可再生能源自适应的比特币挖矿平台。此举扩大了该公司对绿色能源的关注,因为整个行业正在寻求可持续的方式来满足其电力需求。 嘉楠耘智周一…

作者头像 李华
网站建设 2026/5/31 5:30:02

如何在云服务器上用Miniconda快速部署大模型训练环境?

如何在云服务器上用Miniconda快速部署大模型训练环境?在如今的大模型时代,一个常见的场景是:你刚申请了一台带有GPU的云服务器,准备复现一篇论文或启动新的训练任务。可还没开始写代码,就被各种依赖问题卡住——Python…

作者头像 李华
网站建设 2026/5/29 4:27:18

介绍 from typing import Optional

from typing import Optional 引入的是 Python 类型注解体系中的一个基础工具。下面给你一个不兜圈子、直接到位的说明,并顺便指出很多人理解上的误区。一句话定义 Optional[T] 表示:一个值要么是 T 类型,要么是 None。 等价写法:…

作者头像 李华
网站建设 2026/5/30 18:04:55

Qwen3-14B与主流transformer模型的推理速度对比

Qwen3-14B与主流Transformer模型的推理速度对比 在当前企业级AI系统的设计中,一个核心挑战逐渐浮现:如何让大语言模型既具备强大的语义理解能力,又能以毫秒级响应满足真实业务场景的需求。尤其是在智能客服、合同审查、自动化工单等对延迟敏感…

作者头像 李华
网站建设 2026/5/31 5:38:32

vLLM vs 传统推理框架:性能对比实测报告

vLLM vs 传统推理框架:性能对比实测报告 在大模型落地进入深水区的今天,一个现实问题摆在每个AI工程师面前:为什么训练好的千亿参数模型,一到线上就“卡成PPT”?用户等得不耐烦,服务器烧钱如流水——这背后…

作者头像 李华