news 2026/5/9 11:58:35

CANN/catlass MLA算子实现

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
CANN/catlass MLA算子实现

CATLASS MLA

【免费下载链接】catlass本项目是CANN的算子模板库,提供NPU上高性能矩阵乘及其相关融合类算子模板样例。项目地址: https://gitcode.com/cann/catlass

CATLASS MLA是基于CATLASS Gemm API实现的亲和昇腾AtlasA2硬件的Flash-MLA算子,算子的结构可以分为以下几部分

  • Tiling计算;
  • Kernel实现,具体有两种实现,通用的mla_kernel.cpp以及特化的mla_kernel_tp1_spec.cpp;
  • Kernel中依赖适合Flash-MLA运算的Block组件;
  • 使用的Block组件依赖模板库提供的Tile组件。

Tiling

Tiling计算的逻辑位于mla.cpp文件中,在调用算子前,需要准备好tiling计算所需的各项参数,赋值给MLAInfo结构体,并调用GetMLATilingParam函数。mla.cpp中提供了一个示例

// 准备Tiling计算所需的中间结构体以及Host侧空间 MLATiling::MLAInfo mlaInfo; ... MLATiling::GetMLATilingParam(mlaInfo, blockDim, (uint32_t *)tilingHost);

GetMLATilingParam函数中,调用了两个函数GetMLATilingCommonGetMLATilingSpec,分别对应了通用场景下和特化场景下的分核逻辑

Kernel

本算子提供了两种Kernel实现:

  1. 通用的mla_kernel.cpp,在qHeadNum为16/32/64场景(分别对应模型侧TP8/4/2场景)性能更优。
  2. 特化的mla_kernel_tp1_spec.cpp,在qHeadNum为128场景(对应模型侧TP1场景)性能更优。

mla_kernel.cpp具有以下特性:

  • 采用FlashAttention的四阶段计算流程,对于输入的Q, QRope, K, KRope进行切块后运算。
  • 对输入序列长度kvSeqlen按照blockSize为单位进行切块,每次Attention运算的基块为一个block,使能提前下发一个基块的QK Mmad与softmax,让不同基块的CUBE与VECTOR阶段互相掩盖。
  • 在同一基块的QK与PV的矩阵乘之间,由于K与V共用同一段数据,使能K常驻在L1 buffer上,减少搬入带宽占用。

mla_kernel_tp1_spec.cpp具有以下特性:

  • 采用FlashAttention的四阶段计算流程,对于输入的Q, QRope, K, KRope进行切块后运算。
  • 对输入序列长度kvSeqlen按照blockSize为单位进行切块,每次Attention运算的基块为四个block,使能提前下发一个基块的QK Mmad与softmax,让不同基块的CUBE与VECTOR阶段互相掩盖。
  • 由于基块大小的放大,该Kernel的PV Mmad阶段的搬出数据量降低,减少了搬出带宽占用,相应的,由于硬件buffer大小限制,取消了K的常驻。

在本算子中,使用了Block和Tile层级组件来组装Kernel,具体步骤为:

  1. 组装attention计算中的两个BlockMmad(QK,PV)以及三个BlockEpilogue(softmax, rescaleO, flashDecoding)。
  2. 将Block组合在一起构建成MLAKernel,并在Kernel类中完成对各个Block的循环调用。

这一过程也体现在Kernel入口的代码中(以mla_kernel.cpp为例):

// GEMM Block模块,实现Flash MLA的Q * K^T using DispatchPolicyQK = Gemm::MmadAtlasA2MLAQK; using QType = Gemm::GemmType<ElementQ, LayoutQ>; using KType = Gemm::GemmType<ElementK, LayoutK>; using SType = Gemm::GemmType<ElementS, LayoutS>; using BlockMmadQK = Gemm::Block::BlockMmad<DispatchPolicyQK, L1TileShape, L0TileShape, QType, KType, SType>; // Epilogue Block模块, 实现Flash MLA中当前S基块的softmax using PType = Gemm::GemmType<ElementP, LayoutP>; using MaskType = Gemm::GemmType<ElementMask, LayoutMask>; using EpilogueMLASoftmax = Epilogue::Block::BlockEpilogue<Epilogue::EpilogueAtlasA2MLASoftmax, PType, SType, MaskType>; // GEMM Block模块,实现Flash MLA的P * V using DispatchPolicyPV = Gemm::MmadAtlasA2MLAPV; using VType = Gemm::GemmType<ElementV, LayoutV>; using OTmpType = Gemm::GemmType<ElementOTmp, LayoutOTmp>; using BlockMmadPV = Gemm::Block::BlockMmad<DispatchPolicyPV, L1TileShape, L0TileShape, PType, VType, OTmpType>; // Epilogue Block模块, 实现Flash MLA中当前O基块的更新 using OType = Gemm::GemmType<ElementO, LayoutO>; using OUpdateType = Gemm::GemmType<ElementUpdate, LayoutUpdate>; using EpilogueMLARescaleO = Epilogue::Block::BlockEpilogue<Epilogue::EpilogueAtlasA2MLARescaleO, OType, OUpdateType, OTmpType>; // Epilogue Block模块, 实现Flash MLA中flash decoding using lType = Gemm::GemmType<ElementUpdate, LayoutUpdate>; constexpr uint32_t ComputeEleNum = 6144; using EpilogueMLAFDRescaleO = Epilogue::Block::BlockEpilogue<Epilogue::EpilogueAtlasA2MLAFDRescaleO<ComputeEleNum>, OType, lType>; // Kernel level using MLAKernel = MLAKernel<BlockMmadQK, BlockMmadPV, EpilogueMLASoftmax, EpilogueMLARescaleO, EpilogueMLAFDRescaleO>;

