news 2026/5/9 12:27:48

CANN/catlass GEMM Kernel开发详解

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
CANN/catlass GEMM Kernel开发详解

GEMM Kernel 代码开发详解

【免费下载链接】catlass本项目是CANN的算子模板库,提供NPU上高性能矩阵乘及其相关融合类算子模板样例。项目地址: https://gitcode.com/cann/catlass

1. Kernel代码结构概述

CATLASS模板库中的GEMM Kernel采用了高度模块化的设计,通过模板参数组装不同的组件来实现各种矩阵乘法功能。本文将以BasicMatmul为例,详细拆解Kernel代码的核心结构和关键组件。

2. 模板组装机制

所有GEMM Kernel都采用模板类的形式定义,通过模板参数来组装不同的功能组件。以BasicMatmul为例:

template < class BlockMmad_, class BlockEpilogue_, class BlockScheduler_ > class BasicMatmul { public: using BlockMmad = BlockMmad_; using ArchTag = typename BlockMmad::ArchTag; using L1TileShape = typename BlockMmad::L1TileShape; using ElementA = typename BlockMmad::ElementA; using LayoutA = typename BlockMmad::LayoutA; using ElementB = typename BlockMmad::ElementB; using LayoutB = typename BlockMmad::LayoutB; using ElementC = typename BlockMmad::ElementC; using LayoutC = typename BlockMmad::LayoutC; using ElementAccumulator = typename BlockMmad::ElementAccumulator; using BlockScheduler = BlockScheduler_; // ... };

2.1 核心模板参数

模板参数描述
BlockMmad_负责矩阵乘法的核心计算组件
BlockEpilogue_负责计算结果的后处理(如激活函数、量化等)
BlockScheduler_负责调度和分配计算任务到不同的计算核心

2.2 类型导出

通过模板参数导出的类型形成了Kernel的核心类型系统,包括:

  • 架构标签(ArchTag)
  • L1缓存 tile 形状(L1TileShape)
  • 数据类型(ElementA/B/C/Accumulator)
  • 数据布局(LayoutA/B/C)

3. 参数传递机制

Kernel采用了两层参数结构:Arguments(用户接口层)和Params(内核执行层)。

3.1 Arguments结构

Arguments是用户直接使用的参数结构,包含最基本的输入输出信息:

struct Arguments { GemmCoord problemShape; GM_ADDR ptrA; GM_ADDR ptrB; GM_ADDR ptrC; };

3.2 Params结构

Params是内核实际执行时使用的参数结构,包含更详细的执行信息:

struct Params { // Data members GemmCoord problemShape; GM_ADDR ptrA; LayoutA layoutA; GM_ADDR ptrB; LayoutB layoutB; GM_ADDR ptrC; LayoutC layoutC; // Methods CATLASS_HOST_DEVICE Params() {} CATLASS_HOST_DEVICE Params(GemmCoord const &problemShape_, GM_ADDR ptrA_, LayoutA layoutA_, GM_ADDR ptrB_, LayoutB layoutB_, GM_ADDR ptrC_, LayoutC layoutC_) : problemShape(problemShape_), ptrA(ptrA_), layoutA(layoutA_), ptrB(ptrB_), layoutB(layoutB_), ptrC(ptrC_), layoutC(layoutC_) {} };

3.3 参数转换

通过ToUnderlyingArguments函数将Arguments转换为Params

static Params ToUnderlyingArguments(const Arguments &args, uint8_t *workspace) { LayoutA layoutA{args.problemShape.m(), args.problemShape.k()}; LayoutB layoutB{args.problemShape.k(), args.problemShape.n()}; LayoutC layoutC{args.problemShape.m(), args.problemShape.n()}; Params params{args.problemShape, args.ptrA, layoutA, args.ptrB, layoutB, args.ptrC, layoutC}; return params; }

4. 关键函数解析

4.1 CanImplement

检查当前硬件和环境是否支持实现该Kernel:

static bool CanImplement(const Arguments &args) { return true; }

4.2 GetWorkspaceSize

获取Kernel执行所需的工作区大小:

static size_t GetWorkspaceSize(const Arguments &args) { return 0; }

4.3 operator()

Kernel的核心执行函数,通过模板特化支持不同的核心类型(如AIC、AIV):

template <int32_t CORE_TYPE = g_coreType> CATLASS_DEVICE void operator()(Params const &params); /// Executes one Matmul template <> CATLASS_DEVICE void operator()<AscendC::AIC>(Params const &params) { BlockScheduler matmulBlockScheduler(params.problemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); uint32_t coreLoops = matmulBlockScheduler.GetCoreLoops(); Arch::Resource<ArchTag> resource; BlockMmad blockMmad(resource); // Represent the full gm AscendC::GlobalTensor<ElementA> gmA; gmA.SetGlobalBuffer((__gm__ ElementA *)params.ptrA); AscendC::GlobalTensor<ElementB> gmB; gmB.SetGlobalBuffer((__gm__ ElementB *)params.ptrB); AscendC::GlobalTensor<ElementC> gmC; gmC.SetGlobalBuffer((__gm__ ElementC *)params.ptrC); for (uint32_t loopIdx = AscendC::GetBlockIdx(); loopIdx < coreLoops; loopIdx += AscendC::GetBlockNum()) { // Compute block location GemmCoord blockCoord = matmulBlockScheduler.GetBlockCoord(loopIdx); GemmCoord actualBlockShape = matmulBlockScheduler.GetActualBlockShape(blockCoord); // Compute initial location in logical coordinates MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K}; MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N}; MatrixCoord offsetC{blockCoord.m() * L1TileShape::M, blockCoord.n() * L1TileShape::N}; int64_t gmOffsetA = params.layoutA.GetOffset(offsetA); int64_t gmOffsetB = params.layoutB.GetOffset(offsetB); int64_t gmOffsetC = params.layoutC.GetOffset(offsetC); // Compute block-scoped matrix multiply-add blockMmad(gmA[gmOffsetA], params.layoutA, gmB[gmOffsetB], params.layoutB, gmC[gmOffsetC], params.layoutC, actualBlockShape); } AscendC::PipeBarrier<PIPE_ALL>(); }

5. 执行流程分析

Kernel的执行流程可以概括为以下几个步骤:

5.1 初始化调度器

BlockScheduler matmulBlockScheduler(params.problemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); uint32_t coreLoops = matmulBlockScheduler.GetCoreLoops();

5.2 初始化资源和计算组件

Arch::Resource<ArchTag> resource; BlockMmad blockMmad(resource);

5.3 设置全局内存张量

AscendC::GlobalTensor<ElementA> gmA; gmA.SetGlobalBuffer((__gm__ ElementA *)params.ptrA); // 设置gmB和gmC...

5.4 循环处理每个计算块

for (uint32_t loopIdx = AscendC::GetBlockIdx(); loopIdx < coreLoops; loopIdx += AscendC::GetBlockNum()) { // 1. 计算块坐标 GemmCoord blockCoord = matmulBlockScheduler.GetBlockCoord(loopIdx); GemmCoord actualBlockShape = matmulBlockScheduler.GetActualBlockShape(blockCoord); // 2. 计算内存偏移 MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K}; // 计算offsetB和offsetC... int64_t gmOffsetA = params.layoutA.GetOffset(offsetA); // 计算gmOffsetB和gmOffsetC... // 3. 执行块级矩阵乘法 blockMmad(gmA[gmOffsetA], params.layoutA, gmB[gmOffsetB], params.layoutB, gmC[gmOffsetC], params.layoutC, actualBlockShape); }

5.5 同步操作

AscendC::PipeBarrier<PIPE_ALL>();

6. 不同Kernel的扩展与差异

通过对比BasicMatmulBatchedMatmulQuantMatmulOptimizedMatmul,我们可以看到它们在基础结构上的共性和扩展差异:

6.1 BatchedMatmul扩展

BatchedMatmulBasicMatmul的基础上增加了批处理支持:

struct Params { // Data members uint32_t batchCount; // 增加批处理计数 GemmCoord problemShape; GM_ADDR ptrA; LayoutA layoutA; int64_t strideA; // 增加A矩阵的批处理 stride GM_ADDR ptrB; LayoutB layoutB; int64_t strideB; // 增加B矩阵的批处理 stride GM_ADDR ptrC; LayoutC layoutC; int64_t strideC; // 增加C矩阵的批处理 stride // ... };

6.2 QuantMatmul扩展

QuantMatmul增加了量化相关的参数和处理:

struct Params { // Data members GemmCoord problemShape; __gm__ ElementA *ptrA; LayoutA layoutA; __gm__ ElementB *ptrB; LayoutB layoutB; __gm__ ElementScale *ptrScale; // 增加缩放参数 LayoutScale layoutScale; __gm__ ElementPerTokenScale *ptrPerTokenScale; // 增加每token缩放参数 LayoutPerTokenScale layoutPerTokenScale; __gm__ ElementD *ptrD; // 增加输出D矩阵 LayoutD layoutD; GM_ADDR ptrWorkspace; // 增加工作区 // ... };

6.3 OptimizedMatmul扩展

OptimizedMatmul增加了Prologue处理和更复杂的参数结构:

template < class PrologueA, // 增加A矩阵预处理 class PrologueB, // 增加B矩阵预处理 class BlockMmad_, class BlockEpilogue_, class BlockScheduler_ > class OptimizedMatmul { // ... template<bool IsPaddingA = true, bool IsPaddingB = true> struct KernelParams : public ParamsBase { // 增加填充相关参数 GM_ADDR ptrWA; LayoutWA layoutWA; GM_ADDR ptrWB; LayoutWB layoutWB; // ... }; // ... };

7. 总结

CATLASS GEMM Kernel采用了高度模块化和模板化的设计,具有以下特点:

  1. 模板组装:通过模板参数灵活组装不同的功能组件,实现代码复用和功能扩展
  2. 分层参数:使用Arguments和Params两层结构,分离用户接口和内核执行参数
  3. 统一执行流程:所有Kernel遵循相似的执行流程,包括初始化、调度、计算和同步
  4. 可扩展性:通过扩展基础结构,可以轻松实现批处理、量化、优化等高级功能

这种设计使得CATLASS模板库能够高效地支持各种GEMM操作,同时保持代码的可维护性和可扩展性。

【免费下载链接】catlass本项目是CANN的算子模板库,提供NPU上高性能矩阵乘及其相关融合类算子模板样例。项目地址: https://gitcode.com/cann/catlass

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/9 12:26:41

多模态可解释AI:从热图到文本的跨模态决策解释技术

1. 项目概述&#xff1a;为什么我们需要“看得懂”的AI决策&#xff1f;最近几年&#xff0c;AI模型&#xff0c;特别是那些基于深度学习的“黑箱”模型&#xff0c;在图像识别、自然语言处理乃至多模态任务上取得了惊人的成绩。但一个越来越突出的问题是&#xff1a;我们往往只…

作者头像 李华
网站建设 2026/5/9 12:25:28

视频直播点播/高清点播/音视频点播/云点播/云直播EasyDSS构建社交娱乐全链路视频解决方案

社交娱乐行业的核心竞争力&#xff0c;始终源于场景持续创新与互动形式多元化。行业早已从早期单一的语音聊天、点对点视频通话&#xff0c;逐步延伸至线上桌游、远程剧本杀、线上才艺展演等全新业态。场景边界持续拓宽&#xff0c;也对底层视频系统的功能兼容性、多场景适配能…

作者头像 李华
网站建设 2026/5/9 12:24:11

基于GB/T 8567-2006的《数据需求说明(DRD)》完整案例(7/25)

项目名称&#xff1a; 某无人机系统 项目编号&#xff1a; LEMF-UAS-DRD-2024-001 版本号&#xff1a; 1.0 编制单位&#xff1a; XXX公司 编制&#xff1a; XXX 审核&#xff1a; XXX 批准&#xff1a; XXX 编制日期&#xff1a;1 引言1.1 标识本文档是“某无人机系统&#xf…

作者头像 李华
网站建设 2026/5/9 12:22:49

CANN运行时主机到设备同步内存复制示例

1_h2d_sync_memory_copy 【免费下载链接】runtime 本项目提供CANN运行时组件和维测功能组件。 项目地址: https://gitcode.com/cann/runtime 描述 本样例展示了Host到Device的内存复制&#xff0c;使用aclrtMemcpy内存复制接口。 产品支持情况 本样例支持以下产品&am…

作者头像 李华
网站建设 2026/5/9 12:21:33

【Vue3 + SVG 饼图组件单一数据类别显示异常问题】

问题描述 在项目统计功能模块中&#xff0c;"章节状态分布"饼图存在显示异常问题&#xff1a; ❌ 单状态场景&#xff1a;当统计数据仅包含单个状态类别时&#xff0c;饼图显示错误 ✅ 多状态场景&#xff1a;当统计数据包含两个及以上状态类别时&#xff0c;饼图显…

作者头像 李华
网站建设 2026/5/9 12:18:03

CANN/ATVC BroadcastTo算子样例

【免费下载链接】atvc ATVC&#xff08;Ascend C Templates for Vector Compute&#xff09;&#xff0c;是为基于Ascend C开发的典型Vector算子封装的一系列模板头文件的集合&#xff0c;可帮助用户快速开发典型Vector算子。 项目地址: https://gitcode.com/cann/atvc …

作者头像 李华