本文基于昇腾CANN和昇腾NPU,围绕 Cube MatMul 矩阵乘法技术展开。
想象你在一个巨大的停车场里搬箱子。方案 A:一次搬一个箱子,走 100 趟——这是 Vector 的做法。方案 B:用叉车一次叉起 16×16 个箱子,一趟搞定——这是 Cube 的做法。
AI 计算的核心是矩阵乘法——95% 以上的浮点运算都花在 GEMM 上。一个 Attention 层里的 Q×K^T、Attention×V、FFN 的两次 MatMul——全是 GEMM。硬件优化的第一优先级就是让 GEMM 跑得尽可能快。这就是为什么达芬奇 Core 里专门塞了一个 Cube Unit——它生来只干一件事:16×16×16 的 FP16 乘累加。
Cube Unit 怎么算 GEMM
Cube Unit 每 cycle 做一次 C[16,16] += A[16,16] × B[16,16] 的矩阵乘累加。用 FP16 计算,FP32 累加——输入低精度省内存和带宽,中间累积用高精度保留数值稳定性。
一次 Attention 的 Q×K^T 计算:Q 是 [B, H, S, D],K^T 是 [B, H, D, S]。Cube Unit 把 M=S、K=D、N=S 的 GEMM 切成一堆 16×16 的小块,一块一块算。每个小块独立在 Cube Unit 上完成,结果累积到 C 矩阵。
GEMM C[M,N] = A[M,K] × B[K,N] 在 Cube Unit 上的分解: ┌───────────────┐ ┌───────────┐ ┌───────────────┐ │ A (M×K) │ │ B (K×N) │ │ C (M×N) │ │ │ × │ │ = │ │ │ 切成 M_TILE× │ │ 切成 K_TILE│ │ 切成 M_TILE× │ │ K_TILE 小块 │ │ ×N_TILE │ │ N_TILE 小块 │ └───────────────┘ └───────────┘ └───────────────┘ 每个 M_TILE×N_TILE 的小块 C: for k_step in range(K / K_TILE): C_tile += A_tile[k_step] @ B_tile[k_step] // 16×16×16 的 MACTile 分块为什么是核心
Cube Unit 一次只算 16×16。一个 4096×4096 的 GEMM 要分解成 (4096/16)² = 65536 个小块——但不需要全存在 L1 上同时算。
Tiling 的关键在 K 维度上的循环累积。A 和 B 沿着 K 维度切段,每次载入一段到 L1、Cube 算完这段的乘累加、结果加到 C 上——然后载入下一段。C 一直在 L1 上不动,A 和 B 轮流上场。
这样 L1 只需要装下:一段 A [M_TILE × K_TILE]、一段 B [K_TILE × N_TILE]、和 C [M_TILE × N_TILE]。三个块加起来几十到一百多 KB——刚好塞进 192KB 的 L1。
昇腾NPU的 GEMM 融合
单纯的 MatMul 后面往往跟着 Bias Add、Activation、Residual Add。CANN 的算子库把这些操作融进 MatMul——不是"先算 GEMM 再调 Activation",而是在 GEMM 的最后一个 K-step 还没结束时,Vector Unit 就已经开始处理前面累积好的 C 元素做 Activation。
融合 MatMul + Bias + GELU 的执行流: Cube Unit: [GEMM K-step 0] [GEMM K-step 1] ... [GEMM 最后 K-step] Vector Unit: [Bias Add] [GELU] Scalar Unit: [地址计算] [循环控制] ... [地址计算] [循环控制] 三个单元在不同 K-step 上并行——融合的 GEMM 比"先 GEMM 再 Activation" 快约 30%,因为省掉了 GEMM 输出写回 DDR 和 Activation 从 DDR 读入。// Ascend C 里的融合 MatMul 模板——简化版classFusedMatMulGELU:publicAscendC::Kernel{__aicore__voidProcess()override{// Tiling 参数constexprintM_TILE=128,N_TILE=256,K_TILE=32;LocalTensor<fp16>a_tile,b_tile,c_tile;LocalAlloc(a_tile,M_TILE*K_TILE);LocalAlloc(b_tile,K_TILE*N_TILE);LocalAlloc(c_tile,M_TILE*N_TILE);for(intm=0;m<M;m+=M_TILE){for(intn=0;n<N;n+=N_TILE){SetZero(c_tile);for(intk=0;k<K;k+=K_TILE){DataCopy(a_tile,gm_a[m][k],M_TILE*K_TILE);DataCopy(b_tile,gm_b[k][n],K_TILE*N_TILE);MatMul(c_tile,a_tile,b_tile,M_TILE,N_TILE,K_TILE);// 累加:C += A @ B(Cube Unit 原生支持 FP32 累加)}// GEMM 完成后,Vector Unit 直接对 C 做 GELU——不写回 DDRGELU(c_tile,c_tile,M_TILE*N_TILE);DataCopy(gm_c[m][n],c_tile,M_TILE*N_TILE);}}}};参考仓库
ops-blas 高性能 GEMM
ops-nn 神经网络算子
catlass 算子模板库
CANN 学习中心