AstrBot 4.3.5__py3-none-any.whl → 4.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. astrbot/core/agent/runners/tool_loop_agent_runner.py +31 -2
  2. astrbot/core/astrbot_config_mgr.py +23 -51
  3. astrbot/core/config/default.py +92 -12
  4. astrbot/core/conversation_mgr.py +36 -1
  5. astrbot/core/core_lifecycle.py +24 -5
  6. astrbot/core/db/migration/migra_45_to_46.py +44 -0
  7. astrbot/core/db/vec_db/base.py +33 -2
  8. astrbot/core/db/vec_db/faiss_impl/document_storage.py +310 -52
  9. astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +31 -3
  10. astrbot/core/db/vec_db/faiss_impl/vec_db.py +81 -23
  11. astrbot/core/file_token_service.py +6 -1
  12. astrbot/core/initial_loader.py +6 -3
  13. astrbot/core/knowledge_base/chunking/__init__.py +11 -0
  14. astrbot/core/knowledge_base/chunking/base.py +24 -0
  15. astrbot/core/knowledge_base/chunking/fixed_size.py +57 -0
  16. astrbot/core/knowledge_base/chunking/recursive.py +155 -0
  17. astrbot/core/knowledge_base/kb_db_sqlite.py +299 -0
  18. astrbot/core/knowledge_base/kb_helper.py +348 -0
  19. astrbot/core/knowledge_base/kb_mgr.py +287 -0
  20. astrbot/core/knowledge_base/models.py +114 -0
  21. astrbot/core/knowledge_base/parsers/__init__.py +15 -0
  22. astrbot/core/knowledge_base/parsers/base.py +50 -0
  23. astrbot/core/knowledge_base/parsers/markitdown_parser.py +25 -0
  24. astrbot/core/knowledge_base/parsers/pdf_parser.py +100 -0
  25. astrbot/core/knowledge_base/parsers/text_parser.py +41 -0
  26. astrbot/core/knowledge_base/parsers/util.py +13 -0
  27. astrbot/core/knowledge_base/retrieval/__init__.py +16 -0
  28. astrbot/core/knowledge_base/retrieval/hit_stopwords.txt +767 -0
  29. astrbot/core/knowledge_base/retrieval/manager.py +273 -0
  30. astrbot/core/knowledge_base/retrieval/rank_fusion.py +138 -0
  31. astrbot/core/knowledge_base/retrieval/sparse_retriever.py +130 -0
  32. astrbot/core/pipeline/process_stage/method/llm_request.py +29 -7
  33. astrbot/core/pipeline/process_stage/utils.py +80 -0
  34. astrbot/core/platform/astr_message_event.py +8 -7
  35. astrbot/core/platform/sources/misskey/misskey_adapter.py +380 -44
  36. astrbot/core/platform/sources/misskey/misskey_api.py +581 -45
  37. astrbot/core/platform/sources/misskey/misskey_event.py +76 -41
  38. astrbot/core/platform/sources/misskey/misskey_utils.py +254 -43
  39. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +2 -1
  40. astrbot/core/platform/sources/satori/satori_adapter.py +27 -1
  41. astrbot/core/platform/sources/satori/satori_event.py +270 -99
  42. astrbot/core/provider/manager.py +14 -9
  43. astrbot/core/provider/provider.py +67 -0
  44. astrbot/core/provider/sources/anthropic_source.py +4 -4
  45. astrbot/core/provider/sources/dashscope_source.py +10 -9
  46. astrbot/core/provider/sources/dify_source.py +6 -8
  47. astrbot/core/provider/sources/gemini_embedding_source.py +1 -2
  48. astrbot/core/provider/sources/openai_embedding_source.py +1 -2
  49. astrbot/core/provider/sources/openai_source.py +18 -15
  50. astrbot/core/provider/sources/openai_tts_api_source.py +1 -1
  51. astrbot/core/star/context.py +3 -0
  52. astrbot/core/star/star.py +6 -0
  53. astrbot/core/star/star_manager.py +13 -7
  54. astrbot/core/umop_config_router.py +81 -0
  55. astrbot/core/updator.py +1 -1
  56. astrbot/core/utils/io.py +23 -12
  57. astrbot/dashboard/routes/__init__.py +2 -0
  58. astrbot/dashboard/routes/config.py +137 -9
  59. astrbot/dashboard/routes/knowledge_base.py +1065 -0
  60. astrbot/dashboard/routes/plugin.py +24 -5
  61. astrbot/dashboard/routes/update.py +1 -1
  62. astrbot/dashboard/server.py +6 -0
  63. astrbot/dashboard/utils.py +161 -0
  64. {astrbot-4.3.5.dist-info → astrbot-4.5.0.dist-info}/METADATA +29 -13
  65. {astrbot-4.3.5.dist-info → astrbot-4.5.0.dist-info}/RECORD +68 -44
  66. {astrbot-4.3.5.dist-info → astrbot-4.5.0.dist-info}/WHEEL +0 -0
  67. {astrbot-4.3.5.dist-info → astrbot-4.5.0.dist-info}/entry_points.txt +0 -0
  68. {astrbot-4.3.5.dist-info → astrbot-4.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,273 @@
1
+ """检索管理器
2
+
3
+ 协调稠密检索、稀疏检索和 Rerank,提供统一的检索接口
4
+ """
5
+
6
+ import time
7
+
8
+ from dataclasses import dataclass
9
+ from typing import List
10
+
11
+ from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
12
+ from astrbot.core.knowledge_base.retrieval.rank_fusion import RankFusion
13
+ from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseRetriever
14
+ from astrbot.core.provider.provider import RerankProvider
15
+ from astrbot.core.db.vec_db.base import Result
16
+ from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
17
+ from ..kb_helper import KBHelper
18
+ from astrbot import logger
19
+
20
+
21
+ @dataclass
22
+ class RetrievalResult:
23
+ """检索结果"""
24
+
25
+ chunk_id: str
26
+ doc_id: str
27
+ doc_name: str
28
+ kb_id: str
29
+ kb_name: str
30
+ content: str
31
+ score: float
32
+ metadata: dict
33
+
34
+
35
+ class RetrievalManager:
36
+ """检索管理器
37
+
38
+ 职责:
39
+ - 协调稠密检索、稀疏检索和 Rerank
40
+ - 结果融合和排序
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ sparse_retriever: SparseRetriever,
46
+ rank_fusion: RankFusion,
47
+ kb_db: KBSQLiteDatabase,
48
+ ):
49
+ """初始化检索管理器
50
+
51
+ Args:
52
+ vec_db_factory: 向量数据库工厂
53
+ sparse_retriever: 稀疏检索器
54
+ rank_fusion: 结果融合器
55
+ kb_db: 知识库数据库实例
56
+ """
57
+ self.sparse_retriever = sparse_retriever
58
+ self.rank_fusion = rank_fusion
59
+ self.kb_db = kb_db
60
+
61
+ async def retrieve(
62
+ self,
63
+ query: str,
64
+ kb_ids: List[str],
65
+ kb_id_helper_map: dict[str, KBHelper],
66
+ top_k_fusion: int = 20,
67
+ top_m_final: int = 5,
68
+ ) -> List[RetrievalResult]:
69
+ """混合检索
70
+
71
+ 流程:
72
+ 1. 稠密检索 (向量相似度)
73
+ 2. 稀疏检索 (BM25)
74
+ 3. 结果融合 (RRF)
75
+ 4. Rerank 重排序
76
+
77
+ Args:
78
+ query: 查询文本
79
+ kb_ids: 知识库 ID 列表
80
+ top_m_final: 最终返回数量
81
+ enable_rerank: 是否启用 Rerank
82
+
83
+ Returns:
84
+ List[RetrievalResult]: 检索结果列表
85
+ """
86
+ if not kb_ids:
87
+ return []
88
+
89
+ kb_options: dict = {}
90
+ new_kb_ids = []
91
+ for kb_id in kb_ids:
92
+ kb_helper = kb_id_helper_map.get(kb_id)
93
+ if kb_helper:
94
+ kb = kb_helper.kb
95
+ kb_options[kb_id] = {
96
+ "top_k_dense": kb.top_k_dense or 50,
97
+ "top_k_sparse": kb.top_k_sparse or 50,
98
+ "top_m_final": kb.top_m_final or 5,
99
+ "vec_db": kb_helper.vec_db,
100
+ "rerank_provider_id": kb.rerank_provider_id,
101
+ }
102
+ new_kb_ids.append(kb_id)
103
+ else:
104
+ logger.warning(f"知识库 ID {kb_id} 实例未找到, 已跳过该知识库的检索")
105
+
106
+ kb_ids = new_kb_ids
107
+
108
+ # 1. 稠密检索
109
+ time_start = time.time()
110
+ dense_results = await self._dense_retrieve(
111
+ query=query,
112
+ kb_ids=kb_ids,
113
+ kb_options=kb_options,
114
+ )
115
+ time_end = time.time()
116
+ logger.debug(
117
+ f"Dense retrieval across {len(kb_ids)} bases took {time_end - time_start:.2f}s and returned {len(dense_results)} results."
118
+ )
119
+
120
+ # 2. 稀疏检索
121
+ time_start = time.time()
122
+ sparse_results = await self.sparse_retriever.retrieve(
123
+ query=query,
124
+ kb_ids=kb_ids,
125
+ kb_options=kb_options,
126
+ )
127
+ time_end = time.time()
128
+ logger.debug(
129
+ f"Sparse retrieval across {len(kb_ids)} bases took {time_end - time_start:.2f}s and returned {len(sparse_results)} results."
130
+ )
131
+
132
+ # 3. 结果融合
133
+ time_start = time.time()
134
+ fused_results = await self.rank_fusion.fuse(
135
+ dense_results=dense_results,
136
+ sparse_results=sparse_results,
137
+ top_k=top_k_fusion,
138
+ )
139
+ time_end = time.time()
140
+ logger.debug(
141
+ f"Rank fusion took {time_end - time_start:.2f}s and returned {len(fused_results)} results."
142
+ )
143
+
144
+ # 4. 转换为 RetrievalResult (获取元数据)
145
+ retrieval_results = []
146
+ for fr in fused_results:
147
+ metadata_dict = await self.kb_db.get_document_with_metadata(fr.doc_id)
148
+ if metadata_dict:
149
+ retrieval_results.append(
150
+ RetrievalResult(
151
+ chunk_id=fr.chunk_id,
152
+ doc_id=fr.doc_id,
153
+ doc_name=metadata_dict["document"].doc_name,
154
+ kb_id=fr.kb_id,
155
+ kb_name=metadata_dict["knowledge_base"].kb_name,
156
+ content=fr.content,
157
+ score=fr.score,
158
+ metadata={
159
+ "chunk_index": fr.chunk_index,
160
+ "char_count": len(fr.content),
161
+ },
162
+ )
163
+ )
164
+
165
+ # 5. Rerank
166
+ first_rerank = None
167
+ for kb_id in kb_ids:
168
+ vec_db: FaissVecDB = kb_options[kb_id]["vec_db"]
169
+ rerank_pi = kb_options[kb_id]["rerank_provider_id"]
170
+ if (
171
+ vec_db
172
+ and vec_db.rerank_provider
173
+ and rerank_pi
174
+ and rerank_pi == vec_db.rerank_provider.meta().id
175
+ ):
176
+ first_rerank = vec_db.rerank_provider
177
+ break
178
+ if first_rerank and retrieval_results:
179
+ retrieval_results = await self._rerank(
180
+ query=query,
181
+ results=retrieval_results,
182
+ top_k=top_m_final,
183
+ rerank_provider=first_rerank,
184
+ )
185
+
186
+ return retrieval_results[:top_m_final]
187
+
188
+ async def _dense_retrieve(
189
+ self,
190
+ query: str,
191
+ kb_ids: List[str],
192
+ kb_options: dict,
193
+ ):
194
+ """稠密检索 (向量相似度)
195
+
196
+ 为每个知识库使用独立的向量数据库进行检索,然后合并结果。
197
+
198
+ Args:
199
+ query: 查询文本
200
+ kb_ids: 知识库 ID 列表
201
+ top_k: 返回结果数量
202
+
203
+ Returns:
204
+ List[Result]: 检索结果列表
205
+ """
206
+ all_results: list[Result] = []
207
+ for kb_id in kb_ids:
208
+ if kb_id not in kb_options:
209
+ continue
210
+ try:
211
+ vec_db: FaissVecDB = kb_options[kb_id]["vec_db"]
212
+ dense_k = int(kb_options[kb_id]["top_k_dense"])
213
+ vec_results = await vec_db.retrieve(
214
+ query=query,
215
+ k=dense_k,
216
+ fetch_k=dense_k * 2,
217
+ rerank=False, # 稠密检索阶段不进行 rerank
218
+ metadata_filters={"kb_id": kb_id},
219
+ )
220
+
221
+ all_results.extend(vec_results)
222
+ except Exception as e:
223
+ from astrbot.core import logger
224
+
225
+ logger.warning(f"知识库 {kb_id} 稠密检索失败: {e}")
226
+ continue
227
+
228
+ # 按相似度排序并返回 top_k
229
+ all_results.sort(key=lambda x: x.similarity, reverse=True)
230
+ # return all_results[: len(all_results) // len(kb_ids)]
231
+ return all_results
232
+
233
+ async def _rerank(
234
+ self,
235
+ query: str,
236
+ results: List[RetrievalResult],
237
+ top_k: int,
238
+ rerank_provider: RerankProvider,
239
+ ) -> List[RetrievalResult]:
240
+ """Rerank 重排序
241
+
242
+ Args:
243
+ query: 查询文本
244
+ results: 检索结果列表
245
+ top_k: 返回结果数量
246
+
247
+ Returns:
248
+ List[RetrievalResult]: 重排序后的结果列表
249
+ """
250
+ if not results:
251
+ return []
252
+
253
+ # 准备文档列表
254
+ docs = [r.content for r in results]
255
+
256
+ # 调用 Rerank Provider
257
+ rerank_results = await rerank_provider.rerank(
258
+ query=query,
259
+ documents=docs,
260
+ )
261
+
262
+ # 更新分数并重新排序
263
+ reranked_list = []
264
+ for rerank_result in rerank_results:
265
+ idx = rerank_result.index
266
+ if idx < len(results):
267
+ result = results[idx]
268
+ result.score = rerank_result.relevance_score
269
+ reranked_list.append(result)
270
+
271
+ reranked_list.sort(key=lambda x: x.score, reverse=True)
272
+
273
+ return reranked_list[:top_k]
@@ -0,0 +1,138 @@
1
+ """检索结果融合器
2
+
3
+ 使用 Reciprocal Rank Fusion (RRF) 算法融合稠密检索和稀疏检索的结果
4
+ """
5
+
6
+ import json
7
+ from dataclasses import dataclass
8
+
9
+ from astrbot.core.db.vec_db.base import Result
10
+ from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
11
+ from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseResult
12
+
13
+
14
+ @dataclass
15
+ class FusedResult:
16
+ """融合后的检索结果"""
17
+
18
+ chunk_id: str
19
+ chunk_index: int
20
+ doc_id: str
21
+ kb_id: str
22
+ content: str
23
+ score: float
24
+
25
+
26
+ class RankFusion:
27
+ """检索结果融合器
28
+
29
+ 职责:
30
+ - 融合稠密检索和稀疏检索的结果
31
+ - 使用 Reciprocal Rank Fusion (RRF) 算法
32
+ """
33
+
34
+ def __init__(self, kb_db: KBSQLiteDatabase, k: int = 60):
35
+ """初始化结果融合器
36
+
37
+ Args:
38
+ kb_db: 知识库数据库实例
39
+ k: RRF 参数,用于平滑排名
40
+ """
41
+ self.kb_db = kb_db
42
+ self.k = k
43
+
44
+ async def fuse(
45
+ self,
46
+ dense_results: list[Result],
47
+ sparse_results: list[SparseResult],
48
+ top_k: int = 20,
49
+ ) -> list[FusedResult]:
50
+ """融合稠密和稀疏检索结果
51
+
52
+ RRF 公式:
53
+ score(doc) = sum(1 / (k + rank_i))
54
+
55
+ Args:
56
+ dense_results: 稠密检索结果
57
+ sparse_results: 稀疏检索结果
58
+ top_k: 返回结果数量
59
+
60
+ Returns:
61
+ List[FusedResult]: 融合后的结果列表
62
+ """
63
+ # 1. 构建排名映射
64
+ dense_ranks = {
65
+ r.data["doc_id"]: (idx + 1) for idx, r in enumerate(dense_results)
66
+ } # 这里的 doc_id 实际上是 chunk_id
67
+ sparse_ranks = {r.chunk_id: (idx + 1) for idx, r in enumerate(sparse_results)}
68
+
69
+ # 2. 收集所有唯一的 ID
70
+ # 需要统一为 chunk_id
71
+ all_chunk_ids = set()
72
+ vec_doc_id_to_dense: dict[str, Result] = {} # vec_doc_id -> Result
73
+ chunk_id_to_sparse: dict[str, SparseResult] = {} # chunk_id -> SparseResult
74
+
75
+ # 处理稀疏检索结果
76
+ for r in sparse_results:
77
+ all_chunk_ids.add(r.chunk_id)
78
+ chunk_id_to_sparse[r.chunk_id] = r
79
+
80
+ # 处理稠密检索结果 (需要转换 vec_doc_id 到 chunk_id)
81
+ for r in dense_results:
82
+ vec_doc_id = r.data["doc_id"]
83
+ all_chunk_ids.add(vec_doc_id)
84
+ vec_doc_id_to_dense[vec_doc_id] = r
85
+
86
+ # 3. 计算 RRF 分数
87
+ rrf_scores: dict[str, float] = {}
88
+
89
+ for identifier in all_chunk_ids:
90
+ score = 0.0
91
+
92
+ # 来自稠密检索的贡献
93
+ if identifier in dense_ranks:
94
+ score += 1.0 / (self.k + dense_ranks[identifier])
95
+
96
+ # 来自稀疏检索的贡献
97
+ if identifier in sparse_ranks:
98
+ score += 1.0 / (self.k + sparse_ranks[identifier])
99
+
100
+ rrf_scores[identifier] = score
101
+
102
+ # 4. 排序
103
+ sorted_ids = sorted(
104
+ rrf_scores.keys(), key=lambda cid: rrf_scores[cid], reverse=True
105
+ )[:top_k]
106
+
107
+ # 5. 构建融合结果
108
+ fused_results = []
109
+ for identifier in sorted_ids:
110
+ # 优先从稀疏检索获取完整信息
111
+ if identifier in chunk_id_to_sparse:
112
+ sr = chunk_id_to_sparse[identifier]
113
+ fused_results.append(
114
+ FusedResult(
115
+ chunk_id=sr.chunk_id,
116
+ chunk_index=sr.chunk_index,
117
+ doc_id=sr.doc_id,
118
+ kb_id=sr.kb_id,
119
+ content=sr.content,
120
+ score=rrf_scores[identifier],
121
+ )
122
+ )
123
+ elif identifier in vec_doc_id_to_dense:
124
+ # 从向量检索获取信息,需要从数据库获取块的详细信息
125
+ vec_result = vec_doc_id_to_dense[identifier]
126
+ chunk_md = json.loads(vec_result.data["metadata"])
127
+ fused_results.append(
128
+ FusedResult(
129
+ chunk_id=identifier,
130
+ chunk_index=chunk_md["chunk_index"],
131
+ doc_id=chunk_md["kb_doc_id"],
132
+ kb_id=chunk_md["kb_id"],
133
+ content=vec_result.data["text"],
134
+ score=rrf_scores[identifier],
135
+ )
136
+ )
137
+
138
+ return fused_results
@@ -0,0 +1,130 @@
1
+ """稀疏检索器
2
+
3
+ 使用 BM25 算法进行基于关键词的文档检索
4
+ """
5
+
6
+ import jieba
7
+ import os
8
+ import json
9
+ from dataclasses import dataclass
10
+ from rank_bm25 import BM25Okapi
11
+ from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
12
+ from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
13
+
14
+
15
+ @dataclass
16
+ class SparseResult:
17
+ """稀疏检索结果"""
18
+
19
+ chunk_index: int
20
+ chunk_id: str
21
+ doc_id: str
22
+ kb_id: str
23
+ content: str
24
+ score: float
25
+
26
+
27
+ class SparseRetriever:
28
+ """BM25 稀疏检索器
29
+
30
+ 职责:
31
+ - 基于关键词的文档检索
32
+ - 使用 BM25 算法计算相关度
33
+ """
34
+
35
+ def __init__(self, kb_db: KBSQLiteDatabase):
36
+ """初始化稀疏检索器
37
+
38
+ Args:
39
+ kb_db: 知识库数据库实例
40
+ """
41
+ self.kb_db = kb_db
42
+ self._index_cache = {} # 缓存 BM25 索引
43
+
44
+ with open(
45
+ os.path.join(os.path.dirname(__file__), "hit_stopwords.txt"),
46
+ encoding="utf-8",
47
+ ) as f:
48
+ self.hit_stopwords = {
49
+ word.strip() for word in set(f.read().splitlines()) if word.strip()
50
+ }
51
+
52
+ async def retrieve(
53
+ self,
54
+ query: str,
55
+ kb_ids: list[str],
56
+ kb_options: dict,
57
+ ) -> list[SparseResult]:
58
+ """执行稀疏检索
59
+
60
+ Args:
61
+ query: 查询文本
62
+ kb_ids: 知识库 ID 列表
63
+ kb_options: 每个知识库的检索选项
64
+
65
+ Returns:
66
+ List[SparseResult]: 检索结果列表
67
+ """
68
+ # 1. 获取所有相关块
69
+ top_k_sparse = 0
70
+ chunks = []
71
+ for kb_id in kb_ids:
72
+ vec_db: FaissVecDB = kb_options.get(kb_id, {}).get("vec_db")
73
+ if not vec_db:
74
+ continue
75
+ result = await vec_db.document_storage.get_documents(
76
+ metadata_filters={}, limit=None, offset=None
77
+ )
78
+ chunk_mds = [json.loads(doc["metadata"]) for doc in result]
79
+ result = [
80
+ {
81
+ "chunk_id": doc["doc_id"],
82
+ "chunk_index": chunk_md["chunk_index"],
83
+ "doc_id": chunk_md["kb_doc_id"],
84
+ "kb_id": kb_id,
85
+ "text": doc["text"],
86
+ }
87
+ for doc, chunk_md in zip(result, chunk_mds)
88
+ ]
89
+ chunks.extend(result)
90
+ top_k_sparse += kb_options.get(kb_id, {}).get("top_k_sparse", 50)
91
+
92
+ if not chunks:
93
+ return []
94
+
95
+ # 2. 准备文档和索引
96
+ corpus = [chunk["text"] for chunk in chunks]
97
+ tokenized_corpus = [list(jieba.cut(doc)) for doc in corpus]
98
+ tokenized_corpus = [
99
+ [word for word in doc if word not in self.hit_stopwords]
100
+ for doc in tokenized_corpus
101
+ ]
102
+
103
+ # 3. 构建 BM25 索引
104
+ bm25 = BM25Okapi(tokenized_corpus)
105
+
106
+ # 4. 执行检索
107
+ tokenized_query = list(jieba.cut(query))
108
+ tokenized_query = [
109
+ word for word in tokenized_query if word not in self.hit_stopwords
110
+ ]
111
+ scores = bm25.get_scores(tokenized_query)
112
+
113
+ # 5. 排序并返回 Top-K
114
+ results = []
115
+ for idx, score in enumerate(scores):
116
+ chunk = chunks[idx]
117
+ results.append(
118
+ SparseResult(
119
+ chunk_id=chunk["chunk_id"],
120
+ chunk_index=chunk["chunk_index"],
121
+ doc_id=chunk["doc_id"],
122
+ kb_id=chunk["kb_id"],
123
+ content=chunk["text"],
124
+ score=float(score),
125
+ )
126
+ )
127
+
128
+ results.sort(key=lambda x: x.score, reverse=True)
129
+ # return results[: len(results) // len(kb_ids)]
130
+ return results[:top_k_sparse]
@@ -7,7 +7,7 @@ import copy
7
7
  import json
8
8
  import traceback
9
9
  from datetime import timedelta
10
- from typing import AsyncGenerator, Union
10
+ from collections.abc import AsyncGenerator
11
11
  from astrbot.core.conversation_mgr import Conversation
12
12
  from astrbot.core import logger
13
13
  from astrbot.core.message.components import Image
@@ -33,6 +33,7 @@ from astrbot.core.star.star_handler import EventType
33
33
  from astrbot.core.utils.metrics import Metric
34
34
  from ...context import PipelineContext, call_event_hook, call_handler
35
35
  from ..stage import Stage
36
+ from ..utils import inject_kb_context
36
37
  from astrbot.core.provider.register import llm_tools
37
38
  from astrbot.core.star.star_handler import star_map
38
39
  from astrbot.core.astr_agent_context import AstrAgentContext
@@ -44,7 +45,7 @@ except (ModuleNotFoundError, ImportError):
44
45
 
45
46
 
46
47
  AgentContextWrapper = ContextWrapper[AstrAgentContext]
47
- AgentRunner = ToolLoopAgentRunner[AgentContextWrapper]
48
+ AgentRunner = ToolLoopAgentRunner[AstrAgentContext]
48
49
 
49
50
 
50
51
  class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
@@ -102,7 +103,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
102
103
 
103
104
  request = ProviderRequest(
104
105
  prompt=input_,
105
- system_prompt=tool.description,
106
+ system_prompt=tool.description or "",
106
107
  image_urls=[], # 暂时不传递原始 agent 的上下文
107
108
  contexts=[], # 暂时不传递原始 agent 的上下文
108
109
  func_tool=toolset,
@@ -239,7 +240,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
239
240
  yield res
240
241
 
241
242
 
242
- class MainAgentHooks(BaseAgentRunHooks[AgentContextWrapper]):
243
+ class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
243
244
  async def on_agent_done(self, run_context, llm_response):
244
245
  # 执行事件钩子
245
246
  await call_event_hook(
@@ -337,7 +338,7 @@ class LLMRequestSubStage(Stage):
337
338
 
338
339
  self.conv_manager = ctx.plugin_manager.context.conversation_manager
339
340
 
340
- def _select_provider(self, event: AstrMessageEvent) -> Provider | None:
341
+ def _select_provider(self, event: AstrMessageEvent):
341
342
  """选择使用的 LLM 提供商"""
342
343
  sel_provider = event.get_extra("selected_provider")
343
344
  _ctx = self.ctx.plugin_manager.context
@@ -367,7 +368,7 @@ class LLMRequestSubStage(Stage):
367
368
 
368
369
  async def process(
369
370
  self, event: AstrMessageEvent, _nested: bool = False
370
- ) -> Union[None, AsyncGenerator[None, None]]:
371
+ ) -> None | AsyncGenerator[None, None]:
371
372
  req: ProviderRequest | None = None
372
373
 
373
374
  if not self.ctx.astrbot_config["provider_settings"]["enable"]:
@@ -382,6 +383,9 @@ class LLMRequestSubStage(Stage):
382
383
  provider = self._select_provider(event)
383
384
  if provider is None:
384
385
  return
386
+ if not isinstance(provider, Provider):
387
+ logger.error(f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。")
388
+ return
385
389
 
386
390
  if event.get_extra("provider_request"):
387
391
  req = event.get_extra("provider_request")
@@ -416,6 +420,14 @@ class LLMRequestSubStage(Stage):
416
420
  if not req.prompt and not req.image_urls:
417
421
  return
418
422
 
423
+ # 应用知识库
424
+ try:
425
+ await inject_kb_context(
426
+ umo=event.unified_msg_origin, p_ctx=self.ctx, req=req
427
+ )
428
+ except Exception as e:
429
+ logger.error(f"调用知识库时遇到问题: {e}")
430
+
419
431
  # 执行请求 LLM 前事件钩子。
420
432
  if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
421
433
  return
@@ -480,6 +492,9 @@ class LLMRequestSubStage(Stage):
480
492
  new_tool_set.add_tool(tool)
481
493
  req.func_tool = new_tool_set
482
494
 
495
+ # 备份 req.contexts
496
+ backup_contexts = copy.deepcopy(req.contexts)
497
+
483
498
  # run agent
484
499
  agent_runner = AgentRunner()
485
500
  logger.debug(
@@ -517,8 +532,10 @@ class LLMRequestSubStage(Stage):
517
532
  chain = (
518
533
  MessageChain().message(final_llm_resp.completion_text).chain
519
534
  )
520
- else:
535
+ elif final_llm_resp.result_chain:
521
536
  chain = final_llm_resp.result_chain.chain
537
+ else:
538
+ chain = MessageChain().chain
522
539
  event.set_result(
523
540
  MessageEventResult(
524
541
  chain=chain,
@@ -529,6 +546,9 @@ class LLMRequestSubStage(Stage):
529
546
  async for _ in run_agent(agent_runner, self.max_step, self.show_tool_use):
530
547
  yield
531
548
 
549
+ # 恢复备份的 contexts
550
+ req.contexts = backup_contexts
551
+
532
552
  await self._save_to_history(event, req, agent_runner.get_final_llm_resp())
533
553
 
534
554
  # 异步处理 WebChat 特殊情况
@@ -547,6 +567,8 @@ class LLMRequestSubStage(Stage):
547
567
  self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider
548
568
  ):
549
569
  """处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title"""
570
+ if not req.conversation:
571
+ return
550
572
  conversation = await self.conv_manager.get_conversation(
551
573
  event.unified_msg_origin, req.conversation.cid
552
574
  )