Block Mmad

算子总共使用了两类Block Mmad组件,分别为:

  • BlockMmadQK为BlockMmad模板类的偏特化,用于处理Flash-MLA中的Q与K的矩阵乘操作,头文件block_mmad_mla_qk.hpp中的实现对应通用的mla_kernel.cpp,头文件block_mmad_mla_qk_tp1_spec.hpp中的实现则对应特化的mla_kernel_tp1_spec.cpp。
  • BlockMmadPV为BlockMmad模板类的偏特化,用于处理Flash-MLA中的P与V的矩阵乘操作,头文件block_mmad_mla_pv.hpp中的实现对应通用的mla_kernel.cpp,头文件block_mmad_mla_pv_tp1_spec.hpp中的实现则对应特化的mla_kernel_tp1_spec.cpp。

Block Epilogue

算子总共使用了三类Block Epilogue组件,分别为:

  • EpilogueMLASoftmax为BlockEpilogue模板类的偏特化,用于处理Flash-MLA中的online softmax操作,头文件block_epilogue_mla_softmax.hpp中的实现对应通用的mla_kernel.cpp,头文件block_epilogue_mla_tp1_softmax.hpp中的实现则对应特化的mla_kernel_tp1_spec.cpp。
  • EpilogueMLARescaleO为BlockEpilogue模板类的偏特化,用于处理Flash-MLA中的rescaleO操作,头文件block_epilogue_mla_rescale_o.hpp中的实现对应通用的mla_kernel.cpp,头文件block_epilogue_mla_tp1_rescale_o.hpp中的实现则对应特化的mla_kernel_tp1_spec.cpp。
  • EpilogueMLAFDRescaleO为BlockEpilogue模板类的偏特化,用于处理Flash-MLA中的flashDecoding操作(如有必要),头文件block_epilogue_mla_fd_rescale_o.hpp中的实现为mla_kernel.cpp与mla_kernel_tp1_spec.cpp两者共用。

Tile Mmad & Tile Copy

在通用Kernel使用的Block组件中,使用了位于tile_mmad.hpp中的tileMmad组件和位于tile_copy.hpp中的tileCopy组件,例如:

using TileMmad = TileMmad_; using CopyGmToL1A = typename TileCopy_::CopyGmToL1A; using CopyGmToL1B = typename TileCopy_::CopyGmToL1B; using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A; using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B; using CopyL0CToGm = typename TileCopy_::CopyL0CToGm; using ElementAccumulator = typename Gemm::helper::ElementAccumulatorSelector<ElementA, ElementB>::ElementAccumulator;

【免费下载链接】catlass本项目是CANN的算子模板库,提供NPU上高性能矩阵乘及其相关融合类算子模板样例。项目地址: https://gitcode.com/cann/catlass

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/9 11:56:37

CANN TensorFlow HCCL广播操作

broadcast 【免费下载链接】tensorflow Ascend TensorFlow Adapter 项目地址: https://gitcode.com/cann/tensorflow 功能说明 集合通信算子Broadcast的操作接口&#xff0c;将通信域内root节点的数据广播到其他rank。 函数原型 def broadcast(tensor, root_rank, fus…

作者头像 李华
网站建设 2026/5/9 11:54:06

CANN/pto-isa复杂操作指令集

复杂操作 【免费下载链接】pto-isa Parallel Tile Operation (PTO) is a virtual instruction set architecture designed by Ascend CANN, focusing on tile-level operations. This repository offers high-performance, cross-platform tile operations across Ascend platf…

作者头像 李华
网站建设 2026/5/9 11:53:45

CANN/pyasc按位或运算API

asc.language.basic.bitwise_or 【免费下载链接】pyasc 本项目为Python用户提供算子编程接口&#xff0c;支持在昇腾AI处理器上加速计算&#xff0c;接口与Ascend C一一对应并遵守Python原生语法。 项目地址: https://gitcode.com/cann/pyasc asc.language.basic.bitwis…

作者头像 李华
网站建设 2026/5/9 11:51:35

[QML] Qt6/Qt5四大渐变效果实战指南

一、模块导入import QtQuick import QtQuick.Shapes 1.8 as QT6Style // Qt6 Shape渐变 import Qt5Compat.GraphicalEffects as QT5Style // Qt5兼容效果渐变二、四种渐变对比渐变类型模块效果适用场景GradientQtQuick线性&#xff08;水平/垂直&#xff09;简单背景Line…

作者头像 李华
网站建设 2026/5/9 11:50:30

深蓝BREAKER:全球首家ORIVO认证南极磷虾油原料商,树立品质新标杆

近日&#xff0c;深蓝BREAKER&#xff08;江苏深蓝生物科技有限公司&#xff09;成功通过权威海洋脂质纯度验证机构——ORIVO 的认证&#xff0c;成为全球首家斩获该认证的南极磷虾油原料商&#xff0c;并获得其颁发的 100% 纯南极磷虾油证书&#xff0c;跻身全球极少数获此认证…

作者头像 李华
网站建设 2026/5/9 11:35:36

SystemC与SystemCrafter在DES加密硬件加速中的实践

1. SystemC与SystemCrafter在DES加密中的协同设计实践作为一名长期从事硬件加速开发的工程师&#xff0c;我亲历了从传统HDL到高层次综合&#xff08;HLS&#xff09;的技术演进。本文将分享如何利用SystemC和SystemCrafter SC工具链实现DES加密算法的硬件加速&#xff0c;这个…

作者头像 李华