工作记忆与注意力系统完整实现
import numpy as np
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
import math
class MemoryComponent(Enum):
"""工作记忆组件"""
PHONOLOGICAL = "phonological" # 语音回路
VISUOSPATIAL = "visuospatial" # 视觉空间模板
EPISODIC_BUFFER = "episodic" # 情景缓冲区
CENTRAL_EXECUTIVE = "executive" # 中央执行系统
class AttentionType(Enum):
"""注意力类型"""
SELECTIVE = "selective" # 选择性注意
DIVIDED = "divided" # 分配性注意
SUSTAINED = "sustained" # 持续性注意
EXECUTIVE = "executive" # 执行性注意
@dataclass
class MemoryChunk:
"""记忆组块"""
id: str
content: Any
component: MemoryComponent
activation: float = 1.0 # 激活水平 (0-1)
decay_rate: float = 0.01 # 衰减率
timestamp: datetime = field(default_factory=datetime.now)
rehearsal_count: int = 0 # 复述次数
def decay(self, time_delta: float) -> float:
"""计算衰减后的激活水平"""
self.activation *= math.exp(-self.decay_rate * time_delta)
return self.activation
def rehearse(self) -> None:
"""复述刷新"""
self.activation = min(1.0, self.activation + 0.3)
self.rehearsal_count += 1
self.timestamp = datetime.now()
@dataclass
class AttentionFocus:
"""注意力焦点"""
target_id: str
attention_type: AttentionType
intensity: float = 1.0 # 注意强度 (0-1)
duration: float = 0.0 # 持续时间 (秒)
priority: float = 0.5 # 优先级 (0-1)
class WorkingMemorySystem:
"""
工作记忆系统
基于 Baddeley 多成分模型:
1. 语音回路:存储语音信息
2. 视觉空间模板:存储视觉空间信息
3. 情景缓冲区:整合多模态信息
4. 中央执行系统:注意控制、资源协调
"""
def __init__(self,
capacity: int = 7, # Miller's 7±2
decay_time: float = 30.0): # 衰减时间 (秒)
# 各组件记忆存储
self.memory_stores: Dict[MemoryComponent, List[MemoryChunk]] = {
comp: [] for comp in MemoryComponent
}
# 注意力焦点
self.attention_focuses: List[AttentionFocus] = []
# 系统参数
self.capacity = capacity # 总容量
self.decay_time = decay_time # 衰减时间
# 统计信息
self.stats = {
"total_encoded": 0,
"total_rehearsed": 0,
"total_forgotten": 0,
"attention_shifts": 0
}
def encode_chunk(self,
content: Any,
component: MemoryComponent,
priority: float = 0.5) -> Optional[MemoryChunk]:
"""
编码记忆组块
Args:
content: 内容
component: 记忆组件
priority: 优先级
Returns:
创建的记忆组块,如果容量已满返回 None
"""
# 检查容量
total_chunks = sum(len(store) for store in self.memory_stores.values())
if total_chunks >= self.capacity:
# 容量已满,尝试清除最低激活水平的组块
self._clear_lowest_activation()
# 创建组块
chunk_id = f"{component.value}_{len(self.memory_stores[component])}"
chunk = MemoryChunk(
id=chunk_id,
content=content,
component=component,
activation=min(1.0, 0.5 + priority * 0.5)
)
# 添加到对应组件
self.memory_stores[component].append(chunk)
self.stats["total_encoded"] += 1
return chunk
def _clear_lowest_activation(self) -> None:
"""清除激活水平最低的组块"""
all_chunks = []
for comp, chunks in self.memory_stores.items():
for chunk in chunks:
all_chunks.append((chunk, comp))
if not all_chunks:
return
# 找到激活水平最低的组块
all_chunks.sort(key=lambda x: x[0].activation)
lowest_chunk, lowest_comp = all_chunks[0]
# 移除
self.memory_stores[lowest_comp].remove(lowest_chunk)
self.stats["total_forgotten"] += 1
def focus_attention(self,
target_id: str,
attention_type: AttentionType,
intensity: float = 1.0,
priority: float = 0.5) -> bool:
"""
聚焦注意力
Args:
target_id: 目标组块 ID
attention_type: 注意力类型
intensity: 注意强度
priority: 优先级
Returns:
是否成功
"""
# 查找目标组块
target_chunk = self._find_chunk(target_id)
if target_chunk is None:
return False
# 创建注意力焦点
focus = AttentionFocus(
target_id=target_id,
attention_type=attention_type,
intensity=intensity,
priority=priority
)
# 添加焦点
self.attention_focuses.append(focus)
# 增强目标组块激活
target_chunk.activation = min(1.0, target_chunk.activation + intensity * 0.3)
self.stats["attention_shifts"] += 1
return True
def _find_chunk(self, chunk_id: str) -> Optional[MemoryChunk]:
"""查找组块"""
for chunks in self.memory_stores.values():
for chunk in chunks:
if chunk.id == chunk_id:
return chunk
return None
def rehearse_chunk(self, chunk_id: str) -> bool:
"""
复述组块(刷新记忆)
Args:
chunk_id: 组块 ID
Returns:
是否成功
"""
chunk = self._find_chunk(chunk_id)
if chunk is None:
return False
chunk.rehearse()
self.stats["total_rehearsed"] += 1
return True
def apply_decay(self, time_delta: float) -> Dict[MemoryComponent, int]:
"""
应用衰减
Args:
time_delta: 时间增量 (秒)
Returns:
各组件遗忘数量
"""
forgotten = {comp: 0 for comp in MemoryComponent}
for comp, chunks in self.memory_stores.items():
to_remove = []
for chunk in chunks:
# 衰减
new_activation = chunk.decay(time_delta)
# 检查是否低于阈值
if new_activation < 0.1:
to_remove.append(chunk)
# 移除遗忘的组块
for chunk in to_remove:
chunks.remove(chunk)
forgotten[comp] += 1
self.stats["total_forgotten"] += 1
return forgotten
def get_active_chunks(self,
min_activation: float = 0.5,
component: MemoryComponent = None) -> List[MemoryChunk]:
"""
获取活跃组块
Args:
min_activation: 最小激活水平
component: 指定组件 (可选)
Returns:
活跃组块列表
"""
active = []
components = [component] if component else list(MemoryComponent)
for comp in components:
for chunk in self.memory_stores[comp]:
if chunk.activation >= min_activation:
active.append(chunk)
# 按激活水平排序
active.sort(key=lambda x: x.activation, reverse=True)
return active
def get_attention_distribution(self) -> Dict[str, float]:
"""获取注意力分布"""
distribution = {}
for focus in self.attention_focuses:
if focus.target_id in distribution:
distribution[focus.target_id] += focus.intensity * focus.priority
else:
distribution[focus.target_id] = focus.intensity * focus.priority
# 归一化
total = sum(distribution.values())
if total > 0:
distribution = {k: v/total for k, v in distribution.items()}
return distribution
def shift_attention(self,
from_id: str,
to_id: str,
shift_speed: float = 1.0) -> bool:
"""
切换注意力
Args:
from_id: 原焦点
to_id: 新焦点
shift_speed: 切换速度
Returns:
是否成功
"""
# 移除原焦点
self.attention_focuses = [
f for f in self.attention_focuses if f.target_id != from_id
]
# 添加新焦点
return self.focus_attention(to_id, AttentionType.SELECTIVE, 1.0 * shift_speed)
def get_cognitive_load(self) -> float:
"""
计算认知负荷
Returns:
认知负荷 (0-1)
"""
total_chunks = sum(len(store) for store in self.memory_stores.values())
return min(1.0, total_chunks / self.capacity)
def get_stats(self) -> Dict[str, Any]:
"""获取统计信息"""
return {
**self.stats,
"total_chunks": sum(len(store) for store in self.memory_stores.values()),
"chunk_distribution": {
comp.value: len(chunks)
for comp, chunks in self.memory_stores.items()
},
"attention_focuses": len(self.attention_focuses),
"cognitive_load": self.get_cognitive_load()
}
# 使用示例
if __name__ == "__main__":
print("=== 短期工作记忆与注意力机制 ===\n")
# 创建工作记忆系统
wm_system = WorkingMemorySystem(capacity=7, decay_time=30.0)
print("=== 编码记忆组块 ===")
# 编码语音信息
chunk1 = wm_system.encode_chunk(
content="电话号码:123-4567",
component=MemoryComponent.PHONOLOGICAL,
priority=0.8
)
print(f"编码语音:{chunk1.content}")
print(f"激活水平:{chunk1.activation:.2f}")
# 编码视觉信息
chunk2 = wm_system.encode_chunk(
content="地图位置: (x=5, y=3)",
component=MemoryComponent.VISUOSPATIAL,
priority=0.6
)
print(f"\n编码视觉:{chunk2.content}")
print(f"激活水平:{chunk2.activation:.2f}")
# 编码多个组块
for i in range(5):
wm_system.encode_chunk(
content=f"任务项 {i+1}: 完成报告第{i+1}部分",
component=MemoryComponent.EPISODIC_BUFFER,
priority=0.5
)
print(f"\n=== 工作记忆统计 ===")
stats = wm_system.get_stats()
print(f"总组块数:{stats['total_chunks']}")
print(f"组件分布:{stats['chunk_distribution']}")
print(f"认知负荷:{stats['cognitive_load']:.2f}")
print(f"\n=== 注意力聚焦 ===")
# 聚焦注意力
wm_system.focus_attention(
target_id=chunk1.id,
attention_type=AttentionType.SELECTIVE,
intensity=0.9,
priority=0.8
)
print(f"聚焦注意力到:{chunk1.content}")
wm_system.focus_attention(
target_id=chunk2.id,
attention_type=AttentionType.SUSTAINED,
intensity=0.7,
priority=0.6
)
print(f"持续注意:{chunk2.content}")
# 获取注意力分布
distribution = wm_system.get_attention_distribution()
print(f"\n注意力分布:")
for target_id, intensity in distribution.items():
chunk = wm_system._find_chunk(target_id)
if chunk:
print(f" {chunk.content}: {intensity:.2f}")
print(f"\n=== 复述刷新 ===")
# 复述重要组块
wm_system.rehearse_chunk(chunk1.id)
print(f"复述后 {chunk1.content} 激活水平:{chunk1.activation:.2f}")
print(f"\n=== 应用衰减 ===")
# 模拟时间流逝
forgotten = wm_system.apply_decay(time_delta=10.0)
print(f"10 秒后遗忘组块:{forgotten}")
# 获取活跃组块
active_chunks = wm_system.get_active_chunks(min_activation=0.3)
print(f"\n活跃组块 ({len(active_chunks)}个):")
for chunk in active_chunks:
print(f" - {chunk.content} (激活:{chunk.activation:.2f}, 复述:{chunk.rehearsal_count})")
print(f"\n=== 注意力切换 ===")
# 切换注意力
if len(active_chunks) >= 2:
wm_system.shift_attention(
from_id=active_chunks[0].id,
to_id=active_chunks[1].id,
shift_speed=0.8
)
print(f"注意力从 {active_chunks[0].content} 切换到 {active_chunks[1].content}")
print(f"\n=== 最终统计 ===")
final_stats = wm_system.get_stats()
print(f"总编码:{final_stats['total_encoded']}")
print(f"总复述:{final_stats['total_rehearsed']}")
print(f"总遗忘:{final_stats['total_forgotten']}")
print(f"注意力切换:{final_stats['attention_shifts']}")
print(f"当前认知负荷:{final_stats['cognitive_load']:.2f}")
print(f"\n关键观察:")
print("1. 工作记忆容量有限:7±2 个组块")
print("2. 多成分模型:语音 + 视觉 + 情景 + 执行")
print("3. 注意力聚焦:选择性增强目标激活")
print("4. 衰减机制:不复习就会遗忘")
print("5. 复述刷新:主动复习维持记忆")
print("6. 认知负荷:容量使用率决定负荷")
print("\n工作记忆的核心:有限容量 + 注意选择 + 主动维持")