news 2026/5/9 21:35:33

CANN稠密索引器梯度KL损失算子

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
CANN稠密索引器梯度KL损失算子

aclnnDenseLightningIndexerGradKLLoss

【免费下载链接】ops-transformer本项目是CANN提供的transformer类大模型算子库,实现网络在NPU上加速计算。项目地址: https://gitcode.com/cann/ops-transformer

产品支持情况

产品是否支持
Ascend 950PR/Ascend 950DT×
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 200I/500 A2 推理产品×
Atlas 推理系列产品×
Atlas 训练系列产品×

功能说明

  • 接口功能:DenseLightningIndexerGradKlLoss算子是LightningIndexer的反向算子,再额外融合了Loss计算功能。LightningIndexer算子将QueryToken和KeyToken之间的最高内在联系的TopK个筛选出来,从而减少长序列场景下Attention的计算量,加速长序列的网络的推理和训练的性能。稠密场景下的LightningIndexerGrad的输入query、key、query_index、key_index不用做稀疏化处理。

  • 计算公式:

    1. Top-k value的计算公式:

    $$ I_{t,:}=W_{t,:}@ReLU(\tilde{q}{t,:}@\tilde{K}{:t,:}^\top) $$

    • $W_{t,:}$是第$t$个token对应的$weights$;
    • $\tilde{q}_{t,:}$是$\tilde{q}$矩阵第$t$个token对应的$G$个query头合轴后的结果;
    • $\tilde{K}_{:t,:}$为$t$行$\tilde{K}$矩阵。
    1. 正向的Softmax对应公式:

    $$ p_{t,:} = \text{Softmax}(q_{t,:} @ K_{:t,:}^\top/\sqrt{d}) $$

    • $p_{t,:}$是第$t$个token对应的Softmax结果;
    • $q_{t,:}$是$q$矩阵第$t$个token对应的$G$个query头合轴后的结果;
    • ${K}_{:t,:}$为$t$行$K$矩阵。
    1. npu_lightning_indexer会单独训练,对应的loss function为:

    $$ Loss{=}\sum_tD_{KL}(p_{t,:}||Softmax(I_{t,:})) $$

    其中,$p_{t,:}$是target distribution,通过对main attention score 进行所有的head的求和,然后把求和结果沿着上下文方向进行L1正则化得到。$D_{KL}$为KL散度,其表达式为:

    $$ D_{KL}(a||b){=}\sum_ia_i\mathrm{log}{\left(\frac{a_i}{b_i}\right)} $$

    1. 通过求导可得Loss的梯度表达式:

    $$ dI\mathop{{}}\nolimits_{{t,:}}=Softmax \left( I\mathop{{}}\nolimits_{{t,:}} \left) -p\mathop{{}}\nolimits_{{t,:}}\right. \right. $$

    利用链式法则可以进行weights,query和key矩阵的梯度计算:

    $$ dW\mathop{{}}\nolimits_{{t,:}}=dI\mathop{{}}\nolimits_{{t,:}}\text{@} \left( ReLU \left( S\mathop{{}}\nolimits_{{t,:}} \left) \left) \mathop{{}}\nolimits^{\top}\right. \right. \right. \right. $$

    $$ d\mathop{{\tilde{q}}}\nolimits_{{t,:}}=dS\mathop{{}}\nolimits_{{t,:}}@\tilde{K}\mathop{{}}\nolimits_{{:t,:}} $$

    $$ d\tilde{K}\mathop{{}}\nolimits_{{:t,:}}=\left(dS\mathop{{}}\nolimits_{{t,:}} \left) \mathop{{}}\nolimits^{\top}@\tilde{q}\mathop{{}}\nolimits_{{:t, :}}\right. \right. $$

    其中,$S$为$\tilde{q}$和$K$矩阵乘的结果。

函数原型

算子执行接口为两段式接口,必须先调用“aclnnDenseLightningIndexerGradKLLossGetWorkspaceSize”接口获取入参并根据计算流程计算所需workspace大小,再调用“aclnnDenseLightningIndexerGradKLLoss”接口执行计算。

