数据库组件

本章节详细介绍了 HRAG 系统中的数据库组件。

数据库介绍

本项目支持 Elasticsearch、Milvus、MySQL、Neo4j 四种数据库操作。具体安装流程见 数据库安装

数据库操作

下文详细介绍四种数据库的连接、数据插入删除、查找等操作。

Elastic

from src.database.operations.elastic_operations import upload_es, search_top, delete
from src.database.db_connection import es_connection

# 创建连接
client = es_connection()

# 指定 index 名称(要与 elastic_operations.py 中的一致)
index_name = "knowledge_test"

# 删除旧的 index(如果存在)
delete(index_name)

# 上传新的 pkl 文件(请提前确认路径下的 .pkl 文件存在)
pkl_path = "src/pkl_files/es_test"
upload_es(pkl_path, client)

# 执行搜索测试
result = search_top("东数西算", 3, index_name, client)

print("\n=== 检索结果 ===")
for item in result:
    print(item)

下载 Elastic 所需要的测试数据,es_test_data.pkl。数据格式为 List[dict],每个数据条目为包含以下字段的字典:

数据字段说明

字段

说明

filename

原始PDF文件名

chunk_id

文本块唯一标识

text

解析的文本内容(包含原文格式和换行)

Milvus

from pymilvus import utility, Collection
from src.database.db_connection import milvus_connection
from src.database.operations.milvus_operations import (
    create_collection,
    pkl_insert,
    search,
    delete_collection
)

# 初始化连接
milvus_connection()

# 配置参数
collection_name = "world_trade_report"
pkl_path = "src/pkl_files/vector_db.pkl"  # 确保此路径存在一个有效的.pkl
embedding_dim = 1024 # 根据 Embedding 模型确定
image_embedding = False  # 是否包含图片向量

# 删除旧集合(如果存在)
if utility.has_collection(collection_name):
    delete_collection(collection_name)

# 创建新集合
collection = create_collection(collection_name, embedding_dim)

# 插入数据
pkl_insert(collection, pkl_path, image_embedding=image_embedding)

question = "2020年世界贸易报告的主要内容是什么?"
# 测试检索
search(collection_name, question, image_embedding=image_embedding, result_type="text")

# (可选)删除集合
# delete_collection(collection_name)

milvus 插入的数据格式见 数据格式

Tip

embedding_dim 具体由对应的 Embedding 模型确定。本项目使用的纯文本 Embedding 模型 bge-reranker-large ,向量化后的向量为1024维;使用图片文本多模态 QwenVL 模型编码器,向量化后的向量为1536维。

MySQL

from tabulate import tabulate
from src.database.db_connection import mysql_connnection
from src.database.operations.mysql_operations import (
    create_database,
    create_table,
    import_pkl_to_mysql,
    query_all_records,
    query_with_conditions,
    query_aggregate,
    delete_table_alldata,
    drop_table
)

# 数据库和表配置
database_name = "mysqldb_test"
table_name = "pkltomysql"
pkl_path = "src/pkl_files/vector_db.pkl"

# 建立连接
connection = mysql_connnection(database_name=database_name)

# 可选:创建数据库(如未手动创建)
# create_database(connection, database_name)

# 创建表(如不存在)
create_table(connection, table_name)

# 导入数据
import_pkl_to_mysql(connection, table_name, pkl_path)

# 查询前5条记录
print("\n=== 查询前5条记录 ===")
records = query_all_records(connection, table_name, limit=5)
print(tabulate(records, headers="keys", tablefmt="grid"))

# 条件查询
print("\n=== 条件查询 page_number = 5 且 block_type = 'text' ===")
conditions = {'page_number': 5, 'block_type': 'text'}
result = query_with_conditions(connection, table_name, conditions)
print(tabulate(result, headers="keys", tablefmt="grid"))

# 聚合查询
print("\n=== 按 block_type 分组统计 ===")
stats = query_aggregate(connection, table_name, 'block_type')
print(tabulate(stats, headers="keys", tablefmt="grid"))

# ✅ 关闭连接
if connection:
    connection.close()
    print("\n✅ MySQL连接已关闭")

Neo4j

from src.database.db_connection import neo4j_connection_driver, neo4j_connection
from src.database.operations.neo4j_operation import (
    import_csv_to_neo4j,
    delete,
    Key_search_bytoken
)

# CSV 数据路径(可选导入构图)
csv_file_path = "src/resources/temp/database/all_data.csv"

# 用户词典(与 CSV 一致)
user_dict = "src/resources/temp/database/all_data.csv"

# 初始化 Neo4j 连接(graph 和 driver)
graph = neo4j_connection()
driver = neo4j_connection_driver()

# (可选)清空现有图数据
# delete(graph)

# (可选)导入 CSV 数据建图
# import_csv_to_neo4j(csv_file_path, graph)

# 设定用户问题和 top-k 返回数量
question = "在银屑病治疗过程中,糖皮质激素的作用是什么?"
top_k = 5

# 初始化搜索器并运行 pipeline
key_searcher = Key_search_bytoken(driver, question, top_k, user_dict)
print("Neo4j 图谱节点列表:", key_searcher.neo4j_nodes)

result = key_searcher.pipeline()

# 打印输出结果
print("\n=== Neo4j Top-k 结果关系 ===")
for item, score in result:
    print(f"{item} | Score: {score}")

# 关闭连接
driver.close()

下载 Neo4j 所需要的测试数据,neo4j_data.csv