Agent 推理性能与 Token 效率优化完整实现
import time
import json
import hashlib
import secrets
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
import numpy as np
from collections import defaultdict
import statistics
class OptimizationStrategy(Enum):
"""优化策略"""
SPECULATIVE_DECODING = "speculative_decoding" # 推测解码
KV_CACHE_OPTIMIZATION = "kv_cache_optimization" # KV 缓存优化
PROMPT_COMPRESSION = "prompt_compression" # 提示压缩
BATCH_OPTIMIZATION = "batch_optimization" # 批处理优化
QUANTIZATION = "quantization" # 量化
DYNAMIC_BATCHING = "dynamic_batching" # 动态批处理
class PerformanceMetric(Enum):
"""性能指标"""
THROUGHPUT = "throughput" # 吞吐量 (tokens/s)
LATENCY = "latency" # 延迟 (ms)
TTFT = "ttft" # 首 Token 时间 (ms)
TPOT = "tpot" # 每 Token 时间 (ms/token)
MEMORY_USAGE = "memory_usage" # 内存使用 (MB)
COST_PER_TOKEN = "cost_per_token" # 每 Token 成本
@dataclass
class PerformanceConfig:
"""性能配置"""
max_batch_size: int
max_context_length: int
enable_speculative_decoding: bool
speculative_draft_model: Optional[str]
enable_kv_cache_optimization: bool
kv_cache_size: int
enable_prompt_compression: bool
compression_ratio: float
quantization_bits: int # 4/8/16
dynamic_batching: bool
max_wait_time: float # 动态批处理最大等待时间 (ms)
@dataclass
class InferenceRequest:
"""推理请求"""
request_id: str
prompt: str
max_tokens: int
temperature: float
top_p: float
timestamp: datetime = field(default_factory=datetime.now)
@dataclass
class InferenceResponse:
"""推理响应"""
request_id: str
generated_text: str
tokens_generated: int
tokens_used: int
inference_time: float # 秒
ttft: float # 首 Token 时间 (ms)
tpot: float # 每 Token 时间 (ms/token)
throughput: float # tokens/s
memory_used: float # MB
cost: float # 美元
timestamp: datetime = field(default_factory=datetime.now)
@dataclass
class PerformanceReport:
"""性能报告"""
report_id: str
total_requests: int
avg_throughput: float
avg_latency: float
avg_ttft: float
avg_tpot: float
total_tokens: int
total_cost: float
optimization_savings: Dict[str, float]
recommendations: List[str]
generated_at: datetime = field(default_factory=datetime.now)
class PerformanceOptimizer:
"""
性能优化器
支持:
1. 推测解码
2. KV 缓存优化
3. 提示压缩
4. 批处理优化
"""
def __init__(self, config: PerformanceConfig):
self.config = config
self.kv_cache = {}
self.request_history = []
self.performance_metrics = defaultdict(list)
def compress_prompt(self, prompt: str) -> Tuple[str, float]:
"""提示压缩"""
if not self.config.enable_prompt_compression:
return prompt, 1.0
# 简化版提示压缩(实际应使用 LLMLingua 等)
words = prompt.split()
original_length = len(words)
# 移除冗余词汇
stop_words = {'the', 'a', 'an', 'is', 'are', 'was', 'were',
'be', 'been', 'being', 'have', 'has', 'had', 'do'}
compressed_words = [w for w in words if w.lower() not in stop_words]
# 应用压缩率
target_length = int(len(compressed_words) * self.config.compression_ratio)
compressed_words = compressed_words[:target_length]
compressed_prompt = ' '.join(compressed_words)
compression_ratio = len(compressed_words) / original_length if original_length > 0 else 1.0
return compressed_prompt, compression_ratio
def speculative_decode(self, prompt: str, draft_tokens: int = 5) -> Tuple[List[str], int, int]:
"""推测解码模拟"""
if not self.config.enable_speculative_decoding:
# 标准自回归解码
return self._standard_decode(prompt), 0, 0
# 推测解码:小模型生成 draft tokens
draft_sequence = self._draft_decode(prompt, draft_tokens)
# 大模型验证 draft tokens
accepted_tokens, rejected_count = self._verify_draft(prompt, draft_sequence)
# 计算加速比
speedup = len(draft_sequence) / (len(draft_sequence) + rejected_count) if rejected_count > 0 else 2.0
return accepted_tokens, len(draft_sequence), rejected_count
def _standard_decode(self, prompt: str) -> List[str]:
"""标准自回归解码模拟"""
# 模拟生成 tokens
return [f"token_{i}" for i in range(10)]
def _draft_decode(self, prompt: str, num_tokens: int) -> List[str]:
"""Draft 解码模拟"""
return [f"draft_token_{i}" for i in range(num_tokens)]
def _verify_draft(self, prompt: str, draft_sequence: List[str]) -> Tuple[List[str], int]:
"""验证 draft tokens"""
# 模拟验证过程(假设 80% 接受率)
accepted = []
rejected = 0
for token in draft_sequence:
if np.random.random() < 0.8: # 80% 接受率
accepted.append(token)
else:
rejected += 1
return accepted, rejected
def optimize_kv_cache(self, request_id: str, context: str) -> bool:
"""KV 缓存优化"""
if not self.config.enable_kv_cache_optimization:
return False
# 检查是否已缓存
context_hash = hashlib.md5(context.encode()).hexdigest()
if context_hash in self.kv_cache:
# 缓存命中
return True
# 缓存未命中,存储新 KV
if len(self.kv_cache) >= self.config.kv_cache_size:
# LRU 淘汰
oldest_key = next(iter(self.kv_cache))
del self.kv_cache[oldest_key]
self.kv_cache[context_hash] = {
'context': context,
'timestamp': datetime.now(),
'size': len(context)
}
return False
def dynamic_batching(self, requests: List[InferenceRequest]) -> List[List[InferenceRequest]]:
"""动态批处理"""
if not self.config.dynamic_batching:
return [[req] for req in requests]
batches = []
current_batch = []
current_tokens = 0
for request in requests:
estimated_tokens = len(request.prompt.split()) + request.max_tokens
if (len(current_batch) < self.config.max_batch_size and
current_tokens + estimated_tokens <= self.config.max_context_length):
current_batch.append(request)
current_tokens += estimated_tokens
else:
if current_batch:
batches.append(current_batch)
current_batch = [request]
current_tokens = estimated_tokens
if current_batch:
batches.append(current_batch)
return batches
def estimate_cost(self, tokens: int, model_size: str = '7b') -> float:
"""估算成本"""
# 简化成本模型(实际应基于云厂商定价)
cost_per_1k_tokens = {
'7b': 0.0001,
'13b': 0.0002,
'70b': 0.0008,
'405b': 0.003
}
return (tokens / 1000) * cost_per_1k_tokens.get(model_size, 0.0001)
def infer(self, request: InferenceRequest) -> InferenceResponse:
"""执行推理"""
start_time = time.time()
# 提示压缩
compressed_prompt, compression_ratio = self.compress_prompt(request.prompt)
tokens_saved_compression = int(len(request.prompt.split()) * (1 - compression_ratio))
# KV 缓存优化
cache_hit = self.optimize_kv_cache(request.request_id, compressed_prompt)
# 推测解码
generated_tokens, draft_count, rejected_count = self.speculative_decode(compressed_prompt)
speedup_from_speculative = draft_count / (draft_count + rejected_count) if (draft_count + rejected_count) > 0 else 1.0
# 计算指标
inference_time = time.time() - start_time
tokens_generated = len(generated_tokens)
tokens_used = int(len(compressed_prompt.split()) + tokens_generated)
# 计算延迟指标
ttft = np.random.uniform(10, 50) # 模拟首 Token 时间 (ms)
tpot = (inference_time * 1000) / tokens_generated if tokens_generated > 0 else 0
throughput = tokens_generated / inference_time if inference_time > 0 else 0
# 内存使用模拟
memory_used = np.random.uniform(2000, 8000) # MB
# 成本计算
cost = self.estimate_cost(tokens_used)
# 应用优化节省
tokens_saved_speculative = int(draft_count * (1 - speedup_from_speculative))
total_tokens_saved = tokens_saved_compression + tokens_saved_speculative
response = InferenceResponse(
request_id=request.request_id,
generated_text=' '.join(generated_tokens),
tokens_generated=tokens_generated,
tokens_used=tokens_used,
inference_time=inference_time,
ttft=ttft,
tpot=tpot,
throughput=throughput,
memory_used=memory_used,
cost=cost
)
# 记录历史
self.request_history.append(response)
self.performance_metrics['throughput'].append(throughput)
self.performance_metrics['latency'].append(inference_time * 1000)
self.performance_metrics['ttft'].append(ttft)
self.performance_metrics['tpot'].append(tpot)
return response
def generate_performance_report(self) -> PerformanceReport:
"""生成性能报告"""
report_id = f"perf_report_{secrets.token_hex(16)}"
if not self.request_history:
return PerformanceReport(
report_id=report_id,
total_requests=0,
avg_throughput=0,
avg_latency=0,
avg_ttft=0,
avg_tpot=0,
total_tokens=0,
total_cost=0,
optimization_savings={},
recommendations=[]
)
# 计算平均指标
avg_throughput = statistics.mean(self.performance_metrics['throughput'])
avg_latency = statistics.mean(self.performance_metrics['latency'])
avg_ttft = statistics.mean(self.performance_metrics['ttft'])
avg_tpot = statistics.mean(self.performance_metrics['tpot'])
total_tokens = sum(r.tokens_used for r in self.request_history)
total_cost = sum(r.cost for r in self.request_history)
# 计算优化节省
optimization_savings = {
'prompt_compression': sum(len(r.request_id) for r in self.request_history) * 0.01,
'speculative_decoding': total_tokens * 0.15,
'kv_cache_optimization': total_tokens * 0.05,
'batch_optimization': total_tokens * 0.10
}
# 生成建议
recommendations = []
if avg_latency > 1000:
recommendations.append("延迟较高,建议启用推测解码和动态批处理")
if avg_throughput < 50:
recommendations.append("吞吐量较低,建议增大批处理大小")
if avg_tpot > 50:
recommendations.append("每 Token 时间较长,建议优化 KV 缓存")
if not recommendations:
recommendations.append("性能表现良好,继续保持当前配置")
report = PerformanceReport(
report_id=report_id,
total_requests=len(self.request_history),
avg_throughput=avg_throughput,
avg_latency=avg_latency,
avg_ttft=avg_ttft,
avg_tpot=avg_tpot,
total_tokens=total_tokens,
total_cost=total_cost,
optimization_savings=optimization_savings,
recommendations=recommendations
)
return report
def get_statistics(self) -> Dict[str, Any]:
"""获取统计信息"""
return {
"total_requests": len(self.request_history),
"avg_throughput": statistics.mean(self.performance_metrics['throughput']) if self.performance_metrics['throughput'] else 0,
"avg_latency": statistics.mean(self.performance_metrics['latency']) if self.performance_metrics['latency'] else 0,
"cache_size": len(self.kv_cache),
"timestamp": datetime.now().isoformat()
}
# 使用示例
if __name__ == "__main__":
print("=== Agent 推理性能与 Token 效率优化 ===\n")
print("=== 创建性能优化器 ===")
# 性能配置
config = PerformanceConfig(
max_batch_size=32,
max_context_length=8192,
enable_speculative_decoding=True,
speculative_draft_model='small_model',
enable_kv_cache_optimization=True,
kv_cache_size=1000,
enable_prompt_compression=True,
compression_ratio=0.7,
quantization_bits=8,
dynamic_batching=True,
max_wait_time=10.0
)
optimizer = PerformanceOptimizer(config)
print(f"批处理大小:{config.max_batch_size}")
print(f"上下文长度:{config.max_context_length}")
print(f"推测解码:{config.enable_speculative_decoding}")
print(f"KV 缓存优化:{config.enable_kv_cache_optimization}")
print(f"提示压缩:{config.enable_prompt_compression}")
print(f"压缩率:{config.compression_ratio}")
print(f"量化位数:{config.quantization_bits}bit\n")
# 测试请求
test_prompts = [
"Explain quantum computing in simple terms",
"Write a Python function to sort a list",
"What are the benefits of exercise?",
"Generate a haiku about nature",
"Translate 'Hello, how are you?' to French",
]
print("=== 测试性能优化 ===\n")
for i, prompt in enumerate(test_prompts):
request = InferenceRequest(
request_id=f"req_{i}",
prompt=prompt,
max_tokens=50,
temperature=0.7,
top_p=0.9
)
print(f"请求 {i+1}: {prompt[:50]}...")
# 执行推理
response = optimizer.infer(request)
print(f" 请求 ID: {response.request_id}")
print(f" 生成 Token 数:{response.tokens_generated}")
print(f" 使用 Token 数:{response.tokens_used}")
print(f" 推理时间:{response.inference_time:.3f}s")
print(f" 首 Token 时间:{response.ttft:.1f}ms")
print(f" 每 Token 时间:{response.tpot:.1f}ms")
print(f" 吞吐量:{response.throughput:.1f} tokens/s")
print(f" 内存使用:{response.memory_used:.0f}MB")
print(f" 成本:${response.cost:.6f}")
print()
print("=== 性能报告 ===")
report = optimizer.generate_performance_report()
print(f"总请求数:{report.total_requests}")
print(f"平均吞吐量:{report.avg_throughput:.1f} tokens/s")
print(f"平均延迟:{report.avg_latency:.1f}ms")
print(f"平均首 Token 时间:{report.avg_ttft:.1f}ms")
print(f"平均每 Token 时间:{report.avg_tpot:.1f}ms")
print(f"总 Token 数:{report.total_tokens}")
print(f"总成本:${report.total_cost:.6f}")
print(f"\n优化节省:")
for strategy, savings in report.optimization_savings.items():
print(f" {strategy}: {savings:.2f} tokens")
print(f"\n建议:")
for rec in report.recommendations:
print(f" - {rec}")
print(f"\n关键观察:")
print("1. 推理性能:吞吐量、延迟、首 Token 时间、每 Token 时间")
print("2. Token 效率:提示压缩、上下文优化、Token 选择")
print("3. 推理加速:推测解码、KV 缓存优化、批处理")
print("4. 系统优化:vLLM/TGI、量化、分布式、弹性伸缩")
print("5. 高效优化:性能 + 效率 + 加速 + 系统 = 可信赖")
print("\n高效优化的使命:让 AI 推理更快、更省、更高效")