aclnnStatus aclnnDenseLightningIndexerGradKLLossGetWorkspaceSize( const aclTensor *query, const aclTensor *key, const aclTensor *queryIndex, const aclTensor *keyIndex, const aclTensor *weights, const aclTensor *softmaxMax, const aclTensor *softmaxSum, const aclTensor *softmaxMaxIndex, const aclTensor *softmaxSumIndex, const aclTensor *queryRope, const aclTensor *keyRope, const aclIntArray *actualSeqLengthsQuery, const aclIntArray *actualSeqLengthsKey, double scaleValue, char *layout, int64_t sparseMode, int64_t preTokens, int64_t nextTokens, const aclTensor *dQueryIndex, const aclTensor *dKeyIndex, const aclTensor *dWeights, const aclTensor *loss, uint64_t *workspaceSize, aclOpExecutor *executor)
aclnnStatus aclnnDenseLightningIndexerGradKLLoss( void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream)

aclnnDenseLightningIndexerGradKLLoss

  • 参数说明:

    参数名输入/输出描述使用说明数据类型数据格式维度(shape)非连续Tensor
    query(aclTensor*)输入attention结构的输入Q
    • B: 支持泛化。
    • S1: 支持泛化。
    • N1: 支持128、64、32。
    • D: 128。
    • T1: 多个Batch的S1累加。
    FLOAT16、BFLOAT16ND(B,S1,N1,D);(T1,N1,D)×
    key(aclTensor*)输入attention结构的输入K
    • B: 支持泛化且与query的B保持一致。
    • S2: 支持泛化。
    • N2: 等于N1。
    • D: 128。
    • T2: 多个Batch的S2累加。
    FLOAT16、BFLOAT16ND(B,S2,N2,D);(T2,N2,D)×
    queryIndex(aclTensor*)输入lightningIndexer结构的输入queryIndex。
    • B: 支持泛化且与query的B保持一致。
    • S1: 支持泛化。
    • Nidx1: 64、32、16、8。
    • D: 128。
    • T1: 多个Batch的S1累加。
    FLOAT16、BFLOAT16ND(B,S1,Nidx1,D);(T1,Nidx1,D)×
    keyIndex(aclTensor*)输入lightningIndexer结构的输入keyIndex。
    • B: 支持泛化且与query的B保持一致。
    • S2: 支持泛化。
    • Nidx2: 1。
    • D: 128。
    • T2: 多个Batch的S2累加。
    FLOAT16、BFLOAT16ND(B,S2,Nidx2,D);(T2,Nidx2,D)×
    weights(aclTensor*)输入权重
    • B: 支持泛化且与query的B保持一致。
    • S1: 支持泛化且与query的S1保持一致。
    • Nidx1: 64、32、16、8。
    • T1: 多个Batch的S1累加。
    FLOAT16、BFLOAT16、FLOAT32ND(B,S1,Nidx1);(T1,Nidx1)×
    softmaxMax(aclTensor*)输入Device侧的aclTensor,注意力正向计算的中间输出
    • B: 支持泛化与query的B保持一致。
    • N2: 等于N1。
    • S1: 支持泛化且与query的S1保持一致。
    • G: N1/N2。
    • T1: 多个Batch的S1累加。
    FLOAT32ND(B,N2,S1,G);(N2,T1,G)×
    softmaxSum(aclTensor*)输入Device侧的aclTensor,注意力正向计算的中间输出
    • B: 支持泛化与query的B保持一致。
    • N2: 等于N1。
    • S1: 支持泛化且与query的S1保持一致。
    • G: N1/N2。
    • T1: 多个Batch的S1累加。
    FLOAT32ND(B,N2,S1,G);(N2,T1,G)×
    softmaxMaxIndex(aclTensor*)输入Device侧的aclTensor,注意力正向计算的中间输出
    • B: 支持泛化与query的B保持一致。
    • Nidx2: 1。
    • S1: 支持泛化且与query的S1保持一致。
    • T1: 多个Batch的S1累加。
    FLOAT32ND(B,Nidx2,S1);(Nidx2,T1)×
    softmaxSumIndex(aclTensor*)输入Device侧的aclTensor,注意力正向计算的中间输出
    • B: 支持泛化与query的B保持一致。
    • Nidx2: 1。
    • S1: 支持泛化且与query的S1保持一致。
    • T1: 多个Batch的S1累加。
    FLOAT32ND(B,Nidx2,S1);(Nidx2,T1)
    queryRope(aclTensor*)输入MLA rope部分:Query位置编码的输出。
    • 与query的layout维度保持一致。
    • B: 支持泛化与query的B保持一致。
    • S1: 支持泛化且与query的S1保持一致。
    • N1: 128、64、32。
    • Dr: 64。
    • T1: 多个Batch的S1累加。
    FLOAT16、BFLOAT16ND(B,S1,N1,Dr);(T1,N1,Dr)
    keyRope(aclTensor*)输入MLA rope部分:Key位置编码的输出
    • 与key的layout维度保持一致。
    • B: 支持泛化与query的B保持一致。
    • S2: 支持泛化且与key的S1保持一致。
    • N2: 等于N1。
    • Dr: 64。
    • T2: 多个Batch的S2累加。
    FLOAT16、BFLOAT16ND(B,S2,N2,Dr);(T2,N2,Dr)
    actualSeqLengthsQuery(aclIntArray*)输入每个Batch中,Query的有效token数
    • 值依赖。
    • 长度与B保持一致。
    • 累加和与T1保持一致。
    INT64ND(B,)-
    actualSeqLengthsKey(aclIntArray*)输入每个Batch中,Key的有效token数
    • 值依赖。
    • 长度与B保持一致。
    • 累加和T2保持一致。
    INT64ND(B,)-
    scaleValue(double)输入缩放系数
    • 建议值:公式中d开根号的倒数。
    ----
    layout(char*)输入layout格式仅支持BSND和TND格式。----
    sparseMode(int64_t)输入sparse的模式
    • 表示sparse的模式。sparse不同模式的详细说明请参见约束说明。
    • 仅支持模式3。
    ----
    preTokens(int64_t)输入用于稀疏计算,表示Attention需要和前几个token计算关联>和Attention中的preTokens定义相同,在sparseMode = 0和4的时候生效,默认值2^63-1。----
    nextTokens(int64_t)输入用于稀疏计算,表示Attention需要和后几个token计算关联和Attention中的nextTokens定义相同,在sparseMode = 0和4的时候生效,默认值2^63-1。----
    dQueryIndex(aclTensor*)输出QueryIndex的梯度
    • B: 支持泛化与query的B保持一致。
    • S1:支持泛化,且与query的S1保持一致。
    • Nidx1: 64、32、16、8。
    • D: 128。
    • T1: 多个Batch的S1累加。
    FLOAT16、BFLOAT16ND(B,S1,Nidx1,D);(T1,Nidx1,D)
    dKeyIndex(aclTensor*)输出KeyIndex的梯度
    • B: 支持泛化与query的B保持一致。
    • S2: 支持泛化,且与key的S2保持一致。
    • Nidx2: 1。
    • D: 128。
    • T2: 多个Batch的S2累加。
    FLOAT16、BFLOAT16ND(B,S2,Nidx2,D);(T2,Nidx2,D)
    dWeights(aclTensor*)输出Weights的梯度
    • B: 支持泛化。
    • S1: 支持泛化,不能为Matmul的M轴。
    • Nidx1: 64、32、16、8。
    • T1: 多个Batch的S1累加。
    FLOAT16、BFLOAT16、FLOAT32ND(B,S1,Nidx1);(T1,Nidx1)
    loss(aclTensor*)输出损失函数值-FLOAT32ND(1,)-
    workspaceSize(uint64_t*)输出返回需要在Device侧申请的workspace大小。-----
    executor(aclOpExecutor**)输出返回op执行器,包含了算子计算流程。-----
  • 返回值:

    返回aclnnStatus状态码,具体参见aclnn返回码。

    第一段接口完成入参校验,出现以下场景时报错:

    返回值错误码描述
    ACLNN_ERR_PARAM_NULLPTR161001必选参数或者输出是空指针。
    ACLNN_ERR_PARAM_INVALID161002query、key、queryIndex、keyIndex、weights、softmaxMax等输入变量的数据类型和数据格式不在支持的范围内。
    ACLNN_ERR_INNER_TILING_ERROR561002多个输入tensor之间的shape不匹配(详见参数说明)。

