Agent 调试器核心实现
import asyncio
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, field
from datetime import datetime
import json
@dataclass
class DebugSession:
"""调试会话数据模型"""
session_id: str
trace_id: str
start_time: datetime
breakpoints: List[Dict[str, Any]] = field(default_factory=list)
variables: Dict[str, Any] = field(default_factory=dict)
call_stack: List[Dict[str, Any]] = field(default_factory=list)
events: List[Dict[str, Any]] = field(default_factory=list)
class AgentDebugger:
"""Agent 调试器核心类"""
def __init__(self, agent, trace_store):
"""
初始化调试器
Args:
agent: 待调试的 Agent 实例
trace_store: Trace 数据存储
"""
self.agent = agent
self.trace_store = trace_store
self.sessions: Dict[str, DebugSession] = {}
self.breakpoint_handlers = {}
async def start_debug_session(
self,
trace_id: str,
breakpoints: List[Dict[str, Any]] = None
) -> DebugSession:
"""
启动调试会话
Args:
trace_id: 要调试的 Trace ID
breakpoints: 断点列表
Returns:
DebugSession: 调试会话对象
"""
# 1. 加载 Trace 数据
trace_data = await self.trace_store.get_trace(trace_id)
if not trace_data:
raise ValueError(f"Trace {trace_id} not found")
# 2. 创建调试会话
session = DebugSession(
session_id=f"debug_{trace_id}_{datetime.now().timestamp()}",
trace_id=trace_id,
start_time=datetime.now(),
breakpoints=breakpoints or [],
call_stack=self._build_call_stack(trace_data)
)
# 3. 初始化变量
session.variables = self._extract_variables(trace_data)
# 4. 保存会话
self.sessions[session.session_id] = session
return session
def _build_call_stack(self, trace_data: Dict) -> List[Dict]:
"""构建调用栈"""
call_stack = []
for span in trace_data.get('spans', []):
call_stack.append({
'span_id': span['span_id'],
'name': span['name'],
'start_time': span['start_time'],
'end_time': span.get('end_time'),
'status': span.get('status', 'unknown'),
'attributes': span.get('attributes', {})
})
return call_stack
def _extract_variables(self, trace_data: Dict) -> Dict[str, Any]:
"""提取变量"""
variables = {}
for span in trace_data.get('spans', []):
# 提取 Span 属性作为变量
for key, value in span.get('attributes', {}).items():
variables[f"{span['name']}.{key}"] = value
# 提取事件数据
for event in span.get('events', []):
variables[f"event.{event['name']}"] = event.get('attributes', {})
return variables
async def step_over(self, session_id: str) -> Dict[str, Any]:
"""
单步执行(跳过当前 Span)
Args:
session_id: 调试会话 ID
Returns:
当前状态
"""
session = self.sessions.get(session_id)
if not session:
raise ValueError(f"Session {session_id} not found")
# 找到当前执行的 Span
current_span = self._find_current_span(session)
if not current_span:
return {'status': 'completed', 'message': 'Trace execution completed'}
# 移动到下一个 Span
next_span = self._find_next_span(session, current_span)
if next_span:
# 更新调试状态
session.events.append({
'type': 'step_over',
'timestamp': datetime.now().isoformat(),
'from_span': current_span['span_id'],
'to_span': next_span['span_id']
})
return {
'status': 'paused',
'current_span': next_span,
'variables': session.variables,
'call_stack': session.call_stack
}
else:
return {'status': 'completed', 'message': 'No more spans to execute'}
async def continue_execution(self, session_id: str) -> Dict[str, Any]:
"""
继续执行直到下一个断点
Args:
session_id: 调试会话 ID
Returns:
断点命中信息或完成状态
"""
session = self.sessions.get(session_id)
if not session:
raise ValueError(f"Session {session_id} not found")
# 从当前位置继续执行
current_span = self._find_current_span(session)
for span in session.call_stack:
# 检查是否命中断点
if self._check_breakpoint(session, span):
return {
'status': 'breakpoint_hit',
'breakpoint': self._get_matching_breakpoint(session, span),
'current_span': span,
'variables': session.variables
}
return {'status': 'completed', 'message': 'Execution completed without hitting breakpoints'}
def _check_breakpoint(
self,
session: DebugSession,
span: Dict
) -> bool:
"""检查是否命中断点"""
for bp in session.breakpoints:
# 按 Span 名称匹配
if bp.get('type') == 'span_name' and bp.get('name') == span['name']:
return True
# 按条件匹配
if bp.get('type') == 'condition':
condition = bp.get('condition')
if self._evaluate_condition(condition, span, session.variables):
return True
return False
def _evaluate_condition(
self,
condition: str,
span: Dict,
variables: Dict
) -> bool:
"""评估断点条件"""
try:
# 简单的条件评估(生产环境需要更安全的实现)
# 支持如:latency > 1000, status == 'ERROR'
local_vars = {**variables, 'span': span}
return eval(condition, {"__builtins__": {}}, local_vars)
except Exception:
return False
async def inspect_variable(
self,
session_id: str,
variable_name: str
) -> Any:
"""
查看变量值
Args:
session_id: 调试会话 ID
variable_name: 变量名称
Returns:
变量值
"""
session = self.sessions.get(session_id)
if not session:
raise ValueError(f"Session {session_id} not found")
return session.variables.get(variable_name)
async def replay_trace(
self,
trace_id: str,
modify_input: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
回放 Trace(可修改输入)
Args:
trace_id: Trace ID
modify_input: 修改的输入参数
Returns:
回放结果
"""
# 1. 加载原始 Trace
original_trace = await self.trace_store.get_trace(trace_id)
# 2. 修改输入(如果指定)
if modify_input:
original_trace['input'] = {
**original_trace.get('input', {}),
**modify_input
}
# 3. 重新执行 Agent
start_time = datetime.now()
response = await self.agent.run(original_trace['input'])
end_time = datetime.now()
# 4. 对比结果
comparison = self._compare_traces(original_trace, response)
return {
'original_trace': original_trace,
'replayed_trace': response,
'comparison': comparison,
'execution_time': (end_time - start_time).total_seconds()
}
def _compare_traces(
self,
original: Dict,
replayed: Dict
) -> Dict[str, Any]:
"""对比原始 Trace 和回放 Trace"""
return {
'output_similarity': self._calculate_similarity(
original.get('output', ''),
replayed.get('output', '')
),
'latency_diff': (
replayed.get('latency_ms', 0) -
original.get('latency_ms', 0)
),
'tool_calls_diff': self._compare_tool_calls(
original.get('tool_calls', []),
replayed.get('tool_calls', [])
)
}
def _calculate_similarity(self, text1: str, text2: str) -> float:
"""计算文本相似度"""
from difflib import SequenceMatcher
return SequenceMatcher(None, text1, text2).ratio()
def _compare_tool_calls(
self,
original: List[Dict],
replayed: List[Dict]
) -> Dict[str, Any]:
"""对比工具调用"""
return {
'original_count': len(original),
'replayed_count': len(replayed),
'same_tools': set(t['name'] for t in original) ==
set(t['name'] for t in replayed)
}
def generate_debug_report(self, session_id: str) -> Dict[str, Any]:
"""生成调试报告"""
session = self.sessions.get(session_id)
if not session:
raise ValueError(f"Session {session_id} not found")
return {
'session_id': session.session_id,
'trace_id': session.trace_id,
'duration': (datetime.now() - session.start_time).total_seconds(),
'total_spans': len(session.call_stack),
'breakpoints_hit': len([
e for e in session.events
if e['type'] == 'breakpoint_hit'
]),
'variables_count': len(session.variables),
'events': session.events
}
# 使用示例
async def debug_agent_example():
"""调试 Agent 示例"""
from my_agent import CustomerSupportAgent
from trace_store import TraceStore
agent = CustomerSupportAgent()
trace_store = TraceStore()
debugger = AgentDebugger(agent, trace_store)
# 1. 启动调试会话
session = await debugger.start_debug_session(
trace_id="trace_12345",
breakpoints=[
{'type': 'span_name', 'name': 'llm-completion'},
{'type': 'condition', 'condition': 'latency > 1000'}
]
)
print(f"调试会话启动:{session.session_id}")
print(f"调用栈:{len(session.call_stack)} 个 Span")
# 2. 单步执行
status = await debugger.step_over(session.session_id)
print(f"当前状态:{status['status']}")
# 3. 查看变量
latency = await debugger.inspect_variable(
session.session_id,
'llm-completion.latency_ms'
)
print(f"LLM 延迟:{latency}ms")
# 4. 继续执行直到断点
result = await debugger.continue_execution(session.session_id)
if result['status'] == 'breakpoint_hit':
print(f"命中断点:{result['breakpoint']}")
# 5. Trace 回放
replay_result = await debugger.replay_trace(
trace_id="trace_12345",
modify_input={'temperature': 0.5} # 修改温度参数
)
print(f"回放相似度:{replay_result['comparison']['output_similarity']:.2f}")
# 6. 生成调试报告
report = debugger.generate_debug_report(session.session_id)
print(f"调试报告:{json.dumps(report, indent=2)}")