在认知科学中,人类记忆并非单一容器,而是由感觉记忆、短时记忆和长时记忆构成的分层系统。计算机架构师也早已深谙此道——从L1缓存到内存再到磁盘,逐级扩展容量,每一层都平衡着速度与成本。如今,这个思想正在大语言模型领域焕发新生,帮助Transformer突破上下文窗口的限制。这就是本文要探讨的核心:分层记忆缓冲(Hierarchical Memory Buffer)。
本文不仅会讲解概念,还会给出可落地的PyTorch代码,覆盖工作记忆、情节记忆、压缩记忆及自主调度等完整实现。
1. 为什么大模型需要一个记忆系统?
标准Transformer的自注意力复杂度随序列长度平方增长,即便有了FlashAttention等优化,处理百万级token的长文档时,仍面临两大顽疾:
- 遗忘首部信息:过长输入会超出位置编码有效范围,模型“看了后面忘前面”。
- 推理成本爆炸:KV Cache线性增长,内存和计算不堪重负。
一种自然的思路是:我们不把全部历史压进同一个注意力窗口,而是让模型学会分层存储和召回信息。这正是分层记忆缓冲的出发点。
2. 分层记忆缓冲的通用蓝图
在神经网络中,分层记忆通常抽象为三层结构:
| 层级 | 类比 | 容量 | 读写速度 | 典型实现 |
|---|---|---|---|---|
| 工作记忆 | L1 缓存 / 短时记忆 | 几k tokens | 极高(直接注意力) | 当前窗口的KV Cache |
| 情节记忆 | 内存 / 长时记忆 | 几十万tokens | 中等(检索/前馈) | 外部键值库、kNN索引 |
| 语义记忆 | 磁盘 / 永久知识 | 近乎无限 | 较慢(压缩/参数化) | 模型参数、向量数据库、摘要树 |
推理时,模型就像一位带着笔记本的学者:工作记忆是当前段落;情节记忆是手边快速查阅的索引卡片;语义记忆是大脑中长期内化的知识。
下面我们逐层用代码实现。
3. 工作记忆:当前窗口的KV Cache
任何Transformer推理都离不开KV Cache。在分层记忆中,工作记忆就是当前正在处理的片段对应的缓存,通过限制长度来模拟容量上限。
importtorchimporttorch.nnasnnclassWorkingMemory(nn.Module):def__init__(self,num_layers,num_heads,head_dim,max_len=4096):super().__init__()self.num_layers=num_layers self.num_heads=num_heads self.head_dim=head_dim self.max_len=max_len# 每一层维护K和V的缓存,初始为空self.k_cache=Noneself.v_cache=Nonedefupdate(self,new_k,new_v,layer_idx):"""将新片段的KV追加到缓存,并截断至最大长度"""ifself.k_cacheisNone:self.k_cache=[None]*self.num_layers self.v_cache=[None]*self.num_layersifself.k_cache[layer_idx]isNone:self.k_cache[layer_idx]=new_k self.v_cache[layer_idx]=new_velse:self.k_cache[layer_idx]=torch.cat([self.k_cache[layer_idx],new_k],dim=1)self.v_cache[layer_idx]=torch.cat([self.v_cache[layer_idx],new_v],dim=1)# 截断,保证工作记忆不溢出ifself.k_cache[layer_idx].size(1)>self.max_len:self.k_cache[layer_idx]=self.k_cache[layer_idx][:,-self.max_len:]self.v_cache[layer_idx]=self.v_cache[layer_idx][:,-self.max_len:]在实际注意力计算时,query不仅关注当前片段的KV,还会关注工作记忆中的KV。这正是标准自回归生成流程,此处不再赘述。
4. 情节记忆:kNN增强的外部记忆(Memorizing Transformers)
Google的Memorizing Transformers将过去所有token的Key-Value存入kNN索引,作为情节记忆。我们使用faiss实现一个简化版。
4.1 构建外部记忆库
importfaissimportnumpyasnpclassEpisodicMemory:def__init__(self,key_dim,capacity=100000):self.key_dim=key_dim self.capacity=capacity self.keys=[]# 存储所有过去的Keyself.values=[]# 存储所有过去的Valueself.index=faiss.IndexFlatIP(key_dim)# 内积相似度,与注意力点积对齐defadd(self,keys,values):"""keys: [seq_len, key_dim], values: [seq_len, value_dim]"""self.keys.extend(keys.detach().cpu().numpy())self.values.extend(values.detach().cpu().numpy())# 保持容量限制iflen(self.keys)>self.capacity:self.keys=self.keys[-self.capacity:]self.values=self.values[-self.capacity:]# 重建索引(实际可使用增量索引,此处简化)self.index=faiss.IndexFlatIP(self.key_dim)iflen(self.keys)>0:self.index.add(np.array(self.keys).astype(np.float32))defsearch(self,query,top_k=32):"""query: [batch*heads, q_len, key_dim]"""orig_shape=query.shape query_np=query.reshape(-1,self.key_dim).detach().cpu().numpy().astype(np.float32)scores,indices=self.index.search(query_np,top_k)# 根据索引取出对应的valueretrieved_vals=[]foridx_rowinindices:row_vals=[self.values[i]foriinidx_row]retrieved_vals.append(torch.tensor(np.array(row_vals)))retrieved_vals=torch.stack(retrieved_vals).view(*orig_shape[:-1],top_k,-1)returnretrieved_vals,torch.tensor(scores).view(*orig_shape[:-1],top_k)4.2 将外部记忆融入注意力
修改注意力计算,将检索到的记忆值通过softmax融合,并使用可学习的门控与本地注意力结合。
defattention_with_memory(query,key,value,episodic_memory,top_k=32):# 1. 正常局部注意力attn_scores=torch.matmul(query,key.transpose(-2,-1))/math.sqrt(query.size(-1))attn_probs=torch.softmax(attn_scores,dim=-1)local_output=torch.matmul(attn_probs,value)# 2. 从情节记忆检索mem_values,mem_scores=episodic_memory.search(query,top_k)mem_scores=mem_scores/math.sqrt(query.size(-1))mem_probs=torch.softmax(mem_scores,dim=-1)mem_output=torch.matmul(mem_probs.unsqueeze(-2),mem_values).squeeze(-2)# 3. 可学习门控融合(此处简化为固定值,实际可训练)gate=torch.sigmoid(torch.tensor(0.5))output=gate*local_output+(1-gate)*mem_outputreturnoutput推理一个片段后,将该片段的K、V存入情节记忆:
episodic_memory.add(layer_k[0],layer_v[0])# batch中第一个样本5. 递归压缩记忆:用摘要向量传递(AutoCompressor / Infini-Transformer)
另一种路径是将长序列压缩为固定数量的“记忆token”。这些token作为下一片段的前缀,扮演情节记忆。
5.1 记忆压缩模块
classCompressiveMemory(nn.Module):def__init__(self,dim,num_memory_tokens=16):super().__init__()# 可学习的记忆查询向量,负责从片段中提取信息self.memory_queries=nn.Parameter(torch.randn(num_memory_tokens,dim))self.cross_attn=nn.MultiheadAttention(dim,num_heads=8,batch_first=True)defforward(self,segment_hidden):""" segment_hidden: [batch, seg_len, dim] 返回压缩后的记忆: [batch, num_memory_tokens, dim] """queries=self.memory_queries.unsqueeze(0).expand(segment_hidden.size(0),-1,-1)compressed,_=self.cross_attn(queries,segment_hidden,segment_hidden)returncompressed5.2 片段间记忆传递
处理长文档时,将前一片段的压缩记忆拼接到当前段embedding之前,实现记忆的递归传递。
classHierarchicalTransformer(nn.Module):def__init__(self,base_transformer,num_memory_tokens=16):super().__init__()self.transformer=base_transformer self.memory_compressor=CompressiveMemory(base_transformer.d_model,num_memory_tokens)self.memory=None# 上一片段的压缩记忆defforward(self,input_ids,segment_length=2048):segments=input_ids.split(segment_length,dim=1)outputs=[]forseginsegments:ifself.memoryisnotNone:seg_emb=self.transformer.embedding(seg)seg_emb=torch.cat([self.memory,seg_emb],dim=1)# 记忆作为前缀else:seg_emb=self.transformer.embedding(seg)hidden=self.transformer(seg_emb)# 简化,实际需处理maskoutputs.append(hidden)# 压缩当前段最后一部分作为新记忆self.memory=self.memory_compressor(hidden[:,-segment_length:])returntorch.cat(outputs,dim=1)这种设计使得记忆规模恒定,不会随时间增长。
6. 操作系统式记忆:LLM自主管理读写(MemGPT)
MemGPT让LLM通过函数调用显式管理外部记忆。我们可借助OpenAI Function Calling的风格实现。
6.1 定义记忆工具
importjsonclassMemoryStore:def__init__(self):self.storage={}self.conversation_history=[]defread(self,key):returnself.storage.get(key,"Memory not found.")defwrite(self,key,content):self.storage[key]=contentreturnf"Stored '{key}'."defsearch(self,query):results={k:vfork,vinself.storage.items()ifqueryinv}returnjson.dumps(results)# 工具定义(符合OpenAI function calling格式)tools=[{"name":"read_memory","description":"Read content from external memory by key.","parameters":{"type":"object","properties":{"key":{"type":"string"}},"required":["key"]}},{"name":"write_memory","description":"Write a key-content pair to external memory.","parameters":{"type":"object","properties":{"key":{"type":"string"},"content":{"type":"string"}},"required":["key","content"]}},{"name":"search_memory","description":"Search memory for a query string.","parameters":{"type":"object","properties":{"query":{"type":"string"}},"required":["query"]}}]6.2 自主记忆调度
与LLM交互时,让模型决定何时读写记忆。
defllm_with_memory(user_message,model):messages=[{"role":"system","content":"You have an external memory. Use read/write/search_memory to manage it."},{"role":"user","content":user_message}]response=model.chat(messages,tools=tools)ifresponse.tool_calls:fortool_callinresponse.tool_calls:func_name=tool_call.function.name args=json.loads(tool_call.function.arguments)iffunc_name=="read_memory":result=memory_store.read(args["key"])eliffunc_name=="write_memory":result=memory_store.write(args["key"],args["content"])eliffunc_name=="search_memory":result=memory_store.search(args["query"])messages.append({"role":"tool","content":result,"name":func_name})final_response=model.chat(messages)returnfinal_response.contentelse:returnresponse.content模型可以自行将不重要的内容换出,需要时再检索,实现动态上下文扩展。
7. 训练分层记忆:让梯度流过记忆边界
要让模型学会何时写入、如何压缩,记忆操作必须可微或采用强化学习。
7.1 可微的近似检索
在训练时,用全部过去key的softmax近似替代kNN硬检索,使梯度能够回传。
defdifferentiable_memory_retrieval(query,all_past_keys,all_past_values,top_k=32):scores=torch.matmul(query,all_past_keys.transpose(-2,-1))/math.sqrt(query.size(-1))topk_scores,topk_indices=torch.topk(scores,top_k,dim=-1)topk_probs=torch.softmax(topk_scores,dim=-1)retrieved_values=torch.gather(all_past_values,1,topk_indices.unsqueeze(-1).expand(-1,-1,-1,all_past_values.size(-1)))returntorch.matmul(topk_probs.unsqueeze(-2),retrieved_values).squeeze(-2)7.2 压缩记忆的自监督损失
对于压缩记忆,可以要求模型从压缩向量重建原始片段,作为辅助损失。
defcompression_loss(compressed_memory,original_segment,decoder):reconstructed=decoder(compressed_memory)loss=nn.CrossEntropyLoss()(reconstructed.view(-1,vocab_size),original_segment.view(-1))returnloss联合主任务损失一起优化,迫使压缩记忆保留足够细节。
8. 最小可行示例:串起整个系统
下面代码演示了一个极简的分层记忆LLM,结合了工作记忆(KV Cache)和情节记忆(外部存储)。
classMiniHierarchicalLLM:def__init__(self,transformer,episodic_memory_capacity=10000):self.model=transformer self.working_memory=WorkingMemory(num_layers=transformer.num_layers,num_heads=transformer.num_heads,head_dim=transformer.head_dim,max_len=4096)self.episodic=EpisodicMemory(key_dim=transformer.d_model,capacity=episodic_memory_capacity)self.max_seg_len=2048defgenerate(self,input_ids,max_new_tokens=100):# 分段处理输入,更新记忆segments=input_ids.split(self.max_seg_len,dim=1)forseginsegments:hidden=self.model(seg,use_cache=True,past_key_values=self.working_memory.k_cache)# 更新工作记忆self.working_memory.k_cache=hidden.past_key_values# 将当前段的K,V存入情节记忆(取最后一层)last_k,last_v=hidden.past_key_values[-1]self.episodic.add(last_k.squeeze(0),last_v.squeeze(0))generated=[]current=input_ids[:,-1:]# 从最后一个token开始自回归生成for_inrange(max_new_tokens):output=self.model(current,use_cache=True,past_key_values=self.working_memory.k_cache,episodic_memory=self.episodic# 需要自行修改forward支持)next_token=output.logits[:,-1:].argmax(dim=-1)generated.append(next_token)current=next_token# 工作记忆的缓存在模型内部自动更新returntorch.cat(generated,dim=1)你可以从最简单的外部向量存储+检索开始,逐步加入压缩、自主调度和可微训练,让你的模型拥有真正的长时记忆。
9. 挑战与未来
分层记忆缓冲已在代码库理解、终生对话代理等任务上展现潜力,但仍面临挑战:
- 记忆冗余与遗忘:如何优雅地淘汰旧信息?能否模拟记忆的“再巩固”过程?
- 跨层级重组:能否增加离线阶段,自动将情节记忆提炼进语义记忆(模型参数)?
- 隐私与安全:外部记忆可能包含敏感信息,选择性遗忘机制至关重要。
- 多模态统一记忆:能否将文本、图像、音频映射到同一套键值空间?
10. 结语
分层记忆缓冲并非要让大模型变成笨重的数据库系统,而是赋予它一种组织自身经验的能力。正如记忆术中的“记忆宫殿”——将信息放置在熟悉的空间结构中,逐层导引,随时提取。
本文给出的代码片段为你提供了构建记忆系统的基石。无论你是想为聊天机器人增加长期记忆,还是让代码助手理解整个仓库,都可以从这里开始。随着我们向通用人工智能迈进,记忆的架构可能比模型本身更能定义其思考的深度与连贯性。
延伸阅读:
- Memorizing Transformers (Wu et al., 2022)
- MemGPT: Towards LLMs as Operating Systems (Packer et al., 2023)
- Infini-Transformer: Infinite Context with Compressive Memory
- AutoCompressor: Long Context Compression via Summary Tokens