长期记忆存储与检索系统完整实现
import numpy as np
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
import math
import hashlib
import json
from collections import defaultdict
import heapq
class MemoryType(Enum):
"""记忆类型"""
SEMANTIC = "semantic" # 语义记忆(事实知识)
EPISODIC = "episodic" # 情景记忆(经历事件)
PROCEDURAL = "procedural" # 程序记忆(技能方法)
class IndexType(Enum):
"""索引类型"""
HNSW = "hnsw" # HNSW 索引
IVF = "ivf" # 倒排文件索引
FLAT = "flat" # 暴力检索
HYBRID = "hybrid" # 混合索引
@dataclass
class MemoryNode:
"""记忆节点"""
id: str
content: str
memory_type: MemoryType
embedding: np.ndarray
timestamp: datetime
metadata: Dict[str, Any] = field(default_factory=dict)
access_count: int = 0
last_accessed: datetime = None
strength: float = 1.0 # 记忆强度 (0-1)
related_ids: List[str] = field(default_factory=list) # 关联记忆 ID
def __post_init__(self):
if self.last_accessed is None:
self.last_accessed = self.timestamp
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
"id": self.id,
"content": self.content,
"memory_type": self.memory_type.value,
"embedding": self.embedding.tolist(),
"timestamp": self.timestamp.isoformat(),
"metadata": self.metadata,
"access_count": self.access_count,
"last_accessed": self.last_accessed.isoformat(),
"strength": self.strength,
"related_ids": self.related_ids
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'MemoryNode':
"""从字典创建"""
return cls(
id=data["id"],
content=data["content"],
memory_type=MemoryType(data["memory_type"]),
embedding=np.array(data["embedding"]),
timestamp=datetime.fromisoformat(data["timestamp"]),
metadata=data.get("metadata", {}),
access_count=data["access_count"],
last_accessed=datetime.fromisoformat(data["last_accessed"]),
strength=data.get("strength", 1.0),
related_ids=data.get("related_ids", [])
)
class HNSWIndex:
"""
HNSW (Hierarchical Navigable Small World) 索引
高效近似最近邻检索
"""
def __init__(self, M: int = 16, ef_construction: int = 200):
self.M = M # 每个节点的最大连接数
self.ef_construction = ef_construction # 构建时的搜索范围
self.layers: List[Dict[str, List[str]]] = [] # 分层图
self.entry_point: Optional[str] = None # 入口点
self.node_embeddings: Dict[str, np.ndarray] = {} # 节点向量
def _cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float:
"""计算余弦相似度"""
norm_a = np.linalg.norm(a)
norm_b = np.linalg.norm(b)
if norm_a == 0 or norm_b == 0:
return 0.0
return float(np.dot(a, b) / (norm_a * norm_b))
def _get_layer(self, node_id: str) -> int:
"""获取节点所在层"""
for i, layer in enumerate(self.layers):
if node_id in layer:
return i
return -1
def insert(self, node_id: str, embedding: np.ndarray):
"""插入节点"""
self.node_embeddings[node_id] = embedding
# 简单实现:单层图
if not self.layers:
self.layers.append({})
layer0 = self.layers[0]
layer0[node_id] = []
# 找到最近的节点并建立连接
if self.entry_point is not None:
# 计算与入口点的相似度
similarity = self._cosine_similarity(embedding, self.node_embeddings[self.entry_point])
# 双向连接
layer0[node_id].append(self.entry_point)
if node_id not in layer0[self.entry_point]:
layer0[self.entry_point].append(node_id)
else:
self.entry_point = node_id
def search(self, query_embedding: np.ndarray, k: int = 10) -> List[Tuple[str, float]]:
"""搜索最近邻"""
if not self.layers or not self.entry_point:
return []
# 贪心搜索
candidates = [(self.entry_point, self._cosine_similarity(
query_embedding, self.node_embeddings[self.entry_point]))]
visited = {self.entry_point}
# 扩展搜索
while candidates:
current_id, current_sim = heapq.heappop(candidates)
current_layer = self.layers[0]
# 探索邻居
for neighbor_id in current_layer.get(current_id, []):
if neighbor_id not in visited:
visited.add(neighbor_id)
similarity = self._cosine_similarity(
query_embedding, self.node_embeddings[neighbor_id])
heapq.heappush(candidates, (neighbor_id, similarity))
# 返回 Top-K
results = []
for node_id, embedding in self.node_embeddings.items():
similarity = self._cosine_similarity(query_embedding, embedding)
results.append((node_id, similarity))
results.sort(key=lambda x: x[1], reverse=True)
return results[:k]
class LongTermMemorySystem:
"""
长期记忆系统
核心功能:
1. 持久化存储:向量数据库
2. 高效索引:HNSW 索引
3. 相似检索:近似最近邻搜索
4. 记忆巩固:强度更新、关联建立
5. 遗忘管理:低强度记忆清理
"""
def __init__(self,
index_type: IndexType = IndexType.HNSW,
consolidation_threshold: float = 0.7,
forgetting_rate: float = 0.001):
# 记忆存储
self.memories: Dict[str, MemoryNode] = {}
# 索引系统
self.index_type = index_type
self.indices: Dict[MemoryType, HNSWIndex] = {
mt: HNSWIndex() for mt in MemoryType
}
# 参数配置
self.consolidation_threshold = consolidation_threshold
self.forgetting_rate = forgetting_rate
# 统计信息
self.stats = {
"total_stored": 0,
"total_retrieved": 0,
"total_consolidated": 0,
"total_forgotten": 0
}
def store_memory(self,
content: str,
memory_type: MemoryType,
embedding: np.ndarray,
metadata: Dict[str, Any] = None) -> MemoryNode:
"""
存储记忆
Args:
content: 记忆内容
memory_type: 记忆类型
embedding: 向量嵌入
metadata: 元数据
Returns:
创建的记忆节点
"""
# 生成唯一 ID
memory_id = hashlib.md5(
f"{content}{datetime.now().isoformat()}".encode()
).hexdigest()[:16]
# 创建记忆节点
memory = MemoryNode(
id=memory_id,
content=content,
memory_type=memory_type,
embedding=embedding,
timestamp=datetime.now(),
metadata=metadata or {}
)
# 存储
self.memories[memory_id] = memory
# 建立索引
self.indices[memory_type].insert(memory_id, embedding)
# 更新统计
self.stats["total_stored"] += 1
return memory
def retrieve(self,
query_embedding: np.ndarray,
memory_types: List[MemoryType] = None,
k: int = 10,
min_strength: float = 0.0) -> List[MemoryNode]:
"""
检索记忆
Args:
query_embedding: 查询向量
memory_types: 记忆类型列表
k: 返回数量
min_strength: 最小记忆强度
Returns:
匹配的记忆列表
"""
candidates = []
# 搜索指定类型的记忆
for memory_type in (memory_types or list(MemoryType)):
index = self.indices[memory_type]
results = index.search(query_embedding, k * 2)
for node_id, similarity in results:
if node_id in self.memories:
memory = self.memories[node_id]
# 强度过滤
if memory.strength < min_strength:
continue
# 加权分数:相似度 + 强度 + 访问频率
score = (
similarity * 0.6 +
memory.strength * 0.2 +
min(memory.access_count / 10, 1.0) * 0.2
)
candidates.append((score, memory))
# 排序
candidates.sort(key=lambda x: x[0], reverse=True)
# 更新访问统计
for _, memory in candidates[:k]:
memory.access_count += 1
memory.last_accessed = datetime.now()
self.stats["total_retrieved"] += 1
return [memory for _, memory in candidates[:k]]
def consolidate_memory(self,
memory_id: str,
related_ids: List[str] = None) -> bool:
"""
巩固记忆(增强强度、建立关联)
Args:
memory_id: 记忆 ID
related_ids: 相关记忆 ID 列表
Returns:
是否成功
"""
if memory_id not in self.memories:
return False
memory = self.memories[memory_id]
# 增强强度
memory.strength = min(1.0, memory.strength + 0.1)
# 建立关联
if related_ids:
for related_id in related_ids:
if related_id in self.memories and related_id != memory_id:
if related_id not in memory.related_ids:
memory.related_ids.append(related_id)
# 双向关联
related_memory = self.memories[related_id]
if memory_id not in related_memory.related_ids:
related_memory.related_ids.append(memory_id)
self.stats["total_consolidated"] += 1
return True
def apply_forgetting(self) -> Dict[MemoryType, int]:
"""
应用遗忘机制
Returns:
各类型遗忘数量
"""
forgotten = {mt: 0 for mt in MemoryType}
to_remove = []
for memory_id, memory in self.memories.items():
# 计算时间衰减
age_days = (datetime.now() - memory.last_accessed).total_seconds() / 86400
time_decay = 1.0 / (1.0 + age_days)
# 强度衰减
memory.strength -= self.forgetting_rate * time_decay
# 检查是否遗忘
if memory.strength < 0.1 and memory.access_count == 0:
to_remove.append(memory_id)
forgotten[memory.memory_type] += 1
# 删除记忆
for memory_id in to_remove:
del self.memories[memory_id]
self.stats["total_forgotten"] += 1
return forgotten
def build_knowledge_graph(self) -> Dict[str, List[str]]:
"""
构建知识图谱(基于关联关系)
Returns:
邻接表表示的图
"""
graph = defaultdict(list)
for memory_id, memory in self.memories.items():
for related_id in memory.related_ids:
graph[memory_id].append(related_id)
return dict(graph)
def get_stats(self) -> Dict[str, Any]:
"""获取统计信息"""
return {
**self.stats,
"total_memories": len(self.memories),
"memory_distribution": {
mt.value: sum(1 for m in self.memories.values()
if m.memory_type == mt)
for mt in MemoryType
},
"average_strength": np.mean([m.strength for m in self.memories.values()]) if self.memories else 0,
"average_access_count": np.mean([m.access_count for m in self.memories.values()]) if self.memories else 0
}
def export_memories(self, filepath: str):
"""导出记忆到文件"""
data = {
"memories": [m.to_dict() for m in self.memories.values()],
"stats": self.stats
}
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
def import_memories(self, filepath: str):
"""从文件导入记忆"""
with open(filepath, 'r', encoding='utf-8') as f:
data = json.load(f)
# 导入记忆
self.memories.clear()
for m_dict in data.get("memories", []):
memory = MemoryNode.from_dict(m_dict)
self.memories[memory.id] = memory
# 重建索引
self.indices[memory.memory_type].insert(memory.id, memory.embedding)
# 恢复统计
self.stats.update(data.get("stats", {}))
# 使用示例
if __name__ == "__main__":
print("=== 长期记忆存储、索引与检索优化 ===\n")
# 创建长期记忆系统
ltm_system = LongTermMemorySystem(
index_type=IndexType.HNSW,
consolidation_threshold=0.7,
forgetting_rate=0.001
)
# 模拟向量嵌入
def generate_embedding(text: str) -> np.ndarray:
hash_val = int(hashlib.md5(text.encode()).hexdigest(), 16)
np.random.seed(hash_val % (2**32))
embedding = np.random.randn(256)
embedding /= np.linalg.norm(embedding)
return embedding
print("=== 存储记忆 ===")
# 存储语义记忆
m1 = ltm_system.store_memory(
content="Python 是一种高级编程语言",
memory_type=MemoryType.SEMANTIC,
embedding=generate_embedding("Python 编程语言"),
metadata={"category": "programming", "difficulty": "beginner"}
)
print(f"存储语义记忆:{m1.content}")
print(f"记忆 ID: {m1.id}")
# 存储情景记忆
m2 = ltm_system.store_memory(
content="2024 年学习了机器学习课程",
memory_type=MemoryType.EPISODIC,
embedding=generate_embedding("机器学习学习经历"),
metadata={"year": 2024, "topic": "machine_learning"}
)
print(f"\n存储情景记忆:{m2.content}")
# 存储程序记忆
m3 = ltm_system.store_memory(
content="如何实现快速排序算法",
memory_type=MemoryType.PROCEDURAL,
embedding=generate_embedding("快速排序算法"),
metadata={"algorithm": "quicksort", "language": "python"}
)
print(f"\n存储程序记忆:{m3.content}")
# 存储更多记忆
for i in range(20):
ltm_system.store_memory(
content=f"知识点 {i+1}: 关于人工智能的某个概念",
memory_type=MemoryType.SEMANTIC,
embedding=generate_embedding(f"人工智能知识点{i+1}")
)
print(f"\n=== 记忆统计 ===")
stats = ltm_system.get_stats()
print(f"总存储:{stats['total_stored']}")
print(f"记忆分布:{stats['memory_distribution']}")
print(f"平均强度:{stats['average_strength']:.2f}")
print(f"平均访问次数:{stats['average_access_count']:.2f}")
print(f"\n=== 检索记忆 ===")
# 创建查询
query_embedding = generate_embedding("Python 编程")
results = ltm_system.retrieve(
query_embedding=query_embedding,
memory_types=[MemoryType.SEMANTIC, MemoryType.PROCEDURAL],
k=5
)
print(f"查询:Python 编程")
print(f"检索到 {len(results)} 条记忆:")
for i, memory in enumerate(results, 1):
print(f"\n{i}. {memory.content}")
print(f" 类型:{memory.memory_type.value}")
print(f" 强度:{memory.strength:.2f}")
print(f" 访问次数:{memory.access_count}")
print(f"\n=== 记忆巩固 ===")
# 巩固记忆并建立关联
ltm_system.consolidate_memory(m1.id, [m3.id])
print(f"巩固记忆 {m1.content}")
print(f"建立关联:{m1.id} <-> {m3.id}")
# 查看关联
memory = ltm_system.memories[m1.id]
print(f"关联记忆:{memory.related_ids}")
print(f"\n=== 构建知识图谱 ===")
graph = ltm_system.build_knowledge_graph()
print(f"知识图谱节点数:{len(graph)}")
print(f"知识图谱边数:{sum(len(neighbors) for neighbors in graph.values())}")
print(f"\n=== 应用遗忘 ===")
forgotten = ltm_system.apply_forgetting()
print(f"遗忘统计:{forgotten}")
print(f"\n=== 最终统计 ===")
final_stats = ltm_system.get_stats()
print(f"总记忆数:{final_stats['total_memories']}")
print(f"总检索:{final_stats['total_retrieved']}")
print(f"总巩固:{final_stats['total_consolidated']}")
print(f"总遗忘:{final_stats['total_forgotten']}")
print(f"\n关键观察:")
print("1. 持久化存储:向量数据库存储海量记忆")
print("2. 高效索引:HNSW 索引实现毫秒级检索")
print("3. 相似检索:基于余弦相似度的近似最近邻搜索")
print("4. 记忆巩固:增强强度、建立关联")
print("5. 遗忘管理:时间衰减、低强度清理")
print("6. 知识图谱:基于关联关系构建图结构")
print("\n长期记忆的核心:存储 + 索引 + 检索 + 优化 + 演化")