Agent 测试框架核心实现
import pytest
from typing import List, Dict, Any
from dataclasses import dataclass
from enum import Enum
class EvalMetric(Enum):
"""评估指标类型"""
EXACT_MATCH = "exact_match"
SEMANTIC_SIMILARITY = "semantic_similarity"
LLM_JUDGE = "llm_judge"
TOOL_CALL_ACCURACY = "tool_call_accuracy"
RESPONSE_LATENCY = "response_latency"
@dataclass
class TestCase:
"""测试用例数据模型"""
id: str
name: str
input: str
expected_output: str = None
expected_tools: List[str] = None
metadata: Dict[str, Any] = None
@dataclass
class TestResult:
"""测试结果数据模型"""
test_case: TestCase
actual_output: str
passed: bool
metrics: Dict[str, float]
error: str = None
class AgentTestFramework:
"""Agent 测试框架核心类"""
def __init__(self, agent, eval_model=None):
"""
初始化测试框架
Args:
agent: 待测试的 Agent 实例
eval_model: 用于 LLM-as-a-Judge 的评估模型
"""
self.agent = agent
self.eval_model = eval_model
self.results: List[TestResult] = []
def run_test(self, test_case: TestCase) -> TestResult:
"""
执行单个测试用例
Args:
test_case: 测试用例
Returns:
TestResult: 测试结果
"""
try:
# 1. 执行 Agent
response = self.agent.run(test_case.input)
# 2. 计算评估指标
metrics = self._evaluate_metrics(
response=response,
test_case=test_case
)
# 3. 判断测试是否通过
passed = self._determine_pass(metrics, test_case)
return TestResult(
test_case=test_case,
actual_output=response.output,
passed=passed,
metrics=metrics
)
except Exception as e:
return TestResult(
test_case=test_case,
actual_output="",
passed=False,
metrics={},
error=str(e)
)
def _evaluate_metrics(
self,
response: Any,
test_case: TestCase
) -> Dict[str, float]:
"""评估各项指标"""
metrics = {}
# 1. 精确匹配(如果有期望输出)
if test_case.expected_output:
metrics['exact_match'] = float(
response.output == test_case.expected_output
)
# 2. 语义相似度
if test_case.expected_output:
similarity = self._calculate_semantic_similarity(
response.output,
test_case.expected_output
)
metrics['semantic_similarity'] = similarity
# 3. LLM-as-a-Judge 评估
if self.eval_model and test_case.expected_output:
judge_score = self._llm_judge_eval(
response.output,
test_case.expected_output,
test_case.input
)
metrics['llm_judge_score'] = judge_score
# 4. 工具调用准确率
if test_case.expected_tools:
tool_accuracy = self._calculate_tool_accuracy(
response.tools_called,
test_case.expected_tools
)
metrics['tool_call_accuracy'] = tool_accuracy
# 5. 响应延迟
metrics['response_latency'] = response.latency_ms
return metrics
def _calculate_semantic_similarity(
self,
text1: str,
text2: str
) -> float:
"""计算语义相似度(使用 Embedding)"""
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
model = SentenceTransformer('all-MiniLM-L6-v2')
embeddings = model.encode([text1, text2])
similarity = cosine_similarity([embeddings[0]], [embeddings[1]])[0][0]
return float(similarity)
def _llm_judge_eval(
self,
actual: str,
expected: str,
input_text: str
) -> float:
"""LLM-as-a-Judge 评估"""
prompt = f"""
请评估以下 Agent 回答的质量:
用户输入:{input_text}
期望输出:{expected}
实际输出:{actual}
请从以下维度评分(0-10 分):
1. 准确性:回答是否正确
2. 完整性:是否覆盖所有要点
3. 相关性:是否与问题相关
只返回一个 0-10 的数字分数。
"""
response = self.eval_model.generate(prompt)
score = float(response.text.strip())
return min(10.0, max(0.0, score)) / 10.0
def _calculate_tool_accuracy(
self,
actual_tools: List[str],
expected_tools: List[str]
) -> float:
"""计算工具调用准确率"""
if not expected_tools:
return 1.0
correct = len(set(actual_tools) & set(expected_tools))
total = len(set(expected_tools))
return correct / total if total > 0 else 1.0
def _determine_pass(
self,
metrics: Dict[str, float],
test_case: TestCase
) -> bool:
"""判断测试是否通过"""
# 默认规则:语义相似度>0.7 或 LLM 评分>0.6
if 'semantic_similarity' in metrics:
if metrics['semantic_similarity'] >= 0.7:
return True
if 'llm_judge_score' in metrics:
if metrics['llm_judge_score'] >= 0.6:
return True
# 如果有精确匹配要求
if test_case.expected_output:
if metrics.get('exact_match', 0) == 1.0:
return True
return False
def run_suite(self, test_cases: List[TestCase]) -> List[TestResult]:
"""执行测试套件"""
self.results = []
for test_case in test_cases:
result = self.run_test(test_case)
self.results.append(result)
return self.results
def generate_report(self) -> Dict[str, Any]:
"""生成测试报告"""
if not self.results:
return {"error": "No test results"}
total = len(self.results)
passed = sum(1 for r in self.results if r.passed)
failed = total - passed
# 计算平均指标
avg_metrics = {}
metric_keys = set()
for result in self.results:
metric_keys.update(result.metrics.keys())
for key in metric_keys:
values = [
r.metrics[key]
for r in self.results
if key in r.metrics
]
if values:
avg_metrics[key] = sum(values) / len(values)
return {
"summary": {
"total": total,
"passed": passed,
"failed": failed,
"pass_rate": passed / total if total > 0 else 0
},
"metrics": avg_metrics,
"results": [
{
"id": r.test_case.id,
"name": r.test_case.name,
"passed": r.passed,
"metrics": r.metrics,
"error": r.error
}
for r in self.results
]
}
# 使用示例
@pytest.fixture
def agent():
"""测试用 Agent Fixture"""
from my_agent import CustomerSupportAgent
return CustomerSupportAgent()
@pytest.fixture
def eval_model():
"""评估用 LLM Fixture"""
from langchain_openai import ChatOpenAI
return ChatOpenAI(model="gpt-4-turbo")
def test_customer_support_basic(agent, eval_model):
"""测试客服 Agent 基本功能"""
framework = AgentTestFramework(agent, eval_model)
test_case = TestCase(
id="cs-001",
name="退货政策查询",
input="我想退货,流程是什么?",
expected_output="退货流程包括:1. 登录账户 2. 找到订单 3. 点击退货...",
metadata={"category": "return_policy"}
)
result = framework.run_test(test_case)
assert result.passed, f"测试失败:{result.error}"
assert result.metrics['semantic_similarity'] > 0.7
assert result.metrics['response_latency'] < 2000
def test_tool_calling_accuracy(agent, eval_model):
"""测试工具调用准确率"""
framework = AgentTestFramework(agent, eval_model)
test_case = TestCase(
id="tool-001",
name="订单查询",
input="我的订单#12345 发货了吗?",
expected_tools=["order_lookup", "shipping_status"],
metadata={"category": "order_inquiry"}
)
result = framework.run_test(test_case)
assert result.metrics['tool_call_accuracy'] == 1.0
def test_batch_evaluation(agent, eval_model):
"""批量测试评估"""
framework = AgentTestFramework(agent, eval_model)
test_cases = [
TestCase(
id=f"cs-{i:03d}",
name=f"测试用例{i}",
input=f"测试输入{i}",
expected_output=f"期望输出{i}"
)
for i in range(100)
]
results = framework.run_suite(test_cases)
report = framework.generate_report()
print(f"通过率:{report['summary']['pass_rate']:.2%}")
print(f"平均语义相似度:{report['metrics'].get('semantic_similarity', 0):.2f}")
print(f"平均响应延迟:{report['metrics'].get('response_latency', 0):.0f}ms")