FusedCausalConv1d
【免费下载链接】ops-transformer本项目是CANN提供的transformer类大模型算子库,实现网络在NPU上加速计算。项目地址: https://gitcode.com/cann/ops-transformer
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | × |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | × |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列产品 | × |
| Atlas 训练系列产品 | × |
功能说明
算子功能:对序列执行因果一维卷积,沿序列维度使用缓存数据(长度为卷积核宽减1)对各序列头部进行padding,确保输出依赖当前及历史输入;卷积完成后,将当前序列尾部的数据(长度为卷积核宽减1)更新到缓存;在因果一维卷积输出的基础上,将原始输入加到输出上以实现残差连接。
本算子支持以下场景:
场景一(prefill场景):
x: [cu_seq_len, dim] weight: [K, dim],其中K=3 conv_states: [-1, K-1, dim] query_start_loc: [batch+1] cache_indices: [batch] initial_state_mode: [batch] bias: [dim](无作用) num_accepted_tokens: [batch](无作用) y: [cu_seq_len, dim] run_mode: 0其中cu_seq_len为batch内所有变长序列拼接后的总长度。
场景二(decode场景 - 变长序列):
x: [cu_seq_len, dim] weight: [K, dim],其中K=3 conv_states: [-1, state_len, dim] query_start_loc: [batch+1] cache_indices: [batch] initial_state_mode: [batch] bias: [dim](无作用) num_accepted_tokens: [batch](用于投机解码) y: [cu_seq_len, dim] run_mode: 1其中state_len必须大于所有batch中最大的token个数加1。
场景三(decode场景 - 固定batch):
x: [batch, m+1, dim] weight: [K, dim],其中K=3 conv_states: [-1, K-1+m, dim] query_start_loc: [batch+1](无作用) cache_indices: [batch] initial_state_mode: [batch] bias: [dim](无作用) num_accepted_tokens: [batch](用于投机解码,m为投机token个数) y: [batch, m+1, dim] run_mode: 1
计算公式:
K是卷积核宽度(固定为3),L是原始序列长度,dim是特征维度。
- 缓存拼接:
$$ x'[i, dim] = \begin{cases} cacheState[i, dim], & 0 \leq i < K-1 \ x[i - (K-1), dim], & K-1 \leq i < L + K - 1 \end{cases} $$
- 因果1维卷积:
$$ y[i, dim] = \sum_{k=0}^{K-1} w[k, dim] \cdot x'[i + k, dim] $$
- 缓存更新:
$$ cacheState[i, dim] = x'[L + i, dim], \quad i = 0, 1, \dots, K-2 $$
- 残差连接(可选):
$$ y[i, dim] += x[i, dim] $$
参数说明
| 参数名 | 输入/输出/属性 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| x | 输入 | 输入序列,对应公式中x。 | FLOAT16、BFLOAT16 | ND |
| weight | 输入 | 因果1维卷积核,K固定为3,对应公式中w。 | 数据类型与x一致 | ND |
| conv_states | 输入/输出 | 缓存状态张量,存储各序列的历史token数据,各序列计算完成后原地更新,对应公式中cacheState。 | 数据类型与x一致 | ND |
| query_start_loc | 可选输入 | 序列起始位置索引,记录各序列在拼接张量x中的起始位置。query_start_loc[i]表示第i个序列的起始偏移。 | INT32 | ND |
| cache_indices | 可选输入 | 缓存索引,指定每个序列对应的缓存状态在conv_states中的索引。 | INT32 | ND |
| initial_state_mode | 可选输入 | 初始状态标志,表示各序列是否使用缓存数据:0=零填充,1=使用缓存,2=使用缓存但前K-1个输出置0。 | INT32 | ND |
| bias | 可选输入 | 卷积的偏置。 | 数据类型与x一致 | ND |
| num_accepted_tokens | 可选输入 | decode场景下的投机token个数。 | INT32 | ND |
| activation_mode | 属性 | 激活函数类型,取值为0、1、2。 0:None; 1:silu; 2:swish。 | INT | - |
| pad_slot_id | 属性 | 用于跳过不需要参与计算的batch,-1表示不跳过。当cache_indices[i]==pad_slot_id时跳过该batch。 | INT | - |
| run_mode | 属性 | 用于判断是prefill场景或decode场景,取值为0、1。 0:prefill场景; 1:decode场景。 | INT | - |
| residual_connection | 属性 | 是否做残差连接,取值为0、1。 0:不做残差连接; 1:输出y和输入x相加后输出。 | INT | - |
| y | 输出 | 输出序列,shape与x一致,对应公式中y。 | 数据类型与x一致 | ND |
约束说明
输入shape限制:
- prefill场景:
- x支持2维[cu_seq_len, dim]。
- weight必须是2维[K, dim],其中K固定为3。
- conv_states必须是3维[..., K-1, dim],第0维大小不固定且大于等于batch。
- cu_seq_len范围[batch, 65536],dim范围[128, 16384]且是128的倍数,batch范围[1, 256]。
- decode场景(固定batch):
- x支持3维[batch, m+1, dim]。
- weight必须是2维[K, dim],其中K固定为3。
- conv_states必须是3维[..., K-1+m, dim],第0维大小不固定且大于等于batch。
- m范围[0, 5],dim范围[128, 16384]且是128的倍数,batch范围[1, 256]。
- decode场景(变长序列):
- x支持2维[cu_seq_len, dim]。
- weight必须是2维[K, dim],其中K固定为3。
- conv_states必须是3维[..., state_len, dim],第0维大小不固定且大于等于batch,state_len必须大于所有batch中最大的token个数加K-1。
- cu_seq_len范围[batch, batch*6],每个batch的token个数范围为[1, 6]。dim范围[128, 16384]且是128的倍数,batch范围[1, 256]。
- prefill场景:
输入值域限制:
- query_start_loc是累计偏移量,取值范围[0, cu_seq_len],长度为batch+1,query_start_loc[i]表示第i个序列的起始偏移,query_start_loc[batch+1]表示最后一个序列的结束位置。
- cache_indices长度为batch,指定每个序列对应的缓存槽索引。
- num_accepted_tokens分为None和非None,非None情况下长度为batch,每个元素取值不超过当前batch的token个数且大于0。
调用说明
调用方式 样例代码 说明 aclnn接口 test_aclnn_fused_causal_conv1d 通过aclnnFusedCausalConv1d调用FusedCausalConv1d算子
【免费下载链接】ops-transformer本项目是CANN提供的transformer类大模型算子库,实现网络在NPU上加速计算。项目地址: https://gitcode.com/cann/ops-transformer
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考