AGoT 自适应图推理器完整实现
import openai
from typing import List, Dict, Optional, Set
from dataclasses import dataclass
from enum import Enum
import json
class NodeType(Enum):
CHAIN = "chain" # 链式节点
TREE = "tree" # 树式节点
GRAPH = "graph" # 图式节点
@dataclass
class ReasoningNode:
"""推理节点"""
id: str
content: str
node_type: NodeType
dependencies: List[str] # 依赖的父节点 ID
children: List[str] # 子节点 ID
status: str = "pending" # pending, completed, expanded
value: float = 0.0 # 质量评分
class AdaptiveGraphReasoner:
"""
AGoT 自适应图推理器
动态构建 DAG,统一链/树/图结构
"""
def __init__(self, model: str = "gpt-4",
max_depth: int = 6,
complexity_threshold: float = 0.6):
"""
初始化
Args:
model: LLM 模型
max_depth: 最大深度
complexity_threshold: 复杂度阈值(决定使用何种结构)
"""
self.model = model
self.max_depth = max_depth
self.complexity_threshold = complexity_threshold
self.nodes: Dict[str, ReasoningNode] = {}
self.root_id: Optional[str] = None
self.node_counter = 0
def _call_llm(self, prompt: str, temperature: float = 0.7) -> str:
"""调用 LLM"""
response = openai.ChatCompletion.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
temperature=temperature
)
return response.choices[0].message.content
def analyze_complexity(self, query: str) -> float:
"""
分析问题复杂度(0-1)
决定使用链/树/图结构
"""
prompt = f"""
分析问题复杂度:{query}
请评估这个问题的复杂度(0-1 分):
- 0.0-0.3: 简单问题,适合链式推理(CoT)
- 0.3-0.6: 中等复杂,适合树式推理(ToT)
- 0.6-1.0: 高度复杂,适合图式推理(GoT)
考虑因素:
1. 是否需要多步推理?
2. 是否有多个子问题?
3. 子问题之间是否有复杂依赖?
4. 是否需要思维合并或反馈循环?
只输出一个数字(0.0-1.0):
"""
response_text = self._call_llm(prompt, temperature=0.3)
try:
score = float(response_text.strip())
return max(0.0, min(1.0, score))
except:
return 0.5
def determine_structure(self, complexity: float) -> NodeType:
"""根据复杂度决定结构类型"""
if complexity < 0.3:
return NodeType.CHAIN
elif complexity < 0.6:
return NodeType.TREE
else:
return NodeType.GRAPH
def decompose_query(self, query: str, depth: int = 0) -> Dict:
"""
递归分解查询为子问题
返回 DAG 结构
"""
prompt = f"""
主问题:{query}
当前深度:{depth}/{self.max_depth}
请将这个问题分解为相互依赖的子问题。
输出格式(JSON):
{{
"subproblems": [
{{
"id": "sub_1",
"description": "子问题 1 描述",
"dependencies": [], // 依赖的子问题 ID 列表
"complexity": 0.5 // 子问题复杂度
}},
...
],
"merge_strategy": "如何合并子问题答案"
}}
如果问题足够简单,返回空列表表示无需分解。
"""
response_text = self._call_llm(prompt, temperature=0.7)
return json.loads(response_text)
def create_node(self, content: str, node_type: NodeType,
dependencies: List[str]) -> str:
"""创建推理节点"""
node_id = f"node_{self.node_counter}"
self.node_counter += 1
self.nodes[node_id] = ReasoningNode(
id=node_id,
content=content,
node_type=node_type,
dependencies=dependencies,
children=[]
)
# 更新父节点的 children
for dep_id in dependencies:
if dep_id in self.nodes:
self.nodes[dep_id].children.append(node_id)
return node_id
def solve_subproblem(self, subproblem: Dict,
parent_answers: Dict[str, str]) -> str:
"""求解单个子问题"""
context = ""
for dep_id, answer in parent_answers.items():
context += f"依赖 {dep_id} 的答案:{answer}\n"
prompt = f"""
{context}
子问题:{subproblem['description']}
请基于依赖答案求解这个子问题。
"""
return self._call_llm(prompt, temperature=0.7)
def merge_answers(self, sub_answers: Dict[str, str],
merge_strategy: str, query: str) -> str:
"""合并子问题答案"""
answers_text = "\n".join([
f"{k}: {v}" for k, v in sub_answers.items()
])
prompt = f"""
主问题:{query}
合并策略:{merge_strategy}
子问题答案:
{answers_text}
请根据合并策略整合所有子问题答案,形成最终解决方案。
"""
return self._call_llm(prompt, temperature=0.7)
def reason(self, query: str) -> Dict:
"""
自适应图推理主流程
Args:
query: 查询问题
Returns:
result: 包含答案和推理图的字典
"""
print(f"开始 AGoT 推理:{query[:50]}...")
print("="*70 + "\n")
# 1. 分析复杂度
complexity = self.analyze_complexity(query)
structure = self.determine_structure(complexity)
print(f"问题复杂度:{complexity:.2f}")
print(f"选择结构:{structure.value}")
print()
# 2. 创建根节点
self.root_id = self.create_node(query, structure, [])
# 3. 递归分解与求解
result = self._recursive_solve(query, depth=0)
print("="*70)
print(f"\n推理完成,共{len(self.nodes)}个节点")
print(f"推理图结构:{self._get_graph_summary()}")
return result
def _recursive_solve(self, query: str, depth: int) -> Dict:
"""递归求解"""
if depth >= self.max_depth:
return {"answer": self._call_llm(f"直接回答:{query}"), "nodes": []}
# 分解查询
decomposition = self.decompose_query(query, depth)
if not decomposition["subproblems"]:
# 无需分解,直接求解
answer = self._call_llm(f"回答:{query}")
node_id = self.create_node(query, NodeType.CHAIN, [])
self.nodes[node_id].status = "completed"
self.nodes[node_id].value = 1.0
return {"answer": answer, "nodes": [node_id]}
print(f"深度 {depth}: 分解为 {len(decomposition['subproblems'])} 个子问题")
# 求解子问题(拓扑排序)
sub_answers = {}
all_nodes = []
# 简单拓扑排序(按依赖数排序)
sorted_subs = sorted(
decomposition["subproblems"],
key=lambda x: len(x["dependencies"])
)
for sub in sorted_subs:
# 获取依赖答案
parent_answers = {
dep_id: sub_answers[dep_id]
for dep_id in sub["dependencies"]
if dep_id in sub_answers
}
# 求解
answer = self.solve_subproblem(sub, parent_answers)
sub_answers[sub["id"]] = answer
# 创建节点
node_type = self.determine_structure(sub["complexity"])
node_id = self.create_node(
sub["description"],
node_type,
sub["dependencies"]
)
self.nodes[node_id].status = "completed"
self.nodes[node_id].value = 1.0 - sub["complexity"] * 0.3
all_nodes.append(node_id)
print(f" ✓ 子问题 {sub['id']}: {answer[:50]}...")
# 合并答案
final_answer = self.merge_answers(
sub_answers,
decomposition["merge_strategy"],
query
)
return {
"answer": final_answer,
"nodes": all_nodes,
"sub_answers": sub_answers
}
def _get_graph_summary(self) -> str:
"""获取图结构摘要"""
chain_count = len([n for n in self.nodes.values()
if n.node_type == NodeType.CHAIN])
tree_count = len([n for n in self.nodes.values()
if n.node_type == NodeType.TREE])
graph_count = len([n for n in self.nodes.values()
if n.node_type == NodeType.GRAPH])
return f"链式={chain_count}, 树式={tree_count}, 图式={graph_count}"
def visualize_graph(self) -> str:
"""生成图的可视化描述"""
lines = ["推理图结构:", "="*50]
for node_id, node in self.nodes.items():
deps = f" <- [{', '.join(node.dependencies)}]" if node.dependencies else ""
lines.append(f"{node_id} ({node.node_type.value}): {node.content[:40]}...{deps}")
return "\n".join(lines)
# 使用示例
if __name__ == "__main__":
# 初始化推理器
reasoner = AdaptiveGraphReasoner(max_depth=4)
# 示例问题:复杂逻辑推理
query = """
某公司有 5 个部门:研发、市场、销售、财务、人力。
已知:
1. 研发部门人数最多
2. 财务部门人数少于市场部门
3. 销售部门人数不是最少
4. 人力部门人数多于财务部门但少于研发部门
5. 市场部门人数不是最多
请问:各部门人数从多到少的排序是什么?
"""
print("AGoT 示例:复杂逻辑排序问题")
print("="*70 + "\n")
print(f"问题:{query.strip()}\n")
# 推理
result = reasoner.reason(query)
print("\n最终答案:")
print("-"*70)
print(result["answer"])
print("-"*70)
print("\n" + reasoner.visualize_graph())
print("\n" + "="*70)
print("\n关键观察:")
print("1. AGoT 自适应选择结构(根据复杂度)")
print("2. 递归分解为 DAG 子问题")
print("3. 拓扑排序求解(处理依赖关系)")
print("4. 统一链/树/图优势")
print("5. 无需训练,纯推理时优化")