模型训练中的浮点数格式:FP32、FP16、BF16
1. 浮点数基础
IEEE 754 浮点数由三部分组成:
x=(−1)sign×2exponent−bias×(1+mantissa)x = (-1)^{sign} \times 2^{exponent - bias} \times (1 + mantissa)x=(−1)sign×2exponent−bias×(1+mantissa)
- 符号位 (Sign):决定正负
- 指数位 (Exponent):决定数值范围
- 尾数位 (Mantissa/Fraction):决定数值精度
直觉理解:尺子比喻
可以把浮点数想象成一把刻度尺:
指数位 = 尺子的长度(能量到多远、多近的地方) 尾数位 = 尺子的刻度密度(相邻两个可表示数之间有多细)FP32: ||||||||||||||||||||||||||||||||||||||||||||||||| 很长的尺子,刻度极密 FP16: ||||||||||| 短尺子,刻度中等 BF16: |||| |||| |||| |||| |||| |||| |||| |||| |||| 和 FP32 一样长,但刻度稀疏- 指数位越多→ 尺子越长,能量到极大和极小的数(范围大,不容易溢出)
- 尾数位越多→ 刻度越密,相邻两个数的间距越小(精度高,舍入误差小)
- BF16 的设计哲学:宁可刻度粗一点,也要保证尺子够长——对深度学习来说,“量不到”(下溢为 0)比"量得粗"(精度低)危害大得多
2. 三种格式详解
2.1 FP32(单精度浮点数)
| 1 bit sign | 8 bits exponent | 23 bits mantissa | | S | EEEEEEEE | MMMMMMMMMMMMMMMMMMMMMMM |- 总位宽:32 bits(4 bytes)
- 指数位:8 bits → 范围≈±3.4×1038\approx \pm 3.4 \times 10^{38}≈±3.4×1038
- 尾数位:23 bits → 精度约7 位有效十进制数字
- 深度学习中的默认格式,精度高但显存占用大
2.2 FP16(半精度浮点数)
| 1 bit sign | 5 bits exponent | 10 bits mantissa | | S | EEEEE | MMMMMMMMMM |- 总位宽:16 bits(2 bytes)
- 指数位:5 bits → 范围≈±6.5×104\approx \pm 6.5 \times 10^{4}≈±6.5×104
- 尾数位:10 bits → 精度约3 位有效十进制数字
- 范围窄,容易出现上溢 (overflow)和下溢 (underflow)
2.3 BF16(Brain Floating Point 16)
| 1 bit sign | 8 bits exponent | 7 bits mantissa | | S | EEEEEEEE | MMMMMMM |- 总位宽:16 bits(2 bytes)
- 指数位:8 bits → 范围≈±3.4×1038\approx \pm 3.4 \times 10^{38}≈±3.4×1038(与 FP32 相同)
- 尾数位:7 bits → 精度约2 位有效十进制数字
- 由 Google Brain 提出,专为深度学习设计
3. 特性对比
| 特性 | FP32 | FP16 | BF16 |
|---|---|---|---|
| 总位宽 | 32 bits | 16 bits | 16 bits |
| 符号位 | 1 | 1 | 1 |
| 指数位 | 8 | 5 | 8 |
| 尾数位 | 23 | 10 | 7 |
| 数值范围 | ±3.4×1038\pm 3.4 \times 10^{38}±3.4×1038 | ±6.5×104\pm 6.5 \times 10^{4}±6.5×104 | ±3.4×1038\pm 3.4 \times 10^{38}±3.4×1038 |
| 精度 | ~7 位十进制 | ~3 位十进制 | ~2 位十进制 |
| 显存占用 | 1× (基准) | 0.5× | 0.5× |
| 计算速度 | 1× (基准) | ~2× | ~2× |
| 溢出风险 | 低 | 高 | 低 |
| 精度损失 | 无 | 中等 | 较大 |
| 硬件支持 | 所有 GPU | 大部分 GPU | A100+, TPU, AMD MI300X GPU( 最近我train model 在使用的) |
4. BF16 vs FP16 下溢特性对比
4.1 什么是下溢 (Underflow)
当一个数的绝对值小于格式能表示的最小正数时,就会发生下溢——该值被截断为 0,信息彻底丢失。
∣x∣<ϵmin⇒x→0|x| < \epsilon_{min} \Rightarrow x \to 0∣x∣<ϵmin⇒x→0
4.2 最小可表示正数
| 属性 | FP16 | BF16 |
|---|---|---|
| 指数位 | 5 bits | 8 bits |
| 指数偏移 (bias) | 15 | 127 |
| 最小正规数 (normal) | 2−14≈6.1×10−52^{-14} \approx 6.1 \times 10^{-5}2−14≈6.1×10−5 | 2−126≈1.18×10−382^{-126} \approx 1.18 \times 10^{-38}2−126≈1.18×10−38 |
| 最小非正规数 (subnormal) | 2−24≈5.96×10−82^{-24} \approx 5.96 \times 10^{-8}2−24≈5.96×10−8 | 2−133≈9.18×10−412^{-133} \approx 9.18 \times 10^{-41}2−133≈9.18×10−41 |
BF16 的最小可表示数比 FP16 小了约33 个数量级,下溢空间远大于 FP16。
4.3 训练中的下溢场景
| 场景 | 典型数值量级 | FP16 | BF16 |
|---|---|---|---|
| 小梯度值 | 10−5∼10−810^{-5} \sim 10^{-8}10−5∼10−8 | 容易下溢为 0 | 安全 |
| 学习率 × 梯度 | 10−4×10−5=10−910^{-4} \times 10^{-5} = 10^{-9}10−4×10−5=10−9 | 下溢 | 安全 |
| 权重衰减项 | λ⋅w≈10−7\lambda \cdot w \approx 10^{-7}λ⋅w≈10−7 | 边缘危险 | 安全 |
| Softmax 中间值 | e−xe^{-x}e−x可达10−2010^{-20}10−20 | 下溢 | 安全 |
| BatchNorm 方差 | 可能极小 | 有风险 | 安全 |
4.4 实际影响
FP16 下溢链: 小梯度 → 下溢为 0 → 权重停止更新 → 训练停滞或 loss 发散 ↳ 必须用 Loss Scaling 放大梯度来缓解 BF16 下溢链: 小梯度 → 仍在可表示范围内 → 正常更新(但精度较低) ↳ 不需要 Loss Scaling4.5 总结对比
| 对比维度 | FP16 | BF16 |
|---|---|---|
| 下溢阈值 | ≈10−8\approx 10^{-8}≈10−8 | ≈10−41\approx 10^{-41}≈10−41 |
| 下溢风险 | 高(训练中频繁触发) | 极低(与 FP32 一致) |
| 是否需要 Loss Scaling | 必须,否则训练不稳定 | 不需要 |
| 精度表现 | 不下溢时精度更高 (10 bit 尾数) | 精度略低 (7 bit 尾数),但不会丢失为 0 |
| 核心取舍 | 高精度,窄范围 → 易丢失小值 | 低精度,宽范围 → 小值保留但精度粗 |
关键结论:FP16 的下溢问题是混合精度训练复杂性的根源(需要 GradScaler + Loss Scaling)。BF16 通过保留 FP32 的 8 位指数,从根本上消除了下溢问题,以牺牲少量精度换取训练的稳定性和简洁性。
5. 训练中的实际应用
4.1 纯 FP32 训练
- 最安全,精度最高
- 显存占用大,速度慢
- 适合小模型或对精度要求极高的场景
4.2 混合精度训练 (Mixed Precision, FP16)
- 前向 & 反向传播:用 FP16 加速计算
- 权重主副本 (Master Weights):用 FP32 保存,防止梯度更新丢失
- Loss Scaling:将 loss 放大后再反向传播,防止小梯度下溢为 0, 所以deepspeed 中有loss scale的参数
- 框架支持:PyTorch
torch.cuda.amp、TFtf.keras.mixed_precision
# PyTorch 混合精度示例scaler=torch.cuda.amp.GradScaler()withtorch.cuda.amp.autocast():# 自动选择 FP16/FP32output=model(input)loss=criterion(output,target)scaler.scale(loss).backward()# loss scalingscaler.step(optimizer)scaler.update()4.3 BF16 训练
- 范围与 FP32 相同,无需 loss scaling
- 使用更简单,直接替换即可
- 大模型训练的主流选择(GPT-3、LLaMA 等均使用 BF16)
# PyTorch BF16 示例withtorch.cuda.amp.autocast(dtype=torch.bfloat16):output=model(input)loss=criterion(output,target)loss.backward()# 不需要 GradScaleroptimizer.step()6. 如何选择
| 场景 | 推荐格式 |
|---|---|
| 硬件不支持半精度 | FP32 |
| 旧 GPU (V100 等) | FP16 混合精度 |
| 新 GPU (A100/H100) | BF16(首选) |
| 大语言模型训练 | BF16 |
| 推理部署 | FP16 或 INT8/INT4 量化 |
一句话总结
BF16 = FP32 的范围 + FP16 的速度,是当前大模型训练的最佳平衡点。