思维树推理器完整实现
import openai
from typing import List, Dict, Optional, Tuple
from dataclasses import dataclass
from enum import Enum
import json
import math
class ThoughtStatus(Enum):
PROMISING = "promising" # 有希望
NEUTRAL = "neutral" # 中性
DEAD_END = "dead_end" # 死胡同
@dataclass
class ThoughtNode:
"""思维节点"""
id: int
content: str
depth: int
parent_id: Optional[int]
children: List[int]
status: ThoughtStatus
value: float # 启发式评分 (0-1)
visits: int = 0 # 访问次数
class TreeOfThoughtsReasoner:
"""
思维树推理器
支持多路径探索、回溯、启发式搜索
"""
def __init__(self, model: str = "gpt-4",
max_depth: int = 5,
branch_factor: int = 3,
max_iterations: int = 20,
temperature: float = 0.7):
"""
初始化
Args:
model: LLM 模型
max_depth: 最大深度
branch_factor: 分支因子(每步生成候选数)
max_iterations: 最大迭代次数
temperature: 生成温度
"""
self.model = model
self.max_depth = max_depth
self.branch_factor = branch_factor
self.max_iterations = max_iterations
self.temperature = temperature
self.thoughts: Dict[int, ThoughtNode] = {}
self.root_id: Optional[int] = None
self.current_id = 0
self.iteration = 0
def _call_llm(self, prompt: str, temperature: float = None) -> str:
"""调用 LLM"""
if temperature is None:
temperature = self.temperature
response = openai.ChatCompletion.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
temperature=temperature
)
return response.choices[0].message.content
def generate_initial_thoughts(self, problem: str) -> List[str]:
"""生成初始思维(根节点候选)"""
prompt = f"""
问题:{problem}
请生成{self.branch_factor}个不同的初始思考方向。
每个思考方向应该:
1. 代表一种可能的解决策略
2. 简洁明了(1-2 句话)
3. 彼此不同
输出格式(JSON 数组):
["思考方向 1", "思考方向 2", "思考方向 3"]
"""
response_text = self._call_llm(prompt, temperature=0.7)
thoughts = json.loads(response_text)
return thoughts
def expand_thought(self, thought: str, problem: str,
context: str) -> List[str]:
"""扩展思维(生成子节点)"""
prompt = f"""
问题:{problem}
当前思考:{thought}
上下文:{context}
请生成{self.branch_factor}个可能的下一步思考。
每个思考应该:
1. 是当前思考的自然延续
2. 推进问题解决
3. 简洁明了(1-2 句话)
输出格式(JSON 数组):
["下一步思考 1", "下一步思考 2", "下一步思考 3"]
"""
response_text = self._call_llm(prompt, temperature=0.7)
sub_thoughts = json.loads(response_text)
return sub_thoughts
def evaluate_thought(self, thought: str, problem: str,
context: str) -> Tuple[ThoughtStatus, float]:
"""评估思维质量"""
prompt = f"""
问题:{problem}
上下文:{context}
当前思考:{thought}
请评估这个思考的质量:
1. 是否有助于解决问题?
2. 是否有逻辑错误?
3. 是否接近解决方案?
输出格式(JSON):
{{
"status": "promising" | "neutral" | "dead_end",
"value": 0.0-1.0,
"reason": "评估理由"
}}
"""
response_text = self._call_llm(prompt, temperature=0.3)
eval_result = json.loads(response_text)
status = ThoughtStatus(eval_result["status"])
value = float(eval_result["value"])
return status, value
def solve(self, problem: str) -> Optional[str]:
"""
使用思维树解决问题
Args:
problem: 问题描述
Returns:
solution: 解决方案或 None
"""
print(f"开始解决问题:{problem[:50]}...")
print("="*70 + "\n")
# 生成初始思维
initial_thoughts = self.generate_initial_thoughts(problem)
# 创建根节点(虚拟)
self.root_id = self.current_id
self.thoughts[self.root_id] = ThoughtNode(
id=self.root_id,
content="ROOT",
depth=0,
parent_id=None,
children=[],
status=ThoughtStatus.NEUTRAL,
value=1.0
)
self.current_id += 1
# 添加初始思维作为根节点的子节点
for thought_text in initial_thoughts:
thought_id = self.current_id
status, value = self.evaluate_thought(thought_text, problem, "")
self.thoughts[thought_id] = ThoughtNode(
id=thought_id,
content=thought_text,
depth=1,
parent_id=self.root_id,
children=[],
status=status,
value=value
)
self.thoughts[self.root_id].children.append(thought_id)
self.current_id += 1
print(f"生成 {len(initial_thoughts)} 个初始思维")
print()
# 广度优先搜索 + 启发式剪枝
best_solution = None
best_value = 0.0
while self.iteration < self.max_iterations:
self.iteration += 1
# 选择最有希望的节点扩展
candidates = [
(tid, node) for tid, node in self.thoughts.items()
if node.status == ThoughtStatus.PROMISING
and node.depth < self.max_depth
and len(node.children) == 0 # 未扩展
]
if not candidates:
print("没有更多有希望的节点可扩展")
break
# 按价值排序,选择最好的
candidates.sort(key=lambda x: x[1].value, reverse=True)
current_id, current_node = candidates[0]
print(f"迭代 {self.iteration}: 扩展节点 {current_id}")
print(f" 思维:{current_node.content}")
print(f" 评分:{current_node.value:.3f}")
print(f" 深度:{current_node.depth}")
print()
# 获取上下文(祖先节点)
context = self._get_context(current_id)
# 扩展子节点
sub_thoughts = self.expand_thought(
current_node.content, problem, context
)
for sub_thought_text in sub_thoughts:
sub_id = self.current_id
status, value = self.evaluate_thought(
sub_thought_text, problem, context + "\n" + current_node.content
)
self.thoughts[sub_id] = ThoughtNode(
id=sub_id,
content=sub_thought_text,
depth=current_node.depth + 1,
parent_id=current_id,
children=[],
status=status,
value=value
)
self.thoughts[current_id].children.append(sub_id)
self.current_id += 1
# 检查是否达到解决方案
if status == ThoughtStatus.PROMISING and value > 0.9:
solution = self._extract_solution(sub_id)
if solution:
if value > best_value:
best_solution = solution
best_value = value
print("="*70)
print(f"\n搜索完成,共探索 {len(self.thoughts)} 个思维节点")
print(f"最佳解决方案评分:{best_value:.3f}")
return best_solution
def _get_context(self, node_id: int) -> str:
"""获取节点的上下文(祖先链)"""
context_parts = []
current = self.thoughts[node_id]
while current.parent_id is not None:
context_parts.append(current.content)
current = self.thoughts[current.parent_id]
return "\n".join(reversed(context_parts))
def _extract_solution(self, node_id: int) -> Optional[str]:
"""从思维链中提取解决方案"""
node = self.thoughts[node_id]
# 简单启发式:如果深度足够且评分高,认为是解决方案
if node.depth >= 3 and node.value > 0.85:
return self._get_context(node_id) + "\n" + node.content
return None
def get_statistics(self) -> Dict:
"""获取搜索统计"""
total = len(self.thoughts)
promising = len([n for n in self.thoughts.values()
if n.status == ThoughtStatus.PROMISING])
neutral = len([n for n in self.thoughts.values()
if n.status == ThoughtStatus.NEUTRAL])
dead_end = len([n for n in self.thoughts.values()
if n.status == ThoughtStatus.DEAD_END])
avg_value = sum(n.value for n in self.thoughts.values()) / total if total > 0 else 0
return {
"total_nodes": total,
"promising": promising,
"neutral": neutral,
"dead_end": dead_end,
"avg_value": avg_value,
"max_depth": max(n.depth for n in self.thoughts.values()) if total > 0 else 0
}
# 使用示例
if __name__ == "__main__":
# 初始化推理器
reasoner = TreeOfThoughtsReasoner(
max_depth=4,
branch_factor=3,
max_iterations=15
)
# 示例问题:24 点游戏
problem = """
使用数字 3, 8, 8, 9 和基本运算 (+, -, *, /, 括号),
如何得到 24?每个数字必须使用且只能使用一次。
"""
print("思维树推理器示例:24 点游戏")
print("="*70 + "\n")
print(f"问题:{problem.strip()}\n")
# 求解
solution = reasoner.solve(problem)
if solution:
print("\n找到的解决方案:")
print("-"*70)
print(solution)
print("-"*70)
else:
print("\n未找到满意解决方案")
# 统计
stats = reasoner.get_statistics()
print("\n搜索统计:")
print(f" 总节点数:{stats['total_nodes']}")
print(f" 有希望:{stats['promising']}")
print(f" 中性:{stats['neutral']}")
print(f" 死胡同:{stats['dead_end']}")
print(f" 平均评分:{stats['avg_value']:.3f}")
print(f" 最大深度:{stats['max_depth']}")
print("\n" + "="*70)
print("\n关键观察:")
print("1. 思维树支持多路径探索(分支因子=3)")
print("2. 启发式评估剪枝(dead_end 节点不再扩展)")
print("3. 回溯机制(当一条路径走不通时尝试其他路径)")
print("4. 深度限制防止无限扩展(max_depth=4)")
print("5. 迭代限制控制计算成本(max_iterations=15)")