当千亿参数遇上显存瓶颈,ZeRO用一套“分片存储、按需加载”的哲学,让模型的三个“内存大户”不再冗余,实现了显存占用的线性解耦。
在搭建“训练篇”大模型加速体系的过程中,我们先后学习了梯度累积、激活重计算与3D并行。然而,当模型膨胀到千亿参数级别,即使将这些技术用到极致,单张GPU卡的显存依然捉襟见肘。
在单卡完成梯度累积与激活重计算,通过混合并行将模型切分后,ZeRO(零冗余优化器)将分布式训练的显存优化推向了极致。
一、ZeRO的核心思想:分片存储,按需加载
ZeRO全称Zero Redundancy Optimizer——“零冗余优化器”。在普通的数据并行(DDP)中,每张GPU各自保留一份完整模型副本,存储优化器状态、梯度和参数。当一个百亿参数模型使用Adam优化器训练时,即便参数本身(FP16)只要20GB,其优化器状态(动量等 FP32)与梯度等加起来,总显存需求会飙升到参数量的十几倍,这便是ZeRO要解决的“冗余”与“浪费”。
ZeRO的核心思想是:分片存储、按需加载。它将庞大的“模型状态”(优化器状态、梯度、参数)切分成多个分片,分布式地存储在不同GPU上,并在执行某一层的计算时,通过高效的集体通信,临时从其他设备拉取当前所需分片,计算完成后立即释放。
二、ZeRO三步走:从ZeRO-1到ZeRO-3
ZeRO策略通过三个阶段的渐进式优化,实现显存占用