分层记忆系统完整实现
import numpy as np
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum
import json
import hashlib
from collections import deque
import heapq
class MemoryType(Enum):
"""记忆类型"""
SHORT_TERM = "short_term" # 短时记忆
LONG_TERM = "long_term" # 长时记忆
SEMANTIC = "semantic" # 语义记忆
EPISODIC = "episodic" # 情景记忆
PROCEDURAL = "procedural" # 程序记忆
class MemoryPriority(Enum):
"""记忆优先级"""
CRITICAL = 1 # 关键记忆
HIGH = 2 # 高优先级
NORMAL = 3 # 普通
LOW = 4 # 低优先级
@dataclass
class Memory:
"""记忆单元"""
id: str
content: str
memory_type: MemoryType
embedding: np.ndarray
timestamp: datetime
priority: MemoryPriority = MemoryPriority.NORMAL
access_count: int = 0
last_accessed: datetime = None
metadata: Dict[str, Any] = field(default_factory=dict)
strength: float = 1.0 # 记忆强度 (0-1)
def __post_init__(self):
if self.last_accessed is None:
self.last_accessed = self.timestamp
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
"id": self.id,
"content": self.content,
"memory_type": self.memory_type.value,
"embedding": self.embedding.tolist(),
"timestamp": self.timestamp.isoformat(),
"priority": self.priority.value,
"access_count": self.access_count,
"last_accessed": self.last_accessed.isoformat(),
"metadata": self.metadata,
"strength": self.strength
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'Memory':
"""从字典创建"""
return cls(
id=data["id"],
content=data["content"],
memory_type=MemoryType(data["memory_type"]),
embedding=np.array(data["embedding"]),
timestamp=datetime.fromisoformat(data["timestamp"]),
priority=MemoryPriority(data["priority"]),
access_count=data["access_count"],
last_accessed=datetime.fromisoformat(data["last_accessed"]),
metadata=data.get("metadata", {}),
strength=data.get("strength", 1.0)
)
@dataclass
class MemoryQuery:
"""记忆查询"""
query: str
query_embedding: np.ndarray
memory_types: List[MemoryType] = None
max_results: int = 10
time_range: Tuple[datetime, datetime] = None
min_strength: float = 0.0
class HierarchicalMemorySystem:
"""
分层记忆系统
核心功能:
1. 短时记忆:工作记忆、上下文窗口
2. 长时记忆:持久化存储
3. 语义记忆:事实知识
4. 情景记忆:个人经历
5. 记忆检索与更新
"""
def __init__(self,
short_term_capacity: int = 100,
consolidation_threshold: float = 0.7,
forgetting_rate: float = 0.01):
# 短时记忆(双端队列,FIFO)
self.short_term_memory: deque = deque(maxlen=short_term_capacity)
# 长时记忆(按类型分组)
self.long_term_memories: Dict[MemoryType, List[Memory]] = {
mt: [] for mt in MemoryType
}
# 记忆索引(用于快速检索)
self.memory_index: Dict[str, Memory] = {}
# 参数配置
self.consolidation_threshold = consolidation_threshold # 巩固阈值
self.forgetting_rate = forgetting_rate # 遗忘率
# 统计信息
self.stats = {
"total_encoded": 0,
"total_consolidated": 0,
"total_forgotten": 0,
"total_retrievals": 0
}
def encode_memory(self,
content: str,
memory_type: MemoryType,
embedding: np.ndarray,
priority: MemoryPriority = MemoryPriority.NORMAL,
metadata: Dict[str, Any] = None) -> Memory:
"""
编码记忆
Args:
content: 记忆内容
memory_type: 记忆类型
embedding: 向量嵌入
priority: 优先级
metadata: 元数据
Returns:
创建的记忆单元
"""
# 生成唯一 ID
memory_id = hashlib.md5(
f"{content}{datetime.now().isoformat()}".encode()
).hexdigest()[:16]
# 创建记忆单元
memory = Memory(
id=memory_id,
content=content,
memory_type=memory_type,
embedding=embedding,
timestamp=datetime.now(),
priority=priority,
metadata=metadata or {}
)
# 添加到短时记忆
self.short_term_memory.append(memory)
self.memory_index[memory_id] = memory
# 更新统计
self.stats["total_encoded"] += 1
# 检查是否需要巩固到长时记忆
if self._should_consolidate(memory):
self.consolidate_memory(memory_id)
return memory
def _should_consolidate(self, memory: Memory) -> bool:
"""判断是否应该巩固到长时记忆"""
# 高优先级记忆直接巩固
if memory.priority in [MemoryPriority.CRITICAL, MemoryPriority.HIGH]:
return True
# 根据记忆强度判断
return memory.strength >= self.consolidation_threshold
def consolidate_memory(self, memory_id: str) -> bool:
"""
巩固记忆到长时记忆
Args:
memory_id: 记忆 ID
Returns:
是否成功
"""
if memory_id not in self.memory_index:
return False
memory = self.memory_index[memory_id]
# 从短时记忆移除(如果在)
if memory in self.short_term_memory:
self.short_term_memory.remove(memory)
# 添加到长时记忆
self.long_term_memories[memory.memory_type].append(memory)
# 更新统计
self.stats["total_consolidated"] += 1
return True
def retrieve(self, query: MemoryQuery) -> List[Memory]:
"""
检索记忆
Args:
query: 查询对象
Returns:
匹配的记忆列表
"""
candidates = []
# 搜索长时记忆
for memory_type in (query.memory_types or list(MemoryType)):
for memory in self.long_term_memories[memory_type]:
# 时间范围过滤
if query.time_range:
start, end = query.time_range
if not (start <= memory.timestamp <= end):
continue
# 记忆强度过滤
if memory.strength < query.min_strength:
continue
# 计算相似度
similarity = self._cosine_similarity(
query.query_embedding,
memory.embedding
)
# 考虑记忆强度和访问频率的加权相似度
weighted_similarity = (
similarity * 0.6 +
memory.strength * 0.2 +
min(memory.access_count / 10, 1.0) * 0.2
)
candidates.append((weighted_similarity, memory))
# 搜索短时记忆
for memory in self.short_term_memory:
if query.memory_types and memory.memory_type not in query.memory_types:
continue
similarity = self._cosine_similarity(
query.query_embedding,
memory.embedding
)
weighted_similarity = similarity * 0.8 + memory.strength * 0.2
candidates.append((weighted_similarity, memory))
# 按相似度排序
candidates.sort(key=lambda x: x[0], reverse=True)
# 更新访问统计
for _, memory in candidates[:query.max_results]:
memory.access_count += 1
memory.last_accessed = datetime.now()
self.stats["total_retrievals"] += 1
return [memory for _, memory in candidates[:query.max_results]]
def _cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float:
"""计算余弦相似度"""
norm_a = np.linalg.norm(a)
norm_b = np.linalg.norm(b)
if norm_a == 0 or norm_b == 0:
return 0.0
return float(np.dot(a, b) / (norm_a * norm_b))
def update_memory_strength(self,
memory_id: str,
delta: float) -> bool:
"""
更新记忆强度
Args:
memory_id: 记忆 ID
delta: 强度变化量
Returns:
是否成功
"""
if memory_id not in self.memory_index:
return False
memory = self.memory_index[memory_id]
memory.strength = max(0.0, min(1.0, memory.strength + delta))
return True
def apply_forgetting(self) -> Dict[str, int]:
"""
应用遗忘机制
Returns:
遗忘统计
"""
forgotten_count = {mt: 0 for mt in MemoryType}
for memory_type in MemoryType:
memories = self.long_term_memories[memory_type]
# 标记要删除的记忆
to_remove = []
for memory in memories:
# 计算强度衰减
time_decay = self._calculate_time_decay(memory)
memory.strength -= time_decay * self.forgetting_rate
# 如果强度过低,标记删除
if memory.strength < 0.1 and memory.priority == MemoryPriority.LOW:
to_remove.append(memory.id)
# 删除记忆
for memory_id in to_remove:
if memory_id in self.memory_index:
memory = self.memory_index[memory_id]
memories.remove(memory)
del self.memory_index[memory_id]
forgotten_count[memory_type] += 1
self.stats["total_forgotten"] += 1
return forgotten_count
def _calculate_time_decay(self, memory: Memory) -> float:
"""计算时间衰减"""
age = datetime.now() - memory.last_accessed
days = age.total_seconds() / 86400
# 艾宾浩斯遗忘曲线简化版
return 1.0 / (1.0 + days)
def get_stats(self) -> Dict[str, Any]:
"""获取统计信息"""
return {
**self.stats,
"short_term_count": len(self.short_term_memory),
"long_term_counts": {
mt.value: len(memories)
for mt, memories in self.long_term_memories.items()
},
"total_memories": len(self.memory_index)
}
def export_memories(self, filepath: str):
"""导出记忆到文件"""
data = {
"short_term": [m.to_dict() for m in self.short_term_memory],
"long_term": {
mt.value: [m.to_dict() for m in memories]
for mt, memories in self.long_term_memories.items()
},
"stats": self.stats
}
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
def import_memories(self, filepath: str):
"""从文件导入记忆"""
with open(filepath, 'r', encoding='utf-8') as f:
data = json.load(f)
# 导入短时记忆
self.short_term_memory.clear()
for m_dict in data.get("short_term", []):
memory = Memory.from_dict(m_dict)
self.short_term_memory.append(memory)
self.memory_index[memory.id] = memory
# 导入长时记忆
for mt in MemoryType:
self.long_term_memories[mt] = []
for m_dict in data.get("long_term", {}).get(mt.value, []):
memory = Memory.from_dict(m_dict)
self.long_term_memories[mt].append(memory)
self.memory_index[memory.id] = memory
# 恢复统计
self.stats.update(data.get("stats", {}))
# 使用示例
if __name__ == "__main__":
print("=== Agent 分层记忆架构与工作机理 ===\n")
# 创建记忆系统
memory_system = HierarchicalMemorySystem(
short_term_capacity=50,
consolidation_threshold=0.7,
forgetting_rate=0.01
)
# 模拟向量嵌入(实际应使用 embedding 模型)
def generate_embedding(text: str) -> np.ndarray:
# 简化:使用文本哈希生成伪向量
hash_val = int(hashlib.md5(text.encode()).hexdigest(), 16)
np.random.seed(hash_val % (2**32))
embedding = np.random.randn(128)
embedding /= np.linalg.norm(embedding)
return embedding
print("=== 编码记忆 ===")
# 编码短时记忆
m1 = memory_system.encode_memory(
content="用户喜欢喝拿铁咖啡",
memory_type=MemoryType.EPISODIC,
embedding=generate_embedding("用户喜欢喝拿铁咖啡"),
priority=MemoryPriority.NORMAL,
metadata={"user_id": "user_001", "category": "preference"}
)
print(f"编码记忆:{m1.content}")
print(f"记忆 ID: {m1.id}")
print(f"初始强度:{m1.strength:.2f}")
# 编码高优先级记忆
m2 = memory_system.encode_memory(
content="用户过敏:花生",
memory_type=MemoryType.SEMANTIC,
embedding=generate_embedding("用户过敏:花生"),
priority=MemoryPriority.CRITICAL,
metadata={"user_id": "user_001", "category": "health"}
)
print(f"\n编码关键记忆:{m2.content}")
print(f"优先级:{m2.priority.value}")
# 编码更多记忆
for i in range(10):
memory_system.encode_memory(
content=f"对话历史 {i}: 讨论了 AI 记忆系统",
memory_type=MemoryType.EPISODIC,
embedding=generate_embedding(f"对话历史 {i}"),
priority=MemoryPriority.NORMAL
)
print(f"\n=== 记忆统计 ===")
stats = memory_system.get_stats()
print(f"总编码:{stats['total_encoded']}")
print(f"已巩固:{stats['total_consolidated']}")
print(f"短时记忆:{stats['short_term_count']}")
print(f"长时记忆:{stats['long_term_counts']}")
print(f"\n=== 检索记忆 ===")
# 创建查询
query = MemoryQuery(
query="用户偏好",
query_embedding=generate_embedding("用户喜欢什么"),
memory_types=[MemoryType.EPISODIC, MemoryType.SEMANTIC],
max_results=5
)
results = memory_system.retrieve(query)
print(f"查询:{query.query}")
print(f"检索到 {len(results)} 条记忆:")
for i, memory in enumerate(results, 1):
print(f"\n{i}. {memory.content}")
print(f" 类型:{memory.memory_type.value}")
print(f" 强度:{memory.strength:.2f}")
print(f" 访问次数:{memory.access_count}")
print(f" 相似度:{memory_system._cosine_similarity(query.query_embedding, memory.embedding):.3f}")
print(f"\n=== 记忆巩固 ===")
# 手动巩固一条记忆
if m1.id in memory_system.memory_index:
memory_system.consolidate_memory(m1.id)
print(f"已巩固记忆:{m1.content}")
print(f"\n=== 应用遗忘 ===")
forgotten = memory_system.apply_forgetting()
print(f"遗忘统计:{forgotten}")
print(f"\n=== 最终统计 ===")
final_stats = memory_system.get_stats()
print(f"总记忆数:{final_stats['total_memories']}")
print(f"总检索:{final_stats['total_retrievals']}")
print(f"总遗忘:{final_stats['total_forgotten']}")
print(f"\n关键观察:")
print("1. 分层架构:短时记忆 + 长时记忆(语义 + 情景)")
print("2. 记忆编码:内容 + 向量嵌入 + 元数据")
print("3. 记忆巩固:从短时到长时的自动迁移")
print("4. 相似检索:基于向量相似度的高效检索")
print("5. 遗忘机制:时间衰减 + 强度阈值")
print("6. 记忆更新:强度调整 + 访问统计")
print("\n记忆系统的核心:编码→存储→检索→更新→演化")