news 2026/5/1 10:47:58

sglang 大模型推理框架支持的EAGLE 1,2,3

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
sglang 大模型推理框架支持的EAGLE 1,2,3

文章目录

      • EAGLE 系列模型的演进与核心机制
      • 关键参数与训练逻辑
      • 思考

参考来源:https://docs.sglang.com.cn/backend/speculative_decoding.html
https://github.com/SafeAILab/EAGLE
EAGLE3 https://arxiv.org/pdf/2503.01840

EAGLE 系列模型的演进与核心机制

EAGLE 基础架构
草稿模型通过特征序列和 token 序列预测下一个特征向量,基于原始 LLM 的最后一个隐藏状态生成候选。采样后的 token 与原始序列以树状结构扩展,分支因子由speculative_eagle_topk控制,确保上下文连贯性。扩展后的树结构重新作为输入迭代生成。

EAGLE-2 的优化
引入动态分支评估机制,草稿模型主动评估扩展分支的可能性,提前终止低概率分支的扩展。扩展阶段结束后,通过重排序筛选前speculative_num_draft_tokens个节点作为最终草稿 token,减少冗余计算。

--speculative-token-map参数设置为true以启用高频 token 优化功能。该参数通常在模型推理或训练配置文件中进行设置。

EAGLE-3 的改进
移除特征预测目标,整合低层与中间层特征提升表示能力。采用 on-policy 训练方式,使模型在推理阶段的行为与训练目标更一致,进一步优化生成质量与效率。

关键参数与训练逻辑

  • speculative_eagle_topk:控制每步扩展的分支数量,影响生成多样性与计算开销。
  • speculative_num_draft_tokens:决定保留的候选 token 数量,平衡生成速度与准确性。
  • On-policy 训练:通过对齐训练与推理阶段的策略,减少分布偏移问题。

  • https://github.com/SafeAILab/EAGLE/blob/main/eagle/traineagle3/cnets.py

核心代码部分

def_prepare_decoder_attention_mask(self,attention_mask,input_shape,inputs_embeds,past_key_values_length):# create causal mask# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]combined_attention_mask=Noneifinput_shape[-1]>1:combined_attention_mask=_make_causal_mask(input_shape,inputs_embeds.dtype,device=inputs_embeds.device,past_key_values_length=past_key_values_length,)ifattention_maskisnotNone:# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]expanded_attn_mask=_expand_mask(attention_mask,inputs_embeds.dtype,tgt_len=input_shape[-1]).to(inputs_embeds.device)combined_attention_mask=(expanded_attn_maskifcombined_attention_maskisNoneelseexpanded_attn_mask+combined_attention_mask)returncombined_attention_mask@torch.no_grad()defdataprepare(self,input_ids,attention_mask,loss_mask):device=input_ids.device outs=self.target_model(input_ids=input_ids,attention_mask=attention_mask)hidden_states0=outs.hidden_states[0]hidden_states1=outs.hidden_states[1]hidden_states2=outs.hidden_states[2]hidden_states=torch.cat((hidden_states0,hidden_states1,hidden_states2),dim=-1)# hidden_states=torch.cat((hidden_states0,hidden_states1),dim=-1)target=outs.logits target=padding(target,left=False)input_ids=padding(input_ids,left=False)iftargetisnotNone:target=target.to(device)loss_mask=loss_mask[...,None]loss_mask=loss_mask.to(device)returnhidden_states,target,loss_mask,input_idsdefforward(self,# hidden_states,input_ids,attention_mask:Optional[torch.Tensor]=None,position_ids:Optional[torch.LongTensor]=None,past_key_values:Optional[List[torch.FloatTensor]]=None,use_cache:Optional[bool]=None,output_attentions:Optional[bool]=None,output_hidden_states:Optional[bool]=None,loss_mask:Optional[torch.Tensor]=None,):hidden_states,target,loss_mask,input_ids=self.dataprepare(input_ids,attention_mask,loss_mask)batch_size,seq_length,_=hidden_states.shape seq_length_with_past=seq_length past_key_values_length=0# with torch.no_grad():# inputs_embeds = self.embed_tokens(input_ids)# inputs_embeds = inputs_embeds.detach()ifself.trainingandself.gradient_checkpointingandnothidden_states.requires_grad:hidden_states.requires_grad=Truehidden_states=self.fc(hidden_states)ifpast_key_valuesisnotNone:past_key_values_length=past_key_values[0][0].shape[2]seq_length_with_past=seq_length_with_past+past_key_values_lengthifposition_idsisNone:device=hidden_states.device position_ids=torch.arange(past_key_values_length,seq_length+past_key_values_length,dtype=torch.long,device=device)position_ids=position_ids.unsqueeze(0).view(-1,seq_length)else:position_ids=position_ids.view(-1,seq_length).long()ifattention_maskisNone:attention_mask=torch.ones((batch_size,seq_length_with_past),dtype=torch.bool,device=hidden_states.device)attention_mask=self._prepare_decoder_attention_mask(attention_mask,(batch_size,seq_length),hidden_states,past_key_values_length)ifself.gradient_checkpointingandself.training:ifuse_cache:use_cache=Falseplosses=[]vlosses=[]acces=[]cache_hidden=[[],[]]foridxinrange(self.length):last=idx==self.length-1inputs_embeds=self.embed_tokens(input_ids)ifself.trainingandself.gradient_checkpointingandnotinputs_embeds.requires_grad:inputs_embeds.requires_grad=Trueinputs_embeds=inputs_embeds.to(hidden_states.dtype)ifself.gradient_checkpointingandself.training:defcreate_custom_forward(module):defcustom_forward(*inputs):# None for past_key_valuereturnmodule(*inputs,None,output_attentions)returncustom_forward layer_outputs,cache_hidden=torch.utils.checkpoint.checkpoint(create_custom_forward(self.midlayer),inputs_embeds,hidden_states,cache_hidden,attention_mask,position_ids,)else:layer_outputs,cache_hidden=self.midlayer(input_emb=inputs_embeds,hidden_states=hidden_states,cache_hidden=cache_hidden,attention_mask=attention_mask,position_ids=position_ids,past_key_value=None,output_attentions=output_attentions,use_cache=True,)hidden_states_out=layer_outputs[0]# cache_hidden.append(layer_outputs[1])# kv_cahce = layer_outputs[-1]withtorch.no_grad():# hidden_states_target = padding(hidden_states, left=False)target_head=target target_max_token=target_head.argmax(-1)# Move d2t to the same device as target_max_tokenself.t2d=self.t2d.to(target_max_token.device)target_mask=self.t2d[target_max_token]target_mask=target_mask[...,None].int()position_mask=target_mask*loss_mask target_head=target_head[...,self.t2d]target_head=target_head.float()target_p=nn.Softmax(dim=2)(target_head)target_p=target_p.detach()hidden_states=hidden_states_out hidden_states_out=self.norm(hidden_states_out)logits=self.lm_head(hidden_states_out)logits=logits.float()out_logp=nn.LogSoftmax(dim=2)(logits)plogp=target_p*out_logp loss=-torch.sum(position_mask*plogp,2).mean()plosses.append(loss)withtorch.no_grad():acces.append(((logits.argmax(-1)==target_p.argmax(-1))*position_mask.squeeze(-1)).sum().item()/(loss_mask.sum().item()+1e-6))ifnotlast:input_ids=padding(input_ids,left=False)target=padding(target,left=False)loss_mask=padding(loss_mask,left=False)returnplosses,vlosses,acces

