news 2026/5/22 17:25:25

torch-catlass 测试框架设计文档

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
torch-catlass 测试框架设计文档

torch-catlass 测试框架设计文档

1. 总览

tests/optest是 CATLASS 示例算子接入 PyTorch 的端到端测试框架。框架将 CATLASS AscendC kernel 封装为torch.ops.catlass.*算子,并通过 Python 包torch_catlass提供测试入口。

框架按职责分为五层:

Python API torch_catlass.ops.* | v Python package loader load kernel libs and libcatlass_torch.so | v PyTorch C++ extension register torch.ops.catlass.* | v Kernel adapter convert Tensor arguments to CATLASS kernel params | v Kernel implementation prebuilt kernel or JIT compiled template

核心原则:

  • Python 层只负责用户接口、动态库加载和轻量参数转换。
  • C++ extension 层负责 PyTorch op 注册和 NPU dispatch。
  • adapter 层负责 Tensor 到 kernel ABI 的转换。
  • kernel 层负责 AscendC/CATLASS 代码执行。
  • JIT 子系统负责模板参数宏生成、编译、缓存和动态加载。

2. 目录模块

tests/optest/ ├── pyproject.toml ├── CMakeLists.txt ├── build.sh ├── docs/ │ └── design.md ├── include/ │ ├── catlass_kernel.h │ └── catlass_torch.h ├── torch_catlass/ │ ├── __init__.py │ ├── _version.py │ └── ops/ ├── src/ │ ├── catlass_torch.cpp │ ├── common/ │ └── include/ ├── utils/ │ ├── CMakeLists.txt │ ├── include/ │ ├── kernel_utils.cpp │ ├── torch_utils.cpp │ └── type_utils.hpp ├── kernels/ │ ├── CMakeLists.txt │ ├── common/ │ ├── include/ │ ├── jit/ │ └── 00_basic_matmul/ └── tests/ └── test_00_basic_matmul.py
模块责任
torch_catlass/Python 包入口、动态库加载、用户侧 op wrapper
include/框架和 kernel 共享的公共 ABI 声明
src/PyTorch C++ extension 和 torch op 注册
utils/dtype/layout/Tensor 工具函数,拆分 torch 依赖和纯 ACL 依赖
kernels/kernel 构建、JIT compiler、JIT template、kernel entry
tests/pytest 集成测试

3. Python 包模块

3.1 包初始化

torch_catlass/__init__.py在 import 时完成运行时初始化:

  1. _version.py读取构建期版本信息。
  2. 设置TORCH_CATLASS_VERSION,供 JIT 编译阶段注入版本宏。
  3. 设置TORCH_CATLASS_PKG_DIR,供 JIT compiler 定位安装后的 headers 和 templates。
  4. 加载 JIT compiler 和 JIT kernel entry 库。
  5. 根据当前 NPU 架构加载 arch-specific kernel 库。
  6. 调用torch.ops.load_library()加载libcatlass_torch.so,注册torch.ops.catlass.*

动态库加载顺序:

