向量数据库统一查询实现
import numpy as np
from typing import List, Dict, Optional, Union
import sqlite3
import json
class UnifiedVectorDatabase:
"""
统一向量数据库
支持结构化元数据 + 非结构化向量嵌入的统一查询
"""
def __init__(self, db_path: str = "unified.db", embedding_dim: int = 768):
"""
初始化
Args:
db_path: 数据库路径
embedding_dim: 向量维度
"""
self.db_path = db_path
self.embedding_dim = embedding_dim
# 初始化数据库
self._init_db()
def _init_db(self):
"""初始化数据库表"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# 创建文档表(结构化元数据)
cursor.execute('''
CREATE TABLE IF NOT EXISTS documents (
id INTEGER PRIMARY KEY AUTOINCREMENT,
title TEXT NOT NULL,
content TEXT,
doc_type TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
metadata JSON
)
''')
# 创建向量表(非结构化嵌入)
cursor.execute('''
CREATE TABLE IF NOT EXISTS embeddings (
id INTEGER PRIMARY KEY AUTOINCREMENT,
doc_id INTEGER,
embedding BLOB, -- 存储为二进制
FOREIGN KEY (doc_id) REFERENCES documents(id)
)
''')
# 创建索引
cursor.execute('CREATE INDEX IF NOT EXISTS idx_doc_type ON documents(doc_type)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_doc_id ON embeddings(doc_id)')
conn.commit()
conn.close()
def insert_document(self, title: str, content: str, doc_type: str,
embedding: np.ndarray, metadata: Optional[Dict] = None):
"""
插入文档及其向量嵌入
Args:
title: 文档标题
content: 文档内容
doc_type: 文档类型
embedding: 向量嵌入 (numpy array)
metadata: 额外元数据
"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# 插入结构化元数据
metadata_json = json.dumps(metadata) if metadata else None
cursor.execute('''
INSERT INTO documents (title, content, doc_type, metadata)
VALUES (?, ?, ?, ?)
''', (title, content, doc_type, metadata_json))
doc_id = cursor.lastrowid
# 插入向量嵌入(转换为二进制)
embedding_bytes = embedding.astype(np.float32).tobytes()
cursor.execute('''
INSERT INTO embeddings (doc_id, embedding)
VALUES (?, ?)
''', (doc_id, embedding_bytes))
conn.commit()
conn.close()
return doc_id
def similarity_search(self, query_embedding: np.ndarray,
top_k: int = 10,
filters: Optional[Dict] = None) -> List[Dict]:
"""
相似度搜索
Args:
query_embedding: 查询向量
top_k: 返回结果数量
filters: 结构化过滤条件(如 doc_type='article')
Returns:
results: 搜索结果列表
"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# 构建过滤条件
where_clause = ""
params = []
if filters:
conditions = []
for key, value in filters.items():
conditions.append(f"d.{key} = ?")
params.append(value)
where_clause = "WHERE " + " AND ".join(conditions)
# 查询所有向量和元数据
query = f'''
SELECT d.id, d.title, d.content, d.doc_type, d.metadata, e.embedding
FROM documents d
JOIN embeddings e ON d.id = e.doc_id
{where_clause}
'''
cursor.execute(query, params)
rows = cursor.fetchall()
conn.close()
# 计算相似度(余弦相似度)
results = []
query_vec = query_embedding.astype(np.float32)
query_norm = np.linalg.norm(query_vec)
for row in rows:
doc_id, title, content, doc_type, metadata_json, embedding_bytes = row
# 恢复向量
embedding_vec = np.frombuffer(embedding_bytes, dtype=np.float32)
emb_norm = np.linalg.norm(embedding_vec)
# 余弦相似度
if query_norm > 0 and emb_norm > 0:
similarity = np.dot(query_vec, embedding_vec) / (query_norm * emb_norm)
else:
similarity = 0.0
results.append({
"id": doc_id,
"title": title,
"content": content,
"doc_type": doc_type,
"metadata": json.loads(metadata_json) if metadata_json else None,
"similarity": float(similarity)
})
# 按相似度排序
results.sort(key=lambda x: x["similarity"], reverse=True)
return results[:top_k]
def hybrid_query(self, sql_condition: str, query_embedding: np.ndarray,
top_k: int = 10) -> List[Dict]:
"""
混合查询:SQL 条件 + 向量相似度
Args:
sql_condition: SQL WHERE 条件(如 "doc_type = 'article' AND created_at > '2025-01-01'")
query_embedding: 查询向量
top_k: 返回结果数量
Returns:
results: 混合查询结果
"""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# 执行混合查询
query = f'''
SELECT d.id, d.title, d.content, d.doc_type, d.metadata,
d.created_at, e.embedding
FROM documents d
JOIN embeddings e ON d.id = e.doc_id
WHERE {sql_condition}
'''
cursor.execute(query)
rows = cursor.fetchall()
conn.close()
# 计算相似度并排序
results = []
query_vec = query_embedding.astype(np.float32)
query_norm = np.linalg.norm(query_vec)
for row in rows:
doc_id, title, content, doc_type, metadata_json, created_at, embedding_bytes = row
# 恢复向量
embedding_vec = np.frombuffer(embedding_bytes, dtype=np.float32)
emb_norm = np.linalg.norm(embedding_vec)
# 余弦相似度
if query_norm > 0 and emb_norm > 0:
similarity = np.dot(query_vec, embedding_vec) / (query_norm * emb_norm)
else:
similarity = 0.0
results.append({
"id": doc_id,
"title": title,
"content": content[:200] + "..." if len(content) > 200 else content,
"doc_type": doc_type,
"metadata": json.loads(metadata_json) if metadata_json else None,
"created_at": created_at,
"similarity": float(similarity)
})
# 按相似度排序
results.sort(key=lambda x: x["similarity"], reverse=True)
return results[:top_k]
# 使用示例
if __name__ == "__main__":
# 初始化数据库
db = UnifiedVectorDatabase()
# 模拟嵌入(实际应使用 BERT 等模型生成)
np.random.seed(42)
print("向量数据库统一查询示例:")
print("="*70 + "\n")
# 插入示例文档
docs = [
{
"title": "人工智能发展史",
"content": "人工智能从 1956 年达特茅斯会议开始...",
"doc_type": "article",
"embedding": np.random.randn(768).astype(np.float32)
},
{
"title": "机器学习入门",
"content": "机器学习是人工智能的核心分支...",
"doc_type": "tutorial",
"embedding": np.random.randn(768).astype(np.float32)
},
{
"title": "深度学习实战",
"content": "深度学习基于神经网络模型...",
"doc_type": "book",
"embedding": np.random.randn(768).astype(np.float32)
}
]
print("1. 插入文档及向量嵌入:")
for doc in docs:
doc_id = db.insert_document(
title=doc["title"],
content=doc["content"],
doc_type=doc["doc_type"],
embedding=doc["embedding"],
metadata={"author": "AI Lab", "year": 2025}
)
print(f" ✓ 插入文档 ID={doc_id}: {doc['title']}")
print("\n" + "="*70 + "\n")
# 相似度搜索
query_embedding = np.random.randn(768).astype(np.float32)
print("2. 纯向量相似度搜索:")
results = db.similarity_search(query_embedding, top_k=2)
for i, result in enumerate(results, 1):
print(f" {i}. {result['title']} (相似度:{result['similarity']:.4f})")
print("\n" + "="*70 + "\n")
# 混合查询
print("3. SQL+ 向量混合查询:")
print(" 条件:doc_type = 'article' OR doc_type = 'tutorial'")
results = db.hybrid_query(
sql_condition="d.doc_type IN ('article', 'tutorial')",
query_embedding=query_embedding,
top_k=2
)
for i, result in enumerate(results, 1):
print(f" {i}. {result['title']} [类型:{result['doc_type']}] (相似度:{result['similarity']:.4f})")
print("\n" + "="*70)
print("\n关键观察:")
print("1. 结构化元数据(标题、类型、时间)存储在关系表")
print("2. 非结构化向量嵌入存储在向量表")
print("3. 支持纯向量相似度搜索")
print("4. 支持 SQL+ 向量混合查询(统一查询)")
print("5. 实现结构化与非结构化数据的统一处理")