aclnnDenseLightningIndexerGradKLLoss

  • 参数说明:

    参数名输入/输出描述
    workspace输入在Device侧申请的workspace内存地址。
    workspaceSize输入在Device侧申请的workspace大小,由第一段接口aclnnDenseLightningIndexerGradKLLossGetWorkspaceSize获取。
    executor输入op执行器,包含了算子计算流程。
    stream输入指定执行任务的Stream流。
  • 返回值:

    返回aclnnStatus状态码,具体参见aclnn返回码。

约束说明

  • 参数query、key、queryIndex、keyIndex的数据类型应保持一致。

  • 参数weights不为float32时,参数query、key、queryIndex、keyIndex、weights的数据类型应保持一致。

  • 公共约束

    • 确定性计算: aclnnDenseLightningIndexerGradKLLoss默认非确定性实现,支持通过alcrtCtxSetSysParamOpt开启确定性。
    • 入参为空的场景处理:
      • query或key或query_index或key_index或weight为空Tensor:当前不支持,会报错。
    sparseMode含义备注
    0defaultMask模式,如果attenmask未传入则不做mask操作,忽略preTokens和nextTokens;如果传入,则需要传入完整的attenmask矩阵,表示preTokens和nextTokens之间的部分需要计算不支持
    1allMask,必须传入完整的attenmask矩阵不支持
    2leftUpCausal模式的mask,需要传入优化后的attenmask矩阵不支持
    3rightDownCausal模式的mask,对应以右顶点为划分的下三角场景,需要传入优化后的attenmask矩阵支持
    4band模式的mask,需要传入优化后的attenmask矩阵不支持
    5prefix不支持
    6global不支持
    7dilated不支持
    8block_local不支持
  • 规格约束

    规格项规格规格说明
    B1~256-
    S1、S21~128KS1、S2支持不等长
    N132、64、128-
    Nidx18、16、32、64-
    N232、64、128-
    Nidx21-
    D128query与query_index的D相同。
    Drope64-
    layoutBSND/TND-
  • 典型值

    规格项典型值
    queryN1=128/64/32; D=128
    queryIndexNidx1 = 64/32/16/8; D = 128 ; S1 = 64k/128k
    keyIndexD = 128
    qRopeD = 64

