news 2026/5/9 12:53:35

CANN NPU SwiGLU分组量化

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
CANN NPU SwiGLU分组量化

custom-npu_swiglu_group_quant

【免费下载链接】cann-recipes-infer本项目针对LLM与多模态模型推理业务中的典型模型、加速算法,提供基于CANN平台的优化样例项目地址: https://gitcode.com/cann/cann-recipes-infer

产品支持情况

产品是否支持
Ascend 950PR/Ascend 950DT

功能说明

在SwiGlu激活函数后添加量化操作,实现输入x的SwiGluQuant计算。根据quant_mode不同,量化分组大小有所差异:quant_mode=1或3时groupSize固定为128,quant_mode=2时groupSize固定为32。计算过程见计算公式。

计算公式

$$ Y_{tmp} = SwiGLU(x) = Swish(A)*B $$

$$ scale=row_max(abs(Y_{tmp}))/dstTypeScale $$

$$ Y = Cast(Mul(Y_{tmp}, Scale)) $$ 其中,A表示输入x的前半部分,B表示输入x的后半部分。

函数原型

custom.npu_swiglu_group_quant(Tensor x, *, Tensor? topk_weight=None, Tensor? group_index=None, ScalarType dst_type, int quant_mode=1, int group_size=128, bool round_scale=False, bool ue8m0_scale=False, bool output_origin=False, int group_list_type=0, float clamp_value=0.0) -> (Tensor, Tensor, Tensor)

参数说明

说明:

  • b(batch size)表示输入样本批量大小、s(sequence length)表示输入样本序列长度、d(head dimension)表示注意力头的维度数、T表示bs合轴后的大小、G表示MoE场景下的分组数量。
  • xTensor):必选参数,输入tensor,公式中的输入 x 。不支持非连续,数据格式支持ND,数据类型支持float16bfloat16,shape为[ b, s, d ]或[ T, d ](T为bs合轴后的大小)。

  • *:代表其之前的参数是位置相关的,必须按照顺序输入,属于必选参数;其之后的参数是键值对赋值,与位置无关,属于可选参数(不传入会使用默认值)。

  • topk_weightTensor,可选):MoE场景下的topk权重tensor,用于对SwiGLU输出进行加权。数据类型支持float32,shape为[ T, 1 ]。默认为None,表示不进行加权操作。

  • group_indexTensor,可选):MoE场景下分组索引tensor,用于指定每个group的token个数。数据类型支持int64,shape为[G],其中G表示分组数量。该Tensor的数值总和代表输入x中的有效Token数。默认值为None,表示非MoE场景,此时所有Token都有效。

  • dst_typeScalarType,可选):量化输出数据类型,支持torch.float8_e5m2torch.float8_e4m3fn。默认为torch.float8_e4m3fn

  • quant_modeint, 可选):量化模式,取值范围为1、2、3。

    • 1:PerGroup fp8量化,量化输出scale为float32类型,groupSize固定为128。
    • 2:MX fp8量化,量化输出scale为float8_e8m0类型,groupSize固定为32。
    • 3:PerGroup fp8量化增强模式,groupSize固定为128。scale类型由ue8m0_scale参数决定:ue8m0_scale=False时scale为float32类型(与quant_mode=1功能相同),ue8m0_scale=True时scale为float8_e8m0类型。相比quant_mode=1,该模式还支持round_scale、output_origin等额外参数。
  • group_sizeint, 可选):量化分组大小。quant_mode=1或3时固定为128;quant_mode=2时固定为32。默认为128。

  • round_scalebool, 可选):是否对scale进行舍入处理。默认为False。仅在quant_mode=3时有效。

  • ue8m0_scalebool, 可选):是否使用float8_e8m0格式输出scale。默认为False,scale为float32类型;设置为True时,scale为float8_e8m0类型。仅在quant_mode=3时有效。

  • output_originbool, 可选):是否输出SwiGLU计算后的原始fp16/bf16结果。默认为False。仅在quant_mode=3时有效。

  • group_list_typeint, 可选):分组列表类型,固定为0,默认为0,表示count模式。在count模式下,group_index的每个元素代表对应分组的元素个数。在quant_mode=1、2、3时均生效。

  • clamp_valuefloat, 可选):SwiGLU激活函数的截断值。当clamp_value>0时,对激活值进行截断:A部分只有最大值约束,截断到不超过clamp_value;B部分截断到[-clamp_value, clamp_value]。默认为0.0,表示不进行截断。

