反馈驱动修正 Agent 完整实现
import openai
from typing import List, Dict, Optional, Tuple
from dataclasses import dataclass
from enum import Enum
import json
class FeedbackType(Enum):
FACTUAL_ERROR = "factual_error" # 事实错误
LOGIC_ERROR = "logic_error" # 逻辑错误
COMPLETENESS = "completeness" # 完整性不足
CLARITY = "clarity" # 表达不清
STYLE = "style" # 风格问题
SAFETY = "safety" # 安全问题
@dataclass
class Feedback:
"""反馈项"""
type: FeedbackType
severity: str # critical, major, minor
description: str
suggestion: str
location: str # 问题位置
@dataclass
class CorrectionResult:
"""修正结果"""
original_output: str
feedback_list: List[Feedback]
corrected_output: str
improvement_score: float # 改进幅度 (0-1)
iterations: int
class FeedbackDrivenCorrector:
"""
反馈驱动修正 Agent
生成反馈 → 应用修正 → 验证改进
"""
def __init__(self, model: str = "gpt-4",
max_iterations: int = 3,
improvement_threshold: float = 0.2):
"""
初始化
Args:
model: LLM 模型
max_iterations: 最大迭代次数
improvement_threshold: 改进阈值
"""
self.model = model
self.max_iterations = max_iterations
self.improvement_threshold = improvement_threshold
def _call_llm(self, prompt: str, temperature: float = 0.7) -> str:
"""调用 LLM"""
response = openai.ChatCompletion.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
temperature=temperature
)
return response.choices[0].message.content
def generate_feedback(self, task: str, output: str,
criteria: List[str]) -> List[Feedback]:
"""
生成反馈
Args:
task: 任务描述
output: 初始输出
criteria: 评估标准列表
Returns:
feedback_list: 反馈列表
"""
criteria_text = "\n".join([f"- {c}" for c in criteria])
prompt = f"""
任务:{task}
输出:
{output}
评估标准:
{criteria_text}
请生成详细反馈,识别输出中的问题。
输出格式(JSON 数组):
[
{{
"type": "factual_error|logic_error|completeness|clarity|style|safety",
"severity": "critical|major|minor",
"description": "问题描述",
"suggestion": "改进建议",
"location": "问题位置(如:第 2 段)"
}}
]
只输出 JSON 数组,不要其他内容。
"""
response_text = self._call_llm(prompt, temperature=0.3)
feedback_data = json.loads(response_text)
feedback_list = []
for item in feedback_data:
feedback_list.append(Feedback(
type=FeedbackType(item["type"]),
severity=item["severity"],
description=item["description"],
suggestion=item["suggestion"],
location=item["location"]
))
return feedback_list
def apply_correction(self, task: str, original_output: str,
feedback_list: List[Feedback]) -> str:
"""
应用反馈进行修正
Args:
task: 任务描述
original_output: 原始输出
feedback_list: 反馈列表
Returns:
corrected_output: 修正后输出
"""
# 按严重程度排序反馈
severity_order = {"critical": 0, "major": 1, "minor": 2}
sorted_feedback = sorted(
feedback_list,
key=lambda f: severity_order[f.severity]
)
# 构建反馈文本
feedback_text = ""
for i, fb in enumerate(sorted_feedback, 1):
feedback_text += f"{i}. [{fb.severity.upper()}] {fb.type.value}\n"
feedback_text += f" 问题:{fb.description}\n"
feedback_text += f" 位置:{fb.location}\n"
feedback_text += f" 建议:{fb.suggestion}\n\n"
prompt = f"""
任务:{task}
原始输出:
{original_output}
反馈与改进建议:
{feedback_text}
请根据反馈逐条修正输出,确保:
1. 修复所有 critical 和 major 问题
2. 尽可能改进 minor 问题
3. 保持原文风格和结构
4. 只输出修正后的完整内容
修正后输出:
"""
return self._call_llm(prompt, temperature=0.7)
def evaluate_improvement(self, task: str, original: str,
corrected: str, criteria: List[str]) -> float:
"""评估改进幅度(0-1)"""
prompt = f"""
任务:{task}
评估标准:{', '.join(criteria)}
原始输出:
{original}
修正后输出:
{corrected}
请评估修正后输出的改进幅度(0-1 分):
- 0.0: 毫无改进甚至更差
- 0.5: 中等改进
- 1.0: 显著改进,所有关键问题解决
考虑因素:
1. 反馈中的问题是否被修复?
2. 是否引入新问题?
3. 整体质量提升程度?
只输出一个数字(0.0-1.0):
"""
response_text = self._call_llm(prompt, temperature=0.3)
try:
score = float(response_text.strip())
return max(0.0, min(1.0, score))
except:
return 0.5
def correct(self, task: str, initial_output: str,
criteria: List[str]) -> CorrectionResult:
"""
反馈驱动修正主流程
Args:
task: 任务描述
initial_output: 初始输出
criteria: 评估标准
Returns:
result: 修正结果
"""
print(f"开始反馈驱动修正:{task[:50]}...")
print("="*70 + "\n")
current_output = initial_output
all_feedback = []
for iteration in range(self.max_iterations):
print(f"迭代 {iteration + 1}/{self.max_iterations}")
print("-"*50)
# 生成反馈
feedback_list = self.generate_feedback(task, current_output, criteria)
print(f"生成 {len(feedback_list)} 条反馈")
if not feedback_list:
print("✓ 无需改进,输出已达标")
break
# 显示关键反馈
critical_count = len([f for f in feedback_list if f.severity == "critical"])
major_count = len([f for f in feedback_list if f.severity == "major"])
print(f" Critical: {critical_count}, Major: {major_count}")
for fb in feedback_list[:3]: # 显示前 3 条
print(f" - [{fb.severity}] {fb.description[:60]}...")
# 应用修正
corrected_output = self.apply_correction(
task, current_output, feedback_list
)
# 评估改进
improvement = self.evaluate_improvement(
task, current_output, corrected_output, criteria
)
print(f"改进幅度:{improvement:.2f}")
all_feedback.extend(feedback_list)
# 检查是否达到阈值
if improvement < self.improvement_threshold:
print(f"⚠️ 改进幅度低于阈值,停止迭代")
current_output = corrected_output
break
current_output = corrected_output
print()
print("="*70)
print(f"\n修正完成,共{len(all_feedback)}条反馈,{iteration + 1}次迭代")
return CorrectionResult(
original_output=initial_output,
feedback_list=all_feedback,
corrected_output=current_output,
improvement_score=improvement if 'improvement' in locals() else 0.0,
iterations=iteration + 1
)
# 使用示例
if __name__ == "__main__":
# 初始化修正器
corrector = FeedbackDrivenCorrector(
max_iterations=3,
improvement_threshold=0.15
)
# 示例任务:商务邮件写作
task = """
写一封商务邮件,主题是"项目延期通知"。
要求:
1. 语气专业且诚恳
2. 说明延期原因
3. 提供新的时间表
4. 表达歉意并承诺改进
5. 长度适中(200-300 字)
"""
# 初始输出(故意包含问题)
initial_output = """
主题:项目延期
你好,
项目要延期了。因为有些问题没解决。
新的时间下个月吧。
抱歉。
"""
print("反馈驱动修正示例:商务邮件写作")
print("="*70 + "\n")
print(f"任务:{task.strip()}\n")
print("初始输出:")
print("-"*70)
print(initial_output)
print("-"*70 + "\n")
# 评估标准
criteria = [
"语气专业且诚恳",
"详细说明延期原因",
"提供具体新时间表",
"充分表达歉意",
"承诺改进措施",
"长度适中(200-300 字)",
"格式规范(称呼、正文、结尾、签名)"
]
# 修正
result = corrector.correct(task, initial_output, criteria)
print("\n修正后输出:")
print("-"*70)
print(result.corrected_output)
print("-"*70)
print(f"\n改进统计:")
print(f" 总反馈数:{len(result.feedback_list)}")
print(f" 迭代次数:{result.iterations}")
print(f" 改进幅度:{result.improvement_score:.2f}")
print("\n" + "="*70)
print("\n关键观察:")
print("1. 反馈生成是修正的核心(具体、可操作、建设性)")
print("2. 按严重程度优先级处理(critical → major → minor)")
print("3. 多轮迭代持续改进(通常 2-3 轮达到平台期)")
print("4. 改进评估防止过度修正")
print("5. 适用于写作、代码、推理等多种任务")