向量数据库与 RAG 系统完整实现
import numpy as np
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
import math
import hashlib
import json
from collections import defaultdict
import heapq
class DistanceMetric(Enum):
"""距离度量方式"""
COSINE = "cosine" # 余弦相似度
EUCLIDEAN = "euclidean" # 欧氏距离
DOT_PRODUCT = "dot" # 点积
@dataclass
class VectorDocument:
"""向量文档"""
id: str
content: str
embedding: np.ndarray
metadata: Dict[str, Any] = field(default_factory=dict)
timestamp: datetime = field(default_factory=datetime.now)
def to_dict(self) -> Dict[str, Any]:
"""转换为字典"""
return {
"id": self.id,
"content": self.content,
"embedding": self.embedding.tolist(),
"metadata": self.metadata,
"timestamp": self.timestamp.isoformat()
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'VectorDocument':
"""从字典创建"""
return cls(
id=data["id"],
content=data["content"],
embedding=np.array(data["embedding"]),
metadata=data.get("metadata", {}),
timestamp=datetime.fromisoformat(data["timestamp"])
)
class VectorIndex:
"""
向量索引(简化版 HNSW)
支持高效近似最近邻检索
"""
def __init__(self, dimension: int, metric: DistanceMetric = DistanceMetric.COSINE):
self.dimension = dimension
self.metric = metric
self.vectors: Dict[str, np.ndarray] = {}
self.doc_ids: List[str] = []
def _cosine_similarity(self, a: np.ndarray, b: np.ndarray) -> float:
"""计算余弦相似度"""
norm_a = np.linalg.norm(a)
norm_b = np.linalg.norm(b)
if norm_a == 0 or norm_b == 0:
return 0.0
return float(np.dot(a, b) / (norm_a * norm_b))
def _euclidean_distance(self, a: np.ndarray, b: np.ndarray) -> float:
"""计算欧氏距离"""
return float(np.linalg.norm(a - b))
def _dot_product(self, a: np.ndarray, b: np.ndarray) -> float:
"""计算点积"""
return float(np.dot(a, b))
def _compute_similarity(self, a: np.ndarray, b: np.ndarray) -> float:
"""计算相似度/距离"""
if self.metric == DistanceMetric.COSINE:
return self._cosine_similarity(a, b)
elif self.metric == DistanceMetric.EUCLIDEAN:
return -self._euclidean_distance(a, b) # 负距离,越大越相似
elif self.metric == DistanceMetric.DOT_PRODUCT:
return self._dot_product(a, b)
else:
raise ValueError(f"Unknown metric: {self.metric}")
def insert(self, doc_id: str, embedding: np.ndarray):
"""插入向量"""
if len(embedding) != self.dimension:
raise ValueError(f"Embedding dimension {len(embedding)} != {self.dimension}")
self.vectors[doc_id] = embedding
self.doc_ids.append(doc_id)
def search(self, query_embedding: np.ndarray, k: int = 10) -> List[Tuple[str, float]]:
"""搜索最近邻(暴力检索,实际应使用 HNSW)"""
if len(query_embedding) != self.dimension:
raise ValueError(f"Query embedding dimension {len(query_embedding)} != {self.dimension}")
results = []
for doc_id, vector in self.vectors.items():
similarity = self._compute_similarity(query_embedding, vector)
results.append((doc_id, similarity))
# 排序(相似度降序)
results.sort(key=lambda x: x[1], reverse=True)
return results[:k]
def delete(self, doc_id: str) -> bool:
"""删除向量"""
if doc_id in self.vectors:
del self.vectors[doc_id]
self.doc_ids.remove(doc_id)
return True
return False
def size(self) -> int:
"""返回向量数量"""
return len(self.vectors)
class RAGSystem:
"""
检索增强生成(RAG)系统
核心流程:
1. 索引:将文档转换为向量并存储
2. 检索:基于查询向量检索相关文档
3. 增强:将检索结果注入上下文
4. 生成:LLM 基于增强上下文生成回答
"""
def __init__(self,
dimension: int = 768,
metric: DistanceMetric = DistanceMetric.COSINE,
top_k: int = 5):
# 向量索引
self.index = VectorIndex(dimension, metric)
# 文档存储
self.documents: Dict[str, VectorDocument] = {}
# 参数配置
self.top_k = top_k
# 统计信息
self.stats = {
"total_indexed": 0,
"total_queries": 0,
"total_generations": 0
}
def index_document(self,
content: str,
embedding: np.ndarray,
metadata: Dict[str, Any] = None,
doc_id: str = None) -> VectorDocument:
"""
索引文档
Args:
content: 文档内容
embedding: 向量嵌入
metadata: 元数据
doc_id: 文档 ID(可选,自动生成)
Returns:
创建的文档
"""
# 生成 ID
if doc_id is None:
doc_id = hashlib.md5(
f"{content}{datetime.now().isoformat()}".encode()
).hexdigest()[:16]
# 创建文档
doc = VectorDocument(
id=doc_id,
content=content,
embedding=embedding,
metadata=metadata or {}
)
# 存储
self.documents[doc_id] = doc
self.index.insert(doc_id, embedding)
# 更新统计
self.stats["total_indexed"] += 1
return doc
def retrieve(self,
query_embedding: np.ndarray,
k: int = None,
filter_fn: callable = None) -> List[VectorDocument]:
"""
检索相关文档
Args:
query_embedding: 查询向量
k: 返回数量
filter_fn: 过滤函数(可选)
Returns:
相关文档列表
"""
k = k or self.top_k
# 检索
results = self.index.search(query_embedding, k * 2) # 多检索一些用于过滤
# 过滤
docs = []
for doc_id, similarity in results:
doc = self.documents[doc_id]
# 应用过滤
if filter_fn and not filter_fn(doc):
continue
# 添加相似度
doc.similarity = similarity
docs.append(doc)
self.stats["total_queries"] += 1
return docs[:k]
def build_context(self,
query: str,
retrieved_docs: List[VectorDocument],
max_tokens: int = 2000) -> str:
"""
构建增强上下文
Args:
query: 用户查询
retrieved_docs: 检索到的文档
max_tokens: 最大 token 数
Returns:
增强后的上下文字符串
"""
context_parts = []
# 添加系统提示
context_parts.append("以下是与问题相关的背景信息:")
context_parts.append("=" * 50)
# 添加检索到的文档
total_length = 0
for i, doc in enumerate(retrieved_docs, 1):
doc_text = f"\n[文档{i}] {doc.content}"
# 简单 token 估算(1 token ≈ 4 字符)
if total_length + len(doc_text) > max_tokens * 4:
break
context_parts.append(doc_text)
total_length += len(doc_text)
context_parts.append("\n" + "=" * 50)
context_parts.append(f"问题:{query}")
return "\n".join(context_parts)
def generate_response(self,
context: str,
llm_fn: callable = None) -> str:
"""
生成回答(模拟 LLM)
Args:
context: 增强上下文
llm_fn: LLM 调用函数(可选,默认使用模拟)
Returns:
生成的回答
"""
if llm_fn:
# 调用真实 LLM
response = llm_fn(context)
else:
# 模拟回答(实际应用中应调用 LLM API)
response = self._simulate_llm(context)
self.stats["total_generations"] += 1
return response
def _simulate_llm(self, context: str) -> str:
"""模拟 LLM 回答"""
# 提取问题
if "问题:" in context:
question = context.split("问题:")[-1].strip()
else:
question = "未知问题"
# 提取相关文档
docs = []
lines = context.split("\n")
for line in lines:
if line.startswith("[文档"):
content = line.split("]", 1)[-1].strip()
docs.append(content)
# 生成简单回答
if docs:
return f"根据检索到的信息,{question} 的答案是:{docs[0][:200]}..."
else:
return f"抱歉,未找到与'{question}'相关的信息。"
def query(self,
query_text: str,
query_embedding: np.ndarray,
llm_fn: callable = None) -> Dict[str, Any]:
"""
完整 RAG 查询流程
Args:
query_text: 查询文本
query_embedding: 查询向量
llm_fn: LLM 调用函数
Returns:
包含检索结果和生成的回答
"""
# 检索
retrieved_docs = self.retrieve(query_embedding)
# 构建上下文
context = self.build_context(query_text, retrieved_docs)
# 生成回答
response = self.generate_response(context, llm_fn)
return {
"query": query_text,
"retrieved_docs": [
{
"id": doc.id,
"content": doc.content,
"similarity": getattr(doc, 'similarity', 0)
}
for doc in retrieved_docs
],
"context": context,
"response": response
}
def get_stats(self) -> Dict[str, Any]:
"""获取统计信息"""
return {
**self.stats,
"total_documents": len(self.documents),
"index_size": self.index.size()
}
def export_index(self, filepath: str):
"""导出索引到文件"""
data = {
"documents": [doc.to_dict() for doc in self.documents.values()],
"stats": self.stats
}
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
def import_index(self, filepath: str):
"""从文件导入索引"""
with open(filepath, 'r', encoding='utf-8') as f:
data = json.load(f)
# 导入文档
self.documents.clear()
for doc_dict in data.get("documents", []):
doc = VectorDocument.from_dict(doc_dict)
self.documents[doc.id] = doc
self.index.insert(doc.id, doc.embedding)
# 恢复统计
self.stats.update(data.get("stats", {}))
# 使用示例
if __name__ == "__main__":
print("=== 向量数据库与记忆检索增强 ===\n")
# 创建 RAG 系统
rag_system = RAGSystem(
dimension=256, # 简化示例,实际应为 768 或 1024
metric=DistanceMetric.COSINE,
top_k=3
)
# 模拟嵌入生成
def generate_embedding(text: str) -> np.ndarray:
hash_val = int(hashlib.md5(text.encode()).hexdigest(), 16)
np.random.seed(hash_val % (2**32))
embedding = np.random.randn(256)
embedding /= np.linalg.norm(embedding)
return embedding
print("=== 索引文档 ===")
# 索引文档
docs_data = [
("Python 是一种高级编程语言", {"category": "programming"}),
("机器学习是人工智能的分支", {"category": "ai"}),
("深度学习使用神经网络", {"category": "ai"}),
("向量数据库用于相似检索", {"category": "database"}),
("RAG 是检索增强生成技术", {"category": "ai"}),
]
for content, metadata in docs_data:
doc = rag_system.index_document(
content=content,
embedding=generate_embedding(content),
metadata=metadata
)
print(f"索引文档:{doc.content}")
print(f"\n=== 系统统计 ===")
stats = rag_system.get_stats()
print(f"总索引:{stats['total_indexed']}")
print(f"总文档:{stats['total_documents']}")
print(f"\n=== RAG 查询 ===")
# 查询
query = "什么是 RAG 技术?"
query_embedding = generate_embedding(query)
result = rag_system.query(
query_text=query,
query_embedding=query_embedding
)
print(f"查询:{result['query']}")
print(f"\n检索到的文档:")
for i, doc in enumerate(result['retrieved_docs'], 1):
print(f"\n{i}. {doc['content']}")
print(f" 相似度:{doc['similarity']:.4f}")
print(f"\n生成的回答:")
print(result['response'])
print(f"\n=== 最终统计 ===")
final_stats = rag_system.get_stats()
print(f"总查询:{final_stats['total_queries']}")
print(f"总生成:{final_stats['total_generations']}")
print(f"\n关键观察:")
print("1. 向量数据库:存储和检索向量嵌入")
print("2. 相似检索:基于余弦相似度找到相关文档")
print("3. RAG 流程:检索→构建上下文→生成回答")
print("4. 上下文增强:将检索结果注入 LLM 提示")
print("5. 幻觉消除:基于事实生成,减少胡编乱造")
print("\nRAG 的核心:检索 + 增强 + 生成 = 准确可靠")