Agent 上下文压缩与缓存加速完整实现
import time
import json
import hashlib
import secrets
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
import numpy as np
from collections import OrderedDict, defaultdict
import statistics
import re
class CompressionStrategy(Enum):
"""压缩策略"""
LLMLINGUA = "llmlingua" # LLMLingua 压缩
SEMANTIC = "semantic" # 语义压缩
TOKEN_SELECTION = "token_selection" # Token 选择
SLIDING_WINDOW = "sliding_window" # 滑动窗口
ATTENTION_SINK = "attention_sink" # Attention Sink
class CacheStrategy(Enum):
"""缓存策略"""
KV_CACHE = "kv_cache" # KV 缓存
SEMANTIC_CACHE = "semantic_cache" # 语义缓存
LRU = "lru" # LRU 淘汰
LFU = "lfu" # LFU 淘汰
TTL = "ttl" # TTL 过期
@dataclass
class ContextConfig:
"""上下文配置"""
max_context_length: int
compression_strategy: CompressionStrategy
compression_ratio: float # 目标压缩率
enable_cache: bool
cache_size: int
cache_strategy: CacheStrategy
cache_ttl: int # 秒
enable_sliding_window: bool
window_size: int
enable_attention_sink: bool
sink_tokens: int
@dataclass
class ContextWindow:
"""上下文窗口"""
window_id: str
content: str
tokens: List[str]
compressed_content: str
compressed_tokens: List[str]
compression_ratio: float
original_length: int
compressed_length: int
timestamp: datetime = field(default_factory=datetime.now)
@dataclass
class CacheEntry:
"""缓存条目"""
entry_id: str
key: str
value: Any
access_count: int
last_access: datetime
created_at: datetime = field(default_factory=datetime.now)
ttl: Optional[int] = None
@dataclass
class CompressionResult:
"""压缩结果"""
original_text: str
compressed_text: str
original_tokens: int
compressed_tokens: int
compression_ratio: float
tokens_saved: int
time_saved: float # 秒
cost_saved: float # 美元
semantic_similarity: float # 语义相似度
timestamp: datetime = field(default_factory=datetime.now)
@dataclass
class CacheReport:
"""缓存报告"""
report_id: str
total_entries: int
hit_rate: float
miss_rate: float
memory_used: float # MB
evictions: int
avg_ttl: float # 秒
recommendations: List[str]
generated_at: datetime = field(default_factory=datetime.now)
class ContextCompressor:
"""
上下文压缩器
支持:
1. LLMLingua 压缩
2. 语义压缩
3. Token 选择
4. 滑动窗口
"""
def __init__(self, config: ContextConfig):
self.config = config
self.compression_history = []
def llmlingua_compress(self, text: str, target_ratio: float) -> Tuple[str, float]:
"""LLMLingua 风格压缩(简化版)"""
# 实际应使用 LLMLingua 库
# 这里模拟基于小模型的重要性评分
# 分句
sentences = re.split(r'[.!?。!?]', text)
sentences = [s.strip() for s in sentences if s.strip()]
# 简化重要性评分(基于关键词和句子长度)
important_keywords = ['关键', '重要', '核心', '主要', '必须', 'essential',
'important', 'critical', 'key', 'main']
scored_sentences = []
for sent in sentences:
score = 0
# 关键词评分
for keyword in important_keywords:
if keyword in sent.lower():
score += 2
# 长度评分(避免过短或过长)
length = len(sent.split())
if 5 <= length <= 30:
score += 1
scored_sentences.append((sent, score))
# 按重要性排序并选择
scored_sentences.sort(key=lambda x: x[1], reverse=True)
target_count = max(1, int(len(scored_sentences) * target_ratio))
selected = scored_sentences[:target_count]
# 重组文本
compressed_text = '. '.join([s[0] for s in selected]) + '.'
actual_ratio = len(compressed_text.split()) / len(text.split()) if text.split() else 1.0
return compressed_text, actual_ratio
def semantic_compress(self, text: str) -> Tuple[str, float]:
"""语义压缩(简化版)"""
# 移除冗余词汇
stop_words = {'the', 'a', 'an', 'is', 'are', 'was', 'were',
'be', 'been', 'being', 'have', 'has', 'had', 'do',
'的', '了', '是', '在', '和', '与', '或', '一个'}
words = text.split()
original_count = len(words)
compressed_words = [w for w in words if w.lower() not in stop_words]
compressed_text = ' '.join(compressed_words)
compression_ratio = len(compressed_words) / original_count if original_count > 0 else 1.0
return compressed_text, compression_ratio
def token_selection(self, tokens: List[str], target_count: int) -> List[str]:
"""Token 选择"""
if len(tokens) <= target_count:
return tokens
# 基于位置的重要性(保留开头和结尾)
head_count = target_count // 3
tail_count = target_count // 3
middle_count = target_count - head_count - tail_count
selected = []
selected.extend(tokens[:head_count])
# 中间部分均匀采样
if middle_count > 0 and len(tokens) > head_count + tail_count:
middle_tokens = tokens[head_count:len(tokens)-tail_count]
step = len(middle_tokens) / middle_count
for i in range(middle_count):
idx = int(i * step)
if idx < len(middle_tokens):
selected.append(middle_tokens[idx])
selected.extend(tokens[-tail_count:])
return selected
def sliding_window_compress(self, text: str, window_size: int) -> str:
"""滑动窗口压缩"""
tokens = text.split()
if len(tokens) <= window_size:
return text
# 保留最后 window_size 个 token
return ' '.join(tokens[-window_size:])
def compress(self, text: str) -> CompressionResult:
"""执行压缩"""
start_time = time.time()
original_tokens = text.split()
original_count = len(original_tokens)
# 根据策略选择压缩方法
if self.config.compression_strategy == CompressionStrategy.LLMLINGUA:
compressed_text, actual_ratio = self.llmlingua_compress(
text, self.config.compression_ratio
)
elif self.config.compression_strategy == CompressionStrategy.SEMANTIC:
compressed_text, actual_ratio = self.semantic_compress(text)
elif self.config.compression_strategy == CompressionStrategy.SLIDING_WINDOW:
compressed_text = self.sliding_window_compress(
text, self.config.window_size
)
actual_ratio = len(compressed_text.split()) / original_count if original_count > 0 else 1.0
else:
compressed_text = text
actual_ratio = 1.0
compressed_tokens = compressed_text.split()
compressed_count = len(compressed_tokens)
compression_time = time.time() - start_time
# 估算节省
tokens_saved = original_count - compressed_count
time_saved = tokens_saved * 0.001 # 假设每个 token 节省 1ms
cost_saved = tokens_saved * 0.0001 # 假设每 1k tokens $0.0001
# 语义相似度(简化估算)
semantic_similarity = max(0.7, 1.0 - (1.0 - actual_ratio) * 0.5)
result = CompressionResult(
original_text=text,
compressed_text=compressed_text,
original_tokens=original_count,
compressed_tokens=compressed_count,
compression_ratio=actual_ratio,
tokens_saved=tokens_saved,
time_saved=time_saved,
cost_saved=cost_saved,
semantic_similarity=semantic_similarity
)
self.compression_history.append(result)
return result
def get_statistics(self) -> Dict[str, Any]:
"""获取统计信息"""
if not self.compression_history:
return {"total_compressions": 0}
avg_ratio = statistics.mean([r.compression_ratio for r in self.compression_history])
avg_saved = statistics.mean([r.tokens_saved for r in self.compression_history])
total_saved = sum([r.tokens_saved for r in self.compression_history])
return {
"total_compressions": len(self.compression_history),
"avg_compression_ratio": avg_ratio,
"avg_tokens_saved": avg_saved,
"total_tokens_saved": total_saved,
"timestamp": datetime.now().isoformat()
}
class SemanticCache:
"""
语义缓存系统
支持:
1. KV 缓存
2. 语义缓存
3. LRU/LFU 淘汰
4. TTL 过期
"""
def __init__(self, config: ContextConfig):
self.config = config
self.cache: OrderedDict[str, CacheEntry] = OrderedDict()
self.hits = 0
self.misses = 0
self.evictions = 0
def _generate_key(self, text: str) -> str:
"""生成缓存键"""
return hashlib.md5(text.encode()).hexdigest()
def _semantic_similarity(self, text1: str, text2: str) -> float:
"""计算语义相似度(简化版)"""
# 实际应使用 embedding 模型
words1 = set(text1.lower().split())
words2 = set(text2.lower().split())
if not words1 or not words2:
return 0.0
intersection = words1 & words2
union = words1 | words2
return len(intersection) / len(union) if union else 0.0
def get(self, text: str, threshold: float = 0.9) -> Optional[Any]:
"""获取缓存"""
if not self.config.enable_cache:
return None
query_key = self._generate_key(text)
# 精确匹配
if query_key in self.cache:
entry = self.cache[query_key]
# 检查 TTL
if entry.ttl:
age = (datetime.now() - entry.created_at).total_seconds()
if age > entry.ttl:
self._evict(query_key)
return None
# 更新访问统计
entry.access_count += 1
entry.last_access = datetime.now()
self.cache.move_to_end(query_key)
self.hits += 1
return entry.value
# 语义匹配(简化版)
for key, entry in self.cache.items():
similarity = self._semantic_similarity(text, entry.key)
if similarity >= threshold:
# 更新访问统计
entry.access_count += 1
entry.last_access = datetime.now()
self.cache.move_to_end(key)
self.hits += 1
return entry.value
self.misses += 1
return None
def put(self, text: str, value: Any, ttl: Optional[int] = None):
"""存入缓存"""
if not self.config.enable_cache:
return
key = self._generate_key(text)
# 检查容量
while len(self.cache) >= self.config.cache_size:
self._evict_oldest()
# 创建条目
entry = CacheEntry(
entry_id=f"entry_{secrets.token_hex(8)}",
key=text,
value=value,
access_count=1,
last_access=datetime.now(),
ttl=ttl or self.config.cache_ttl
)
self.cache[key] = entry
def _evict(self, key: str):
"""淘汰缓存条目"""
if key in self.cache:
del self.cache[key]
self.evictions += 1
def _evict_oldest(self):
"""淘汰最旧的条目(LRU)"""
if self.cache:
oldest_key = next(iter(self.cache))
self._evict(oldest_key)
def _evict_lfu(self):
"""淘汰最少使用的条目(LFU)"""
if self.cache:
lfu_key = min(self.cache.keys(),
key=lambda k: self.cache[k].access_count)
self._evict(lfu_key)
def get_hit_rate(self) -> float:
"""获取命中率"""
total = self.hits + self.misses
return self.hits / total if total > 0 else 0.0
def generate_report(self) -> CacheReport:
"""生成缓存报告"""
report_id = f"cache_report_{secrets.token_hex(16)}"
hit_rate = self.get_hit_rate()
miss_rate = 1.0 - hit_rate
# 估算内存使用
memory_used = len(self.cache) * 0.5 # 假设每个条目 0.5MB
# 平均 TTL
ttls = [e.ttl for e in self.cache.values() if e.ttl]
avg_ttl = statistics.mean(ttls) if ttls else 0.0
# 生成建议
recommendations = []
if hit_rate < 0.5:
recommendations.append("缓存命中率较低,建议增加缓存容量或优化缓存策略")
if hit_rate > 0.9:
recommendations.append("缓存命中率很高,可以考虑减少缓存容量以节省内存")
if self.evictions > len(self.cache) * 2:
recommendations.append("淘汰次数过多,建议增加缓存容量")
if not recommendations:
recommendations.append("缓存性能良好,继续保持当前配置")
report = CacheReport(
report_id=report_id,
total_entries=len(self.cache),
hit_rate=hit_rate,
miss_rate=miss_rate,
memory_used=memory_used,
evictions=self.evictions,
avg_ttl=avg_ttl,
recommendations=recommendations
)
return report
# 使用示例
if __name__ == "__main__":
print("=== Agent 上下文压缩与缓存加速 ===\n")
print("=== 创建上下文压缩与缓存系统 ===")
# 上下文配置
config = ContextConfig(
max_context_length=8192,
compression_strategy=CompressionStrategy.LLMLINGUA,
compression_ratio=0.6,
enable_cache=True,
cache_size=100,
cache_strategy=CacheStrategy.SEMANTIC,
cache_ttl=3600,
enable_sliding_window=False,
window_size=2048,
enable_attention_sink=True,
sink_tokens=4
)
compressor = ContextCompressor(config)
cache = SemanticCache(config)
print(f"最大上下文长度:{config.max_context_length}")
print(f"压缩策略:{config.compression_strategy.value}")
print(f"目标压缩率:{config.compression_ratio}")
print(f"缓存启用:{config.enable_cache}")
print(f"缓存大小:{config.cache_size}")
print(f"缓存策略:{config.cache_strategy.value}\n")
# 测试文本
test_texts = [
"The quick brown fox jumps over the lazy dog. This is a very important sentence that contains key information. The essential point is that we need to understand the main idea.",
"人工智能是当今科技发展的核心驱动力。关键技术包括机器学习、深度学习和自然语言处理。重要应用涵盖智能助手、自动驾驶和医疗诊断。",
"In the realm of large language models, context management is critical. The important aspects include compression, caching, and optimization. Key techniques involve token selection and semantic preservation.",
]
print("=== 测试上下文压缩 ===\n")
for i, text in enumerate(test_texts):
print(f"文本 {i+1}:")
print(f" 原文:{text[:80]}...")
print(f" 原文 Token 数:{len(text.split())}")
# 执行压缩
result = compressor.compress(text)
print(f" 压缩后:{result.compressed_text[:80]}...")
print(f" 压缩后 Token 数:{result.compressed_tokens}")
print(f" 压缩率:{result.compression_ratio:.2%}")
print(f" 节省 Token 数:{result.tokens_saved}")
print(f" 节省时间:{result.time_saved*1000:.1f}ms")
print(f" 节省成本:${result.cost_saved:.6f}")
print(f" 语义相似度:{result.semantic_similarity:.2%}")
print()
# 测试缓存
cache.put(result.compressed_text, f"response_{i}")
print("=== 测试缓存系统 ===\n")
# 测试缓存命中
for i, text in enumerate(test_texts[:2]):
result = compressor.compress(text)
cached_response = cache.get(result.compressed_text)
print(f"查询 {i+1}:")
print(f" 缓存命中:{cached_response is not None}")
if cached_response:
print(f" 缓存响应:{cached_response}")
print()
# 生成缓存报告
print("=== 缓存报告 ===")
report = cache.generate_report()
print(f"总条目数:{report.total_entries}")
print(f"命中率:{report.hit_rate:.2%}")
print(f"未命中率:{report.miss_rate:.2%}")
print(f"内存使用:{report.memory_used:.1f}MB")
print(f"淘汰次数:{report.evictions}")
print(f"平均 TTL: {report.avg_ttl:.0f}s")
print(f"\n建议:")
for rec in report.recommendations:
print(f" - {rec}")
print(f"\n压缩统计:")
stats = compressor.get_statistics()
print(f" 总压缩次数:{stats.get('total_compressions', 0)}")
print(f" 平均压缩率:{stats.get('avg_compression_ratio', 0):.2%}")
print(f" 平均节省 Token: {stats.get('avg_tokens_saved', 0):.1f}")
print(f" 总节省 Token: {stats.get('total_tokens_saved', 0)}")
print(f"\n关键观察:")
print("1. 上下文管理:窗口优化、密度提升、关键保留")
print("2. 上下文压缩:LLMLingua、语义压缩、Token 选择")
print("3. 缓存加速:KV 缓存、语义缓存、LRU/LFU")
print("4. 系统优化:滑动窗口、Attention Sink、弹性缓存")
print("5. 精简高效:管理 + 压缩 + 缓存 + 优化 = 可信赖")
print("\n精简高效的使命:让 AI 上下文更精简、更高效、更智能")