调用示例

调用示例代码如下,仅供参考,具体编译和执行过程请参考编译与运行样例。

#include <iostream> #include <vector> #include <cstdint> #include <cmath> #include "acl/acl.h" #include "aclnnop/aclnn_dense_lightning_indexer_grad_kl_loss.h" #define CHECK_RET(cond, return_expr) \ do { \ if (!(cond)) { \ return_expr; \ } \ } while (0) #define LOG_PRINT(message, ...) \ do { \ printf(message, ##__VA_ARGS__); \ } while (0) int64_t GetShapeSize(const std::vector<int64_t>& shape) { int64_t shapeSize = 1; for (auto i : shape) { shapeSize *= i; } return shapeSize; } void PrintOutResult(std::vector<int64_t> &shape, void** deviceAddr) { auto size = GetShapeSize(shape); std::vector<aclFloat16> resultData(size, 0); auto ret = aclrtMemcpy(resultData.data(), resultData.size() * sizeof(resultData[0]), *deviceAddr, size * sizeof(resultData[0]), ACL_MEMCPY_DEVICE_TO_HOST); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("copy result from device to host failed. ERROR: %d\n", ret); return); for (int64_t i = 0; i < size; i++) { LOG_PRINT("mean result[%ld] is: %f\n", i, aclFloat16ToFloat(resultData[i])); } } int Init(int32_t deviceId, aclrtContext* context, aclrtStream* stream) { // 固定写法,AscendCL初始化 auto ret = aclInit(nullptr); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclInit failed. ERROR: %d\n", ret); return ret); ret = aclrtSetDevice(deviceId); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSetDevice failed. ERROR: %d\n", ret); return ret); ret = aclrtCreateContext(context, deviceId); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateContext failed. ERROR: %d\n", ret); return ret); ret = aclrtSetCurrentContext(*context); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSetCurrentContext failed. ERROR: %d\n", ret); return ret); ret = aclrtCreateStream(stream); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateStream failed. ERROR: %d\n", ret); return ret); return 0; } template <typename T> int CreateAclTensor(const std::vector<T>& hostData, const std::vector<int64_t>& shape, void** deviceAddr, aclDataType dataType, aclTensor** tensor) { auto size = GetShapeSize(shape) * sizeof(T); // 调用aclrtMalloc申请device侧内存 auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc failed. ERROR: %d\n", ret); return ret); // 调用aclrtMemcpy将host侧数据拷贝到device侧内存上 ret = aclrtMemcpy(*deviceAddr, size, hostData.data(), size, ACL_MEMCPY_HOST_TO_DEVICE); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMemcpy failed. ERROR: %d\n", ret); return ret); // 计算连续tensor的strides std::vector<int64_t> strides(shape.size(), 1); for (int64_t i = shape.size() - 2; i >= 0; i--) { strides[i] = shape[i + 1] * strides[i + 1]; } // 调用aclCreateTensor接口创建aclTensor *tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND, shape.data(), shape.size(), *deviceAddr); return 0; } int main() { // 1. (固定写法)device/context/stream初始化,参考AscendCL对外接口列表 // 根据自己的实际device填写deviceId int32_t deviceId = 0; aclrtContext context; aclrtStream stream; auto ret = Init(deviceId, &context, &stream); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("Init acl failed. ERROR: %d\n", ret); return ret); // 2. 构造输入与输出,需要根据API的接口自定义构造 int64_t s1 = 1; int64_t s2 = 1; int64_t n1 = 32; int64_t n2 = n1; int64_t n1Index = 8; int64_t n2Index = 1; int64_t dQuery = 128; int64_t dRope = 64; int64_t dQueryIndex = 128; int64_t t1 = s1; int64_t t2 = s2; int64_t G = n1 / n2; std::vector<int64_t> qShape = {t1, n1, dQuery}; std::vector<int64_t> kShape = {t2, n2, dQuery}; std::vector<int64_t> qRopeShape = {t1, n1, dRope}; std::vector<int64_t> kRopeShape = {t2, n2, dRope}; std::vector<int64_t> qIndexShape = {t1, n1Index, dQueryIndex}; std::vector<int64_t> kIndexShape = {t2, n2Index, dQueryIndex}; std::vector<int64_t> weightShape = {t1, n1Index}; std::vector<int64_t> softmaxMaxShape = {n2, t1, G}; std::vector<int64_t> softmaxSumShape = {n2, t1, G}; std::vector<int64_t> softmaxMaxIndexShape = {n2Index, t1}; std::vector<int64_t> softmaxSumIndexShape = {n2Index, t1}; std::vector<int64_t> dQIndexShape = {t1, n1Index, dQueryIndex}; std::vector<int64_t> dKIndexShape = {t2, n2Index, dQueryIndex}; std::vector<int64_t> dWeightShape = {t1, n1Index}; std::vector<int64_t> lossShape = {1}; void* qDeviceAddr = nullptr; void* kDeviceAddr = nullptr; void* qRopeDeviceAddr = nullptr; void* kRopeDeviceAddr = nullptr; void* qIndexDeviceAddr = nullptr; void* kIndexDeviceAddr = nullptr; void* weightDeviceAddr = nullptr; void* softmaxMaxDeviceAddr = nullptr; void* softmaxSumDeviceAddr = nullptr; void* softmaxMaxIndexDeviceAddr = nullptr; void* softmaxSumIndexDeviceAddr = nullptr; void* dQIndexDeviceAddr = nullptr; void* dKIndexDeviceAddr = nullptr; void* dWeightDeviceAddr = nullptr; void* lossDeviceAddr = nullptr; aclTensor* q = nullptr; aclTensor* k = nullptr; aclTensor* qRope = nullptr; aclTensor* kRope = nullptr; aclTensor* qIndex = nullptr; aclTensor* kIndex = nullptr; aclTensor* weight = nullptr; aclTensor* softmaxMax = nullptr; aclTensor* softmaxSum = nullptr; aclTensor* softmaxMaxIndex = nullptr; aclTensor* softmaxSumIndex = nullptr; aclTensor* dQIndex = nullptr; aclTensor* dKIndex = nullptr; aclTensor* dWeight = nullptr; aclTensor* loss = nullptr; std::vector<aclFloat16> qHostData(t1 * n1 * dQuery, aclFloatToFloat16(0.1)); std::vector<aclFloat16> kHostData(t2 * n2 * dQuery, aclFloatToFloat16(0.2)); std::vector<aclFloat16> qRopeHostData(t1 * n1 * dRope, aclFloatToFloat16(0.1)); std::vector<aclFloat16> kRopeHostData(t2 * n2 * dRope, aclFloatToFloat16(0.2)); std::vector<aclFloat16> qIndexHostData(t1 * n1Index * dQueryIndex, aclFloatToFloat16(0.2)); std::vector<aclFloat16> kIndexHostData(t2 * n2Index * dQueryIndex, aclFloatToFloat16(0.1)); std::vector<aclFloat16> weightHostData(t1 * n1Index, aclFloatToFloat16(0.005)); std::vector<float> softmaxMaxHostData(t1 * n2, 25.4483f); std::vector<float> softmaxSumHostData(t1 * n2, 1.0f); std::vector<float> softmaxMaxIndexHostData(t1 * n2Index, 25.4483f); std::vector<float> softmaxSumIndexHostData(t1 * n2Index, 1.0f); std::vector<aclFloat16> dQIndexHostData(t1 * n1Index * dQueryIndex); std::vector<aclFloat16> dKIndexHostData(t2 * n2Index * dQueryIndex); std::vector<aclFloat16> dWeightHostData(t1 * n1Index); std::vector<float> lossHostData(1, 1.0f); ret = CreateAclTensor(qHostData, qShape, &qDeviceAddr, aclDataType::ACL_FLOAT16, &q); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(kHostData, kShape, &kDeviceAddr, aclDataType::ACL_FLOAT16, &k); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(qRopeHostData, qRopeShape, &qRopeDeviceAddr, aclDataType::ACL_FLOAT16, &qRope); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(kRopeHostData, kRopeShape, &kRopeDeviceAddr, aclDataType::ACL_FLOAT16, &kRope); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(qIndexHostData, qIndexShape, &qIndexDeviceAddr, aclDataType::ACL_FLOAT16, &qIndex); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(kIndexHostData, kIndexShape, &kIndexDeviceAddr, aclDataType::ACL_FLOAT16, &kIndex); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(weightHostData, weightShape, &weightDeviceAddr, aclDataType::ACL_FLOAT16, &weight); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(softmaxMaxHostData, softmaxMaxShape, &softmaxMaxDeviceAddr, aclDataType::ACL_FLOAT, &softmaxMax); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(softmaxSumHostData, softmaxSumShape, &softmaxSumDeviceAddr, aclDataType::ACL_FLOAT, &softmaxSum); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(softmaxMaxIndexHostData, softmaxMaxIndexShape, &softmaxMaxIndexDeviceAddr, aclDataType::ACL_FLOAT, &softmaxMaxIndex); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(softmaxSumIndexHostData, softmaxSumIndexShape, &softmaxSumIndexDeviceAddr, aclDataType::ACL_FLOAT, &softmaxSumIndex); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(dQIndexHostData, dQIndexShape, &dQIndexDeviceAddr, aclDataType::ACL_FLOAT16, &dQIndex); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(dKIndexHostData, dKIndexShape, &dKIndexDeviceAddr, aclDataType::ACL_FLOAT16, &dKIndex); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(dWeightHostData, dWeightShape, &dWeightDeviceAddr, aclDataType::ACL_FLOAT16, &dWeight); CHECK_RET(ret == ACL_SUCCESS, return ret); ret = CreateAclTensor(lossHostData, lossShape, &lossDeviceAddr, aclDataType::ACL_FLOAT, &loss); CHECK_RET(ret == ACL_SUCCESS, return ret); std::vector<int64_t> acSeqQLenOp = {t1}; std::vector<int64_t> acSeqKvLenOp = {t2}; aclIntArray* acSeqQLen = aclCreateIntArray(acSeqQLenOp.data(), acSeqQLenOp.size()); aclIntArray* acSeqKvLen = aclCreateIntArray(acSeqKvLenOp.data(), acSeqKvLenOp.size()); float scaleValue = 1.0 / sqrt(dQuery); int64_t preTokens = 2147483647; int64_t nextTokens = 2147483647; int64_t sparseMode = 3; bool deterministic = false; char layOut[5] = {'T', 'N', 'D', 0}; // 3. 调用CANN算子库API,需要修改为具体的Api名称 uint64_t workspaceSize = 0; aclOpExecutor* executor; // 调用aclnnDenseLightningIndexerGradKLLossGetWorkspaceSize第一段接口 ret = aclnnDenseLightningIndexerGradKLLossGetWorkspaceSize( q, k, qIndex, kIndex, weight, softmaxMax, softmaxSum, softmaxMaxIndex, softmaxSumIndex, qRope, kRope, acSeqQLen, acSeqKvLen, scaleValue, layOut, sparseMode, preTokens, nextTokens, dQIndex, dKIndex, dWeight, loss, &workspaceSize, &executor); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnDenseLightningIndexerGradKLLossGetWorkspaceSize failed. ERROR: %d\n", ret); return ret); // 根据第一段接口计算出的workspaceSize申请device内存 void* workspaceAddr = nullptr; if (workspaceSize > 0) { ret = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("allocate workspace failed. ERROR: %d\n", ret); return ret); } // 调用aclnnDenseLightningIndexerGradKLLoss第二段接口 ret = aclnnDenseLightningIndexerGradKLLoss(workspaceAddr, workspaceSize, executor, stream); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclnnDenseLightningIndexerGradKLLoss failed. ERROR: %d\n", ret); return ret); // 4. (固定写法)同步等待任务执行结束 ret = aclrtSynchronizeStream(stream); CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSynchronizeStream failed. ERROR: %d\n", ret); return ret); // 5. 获取输出的值,将device侧内存上的结果拷贝至host侧,需要根据具体API的接口定义修改 PrintOutResult(dQIndexShape, &dQIndexDeviceAddr); PrintOutResult(dKIndexShape, &dKIndexDeviceAddr); PrintOutResult(dWeightShape, &dWeightDeviceAddr); PrintOutResult(lossShape, &lossDeviceAddr); // 6. 释放aclTensor和aclScalar,需要根据具体API的接口定义修改 aclDestroyTensor(q); aclDestroyTensor(k); aclDestroyTensor(qIndex); aclDestroyTensor(kIndex); aclDestroyTensor(qRope); aclDestroyTensor(kRope); aclDestroyTensor(weight); aclDestroyTensor(softmaxMax); aclDestroyTensor(softmaxSum); aclDestroyTensor(softmaxMaxIndex); aclDestroyTensor(softmaxSumIndex); aclDestroyTensor(dQIndex); aclDestroyTensor(dKIndex); aclDestroyTensor(dWeight); aclDestroyTensor(loss); // 7. 释放device资源 aclrtFree(qDeviceAddr); aclrtFree(kDeviceAddr); aclrtFree(qIndexDeviceAddr); aclrtFree(kIndexDeviceAddr); aclrtFree(qRopeDeviceAddr); aclrtFree(kRopeDeviceAddr); aclrtFree(weightDeviceAddr); aclrtFree(softmaxMaxDeviceAddr); aclrtFree(softmaxSumDeviceAddr); aclrtFree(softmaxMaxIndexDeviceAddr); aclrtFree(softmaxSumIndexDeviceAddr); aclrtFree(dQIndexDeviceAddr); aclrtFree(dKIndexDeviceAddr); aclrtFree(dWeightDeviceAddr); aclrtFree(lossDeviceAddr); if (workspaceSize > 0) { aclrtFree(workspaceAddr); } aclrtDestroyStream(stream); aclrtDestroyContext(context); aclrtResetDevice(deviceId); aclFinalize(); return 0; }

