news 2026/4/15 16:10:36

KV Cache 详解:大模型推理的核心优化技术

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
KV Cache 详解:大模型推理的核心优化技术

一、什么是 KV Cache?

KV Cache(Key-Value 缓存)是 Transformer 模型在自回归推理过程中,为了避免重复计算而存储的中间状态。它是提高大模型推理速度的关键技术。

核心概念

  1. KV指的是 Transformer 注意力机制中的Key 和 Value 向量

  2. Cache是指将这些向量缓存起来,供后续 token 生成时复用

  3. 目的:避免对已处理过的 token 重新计算 Key 和 Value

二、为什么需要 KV Cache?

问题:Transformer 的注意力计算

在 Transformer 中,每个 token 在注意力层需要:

  • Query(Q):查询向量

  • Key(K):键向量

  • Value(V):值向量

注意力分数计算公式:

Attention(Q, K, V) = softmax(Q·K^T / √d) · V

推理时的困境

没有缓存的情况

生成第1个token:计算 Token1 的 Q, K, V 生成第2个token:计算 Token1, Token2 的 Q, K, V ← 重复计算Token1! 生成第3个token:计算 Token1, Token2, Token3 的 Q, K, V ← 重复计算Token1,2!

每次生成新 token 时,都需要为所有历史token重新计算K和V,计算量随序列长度平方增长。

有缓存的情况

生成第1个token:计算 Token1 的 Q, K, V,并缓存 K1, V1 生成第2个token:计算 Token2 的 Q, K, V,从缓存读取 K1, V1 生成第3个token:计算 Token3 的 Q, K, V,从缓存读取 K1, V1, K2, V2

只需为新 token 计算 K 和 V,历史 token 的 K, V 从缓存读取。

三、KV Cache 的数学表示

Transformer 层中的计算

对于第l层,输入x

Q^l = x · W_Q^l # [batch, seq_len, d_model] -> [batch, seq_len, d_k] K^l = x · W_K^l # [batch, seq_len, d_model] -> [batch, seq_len, d_k] V^l = x · W_V^l # [batch, seq_len, d_model] -> [batch, seq_len, d_v]

缓存内容

KV Cache 存储的是每个层的:

  • Key 矩阵 K^l:形状 [batch, seq_len, d_k]

  • Value 矩阵 V^l:形状 [batch, seq_len, d_v]

推理过程伪代码

class TransformerDecoderWithKVCache: def __init__(self, model, max_seq_len): self.model = model self.kv_cache = { 'keys': torch.zeros(max_seq_len, num_layers, d_k), 'values': torch.zeros(max_seq_len, num_layers, d_v) } self.cache_position = 0 def generate_next_token(self, input_token): # 1. 前向传播到每一层 for layer_idx, layer in enumerate(self.model.layers): # 计算当前token的Q, K, V q, k, v = layer.attention.qkv_projection(input_token) # 2. 更新缓存:将新token的K,V存入缓存 self.kv_cache['keys'][self.cache_position, layer_idx] = k self.kv_cache['values'][self.cache_position, layer_idx] = v # 3. 注意力计算:使用当前Q和所有缓存的K,V # 从缓存获取到当前位置的所有K,V cached_keys = self.kv_cache['keys'][:self.cache_position+1, layer_idx] cached_values = self.kv_cache['values'][:self.cache_position+1, layer_idx] # 计算注意力 attention_output = self.attention(q, cached_keys, cached_values) # 4. 继续前向传播 input_token = layer.ffn(attention_output) self.cache_position += 1 return output_logits

四、实际例子分析

案例:LLaMA-7B 模型

模型参数

  • 层数:32

  • 注意力头数:32

  • 每个头的维度:128

  • 上下文长度:4096

KV Cache 大小计算

# 每层的K/V矩阵大小 per_layer_kv_size = 2 * (d_model * d_k) # K和V各一个 # 对于每个token,每层的缓存大小 per_token_per_layer = 2 * (num_heads * head_dim) # 假设为 2 * (32 * 128) = 8192 个浮点数 # 所有层的缓存大小(单个token) per_token_all_layers = per_token_per_layer * num_layers # 8192 * 32 = 262,144 浮点数 # 浮点数大小(假设float16) per_token_bytes = 262,144 * 2 # ≈ 524 KB # 完整序列的缓存(4096个token) full_sequence_bytes = 524 KB * 4096 ≈ 2.1 GB

内存占用分析

  • 模型权重:7B 参数,float16 格式 ≈ 14 GB

  • KV Cache(最大长度):≈ 2.1 GB

  • 总内存:约 16.1 GB

实际推理步骤示例

假设我们让模型生成 "The quick brown fox"

步骤1:输入 "The"

计算过程: - 嵌入层:将 "The" 转换为向量 - 每一层:计算 "The" 的 Q1, K1, V1 - 保存:K1, V1 到 KV Cache - 输出:预测下一个词的概率分布 - 选择:选择概率最高的词 "quick"

