torch-catlass 测试框架设计文档
1. 总览
tests/optest是 CATLASS 示例算子接入 PyTorch 的端到端测试框架。框架将 CATLASS AscendC kernel 封装为torch.ops.catlass.*算子,并通过 Python 包torch_catlass提供测试入口。
框架按职责分为五层:
Python API torch_catlass.ops.* | v Python package loader load kernel libs and libcatlass_torch.so | v PyTorch C++ extension register torch.ops.catlass.* | v Kernel adapter convert Tensor arguments to CATLASS kernel params | v Kernel implementation prebuilt kernel or JIT compiled template核心原则:
- Python 层只负责用户接口、动态库加载和轻量参数转换。
- C++ extension 层负责 PyTorch op 注册和 NPU dispatch。
- adapter 层负责 Tensor 到 kernel ABI 的转换。
- kernel 层负责 AscendC/CATLASS 代码执行。
- JIT 子系统负责模板参数宏生成、编译、缓存和动态加载。
2. 目录模块
tests/optest/ ├── pyproject.toml ├── CMakeLists.txt ├── build.sh ├── docs/ │ └── design.md ├── include/ │ ├── catlass_kernel.h │ └── catlass_torch.h ├── torch_catlass/ │ ├── __init__.py │ ├── _version.py │ └── ops/ ├── src/ │ ├── catlass_torch.cpp │ ├── common/ │ └── include/ ├── utils/ │ ├── CMakeLists.txt │ ├── include/ │ ├── kernel_utils.cpp │ ├── torch_utils.cpp │ └── type_utils.hpp ├── kernels/ │ ├── CMakeLists.txt │ ├── common/ │ ├── include/ │ ├── jit/ │ └── 00_basic_matmul/ └── tests/ └── test_00_basic_matmul.py| 模块 | 责任 |
|---|---|
torch_catlass/ | Python 包入口、动态库加载、用户侧 op wrapper |
include/ | 框架和 kernel 共享的公共 ABI 声明 |
src/ | PyTorch C++ extension 和 torch op 注册 |
utils/ | dtype/layout/Tensor 工具函数,拆分 torch 依赖和纯 ACL 依赖 |
kernels/ | kernel 构建、JIT compiler、JIT template、kernel entry |
tests/ | pytest 集成测试 |
3. Python 包模块
3.1 包初始化
torch_catlass/__init__.py在 import 时完成运行时初始化:
- 从
_version.py读取构建期版本信息。 - 设置
TORCH_CATLASS_VERSION,供 JIT 编译阶段注入版本宏。 - 设置
TORCH_CATLASS_PKG_DIR,供 JIT compiler 定位安装后的 headers 和 templates。 - 加载 JIT compiler 和 JIT kernel entry 库。
- 根据当前 NPU 架构加载 arch-specific kernel 库。
- 调用
torch.ops.load_library()加载libcatlass_torch.so,注册torch.ops.catlass.*。
动态库加载顺序:
lib/jit/libcatlass_kernel_jit_compiler.so lib/jit/libcatlass_kernel_jit.so lib/<arch>/*.so lib/libcatlass_torch.soJIT compiler 和 kernel entry 必须先于 PyTorch extension 加载,保证 extension 中引用的 kernel 符号可解析。
3.2 架构识别
Python loader 通过torch_npu.npu.get_device_name()识别设备并映射为 CATLASS arch id:
| 设备名 | arch |
|---|---|
Ascend910B.* | 2201 |
Ascend910_93 | 2201 |
Ascend950PR/Ascend950DT | 3510 |
当torch_npu.npu.device_count()为 0 时,loader 抛出明确错误。测试代码在无 NPU 环境下通过 pytest skip 处理,避免在 collection 阶段触发 torch-npu 内部错误。
3.3 Python op wrapper
torch_catlass/ops/保存 Python 用户接口。以basic_matmul为例:
torch_catlass.basic_matmul( mat1, mat2, outDType="float16", transA=False, transB=False, formatA=False, formatB=False, )Python wrapper 只做用户友好的轻量转换,例如 dtype 字符串白名单解析。shape 推导、输出分配、stream 获取和 kernel launch 都在 C++ 层完成,避免 Python 和 C++ 维护两套语义。
4. PyTorch C++ Extension 模块
4.1 注册入口
src/catlass_torch.cpp是 PyTorch extension 的注册入口:
using BasicMatmulOp = MatmulLike<CatlassKernel::BasicMatmul>; static auto& basic_matmul = BasicMatmulOp::Run; REGISTER_TORCH_FUNC(basic_matmul);REGISTER_TORCH_FUNC位于src/include/common/register.h,注册流程为:
- 创建或复用
torch::Library,namespace 固定为catlass。 - 通过 PyTorch schema inference 从 C++ 函数签名生成 schema。
- 将实现注册到
c10::DispatchKey::PrivateUse1。
PrivateUse1是 torch-npu 使用的 NPU dispatch key。Python 调用torch.ops.catlass.basic_matmul时,PyTorch 根据输入 Tensor device 走到该 backend 实现。
4.2 kernel launch 包装
RUN_NPU_FUNC位于src/include/common/run_npu_func.h。它通过 torch-npu 的OpCommand::RunOpApiV2执行 kernel launch:
- launch 前检查函数指针是否为空。
- 将 C++ 异常转为 ACL error code。
- 将 kernel 调用交给 torch-npu runtime 管理。
5. Matmul Adapter 模块
src/include/template/matmul.h提供MatmulLike<KernelFunc>。该模板封装 matmul 类算子的通用流程:
Run() ├─ GetKernelInfo() ├─ AllocOutput() ├─ get current NPU stream ├─ get AIC core count └─ RUN_NPU_FUNC(KernelFunc, ...)5.1 参数拆分
Matmul 参数拆分为两类:
| 参数结构 | 用途 | 是否参与 JIT 编译 |
|---|---|---|
MatmulTParams | dtype、layout、transpose 等模板参数 | 是 |
MatmulParams | M/N/K、输入输出地址等运行时参数 | 否 |
这种拆分保证 dtype/layout 变化会生成新的模板实例,而 shape 和 Tensor 地址变化不会导致重复编译。
5.2GetKernelInfo
GetKernelInfo()负责将 PyTorch Tensor 转为 kernel 参数:
- 将
torch::Dtype转为aclDataType。 - 根据
transA和transB推导m/n/k。 - 检查两个输入矩阵的 K 维一致。
- 填充
MatmulTParams。 - 填充
MatmulParams。 - 将输入 Tensor storage 地址写入
params.inputAddr。
5.3AllocOutput
AllocOutput()根据params.m、params.n和tParams.elementC创建输出 Tensor,并将其 storage 地址写入params.outputAddr[0]。输出 Tensor 生命周期由 PyTorch 管理,kernel ABI 只接收裸地址。
6. 公共 ABI 模块
include/catlass_kernel.h定义 C++ wrapper 和 kernel 实现之间共享的数据结构和函数声明。
6.1 matmul ABI
MatmulTParams表示编译期参数:
elementAelementBelementCtransAtransBtransCuseNzAuseNzBuseNzC
MatmulParams表示运行期参数:
mnkinputAddroutputAddr
kernel entry 签名保持固定:
void BasicMatmul( uint32_t blockNum, aclrtStream stream, const MatmulTParams& tParams, const MatmulParams& params);6.2 扩展 ABI
catlass_kernel.h中还保留了 grouped matmul、quant matmul、conv、flash attention 等参数结构和函数声明。新增算子时优先复用已有 ABI 结构;当参数语义明显不同,再新增独立结构。
7. Utils 模块
utils/将工具函数拆成两个 target:
| target | 文件 | 依赖 | 用途 |
|---|---|---|---|
catlass_kernel_utils | kernel_utils.cpp | ACL | JIT compiler 使用的 dtype 到 bisheng type 转换 |
catlass_torch_utils | torch_utils.cpp | ACL、torch、torch-npu | PyTorch wrapper 使用的 Tensor/dtype/layout 工具 |
7.1 dtype 映射
utils/type_utils.hpp维护 dtype 映射表:
- canonical string name
torch::DtypeaclDataType- bisheng C++ type token
映射表按 index 对齐,TypeCast<S, T>()通过查表完成不同表示之间的转换。
部分 dtype 在 torch-npu 随包 ACL 头中没有枚举名,但 ABI 数值与 CANN 定义一致。此类 dtype 使用static_cast<aclDataType>(value)表达,避免混用系统 CANN ACL 头和 torch-npu 随包 ACL 头导致重复定义。
7.2 Tensor 工具
torch_utils.cpp提供:
GetOutputTensor():在当前 NPU device 上创建 ND 格式输出 Tensor。TypeStrToTorchDtype():字符串到 torch dtype。TorchDtypeToAclDtype():torch dtype 到 ACL dtype。AclDtypeToTorchDtype():ACL dtype 到 torch dtype。GetTransposeStatus():根据 tensor stride 和 NPU format 判断矩阵布局。
8. JIT 子系统
JIT 子系统由四部分组成:
| 组件 | 文件 | 责任 |
|---|---|---|
| JIT entry | kernels/00_basic_matmul/basic_matmul.cpp | 稳定 kernel 入口,负责获取 JIT 函数并调用 |
| JIT template | kernels/00_basic_matmul/basic_matmul_impl.cpp | 被运行时编译的 CATLASS kernel 模板 |
| JIT compiler | kernels/jit/jit_compiler.cpp | 编译、缓存、加载.so |
| macro generator | kernels/include/jit_macro_generator.h | 将模板参数转为-D宏 |
8.1 JIT entry
JIT entry 固定编译进libcatlass_kernel_jit.so。以BasicMatmul为例:
auto* entry = JitCompiler::instance().getKernel( "basic_matmul_impl.cpp", JitMacroGenerator<MatmulTParams>::generate("basic_matmul", tParams)); entry(blockNum, stream, &tParams, ¶ms);entry 的职责是连接 stable ABI 和 runtime-compiled template,不承载具体 GEMM 模板逻辑。
8.2 JIT template
JIT template 使用宏注入类型和布局:
| 宏 | 语义 |
|---|---|
CATLASS_JIT_ELEMENT_A | A 元素类型 |
CATLASS_JIT_ELEMENT_B | B 元素类型 |
CATLASS_JIT_ELEMENT_C | C 元素类型 |
CATLASS_JIT_LAYOUT_A | A layout |
CATLASS_JIT_LAYOUT_B | B layout |
CATLASS_JIT_LAYOUT_C | C layout |
CATLASS_JIT_KERNEL_NAME | device kernel 符号名 |
模板导出稳定 C ABI:
extern "C" void run( uint32_t blockNum, aclrtStream stream, const MatmulTParams* tParams, const MatmulParams* params);JIT loader 固定解析run符号。device kernel 名只用于编译产物可读性和 profiling。
8.3 宏生成
JitMacroGenerator<TParams>是模板策略类。默认模板不生成任何宏,具体参数类型通过特化实现。
JitMacroGenerator<MatmulTParams>生成:
CATLASS_KERNEL_NAMECATLASS_JIT_ELEMENT_ACATLASS_JIT_ELEMENT_BCATLASS_JIT_ELEMENT_CCATLASS_JIT_LAYOUT_ACATLASS_JIT_LAYOUT_BCATLASS_JIT_LAYOUT_CCATLASS_JIT_KERNEL_NAME
新增非 matmul JIT kernel 时,应新增对应参数结构的JitMacroGenerator特化。
8.4 编译和缓存
JitCompiler是进程级单例。初始化内容包括:
- JIT cache 目录。
- bisheng/ccec 路径。
- 当前 NPU arch。
- JIT template 根目录。
缓存分两层:
- 内存缓存:
loaded_保存SharedLib和run指针。 - 磁盘缓存:保存编译后的
.so。
cache key 由 kernel 名、arch 和完整宏集合构成:
<CATLASS_KERNEL_NAME>_arch<arch>_<macro-name>_<macro-value>...宏按 key 排序后拼接,保证unordered_map遍历顺序不影响缓存路径。
8.5 环境变量
| 环境变量 | 作用 |
|---|---|
CATLASS_JIT_LOG_LEVEL | JIT 日志等级,0/1/2 |
TORCH_CATLASS_CACHE_DIR | JIT 磁盘缓存目录 |
MS_SANITIZE_MEMORY | 启用 msSanitizer 编译选项 |
TORCH_CATLASS_VERSION | 注入 package/CATLASS 版本 |
ASCEND_HOME_PATH | 查找 Ascend compiler 和 runtime 库 |
TORCH_CATLASS_PKG_DIR | 安装后的 Python 包根目录 |
JIT 编译使用的 NPU arch 只通过 AscendC platform API 获取,即GetCurrentNPUArch()。运行时不支持通过环境变量覆盖 arch,避免编译参数和实际设备不一致。
环境变量分为外部配置和包内注入两类。外部配置只保留ASCEND_HOME_PATH、TORCH_CATLASS_CACHE_DIR、CATLASS_JIT_LOG_LEVEL和MS_SANITIZE_MEMORY;TORCH_CATLASS_VERSION和TORCH_CATLASS_PKG_DIR由 Python loader 在 import 时设置。JIT template 路径固定由TORCH_CATLASS_PKG_DIR/jit/templates/推导,compiler 优先从ASCEND_HOME_PATH的标准目录查找,未找到时回退到PATH中的ccec。
9. Kernel 构建模块
kernels/CMakeLists.txt提供add_kernel(),统一 JIT 和 prebuilt kernel 的构建入口。
9.1 JIT kernel
add_kernel( NAME basic_matmul NPU_ARCH_LIST 2201 KERNEL_TYPE jit ${CMAKE_CURRENT_SOURCE_DIR}/basic_matmul.cpp TEMPLATE ${CMAKE_CURRENT_SOURCE_DIR}/basic_matmul_impl.cpp)JIT kernel 构建流程:
- entry 源文件加入统一
libcatlass_kernel_jit.so。 - template 文件安装到
jit/templates/。 jit_try_compile_check()在构建期检查 template 可被 bisheng 编译。- 运行时由
JitCompiler根据模板参数编译具体.so。
9.2 prebuilt kernel
prebuilt kernel 按 arch 构建独立动态库:
lib/<arch>/libcatlass_kernel_<arch>_<name>.soprebuilt 模式用于固定参数组合或无需运行时模板编译的 kernel。启用 msSanitizer 时,可额外构建_ms变体。
10. 顶层构建模块
10.1 Python 构建入口
build.sh是主要构建入口:
- 推导 package 版本。
- 写入
torch_catlass/_version.py。 - 自动探测当前 Python 环境中的 Torch CMake 目录。
- 调用 pip/scikit-build-core 驱动 CMake 构建。
常用命令:
bash build.sh --skip-wheel bash build.sh --build-type Debug --skip-wheel bash build.sh --clean10.2 CMake target
顶层CMakeLists.txt负责:
- 查找 ASC、Python、Torch。
- 设置 C++17、PIC、compile commands。
- 安装 public headers。
- 安装 CATLASS headers 到 Python 包内的 JIT include tree。
- 添加
kernels、utils、src子目录。
关键 target:
| target | 输出 | 说明 |
|---|---|---|
catlass_kernel_utils | static lib | JIT compiler 依赖的纯 ACL 工具 |
catlass_torch_utils | static lib | torch wrapper 依赖的 Tensor 工具 |
catlass_kernel_jit_compiler | shared lib | JIT 编译器 |
catlass_kernel_jit | shared lib | JIT entry 集合 |
catlass_torch | shared lib | PyTorch extension |
11. 测试模块
pytest 集成测试验证 Python API 到 kernel 执行的完整链路。
tests/test_00_basic_matmul.py测试流程:
- 检查是否存在可用 Ascend NPU;无设备时跳过。
- 构造 NPU fp16 输入。
- 调用
torch_catlass.basic_matmul()。 - 用
torch.matmul()生成参考结果。 - 校验 shape、dtype、device。
- 用
torch.allclose()校验数值。
本地静态检查:
python3 -m py_compile torch_catlass/__init__.py torch_catlass/ops/basic_matmul.py tests/test_00_basic_matmul.py python3 -m ruff check torch_catlass tests完整集成测试:
bash build.sh --skip-wheel pytest tests/test_00_basic_matmul.py -v -s12. 扩展流程
12.1 新增 matmul 类 JIT 算子
- 在
kernels/<nn_name>/下添加 entry.cpp和 template.cpp。 - 在 entry 中调用
JitCompiler::instance().getKernel()。 - 复用或扩展
JitMacroGenerator<MatmulTParams>。 - 在
kernels/<nn_name>/CMakeLists.txt中调用add_kernel()。 - 在
include/catlass_kernel.h声明 kernel entry。 - 在
src/catlass_torch.cpp使用MatmulLike<Kernel>注册 torch op。 - 在
torch_catlass/ops/添加 Python wrapper。 - 在
tests/添加 pytest,与 PyTorch 参考实现比对。
12.2 新增非 matmul 算子
非 matmul 算子应新增独立 adapter,而不是扩展MatmulLike:
src/include/template/<op_family>.h ├─ GetKernelInfo() ├─ AllocOutput() └─ Run()同时新增对应的参数结构和JitMacroGenerator特化,保持参数解析、宏生成和 kernel ABI 各自独立。
12.3 新增 dtype 或 layout
新增 dtype/layout 时需要同步更新:
utils/type_utils.hppJitMacroGenerator对应特化- JIT template 中的默认宏和类型别名
- pytest 参数覆盖
dtype 映射应优先使用当前编译环境可见的 enum;当 torch-npu 随包 ACL 头未暴露 enum 名但 ABI 数值稳定时,可使用static_cast<aclDataType>(value)保持兼容。
【免费下载链接】catlass本项目是CANN的算子模板库,提供NPU上高性能矩阵乘及其相关融合类算子模板样例。项目地址: https://gitcode.com/cann/catlass
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考