记忆遗忘、更新与一致性维护系统完整实现
import numpy as np
from typing import Dict, List, Any, Optional, Tuple, Set
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum
import math
import hashlib
import json
from collections import defaultdict
import heapq
import copy
class MemoryStatus(Enum):
"""记忆状态"""
ACTIVE = "active" # 活跃
CONSOLIDATED = "consolidated" # 巩固
WEAKENED = "weakened" # 弱化
FORGOTTEN = "forgotten" # 已遗忘
class ConsistencyLevel(Enum):
"""一致性级别"""
STRONG = "strong" # 强一致性
EVENTUAL = "eventual" # 最终一致性
WEAK = "weak" # 弱一致性
@dataclass
class MemoryNode:
"""记忆节点"""
id: str
content: Any
embedding: np.ndarray
importance: float # 重要性分数 (0-1)
strength: float # 记忆强度 (0-1)
timestamp: datetime
last_accessed: datetime
access_count: int = 0
version: int = 1
status: MemoryStatus = MemoryStatus.ACTIVE
metadata: Dict[str, Any] = field(default_factory=dict)
dependencies: Set[str] = field(default_factory=set) # 依赖的其他记忆 ID
constraints: List[str] = field(default_factory=list) # 一致性约束
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
"id": self.id,
"content": self.content if isinstance(self.content, (str, int, float, bool)) else str(self.content),
"embedding": self.embedding.tolist(),
"importance": self.importance,
"strength": self.strength,
"timestamp": self.timestamp.isoformat(),
"last_accessed": self.last_accessed.isoformat(),
"access_count": self.access_count,
"version": self.version,
"status": self.status.value,
"metadata": self.metadata,
"dependencies": list(self.dependencies),
"constraints": self.constraints
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'MemoryNode':
"""从字典创建"""
return cls(
id=data["id"],
content=data["content"],
embedding=np.array(data["embedding"]),
importance=data["importance"],
strength=data["strength"],
timestamp=datetime.fromisoformat(data["timestamp"]),
last_accessed=datetime.fromisoformat(data["last_accessed"]),
access_count=data.get("access_count", 0),
version=data.get("version", 1),
status=MemoryStatus(data.get("status", "active")),
metadata=data.get("metadata", {}),
dependencies=set(data.get("dependencies", [])),
constraints=data.get("constraints", [])
)
class ForgettingMechanism:
"""
遗忘机制
支持:
1. 时间衰减:Ebbinghaus 遗忘曲线
2. 重要性加权:基于重要性调整遗忘速度
3. 干扰抑制:新记忆对旧记忆的影响
4. 主动遗忘:有策略地删除记忆
"""
def __init__(self,
decay_rate: float = 0.1,
importance_factor: float = 0.5,
interference_factor: float = 0.3):
self.decay_rate = decay_rate # 基础衰减率
self.importance_factor = importance_factor # 重要性影响因子
self.interference_factor = interference_factor # 干扰影响因子
def ebbinghaus_decay(self, time_elapsed: float, initial_strength: float = 1.0) -> float:
"""
Ebbinghaus 遗忘曲线衰减
Args:
time_elapsed: 经过的时间 (小时)
initial_strength: 初始强度
Returns:
衰减后的强度
"""
# R = exp(-t/S), S 是记忆强度参数
decay = math.exp(-self.decay_rate * time_elapsed / (initial_strength + 0.1))
return initial_strength * decay
def calculate_forgetting(self,
memory: MemoryNode,
current_time: datetime) -> float:
"""
计算记忆遗忘程度
Args:
memory: 记忆节点
current_time: 当前时间
Returns:
遗忘后的新强度
"""
# 时间衰减
time_elapsed = (current_time - memory.last_accessed).total_seconds() / 3600 # 小时
time_decay = self.ebbinghaus_decay(time_elapsed, memory.strength)
# 重要性加权(重要性越高,遗忘越慢)
importance_weight = 1.0 - (memory.importance * self.importance_factor)
importance_adjusted = time_decay * importance_weight
# 访问频率增强(频繁访问的记忆更牢固)
access_bonus = math.log10(memory.access_count + 1) * 0.1
final_strength = min(1.0, importance_adjusted + access_bonus)
return final_strength
def apply_forgetting(self,
memories: Dict[str, MemoryNode],
current_time: datetime,
threshold: float = 0.15) -> List[str]:
"""
应用遗忘机制
Args:
memories: 记忆字典
current_time: 当前时间
threshold: 遗忘阈值(低于此值标记为遗忘)
Returns:
被遗忘的记忆 ID 列表
"""
forgotten_ids = []
for memory_id, memory in memories.items():
if memory.status == MemoryStatus.FORGOTTEN:
continue
# 计算新强度
new_strength = self.calculate_forgetting(memory, current_time)
memory.strength = new_strength
# 更新状态
if new_strength < threshold:
memory.status = MemoryStatus.FORGOTTEN
forgotten_ids.append(memory_id)
elif new_strength < 0.4:
memory.status = MemoryStatus.WEAKENED
elif new_strength > 0.7:
memory.status = MemoryStatus.CONSOLIDATED
return forgotten_ids
def proactive_forgetting(self,
memories: Dict[str, MemoryNode],
max_capacity: int,
strategy: str = "importance") -> List[str]:
"""
主动遗忘(容量管理)
Args:
memories: 记忆字典
max_capacity: 最大容量
strategy: 遗忘策略 ("importance", "age", "combined")
Returns:
被删除的记忆 ID 列表
"""
if len(memories) <= max_capacity:
return []
# 获取活跃记忆
active_memories = [
(mid, m) for mid, m in memories.items()
if m.status != MemoryStatus.FORGOTTEN
]
# 计算优先级分数(越低越应该被遗忘)
scores = []
for mid, memory in active_memories:
if strategy == "importance":
score = memory.importance * 0.6 + memory.strength * 0.4
elif strategy == "age":
age_hours = (datetime.now() - memory.timestamp).total_seconds() / 3600
score = 1.0 / (1.0 + math.log10(age_hours + 1))
else: # combined
score = (memory.importance * 0.4 +
memory.strength * 0.3 +
memory.access_count / (memory.access_count + 10) * 0.3)
scores.append((mid, score))
# 排序并选择要遗忘的
scores.sort(key=lambda x: x[1])
remove_count = len(active_memories) - max_capacity
to_remove = [mid for mid, _ in scores[:remove_count]]
# 标记为遗忘
for mid in to_remove:
memories[mid].status = MemoryStatus.FORGOTTEN
return to_remove
class MemoryUpdater:
"""
记忆更新器
支持:
1. 增量更新:只更新变化的部分
2. 版本管理:维护历史版本
3. 冲突检测:检测更新冲突
4. 回滚机制:支持版本回滚
"""
def __init__(self, max_versions: int = 5):
self.max_versions = max_versions
self.version_history: Dict[str, List[Dict]] = defaultdict(list)
self.conflicts: List[Dict] = []
def update_memory(self,
memory: MemoryNode,
new_content: Any = None,
new_importance: float = None,
update_metadata: Dict[str, Any] = None) -> MemoryNode:
"""
更新记忆
Args:
memory: 要更新的记忆
new_content: 新内容
new_importance: 新重要性
update_metadata: 元数据更新
Returns:
更新后的记忆
"""
# 保存历史版本
self._save_version(memory)
# 创建更新副本
updated_memory = copy.deepcopy(memory)
updated_memory.version += 1
updated_memory.last_accessed = datetime.now()
updated_memory.access_count += 1
# 应用更新
if new_content is not None:
updated_memory.content = new_content
if new_importance is not None:
updated_memory.importance = new_importance
if update_metadata:
updated_memory.metadata.update(update_metadata)
return updated_memory
def _save_version(self, memory: MemoryNode):
"""保存版本历史"""
version_data = {
"version": memory.version,
"content": memory.content if isinstance(memory.content, (str, int, float, bool)) else str(memory.content),
"importance": memory.importance,
"strength": memory.strength,
"timestamp": memory.last_accessed.isoformat(),
"metadata": memory.metadata
}
history = self.version_history[memory.id]
history.append(version_data)
# 限制版本数量
if len(history) > self.max_versions:
history.pop(0)
def rollback(self,
memory: MemoryNode,
target_version: int) -> Optional[MemoryNode]:
"""
回滚到指定版本
Args:
memory: 当前记忆
target_version: 目标版本号
Returns:
回滚后的记忆,如果版本不存在返回 None
"""
history = self.version_history.get(memory.id, [])
# 查找目标版本
target_data = None
for version_data in history:
if version_data["version"] == target_version:
target_data = version_data
break
if target_data is None:
return None
# 创建回滚后的记忆
rolled_back = copy.deepcopy(memory)
rolled_back.version = target_version + 1
rolled_back.content = target_data["content"]
rolled_back.importance = target_data["importance"]
rolled_back.strength = target_data["strength"]
rolled_back.metadata = target_data["metadata"]
rolled_back.last_accessed = datetime.now()
return rolled_back
def detect_conflict(self,
memory1: MemoryNode,
memory2: MemoryNode,
constraint_checker: callable = None) -> Optional[Dict]:
"""
检测记忆冲突
Args:
memory1: 记忆 1
memory2: 记忆 2
constraint_checker: 约束检查函数
Returns:
冲突信息,如果没有冲突返回 None
"""
# 检查依赖冲突
if memory1.id in memory2.dependencies or memory2.id in memory1.dependencies:
if memory1.content != memory2.content:
return {
"type": "dependency_conflict",
"memory1_id": memory1.id,
"memory2_id": memory2.id,
"description": "依赖记忆内容不一致"
}
# 检查约束冲突
if constraint_checker:
if not constraint_checker(memory1, memory2):
return {
"type": "constraint_violation",
"memory1_id": memory1.id,
"memory2_id": memory2.id,
"description": "违反一致性约束"
}
return None
class ConsistencyMaintainer:
"""
一致性维护器
支持:
1. 约束验证:检查记忆是否满足约束
2. 冲突解决:自动或手动解决冲突
3. 一致性修复:修复不一致的记忆
4. 分布式一致性:多节点一致性维护
"""
def __init__(self, consistency_level: ConsistencyLevel = ConsistencyLevel.EVENTUAL):
self.consistency_level = consistency_level
self.constraints: Dict[str, callable] = {}
self.conflict_log: List[Dict] = []
def register_constraint(self,
constraint_name: str,
constraint_fn: callable):
"""注册一致性约束"""
self.constraints[constraint_name] = constraint_fn
def validate_memory(self,
memory: MemoryNode,
all_memories: Dict[str, MemoryNode]) -> List[Dict]:
"""
验证记忆一致性
Args:
memory: 要验证的记忆
all_memories: 所有记忆
Returns:
违反的约束列表
"""
violations = []
for constraint_name, constraint_fn in self.constraints.items():
if constraint_name in memory.constraints:
try:
if not constraint_fn(memory, all_memories):
violations.append({
"constraint": constraint_name,
"memory_id": memory.id,
"description": f"违反约束:{constraint_name}"
})
except Exception as e:
violations.append({
"constraint": constraint_name,
"memory_id": memory.id,
"description": f"约束检查异常:{str(e)}"
})
return violations
def resolve_conflict(self,
conflict: Dict,
strategy: str = "latest") -> Dict[str, Any]:
"""
解决冲突
Args:
conflict: 冲突信息
strategy: 解决策略 ("latest", "highest_importance", "manual")
Returns:
解决结果
"""
self.conflict_log.append(conflict)
if strategy == "latest":
# 选择最新版本
return {
"resolved": True,
"strategy": "latest",
"winner": conflict.get("memory1_id"), # 简化处理
"action": "keep_latest"
}
elif strategy == "highest_importance":
# 选择重要性最高的
return {
"resolved": True,
"strategy": "highest_importance",
"winner": conflict.get("memory1_id"),
"action": "keep_highest_importance"
}
else:
# 需要人工干预
return {
"resolved": False,
"strategy": "manual",
"requires_human": True,
"conflict": conflict
}
def repair_inconsistency(self,
memory: MemoryNode,
violations: List[Dict],
all_memories: Dict[str, MemoryNode]) -> MemoryNode:
"""
修复不一致性
Args:
memory: 要修复的记忆
violations: 违反的约束列表
all_memories: 所有记忆
Returns:
修复后的记忆
"""
repaired = copy.deepcopy(memory)
for violation in violations:
constraint_name = violation["constraint"]
# 根据约束类型进行修复(简化示例)
if "dependency" in constraint_name:
# 更新依赖记忆
for dep_id in memory.dependencies:
if dep_id in all_memories:
# 同步依赖记忆的内容
pass
elif "contradiction" in constraint_name:
# 消除矛盾
repaired.strength *= 0.8 # 降低强度
return repaired
class MemoryManagementSystem:
"""
记忆管理系统
整合:
1. 遗忘机制
2. 记忆更新
3. 一致性维护
4. 优化平衡
"""
def __init__(self,
max_capacity: int = 10000,
forgetting_decay: float = 0.1,
consistency_level: ConsistencyLevel = ConsistencyLevel.EVENTUAL):
# 记忆存储
self.memories: Dict[str, MemoryNode] = {}
# 组件
self.forgetting = ForgettingMechanism(decay_rate=forgetting_decay)
self.updater = MemoryUpdater()
self.consistency = ConsistencyMaintainer(consistency_level)
# 配置
self.max_capacity = max_capacity
# 统计
self.stats = {
"total_stored": 0,
"total_forgotten": 0,
"total_updated": 0,
"conflicts_resolved": 0,
"consistency_violations": 0
}
def store_memory(self,
content: Any,
embedding: np.ndarray,
importance: float = 0.5,
metadata: Dict[str, Any] = None) -> MemoryNode:
"""存储记忆"""
memory_id = hashlib.md5(
f"{content}{datetime.now().isoformat()}".encode()
).hexdigest()[:16]
memory = MemoryNode(
id=memory_id,
content=content,
embedding=embedding,
importance=importance,
strength=1.0,
timestamp=datetime.now(),
last_accessed=datetime.now(),
metadata=metadata or {}
)
self.memories[memory_id] = memory
self.stats["total_stored"] += 1
# 检查容量并应用主动遗忘
if len(self.memories) > self.max_capacity:
forgotten = self.forgetting.proactive_forgetting(
self.memories, self.max_capacity
)
self.stats["total_forgotten"] += len(forgotten)
return memory
def apply_forgetting_cycle(self) -> List[str]:
"""应用遗忘周期"""
forgotten = self.forgetting.apply_forgetting(
self.memories, datetime.now()
)
self.stats["total_forgotten"] += len(forgotten)
return forgotten
def update_memory(self,
memory_id: str,
new_content: Any = None,
new_importance: float = None) -> Optional[MemoryNode]:
"""更新记忆"""
if memory_id not in self.memories:
return None
memory = self.memories[memory_id]
updated = self.updater.update_memory(
memory, new_content, new_importance
)
self.memories[memory_id] = updated
self.stats["total_updated"] += 1
# 验证一致性
violations = self.consistency.validate_memory(updated, self.memories)
if violations:
self.stats["consistency_violations"] += len(violations)
# 修复
updated = self.consistency.repair_inconsistency(
updated, violations, self.memories
)
self.memories[memory_id] = updated
return updated
def get_active_memories(self) -> List[MemoryNode]:
"""获取活跃记忆"""
return [
m for m in self.memories.values()
if m.status != MemoryStatus.FORGOTTEN
]
def get_stats(self) -> Dict[str, Any]:
"""获取统计信息"""
active_count = len(self.get_active_memories())
return {
**self.stats,
"total_memories": len(self.memories),
"active_memories": active_count,
"forgotten_memories": len(self.memories) - active_count,
"capacity_usage": active_count / self.max_capacity
}
# 使用示例
if __name__ == "__main__":
print("=== 记忆遗忘、更新与一致性维护 ===\n")
# 创建记忆管理系统
memory_system = MemoryManagementSystem(
max_capacity=100,
forgetting_decay=0.05,
consistency_level=ConsistencyLevel.EVENTUAL
)
print("=== 存储记忆 ===")
# 模拟记忆存储
np.random.seed(42)
for i in range(50):
content = f"记忆内容_{i}"
embedding = np.random.randn(64)
importance = np.random.uniform(0.3, 1.0)
memory = memory_system.store_memory(
content=content,
embedding=embedding,
importance=importance,
metadata={"category": f"类别_{i % 5}", "source": "模拟生成"}
)
if i % 10 == 0:
print(f"存储记忆:{memory.id[:8]}... 重要性={memory.importance:.3f}")
print(f"\n=== 系统统计 ===")
stats = memory_system.get_stats()
print(f"总存储:{stats['total_stored']}")
print(f"活跃记忆:{stats['active_memories']}")
print(f"容量使用:{stats['capacity_usage']:.1%}")
print(f"\n=== 应用遗忘 ===")
# 模拟时间流逝
for memory in memory_system.memories.values():
memory.last_accessed = datetime.now() - timedelta(hours=np.random.uniform(1, 100))
# 应用遗忘周期
forgotten = memory_system.apply_forgetting_cycle()
print(f"遗忘记忆数量:{len(forgotten)}")
stats = memory_system.get_stats()
print(f"遗忘后活跃记忆:{stats['active_memories']}")
print(f"\n=== 更新记忆 ===")
# 更新一些记忆
active_memories = memory_system.get_active_memories()[:5]
for memory in active_memories:
updated = memory_system.update_memory(
memory_id=memory.id,
new_content=f"{memory.content}_已更新",
new_importance=min(1.0, memory.importance + 0.1)
)
if updated:
print(f"更新记忆:{updated.id[:8]}... 版本={updated.version}")
stats = memory_system.get_stats()
print(f"\n总更新次数:{stats['total_updated']}")
print(f"一致性违规:{stats['consistency_violations']}")
print(f"\n=== 最终统计 ===")
final_stats = memory_system.get_stats()
print(f"总记忆数:{final_stats['total_memories']}")
print(f"活跃记忆:{final_stats['active_memories']}")
print(f"已遗忘:{final_stats['forgotten_memories']}")
print(f"总遗忘数:{final_stats['total_forgotten']}")
print(f"容量使用:{final_stats['capacity_usage']:.1%}")
print(f"\n关键观察:")
print("1. 遗忘机制:时间衰减 + 重要性筛选 + 主动遗忘")
print("2. 记忆更新:增量更新 + 版本管理 + 冲突检测")
print("3. 一致性维护:约束验证 + 冲突解决 + 一致性修复")
print("4. 动态平衡:遗忘 - 保留平衡 + 容量管理")
print("5. 智能管理:自适应调节 + 持续优化")
print("\n记忆管理的核心:遗忘 + 更新 + 一致 + 平衡 = 智能记忆")