思考

》 FASTMTP与EAGLE3相比,谁更快一些?

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/1 8:17:27

打开软件出现找不到vcomp140.dll文件 无法运行的情况 下载修复解决

在使用电脑系统时经常会出现丢失找不到某些文件的情况,由于很多常用软件都是采用 Microsoft Visual Studio 编写的,所以这类软件的运行需要依赖微软Visual C运行库,比如像 QQ、迅雷、Adobe 软件等等,如果没有安装VC运行库或者安装…

作者头像 李华
网站建设 2026/5/1 5:03:30

汇编语言全接触-27.工具提示控件

我们将学习工具提示控件:它是什么如何创建和使用.下载例子理论:工具提示是当鼠标在某特定区域上停留时显示的一个矩形窗口.工具提示窗口包含一些编程者想要显示的文本.在这点上,工具提示同状态栏的作用是一样的,所不同的是工具提示当单击或者远离指定区域的时候就会消逝,你可能…

作者头像 李华
网站建设 2026/5/1 8:55:28

测试左移:构建软件质量的早期防线

在快速迭代的现代软件开发周期中,缺陷发现的时机直接影响项目成本、发布节奏与最终用户体验。传统软件测试模式中,测试活动往往集中于开发后期,导致缺陷修复成本高昂、返工风险加剧。测试左移作为一种前瞻性质量保障策略,通过将测…

作者头像 李华
网站建设 2026/5/1 8:45:00

串口通讯的android 封装开箱即用!提供源代码!

功能概述 本文档总结了在Android应用中使用serialportlibrary实现串口通讯功能的完整过程。通过本次开发,成功添加了以下核心功能: 串口设备的打开与关闭 数据的发送与接收 用户友好的操作界面 实现细节 1. UI界面修改 在activity_main.xml中添加…

作者头像 李华
网站建设 2026/5/1 9:38:59

张量并行 (Tensor Parallelism, TP) 深度解析

张量并行 (Tensor Parallelism, TP) 深度解析 1. TP 只能用于 Transformer 吗? 答案是:不,但它在 Transformer 上用得最多,也最有效。 咱们从 CV 的角度来类比。TP 的核心思想是“拆分矩阵乘法”。任何包含巨大矩阵乘法 (YX⋅WY …

作者头像 李华