大模型推理加速:从 KV Cache 到 Continuous Batching 的实战复盘
一、深夜告警:GPU 没跑满,请求却在排队
某天凌晨,监控面板突然报警——线上 LLM 推理服务的 P99 延迟从 800ms 飙到了 4.2s。排查下来发现,并发量从 50 QPS 涨到 200 QPS 时,GPU 利用率居然只有 35%。大部分时间不是花在计算上,而是耗在请求调度和内存拷贝上。问题不在模型本身,而是推理框架没把 GPU 喂饱。
大模型推理的瓶颈往往不在计算密度,而在于调度策略。请求调度、内存管理、批处理方式的粗放设计,让 GPU 大量时间在“等数据”。这篇文章结合生产环境代码和压测数据,聊聊 KV Cache 管理、Continuous Batching、Prefix Caching 这几个关键优化点。
二、推理加速的三个关键点
2.1 KV Cache:避免重复计算注意力
Transformer 自回归解码时,每生成一个 token 都要重新计算前面所有 token 的注意力。KV Cache 把已经算好的 Key/Value 向量存下来,下次直接用。不过 KV Cache 占用的显存会随着序列长度线性增长,7B 模型在 seq_len=4096 时,光 KV Cache 就要占 2GB 左右。
2.2 Continuous Batching:别让短序列等长序列
传统的 Static Batching 要求批内所有序列都跑完才能释放资源,短序列只能干等着长序列。Continuous Batching 在每个迭代步动态插入新请求、移除已完成请求,GPU 利用率能从 35% 提到 85% 以上。
2.3 Prefix Caching:复用公共前缀
多轮对话里,系统提示词和上下文前缀往往是一样的。Prefix Caching 把公共前缀的 KV Cache 跨请求复用,后续请求直接命中缓存,跳过 prefill 阶段。
sequenceDiagram participant Client participant Scheduler participant KVCacheMgr participant GPU Client->>Scheduler: 请求1 (prompt + query) Scheduler->>KVCacheMgr: 检查 prefix cache 命中 KVCacheMgr-->>Scheduler: 未命中,分配新 slot Scheduler->>GPU: prefill(prompt) + decode(query) GPU-->>KVCacheMgr: 存储 KV Cache KVCacheMgr-->>Scheduler: 返回 token Client->>Scheduler: 请求2 (相同 prompt + 新 query) Scheduler->>KVCacheMgr: 检查 prefix cache 命中 KVCacheMgr-->>Scheduler: 命中!复用 prefix KV Scheduler->>GPU: 仅 decode(query),跳过 prefill GPU-->>Scheduler: 返回 token(延迟降低 60%)三、代码实现与压测结果
3.1 KV Cache 分页管理器
KV Cache 最头疼的问题是显存碎片化。借鉴操作系统的虚拟内存分页机制,把 KV Cache 切成固定大小的 Block,按需分配。
import torch from typing import Dict, List, Optional from dataclasses import dataclass, field @dataclass class KVBlock: """KV Cache 的一个物理块,固定大小""" block_id: int ref_count: int = 0 # 引用计数,支持 prefix cache 共享 device_tensor: Optional[torch.Tensor] = None # 实际显存数据 class PagedKVCacheManager: """ 分页式 KV Cache 管理器 核心思路:将 KV Cache 按固定 block_size 分页, 逻辑序列通过 page table 映射到物理 block, 避免显存预分配导致的碎片化问题 """ def __init__( self, num_blocks: int, block_size: int, num_kv_heads: int, head_dim: int, num_layers: int, dtype: torch.dtype = torch.float16, ): self.block_size = block_size self.num_layers = num_layers # 预分配所有物理 block 的显存池 # 形状: [num_blocks, 2, num_kv_heads, block_size, head_dim] # 2 对应 K 和 V element_size = torch.tensor([], dtype=dtype).element_size() per_block_bytes = 2 * num_kv_heads * block_size * head_dim * element_size total_bytes = num_blocks * per_block_bytes * num_layers print(f"[KVCache] 预分配显存池: {total_bytes / 1024**3:.2f} GB, " f"共 {num_blocks} 个 block") self.kv_pool = torch.empty( (num_layers, num_blocks, 2, num_kv_heads, block_size, head_dim), dtype=dtype, device="cuda" ) # 空闲 block 链表 self.free_blocks: List[KVBlock] = [ KVBlock(block_id=i) for i in range(num_blocks) ] # 逻辑序列 -> 物理 block 映射表 self.page_table: Dict[int, List[int]] = {} # block_id -> KVBlock 反向索引 self.block_map: Dict[int, KVBlock] = { b.block_id: b for b in self.free_blocks } def allocate(self, seq_id: int, num_tokens: int) -> List[int]: """ 为序列分配 KV Cache block 返回分配的物理 block_id 列表 """ num_needed = (num_tokens + self.block_size - 1) // self.block_size if len(self.free_blocks) < num_needed: raise RuntimeError( f"显存不足: 需要 {num_needed} 个 block, " f"仅剩 {len(self.free_blocks)} 个" ) allocated = [] for _ in range(num_needed): block = self.free_blocks.pop() block.ref_count = 1 allocated.append(block.block_id) self.page_table[seq_id] = allocated return allocated def free(self, seq_id: int) -> None: """释放序列占用的所有 KV Cache block""" if seq_id not in self.page_table: return for block_id in self.page_table[seq_id]: block = self.block_map[block_id] block.ref_count -= 1 # 引用计数归零才真正回收,支持 prefix cache 共享 if block.ref_count <= 0: block.ref_count = 0 self.free_blocks.append(block) del self.page_table[seq_id] def copy_prefix( self, src_seq_id: int, dst_seq_id: int, prefix_len: int ) -> List[int]: """ 复用 prefix 的 KV Cache(零拷贝,仅增加引用计数) 用于多轮对话场景,避免重复计算系统提示词 """ src_blocks = self.page_table.get(src_seq_id, []) num_prefix_blocks = prefix_len // self.block_size dst_blocks = [] # 共享 prefix block:增加引用计数,零拷贝 for block_id in src_blocks[:num_prefix_blocks]: self.block_map[block_id].ref_count += 1 dst_blocks.append(block_id) # 为新增 token 分配新 block remaining_tokens = prefix_len % self.block_size if remaining_tokens > 0: new_blocks = self.allocate(dst_seq_id, remaining_tokens) dst_blocks.extend(new_blocks) self.page_table[dst_seq_id] = dst_blocks return dst_blocks def get_physical_table(self, seq_id: int) -> torch.Tensor: """返回序列的 page table,用于 GPU kernel 中的地址映射""" block_ids = self.page_table.get(seq_id, []) return torch.tensor(block_ids, dtype=torch.int32, device="cuda")3.2 Continuous Batching 调度器
import time from collections import deque from dataclasses import dataclass from typing import Deque, List, Set @dataclass class Sequence: """推理序列状态机""" seq_id: int prompt_token_ids: List[int] generated_tokens: List[int] = field(default_factory=list) is_finished: bool = False max_tokens: int = 512 @property def num_generated(self) -> int: return len(self.generated_tokens) class ContinuousBatcher: """ 连续批处理调度器 核心逻辑:每个 decode step 动态调整 batch 组成, 已完成序列立即让出资源,新请求即时填入 """ def __init__(self, max_batch_size: int = 64): self.max_batch_size = max_batch_size self.waiting_queue: Deque[Sequence] = deque() self.running_batch: List[Sequence] = [] self.finished_ids: Set[int] = set() def add_request(self, seq: Sequence) -> None: """新请求入队""" self.waiting_queue.append(seq) def schedule(self) -> List[Sequence]: """ 单步调度:移除已完成序列,填入新请求 返回当前 step 的活跃 batch """ # 移除已完成的序列 self.running_batch = [ s for s in self.running_batch if not s.is_finished ] # 从等待队列填入新请求,直到 batch 满 available_slots = self.max_batch_size - len(self.running_batch) while available_slots > 0 and self.waiting_queue: seq = self.waiting_queue.popleft() self.running_batch.append(seq) available_slots -= 1 return self.running_batch def step(self) -> List[Sequence]: """ 执行一次 decode step 实际生产中此处调用 GPU kernel 执行推理 """ batch = self.schedule() if not batch: return [] # 模拟 decode:每个序列生成一个 token for seq in batch: # 实际场景:调用模型 forward,取 argmax token seq.generated_tokens.append(0) # placeholder if seq.num_generated >= seq.max_tokens: seq.is_finished = True self.finished_ids.add(seq.seq_id) return batch def is_idle(self) -> bool: return len(self.running_batch) == 0 and len(self.waiting_queue) == 03.3 压测数据:加速效果对比
在 A100 80GB 上部署 LLaMA-2-7B,对比三种策略的吞吐与延迟:
| 策略 | QPS | P50 延迟 | P99 延迟 | GPU 利用率 |
|---|---|---|---|---|
| Static Batching (batch=32) | 45 | 1.2s | 4.1s | 38% |
| Continuous Batching | 120 | 0.6s | 1.8s | 82% |
| Continuous + Prefix Cache | 165 | 0.35s | 1.1s | 88% |
数据很直观:Continuous Batching 把吞吐提升了 2.7 倍,加上 Prefix Cache 后达到 3.7 倍,P99 延迟从 4.1s 降到了 1.1s。
四、加速策略的代价:显存、复杂度与一致性
4.1 KV Cache 分页管理的显存开销
分页管理解决了碎片化,但也引入了 page table 的额外显存和查表开销。block_size 越小,碎片越少,但 page table 越大。实测 block_size=16 是 7B 模型的甜点,13B 模型建议 block_size=32。
4.2 Continuous Batching 的调度延迟
每个 step 都要执行 schedule 逻辑,在 batch_size=64 时,纯 Python 调度耗时约 0.3ms。对于 decode step 仅需 10ms 的场景,调度占比 3%。如果 batch_size 超过 256,得把调度逻辑下沉到 C++/CUDA,否则调度本身会成为瓶颈。
4.3 Prefix Cache 的一致性风险
共享 prefix block 用引用计数实现零拷贝,但如果模型权重更新(比如在线学习),缓存的 KV 值和新权重不匹配,输出质量会出问题。生产环境中,模型权重更新时必须强制失效所有 prefix cache。
4.4 不适合的场景
- 显存极度紧张(小于模型权重 1.2 倍)时,KV Cache 分页意义不大,建议优先用 PagedAttention 的 swap 机制
- 请求序列长度差异极大(1 token vs 8192 token)时,Continuous Batching 的调度开销可能抵消收益
- 单轮无前缀复用的场景,Prefix Cache 完全没用
五、总结
大模型推理加速的核心是最大化 GPU 计算密度。KV Cache 分页管理消除显存碎片,Continuous Batching 消除请求等待空洞,Prefix Cache 消除重复计算——这三者分别从内存、调度、计算三个维度压缩浪费。压测数据表明,三者叠加后 A100 上的推理吞吐提升了 3.7 倍,P99 延迟降低了 73%。
但每项优化都有代价:分页引入查表开销,连续批处理引入调度延迟,前缀缓存引入一致性风险。性能优化从来不是免费午餐,而是对具体场景的精确权衡。用代码说话,用数据服人——这才是推理加速工程的正确打开方式。