lib/jit/libcatlass_kernel_jit_compiler.so lib/jit/libcatlass_kernel_jit.so lib/<arch>/*.so lib/libcatlass_torch.so

JIT compiler 和 kernel entry 必须先于 PyTorch extension 加载,保证 extension 中引用的 kernel 符号可解析。

3.2 架构识别

Python loader 通过torch_npu.npu.get_device_name()识别设备并映射为 CATLASS arch id:

设备名arch
Ascend910B.*2201
Ascend910_932201
Ascend950PR/Ascend950DT3510

torch_npu.npu.device_count()为 0 时,loader 抛出明确错误。测试代码在无 NPU 环境下通过 pytest skip 处理,避免在 collection 阶段触发 torch-npu 内部错误。

3.3 Python op wrapper

torch_catlass/ops/保存 Python 用户接口。以basic_matmul为例:

torch_catlass.basic_matmul( mat1, mat2, outDType="float16", transA=False, transB=False, formatA=False, formatB=False, )

Python wrapper 只做用户友好的轻量转换,例如 dtype 字符串白名单解析。shape 推导、输出分配、stream 获取和 kernel launch 都在 C++ 层完成,避免 Python 和 C++ 维护两套语义。

4. PyTorch C++ Extension 模块

4.1 注册入口

src/catlass_torch.cpp是 PyTorch extension 的注册入口:

using BasicMatmulOp = MatmulLike<CatlassKernel::BasicMatmul>; static auto& basic_matmul = BasicMatmulOp::Run; REGISTER_TORCH_FUNC(basic_matmul);

REGISTER_TORCH_FUNC位于src/include/common/register.h,注册流程为:

  1. 创建或复用torch::Library,namespace 固定为catlass
  2. 通过 PyTorch schema inference 从 C++ 函数签名生成 schema。
  3. 将实现注册到c10::DispatchKey::PrivateUse1

PrivateUse1是 torch-npu 使用的 NPU dispatch key。Python 调用torch.ops.catlass.basic_matmul时,PyTorch 根据输入 Tensor device 走到该 backend 实现。

4.2 kernel launch 包装

RUN_NPU_FUNC位于src/include/common/run_npu_func.h。它通过 torch-npu 的OpCommand::RunOpApiV2执行 kernel launch:

  • launch 前检查函数指针是否为空。
  • 将 C++ 异常转为 ACL error code。
  • 将 kernel 调用交给 torch-npu runtime 管理。

5. Matmul Adapter 模块

src/include/template/matmul.h提供MatmulLike<KernelFunc>。该模板封装 matmul 类算子的通用流程:

Run() ├─ GetKernelInfo() ├─ AllocOutput() ├─ get current NPU stream ├─ get AIC core count └─ RUN_NPU_FUNC(KernelFunc, ...)

5.1 参数拆分

Matmul 参数拆分为两类:

参数结构用途是否参与 JIT 编译
MatmulTParamsdtype、layout、transpose 等模板参数
MatmulParamsM/N/K、输入输出地址等运行时参数

这种拆分保证 dtype/layout 变化会生成新的模板实例,而 shape 和 Tensor 地址变化不会导致重复编译。

5.2GetKernelInfo

GetKernelInfo()负责将 PyTorch Tensor 转为 kernel 参数:

  • torch::Dtype转为aclDataType
  • 根据transAtransB推导m/n/k
  • 检查两个输入矩阵的 K 维一致。
  • 填充MatmulTParams
  • 填充MatmulParams
  • 将输入 Tensor storage 地址写入params.inputAddr

5.3AllocOutput

AllocOutput()根据params.mparams.ntParams.elementC创建输出 Tensor,并将其 storage 地址写入params.outputAddr[0]。输出 Tensor 生命周期由 PyTorch 管理,kernel ABI 只接收裸地址。

6. 公共 ABI 模块

include/catlass_kernel.h定义 C++ wrapper 和 kernel 实现之间共享的数据结构和函数声明。

6.1 matmul ABI

MatmulTParams表示编译期参数:

  • elementA
  • elementB
  • elementC
  • transA
  • transB
  • transC
  • useNzA
  • useNzB
  • useNzC

MatmulParams表示运行期参数:

  • m
  • n
  • k
  • inputAddr
  • outputAddr

kernel entry 签名保持固定:

void BasicMatmul( uint32_t blockNum, aclrtStream stream, const MatmulTParams& tParams, const MatmulParams& params);

6.2 扩展 ABI

catlass_kernel.h中还保留了 grouped matmul、quant matmul、conv、flash attention 等参数结构和函数声明。新增算子时优先复用已有 ABI 结构;当参数语义明显不同,再新增独立结构。

7. Utils 模块

utils/将工具函数拆成两个 target:

target文件依赖用途
catlass_kernel_utilskernel_utils.cppACLJIT compiler 使用的 dtype 到 bisheng type 转换
catlass_torch_utilstorch_utils.cppACL、torch、torch-npuPyTorch wrapper 使用的 Tensor/dtype/layout 工具

7.1 dtype 映射

utils/type_utils.hpp维护 dtype 映射表:

  • canonical string name
  • torch::Dtype
  • aclDataType
  • bisheng C++ type token

映射表按 index 对齐,TypeCast<S, T>()通过查表完成不同表示之间的转换。

部分 dtype 在 torch-npu 随包 ACL 头中没有枚举名,但 ABI 数值与 CANN 定义一致。此类 dtype 使用static_cast<aclDataType>(value)表达,避免混用系统 CANN ACL 头和 torch-npu 随包 ACL 头导致重复定义。

7.2 Tensor 工具

torch_utils.cpp提供:

  • GetOutputTensor():在当前 NPU device 上创建 ND 格式输出 Tensor。
  • TypeStrToTorchDtype():字符串到 torch dtype。
  • TorchDtypeToAclDtype():torch dtype 到 ACL dtype。
  • AclDtypeToTorchDtype():ACL dtype 到 torch dtype。
  • GetTransposeStatus():根据 tensor stride 和 NPU format 判断矩阵布局。

8. JIT 子系统

JIT 子系统由四部分组成:

组件文件责任
JIT entrykernels/00_basic_matmul/basic_matmul.cpp稳定 kernel 入口,负责获取 JIT 函数并调用
JIT templatekernels/00_basic_matmul/basic_matmul_impl.cpp被运行时编译的 CATLASS kernel 模板
JIT compilerkernels/jit/jit_compiler.cpp编译、缓存、加载.so
macro generatorkernels/include/jit_macro_generator.h将模板参数转为-D

8.1 JIT entry

JIT entry 固定编译进libcatlass_kernel_jit.so。以BasicMatmul为例:

auto* entry = JitCompiler::instance().getKernel( "basic_matmul_impl.cpp", JitMacroGenerator<MatmulTParams>::generate("basic_matmul", tParams)); entry(blockNum, stream, &tParams, &params);

entry 的职责是连接 stable ABI 和 runtime-compiled template,不承载具体 GEMM 模板逻辑。

8.2 JIT template

JIT template 使用宏注入类型和布局:

语义
CATLASS_JIT_ELEMENT_AA 元素类型
CATLASS_JIT_ELEMENT_BB 元素类型
CATLASS_JIT_ELEMENT_CC 元素类型
CATLASS_JIT_LAYOUT_AA layout
CATLASS_JIT_LAYOUT_BB layout
CATLASS_JIT_LAYOUT_CC layout
CATLASS_JIT_KERNEL_NAMEdevice kernel 符号名

模板导出稳定 C ABI:

extern "C" void run( uint32_t blockNum, aclrtStream stream, const MatmulTParams* tParams, const MatmulParams* params);

JIT loader 固定解析run符号。device kernel 名只用于编译产物可读性和 profiling。

8.3 宏生成

JitMacroGenerator<TParams>是模板策略类。默认模板不生成任何宏,具体参数类型通过特化实现。

JitMacroGenerator<MatmulTParams>生成:

  • CATLASS_KERNEL_NAME
  • CATLASS_JIT_ELEMENT_A
  • CATLASS_JIT_ELEMENT_B
  • CATLASS_JIT_ELEMENT_C
  • CATLASS_JIT_LAYOUT_A
  • CATLASS_JIT_LAYOUT_B
  • CATLASS_JIT_LAYOUT_C
  • CATLASS_JIT_KERNEL_NAME

新增非 matmul JIT kernel 时,应新增对应参数结构的JitMacroGenerator特化。

8.4 编译和缓存

JitCompiler是进程级单例。初始化内容包括:

  • JIT cache 目录。
  • bisheng/ccec 路径。
  • 当前 NPU arch。
  • JIT template 根目录。

缓存分两层:

  1. 内存缓存:loaded_保存SharedLibrun指针。
  2. 磁盘缓存:保存编译后的.so

cache key 由 kernel 名、arch 和完整宏集合构成:

<CATLASS_KERNEL_NAME>_arch<arch>_<macro-name>_<macro-value>...

宏按 key 排序后拼接,保证unordered_map遍历顺序不影响缓存路径。

8.5 环境变量

环境变量作用
CATLASS_JIT_LOG_LEVELJIT 日志等级,0/1/2
TORCH_CATLASS_CACHE_DIRJIT 磁盘缓存目录
MS_SANITIZE_MEMORY启用 msSanitizer 编译选项
TORCH_CATLASS_VERSION注入 package/CATLASS 版本
ASCEND_HOME_PATH查找 Ascend compiler 和 runtime 库
TORCH_CATLASS_PKG_DIR安装后的 Python 包根目录

JIT 编译使用的 NPU arch 只通过 AscendC platform API 获取,即GetCurrentNPUArch()。运行时不支持通过环境变量覆盖 arch,避免编译参数和实际设备不一致。

环境变量分为外部配置和包内注入两类。外部配置只保留ASCEND_HOME_PATHTORCH_CATLASS_CACHE_DIRCATLASS_JIT_LOG_LEVELMS_SANITIZE_MEMORYTORCH_CATLASS_VERSIONTORCH_CATLASS_PKG_DIR由 Python loader 在 import 时设置。JIT template 路径固定由TORCH_CATLASS_PKG_DIR/jit/templates/推导,compiler 优先从ASCEND_HOME_PATH的标准目录查找,未找到时回退到PATH中的ccec

9. Kernel 构建模块

kernels/CMakeLists.txt提供add_kernel(),统一 JIT 和 prebuilt kernel 的构建入口。

9.1 JIT kernel

add_kernel( NAME basic_matmul NPU_ARCH_LIST 2201 KERNEL_TYPE jit ${CMAKE_CURRENT_SOURCE_DIR}/basic_matmul.cpp TEMPLATE ${CMAKE_CURRENT_SOURCE_DIR}/basic_matmul_impl.cpp)

JIT kernel 构建流程:

  1. entry 源文件加入统一libcatlass_kernel_jit.so
  2. template 文件安装到jit/templates/
  3. jit_try_compile_check()在构建期检查 template 可被 bisheng 编译。
  4. 运行时由JitCompiler根据模板参数编译具体.so

9.2 prebuilt kernel

prebuilt kernel 按 arch 构建独立动态库:

lib/<arch>/libcatlass_kernel_<arch>_<name>.so

prebuilt 模式用于固定参数组合或无需运行时模板编译的 kernel。启用 msSanitizer 时,可额外构建_ms变体。

10. 顶层构建模块

10.1 Python 构建入口

build.sh是主要构建入口:

  • 推导 package 版本。
  • 写入torch_catlass/_version.py
  • 自动探测当前 Python 环境中的 Torch CMake 目录。
  • 调用 pip/scikit-build-core 驱动 CMake 构建。

常用命令:

bash build.sh --skip-wheel bash build.sh --build-type Debug --skip-wheel bash build.sh --clean

10.2 CMake target

顶层CMakeLists.txt负责:

  • 查找 ASC、Python、Torch。
  • 设置 C++17、PIC、compile commands。
  • 安装 public headers。
  • 安装 CATLASS headers 到 Python 包内的 JIT include tree。
  • 添加kernelsutilssrc子目录。

关键 target:

target输出说明
catlass_kernel_utilsstatic libJIT compiler 依赖的纯 ACL 工具
catlass_torch_utilsstatic libtorch wrapper 依赖的 Tensor 工具
catlass_kernel_jit_compilershared libJIT 编译器
catlass_kernel_jitshared libJIT entry 集合
catlass_torchshared libPyTorch extension

11. 测试模块

pytest 集成测试验证 Python API 到 kernel 执行的完整链路。

tests/test_00_basic_matmul.py测试流程:

  1. 检查是否存在可用 Ascend NPU;无设备时跳过。
  2. 构造 NPU fp16 输入。
  3. 调用torch_catlass.basic_matmul()
  4. torch.matmul()生成参考结果。
  5. 校验 shape、dtype、device。
  6. torch.allclose()校验数值。

本地静态检查:

python3 -m py_compile torch_catlass/__init__.py torch_catlass/ops/basic_matmul.py tests/test_00_basic_matmul.py python3 -m ruff check torch_catlass tests

完整集成测试:

bash build.sh --skip-wheel pytest tests/test_00_basic_matmul.py -v -s

12. 扩展流程

12.1 新增 matmul 类 JIT 算子

  1. kernels/<nn_name>/下添加 entry.cpp和 template.cpp
  2. 在 entry 中调用JitCompiler::instance().getKernel()
  3. 复用或扩展JitMacroGenerator<MatmulTParams>
  4. kernels/<nn_name>/CMakeLists.txt中调用add_kernel()
  5. include/catlass_kernel.h声明 kernel entry。
  6. src/catlass_torch.cpp使用MatmulLike<Kernel>注册 torch op。
  7. torch_catlass/ops/添加 Python wrapper。
  8. tests/添加 pytest,与 PyTorch 参考实现比对。

12.2 新增非 matmul 算子

非 matmul 算子应新增独立 adapter,而不是扩展MatmulLike

src/include/template/<op_family>.h ├─ GetKernelInfo() ├─ AllocOutput() └─ Run()

同时新增对应的参数结构和JitMacroGenerator特化,保持参数解析、宏生成和 kernel ABI 各自独立。

12.3 新增 dtype 或 layout

新增 dtype/layout 时需要同步更新:

  • utils/type_utils.hpp
  • JitMacroGenerator对应特化
  • JIT template 中的默认宏和类型别名
  • pytest 参数覆盖

dtype 映射应优先使用当前编译环境可见的 enum;当 torch-npu 随包 ACL 头未暴露 enum 名但 ABI 数值稳定时,可使用static_cast<aclDataType>(value)保持兼容。

【免费下载链接】catlass本项目是CANN的算子模板库,提供NPU上高性能矩阵乘及其相关融合类算子模板样例。项目地址: https://gitcode.com/cann/catlass

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/22 17:22:17

Legba终极指南:如何用Rust构建高性能多协议凭证破解工具

Legba终极指南&#xff1a;如何用Rust构建高性能多协议凭证破解工具 【免费下载链接】legba The fastest and more comprehensive multiprotocol credentials bruteforcer / password sprayer and enumerator. &#x1f977; 项目地址: https://gitcode.com/gh_mirrors/le/le…

作者头像 李华
网站建设 2026/5/22 17:22:16

CANN/pypto最大元素函数

&#xfeff;# pypto.maximum 【免费下载链接】pypto PyPTO&#xff08;发音: pai p-t-o&#xff09;&#xff1a;Parallel Tensor/Tile Operation编程范式。 项目地址: https://gitcode.com/cann/pypto 产品支持情况 产品是否支持Ascend 950PR/Ascend 950DT√Atlas A3…

作者头像 李华
网站建设 2026/5/22 17:17:13

AI 工作范式下的研发新范式:从需求到测试的全链路落地指南

写在前面 最近一年&#xff0c;团队里几乎每个 Java 后端、前端、甚至产品经理&#xff0c;都在用 AI 编辑器写代码。Cursor、Qoder、Claude Code、Trae、Copilot……工具的迭代速度肉眼可见。 但用着用着大家发现一个尴尬的现实&#xff1a;工具升级了&#xff0c;研发流程没升…

作者头像 李华
网站建设 2026/5/22 17:16:23

PDF Arranger终极指南:快速掌握免费PDF页面管理工具

PDF Arranger终极指南&#xff1a;快速掌握免费PDF页面管理工具 【免费下载链接】pdfarranger Small python-gtk application, which helps the user to merge or split PDF documents and rotate, crop and rearrange their pages using an interactive and intuitive graphic…

作者头像 李华
网站建设 2026/5/22 17:16:18

B站直播弹幕机器人实战手册:3步打造高互动自动化直播间

B站直播弹幕机器人实战手册&#xff1a;3步打造高互动自动化直播间 【免费下载链接】MagicalDanmaku 本仓库及所有相关项目已永久停止开发、维护和任何形式的分发。 项目地址: https://gitcode.com/gh_mirrors/bi/MagicalDanmaku 还在为直播时手忙脚乱回复弹幕而焦虑吗&…

作者头像 李华
网站建设 2026/5/22 17:15:59

Git-Sim技术工具:可视化分析Git操作的终极开发效率解决方案

Git-Sim技术工具&#xff1a;可视化分析Git操作的终极开发效率解决方案 【免费下载链接】git-sim Visually simulate Git operations in your own repos with a single terminal command. 项目地址: https://gitcode.com/gh_mirrors/gi/git-sim 在复杂的软件开发流程中&…

作者头像 李华