news 2026/5/27 19:20:07

使用torch.compile与梯度累积加速模型训练

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
使用torch.compile与梯度累积加速模型训练

训练一个具有深度Transformer架构的语言模型是耗时的。然而,有些技巧可以用来加速训练。在本文中,你将学习到:

  • 使用 torch.compile() 加速模型
  • 使用梯度累积来训练具有更大有效批次大小的模型

让我们开始吧!

概述

本文分为两个部分:

  • 使用 torch.compile()
  • 梯度累积

使用 torch.compile

当你在PyTorch中编写并运行模型代码时,它是在eager模式下执行的。这意味着代码是一行一行执行的,结果存储在内存中。这是Python的原生方式,因为它是一种解释型语言。你知道这一点是因为当代码出现错误时,只有运行到该行时才会看到错误提示。

在eager模式下运行模型速度较慢。从PyTorch 2.0开始,你可以使用torch.compile()来编译模型以提高性能。这会生成一个经过优化的新模型对象。它不是你用nn.Module创建的原始模型对象,但它与原始模型共享相同的张量。你可以像往常一样使用这个编译后的模型进行前向传播、反向传播和优化器更新。

将模型构建并编译成计算图正是TensorFlow 1.0的设计思路。这使得调试更加困难,因为你执行的模型无法与你编写的代码逐行对应。因此,在运行试验并确认模型没有错误之前,你不应该编译模型。

并非所有模型都可以编译。但是,如果你的模型支持编译,你将立即受益于速度提升。要编译一个模型,你只需要在准备使用模型之前替换模型对象:

... model = LlamaForPretraining(model_config).to(device) model.load_state_dict(checkpoint) model = torch.compile(model) ...

不要在编译后加载模型权重。这是因为编译后的模型是一个与原始模型共享权重的对象。在编译过程中,构建的计算图引用了原始模型的权重张量。如果你在编译后加载权重,模型可能无法按预期工作。

同样,要保存编译后的模型,你应该引用原始模型的状态字典,如下所示:

torch.save(getattr(model, "_orig_mod", model).state_dict(), "model.pth")

可以通过model._orig_mod访问编译模型中的原始模型。在上面的代码中,我们使用getattr(model, "_orig_mod", model)来获取原始模型(如果存在),或者如果不存在则使用模型本身。这行代码对编译模型和原始模型都适用。

梯度累积

当你训练一个模型时,你在反向传播上花费的时间可能是前向传播的两到三倍。这是因为反向传播计算强度更大,并且占用更多内存。

一个简单的加速训练技巧是减少反向传播的次数。这可以通过增加批次大小来实现:对于相同数量的数据样本,更大的批次大小意味着要处理的批次更少。

然而,更大的批次大小需要更多内存。在内存受限的环境中,你可以通过运行多次前向传播并累积梯度来模拟更大的批次大小。这被称为梯度累积

用代码来解释这个想法更容易:

.. accumulate_steps = 4 for epoch in range(num_epochs): optimizer.zero_grad() for i, batch in enumerate(dataloader): # 获取批次数据 input_ids, target_ids = batch # 创建注意力掩码:因果掩码 + 填充掩码 attn_mask = create_causal_mask(input_ids.shape[1], device) + \ create_padding_mask(input_ids, PAD_TOKEN_ID, device) # 从模型提取输出 logits = model(input_ids, attn_mask) # 计算损失:logits与目标之间的交叉熵,忽略填充标记 loss = loss_fn(logits.view(-1, logits.size(-1)), target_ids.view(-1)) loss = loss / accumulate_steps # 运行反向传播,但每`accumulate_steps`步才更新一次 loss.backward() if (i + 1) % accumulate_steps == 0: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() optimizer.zero_grad() scheduler.step()

上面的训练循环摘自上一篇关于在本地GPU上训练Llama模型的文章。

通常,当你运行一次前向传播时,你会计算损失。然后调用loss.backward()通过模型参数反向传播损失梯度。在PyTorch中,backward()方法是累积的,这意味着梯度是相加的。因此,你需要在运行反向传播之前显式调用optimizer.zero_grad()来清除梯度。

在上面的代码中,你故意不在每次迭代中都调用optimizer.zero_grad()。相反,你对损失(除以accumulate_steps)运行反向传播。这样,梯度被缩小但在accumulate_steps次迭代中累积。每经过accumulate_steps次迭代,你才运行优化器来调整模型参数。

这种方法产生的结果与使用更大批次大小获得的结果相当。然而,由于你运行的优化器更新次数更少,学习率调度器应相应调整。这意味着你需要用不同的步数来初始化调度器:

... num_training_steps = (len(dataloader) // accumulate_steps) * num_epochs cosine_scheduler = lr_scheduler.CosineAnnealingLR( optimizer, T_max=num_training_steps - num_warmup_steps, eta_min=0 )

进一步阅读

以下是一些你可能感兴趣的资料:

  • torch.compile 文档
  • PyTorch 文档中的自动混合精度示例

总结

在本文中,你了解到使用torch.compile()可以通过编译计算图来帮助你加速模型。你还了解到,梯度累积是一种通过累积多个小批次的梯度来训练更大有效批次大小的技术。由于这种方式减少了优化器更新次数,你可以节省反向传播和参数更新的时间。
更多精彩内容 请关注我的个人公众号 公众号(办公AI智能小助手)或者 我的个人博客 https://blog.qife122.com/
对网络安全、黑客技术感兴趣的朋友可以关注我的安全公众号(网络安全技术点滴分享)

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/25 16:35:36

Java毕设项目:基于springboot的无人机农田巡查系统的设计与实现(源码+文档,讲解、调试运行,定制等)

博主介绍:✌️码农一枚 ,专注于大学生项目实战开发、讲解和毕业🚢文撰写修改等。全栈领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技术领域和毕业项目实战 ✌️技术范围:&am…

作者头像 李华
网站建设 2026/5/12 8:05:21

BERT微调加速

💓 博客主页:借口的CSDN主页 ⏩ 文章专栏:《热点资讯》 BERT微调加速:边缘计算驱动的范式革新与未来路径目录BERT微调加速:边缘计算驱动的范式革新与未来路径 引言:微调瓶颈与加速的迫切性 维度一&#xff…

作者头像 李华
网站建设 2026/5/22 11:22:11

计算机毕业设计springboot法律咨询援助平台 基于 SpringBoot 的在线法律求助与知识共享系统 轻量级法律智能服务平台的设计与实现(SpringBoot 架构)

计算机毕业设计springboot法律咨询援助平台(配套有源码 程序 mysql数据库 论文) 本套源码可以在文本联xi,先看具体系统功能演示视频领取,可分享源码参考。当“法律”遇上“互联网”,咨询不再受限于时间与地域。城乡差距、费用门槛…

作者头像 李华
网站建设 2026/5/22 23:45:51

Java毕设项目推荐-基于springboot+vue的鲜花盆栽绿植销售系统设计与实现基于springboot的鲜花销售管理系统的设计与实现【附源码+文档,调试定制服务】

博主介绍:✌️码农一枚 ,专注于大学生项目实战开发、讲解和毕业🚢文撰写修改等。全栈领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技术领域和毕业项目实战 ✌️技术范围:&am…

作者头像 李华
网站建设 2026/5/5 17:22:05

DDR3带宽计算

一、DDR3带宽计算 1.需要明确其理论带宽 2.结合实际应用需求进行评估 3.计算DDR3带宽的核心参数:有效数据传输速率和总线的位宽 理论带宽(GB/s) ‌有效数据传输速率(MT/s) 内存总线位宽(bit) 8…

作者头像 李华