情景记忆与经验回放系统完整实现
import numpy as np
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
import math
import hashlib
import json
from collections import deque
import random
import heapq
class MemoryType(Enum):
"""记忆类型"""
EPISODIC = "episodic" # 情景记忆
TEMPORAL = "temporal" # 时序记忆
SEMANTIC = "semantic" # 语义记忆
@dataclass
class Episode:
"""情景记忆片段"""
id: str
state: np.ndarray
action: int
reward: float
next_state: np.ndarray
done: bool
timestamp: datetime
context: Dict[str, Any] = field(default_factory=dict)
priority: float = 1.0 # 优先级 (用于优先回放)
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
"id": self.id,
"state": self.state.tolist(),
"action": self.action,
"reward": self.reward,
"next_state": self.next_state.tolist(),
"done": self.done,
"timestamp": self.timestamp.isoformat(),
"context": self.context,
"priority": self.priority
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'Episode':
"""从字典创建"""
return cls(
id=data["id"],
state=np.array(data["state"]),
action=data["action"],
reward=data["reward"],
next_state=np.array(data["next_state"]),
done=data["done"],
timestamp=datetime.fromisoformat(data["timestamp"]),
context=data.get("context", {}),
priority=data.get("priority", 1.0)
)
def temporal_difference_error(self, q_current: float, q_target: float) -> float:
"""计算 TD 误差(用于优先级)"""
td_error = abs(q_target - q_current)
self.priority = td_error + 1e-6 # 避免为零
return td_error
class EpisodicMemoryBuffer:
"""
情景记忆缓冲区
支持:
1. 均匀采样:随机采样经验
2. 优先采样:基于优先级采样
3. 时序采样:按时间顺序采样
"""
def __init__(self, capacity: int = 10000):
self.capacity = capacity
self.buffer: deque = deque(maxlen=capacity)
self.position = 0
self.is_full = False
# 优先回放相关
self.priorities = np.zeros(capacity)
self.alpha = 0.6 # 优先级指数
self.beta = 0.4 # 重要性采样指数
self.beta_increment = 0.001 # beta 增量
def push(self, episode: Episode):
"""存储经验"""
if len(self.buffer) < self.capacity:
self.buffer.append(episode)
else:
self.buffer[self.position] = episode
# 更新优先级
max_priority = self.priorities.max() if self.priorities.any() else 1.0
self.priorities[self.position] = max_priority
# 更新位置
self.position = (self.position + 1) % self.capacity
if self.position == 0:
self.is_full = True
def sample_uniform(self, batch_size: int) -> List[Episode]:
"""均匀采样"""
buffer_size = len(self.buffer)
indices = np.random.choice(buffer_size, batch_size, replace=False)
return [self.buffer[i] for i in indices]
def sample_prioritized(self, batch_size: int) -> Tuple[List[Episode], np.ndarray, np.ndarray]:
"""
优先采样
Returns:
episodes: 采样的经验列表
indices: 采样索引
weights: 重要性采样权重
"""
buffer_size = len(self.buffer)
# 计算采样概率
priorities = self.priorities[:buffer_size]
probs = priorities ** self.alpha
probs /= probs.sum()
# 采样
indices = np.random.choice(buffer_size, batch_size, p=probs, replace=False)
episodes = [self.buffer[i] for i in indices]
# 计算重要性采样权重
weights = (buffer_size * probs[indices]) ** (-self.beta)
weights /= weights.max() # 归一化
# 增加 beta
self.beta = min(1.0, self.beta + self.beta_increment)
return episodes, indices, weights
def update_priorities(self, indices: np.ndarray, priorities: np.ndarray):
"""更新优先级"""
for idx, priority in zip(indices, priorities):
self.priorities[idx] = priority
def sample_temporal(self, sequence_length: int) -> List[List[Episode]]:
"""时序采样(连续序列)"""
buffer_size = len(self.buffer)
if buffer_size < sequence_length:
return [list(self.buffer)]
# 随机选择起始点
start_idx = np.random.randint(0, buffer_size - sequence_length + 1)
sequences = []
for i in range(start_idx, start_idx + sequence_length):
sequences.append([self.buffer[i]])
return sequences
def __len__(self) -> int:
return len(self.buffer)
class SequentialMemory:
"""
时序记忆系统
支持:
1. LSTM 风格记忆
2. 注意力机制
3. 长期依赖建模
"""
def __init__(self, hidden_size: int = 128, sequence_length: int = 50):
self.hidden_size = hidden_size
self.sequence_length = sequence_length
# 隐藏状态
self.hidden_state = np.zeros(hidden_size)
self.cell_state = np.zeros(hidden_size)
# 序列存储
self.sequence: deque = deque(maxlen=sequence_length)
# 注意力权重
self.attention_weights = None
def update(self, input_vector: np.ndarray):
"""更新时序记忆(简化 LSTM)"""
# 简化 LSTM 更新
forget_gate = self._sigmoid(input_vector @ np.random.randn(len(input_vector), self.hidden_size))
input_gate = self._sigmoid(input_vector @ np.random.randn(len(input_vector), self.hidden_size))
output_gate = self._sigmoid(input_vector @ np.random.randn(len(input_vector), self.hidden_size))
cell_candidate = np.tanh(input_vector @ np.random.randn(len(input_vector), self.hidden_size))
# 更新细胞状态
self.cell_state = forget_gate * self.cell_state + input_gate * cell_candidate
# 更新隐藏状态
self.hidden_state = output_gate * np.tanh(self.cell_state)
# 存储序列
self.sequence.append({
"input": input_vector.copy(),
"hidden": self.hidden_state.copy(),
"cell": self.cell_state.copy(),
"timestamp": datetime.now()
})
def _sigmoid(self, x):
return 1 / (1 + np.exp(-np.clip(x, -500, 500)))
def get_attention(self, query: np.ndarray) -> np.ndarray:
"""计算注意力权重"""
if len(self.sequence) == 0:
return np.array([])
# 计算注意力分数
scores = []
for item in self.sequence:
score = np.dot(query, item["hidden"])
scores.append(score)
# Softmax
scores = np.array(scores)
exp_scores = np.exp(scores - np.max(scores))
self.attention_weights = exp_scores / exp_scores.sum()
return self.attention_weights
def get_weighted_context(self, query: np.ndarray) -> np.ndarray:
"""获取加权上下文"""
weights = self.get_attention(query)
if len(weights) == 0:
return np.zeros(self.hidden_size)
context = np.zeros(self.hidden_size)
for weight, item in zip(weights, self.sequence):
context += weight * item["hidden"]
return context
def get_sequence(self) -> List[Dict]:
"""获取完整序列"""
return list(self.sequence)
class ExperienceReplaySystem:
"""
经验回放系统
整合:
1. 情景记忆存储
2. 时序记忆建模
3. 优先经验回放
4. 强化学习优化
"""
def __init__(self,
buffer_capacity: int = 10000,
hidden_size: int = 128,
sequence_length: int = 50,
batch_size: int = 32,
gamma: float = 0.99):
# 情景记忆缓冲区
self.episodic_buffer = EpisodicMemoryBuffer(buffer_capacity)
# 时序记忆系统
self.temporal_memory = SequentialMemory(hidden_size, sequence_length)
# 参数配置
self.batch_size = batch_size
self.gamma = gamma # 折扣因子
# 统计信息
self.stats = {
"total_episodes": 0,
"total_replays": 0,
"average_reward": 0.0,
"best_reward": -float('inf')
}
def store_experience(self,
state: np.ndarray,
action: int,
reward: float,
next_state: np.ndarray,
done: bool,
context: Dict[str, Any] = None) -> Episode:
"""
存储经验
Args:
state: 当前状态
action: 动作
reward: 奖励
next_state: 下一状态
done: 是否结束
context: 上下文信息
Returns:
存储的经验片段
"""
# 创建经验片段
episode = Episode(
id=hashlib.md5(f"{datetime.now().isoformat()}{state.sum()}".encode()).hexdigest()[:16],
state=state,
action=action,
reward=reward,
next_state=next_state,
done=done,
timestamp=datetime.now(),
context=context or {}
)
# 存储到情景记忆
self.episodic_buffer.push(episode)
# 更新时序记忆
input_vector = np.concatenate([state, [action, reward]])
self.temporal_memory.update(input_vector)
# 更新统计
self.stats["total_episodes"] += 1
self.stats["average_reward"] = (
(self.stats["average_reward"] * (self.stats["total_episodes"] - 1) + reward)
/ self.stats["total_episodes"]
)
self.stats["best_reward"] = max(self.stats["best_reward"], reward)
return episode
def replay_experience(self,
use_prioritized: bool = True,
q_network: callable = None,
target_network: callable = None) -> Dict[str, Any]:
"""
经验回放
Args:
use_prioritized: 是否使用优先回放
q_network: Q 网络
target_network: 目标网络
Returns:
回放结果
"""
if len(self.episodic_buffer) < self.batch_size:
return {"success": False, "reason": "Insufficient experiences"}
# 采样经验
if use_prioritized:
episodes, indices, weights = self.episodic_buffer.sample_prioritized(self.batch_size)
else:
episodes = self.episodic_buffer.sample_uniform(self.batch_size)
indices = None
weights = None
# 提取批次数据
states = np.array([ep.state for ep in episodes])
actions = np.array([ep.action for ep in episodes])
rewards = np.array([ep.reward for ep in episodes])
next_states = np.array([ep.next_state for ep in episodes])
dones = np.array([ep.done for ep in episodes])
# 计算 TD 目标和误差(如果有 Q 网络)
td_errors = None
if q_network and target_network:
# 当前 Q 值
q_current = q_network(states)[np.arange(self.batch_size), actions]
# 目标 Q 值
next_q = target_network(next_states).max(axis=1)
q_target = rewards + self.gamma * next_q * (1 - dones)
# TD 误差
td_errors = q_target - q_current
# 更新优先级
if use_prioritized and indices is not None:
priorities = np.abs(td_errors) + 1e-6
self.episodic_buffer.update_priorities(indices, priorities)
self.stats["total_replays"] += 1
return {
"success": True,
"batch_size": len(episodes),
"states": states,
"actions": actions,
"rewards": rewards,
"next_states": next_states,
"dones": dones,
"td_errors": td_errors,
"weights": weights
}
def get_temporal_context(self, query_state: np.ndarray) -> np.ndarray:
"""获取时序上下文"""
return self.temporal_memory.get_weighted_context(query_state)
def get_stats(self) -> Dict[str, Any]:
"""获取统计信息"""
return {
**self.stats,
"buffer_size": len(self.episodic_buffer),
"buffer_capacity": self.episodic_buffer.capacity,
"average_priority": self.episodic_buffer.priorities[:len(self.episodic_buffer)].mean() if len(self.episodic_buffer) > 0 else 0
}
def export_experiences(self, filepath: str):
"""导出经验到文件"""
data = {
"episodes": [ep.to_dict() for ep in self.episodic_buffer.buffer],
"stats": self.stats
}
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
def import_experiences(self, filepath: str):
"""从文件导入经验"""
with open(filepath, 'r', encoding='utf-8') as f:
data = json.load(f)
# 清空缓冲区
self.episodic_buffer.buffer.clear()
self.episodic_buffer.position = 0
self.episodic_buffer.is_full = False
# 导入经验
for ep_dict in data.get("episodes", []):
episode = Episode.from_dict(ep_dict)
self.episodic_buffer.buffer.append(episode)
# 恢复统计
self.stats.update(data.get("stats", {}))
# 使用示例
if __name__ == "__main__":
print("=== 情景记忆、时序记忆与经验回放 ===\n")
# 创建经验回放系统
replay_system = ExperienceReplaySystem(
buffer_capacity=1000,
hidden_size=64,
sequence_length=30,
batch_size=16,
gamma=0.99
)
print("=== 存储经验 ===")
# 模拟经验存储
np.random.seed(42)
for i in range(100):
state = np.random.randn(10)
action = np.random.randint(0, 4)
reward = np.random.randn() * 0.5 + (1 if action == 0 else 0)
next_state = state + np.random.randn(10) * 0.1
done = (i % 20 == 19) # 每 20 步结束
episode = replay_system.store_experience(
state=state,
action=action,
reward=reward,
next_state=next_state,
done=done,
context={"step": i, "episode": i // 20}
)
if i % 20 == 0:
print(f"存储经验:step={i}, action={action}, reward={reward:.3f}")
print(f"\n=== 系统统计 ===")
stats = replay_system.get_stats()
print(f"总经验数:{stats['total_episodes']}")
print(f"平均奖励:{stats['average_reward']:.3f}")
print(f"最佳奖励:{stats['best_reward']:.3f}")
print(f"缓冲区大小:{stats['buffer_size']}/{stats['buffer_capacity']}")
print(f"\n=== 经验回放 ===")
# 模拟 Q 网络(简化)
def mock_q_network(states):
return np.random.randn(len(states), 4)
# 均匀回放
result_uniform = replay_system.replay_experience(
use_prioritized=False,
q_network=mock_q_network,
target_network=mock_q_network
)
print(f"均匀回放:成功={result_uniform['success']}, 批次大小={result_uniform['batch_size']}")
if result_uniform['td_errors'] is not None:
print(f" 平均 TD 误差:{np.abs(result_uniform['td_errors']).mean():.4f}")
# 优先回放
result_prioritized = replay_system.replay_experience(
use_prioritized=True,
q_network=mock_q_network,
target_network=mock_q_network
)
print(f"\n优先回放:成功={result_prioritized['success']}, 批次大小={result_prioritized['batch_size']}")
if result_prioritized['td_errors'] is not None:
print(f" 平均 TD 误差:{np.abs(result_prioritized['td_errors']).mean():.4f}")
print(f" 平均重要性权重:{result_prioritized['weights'].mean():.4f}")
print(f"\n=== 时序上下文 ===")
# 获取时序上下文
query_state = np.random.randn(10)
context = replay_system.get_temporal_context(query_state)
print(f"时序上下文维度:{context.shape}")
print(f"上下文均值:{context.mean():.4f}, 标准差:{context.std():.4f}")
print(f"\n=== 最终统计 ===")
final_stats = replay_system.get_stats()
print(f"总回放次数:{final_stats['total_replays']}")
print(f"平均优先级:{final_stats['average_priority']:.4f}")
print(f"\n关键观察:")
print("1. 情景记忆:记录 what-where-when 完整事件")
print("2. 时序记忆:建模时间依赖,捕捉因果关系")
print("3. 经验回放:存储经验,反复学习提升样本效率")
print("4. 优先回放:基于 TD 误差优先学习重要经验")
print("5. 时序上下文:基于注意力获取相关历史信息")
print("\n经验回放的核心:存储 + 重放 + 优先 + 时序 = 高效学习")