AstrBot 4.3.3__py3-none-any.whl → 4.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/core/agent/mcp_client.py +18 -4
- astrbot/core/agent/runners/tool_loop_agent_runner.py +31 -2
- astrbot/core/astr_agent_context.py +1 -0
- astrbot/core/astrbot_config_mgr.py +23 -51
- astrbot/core/config/default.py +139 -14
- astrbot/core/conversation_mgr.py +36 -1
- astrbot/core/core_lifecycle.py +24 -5
- astrbot/core/db/migration/migra_45_to_46.py +44 -0
- astrbot/core/db/vec_db/base.py +33 -2
- astrbot/core/db/vec_db/faiss_impl/document_storage.py +310 -52
- astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +31 -3
- astrbot/core/db/vec_db/faiss_impl/vec_db.py +81 -23
- astrbot/core/file_token_service.py +6 -1
- astrbot/core/initial_loader.py +6 -3
- astrbot/core/knowledge_base/chunking/__init__.py +11 -0
- astrbot/core/knowledge_base/chunking/base.py +24 -0
- astrbot/core/knowledge_base/chunking/fixed_size.py +57 -0
- astrbot/core/knowledge_base/chunking/recursive.py +155 -0
- astrbot/core/knowledge_base/kb_db_sqlite.py +299 -0
- astrbot/core/knowledge_base/kb_helper.py +348 -0
- astrbot/core/knowledge_base/kb_mgr.py +287 -0
- astrbot/core/knowledge_base/models.py +114 -0
- astrbot/core/knowledge_base/parsers/__init__.py +15 -0
- astrbot/core/knowledge_base/parsers/base.py +50 -0
- astrbot/core/knowledge_base/parsers/markitdown_parser.py +25 -0
- astrbot/core/knowledge_base/parsers/pdf_parser.py +100 -0
- astrbot/core/knowledge_base/parsers/text_parser.py +41 -0
- astrbot/core/knowledge_base/parsers/util.py +13 -0
- astrbot/core/knowledge_base/retrieval/__init__.py +16 -0
- astrbot/core/knowledge_base/retrieval/hit_stopwords.txt +767 -0
- astrbot/core/knowledge_base/retrieval/manager.py +273 -0
- astrbot/core/knowledge_base/retrieval/rank_fusion.py +138 -0
- astrbot/core/knowledge_base/retrieval/sparse_retriever.py +130 -0
- astrbot/core/pipeline/process_stage/method/llm_request.py +61 -21
- astrbot/core/pipeline/process_stage/utils.py +80 -0
- astrbot/core/pipeline/scheduler.py +1 -1
- astrbot/core/platform/astr_message_event.py +8 -7
- astrbot/core/platform/manager.py +4 -0
- astrbot/core/platform/sources/misskey/misskey_adapter.py +380 -44
- astrbot/core/platform/sources/misskey/misskey_api.py +581 -45
- astrbot/core/platform/sources/misskey/misskey_event.py +76 -41
- astrbot/core/platform/sources/misskey/misskey_utils.py +254 -43
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +2 -1
- astrbot/core/platform/sources/satori/satori_adapter.py +27 -1
- astrbot/core/platform/sources/satori/satori_event.py +270 -77
- astrbot/core/platform/sources/webchat/webchat_adapter.py +0 -1
- astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +289 -0
- astrbot/core/platform/sources/wecom_ai_bot/__init__.py +17 -0
- astrbot/core/platform/sources/wecom_ai_bot/ierror.py +20 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +445 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +378 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +149 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +148 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +166 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +199 -0
- astrbot/core/provider/manager.py +14 -9
- astrbot/core/provider/provider.py +67 -0
- astrbot/core/provider/sources/anthropic_source.py +4 -4
- astrbot/core/provider/sources/dashscope_source.py +10 -9
- astrbot/core/provider/sources/dify_source.py +6 -8
- astrbot/core/provider/sources/gemini_embedding_source.py +1 -2
- astrbot/core/provider/sources/openai_embedding_source.py +1 -2
- astrbot/core/provider/sources/openai_source.py +18 -15
- astrbot/core/provider/sources/openai_tts_api_source.py +1 -1
- astrbot/core/star/context.py +3 -0
- astrbot/core/star/star.py +6 -0
- astrbot/core/star/star_manager.py +13 -7
- astrbot/core/umop_config_router.py +81 -0
- astrbot/core/updator.py +1 -1
- astrbot/core/utils/io.py +23 -12
- astrbot/dashboard/routes/__init__.py +2 -0
- astrbot/dashboard/routes/config.py +137 -9
- astrbot/dashboard/routes/knowledge_base.py +1065 -0
- astrbot/dashboard/routes/plugin.py +24 -5
- astrbot/dashboard/routes/tools.py +14 -0
- astrbot/dashboard/routes/update.py +1 -1
- astrbot/dashboard/server.py +6 -0
- astrbot/dashboard/utils.py +161 -0
- {astrbot-4.3.3.dist-info → astrbot-4.5.0.dist-info}/METADATA +91 -55
- {astrbot-4.3.3.dist-info → astrbot-4.5.0.dist-info}/RECORD +83 -50
- {astrbot-4.3.3.dist-info → astrbot-4.5.0.dist-info}/WHEEL +0 -0
- {astrbot-4.3.3.dist-info → astrbot-4.5.0.dist-info}/entry_points.txt +0 -0
- {astrbot-4.3.3.dist-info → astrbot-4.5.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,273 @@
|
|
|
1
|
+
"""检索管理器
|
|
2
|
+
|
|
3
|
+
协调稠密检索、稀疏检索和 Rerank,提供统一的检索接口
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import time
|
|
7
|
+
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from typing import List
|
|
10
|
+
|
|
11
|
+
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
|
|
12
|
+
from astrbot.core.knowledge_base.retrieval.rank_fusion import RankFusion
|
|
13
|
+
from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseRetriever
|
|
14
|
+
from astrbot.core.provider.provider import RerankProvider
|
|
15
|
+
from astrbot.core.db.vec_db.base import Result
|
|
16
|
+
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
|
|
17
|
+
from ..kb_helper import KBHelper
|
|
18
|
+
from astrbot import logger
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class RetrievalResult:
|
|
23
|
+
"""检索结果"""
|
|
24
|
+
|
|
25
|
+
chunk_id: str
|
|
26
|
+
doc_id: str
|
|
27
|
+
doc_name: str
|
|
28
|
+
kb_id: str
|
|
29
|
+
kb_name: str
|
|
30
|
+
content: str
|
|
31
|
+
score: float
|
|
32
|
+
metadata: dict
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class RetrievalManager:
|
|
36
|
+
"""检索管理器
|
|
37
|
+
|
|
38
|
+
职责:
|
|
39
|
+
- 协调稠密检索、稀疏检索和 Rerank
|
|
40
|
+
- 结果融合和排序
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
sparse_retriever: SparseRetriever,
|
|
46
|
+
rank_fusion: RankFusion,
|
|
47
|
+
kb_db: KBSQLiteDatabase,
|
|
48
|
+
):
|
|
49
|
+
"""初始化检索管理器
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
vec_db_factory: 向量数据库工厂
|
|
53
|
+
sparse_retriever: 稀疏检索器
|
|
54
|
+
rank_fusion: 结果融合器
|
|
55
|
+
kb_db: 知识库数据库实例
|
|
56
|
+
"""
|
|
57
|
+
self.sparse_retriever = sparse_retriever
|
|
58
|
+
self.rank_fusion = rank_fusion
|
|
59
|
+
self.kb_db = kb_db
|
|
60
|
+
|
|
61
|
+
async def retrieve(
|
|
62
|
+
self,
|
|
63
|
+
query: str,
|
|
64
|
+
kb_ids: List[str],
|
|
65
|
+
kb_id_helper_map: dict[str, KBHelper],
|
|
66
|
+
top_k_fusion: int = 20,
|
|
67
|
+
top_m_final: int = 5,
|
|
68
|
+
) -> List[RetrievalResult]:
|
|
69
|
+
"""混合检索
|
|
70
|
+
|
|
71
|
+
流程:
|
|
72
|
+
1. 稠密检索 (向量相似度)
|
|
73
|
+
2. 稀疏检索 (BM25)
|
|
74
|
+
3. 结果融合 (RRF)
|
|
75
|
+
4. Rerank 重排序
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
query: 查询文本
|
|
79
|
+
kb_ids: 知识库 ID 列表
|
|
80
|
+
top_m_final: 最终返回数量
|
|
81
|
+
enable_rerank: 是否启用 Rerank
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
List[RetrievalResult]: 检索结果列表
|
|
85
|
+
"""
|
|
86
|
+
if not kb_ids:
|
|
87
|
+
return []
|
|
88
|
+
|
|
89
|
+
kb_options: dict = {}
|
|
90
|
+
new_kb_ids = []
|
|
91
|
+
for kb_id in kb_ids:
|
|
92
|
+
kb_helper = kb_id_helper_map.get(kb_id)
|
|
93
|
+
if kb_helper:
|
|
94
|
+
kb = kb_helper.kb
|
|
95
|
+
kb_options[kb_id] = {
|
|
96
|
+
"top_k_dense": kb.top_k_dense or 50,
|
|
97
|
+
"top_k_sparse": kb.top_k_sparse or 50,
|
|
98
|
+
"top_m_final": kb.top_m_final or 5,
|
|
99
|
+
"vec_db": kb_helper.vec_db,
|
|
100
|
+
"rerank_provider_id": kb.rerank_provider_id,
|
|
101
|
+
}
|
|
102
|
+
new_kb_ids.append(kb_id)
|
|
103
|
+
else:
|
|
104
|
+
logger.warning(f"知识库 ID {kb_id} 实例未找到, 已跳过该知识库的检索")
|
|
105
|
+
|
|
106
|
+
kb_ids = new_kb_ids
|
|
107
|
+
|
|
108
|
+
# 1. 稠密检索
|
|
109
|
+
time_start = time.time()
|
|
110
|
+
dense_results = await self._dense_retrieve(
|
|
111
|
+
query=query,
|
|
112
|
+
kb_ids=kb_ids,
|
|
113
|
+
kb_options=kb_options,
|
|
114
|
+
)
|
|
115
|
+
time_end = time.time()
|
|
116
|
+
logger.debug(
|
|
117
|
+
f"Dense retrieval across {len(kb_ids)} bases took {time_end - time_start:.2f}s and returned {len(dense_results)} results."
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
# 2. 稀疏检索
|
|
121
|
+
time_start = time.time()
|
|
122
|
+
sparse_results = await self.sparse_retriever.retrieve(
|
|
123
|
+
query=query,
|
|
124
|
+
kb_ids=kb_ids,
|
|
125
|
+
kb_options=kb_options,
|
|
126
|
+
)
|
|
127
|
+
time_end = time.time()
|
|
128
|
+
logger.debug(
|
|
129
|
+
f"Sparse retrieval across {len(kb_ids)} bases took {time_end - time_start:.2f}s and returned {len(sparse_results)} results."
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# 3. 结果融合
|
|
133
|
+
time_start = time.time()
|
|
134
|
+
fused_results = await self.rank_fusion.fuse(
|
|
135
|
+
dense_results=dense_results,
|
|
136
|
+
sparse_results=sparse_results,
|
|
137
|
+
top_k=top_k_fusion,
|
|
138
|
+
)
|
|
139
|
+
time_end = time.time()
|
|
140
|
+
logger.debug(
|
|
141
|
+
f"Rank fusion took {time_end - time_start:.2f}s and returned {len(fused_results)} results."
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
# 4. 转换为 RetrievalResult (获取元数据)
|
|
145
|
+
retrieval_results = []
|
|
146
|
+
for fr in fused_results:
|
|
147
|
+
metadata_dict = await self.kb_db.get_document_with_metadata(fr.doc_id)
|
|
148
|
+
if metadata_dict:
|
|
149
|
+
retrieval_results.append(
|
|
150
|
+
RetrievalResult(
|
|
151
|
+
chunk_id=fr.chunk_id,
|
|
152
|
+
doc_id=fr.doc_id,
|
|
153
|
+
doc_name=metadata_dict["document"].doc_name,
|
|
154
|
+
kb_id=fr.kb_id,
|
|
155
|
+
kb_name=metadata_dict["knowledge_base"].kb_name,
|
|
156
|
+
content=fr.content,
|
|
157
|
+
score=fr.score,
|
|
158
|
+
metadata={
|
|
159
|
+
"chunk_index": fr.chunk_index,
|
|
160
|
+
"char_count": len(fr.content),
|
|
161
|
+
},
|
|
162
|
+
)
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
# 5. Rerank
|
|
166
|
+
first_rerank = None
|
|
167
|
+
for kb_id in kb_ids:
|
|
168
|
+
vec_db: FaissVecDB = kb_options[kb_id]["vec_db"]
|
|
169
|
+
rerank_pi = kb_options[kb_id]["rerank_provider_id"]
|
|
170
|
+
if (
|
|
171
|
+
vec_db
|
|
172
|
+
and vec_db.rerank_provider
|
|
173
|
+
and rerank_pi
|
|
174
|
+
and rerank_pi == vec_db.rerank_provider.meta().id
|
|
175
|
+
):
|
|
176
|
+
first_rerank = vec_db.rerank_provider
|
|
177
|
+
break
|
|
178
|
+
if first_rerank and retrieval_results:
|
|
179
|
+
retrieval_results = await self._rerank(
|
|
180
|
+
query=query,
|
|
181
|
+
results=retrieval_results,
|
|
182
|
+
top_k=top_m_final,
|
|
183
|
+
rerank_provider=first_rerank,
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
return retrieval_results[:top_m_final]
|
|
187
|
+
|
|
188
|
+
async def _dense_retrieve(
|
|
189
|
+
self,
|
|
190
|
+
query: str,
|
|
191
|
+
kb_ids: List[str],
|
|
192
|
+
kb_options: dict,
|
|
193
|
+
):
|
|
194
|
+
"""稠密检索 (向量相似度)
|
|
195
|
+
|
|
196
|
+
为每个知识库使用独立的向量数据库进行检索,然后合并结果。
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
query: 查询文本
|
|
200
|
+
kb_ids: 知识库 ID 列表
|
|
201
|
+
top_k: 返回结果数量
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
List[Result]: 检索结果列表
|
|
205
|
+
"""
|
|
206
|
+
all_results: list[Result] = []
|
|
207
|
+
for kb_id in kb_ids:
|
|
208
|
+
if kb_id not in kb_options:
|
|
209
|
+
continue
|
|
210
|
+
try:
|
|
211
|
+
vec_db: FaissVecDB = kb_options[kb_id]["vec_db"]
|
|
212
|
+
dense_k = int(kb_options[kb_id]["top_k_dense"])
|
|
213
|
+
vec_results = await vec_db.retrieve(
|
|
214
|
+
query=query,
|
|
215
|
+
k=dense_k,
|
|
216
|
+
fetch_k=dense_k * 2,
|
|
217
|
+
rerank=False, # 稠密检索阶段不进行 rerank
|
|
218
|
+
metadata_filters={"kb_id": kb_id},
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
all_results.extend(vec_results)
|
|
222
|
+
except Exception as e:
|
|
223
|
+
from astrbot.core import logger
|
|
224
|
+
|
|
225
|
+
logger.warning(f"知识库 {kb_id} 稠密检索失败: {e}")
|
|
226
|
+
continue
|
|
227
|
+
|
|
228
|
+
# 按相似度排序并返回 top_k
|
|
229
|
+
all_results.sort(key=lambda x: x.similarity, reverse=True)
|
|
230
|
+
# return all_results[: len(all_results) // len(kb_ids)]
|
|
231
|
+
return all_results
|
|
232
|
+
|
|
233
|
+
async def _rerank(
|
|
234
|
+
self,
|
|
235
|
+
query: str,
|
|
236
|
+
results: List[RetrievalResult],
|
|
237
|
+
top_k: int,
|
|
238
|
+
rerank_provider: RerankProvider,
|
|
239
|
+
) -> List[RetrievalResult]:
|
|
240
|
+
"""Rerank 重排序
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
query: 查询文本
|
|
244
|
+
results: 检索结果列表
|
|
245
|
+
top_k: 返回结果数量
|
|
246
|
+
|
|
247
|
+
Returns:
|
|
248
|
+
List[RetrievalResult]: 重排序后的结果列表
|
|
249
|
+
"""
|
|
250
|
+
if not results:
|
|
251
|
+
return []
|
|
252
|
+
|
|
253
|
+
# 准备文档列表
|
|
254
|
+
docs = [r.content for r in results]
|
|
255
|
+
|
|
256
|
+
# 调用 Rerank Provider
|
|
257
|
+
rerank_results = await rerank_provider.rerank(
|
|
258
|
+
query=query,
|
|
259
|
+
documents=docs,
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
# 更新分数并重新排序
|
|
263
|
+
reranked_list = []
|
|
264
|
+
for rerank_result in rerank_results:
|
|
265
|
+
idx = rerank_result.index
|
|
266
|
+
if idx < len(results):
|
|
267
|
+
result = results[idx]
|
|
268
|
+
result.score = rerank_result.relevance_score
|
|
269
|
+
reranked_list.append(result)
|
|
270
|
+
|
|
271
|
+
reranked_list.sort(key=lambda x: x.score, reverse=True)
|
|
272
|
+
|
|
273
|
+
return reranked_list[:top_k]
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
"""检索结果融合器
|
|
2
|
+
|
|
3
|
+
使用 Reciprocal Rank Fusion (RRF) 算法融合稠密检索和稀疏检索的结果
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import json
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
|
|
9
|
+
from astrbot.core.db.vec_db.base import Result
|
|
10
|
+
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
|
|
11
|
+
from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseResult
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class FusedResult:
|
|
16
|
+
"""融合后的检索结果"""
|
|
17
|
+
|
|
18
|
+
chunk_id: str
|
|
19
|
+
chunk_index: int
|
|
20
|
+
doc_id: str
|
|
21
|
+
kb_id: str
|
|
22
|
+
content: str
|
|
23
|
+
score: float
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class RankFusion:
|
|
27
|
+
"""检索结果融合器
|
|
28
|
+
|
|
29
|
+
职责:
|
|
30
|
+
- 融合稠密检索和稀疏检索的结果
|
|
31
|
+
- 使用 Reciprocal Rank Fusion (RRF) 算法
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, kb_db: KBSQLiteDatabase, k: int = 60):
|
|
35
|
+
"""初始化结果融合器
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
kb_db: 知识库数据库实例
|
|
39
|
+
k: RRF 参数,用于平滑排名
|
|
40
|
+
"""
|
|
41
|
+
self.kb_db = kb_db
|
|
42
|
+
self.k = k
|
|
43
|
+
|
|
44
|
+
async def fuse(
|
|
45
|
+
self,
|
|
46
|
+
dense_results: list[Result],
|
|
47
|
+
sparse_results: list[SparseResult],
|
|
48
|
+
top_k: int = 20,
|
|
49
|
+
) -> list[FusedResult]:
|
|
50
|
+
"""融合稠密和稀疏检索结果
|
|
51
|
+
|
|
52
|
+
RRF 公式:
|
|
53
|
+
score(doc) = sum(1 / (k + rank_i))
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
dense_results: 稠密检索结果
|
|
57
|
+
sparse_results: 稀疏检索结果
|
|
58
|
+
top_k: 返回结果数量
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
List[FusedResult]: 融合后的结果列表
|
|
62
|
+
"""
|
|
63
|
+
# 1. 构建排名映射
|
|
64
|
+
dense_ranks = {
|
|
65
|
+
r.data["doc_id"]: (idx + 1) for idx, r in enumerate(dense_results)
|
|
66
|
+
} # 这里的 doc_id 实际上是 chunk_id
|
|
67
|
+
sparse_ranks = {r.chunk_id: (idx + 1) for idx, r in enumerate(sparse_results)}
|
|
68
|
+
|
|
69
|
+
# 2. 收集所有唯一的 ID
|
|
70
|
+
# 需要统一为 chunk_id
|
|
71
|
+
all_chunk_ids = set()
|
|
72
|
+
vec_doc_id_to_dense: dict[str, Result] = {} # vec_doc_id -> Result
|
|
73
|
+
chunk_id_to_sparse: dict[str, SparseResult] = {} # chunk_id -> SparseResult
|
|
74
|
+
|
|
75
|
+
# 处理稀疏检索结果
|
|
76
|
+
for r in sparse_results:
|
|
77
|
+
all_chunk_ids.add(r.chunk_id)
|
|
78
|
+
chunk_id_to_sparse[r.chunk_id] = r
|
|
79
|
+
|
|
80
|
+
# 处理稠密检索结果 (需要转换 vec_doc_id 到 chunk_id)
|
|
81
|
+
for r in dense_results:
|
|
82
|
+
vec_doc_id = r.data["doc_id"]
|
|
83
|
+
all_chunk_ids.add(vec_doc_id)
|
|
84
|
+
vec_doc_id_to_dense[vec_doc_id] = r
|
|
85
|
+
|
|
86
|
+
# 3. 计算 RRF 分数
|
|
87
|
+
rrf_scores: dict[str, float] = {}
|
|
88
|
+
|
|
89
|
+
for identifier in all_chunk_ids:
|
|
90
|
+
score = 0.0
|
|
91
|
+
|
|
92
|
+
# 来自稠密检索的贡献
|
|
93
|
+
if identifier in dense_ranks:
|
|
94
|
+
score += 1.0 / (self.k + dense_ranks[identifier])
|
|
95
|
+
|
|
96
|
+
# 来自稀疏检索的贡献
|
|
97
|
+
if identifier in sparse_ranks:
|
|
98
|
+
score += 1.0 / (self.k + sparse_ranks[identifier])
|
|
99
|
+
|
|
100
|
+
rrf_scores[identifier] = score
|
|
101
|
+
|
|
102
|
+
# 4. 排序
|
|
103
|
+
sorted_ids = sorted(
|
|
104
|
+
rrf_scores.keys(), key=lambda cid: rrf_scores[cid], reverse=True
|
|
105
|
+
)[:top_k]
|
|
106
|
+
|
|
107
|
+
# 5. 构建融合结果
|
|
108
|
+
fused_results = []
|
|
109
|
+
for identifier in sorted_ids:
|
|
110
|
+
# 优先从稀疏检索获取完整信息
|
|
111
|
+
if identifier in chunk_id_to_sparse:
|
|
112
|
+
sr = chunk_id_to_sparse[identifier]
|
|
113
|
+
fused_results.append(
|
|
114
|
+
FusedResult(
|
|
115
|
+
chunk_id=sr.chunk_id,
|
|
116
|
+
chunk_index=sr.chunk_index,
|
|
117
|
+
doc_id=sr.doc_id,
|
|
118
|
+
kb_id=sr.kb_id,
|
|
119
|
+
content=sr.content,
|
|
120
|
+
score=rrf_scores[identifier],
|
|
121
|
+
)
|
|
122
|
+
)
|
|
123
|
+
elif identifier in vec_doc_id_to_dense:
|
|
124
|
+
# 从向量检索获取信息,需要从数据库获取块的详细信息
|
|
125
|
+
vec_result = vec_doc_id_to_dense[identifier]
|
|
126
|
+
chunk_md = json.loads(vec_result.data["metadata"])
|
|
127
|
+
fused_results.append(
|
|
128
|
+
FusedResult(
|
|
129
|
+
chunk_id=identifier,
|
|
130
|
+
chunk_index=chunk_md["chunk_index"],
|
|
131
|
+
doc_id=chunk_md["kb_doc_id"],
|
|
132
|
+
kb_id=chunk_md["kb_id"],
|
|
133
|
+
content=vec_result.data["text"],
|
|
134
|
+
score=rrf_scores[identifier],
|
|
135
|
+
)
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
return fused_results
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
"""稀疏检索器
|
|
2
|
+
|
|
3
|
+
使用 BM25 算法进行基于关键词的文档检索
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import jieba
|
|
7
|
+
import os
|
|
8
|
+
import json
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from rank_bm25 import BM25Okapi
|
|
11
|
+
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
|
|
12
|
+
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class SparseResult:
|
|
17
|
+
"""稀疏检索结果"""
|
|
18
|
+
|
|
19
|
+
chunk_index: int
|
|
20
|
+
chunk_id: str
|
|
21
|
+
doc_id: str
|
|
22
|
+
kb_id: str
|
|
23
|
+
content: str
|
|
24
|
+
score: float
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class SparseRetriever:
|
|
28
|
+
"""BM25 稀疏检索器
|
|
29
|
+
|
|
30
|
+
职责:
|
|
31
|
+
- 基于关键词的文档检索
|
|
32
|
+
- 使用 BM25 算法计算相关度
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(self, kb_db: KBSQLiteDatabase):
|
|
36
|
+
"""初始化稀疏检索器
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
kb_db: 知识库数据库实例
|
|
40
|
+
"""
|
|
41
|
+
self.kb_db = kb_db
|
|
42
|
+
self._index_cache = {} # 缓存 BM25 索引
|
|
43
|
+
|
|
44
|
+
with open(
|
|
45
|
+
os.path.join(os.path.dirname(__file__), "hit_stopwords.txt"),
|
|
46
|
+
encoding="utf-8",
|
|
47
|
+
) as f:
|
|
48
|
+
self.hit_stopwords = {
|
|
49
|
+
word.strip() for word in set(f.read().splitlines()) if word.strip()
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
async def retrieve(
|
|
53
|
+
self,
|
|
54
|
+
query: str,
|
|
55
|
+
kb_ids: list[str],
|
|
56
|
+
kb_options: dict,
|
|
57
|
+
) -> list[SparseResult]:
|
|
58
|
+
"""执行稀疏检索
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
query: 查询文本
|
|
62
|
+
kb_ids: 知识库 ID 列表
|
|
63
|
+
kb_options: 每个知识库的检索选项
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
List[SparseResult]: 检索结果列表
|
|
67
|
+
"""
|
|
68
|
+
# 1. 获取所有相关块
|
|
69
|
+
top_k_sparse = 0
|
|
70
|
+
chunks = []
|
|
71
|
+
for kb_id in kb_ids:
|
|
72
|
+
vec_db: FaissVecDB = kb_options.get(kb_id, {}).get("vec_db")
|
|
73
|
+
if not vec_db:
|
|
74
|
+
continue
|
|
75
|
+
result = await vec_db.document_storage.get_documents(
|
|
76
|
+
metadata_filters={}, limit=None, offset=None
|
|
77
|
+
)
|
|
78
|
+
chunk_mds = [json.loads(doc["metadata"]) for doc in result]
|
|
79
|
+
result = [
|
|
80
|
+
{
|
|
81
|
+
"chunk_id": doc["doc_id"],
|
|
82
|
+
"chunk_index": chunk_md["chunk_index"],
|
|
83
|
+
"doc_id": chunk_md["kb_doc_id"],
|
|
84
|
+
"kb_id": kb_id,
|
|
85
|
+
"text": doc["text"],
|
|
86
|
+
}
|
|
87
|
+
for doc, chunk_md in zip(result, chunk_mds)
|
|
88
|
+
]
|
|
89
|
+
chunks.extend(result)
|
|
90
|
+
top_k_sparse += kb_options.get(kb_id, {}).get("top_k_sparse", 50)
|
|
91
|
+
|
|
92
|
+
if not chunks:
|
|
93
|
+
return []
|
|
94
|
+
|
|
95
|
+
# 2. 准备文档和索引
|
|
96
|
+
corpus = [chunk["text"] for chunk in chunks]
|
|
97
|
+
tokenized_corpus = [list(jieba.cut(doc)) for doc in corpus]
|
|
98
|
+
tokenized_corpus = [
|
|
99
|
+
[word for word in doc if word not in self.hit_stopwords]
|
|
100
|
+
for doc in tokenized_corpus
|
|
101
|
+
]
|
|
102
|
+
|
|
103
|
+
# 3. 构建 BM25 索引
|
|
104
|
+
bm25 = BM25Okapi(tokenized_corpus)
|
|
105
|
+
|
|
106
|
+
# 4. 执行检索
|
|
107
|
+
tokenized_query = list(jieba.cut(query))
|
|
108
|
+
tokenized_query = [
|
|
109
|
+
word for word in tokenized_query if word not in self.hit_stopwords
|
|
110
|
+
]
|
|
111
|
+
scores = bm25.get_scores(tokenized_query)
|
|
112
|
+
|
|
113
|
+
# 5. 排序并返回 Top-K
|
|
114
|
+
results = []
|
|
115
|
+
for idx, score in enumerate(scores):
|
|
116
|
+
chunk = chunks[idx]
|
|
117
|
+
results.append(
|
|
118
|
+
SparseResult(
|
|
119
|
+
chunk_id=chunk["chunk_id"],
|
|
120
|
+
chunk_index=chunk["chunk_index"],
|
|
121
|
+
doc_id=chunk["doc_id"],
|
|
122
|
+
kb_id=chunk["kb_id"],
|
|
123
|
+
content=chunk["text"],
|
|
124
|
+
score=float(score),
|
|
125
|
+
)
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
results.sort(key=lambda x: x.score, reverse=True)
|
|
129
|
+
# return results[: len(results) // len(kb_ids)]
|
|
130
|
+
return results[:top_k_sparse]
|