news 2026/5/9 16:46:44

CANN/catlass Block MMAD开发详解

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
CANN/catlass Block MMAD开发详解

Block MMAD 代码开发详解

【免费下载链接】catlass本项目是CANN的算子模板库,提供NPU上高性能矩阵乘及其相关融合类算子模板样例。项目地址: https://gitcode.com/cann/catlass

1. Block MMAD 概述

Block MMAD(Block Matrix Multiply-Add)是CATLASS模板库中负责块级矩阵乘法的核心组件,位于计算架构的中间层。它上接Kernel层,下接Tile层,负责将全局内存(GM)中的数据高效地加载到本地内存(L1/L0),并调度Tile级别的矩阵乘法运算。

Block MMAD采用高度模块化和模板化的设计,支持多种调度策略、Tile形状和数据类型,能够灵活适应不同的硬件架构和计算需求,本文将以BlockMmadPingpong为例详细讲解。

2. 模板组装机制

Block MMAD实现基于以下基础模板结构:

template < class DispatchPolicy, class L1TileShape, class L0TileShape, class AType, class BType, class CType, class BiasType = void, class TileCopy = Gemm::Tile::TileCopy<typename DispatchPolicy::ArchTag, AType, BType, CType, BiasType>, class TileMmad = Gemm::Tile::TileMmad<typename DispatchPolicy::ArchTag, AType, BType, BiasType> > struct BlockMmad { // 类型定义和核心实现 };

2.1 核心模板参数

参数名描述
DispatchPolicy调度策略,控制计算任务的分配和执行流程
L1TileShapeL1缓存级别Tile的形状,定义M、N、K三个维度的大小
L0TileShapeL0缓存级别Tile的形状,定义M、N、K三个维度的大小
AType/BType/CType输入矩阵A、B和输出矩阵C的数据类型和布局信息
BiasType偏置数据类型(可选)
TileCopyTile级别的数据拷贝组件,负责不同内存层级间的数据传输
TileMmadTile级别的矩阵乘法组件,负责实际的计算操作

2.2 类型导出

Block MMAD通过类型导出系统建立统一的类型接口,方便上层组件使用:

public: // Type Aliases using DispatchPolicy = MmadAtlasA2Pingpong<ENABLE_UNIT_FLAG_>; using ArchTag = typename DispatchPolicy::ArchTag; using L1TileShape = L1TileShape_; using L0TileShape = L0TileShape_; using ElementA = typename AType_::Element; using LayoutA = typename AType_::Layout; using ElementB = typename BType_::Element; using LayoutB = typename BType_::Layout; using ElementC = typename CType_::Element; using LayoutC = typename CType_::Layout; using TileMmad = TileMmad_; using CopyGmToL1A = typename TileCopy_::CopyGmToL1A; using CopyGmToL1B = typename TileCopy_::CopyGmToL1B; using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A; using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B; using CopyL0CToGm = typename TileCopy_::CopyL0CToGm; using ElementAccumulator = typename Gemm::helper::ElementAccumulatorSelector<ElementA, ElementB>::ElementAccumulator; using LayoutAInL1 = typename CopyL1ToL0A::LayoutSrc; using LayoutBInL1 = typename CopyL1ToL0B::LayoutSrc; using LayoutAInL0 = typename CopyL1ToL0A::LayoutDst; using LayoutBInL0 = typename CopyL1ToL0B::LayoutDst; using LayoutCInL0 = layout::zN;

3. 内存管理与缓存设计

Block MMAD负责管理不同层级的内存,包括全局内存(GM)、L1缓存和L0缓存:

3.1 静态常量定义

static constexpr bool ENABLE_UNIT_FLAG = DispatchPolicy::ENABLE_UNIT_FLAG; static constexpr uint32_t STAGES = DispatchPolicy::STAGES; static constexpr uint32_t L1A_SIZE = L1TileShape::M * L1TileShape::K * sizeof(ElementA); static constexpr uint32_t L1B_SIZE = L1TileShape::N * L1TileShape::K * sizeof(ElementB); static constexpr uint32_t L0A_SIZE = ArchTag::L0A_SIZE; static constexpr uint32_t L0B_SIZE = ArchTag::L0B_SIZE; static constexpr uint32_t L0C_SIZE = ArchTag::L0C_SIZE; static constexpr uint32_t L0A_PINGPONG_BUF_SIZE = L0A_SIZE / STAGES; static constexpr uint32_t L0B_PINGPONG_BUF_SIZE = L0B_SIZE / STAGES;

3.2 内存检查

// Check LayoutC static_assert(std::is_same_v<LayoutC, layout::RowMajor>, "LayoutC only support RowMajor yet!"); // Check L1TileShape static_assert((L1A_SIZE * STAGES + L1B_SIZE * STAGES) <= ArchTag::L1_SIZE, "L1TileShape exceeding the L1 space!"); // Check L0TileShape static constexpr uint32_t L0A_TILE_SIZE = L0TileShape::M * L0TileShape::K * sizeof(ElementA); static constexpr uint32_t L0B_TILE_SIZE = L0TileShape::K * L0TileShape::N * sizeof(ElementB); static constexpr uint32_t L0C_TILE_SIZE = L0TileShape::M * L0TileShape::N * sizeof(ElementAccumulator); static_assert((L0A_TILE_SIZE * STAGES) <= L0A_SIZE, "L0TileShape exceeding the L0A space!"); static_assert((L0B_TILE_SIZE * STAGES) <= L0B_SIZE, "L0TileShape exceeding the L0B space!"); static_assert(L0C_TILE_SIZE <= L0C_SIZE, "L0TileShape exceeding the L0C space!");

3.3 多阶段缓存设计

Block MMAD采用多阶段流水线设计,使用pingpong技术隐藏内存访问延迟:

protected: // Multi-stage tensors list AscendC::LocalTensor<ElementA> l1ATensorList[STAGES]; AscendC::LocalTensor<ElementB> l1BTensorList[STAGES]; AscendC::LocalTensor<ElementA> l0ATensorList[STAGES]; AscendC::LocalTensor<ElementB> l0BTensorList[STAGES]; AscendC::LocalTensor<ElementAccumulator> l0CTensor; // Multi-stage event id list int32_t l1AEventList[STAGES]; int32_t l1BEventList[STAGES]; int32_t l0AEventList[STAGES]; int32_t l0BEventList[STAGES]; // The id of current stage uint32_t l1ListId{0}; uint32_t l0AListId{0}; uint32_t l0BListId{0};

4. 核心接口实现

4.1 构造函数

CATLASS_DEVICE BlockMmad(Arch::Resource<ArchTag> &resource, uint32_t l1BufAddrStart = 0) { uint32_t l1AOffset = l1BufAddrStart; uint32_t l1BOffset = l1BufAddrStart + L1A_SIZE * STAGES; // Init buffers for (uint32_t i = 0; i < STAGES; i++) { // Assign L1/L0A/L0B space for each stages l1ATensorList[i] = resource.l1Buf.template GetBufferByByte<ElementA>(l1AOffset + L1A_SIZE * i); l1BTensorList[i] = resource.l1Buf.template GetBufferByByte<ElementB>(l1BOffset + L1B_SIZE * i); l0ATensorList[i] = resource.l0ABuf.template GetBufferByByte<ElementA>(L0A_PINGPONG_BUF_SIZE * i); l0BTensorList[i] = resource.l0BBuf.template GetBufferByByte<ElementB>(L0B_PINGPONG_BUF_SIZE * i); // Assign event ID for each stages l1AEventList[i] = i; l1BEventList[i] = i + STAGES; l0AEventList[i] = i; l0BEventList[i] = i + STAGES; // The event id that needs to be set before the loop AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[i]); AscendC::SetFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[i]); AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(l0AEventList[i]); AscendC::SetFlag<AscendC::HardEvent::M_MTE1>(l0BEventList[i]); } l0CTensor = resource.l0CBuf.template GetBufferByByte<ElementAccumulator>(0); AscendC::SetFlag<AscendC::HardEvent::FIX_M>(EVENT_ID0); }

4.2 析构函数

CATLASS_DEVICE ~BlockMmad() { for (uint32_t i = 0; i < STAGES; i++) { AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[i]); AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[i]); AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(l0AEventList[i]); AscendC::WaitFlag<AscendC::HardEvent::M_MTE1>(l0BEventList[i]); } AscendC::WaitFlag<AscendC::HardEvent::FIX_M>(EVENT_ID0); }

