news 2026/5/16 5:26:16

逐行拆解DeepSpeed Chat的RLHF三阶段:从SFT、RM到PPO的源码实现

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
逐行拆解DeepSpeed Chat的RLHF三阶段:从SFT、RM到PPO的源码实现

1. DeepSpeed Chat与RLHF技术全景

如果你正在寻找一个完整的RLHF实现方案,微软开源的DeepSpeed Chat(简称DSC)绝对值得深入研究。这个项目将ChatGPT风格的三阶段训练流程完整呈现,从监督微调(SFT)到奖励模型(RM)训练,再到近端策略优化(PPO),每个阶段的代码实现都清晰可见。

我在实际使用中发现,DSC最突出的优势是它的DeepSpeedHybridEngine技术。传统方案中,模型在训练(参数更新)和推理(经验采集)模式间切换会显著拖慢速度。而DSC通过这个创新引擎,让模型同时获得两种模式的优化,整体训练速度提升明显。

1.1 RLHF三阶段核心逻辑

让我们先快速梳理三个阶段的关键任务:

  • Phase1 SFT:用优质对话数据微调基座模型,使其掌握指令跟随能力
  • Phase2 RM训练:通过人类偏好数据训练奖励模型,学会给回复质量打分
  • Phase3 PPO:通过强化学习迭代优化策略,使模型输出更符合人类偏好

1.2 代码结构概览

DSC的代码组织非常清晰:

DeepSpeedExamples/ └── applications/ └── DeepSpeed-Chat/ ├── training/ │ ├── step1_supervised_finetuning/ │ ├── step2_reward_model_finetuning/ │ └── step3_rlhf_finetuning/ └── utils/ ├── data/ # 数据处理工具 └── model/ # 模型定义

2. 阶段一:监督微调(SFT)源码解析

2.1 数据准备关键代码

SFT阶段使用chosen_sentence作为训练数据,即人类偏好的完整对话。数据加载流程在create_prompt_dataset()函数中实现:

def create_dataset_split(current_dataset, raw_dataset, train_phase, tokenizer): if train_phase == 1: for tmp_data in current_dataset: chosen_sentence = raw_dataset.get_prompt_and_chosen(tmp_data) chosen_token = tokenizer(chosen_sentence, truncation=True) chosen_dataset.append(chosen_token) return PromptDataset(chosen_dataset)

这里有个实用技巧:当处理自定义数据集时,需要继承PromptRawDataset类并重写数据加载方法。我在项目中就遇到过需要适配特殊数据格式的情况,这个设计让扩展变得非常灵活。

2.2 模型训练流程

主训练脚本main.py的核心逻辑:

  1. 加载tokenizer和基座模型
  2. 可选启用LoRA技术
  3. 准备DataLoader
  4. 使用DeepSpeedEngine封装模型
  5. 以困惑度(perplexity)为指标进行训练

困惑度计算值得关注:

perplexity = torch.exp(losses) # 通过交叉熵损失转换得到

2.3 LoRA实现细节

DSC支持通过以下代码快速添加LoRA适配器:

if args.lora_dim > 0: model = convert_linear_layer_to_lora(model) if args.only_optimize_lora: model = only_optimize_lora_parameters(model)

实际测试发现,在7B模型上使用LoRA能减少约75%的可训练参数,显存占用从22GB降至14GB。

3. 阶段二:奖励模型(RM)训练详解

3.1 数据格式设计

RM训练使用成对的(chosen, rejected)数据,例如:

