背景痛点:高峰期“慢、卡、爆”三连击
去年双十一,我们内部客服系统第一次大促压测就翻车了:
- 平均响应 2.8 s,P99 飙到 12 s,用户疯狂点“转人工”。
- 8 张 A100 打满,GPU 内存占用 95%,新 Pod 起不来。
- 长对话(>20 轮)的上下文在显存里越堆越高,OOM 把推理实例一波带走。
根因一句话:LLM 推理是“计算+内存”双密集,而传统同步框架把两个瓶颈串行放大了。
- 计算:自回归生成,每 token 都要走一次完整 forward。
- 内存:KV-Cache 随序列长度线性增长,显存瞬间吃光。
- 业务:高峰并发 3 k QPS,同步阻塞线程池打满,CPU 空转 GPU 却排队。
想扛住大流量,必须在“架构层”把同步改异步,在“模型层”把计算和显存双降,在“数据层”把重复算力省下来。下面按这三层拆方案。
技术方案对比:三板斧怎么选
| 优化手段 | 提速倍数 | 适用场景 | 副作用 |
|---|---|---|---|
| 动态批处理(continuous batching) | 2~4× | 高并发、短答案 | 实现复杂,需调度器 |
| 模型量化(INT8/INT4) | 1.5~2× | 显存吃紧 | 精度下降 1-2% |
| KV-Cache 缓存+复用 | 3~10× | 多轮对话、重复问题 | 缓存命中率决定收益 |
经验:
- 如果流量峰谷差距大→优先上“动态批处理”,把 GPU 打满。
- 如果显存先爆→“KV-Cache 缓存+量化”组合拳,先省内存再提吞吐。
- 若业务答案短且重复度高→“缓存” ROI 最高,几天就能回本。
核心实现:代码直接搬
1. FastAPI 异步推理端点
把同步的model.generate()包一层async线程池,FastAPI 主线程永不阻塞。
# server.py import asyncio, time from fastapi import FastAPI, Request from transformers import AutoTokenizer, AutoModelForCausalLM from concurrent.futures import ThreadPoolExecutor app = FastAPIAPI() executor = ThreadPoolExecutor(max_workers=4) tokenizer = AutoTokenizer.from_pretrained("your-llm") model = AutoModelForCausalLM.from_pretrained("your-llm", device_map="auto") async def generate_async(prompt: str, max_new_tokens: int): loop = asyncio.get_event_loop() return await loop.run_in_executor( executor, lambda: model.generate( tokenizer(prompt, return_tensors="pt").input_ids.cuda(), max_new_tokens=max_new_tokens, do_sample=False ) ) @app.post("/chat") async def chat(req: Request): data = await req.json() tokens = await generate_async(data["prompt"], 128) return {"answer": tokenizer.decode(tokens[0], skip_special_tokens=True)}要点:
ThreadPoolExecutor大小 ≤ GPU 物理流多处理器数,避免 CUDA 上下文切换。- 生产环境用
uvicorn --workers 1 --loop uvloop进一步压 latency。
2. Redis 对话状态管理(TTL+序列化)
长对话最怕重复传 4 k tokens,用 Redis 把“历史上下文”缓存起来,key 用user_id+session_id,value 直接 pickle 整段 token array,TTL 设 30 min。
# redis_cache.py import pickle, redis, time r = redis.Redis(host='redis', port=6379, decode_responses=False) def make_key(uid, sid): return f"chat:{uid}:{sid}" def get_history(uid, sid, max_len=2048): raw = r.get(make_key(uid, sid)) if raw: tokens = pickle.loads(raw) return tokens[-max_len:] # 超长截断 return [] def set_history(uid, sid, tokens, ttl=1800): pipe = r.pipeline() pipe.set(make_key(uid, sid), pickle.dumps(tokens)) pipe.expire(make_key(uid, sid), ttl) pipe.execute()好处:
- 显存里只留当前 batch,历史踢到内存,GPU 侧 OOM 概率直线下降。
- 30 min TTL 自动清掉僵尸会话,防止 Redis 膨胀。
3. 动态批处理算法(简易版)
思路:维护一个“等待队列”,当队列累计到max_batch_size或超时batch_timeout就整包推理。下面用asyncio.Queue实现,真实生产可换成 Redis Stream。
# dynamic_batch.py import asyncio, time from typing import List class BatchScheduler: def __init__(self, max_bs=8, timeout=0.1): self.queue = asyncio.Queue() self.max_bs = max_bs self.timeout = timeout async def submit(self, prompt: str) -> str: future = asyncio.Future() await self.queue.put((prompt, future)) return await future async def loop(self, generate_func): while True: batch: List[tuple] = [] try: # 等待第一个请求 item = await asyncio.wait_for(self.queue.get(), timeout=1) batch.append(item) deadline = time.time() + self.timeout # 继续捞直到满或超时 while len(batch) < self.max_bs and time.time() < deadline: try: item = await asyncio.wait_for(self.queue.get(), timeout=0.02) batch.append(item) except asyncio.TimeoutError: break prompts = [p for (p, _) in batch] # 批量推理(这里简化成 list) answers = await generate_func(prompts) for (_, fut), ans in zip(batch, answers): fut.set_result(ans) except asyncio.TimeoutError: continue把generate_func换成前面generate_async的批量版,就能吃到 dynamic batching 红利。压测显示 8 卡 A100 上 QPS 从 120 → 410,提升 3.4×。
性能测试:优化前后硬指标
| 指标 | 优化前(同步) | 优化后(异步+dynamic batch) | 收益 |
|---|---|---|---|
| QPS | 120 | 410 | ↑241% |
| 平均延迟 | 2.8 s | 0.9 s | ↓68% |
| P99 延迟 | 12 s | 2.3 s | ↓81% |
| GPU 显存峰值 | 80 GB | 52 GB | ↓35% |
| 单卡利用率 | 42% | 92% | ↑50% |
测试条件:
- 输入 300 tokens,输出 100 tokens,8×A100-40G,Triton+TensorRT 未介入。
- 压测工具:locust+自定义客户端,持续 15 min,流量按秒级阶梯爬坡到 3 k QPS。
避坑指南:血泪经验打包
长对话内存泄漏
现象:显存每隔 30 min 跳涨 2 GB。
根因:KV-Cache 的 block 表在旧版本 transformers 里没回收。
解法:升级到 4.35+,或手动torch.cuda.empty_cache()每 100 轮。模型冷启动
现象:Pod 刚起第一包延迟 20 s。
根因:CUDA kernel 编译+权重 lazy load。
解法:- 启动脚本里先跑一条 warm-up prompt;
- 用
nvidia-ptxjitcompiler缓存卷挂载到 emptyDir,缩短重建时间 60%。
异常流量降级
突发 10× 流量时,先把“动态批”超时从 100 ms 降到 10 ms,牺牲延迟保吞吐;
若队列长度 > 5×max_bs,直接返回“系统繁忙,请稍后”,防止雪球。
写在最后的开放问题
目前我们流式响应(SSE)还是“整包生成→一次性 push”,首 token 时间只能压到 400 ms。
你有没有试过在 transformer 内部把use_cache=True与past_key_values逐 token 传出,配合 asyncio 的StreamResponse实现真正的“逐字 SSE”?
如果还能把 speculative decoding 融进来,理论上首 token 能再砍 30%。欢迎一起脑洞。