AstrBot 4.3.5__py3-none-any.whl → 4.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. astrbot/core/agent/runners/tool_loop_agent_runner.py +31 -2
  2. astrbot/core/astrbot_config_mgr.py +23 -51
  3. astrbot/core/config/default.py +92 -12
  4. astrbot/core/conversation_mgr.py +36 -1
  5. astrbot/core/core_lifecycle.py +24 -5
  6. astrbot/core/db/migration/migra_45_to_46.py +44 -0
  7. astrbot/core/db/vec_db/base.py +33 -2
  8. astrbot/core/db/vec_db/faiss_impl/document_storage.py +310 -52
  9. astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +31 -3
  10. astrbot/core/db/vec_db/faiss_impl/vec_db.py +81 -23
  11. astrbot/core/file_token_service.py +6 -1
  12. astrbot/core/initial_loader.py +6 -3
  13. astrbot/core/knowledge_base/chunking/__init__.py +11 -0
  14. astrbot/core/knowledge_base/chunking/base.py +24 -0
  15. astrbot/core/knowledge_base/chunking/fixed_size.py +57 -0
  16. astrbot/core/knowledge_base/chunking/recursive.py +155 -0
  17. astrbot/core/knowledge_base/kb_db_sqlite.py +299 -0
  18. astrbot/core/knowledge_base/kb_helper.py +348 -0
  19. astrbot/core/knowledge_base/kb_mgr.py +287 -0
  20. astrbot/core/knowledge_base/models.py +114 -0
  21. astrbot/core/knowledge_base/parsers/__init__.py +15 -0
  22. astrbot/core/knowledge_base/parsers/base.py +50 -0
  23. astrbot/core/knowledge_base/parsers/markitdown_parser.py +25 -0
  24. astrbot/core/knowledge_base/parsers/pdf_parser.py +100 -0
  25. astrbot/core/knowledge_base/parsers/text_parser.py +41 -0
  26. astrbot/core/knowledge_base/parsers/util.py +13 -0
  27. astrbot/core/knowledge_base/retrieval/__init__.py +16 -0
  28. astrbot/core/knowledge_base/retrieval/hit_stopwords.txt +767 -0
  29. astrbot/core/knowledge_base/retrieval/manager.py +273 -0
  30. astrbot/core/knowledge_base/retrieval/rank_fusion.py +138 -0
  31. astrbot/core/knowledge_base/retrieval/sparse_retriever.py +130 -0
  32. astrbot/core/pipeline/process_stage/method/llm_request.py +29 -7
  33. astrbot/core/pipeline/process_stage/utils.py +80 -0
  34. astrbot/core/platform/astr_message_event.py +8 -7
  35. astrbot/core/platform/sources/misskey/misskey_adapter.py +380 -44
  36. astrbot/core/platform/sources/misskey/misskey_api.py +581 -45
  37. astrbot/core/platform/sources/misskey/misskey_event.py +76 -41
  38. astrbot/core/platform/sources/misskey/misskey_utils.py +254 -43
  39. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +2 -1
  40. astrbot/core/platform/sources/satori/satori_adapter.py +27 -1
  41. astrbot/core/platform/sources/satori/satori_event.py +270 -99
  42. astrbot/core/provider/manager.py +14 -9
  43. astrbot/core/provider/provider.py +67 -0
  44. astrbot/core/provider/sources/anthropic_source.py +4 -4
  45. astrbot/core/provider/sources/dashscope_source.py +10 -9
  46. astrbot/core/provider/sources/dify_source.py +6 -8
  47. astrbot/core/provider/sources/gemini_embedding_source.py +1 -2
  48. astrbot/core/provider/sources/openai_embedding_source.py +1 -2
  49. astrbot/core/provider/sources/openai_source.py +18 -15
  50. astrbot/core/provider/sources/openai_tts_api_source.py +1 -1
  51. astrbot/core/star/context.py +3 -0
  52. astrbot/core/star/star.py +6 -0
  53. astrbot/core/star/star_manager.py +13 -7
  54. astrbot/core/umop_config_router.py +81 -0
  55. astrbot/core/updator.py +1 -1
  56. astrbot/core/utils/io.py +23 -12
  57. astrbot/dashboard/routes/__init__.py +2 -0
  58. astrbot/dashboard/routes/config.py +137 -9
  59. astrbot/dashboard/routes/knowledge_base.py +1065 -0
  60. astrbot/dashboard/routes/plugin.py +24 -5
  61. astrbot/dashboard/routes/update.py +1 -1
  62. astrbot/dashboard/server.py +6 -0
  63. astrbot/dashboard/utils.py +161 -0
  64. {astrbot-4.3.5.dist-info → astrbot-4.5.0.dist-info}/METADATA +29 -13
  65. {astrbot-4.3.5.dist-info → astrbot-4.5.0.dist-info}/RECORD +68 -44
  66. {astrbot-4.3.5.dist-info → astrbot-4.5.0.dist-info}/WHEEL +0 -0
  67. {astrbot-4.3.5.dist-info → astrbot-4.5.0.dist-info}/entry_points.txt +0 -0
  68. {astrbot-4.3.5.dist-info → astrbot-4.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1065 @@
1
+ """知识库管理 API 路由"""
2
+
3
+ import uuid
4
+ import aiofiles
5
+ import os
6
+ import traceback
7
+ import asyncio
8
+ from quart import request
9
+ from astrbot.core import logger
10
+ from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
11
+ from .route import Route, Response, RouteContext
12
+ from ..utils import generate_tsne_visualization
13
+ from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider
14
+
15
+
16
+ class KnowledgeBaseRoute(Route):
17
+ """知识库管理路由
18
+
19
+ 提供知识库、文档、检索、会话配置等 API 接口
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ context: RouteContext,
25
+ core_lifecycle: AstrBotCoreLifecycle,
26
+ ) -> None:
27
+ super().__init__(context)
28
+ self.core_lifecycle = core_lifecycle
29
+ self.kb_manager = None # 延迟初始化
30
+ self.kb_db = None
31
+ self.session_config_db = None # 会话配置数据库
32
+ self.retrieval_manager = None
33
+ self.upload_progress = {} # 存储上传进度 {task_id: {status, file_index, file_total, stage, current, total}}
34
+ self.upload_tasks = {} # 存储后台上传任务 {task_id: {"status", "result", "error"}}
35
+
36
+ # 注册路由
37
+ self.routes = {
38
+ # 知识库管理
39
+ "/kb/list": ("GET", self.list_kbs),
40
+ "/kb/create": ("POST", self.create_kb),
41
+ "/kb/get": ("GET", self.get_kb),
42
+ "/kb/update": ("POST", self.update_kb),
43
+ "/kb/delete": ("POST", self.delete_kb),
44
+ "/kb/stats": ("GET", self.get_kb_stats),
45
+ # 文档管理
46
+ "/kb/document/list": ("GET", self.list_documents),
47
+ "/kb/document/upload": ("POST", self.upload_document),
48
+ "/kb/document/upload/progress": ("GET", self.get_upload_progress),
49
+ "/kb/document/get": ("GET", self.get_document),
50
+ "/kb/document/delete": ("POST", self.delete_document),
51
+ # # 块管理
52
+ "/kb/chunk/list": ("GET", self.list_chunks),
53
+ "/kb/chunk/delete": ("POST", self.delete_chunk),
54
+ # # 多媒体管理
55
+ # "/kb/media/list": ("GET", self.list_media),
56
+ # "/kb/media/delete": ("POST", self.delete_media),
57
+ # 检索
58
+ "/kb/retrieve": ("POST", self.retrieve),
59
+ # 会话知识库配置
60
+ "/kb/session/config/get": ("GET", self.get_session_kb_config),
61
+ "/kb/session/config/set": ("POST", self.set_session_kb_config),
62
+ "/kb/session/config/delete": ("POST", self.delete_session_kb_config),
63
+ }
64
+ self.register_routes()
65
+
66
+ def _get_kb_manager(self):
67
+ return self.core_lifecycle.kb_manager
68
+
69
+ async def _background_upload_task(
70
+ self,
71
+ task_id: str,
72
+ kb_helper,
73
+ files_to_upload: list,
74
+ chunk_size: int,
75
+ chunk_overlap: int,
76
+ batch_size: int,
77
+ tasks_limit: int,
78
+ max_retries: int,
79
+ ):
80
+ """后台上传任务"""
81
+ try:
82
+ # 初始化任务状态
83
+ self.upload_tasks[task_id] = {
84
+ "status": "processing",
85
+ "result": None,
86
+ "error": None,
87
+ }
88
+ self.upload_progress[task_id] = {
89
+ "status": "processing",
90
+ "file_index": 0,
91
+ "file_total": len(files_to_upload),
92
+ "stage": "waiting",
93
+ "current": 0,
94
+ "total": 100,
95
+ }
96
+
97
+ uploaded_docs = []
98
+ failed_docs = []
99
+
100
+ for file_idx, file_info in enumerate(files_to_upload):
101
+ try:
102
+ # 更新整体进度
103
+ self.upload_progress[task_id].update(
104
+ {
105
+ "status": "processing",
106
+ "file_index": file_idx,
107
+ "file_name": file_info["file_name"],
108
+ "stage": "parsing",
109
+ "current": 0,
110
+ "total": 100,
111
+ }
112
+ )
113
+
114
+ # 创建进度回调函数
115
+ async def progress_callback(stage, current, total):
116
+ if task_id in self.upload_progress:
117
+ self.upload_progress[task_id].update(
118
+ {
119
+ "status": "processing",
120
+ "file_index": file_idx,
121
+ "file_name": file_info["file_name"],
122
+ "stage": stage,
123
+ "current": current,
124
+ "total": total,
125
+ }
126
+ )
127
+
128
+ doc = await kb_helper.upload_document(
129
+ file_name=file_info["file_name"],
130
+ file_content=file_info["file_content"],
131
+ file_type=file_info["file_type"],
132
+ chunk_size=chunk_size,
133
+ chunk_overlap=chunk_overlap,
134
+ batch_size=batch_size,
135
+ tasks_limit=tasks_limit,
136
+ max_retries=max_retries,
137
+ progress_callback=progress_callback,
138
+ )
139
+
140
+ uploaded_docs.append(doc.model_dump())
141
+ except Exception as e:
142
+ logger.error(f"上传文档 {file_info['file_name']} 失败: {e}")
143
+ failed_docs.append(
144
+ {"file_name": file_info["file_name"], "error": str(e)}
145
+ )
146
+
147
+ # 更新任务完成状态
148
+ result = {
149
+ "task_id": task_id,
150
+ "uploaded": uploaded_docs,
151
+ "failed": failed_docs,
152
+ "total": len(files_to_upload),
153
+ "success_count": len(uploaded_docs),
154
+ "failed_count": len(failed_docs),
155
+ }
156
+
157
+ self.upload_tasks[task_id] = {
158
+ "status": "completed",
159
+ "result": result,
160
+ "error": None,
161
+ }
162
+ self.upload_progress[task_id]["status"] = "completed"
163
+
164
+ except Exception as e:
165
+ logger.error(f"后台上传任务 {task_id} 失败: {e}")
166
+ logger.error(traceback.format_exc())
167
+ self.upload_tasks[task_id] = {
168
+ "status": "failed",
169
+ "result": None,
170
+ "error": str(e),
171
+ }
172
+ if task_id in self.upload_progress:
173
+ self.upload_progress[task_id]["status"] = "failed"
174
+
175
+ async def list_kbs(self):
176
+ """获取知识库列表
177
+
178
+ Query 参数:
179
+ - page: 页码 (默认 1)
180
+ - page_size: 每页数量 (默认 20)
181
+ - refresh_stats: 是否刷新统计信息 (默认 false,首次加载时可设为 true)
182
+ """
183
+ try:
184
+ kb_manager = self._get_kb_manager()
185
+ page = request.args.get("page", 1, type=int)
186
+ page_size = request.args.get("page_size", 20, type=int)
187
+
188
+ kbs = await kb_manager.list_kbs()
189
+
190
+ # 转换为字典列表
191
+ kb_list = []
192
+ for kb in kbs:
193
+ kb_list.append(kb.model_dump())
194
+
195
+ return (
196
+ Response()
197
+ .ok({"items": kb_list, "page": page, "page_size": page_size})
198
+ .__dict__
199
+ )
200
+ except ValueError as e:
201
+ return Response().error(str(e)).__dict__
202
+ except Exception as e:
203
+ logger.error(f"获取知识库列表失败: {e}")
204
+ logger.error(traceback.format_exc())
205
+ return Response().error(f"获取知识库列表失败: {str(e)}").__dict__
206
+
207
+ async def create_kb(self):
208
+ """创建知识库
209
+
210
+ Body:
211
+ - kb_name: 知识库名称 (必填)
212
+ - description: 描述 (可选)
213
+ - emoji: 图标 (可选)
214
+ - embedding_provider_id: 嵌入模型提供商ID (可选)
215
+ - rerank_provider_id: 重排序模型提供商ID (可选)
216
+ - chunk_size: 分块大小 (可选, 默认512)
217
+ - chunk_overlap: 块重叠大小 (可选, 默认50)
218
+ - top_k_dense: 密集检索数量 (可选, 默认50)
219
+ - top_k_sparse: 稀疏检索数量 (可选, 默认50)
220
+ - top_m_final: 最终返回数量 (可选, 默认5)
221
+ """
222
+ try:
223
+ kb_manager = self._get_kb_manager()
224
+ data = await request.json
225
+ kb_name = data.get("kb_name")
226
+ if not kb_name:
227
+ return Response().error("知识库名称不能为空").__dict__
228
+
229
+ description = data.get("description")
230
+ emoji = data.get("emoji")
231
+ embedding_provider_id = data.get("embedding_provider_id")
232
+ rerank_provider_id = data.get("rerank_provider_id")
233
+ chunk_size = data.get("chunk_size")
234
+ chunk_overlap = data.get("chunk_overlap")
235
+ top_k_dense = data.get("top_k_dense")
236
+ top_k_sparse = data.get("top_k_sparse")
237
+ top_m_final = data.get("top_m_final")
238
+
239
+ # pre-check embedding dim
240
+ if not embedding_provider_id:
241
+ return Response().error("缺少参数 embedding_provider_id").__dict__
242
+ prv = await kb_manager.provider_manager.get_provider_by_id(
243
+ embedding_provider_id
244
+ ) # type: ignore
245
+ if not prv or not isinstance(prv, EmbeddingProvider):
246
+ return (
247
+ Response().error(f"嵌入模型不存在或类型错误({type(prv)})").__dict__
248
+ )
249
+ try:
250
+ vec = await prv.get_embedding("astrbot")
251
+ if len(vec) != prv.get_dim():
252
+ raise ValueError(
253
+ f"嵌入向量维度不匹配,实际是 {len(vec)},然而配置是 {prv.get_dim()}"
254
+ )
255
+ except Exception as e:
256
+ return Response().error(f"测试嵌入模型失败: {str(e)}").__dict__
257
+ # pre-check rerank
258
+ if rerank_provider_id:
259
+ rerank_prv: RerankProvider = (
260
+ await kb_manager.provider_manager.get_provider_by_id(
261
+ rerank_provider_id
262
+ )
263
+ ) # type: ignore
264
+ if not rerank_prv:
265
+ return Response().error("重排序模型不存在").__dict__
266
+ # 检查重排序模型可用性
267
+ try:
268
+ res = await rerank_prv.rerank(
269
+ query="astrbot", documents=["astrbot knowledge base"]
270
+ )
271
+ if not res:
272
+ raise ValueError("重排序模型返回结果异常")
273
+ except Exception as e:
274
+ return (
275
+ Response()
276
+ .error(f"测试重排序模型失败: {str(e)},请检查控制台日志输出。")
277
+ .__dict__
278
+ )
279
+
280
+ kb_helper = await kb_manager.create_kb(
281
+ kb_name=kb_name,
282
+ description=description,
283
+ emoji=emoji,
284
+ embedding_provider_id=embedding_provider_id,
285
+ rerank_provider_id=rerank_provider_id,
286
+ chunk_size=chunk_size,
287
+ chunk_overlap=chunk_overlap,
288
+ top_k_dense=top_k_dense,
289
+ top_k_sparse=top_k_sparse,
290
+ top_m_final=top_m_final,
291
+ )
292
+ kb = kb_helper.kb
293
+
294
+ return Response().ok(kb.model_dump(), "创建知识库成功").__dict__
295
+
296
+ except ValueError as e:
297
+ return Response().error(str(e)).__dict__
298
+ except Exception as e:
299
+ logger.error(f"创建知识库失败: {e}")
300
+ logger.error(traceback.format_exc())
301
+ return Response().error(f"创建知识库失败: {str(e)}").__dict__
302
+
303
+ async def get_kb(self):
304
+ """获取知识库详情
305
+
306
+ Query 参数:
307
+ - kb_id: 知识库 ID (必填)
308
+ """
309
+ try:
310
+ kb_manager = self._get_kb_manager()
311
+ kb_id = request.args.get("kb_id")
312
+ if not kb_id:
313
+ return Response().error("缺少参数 kb_id").__dict__
314
+
315
+ kb_helper = await kb_manager.get_kb(kb_id)
316
+ if not kb_helper:
317
+ return Response().error("知识库不存在").__dict__
318
+ kb = kb_helper.kb
319
+
320
+ return Response().ok(kb.model_dump()).__dict__
321
+
322
+ except ValueError as e:
323
+ return Response().error(str(e)).__dict__
324
+ except Exception as e:
325
+ logger.error(f"获取知识库详情失败: {e}")
326
+ logger.error(traceback.format_exc())
327
+ return Response().error(f"获取知识库详情失败: {str(e)}").__dict__
328
+
329
+ async def update_kb(self):
330
+ """更新知识库
331
+
332
+ Body:
333
+ - kb_id: 知识库 ID (必填)
334
+ - kb_name: 新的知识库名称 (可选)
335
+ - description: 新的描述 (可选)
336
+ - emoji: 新的图标 (可选)
337
+ - embedding_provider_id: 新的嵌入模型提供商ID (可选)
338
+ - rerank_provider_id: 新的重排序模型提供商ID (可选)
339
+ - chunk_size: 分块大小 (可选)
340
+ - chunk_overlap: 块重叠大小 (可选)
341
+ - top_k_dense: 密集检索数量 (可选)
342
+ - top_k_sparse: 稀疏检索数量 (可选)
343
+ - top_m_final: 最终返回数量 (可选)
344
+ """
345
+ try:
346
+ kb_manager = self._get_kb_manager()
347
+ data = await request.json
348
+
349
+ kb_id = data.get("kb_id")
350
+ if not kb_id:
351
+ return Response().error("缺少参数 kb_id").__dict__
352
+
353
+ kb_name = data.get("kb_name")
354
+ description = data.get("description")
355
+ emoji = data.get("emoji")
356
+ embedding_provider_id = data.get("embedding_provider_id")
357
+ rerank_provider_id = data.get("rerank_provider_id")
358
+ chunk_size = data.get("chunk_size")
359
+ chunk_overlap = data.get("chunk_overlap")
360
+ top_k_dense = data.get("top_k_dense")
361
+ top_k_sparse = data.get("top_k_sparse")
362
+ top_m_final = data.get("top_m_final")
363
+
364
+ # 检查是否至少提供了一个更新字段
365
+ if all(
366
+ v is None
367
+ for v in [
368
+ kb_name,
369
+ description,
370
+ emoji,
371
+ embedding_provider_id,
372
+ rerank_provider_id,
373
+ chunk_size,
374
+ chunk_overlap,
375
+ top_k_dense,
376
+ top_k_sparse,
377
+ top_m_final,
378
+ ]
379
+ ):
380
+ return Response().error("至少需要提供一个更新字段").__dict__
381
+
382
+ kb_helper = await kb_manager.update_kb(
383
+ kb_id=kb_id,
384
+ kb_name=kb_name,
385
+ description=description,
386
+ emoji=emoji,
387
+ embedding_provider_id=embedding_provider_id,
388
+ rerank_provider_id=rerank_provider_id,
389
+ chunk_size=chunk_size,
390
+ chunk_overlap=chunk_overlap,
391
+ top_k_dense=top_k_dense,
392
+ top_k_sparse=top_k_sparse,
393
+ top_m_final=top_m_final,
394
+ )
395
+
396
+ if not kb_helper:
397
+ return Response().error("知识库不存在").__dict__
398
+
399
+ kb = kb_helper.kb
400
+ return Response().ok(kb.model_dump(), "更新知识库成功").__dict__
401
+
402
+ except ValueError as e:
403
+ return Response().error(str(e)).__dict__
404
+ except Exception as e:
405
+ logger.error(f"更新知识库失败: {e}")
406
+ logger.error(traceback.format_exc())
407
+ return Response().error(f"更新知识库失败: {str(e)}").__dict__
408
+
409
+ async def delete_kb(self):
410
+ """删除知识库
411
+
412
+ Body:
413
+ - kb_id: 知识库 ID (必填)
414
+ """
415
+ try:
416
+ kb_manager = self._get_kb_manager()
417
+ data = await request.json
418
+
419
+ kb_id = data.get("kb_id")
420
+ if not kb_id:
421
+ return Response().error("缺少参数 kb_id").__dict__
422
+
423
+ success = await kb_manager.delete_kb(kb_id)
424
+ if not success:
425
+ return Response().error("知识库不存在").__dict__
426
+
427
+ return Response().ok(message="删除知识库成功").__dict__
428
+
429
+ except ValueError as e:
430
+ return Response().error(str(e)).__dict__
431
+ except Exception as e:
432
+ logger.error(f"删除知识库失败: {e}")
433
+ logger.error(traceback.format_exc())
434
+ return Response().error(f"删除知识库失败: {str(e)}").__dict__
435
+
436
+ async def get_kb_stats(self):
437
+ """获取知识库统计信息
438
+
439
+ Query 参数:
440
+ - kb_id: 知识库 ID (必填)
441
+ """
442
+ try:
443
+ kb_manager = self._get_kb_manager()
444
+ kb_id = request.args.get("kb_id")
445
+ if not kb_id:
446
+ return Response().error("缺少参数 kb_id").__dict__
447
+
448
+ kb_helper = await kb_manager.get_kb(kb_id)
449
+ if not kb_helper:
450
+ return Response().error("知识库不存在").__dict__
451
+ kb = kb_helper.kb
452
+
453
+ stats = {
454
+ "kb_id": kb.kb_id,
455
+ "kb_name": kb.kb_name,
456
+ "doc_count": kb.doc_count,
457
+ "chunk_count": kb.chunk_count,
458
+ "created_at": kb.created_at.isoformat(),
459
+ "updated_at": kb.updated_at.isoformat(),
460
+ }
461
+
462
+ return Response().ok(stats).__dict__
463
+
464
+ except ValueError as e:
465
+ return Response().error(str(e)).__dict__
466
+ except Exception as e:
467
+ logger.error(f"获取知识库统计失败: {e}")
468
+ logger.error(traceback.format_exc())
469
+ return Response().error(f"获取知识库统计失败: {str(e)}").__dict__
470
+
471
+ # ===== 文档管理 API =====
472
+
473
+ async def list_documents(self):
474
+ """获取文档列表
475
+
476
+ Query 参数:
477
+ - kb_id: 知识库 ID (必填)
478
+ - page: 页码 (默认 1)
479
+ - page_size: 每页数量 (默认 20)
480
+ """
481
+ try:
482
+ kb_manager = self._get_kb_manager()
483
+ kb_id = request.args.get("kb_id")
484
+ if not kb_id:
485
+ return Response().error("缺少参数 kb_id").__dict__
486
+ kb_helper = await kb_manager.get_kb(kb_id)
487
+ if not kb_helper:
488
+ return Response().error("知识库不存在").__dict__
489
+
490
+ page = request.args.get("page", 1, type=int)
491
+ page_size = request.args.get("page_size", 100, type=int)
492
+
493
+ offset = (page - 1) * page_size
494
+ limit = page_size
495
+
496
+ doc_list = await kb_helper.list_documents(offset=offset, limit=limit)
497
+
498
+ doc_list = [doc.model_dump() for doc in doc_list]
499
+
500
+ return (
501
+ Response()
502
+ .ok({"items": doc_list, "page": page, "page_size": page_size})
503
+ .__dict__
504
+ )
505
+
506
+ except ValueError as e:
507
+ return Response().error(str(e)).__dict__
508
+ except Exception as e:
509
+ logger.error(f"获取文档列表失败: {e}")
510
+ logger.error(traceback.format_exc())
511
+ return Response().error(f"获取文档列表失败: {str(e)}").__dict__
512
+
513
+ async def upload_document(self):
514
+ """上传文档
515
+
516
+ 支持两种方式:
517
+ 1. multipart/form-data 文件上传(支持多文件,最多10个)
518
+ 2. JSON 格式 base64 编码上传(支持多文件,最多10个)
519
+
520
+ Form Data (multipart/form-data):
521
+ - kb_id: 知识库 ID (必填)
522
+ - file: 文件对象 (必填,可多个,字段名为 file, file1, file2, ... 或 files[])
523
+
524
+ JSON Body (application/json):
525
+ - kb_id: 知识库 ID (必填)
526
+ - files: 文件数组 (必填)
527
+ - file_name: 文件名 (必填)
528
+ - file_content: base64 编码的文件内容 (必填)
529
+
530
+ 返回:
531
+ - task_id: 任务ID,用于查询上传进度和结果
532
+ """
533
+ try:
534
+ kb_manager = self._get_kb_manager()
535
+
536
+ # 检查 Content-Type
537
+ content_type = request.content_type
538
+ kb_id = None
539
+ chunk_size = None
540
+ chunk_overlap = None
541
+ batch_size = 32
542
+ tasks_limit = 3
543
+ max_retries = 3
544
+ files_to_upload = [] # 存储待上传的文件信息列表
545
+
546
+ if content_type and "multipart/form-data" not in content_type:
547
+ return (
548
+ Response().error("Content-Type 须为 multipart/form-data").__dict__
549
+ )
550
+ form_data = await request.form
551
+ files = await request.files
552
+
553
+ kb_id = form_data.get("kb_id")
554
+ chunk_size = int(form_data.get("chunk_size", 512))
555
+ chunk_overlap = int(form_data.get("chunk_overlap", 50))
556
+ batch_size = int(form_data.get("batch_size", 32))
557
+ tasks_limit = int(form_data.get("tasks_limit", 3))
558
+ max_retries = int(form_data.get("max_retries", 3))
559
+ if not kb_id:
560
+ return Response().error("缺少参数 kb_id").__dict__
561
+
562
+ # 收集所有文件
563
+ file_list = []
564
+ # 支持 file, file1, file2, ... 或 files[] 格式
565
+ for key in files.keys():
566
+ if key == "file" or key.startswith("file") or key == "files[]":
567
+ file_items = files.getlist(key)
568
+ file_list.extend(file_items)
569
+
570
+ if not file_list:
571
+ return Response().error("缺少文件").__dict__
572
+
573
+ # 限制文件数量
574
+ if len(file_list) > 10:
575
+ return Response().error("最多只能上传10个文件").__dict__
576
+
577
+ # 处理每个文件
578
+ for file in file_list:
579
+ file_name = file.filename
580
+
581
+ # 保存到临时文件
582
+ temp_file_path = f"data/temp/{uuid.uuid4()}_{file_name}"
583
+ await file.save(temp_file_path)
584
+
585
+ try:
586
+ # 异步读取文件内容
587
+ async with aiofiles.open(temp_file_path, "rb") as f:
588
+ file_content = await f.read()
589
+
590
+ # 提取文件类型
591
+ file_type = (
592
+ file_name.rsplit(".", 1)[-1].lower() if "." in file_name else ""
593
+ )
594
+
595
+ files_to_upload.append(
596
+ {
597
+ "file_name": file_name,
598
+ "file_content": file_content,
599
+ "file_type": file_type,
600
+ }
601
+ )
602
+ finally:
603
+ # 清理临时文件
604
+ if os.path.exists(temp_file_path):
605
+ os.remove(temp_file_path)
606
+
607
+ # 获取知识库
608
+ kb_helper = await kb_manager.get_kb(kb_id)
609
+ if not kb_helper:
610
+ return Response().error("知识库不存在").__dict__
611
+
612
+ # 生成任务ID
613
+ task_id = str(uuid.uuid4())
614
+
615
+ # 初始化任务状态
616
+ self.upload_tasks[task_id] = {
617
+ "status": "pending",
618
+ "result": None,
619
+ "error": None,
620
+ }
621
+
622
+ # 启动后台任务
623
+ asyncio.create_task(
624
+ self._background_upload_task(
625
+ task_id=task_id,
626
+ kb_helper=kb_helper,
627
+ files_to_upload=files_to_upload,
628
+ chunk_size=chunk_size,
629
+ chunk_overlap=chunk_overlap,
630
+ batch_size=batch_size,
631
+ tasks_limit=tasks_limit,
632
+ max_retries=max_retries,
633
+ )
634
+ )
635
+
636
+ return (
637
+ Response()
638
+ .ok(
639
+ {
640
+ "task_id": task_id,
641
+ "file_count": len(files_to_upload),
642
+ "message": "task created, processing in background",
643
+ }
644
+ )
645
+ .__dict__
646
+ )
647
+
648
+ except ValueError as e:
649
+ return Response().error(str(e)).__dict__
650
+ except Exception as e:
651
+ logger.error(f"上传文档失败: {e}")
652
+ logger.error(traceback.format_exc())
653
+ return Response().error(f"上传文档失败: {str(e)}").__dict__
654
+
655
+ async def get_upload_progress(self):
656
+ """获取上传进度和结果
657
+
658
+ Query 参数:
659
+ - task_id: 任务 ID (必填)
660
+
661
+ 返回状态:
662
+ - pending: 任务待处理
663
+ - processing: 任务处理中
664
+ - completed: 任务完成
665
+ - failed: 任务失败
666
+ """
667
+ try:
668
+ task_id = request.args.get("task_id")
669
+ if not task_id:
670
+ return Response().error("缺少参数 task_id").__dict__
671
+
672
+ # 检查任务是否存在
673
+ if task_id not in self.upload_tasks:
674
+ return Response().error("找不到该任务").__dict__
675
+
676
+ task_info = self.upload_tasks[task_id]
677
+ status = task_info["status"]
678
+
679
+ # 构建返回数据
680
+ response_data = {
681
+ "task_id": task_id,
682
+ "status": status,
683
+ }
684
+
685
+ # 如果任务正在处理,返回进度信息
686
+ if status == "processing" and task_id in self.upload_progress:
687
+ response_data["progress"] = self.upload_progress[task_id]
688
+
689
+ # 如果任务完成,返回结果
690
+ if status == "completed":
691
+ response_data["result"] = task_info["result"]
692
+ # 清理已完成的任务
693
+ # del self.upload_tasks[task_id]
694
+ # if task_id in self.upload_progress:
695
+ # del self.upload_progress[task_id]
696
+
697
+ # 如果任务失败,返回错误信息
698
+ if status == "failed":
699
+ response_data["error"] = task_info["error"]
700
+
701
+ return Response().ok(response_data).__dict__
702
+
703
+ except Exception as e:
704
+ logger.error(f"获取上传进度失败: {e}")
705
+ logger.error(traceback.format_exc())
706
+ return Response().error(f"获取上传进度失败: {str(e)}").__dict__
707
+
708
+ async def get_document(self):
709
+ """获取文档详情
710
+
711
+ Query 参数:
712
+ - doc_id: 文档 ID (必填)
713
+ """
714
+ try:
715
+ kb_manager = self._get_kb_manager()
716
+ kb_id = request.args.get("kb_id")
717
+ if not kb_id:
718
+ return Response().error("缺少参数 kb_id").__dict__
719
+ doc_id = request.args.get("doc_id")
720
+ if not doc_id:
721
+ return Response().error("缺少参数 doc_id").__dict__
722
+ kb_helper = await kb_manager.get_kb(kb_id)
723
+ if not kb_helper:
724
+ return Response().error("知识库不存在").__dict__
725
+
726
+ doc = await kb_helper.get_document(doc_id)
727
+ if not doc:
728
+ return Response().error("文档不存在").__dict__
729
+
730
+ return Response().ok(doc.model_dump()).__dict__
731
+
732
+ except ValueError as e:
733
+ return Response().error(str(e)).__dict__
734
+ except Exception as e:
735
+ logger.error(f"获取文档详情失败: {e}")
736
+ logger.error(traceback.format_exc())
737
+ return Response().error(f"获取文档详情失败: {str(e)}").__dict__
738
+
739
+ async def delete_document(self):
740
+ """删除文档
741
+
742
+ Body:
743
+ - kb_id: 知识库 ID (必填)
744
+ - doc_id: 文档 ID (必填)
745
+ """
746
+ try:
747
+ kb_manager = self._get_kb_manager()
748
+ data = await request.json
749
+
750
+ kb_id = data.get("kb_id")
751
+ if not kb_id:
752
+ return Response().error("缺少参数 kb_id").__dict__
753
+ doc_id = data.get("doc_id")
754
+ if not doc_id:
755
+ return Response().error("缺少参数 doc_id").__dict__
756
+
757
+ kb_helper = await kb_manager.get_kb(kb_id)
758
+ if not kb_helper:
759
+ return Response().error("知识库不存在").__dict__
760
+
761
+ await kb_helper.delete_document(doc_id)
762
+ return Response().ok(message="删除文档成功").__dict__
763
+
764
+ except ValueError as e:
765
+ return Response().error(str(e)).__dict__
766
+ except Exception as e:
767
+ logger.error(f"删除文档失败: {e}")
768
+ logger.error(traceback.format_exc())
769
+ return Response().error(f"删除文档失败: {str(e)}").__dict__
770
+
771
+ async def delete_chunk(self):
772
+ """删除文本块
773
+
774
+ Body:
775
+ - kb_id: 知识库 ID (必填)
776
+ - chunk_id: 块 ID (必填)
777
+ """
778
+ try:
779
+ kb_manager = self._get_kb_manager()
780
+ data = await request.json
781
+
782
+ kb_id = data.get("kb_id")
783
+ if not kb_id:
784
+ return Response().error("缺少参数 kb_id").__dict__
785
+ chunk_id = data.get("chunk_id")
786
+ if not chunk_id:
787
+ return Response().error("缺少参数 chunk_id").__dict__
788
+ doc_id = data.get("doc_id")
789
+ if not doc_id:
790
+ return Response().error("缺少参数 doc_id").__dict__
791
+
792
+ kb_helper = await kb_manager.get_kb(kb_id)
793
+ if not kb_helper:
794
+ return Response().error("知识库不存在").__dict__
795
+
796
+ await kb_helper.delete_chunk(chunk_id, doc_id)
797
+ return Response().ok(message="删除文本块成功").__dict__
798
+
799
+ except ValueError as e:
800
+ return Response().error(str(e)).__dict__
801
+ except Exception as e:
802
+ logger.error(f"删除文本块失败: {e}")
803
+ logger.error(traceback.format_exc())
804
+ return Response().error(f"删除文本块失败: {str(e)}").__dict__
805
+
806
+ async def list_chunks(self):
807
+ """获取块列表
808
+
809
+ Query 参数:
810
+ - kb_id: 知识库 ID (必填)
811
+ - page: 页码 (默认 1)
812
+ - page_size: 每页数量 (默认 20)
813
+ """
814
+ try:
815
+ kb_manager = self._get_kb_manager()
816
+ kb_id = request.args.get("kb_id")
817
+ doc_id = request.args.get("doc_id")
818
+ page = request.args.get("page", 1, type=int)
819
+ page_size = request.args.get("page_size", 100, type=int)
820
+ if not kb_id:
821
+ return Response().error("缺少参数 kb_id").__dict__
822
+ if not doc_id:
823
+ return Response().error("缺少参数 doc_id").__dict__
824
+ kb_helper = await kb_manager.get_kb(kb_id)
825
+ offset = (page - 1) * page_size
826
+ limit = page_size
827
+ if not kb_helper:
828
+ return Response().error("知识库不存在").__dict__
829
+ chunk_list = await kb_helper.get_chunks_by_doc_id(
830
+ doc_id=doc_id, offset=offset, limit=limit
831
+ )
832
+ return (
833
+ Response()
834
+ .ok(
835
+ data={
836
+ "items": chunk_list,
837
+ "page": page,
838
+ "page_size": page_size,
839
+ "total": await kb_helper.get_chunk_count_by_doc_id(doc_id),
840
+ }
841
+ )
842
+ .__dict__
843
+ )
844
+ except ValueError as e:
845
+ return Response().error(str(e)).__dict__
846
+ except Exception as e:
847
+ logger.error(f"获取块列表失败: {e}")
848
+ logger.error(traceback.format_exc())
849
+ return Response().error(f"获取块列表失败: {str(e)}").__dict__
850
+
851
+ # ===== 检索 API =====
852
+
853
+ async def retrieve(self):
854
+ """检索知识库
855
+
856
+ Body:
857
+ - query: 查询文本 (必填)
858
+ - kb_ids: 知识库 ID 列表 (必填)
859
+ - top_k: 返回结果数量 (可选, 默认 5)
860
+ - debug: 是否启用调试模式,返回 t-SNE 可视化图片 (可选, 默认 False)
861
+ """
862
+ try:
863
+ kb_manager = self._get_kb_manager()
864
+ data = await request.json
865
+
866
+ query = data.get("query")
867
+ kb_names = data.get("kb_names")
868
+ debug = data.get("debug", False)
869
+
870
+ if not query:
871
+ return Response().error("缺少参数 query").__dict__
872
+ if not kb_names or not isinstance(kb_names, list):
873
+ return Response().error("缺少参数 kb_names 或格式错误").__dict__
874
+
875
+ top_k = data.get("top_k", 5)
876
+
877
+ results = await kb_manager.retrieve(
878
+ query=query,
879
+ kb_names=kb_names,
880
+ top_m_final=top_k,
881
+ )
882
+ result_list = []
883
+ if results:
884
+ result_list = results["results"]
885
+
886
+ response_data = {
887
+ "results": result_list,
888
+ "total": len(result_list),
889
+ "query": query,
890
+ }
891
+
892
+ # Debug 模式:生成 t-SNE 可视化
893
+ if debug:
894
+ try:
895
+ img_base64 = await generate_tsne_visualization(
896
+ query, kb_names, kb_manager
897
+ )
898
+ if img_base64:
899
+ response_data["visualization"] = img_base64
900
+ except Exception as e:
901
+ logger.error(f"生成 t-SNE 可视化失败: {e}")
902
+ logger.error(traceback.format_exc())
903
+ response_data["visualization_error"] = str(e)
904
+
905
+ return Response().ok(response_data).__dict__
906
+
907
+ except ValueError as e:
908
+ return Response().error(str(e)).__dict__
909
+ except Exception as e:
910
+ logger.error(f"检索失败: {e}")
911
+ logger.error(traceback.format_exc())
912
+ return Response().error(f"检索失败: {str(e)}").__dict__
913
+
914
+ # ===== 会话知识库配置 API =====
915
+
916
+ async def get_session_kb_config(self):
917
+ """获取会话的知识库配置
918
+
919
+ Query 参数:
920
+ - session_id: 会话 ID (必填)
921
+
922
+ 返回:
923
+ - kb_ids: 知识库 ID 列表
924
+ - top_k: 返回结果数量
925
+ - enable_rerank: 是否启用重排序
926
+ """
927
+ try:
928
+ from astrbot.core import sp
929
+
930
+ session_id = request.args.get("session_id")
931
+
932
+ if not session_id:
933
+ return Response().error("缺少参数 session_id").__dict__
934
+
935
+ # 从 SharedPreferences 获取配置
936
+ config = await sp.session_get(session_id, "kb_config", default={})
937
+
938
+ logger.debug(f"[KB配置] 读取到配置: session_id={session_id}")
939
+
940
+ # 如果没有配置,返回默认值
941
+ if not config:
942
+ config = {"kb_ids": [], "top_k": 5, "enable_rerank": True}
943
+
944
+ return Response().ok(config).__dict__
945
+
946
+ except Exception as e:
947
+ logger.error(f"[KB配置] 获取配置时出错: {e}", exc_info=True)
948
+ return Response().error(f"获取会话知识库配置失败: {str(e)}").__dict__
949
+
950
+ async def set_session_kb_config(self):
951
+ """设置会话的知识库配置
952
+
953
+ Body:
954
+ - scope: 配置范围 (目前只支持 "session")
955
+ - scope_id: 会话 ID (必填)
956
+ - kb_ids: 知识库 ID 列表 (必填)
957
+ - top_k: 返回结果数量 (可选, 默认 5)
958
+ - enable_rerank: 是否启用重排序 (可选, 默认 true)
959
+ """
960
+ try:
961
+ from astrbot.core import sp
962
+
963
+ data = await request.json
964
+
965
+ scope = data.get("scope")
966
+ scope_id = data.get("scope_id")
967
+ kb_ids = data.get("kb_ids", [])
968
+ top_k = data.get("top_k", 5)
969
+ enable_rerank = data.get("enable_rerank", True)
970
+
971
+ # 验证参数
972
+ if scope != "session":
973
+ return Response().error("目前仅支持 session 范围的配置").__dict__
974
+
975
+ if not scope_id:
976
+ return Response().error("缺少参数 scope_id").__dict__
977
+
978
+ if not isinstance(kb_ids, list):
979
+ return Response().error("kb_ids 必须是列表").__dict__
980
+
981
+ # 验证知识库是否存在
982
+ kb_mgr = self._get_kb_manager()
983
+ invalid_ids = []
984
+ valid_ids = []
985
+ for kb_id in kb_ids:
986
+ kb_helper = await kb_mgr.get_kb(kb_id)
987
+ if kb_helper:
988
+ valid_ids.append(kb_id)
989
+ else:
990
+ invalid_ids.append(kb_id)
991
+ logger.warning(f"[KB配置] 知识库不存在: {kb_id}")
992
+
993
+ if invalid_ids:
994
+ logger.warning(f"[KB配置] 以下知识库ID无效: {invalid_ids}")
995
+
996
+ # 允许保存空列表,表示明确不使用任何知识库
997
+ if kb_ids and not valid_ids:
998
+ # 只有当用户提供了 kb_ids 但全部无效时才报错
999
+ return Response().error(f"所有提供的知识库ID都无效: {kb_ids}").__dict__
1000
+
1001
+ # 如果 kb_ids 为空列表,表示用户想清空配置
1002
+ if not kb_ids:
1003
+ valid_ids = []
1004
+
1005
+ # 构建配置对象(只保存有效的ID)
1006
+ config = {
1007
+ "kb_ids": valid_ids,
1008
+ "top_k": top_k,
1009
+ "enable_rerank": enable_rerank,
1010
+ }
1011
+
1012
+ # 保存到 SharedPreferences
1013
+ await sp.session_put(scope_id, "kb_config", config)
1014
+
1015
+ # 立即验证是否保存成功
1016
+ verify_config = await sp.session_get(scope_id, "kb_config", default={})
1017
+
1018
+ if verify_config == config:
1019
+ return (
1020
+ Response()
1021
+ .ok(
1022
+ {"valid_ids": valid_ids, "invalid_ids": invalid_ids},
1023
+ "保存知识库配置成功",
1024
+ )
1025
+ .__dict__
1026
+ )
1027
+ else:
1028
+ logger.error("[KB配置] 配置保存失败,验证不匹配")
1029
+ return Response().error("配置保存失败").__dict__
1030
+
1031
+ except Exception as e:
1032
+ logger.error(f"[KB配置] 设置配置时出错: {e}", exc_info=True)
1033
+ return Response().error(f"设置会话知识库配置失败: {str(e)}").__dict__
1034
+
1035
+ async def delete_session_kb_config(self):
1036
+ """删除会话的知识库配置
1037
+
1038
+ Body:
1039
+ - scope: 配置范围 (目前只支持 "session")
1040
+ - scope_id: 会话 ID (必填)
1041
+ """
1042
+ try:
1043
+ from astrbot.core import sp
1044
+
1045
+ data = await request.json
1046
+
1047
+ scope = data.get("scope")
1048
+ scope_id = data.get("scope_id")
1049
+
1050
+ # 验证参数
1051
+ if scope != "session":
1052
+ return Response().error("目前仅支持 session 范围的配置").__dict__
1053
+
1054
+ if not scope_id:
1055
+ return Response().error("缺少参数 scope_id").__dict__
1056
+
1057
+ # 从 SharedPreferences 删除配置
1058
+ await sp.session_remove(scope_id, "kb_config")
1059
+
1060
+ return Response().ok(message="删除知识库配置成功").__dict__
1061
+
1062
+ except Exception as e:
1063
+ logger.error(f"删除会话知识库配置失败: {e}")
1064
+ logger.error(traceback.format_exc())
1065
+ return Response().error(f"删除会话知识库配置失败: {str(e)}").__dict__