KV Cache 状态

Layer1: K1, V1 Layer2: K1, V1 ... Layer32: K1, V1

步骤2:输入 "quick"(当前序列:"The quick")

计算过程: - 嵌入层:将 "quick" 转换为向量 - 每一层: * 计算 "quick" 的 Q2, K2, V2 * 从缓存读取:Layer1: K1, V1 * 注意力计算:Attention(Q2, [K1, K2], [V1, V2]) * 保存:K2, V2 到 KV Cache - 输出:预测下一个词 - 选择:"brown"

KV Cache 状态

Layer1: K1, V1, K2, V2 Layer2: K1, V1, K2, V2 ... Layer32: K1, V1, K2, V2

步骤3:输入 "brown"(当前序列:"The quick brown")

计算过程类似,但缓存中有3个token的K,V 注意力计算:Attention(Q3, [K1, K2, K3], [V1, V2, V3])

五、KV Cache 的优化技术

1.PagedAttention(vLLM)

  • 问题:传统KV Cache是连续内存,导致内存碎片

  • 解决方案:将KV Cache分页管理

  • 效果:提高内存利用率,支持更长的上下文

# 传统KV Cache(连续内存) kv_cache = torch.zeros(max_len, num_layers, d_kv) # PagedAttention(分页管理) class KVCachePage: def __init__(self, page_size): self.keys = torch.zeros(page_size, d_kv) self.values = torch.zeros(page_size, d_kv) # 管理多个页 kv_cache_pages = [KVCachePage(page_size) for _ in range(num_pages)]

2.Multi-Query Attention(MQA)

  • 传统:每个头有自己的K,V → 缓存大

  • MQA:多个头共享K,V → 缓存小

  • 内存节省:约减少为 1/num_heads

3.Grouped-Query Attention(GQA)

  • 介于MHA和MQA之间

  • 将头分组,组内共享K,V

  • 平衡效果和内存

4.KV Cache 量化

  • 将KV Cache从float16量化为int8

  • 内存减半,精度损失小

  • 公式KV_int8 = round(KV_fp16 / scale)

5.滑动窗口注意力

  • 只缓存最近N个token的K,V

  • 适用于长文本,内存恒定

  • 缺点:无法处理长距离依赖

def sliding_window_kv_cache(kv_cache, new_k, new_v, window_size): # 添加新的K,V kv_cache.append((new_k, new_v)) # 如果超过窗口大小,移除最旧的 if len(kv_cache) > window_size: kv_cache.pop(0) return kv_cache

六、代码实现示例

简单KV Cache实现

import torch import torch.nn as nn class KVCache: def __init__(self, max_batch_size, max_seq_len, num_layers, num_heads, head_dim, dtype=torch.float16): self.max_seq_len = max_seq_len self.num_layers = num_layers self.num_heads = num_heads self.head_dim = head_dim # 初始化缓存 self.key_cache = torch.zeros( max_batch_size, num_layers, max_seq_len, num_heads, head_dim, dtype=dtype ) self.value_cache = torch.zeros( max_batch_size, num_layers, max_seq_len, num_heads, head_dim, dtype=dtype ) # 当前序列长度 self.seq_len = 0 def update(self, layer_idx, new_key, new_value, batch_idx=0): """更新指定层的KV缓存""" # new_key形状: [batch, num_heads, seq_len=1, head_dim] # new_value形状: [batch, num_heads, seq_len=1, head_dim] self.key_cache[batch_idx, layer_idx, self.seq_len] = new_key.squeeze(2) self.value_cache[batch_idx, layer_idx, self.seq_len] = new_value.squeeze(2) def get(self, layer_idx, batch_idx=0): """获取指定层的KV缓存(到当前seq_len)""" keys = self.key_cache[batch_idx, layer_idx, :self.seq_len] values = self.value_cache[batch_idx, layer_idx, :self.seq_len] return keys, values def increment_seq_len(self): """增加序列长度""" self.seq_len += 1 if self.seq_len > self.max_seq_len: raise ValueError(f"序列长度超过最大值 {self.max_seq_len}") def clear(self): """清空缓存""" self.seq_len = 0 self.key_cache.zero_() self.value_cache.zero_()

在Transformer中的使用