【免费下载链接】ops-transformer本项目是CANN提供的transformer类大模型算子库,实现网络在NPU上加速计算。项目地址: https://gitcode.com/cann/ops-transformer

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/5/9 21:35:14

构建可信AI解释:从概念到落地的四层评估框架与实践指南

1. 项目概述&#xff1a;为什么我们需要一个“解释”AI的框架&#xff1f;最近几年&#xff0c;AI模型&#xff0c;特别是那些被称为“黑箱”的深度神经网络&#xff0c;在图像识别、自然语言处理乃至决策支持领域取得了惊人的成功。然而&#xff0c;当这些模型被部署在医疗诊断…

作者头像 李华
网站建设 2026/5/9 21:35:04

CANN / pypto - PReLU API文档

pypto.prelu 【免费下载链接】pypto PyPTO&#xff08;发音: pai p-t-o&#xff09;&#xff1a;Parallel Tensor/Tile Operation编程范式。 项目地址: https://gitcode.com/cann/pypto 产品支持情况 产品是否支持Ascend 950PR/Ascend 950DT√Atlas A3 训练系列产品/At…

作者头像 李华
网站建设 2026/5/9 21:31:52

CANN/driver DCMI获取卡电子标签API

dcmi_get_card_elabel 【免费下载链接】driver 本项目是CANN提供的驱动模块&#xff0c;实现基础驱动和资源管理及调度等功能&#xff0c;使能昇腾芯片。 项目地址: https://gitcode.com/cann/driver 函数原型 int dcmi_get_card_elabel(int card_id, struct dcmi_elab…