{ "chosen": "Human: Explain AI\nAssistant: AI is...", # 优质回复 "rejected": "Human: Explain AI\nAssistant: I don't know" # 劣质回复 }

3.2 模型架构关键点

RM由主干网络+线性评分头组成:

class RewardModel(nn.Module): def __init__(self, base_model): self.rwtransformer = base_model # 主干网络 self.v_head = nn.Linear(hidden_size, 1) # 评分头

实际使用时发现,不同主干网络的选择对RM性能影响很大。我们在OPT-1.3B和LLaMA-7B上的对比实验显示,更大的主干网络能显著提升评分准确性。

3.3 成对排序损失实现

核心损失函数计算:

def forward(self, input_ids): rewards = self.v_head(hidden_states).squeeze(-1) chosen_rewards = rewards[:bs] # 前半batch是chosen rejected_rewards = rewards[bs:] # 后半batch是rejected loss = -torch.log(torch.sigmoid(c_truncated_reward - r_truncated_reward)).mean() return loss

这里有个工程细节:DSC会截取chosen和rejected回复中第一个分叉点之后的内容进行比较,确保对比的公平性。

3.4 评估指标设计

RM使用排序准确率作为评估指标:

correct_predictions = (chosen_scores > rejected_scores).sum() accuracy = correct_predictions / total_predictions

我们在实际训练中发现,当验证集准确率达到70%以上时,RLHF阶段的效果就会有明显提升。

4. 阶段三:PPO强化学习实现

4.1 经验数据生成流程

PPO阶段最复杂的部分就是经验数据的生成,主要步骤:

  1. 用当前actor生成回复
  2. 计算4个关键值:
    • actor_logprobs:当前策略的概率
    • ref_logprobs:参考策略的概率
    • values:critic的价值估计
    • rewards:RM的评分
def generate_experience(self, prompts): seq = self.actor.generate(prompts) # 生成回复 with torch.no_grad(): actor_logits = self.actor(seq).logits ref_logits = self.ref(seq).logits rewards = self.reward_model.forward_value(seq) values = self.critic.forward_value(seq) return { 'logprobs': gather_log_probs(actor_logits, seq), 'ref_logprobs': gather_log_probs(ref_logits, seq), 'values': values, 'rewards': rewards }

4.2 PPO核心算法实现

4.2.1 奖励计算

加入KL散度惩罚的奖励计算:

def compute_rewards(self, log_probs, ref_log_probs, rewards): kl_penalty = -self.kl_ctl * (log_probs - ref_log_probs) return kl_penalty + rewards
4.2.2 优势估计

使用GAE(广义优势估计)计算优势值:

def get_advantages(self, rewards, values): lastgaelam = 0 advantages = [] for t in reversed(range(len(rewards))): delta = rewards[t] + gamma * values[t+1] - values[t] lastgaelam = delta + gamma * lam * lastgaelam advantages.append(lastgaelam) return advantages[::-1]
4.2.3 策略损失

包含裁剪的策略梯度损失:

def actor_loss(logprobs, old_logprobs, advantages): ratio = torch.exp(logprobs - old_logprobs) pg_loss1 = -advantages * ratio pg_loss2 = -advantages * torch.clamp(ratio, 1-eps, 1+eps) return torch.max(pg_loss1, pg_loss2).mean()
4.2.4 价值损失

带裁剪的价值函数更新:

def critic_loss(values, old_values, returns): values_clipped = old_values + (values - old_values).clamp(-eps, eps) vf_loss1 = (values - returns).pow(2) vf_loss2 = (values_clipped - returns).pow(2) return 0.5 * torch.max(vf_loss1, vf_loss2).mean()

4.3 训练流程优化

DSC使用了几个关键优化技术:

  1. 混合引擎:DeepSpeedHybridEngine实现训练/推理模式无缝切换
  2. 经验回放:通过MiniDataset管理PPO的mini-batch
  3. EMA平滑:使用指数移动平均稳定训练
rlhf_engine = DeepSpeedRLHFEngine( actor_model_name_or_path, critic_model_name_or_path, enable_hybrid_engine=True )

5. 实战经验与调参技巧

经过多个项目的实践,我总结出以下关键经验:

5.1 数据准备建议

  1. 三阶段数据一致性:尽量使用同一数据源的不同切片
  2. RM训练数据质量:chosen/rejected的差异要明显
  3. Prompt多样性:覆盖各种可能的用户输入类型

5.2 关键超参数设置

参数推荐值说明
PPO epochs1-3过大容易过拟合
KL系数0.1-0.2控制策略变化幅度
学习率1e-6~5e-6需要精细调整

5.3 常见问题解决

问题1:训练初期reward快速上升后崩溃
解决方案:降低学习率,增加KL惩罚系数

问题2:模型输出过于简短
解决方案:在reward设计中加入长度惩罚项

问题3:显存不足
解决方案:启用ZeRO-3优化或梯度检查点

6. 自定义扩展实践

6.1 添加新数据集

  1. 创建继承自PromptRawDataset的子类
  2. 实现数据加载方法
  3. data_utils.py中注册新数据集
class MyDataset(PromptRawDataset): def get_prompt_and_chosen(self, sample): return f"Human: {sample['question']}\nAssistant: {sample['good_answer']}" # 在data_utils.py中注册 def get_raw_dataset(): elif "my_dataset" in dataset_name: return MyDataset(...)

6.2 修改奖励函数

可以通过继承RewardModel类实现自定义奖励计算:

class MyRewardModel(RewardModel): def forward_value(self, seq): rewards = super().forward_value(seq) # 添加长度惩罚 length_penalty = seq.size(1) * 0.01 return rewards - length_penalty

6.3 多GPU训练技巧

DSC原生支持分布式训练,启动命令示例:

deepspeed --num_gpus 8 step3_rlhf_finetuning/main.py \ --actor_model_name_or_path path_to_sft_model \ --critic_model_name_or_path path_to_rm_model

在8×A100上训练7B模型时,合理配置batch size很关键。我们发现per-device batch size设为8时,GPU利用率能达到90%以上。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/16 5:25:15

Stata统计结果自动化导出:从命令到Excel/Word精美表格

1. 为什么需要自动化导出统计结果 做数据分析的朋友们应该都深有体会,每次跑完回归或者统计检验,最头疼的就是怎么把结果整理成论文或报告需要的格式。手动复制粘贴不仅费时费力,还容易出错。我刚开始用Stata的时候,经常因为手误把…

作者头像 李华
网站建设 2026/5/16 5:22:05

量子计算如何革新化学模拟?AFQMC方法解析

1. 量子计算与化学模拟的范式转变在计算化学领域,我们正经历着一场由量子计算技术驱动的革命性变革。传统电子结构计算方法如密度泛函理论(DFT)和耦合簇理论(CCSD(T))在处理强关联系统时面临着根本性挑战——这类系统在…

作者头像 李华
网站建设 2026/5/16 5:21:38

书成紫微动,律定凤凰驯:一破一立《第一大道》与《凰标》双生记

不破不立,先破后成。 一破启天道,一立安众生。破 局旧秩序新震源资本垄断话语权《第一大道》——思想利刃圈层壁垒森严破妄归真,以道开新草根无通道撕开裂隙,照进天光 书成紫微动,非玄虚天命, 是万千普通人…

作者头像 李华
网站建设 2026/5/16 5:19:43

当比你资历浅的人成了你的上级,技术人的心态调整指南

阶段一:缺陷定位——从审视“测试用例”开始当问题出现时,优秀的测试工程师不会立刻指责开发,而是先检查自己的测试环境、数据和步骤。面对年轻领导的晋升,我们同样需要运用这套严谨的思维,进行一次彻底的“根因分析”…

作者头像 李华
网站建设 2026/5/16 5:19:10

线性自抗扰PMSM模型预测控制【附代码】

✨ 长期致力于模型预测转矩控制、永磁同步电机、线性自抗扰控制、权重因子、模拟退火粒子群算法研究工作,擅长数据搜集与处理、建模仿真、程序编写、仿真设计。 ✅ 专业定制毕设、代码 ✅ 如需沟通交流,点击《获取方式》 (1)模拟退…

作者头像 李华