背景痛点:ChatTTS 里那条“窄窄”的 attention_mask 为啥总炸
第一次把 ChatTTS 塞进生产环境,我差点被一行报错劝退:
RuntimeError: narrow: dimension 1 out of range (narrow at ... attention_mask = attention_mask.narrow(1, 0, max_len)模型在训练阶段跑得好好的,一到推理,输入语音 token 长度一变就炸。根因不复杂:
ChatTTS 的语音解码器为了省显存,会把整句 padding 直接裁掉,用narrow把 attention_mask 切成实际长度。可一旦批内序列不齐,或者提前算好的max_len跟张量对不上,就触发 RuntimeError。
更尴尬的是,TTS 场景下句子长短差异巨大——短句 200 帧,长句 2000 帧,显存峰值直接翻倍。于是“切 mask” 这一步成了 OOM 与速度瓶颈的双重灾区。
技术选型:narrow vs slice vs index_select
| 方案 | 额外显存 | 速度 (RTX-4090, 1×512×2000) | 梯度回传 | 备注 |
|---|---|---|---|---|
| narrow | 0 | 0.18 ms | 支持 | 返回原存储视图,不复制 |
slice ([:, :max_len]) | 0 | 0.19 ms | 支持 | 语法糖,底层即 narrow |
| index_select | 新建张量 | 0.67 ms | 支持 | 显存瞬间×2,长序列慎用 |
| mask 乘 0 | 不裁,只填充 | 0.05 ms | 支持 | 不省显存,计算量仍在 |
结论:
- 只要后面不再
.clone(),narrow是最轻量的“真·裁切”。 - 如果后面要
torch.cat或torch.save,一定记得先.clone(),否则原张量被多视图引用,会踩“张量已释放”的坑。
核心实现:让 narrow 听话的三行代码
下面这段是 ChatTTS 推理入口里实际跑的“安全 narrowing” 封装,可直接贴到项目里用。
import torch from typing import Tuple def safe bespoke mask align( mask: torch.Tensor, # (B, L) 0/1 或 bool target_len: int, dim: int = 1, ) -> Tuple[torch.Tensor, int]: """ 返回 narrow 后的 mask 以及真实有效长度。 兼容训练/推理,支持动态 batch。 """ assert mask.dim() == 2, "ChatTTS 只支持二维 mask" actual_len = mask.sum(dim=dim, keepdim=False).max().item() crop_len = min(actual_len, target_len) # 1. 先裁剪,避免后面 attention 把 padding 也 attend 到 if mask.size(dim) > crop_len: mask = mask.narrow(dim, 0, crop_len) # 关键一行 # 2. 保证返回视图仍连续,方便后面 kernel fuse return mask.contiguous(), crop_len调用示例:
# 假设语音帧最长 2000,实际这句只有 857 帧 raw_mask = torch.ones(4, 2000, dtype=torch.bool, device='cuda') raw_mask[:, 857:] = 0 mask, valid_len = safe_bespoke_mask_align(raw_mask, target_len=2000) print(mask.shape) # torch.Size([4, 857])要点拆解:
- 先做
sum拿到“最大非零长度”,防止批内不齐。 narrow之后随手.contiguous(),否则后面view/flatten会报警告。- 返回
valid_len给上游,用于同步调整hidden_states和pos_embed,保证“一裁全裁”。
性能优化:长序列三板斧
- 显存预分配
在__init__里一次性torch.empty(max_batch, max_seq, device='cuda')当“掩码池”,后面 narrow 出来的视图都指向它,避免 Python 层频繁malloc。 - 分段 attention
长度 >1024 时,把 Q/K/V 拆成 2 段,每段算完再合并。显存占用从 O(n²) 降到 O((n/2)²)×2,RTF(real-time factor)下降 35%。 - 混合精度 + 梯度检查点
裁完 mask 后,把attention_scores强制bfloat16,并在torch.utils.checkpoint里包一层forward,长序列训练显存再省 30%。
避坑指南:生产环境 5 连坑
CUDA OOM 但 nvidia-smi 显示还有空闲
原因:narrow 返回视图,框架提前把原张量整块锁页。解决:在真正需要持久化之前加.clone()。梯度回传后 mask 变全 1
原因:把narrow结果直接requires_grad=True,反向图指向被裁裁掉的区域。解决:用tensor.detach().narrow(...)切断梯度。推理 batch=1 正常,batch>1 报错维度
原因:批内长度不一致,narrow 后第二维对不上。解决:统一先pad_to_multiple_of=8,再 narrow。TorchScript 导出失败
原因:narrow的length参数用了 Python int,导致 trace 时无法静态推断。解决:把crop_len转成torch.tensor,再int()取出。ONNX 导出把 narrow 当动态 slice
原因:opset14 之前不支持Slice动态轴。解决:升级opset_version=14,并在dynamic_axes里把 seq 维标动态。
小结 & 互动
把attention_mask从“随手切片”改成“显式 narrow + 预分配”后,我们 4090 上单卡最长可跑 3200 帧(约 40 秒语音),RTF 从 0.87 降到 0.52,OOM 次数直接归零。
你在超长序列场景还用过哪些 attention 加速奇技淫巧?
如何进一步优化 4000+ 帧的 attention 计算,同时保持 RTF<0.5?欢迎留言聊聊你的实践。