class TransformerDecoderLayerWithKVCache(nn.Module): def __init__(self, d_model, num_heads, ff_dim): super().__init__() self.self_attn = MultiHeadAttention(d_model, num_heads) self.ffn = FeedForward(d_model, ff_dim) def forward(self, x, kv_cache, layer_idx, use_cache=False): # 自注意力 q = self.self_attn.query_proj(x) if use_cache and kv_cache.seq_len > 0: # 有缓存:只计算当前token的K,V k = self.self_attn.key_proj(x) v = self.self_attn.value_proj(x) # 从缓存获取历史K,V past_keys, past_values = kv_cache.get(layer_idx) # 合并历史K,V和当前K,V keys = torch.cat([past_keys, k], dim=1) values = torch.cat([past_values, v], dim=1) # 更新缓存 kv_cache.update(layer_idx, k, v) else: # 无缓存:计算所有token的K,V q, k, v = self.self_attn.qkv_proj(x) keys, values = k, v if use_cache: kv_cache.update(layer_idx, k, v) # 计算注意力 attn_output = self.self_attn(q, keys, values) # FFN output = self.ffn(attn_output) return output

七、KV Cache 的挑战与解决方案

挑战1:内存占用大

解决方案

  1. 量化:FP16 → INT8,内存减半

  2. 压缩:稀疏化、低秩近似

  3. 选择性缓存:只缓存重要token

挑战2:长序列生成慢

解决方案

  1. FlashAttention:优化注意力计算

  2. 增量解码:仅计算新token

  3. 并行采样:一次生成多个候选

挑战3:批处理效率

解决方案

  1. 连续批处理:动态调整batch size

  2. vLLM的PagedAttention:高效内存管理

八、性能对比

方法内存占用推理速度实现复杂度
无KV Cache极慢(O(n²))简单
基础KV Cache高(O(n))快(O(n))中等
PagedAttention中等很快复杂
MQA/GQA很快中等

九、总结

KV Cache 是现代大语言模型推理的核心优化技术

  1. 工作原理:缓存历史token的Key和Value向量,避免重复计算

  2. 核心价值:将推理复杂度从 O(n²) 降到 O(n)

  3. 内存代价:需要额外存储所有历史token的K和V

  4. 优化方向:量化、压缩、高效内存管理

  5. 实际影响:决定了模型的最大上下文长度和推理速度

简单来说:KV Cache 就像是你读书时做的笔记。第一次读时做笔记(计算K,V),第二次需要时直接看笔记(从缓存读取),而不是重新读整本书(重新计算所有K,V)。

这就是为什么在 llama.cpp 等推理框架中,保持进程运行(KV Cache在内存中)比每次重启加载要快得多的根本原因。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/13 8:28:39

番茄小说下载工具终极指南:打造完美离线阅读体验

番茄小说下载工具终极指南:打造完美离线阅读体验 【免费下载链接】fanqienovel-downloader 下载番茄小说 项目地址: https://gitcode.com/gh_mirrors/fa/fanqienovel-downloader 还在为网络不稳定而错过精彩小说情节烦恼吗?这款强大的小说下载工具…

作者头像 李华
网站建设 2026/4/14 17:39:18

喜马拉雅音频获取工具的技术实现与用户体验分析

喜马拉雅音频获取工具的技术实现与用户体验分析 【免费下载链接】xmly-downloader-qt5 喜马拉雅FM专辑下载器. 支持VIP与付费专辑. 使用GoQt5编写(Not Qt Binding). 项目地址: https://gitcode.com/gh_mirrors/xm/xmly-downloader-qt5 从用户需求到技术实现 在日常的数…

作者头像 李华
网站建设 2026/4/15 6:27:12

Prometheus

Prometheus:现代监控系统的全方位解析与实践指南 一、初识 Prometheus 什么是 Prometheus? Prometheus(普罗米修斯)是一款开源监控系统,以多维数据模型(指标名称和键值对标识)和基于 HTTP 的…

作者头像 李华
网站建设 2026/4/8 15:01:30

PyTorch-CUDA-v2.9镜像支持Toxic Comment Classification有毒评论检测吗?

PyTorch-CUDA-v2.9镜像支持Toxic Comment Classification有毒评论检测吗? 在当今社交媒体与用户生成内容(UGC)爆炸式增长的背景下,网络空间中的负面言论——如侮辱、仇恨、威胁和恶意攻击——正以前所未有的速度蔓延。平台方面临巨…

作者头像 李华
网站建设 2026/4/14 13:58:15

Video2X:革命性AI视频增强技术的深度解析与应用指南

Video2X:革命性AI视频增强技术的深度解析与应用指南 【免费下载链接】video2x A lossless video/GIF/image upscaler achieved with waifu2x, Anime4K, SRMD and RealSR. Started in Hack the Valley II, 2018. 项目地址: https://gitcode.com/gh_mirrors/vi/vide…

作者头像 李华
网站建设 2026/4/14 19:45:07

ComfyUI Impact Pack 完全指南:AI图像处理与面部细节增强神器

ComfyUI Impact Pack 完全指南:AI图像处理与面部细节增强神器 【免费下载链接】ComfyUI-Impact-Pack 项目地址: https://gitcode.com/gh_mirrors/co/ComfyUI-Impact-Pack ComfyUI Impact Pack 是一个功能强大的AI图像处理扩展包,专为ComfyUI用户…

作者头像 李华