返回值说明

  • yTensor):输出tensor,量化后的value。不支持非连续,数据格式支持ND,数据类型支持float8_e4m3fnfloat8_e5m2,shape为[ b, s, d/2 ]或[ T, d/2 ]。当使用group_index时,前有效Token数行的数据有效,其余为填充数据。

  • scaleTensor):输出tensor,量化后的scale。不支持非连续,数据格式支持ND。当quant_mode=1或quant_mode=3且ue8m0_scale=False时,数据类型为float32;当quant_mode=2或quant_mode=3且ue8m0_scale=True时,数据类型为float8_e8m0

  • yOriginTensor):输出tensor,SwiGLU计算后的原始结果。仅在quant_mode=3且output_origin=True时有效输出,其他情况下返回空tensor。数据类型与输入x相同(float16bfloat16),shape为[ b, s, d/2 ]或[ T, d/2 ]。

约束说明

  • shape 字段取值范围约束 | quant_mode | d(最后一维)约束 | |------------|------------------| | 1 或 3 | 必须是256的倍数 | | 2 | 必须是128的倍数 |

  • quant_mode 取值范围为1、2、3。具体说明详见参数说明中的quant_mode字段。

  • round_scale、ue8m0_scale、output_origin参数仅在quant_mode=3时有效。

  • group_list_type固定为0,表示count模式,在quant_mode=1、2、3时均生效。

  • group_index的数值总和必须不超过输入x的第一维大小。

  • 该接口支持推理场景下使用。

  • 该接口支持aclgraph入图。

  • 该接口与PyTorch配合使用时,需要保证CANN相关包与PyTorch相关包的版本匹配。

调用示例

  • 详见 test_npu_swiglu_group_quant.py

【免费下载链接】cann-recipes-infer本项目针对LLM与多模态模型推理业务中的典型模型、加速算法,提供基于CANN平台的优化样例项目地址: https://gitcode.com/cann/cann-recipes-infer

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/9 12:53:34

CANN/driver获取设备NDIE ID

dcmi_get_device_ndie 【免费下载链接】driver 本项目是CANN提供的驱动模块,实现基础驱动和资源管理及调度等功能,使能昇腾芯片。 项目地址: https://gitcode.com/cann/driver 函数原型 int dcmi_get_device_ndie(int card_id, int device_id, s…

作者头像 李华
网站建设 2026/5/9 12:52:56

3步快速解密:让网易云音乐加密文件重获自由的完整指南

3步快速解密:让网易云音乐加密文件重获自由的完整指南 【免费下载链接】ncmdump 项目地址: https://gitcode.com/gh_mirrors/ncmd/ncmdump 你是否曾遇到过这样的困扰:从网易云音乐精心下载的歌曲,却只能在特定软件中播放,…

作者头像 李华
网站建设 2026/5/9 12:51:38

WeChatPad技术揭秘:如何让您的安卓手机同时登录两个微信账号?

WeChatPad技术揭秘:如何让您的安卓手机同时登录两个微信账号? 【免费下载链接】WeChatPad 强制使用微信平板模式 项目地址: https://gitcode.com/gh_mirrors/we/WeChatPad 作为一名忙碌的开发者,您是否曾面临这样的困境:工…

作者头像 李华
网站建设 2026/5/9 12:51:35

CANN Runtime事件管理API参考

# 7. Event管理 【免费下载链接】runtime 本项目提供CANN运行时组件和维测功能组件。 项目地址: https://gitcode.com/cann/runtime 本章节描述 CANN Runtime 的 Event 管理接口,用于事件的创建、记录、同步、计时及 IPC 跨进程共享。 aclError…

作者头像 李华
网站建设 2026/5/9 12:51:30

CANN/pyasc合并排序API文档

asc.language.basic.mrg_sort4 【免费下载链接】pyasc 本项目为Python用户提供算子编程接口,支持在昇腾AI处理器上加速计算,接口与Ascend C一一对应并遵守Python原生语法。 项目地址: https://gitcode.com/cann/pyasc asc.language.basic.mrg_sor…

作者头像 李华
网站建设 2026/5/9 12:50:09

CANN/asc-devkit AddrReg地址寄存器API

AddrReg 【免费下载链接】asc-devkit 本项目是CANN 推出的昇腾AI处理器专用的算子程序开发语言,原生支持C和C标准规范,主要由类库和语言扩展层构成,提供多层级API,满足多维场景算子开发诉求。 项目地址: https://gitcode.com/ca…

作者头像 李华