4.3 核心计算接口 operator()

operator()是Block MMAD的核心接口,负责执行块级矩阵乘法:

CATLASS_DEVICE void operator()( AscendC::GlobalTensor<ElementA> const &gmA, LayoutA const &layoutA, AscendC::GlobalTensor<ElementB> const &gmB, LayoutB const &layoutB, AscendC::GlobalTensor<ElementC> const &gmC, LayoutC const &layoutC, GemmCoord const &actualShape) { // 1. 初始化参数和布局 // 2. 预加载第一批数据到L1缓存 // 3. 主循环: // a. 预加载下一批数据到L1缓存 // b. 将当前L1缓存中的数据加载到L0缓存 // c. 执行Tile级矩阵乘法 // d. 管理多阶段流水线 // 4. 将结果从L0缓存写回全局内存 }

5. 执行流程分析

以BlockMmadPingpong为例,其执行流程如下:

5.1 数据预加载

// load first matrix A tile from GM to L1 AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[l1ListId]); auto layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShape.m(), kActual)); copyGmToL1A(l1ATensorList[l1ListId], gmA, layoutAInL1, layoutTileA); AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE1>(l1AEventList[l1ListId]); // load first matrix B tile from GM to L1 AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1BEventList[l1ListId]); auto layoutTileB = layoutB.GetTileLayout(MakeCoord(kActual, actualShape.n())); copyGmToL1B(l1BTensorList[l1ListId], gmB, layoutBInL1, layoutTileB); AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE1>(l1BEventList[l1ListId]);

5.2 主循环(多阶段流水线)

// main loop uint32_t kTileCount = CeilDiv<L1TileShape::K>(actualShape.k()); for (uint32_t kLoopIdx = 0; kLoopIdx < kTileCount; kLoopIdx++) { uint32_t l1ListIdNext = (l1ListId + 1 < STAGES) ? (l1ListId + 1) : 0; uint32_t kActualNext{0}; // 预加载下一批数据到L1缓存 if (kLoopIdx < kTileCount - 1) { // ... 预加载逻辑 ... } // 处理当前L1缓存中的数据 // Get L1 tensor for current stage auto l1ATensor = l1ATensorList[l1ListId]; auto l1BTensor = l1BTensorList[l1ListId]; // 将L1缓存中的数据加载到L0缓存并执行计算 // ... L0级处理逻辑 ... // 切换到下一个阶段 l1ListId = l1ListIdNext; kActual = kActualNext; }

5.3 L0级处理与计算

for (int mPartIdx = 0; mPartIdx < mPartLoop; mPartIdx++) { // ... M维度循环 ... for (int kPartIdx = 0; kPartIdx < kPartLoop; kPartIdx++) { // ... K维度循环 ... for (int nPartIdx = 0; nPartIdx < nPartLoop; nPartIdx++) { // ... N维度循环 ... // 执行Tile级矩阵乘法 bool initC = ((kLoopIdx == 0) && (kPartIdx == 0)); uint8_t unitFlag = 0b00; // ... unitFlag设置 ... tileMmad(l0CTile, l0ATile, l0BTile, mPartActual, nPartActual, kPartActual, initC, unitFlag); } } }

5.4 结果写回

// copy block out LayoutC layoutBlock = layoutC.GetTileLayout(actualShape.GetCoordMN()); if constexpr (!ENABLE_UNIT_FLAG) { AscendC::SetFlag<AscendC::HardEvent::M_FIX>(EVENT_ID0); AscendC::WaitFlag<AscendC::HardEvent::M_FIX>(EVENT_ID0); copyL0CToGm(gmC, l0CTensor, layoutBlock, layoutInL0C); AscendC::SetFlag<AscendC::HardEvent::FIX_M>(EVENT_ID0); } else { copyL0CToGm(gmC, l0CTensor, layoutBlock, layoutInL0C, 0b11); }

6. 多阶段流水线与事件同步

Block MMAD使用事件驱动的同步机制确保多阶段流水线的正确执行:

// 等待事件 AscendC::WaitFlag<AscendC::HardEvent::MTE1_MTE2>(l1AEventList[l1ListId]); // 执行操作 copyGmToL1A(l1ATensorList[l1ListId], gmA, layoutAInL1, layoutTileA); // 触发事件 AscendC::SetFlag<AscendC::HardEvent::MTE2_MTE1>(l1AEventList[l1ListId]);

7. 不同Block MMAD实现的共性与差异

CATLASS提供了多种Block MMAD实现,它们具有以下共性:

  1. 统一的模板接口:所有实现都基于相同的模板参数结构
  2. 多阶段流水线设计:都采用多阶段设计来隐藏内存访问延迟
  3. 事件驱动同步:都使用事件机制确保数据传输和计算的正确顺序
  4. 内存层级管理:都负责管理GM、L1和L0之间的数据传输

不同实现的主要差异在于:

  1. 调度策略:如pingpong、preload、streamk等不同的调度方式
  2. 硬件适配:针对不同硬件架构的优化实现
  3. 特殊功能支持:如量化、稀疏计算、偏置处理等
  4. 性能优化:不同的流水线深度和内存访问模式

8. 总结

Block MMAD是CATLASS模板库中连接Kernel层和Tile层的关键组件,它通过以下设计实现了高效的块级矩阵乘法:

  1. 高度模块化:通过模板参数组装不同的功能组件
  2. 多阶段流水线:使用pingpong技术隐藏内存访问延迟
  3. 事件驱动同步:确保数据传输和计算的正确顺序
  4. 自动内存管理:高效管理不同层级的内存资源
  5. 灵活适配:支持不同的硬件架构和计算需求

Block MMAD的设计体现了现代高性能计算库的核心思想,通过精心的流水线设计和内存管理,充分发挥硬件的计算能力,实现高效的矩阵乘法运算。

【免费下载链接】catlass本项目是CANN的算子模板库,提供NPU上高性能矩阵乘及其相关融合类算子模板样例。项目地址: https://gitcode.com/cann/catlass

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/9 16:45:31

AI拟人化设计:如何通过外观、行为与交互激发人类共情与道德考量

1. 项目概述&#xff1a;当AI变得“像人”&#xff0c;我们为何会犹豫&#xff1f;最近和几个做机器人伦理研究的朋友聊天&#xff0c;我们讨论了一个挺有意思的现象&#xff1a;当我们在实验室里测试一个功能强大的机械臂时&#xff0c;下达“让它自毁”的指令&#xff0c;大家…

作者头像 李华
网站建设 2026/5/9 16:42:13

CANN/pyasc API文档自动生成工具使用指南

API文档自动生成工具使用指南 【免费下载链接】pyasc 本项目为Python用户提供算子编程接口&#xff0c;支持在昇腾AI处理器上加速计算&#xff0c;接口与Ascend C一一对应并遵守Python原生语法。 项目地址: https://gitcode.com/cann/pyasc 概述 本项目采用Sphinx工具&…

作者头像 李华
网站建设 2026/5/9 16:41:07

OpenClaw AI Agent实战指南:从自动化客服到个人助理的六大场景应用

1. 从工具到伙伴&#xff1a;OpenClaw AI Agent 如何重塑你的工作流如果你还在把AI当作一个简单的聊天机器人&#xff0c;或者一个偶尔帮你写点文案的“外挂”&#xff0c;那你可能错过了这个时代最激动人心的生产力革命。OpenClaw AI Agent&#xff0c;这个听起来有点赛博朋克…

作者头像 李华
网站建设 2026/5/9 16:37:33

Llama模型转ONNX:从PyTorch到跨平台部署的完整指南

1. 项目概述&#xff1a;从Llama到ONNX的模型“翻译官”最近在折腾大语言模型本地部署和推理优化的朋友&#xff0c;估计没少为模型格式转换头疼。特别是那些动辄几十GB的Llama家族模型&#xff0c;原生的PyTorch格式虽然灵活&#xff0c;但在生产环境部署、跨平台推理或者追求…

作者头像 李华
网站建设 2026/5/9 16:31:58

Tailwind CSS如何设置不同断点的内边距_使用p-4 md-p-8类.txt

不能。std::ios::badbit仅反映流内部状态异常&#xff0c;无法可靠捕获硬盘掉线或I/O控制器故障&#xff1b;真实硬件错误需依赖系统调用返回的EIO等errno&#xff0c;而非流状态位。std::ios::badbit 真的能捕获硬盘掉线或 I/O controller 故障吗&#xff1f;不能。它只反映流…

作者头像 李华