作者头像 李华
网站建设 2026/5/9 21:28:38

脑电信号实时预测:从CNN+Transformer+RNN混合模型到工程部署全解析

1. 项目概述&#xff1a;从脑电信号到实时预测的工程实践脑电图信号处理&#xff0c;听起来像是实验室里的高深学问&#xff0c;离我们很远。但如果你接触过神经反馈训练、专注力监测设备&#xff0c;或者对脑机接口有点兴趣&#xff0c;那你其实已经摸到了它的边。简单说&…

作者头像 李华
网站建设 2026/5/9 21:26:40

大气层系统:从零开始玩转Switch自定义固件的完整指南

大气层系统&#xff1a;从零开始玩转Switch自定义固件的完整指南 【免费下载链接】Atmosphere-stable 大气层整合包系统稳定版 项目地址: https://gitcode.com/gh_mirrors/at/Atmosphere-stable 大气层&#xff08;Atmosphere&#xff09;是一款为任天堂Switch设计的开源…

作者头像 李华
网站建设 2026/5/9 21:25:15

Trafilatura:高精度网页正文提取的Python利器与实战指南

1. 项目概述&#xff1a;一个被低估的文本提取利器 如果你经常需要从网页上批量抓取文章正文&#xff0c;并且受够了那些杂乱无章的HTML标签、导航栏、广告和评论&#xff0c;那么“adbar/trafilatura”这个项目很可能就是你一直在寻找的解决方案。这不是一个简单的正则表达式脚…

作者头像 李华