AstrBot 3.5.6__py3-none-any.whl → 4.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (288) hide show
  1. astrbot/api/__init__.py +16 -4
  2. astrbot/api/all.py +2 -1
  3. astrbot/api/event/__init__.py +5 -6
  4. astrbot/api/event/filter/__init__.py +37 -34
  5. astrbot/api/platform/__init__.py +7 -8
  6. astrbot/api/provider/__init__.py +8 -7
  7. astrbot/api/star/__init__.py +3 -4
  8. astrbot/api/util/__init__.py +2 -2
  9. astrbot/cli/__init__.py +1 -0
  10. astrbot/cli/__main__.py +18 -197
  11. astrbot/cli/commands/__init__.py +6 -0
  12. astrbot/cli/commands/cmd_conf.py +209 -0
  13. astrbot/cli/commands/cmd_init.py +56 -0
  14. astrbot/cli/commands/cmd_plug.py +245 -0
  15. astrbot/cli/commands/cmd_run.py +62 -0
  16. astrbot/cli/utils/__init__.py +18 -0
  17. astrbot/cli/utils/basic.py +76 -0
  18. astrbot/cli/utils/plugin.py +246 -0
  19. astrbot/cli/utils/version_comparator.py +90 -0
  20. astrbot/core/__init__.py +17 -19
  21. astrbot/core/agent/agent.py +14 -0
  22. astrbot/core/agent/handoff.py +38 -0
  23. astrbot/core/agent/hooks.py +30 -0
  24. astrbot/core/agent/mcp_client.py +385 -0
  25. astrbot/core/agent/message.py +175 -0
  26. astrbot/core/agent/response.py +14 -0
  27. astrbot/core/agent/run_context.py +22 -0
  28. astrbot/core/agent/runners/__init__.py +3 -0
  29. astrbot/core/agent/runners/base.py +65 -0
  30. astrbot/core/agent/runners/coze/coze_agent_runner.py +367 -0
  31. astrbot/core/agent/runners/coze/coze_api_client.py +324 -0
  32. astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py +403 -0
  33. astrbot/core/agent/runners/dify/dify_agent_runner.py +336 -0
  34. astrbot/core/agent/runners/dify/dify_api_client.py +195 -0
  35. astrbot/core/agent/runners/tool_loop_agent_runner.py +400 -0
  36. astrbot/core/agent/tool.py +285 -0
  37. astrbot/core/agent/tool_executor.py +17 -0
  38. astrbot/core/astr_agent_context.py +19 -0
  39. astrbot/core/astr_agent_hooks.py +36 -0
  40. astrbot/core/astr_agent_run_util.py +80 -0
  41. astrbot/core/astr_agent_tool_exec.py +246 -0
  42. astrbot/core/astrbot_config_mgr.py +275 -0
  43. astrbot/core/config/__init__.py +2 -2
  44. astrbot/core/config/astrbot_config.py +60 -20
  45. astrbot/core/config/default.py +1972 -453
  46. astrbot/core/config/i18n_utils.py +110 -0
  47. astrbot/core/conversation_mgr.py +285 -75
  48. astrbot/core/core_lifecycle.py +167 -62
  49. astrbot/core/db/__init__.py +305 -102
  50. astrbot/core/db/migration/helper.py +69 -0
  51. astrbot/core/db/migration/migra_3_to_4.py +357 -0
  52. astrbot/core/db/migration/migra_45_to_46.py +44 -0
  53. astrbot/core/db/migration/migra_webchat_session.py +131 -0
  54. astrbot/core/db/migration/shared_preferences_v3.py +48 -0
  55. astrbot/core/db/migration/sqlite_v3.py +497 -0
  56. astrbot/core/db/po.py +259 -55
  57. astrbot/core/db/sqlite.py +773 -528
  58. astrbot/core/db/vec_db/base.py +73 -0
  59. astrbot/core/db/vec_db/faiss_impl/__init__.py +3 -0
  60. astrbot/core/db/vec_db/faiss_impl/document_storage.py +392 -0
  61. astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +93 -0
  62. astrbot/core/db/vec_db/faiss_impl/sqlite_init.sql +17 -0
  63. astrbot/core/db/vec_db/faiss_impl/vec_db.py +204 -0
  64. astrbot/core/event_bus.py +26 -22
  65. astrbot/core/exceptions.py +9 -0
  66. astrbot/core/file_token_service.py +98 -0
  67. astrbot/core/initial_loader.py +19 -10
  68. astrbot/core/knowledge_base/chunking/__init__.py +9 -0
  69. astrbot/core/knowledge_base/chunking/base.py +25 -0
  70. astrbot/core/knowledge_base/chunking/fixed_size.py +59 -0
  71. astrbot/core/knowledge_base/chunking/recursive.py +161 -0
  72. astrbot/core/knowledge_base/kb_db_sqlite.py +301 -0
  73. astrbot/core/knowledge_base/kb_helper.py +642 -0
  74. astrbot/core/knowledge_base/kb_mgr.py +330 -0
  75. astrbot/core/knowledge_base/models.py +120 -0
  76. astrbot/core/knowledge_base/parsers/__init__.py +13 -0
  77. astrbot/core/knowledge_base/parsers/base.py +51 -0
  78. astrbot/core/knowledge_base/parsers/markitdown_parser.py +26 -0
  79. astrbot/core/knowledge_base/parsers/pdf_parser.py +101 -0
  80. astrbot/core/knowledge_base/parsers/text_parser.py +42 -0
  81. astrbot/core/knowledge_base/parsers/url_parser.py +103 -0
  82. astrbot/core/knowledge_base/parsers/util.py +13 -0
  83. astrbot/core/knowledge_base/prompts.py +65 -0
  84. astrbot/core/knowledge_base/retrieval/__init__.py +14 -0
  85. astrbot/core/knowledge_base/retrieval/hit_stopwords.txt +767 -0
  86. astrbot/core/knowledge_base/retrieval/manager.py +276 -0
  87. astrbot/core/knowledge_base/retrieval/rank_fusion.py +142 -0
  88. astrbot/core/knowledge_base/retrieval/sparse_retriever.py +136 -0
  89. astrbot/core/log.py +21 -15
  90. astrbot/core/message/components.py +413 -287
  91. astrbot/core/message/message_event_result.py +35 -24
  92. astrbot/core/persona_mgr.py +192 -0
  93. astrbot/core/pipeline/__init__.py +14 -14
  94. astrbot/core/pipeline/content_safety_check/stage.py +13 -9
  95. astrbot/core/pipeline/content_safety_check/strategies/__init__.py +1 -2
  96. astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py +13 -14
  97. astrbot/core/pipeline/content_safety_check/strategies/keywords.py +2 -1
  98. astrbot/core/pipeline/content_safety_check/strategies/strategy.py +6 -6
  99. astrbot/core/pipeline/context.py +7 -1
  100. astrbot/core/pipeline/context_utils.py +107 -0
  101. astrbot/core/pipeline/preprocess_stage/stage.py +63 -36
  102. astrbot/core/pipeline/process_stage/method/agent_request.py +48 -0
  103. astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +464 -0
  104. astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +202 -0
  105. astrbot/core/pipeline/process_stage/method/star_request.py +26 -32
  106. astrbot/core/pipeline/process_stage/stage.py +21 -15
  107. astrbot/core/pipeline/process_stage/utils.py +125 -0
  108. astrbot/core/pipeline/rate_limit_check/stage.py +34 -36
  109. astrbot/core/pipeline/respond/stage.py +142 -101
  110. astrbot/core/pipeline/result_decorate/stage.py +124 -57
  111. astrbot/core/pipeline/scheduler.py +21 -16
  112. astrbot/core/pipeline/session_status_check/stage.py +37 -0
  113. astrbot/core/pipeline/stage.py +11 -76
  114. astrbot/core/pipeline/waking_check/stage.py +69 -33
  115. astrbot/core/pipeline/whitelist_check/stage.py +10 -7
  116. astrbot/core/platform/__init__.py +6 -6
  117. astrbot/core/platform/astr_message_event.py +107 -129
  118. astrbot/core/platform/astrbot_message.py +32 -12
  119. astrbot/core/platform/manager.py +62 -18
  120. astrbot/core/platform/message_session.py +30 -0
  121. astrbot/core/platform/platform.py +16 -24
  122. astrbot/core/platform/platform_metadata.py +9 -4
  123. astrbot/core/platform/register.py +12 -7
  124. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +136 -60
  125. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +126 -46
  126. astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +63 -31
  127. astrbot/core/platform/sources/dingtalk/dingtalk_event.py +30 -26
  128. astrbot/core/platform/sources/discord/client.py +129 -0
  129. astrbot/core/platform/sources/discord/components.py +139 -0
  130. astrbot/core/platform/sources/discord/discord_platform_adapter.py +473 -0
  131. astrbot/core/platform/sources/discord/discord_platform_event.py +313 -0
  132. astrbot/core/platform/sources/lark/lark_adapter.py +27 -18
  133. astrbot/core/platform/sources/lark/lark_event.py +39 -13
  134. astrbot/core/platform/sources/misskey/misskey_adapter.py +770 -0
  135. astrbot/core/platform/sources/misskey/misskey_api.py +964 -0
  136. astrbot/core/platform/sources/misskey/misskey_event.py +163 -0
  137. astrbot/core/platform/sources/misskey/misskey_utils.py +550 -0
  138. astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +149 -33
  139. astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +41 -26
  140. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +36 -17
  141. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py +3 -1
  142. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +14 -8
  143. astrbot/core/platform/sources/satori/satori_adapter.py +792 -0
  144. astrbot/core/platform/sources/satori/satori_event.py +432 -0
  145. astrbot/core/platform/sources/slack/client.py +164 -0
  146. astrbot/core/platform/sources/slack/slack_adapter.py +416 -0
  147. astrbot/core/platform/sources/slack/slack_event.py +253 -0
  148. astrbot/core/platform/sources/telegram/tg_adapter.py +100 -43
  149. astrbot/core/platform/sources/telegram/tg_event.py +136 -36
  150. astrbot/core/platform/sources/webchat/webchat_adapter.py +72 -22
  151. astrbot/core/platform/sources/webchat/webchat_event.py +46 -22
  152. astrbot/core/platform/sources/webchat/webchat_queue_mgr.py +35 -0
  153. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +926 -0
  154. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +178 -0
  155. astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py +159 -0
  156. astrbot/core/platform/sources/wecom/wecom_adapter.py +169 -27
  157. astrbot/core/platform/sources/wecom/wecom_event.py +162 -77
  158. astrbot/core/platform/sources/wecom/wecom_kf.py +279 -0
  159. astrbot/core/platform/sources/wecom/wecom_kf_message.py +196 -0
  160. astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +297 -0
  161. astrbot/core/platform/sources/wecom_ai_bot/__init__.py +15 -0
  162. astrbot/core/platform/sources/wecom_ai_bot/ierror.py +19 -0
  163. astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +472 -0
  164. astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +417 -0
  165. astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +152 -0
  166. astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +153 -0
  167. astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +168 -0
  168. astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +209 -0
  169. astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +306 -0
  170. astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +186 -0
  171. astrbot/core/platform_message_history_mgr.py +49 -0
  172. astrbot/core/provider/__init__.py +2 -3
  173. astrbot/core/provider/entites.py +8 -8
  174. astrbot/core/provider/entities.py +154 -98
  175. astrbot/core/provider/func_tool_manager.py +446 -458
  176. astrbot/core/provider/manager.py +345 -207
  177. astrbot/core/provider/provider.py +188 -73
  178. astrbot/core/provider/register.py +9 -7
  179. astrbot/core/provider/sources/anthropic_source.py +295 -115
  180. astrbot/core/provider/sources/azure_tts_source.py +224 -0
  181. astrbot/core/provider/sources/bailian_rerank_source.py +236 -0
  182. astrbot/core/provider/sources/dashscope_tts.py +138 -14
  183. astrbot/core/provider/sources/edge_tts_source.py +24 -19
  184. astrbot/core/provider/sources/fishaudio_tts_api_source.py +58 -13
  185. astrbot/core/provider/sources/gemini_embedding_source.py +61 -0
  186. astrbot/core/provider/sources/gemini_source.py +310 -132
  187. astrbot/core/provider/sources/gemini_tts_source.py +81 -0
  188. astrbot/core/provider/sources/groq_source.py +15 -0
  189. astrbot/core/provider/sources/gsv_selfhosted_source.py +151 -0
  190. astrbot/core/provider/sources/gsvi_tts_source.py +14 -7
  191. astrbot/core/provider/sources/minimax_tts_api_source.py +159 -0
  192. astrbot/core/provider/sources/openai_embedding_source.py +40 -0
  193. astrbot/core/provider/sources/openai_source.py +241 -145
  194. astrbot/core/provider/sources/openai_tts_api_source.py +18 -7
  195. astrbot/core/provider/sources/sensevoice_selfhosted_source.py +13 -11
  196. astrbot/core/provider/sources/vllm_rerank_source.py +71 -0
  197. astrbot/core/provider/sources/volcengine_tts.py +115 -0
  198. astrbot/core/provider/sources/whisper_api_source.py +18 -13
  199. astrbot/core/provider/sources/whisper_selfhosted_source.py +19 -12
  200. astrbot/core/provider/sources/xinference_rerank_source.py +116 -0
  201. astrbot/core/provider/sources/xinference_stt_provider.py +197 -0
  202. astrbot/core/provider/sources/zhipu_source.py +6 -73
  203. astrbot/core/star/__init__.py +43 -11
  204. astrbot/core/star/config.py +17 -18
  205. astrbot/core/star/context.py +362 -138
  206. astrbot/core/star/filter/__init__.py +4 -3
  207. astrbot/core/star/filter/command.py +111 -35
  208. astrbot/core/star/filter/command_group.py +46 -34
  209. astrbot/core/star/filter/custom_filter.py +6 -5
  210. astrbot/core/star/filter/event_message_type.py +4 -2
  211. astrbot/core/star/filter/permission.py +4 -2
  212. astrbot/core/star/filter/platform_adapter_type.py +45 -12
  213. astrbot/core/star/filter/regex.py +4 -2
  214. astrbot/core/star/register/__init__.py +19 -15
  215. astrbot/core/star/register/star.py +41 -13
  216. astrbot/core/star/register/star_handler.py +236 -86
  217. astrbot/core/star/session_llm_manager.py +280 -0
  218. astrbot/core/star/session_plugin_manager.py +170 -0
  219. astrbot/core/star/star.py +36 -43
  220. astrbot/core/star/star_handler.py +47 -85
  221. astrbot/core/star/star_manager.py +442 -260
  222. astrbot/core/star/star_tools.py +167 -45
  223. astrbot/core/star/updator.py +17 -20
  224. astrbot/core/umop_config_router.py +106 -0
  225. astrbot/core/updator.py +38 -13
  226. astrbot/core/utils/astrbot_path.py +39 -0
  227. astrbot/core/utils/command_parser.py +1 -1
  228. astrbot/core/utils/io.py +119 -60
  229. astrbot/core/utils/log_pipe.py +1 -1
  230. astrbot/core/utils/metrics.py +11 -10
  231. astrbot/core/utils/migra_helper.py +73 -0
  232. astrbot/core/utils/path_util.py +63 -62
  233. astrbot/core/utils/pip_installer.py +37 -15
  234. astrbot/core/utils/session_lock.py +29 -0
  235. astrbot/core/utils/session_waiter.py +19 -20
  236. astrbot/core/utils/shared_preferences.py +174 -34
  237. astrbot/core/utils/t2i/__init__.py +4 -1
  238. astrbot/core/utils/t2i/local_strategy.py +386 -238
  239. astrbot/core/utils/t2i/network_strategy.py +109 -49
  240. astrbot/core/utils/t2i/renderer.py +29 -14
  241. astrbot/core/utils/t2i/template/astrbot_powershell.html +184 -0
  242. astrbot/core/utils/t2i/template_manager.py +111 -0
  243. astrbot/core/utils/tencent_record_helper.py +115 -1
  244. astrbot/core/utils/version_comparator.py +10 -13
  245. astrbot/core/zip_updator.py +112 -65
  246. astrbot/dashboard/routes/__init__.py +20 -13
  247. astrbot/dashboard/routes/auth.py +20 -9
  248. astrbot/dashboard/routes/chat.py +297 -141
  249. astrbot/dashboard/routes/config.py +652 -55
  250. astrbot/dashboard/routes/conversation.py +107 -37
  251. astrbot/dashboard/routes/file.py +26 -0
  252. astrbot/dashboard/routes/knowledge_base.py +1244 -0
  253. astrbot/dashboard/routes/log.py +27 -2
  254. astrbot/dashboard/routes/persona.py +202 -0
  255. astrbot/dashboard/routes/plugin.py +197 -139
  256. astrbot/dashboard/routes/route.py +27 -7
  257. astrbot/dashboard/routes/session_management.py +354 -0
  258. astrbot/dashboard/routes/stat.py +85 -18
  259. astrbot/dashboard/routes/static_file.py +5 -2
  260. astrbot/dashboard/routes/t2i.py +233 -0
  261. astrbot/dashboard/routes/tools.py +184 -120
  262. astrbot/dashboard/routes/update.py +59 -36
  263. astrbot/dashboard/server.py +96 -36
  264. astrbot/dashboard/utils.py +165 -0
  265. astrbot-4.7.0.dist-info/METADATA +294 -0
  266. astrbot-4.7.0.dist-info/RECORD +274 -0
  267. {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/WHEEL +1 -1
  268. astrbot/core/db/plugin/sqlite_impl.py +0 -112
  269. astrbot/core/db/sqlite_init.sql +0 -50
  270. astrbot/core/pipeline/platform_compatibility/stage.py +0 -56
  271. astrbot/core/pipeline/process_stage/method/llm_request.py +0 -606
  272. astrbot/core/platform/sources/gewechat/client.py +0 -806
  273. astrbot/core/platform/sources/gewechat/downloader.py +0 -55
  274. astrbot/core/platform/sources/gewechat/gewechat_event.py +0 -255
  275. astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py +0 -103
  276. astrbot/core/platform/sources/gewechat/xml_data_parser.py +0 -110
  277. astrbot/core/provider/sources/dashscope_source.py +0 -203
  278. astrbot/core/provider/sources/dify_source.py +0 -281
  279. astrbot/core/provider/sources/llmtuner_source.py +0 -132
  280. astrbot/core/rag/embedding/openai_source.py +0 -20
  281. astrbot/core/rag/knowledge_db_mgr.py +0 -94
  282. astrbot/core/rag/store/__init__.py +0 -9
  283. astrbot/core/rag/store/chroma_db.py +0 -42
  284. astrbot/core/utils/dify_api_client.py +0 -152
  285. astrbot-3.5.6.dist-info/METADATA +0 -249
  286. astrbot-3.5.6.dist-info/RECORD +0 -158
  287. {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/entry_points.txt +0 -0
  288. {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1244 @@
1
+ """知识库管理 API 路由"""
2
+
3
+ import asyncio
4
+ import os
5
+ import traceback
6
+ import uuid
7
+
8
+ import aiofiles
9
+ from quart import request
10
+
11
+ from astrbot.core import logger
12
+ from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
13
+ from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider
14
+
15
+ from ..utils import generate_tsne_visualization
16
+ from .route import Response, Route, RouteContext
17
+
18
+
19
+ class KnowledgeBaseRoute(Route):
20
+ """知识库管理路由
21
+
22
+ 提供知识库、文档、检索、会话配置等 API 接口
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ context: RouteContext,
28
+ core_lifecycle: AstrBotCoreLifecycle,
29
+ ) -> None:
30
+ super().__init__(context)
31
+ self.core_lifecycle = core_lifecycle
32
+ self.kb_manager = None # 延迟初始化
33
+ self.kb_db = None
34
+ self.session_config_db = None # 会话配置数据库
35
+ self.retrieval_manager = None
36
+ self.upload_progress = {} # 存储上传进度 {task_id: {status, file_index, file_total, stage, current, total}}
37
+ self.upload_tasks = {} # 存储后台上传任务 {task_id: {"status", "result", "error"}}
38
+
39
+ # 注册路由
40
+ self.routes = {
41
+ # 知识库管理
42
+ "/kb/list": ("GET", self.list_kbs),
43
+ "/kb/create": ("POST", self.create_kb),
44
+ "/kb/get": ("GET", self.get_kb),
45
+ "/kb/update": ("POST", self.update_kb),
46
+ "/kb/delete": ("POST", self.delete_kb),
47
+ "/kb/stats": ("GET", self.get_kb_stats),
48
+ # 文档管理
49
+ "/kb/document/list": ("GET", self.list_documents),
50
+ "/kb/document/upload": ("POST", self.upload_document),
51
+ "/kb/document/upload/url": ("POST", self.upload_document_from_url),
52
+ "/kb/document/upload/progress": ("GET", self.get_upload_progress),
53
+ "/kb/document/get": ("GET", self.get_document),
54
+ "/kb/document/delete": ("POST", self.delete_document),
55
+ # # 块管理
56
+ "/kb/chunk/list": ("GET", self.list_chunks),
57
+ "/kb/chunk/delete": ("POST", self.delete_chunk),
58
+ # # 多媒体管理
59
+ # "/kb/media/list": ("GET", self.list_media),
60
+ # "/kb/media/delete": ("POST", self.delete_media),
61
+ # 检索
62
+ "/kb/retrieve": ("POST", self.retrieve),
63
+ # 会话知识库配置
64
+ "/kb/session/config/get": ("GET", self.get_session_kb_config),
65
+ "/kb/session/config/set": ("POST", self.set_session_kb_config),
66
+ "/kb/session/config/delete": ("POST", self.delete_session_kb_config),
67
+ }
68
+ self.register_routes()
69
+
70
+ def _get_kb_manager(self):
71
+ return self.core_lifecycle.kb_manager
72
+
73
+ async def _background_upload_task(
74
+ self,
75
+ task_id: str,
76
+ kb_helper,
77
+ files_to_upload: list,
78
+ chunk_size: int,
79
+ chunk_overlap: int,
80
+ batch_size: int,
81
+ tasks_limit: int,
82
+ max_retries: int,
83
+ ):
84
+ """后台上传任务"""
85
+ try:
86
+ # 初始化任务状态
87
+ self.upload_tasks[task_id] = {
88
+ "status": "processing",
89
+ "result": None,
90
+ "error": None,
91
+ }
92
+ self.upload_progress[task_id] = {
93
+ "status": "processing",
94
+ "file_index": 0,
95
+ "file_total": len(files_to_upload),
96
+ "stage": "waiting",
97
+ "current": 0,
98
+ "total": 100,
99
+ }
100
+
101
+ uploaded_docs = []
102
+ failed_docs = []
103
+
104
+ for file_idx, file_info in enumerate(files_to_upload):
105
+ try:
106
+ # 更新整体进度
107
+ self.upload_progress[task_id].update(
108
+ {
109
+ "status": "processing",
110
+ "file_index": file_idx,
111
+ "file_name": file_info["file_name"],
112
+ "stage": "parsing",
113
+ "current": 0,
114
+ "total": 100,
115
+ },
116
+ )
117
+
118
+ # 创建进度回调函数
119
+ async def progress_callback(stage, current, total):
120
+ if task_id in self.upload_progress:
121
+ self.upload_progress[task_id].update(
122
+ {
123
+ "status": "processing",
124
+ "file_index": file_idx,
125
+ "file_name": file_info["file_name"],
126
+ "stage": stage,
127
+ "current": current,
128
+ "total": total,
129
+ },
130
+ )
131
+
132
+ doc = await kb_helper.upload_document(
133
+ file_name=file_info["file_name"],
134
+ file_content=file_info["file_content"],
135
+ file_type=file_info["file_type"],
136
+ chunk_size=chunk_size,
137
+ chunk_overlap=chunk_overlap,
138
+ batch_size=batch_size,
139
+ tasks_limit=tasks_limit,
140
+ max_retries=max_retries,
141
+ progress_callback=progress_callback,
142
+ )
143
+
144
+ uploaded_docs.append(doc.model_dump())
145
+ except Exception as e:
146
+ logger.error(f"上传文档 {file_info['file_name']} 失败: {e}")
147
+ failed_docs.append(
148
+ {"file_name": file_info["file_name"], "error": str(e)},
149
+ )
150
+
151
+ # 更新任务完成状态
152
+ result = {
153
+ "task_id": task_id,
154
+ "uploaded": uploaded_docs,
155
+ "failed": failed_docs,
156
+ "total": len(files_to_upload),
157
+ "success_count": len(uploaded_docs),
158
+ "failed_count": len(failed_docs),
159
+ }
160
+
161
+ self.upload_tasks[task_id] = {
162
+ "status": "completed",
163
+ "result": result,
164
+ "error": None,
165
+ }
166
+ self.upload_progress[task_id]["status"] = "completed"
167
+
168
+ except Exception as e:
169
+ logger.error(f"后台上传任务 {task_id} 失败: {e}")
170
+ logger.error(traceback.format_exc())
171
+ self.upload_tasks[task_id] = {
172
+ "status": "failed",
173
+ "result": None,
174
+ "error": str(e),
175
+ }
176
+ if task_id in self.upload_progress:
177
+ self.upload_progress[task_id]["status"] = "failed"
178
+
179
+ async def list_kbs(self):
180
+ """获取知识库列表
181
+
182
+ Query 参数:
183
+ - page: 页码 (默认 1)
184
+ - page_size: 每页数量 (默认 20)
185
+ - refresh_stats: 是否刷新统计信息 (默认 false,首次加载时可设为 true)
186
+ """
187
+ try:
188
+ kb_manager = self._get_kb_manager()
189
+ page = request.args.get("page", 1, type=int)
190
+ page_size = request.args.get("page_size", 20, type=int)
191
+
192
+ kbs = await kb_manager.list_kbs()
193
+
194
+ # 转换为字典列表
195
+ kb_list = []
196
+ for kb in kbs:
197
+ kb_list.append(kb.model_dump())
198
+
199
+ return (
200
+ Response()
201
+ .ok({"items": kb_list, "page": page, "page_size": page_size})
202
+ .__dict__
203
+ )
204
+ except ValueError as e:
205
+ return Response().error(str(e)).__dict__
206
+ except Exception as e:
207
+ logger.error(f"获取知识库列表失败: {e}")
208
+ logger.error(traceback.format_exc())
209
+ return Response().error(f"获取知识库列表失败: {e!s}").__dict__
210
+
211
+ async def create_kb(self):
212
+ """创建知识库
213
+
214
+ Body:
215
+ - kb_name: 知识库名称 (必填)
216
+ - description: 描述 (可选)
217
+ - emoji: 图标 (可选)
218
+ - embedding_provider_id: 嵌入模型提供商ID (可选)
219
+ - rerank_provider_id: 重排序模型提供商ID (可选)
220
+ - chunk_size: 分块大小 (可选, 默认512)
221
+ - chunk_overlap: 块重叠大小 (可选, 默认50)
222
+ - top_k_dense: 密集检索数量 (可选, 默认50)
223
+ - top_k_sparse: 稀疏检索数量 (可选, 默认50)
224
+ - top_m_final: 最终返回数量 (可选, 默认5)
225
+ """
226
+ try:
227
+ kb_manager = self._get_kb_manager()
228
+ data = await request.json
229
+ kb_name = data.get("kb_name")
230
+ if not kb_name:
231
+ return Response().error("知识库名称不能为空").__dict__
232
+
233
+ description = data.get("description")
234
+ emoji = data.get("emoji")
235
+ embedding_provider_id = data.get("embedding_provider_id")
236
+ rerank_provider_id = data.get("rerank_provider_id")
237
+ chunk_size = data.get("chunk_size")
238
+ chunk_overlap = data.get("chunk_overlap")
239
+ top_k_dense = data.get("top_k_dense")
240
+ top_k_sparse = data.get("top_k_sparse")
241
+ top_m_final = data.get("top_m_final")
242
+
243
+ # pre-check embedding dim
244
+ if not embedding_provider_id:
245
+ return Response().error("缺少参数 embedding_provider_id").__dict__
246
+ prv = await kb_manager.provider_manager.get_provider_by_id(
247
+ embedding_provider_id,
248
+ ) # type: ignore
249
+ if not prv or not isinstance(prv, EmbeddingProvider):
250
+ return (
251
+ Response().error(f"嵌入模型不存在或类型错误({type(prv)})").__dict__
252
+ )
253
+ try:
254
+ vec = await prv.get_embedding("astrbot")
255
+ if len(vec) != prv.get_dim():
256
+ raise ValueError(
257
+ f"嵌入向量维度不匹配,实际是 {len(vec)},然而配置是 {prv.get_dim()}",
258
+ )
259
+ except Exception as e:
260
+ return Response().error(f"测试嵌入模型失败: {e!s}").__dict__
261
+ # pre-check rerank
262
+ if rerank_provider_id:
263
+ rerank_prv: RerankProvider = (
264
+ await kb_manager.provider_manager.get_provider_by_id(
265
+ rerank_provider_id,
266
+ )
267
+ ) # type: ignore
268
+ if not rerank_prv:
269
+ return Response().error("重排序模型不存在").__dict__
270
+ # 检查重排序模型可用性
271
+ try:
272
+ res = await rerank_prv.rerank(
273
+ query="astrbot",
274
+ documents=["astrbot knowledge base"],
275
+ )
276
+ if not res:
277
+ raise ValueError("重排序模型返回结果异常")
278
+ except Exception as e:
279
+ return (
280
+ Response()
281
+ .error(f"测试重排序模型失败: {e!s},请检查控制台日志输出。")
282
+ .__dict__
283
+ )
284
+
285
+ kb_helper = await kb_manager.create_kb(
286
+ kb_name=kb_name,
287
+ description=description,
288
+ emoji=emoji,
289
+ embedding_provider_id=embedding_provider_id,
290
+ rerank_provider_id=rerank_provider_id,
291
+ chunk_size=chunk_size,
292
+ chunk_overlap=chunk_overlap,
293
+ top_k_dense=top_k_dense,
294
+ top_k_sparse=top_k_sparse,
295
+ top_m_final=top_m_final,
296
+ )
297
+ kb = kb_helper.kb
298
+
299
+ return Response().ok(kb.model_dump(), "创建知识库成功").__dict__
300
+
301
+ except ValueError as e:
302
+ return Response().error(str(e)).__dict__
303
+ except Exception as e:
304
+ logger.error(f"创建知识库失败: {e}")
305
+ logger.error(traceback.format_exc())
306
+ return Response().error(f"创建知识库失败: {e!s}").__dict__
307
+
308
+ async def get_kb(self):
309
+ """获取知识库详情
310
+
311
+ Query 参数:
312
+ - kb_id: 知识库 ID (必填)
313
+ """
314
+ try:
315
+ kb_manager = self._get_kb_manager()
316
+ kb_id = request.args.get("kb_id")
317
+ if not kb_id:
318
+ return Response().error("缺少参数 kb_id").__dict__
319
+
320
+ kb_helper = await kb_manager.get_kb(kb_id)
321
+ if not kb_helper:
322
+ return Response().error("知识库不存在").__dict__
323
+ kb = kb_helper.kb
324
+
325
+ return Response().ok(kb.model_dump()).__dict__
326
+
327
+ except ValueError as e:
328
+ return Response().error(str(e)).__dict__
329
+ except Exception as e:
330
+ logger.error(f"获取知识库详情失败: {e}")
331
+ logger.error(traceback.format_exc())
332
+ return Response().error(f"获取知识库详情失败: {e!s}").__dict__
333
+
334
+ async def update_kb(self):
335
+ """更新知识库
336
+
337
+ Body:
338
+ - kb_id: 知识库 ID (必填)
339
+ - kb_name: 新的知识库名称 (可选)
340
+ - description: 新的描述 (可选)
341
+ - emoji: 新的图标 (可选)
342
+ - embedding_provider_id: 新的嵌入模型提供商ID (可选)
343
+ - rerank_provider_id: 新的重排序模型提供商ID (可选)
344
+ - chunk_size: 分块大小 (可选)
345
+ - chunk_overlap: 块重叠大小 (可选)
346
+ - top_k_dense: 密集检索数量 (可选)
347
+ - top_k_sparse: 稀疏检索数量 (可选)
348
+ - top_m_final: 最终返回数量 (可选)
349
+ """
350
+ try:
351
+ kb_manager = self._get_kb_manager()
352
+ data = await request.json
353
+
354
+ kb_id = data.get("kb_id")
355
+ if not kb_id:
356
+ return Response().error("缺少参数 kb_id").__dict__
357
+
358
+ kb_name = data.get("kb_name")
359
+ description = data.get("description")
360
+ emoji = data.get("emoji")
361
+ embedding_provider_id = data.get("embedding_provider_id")
362
+ rerank_provider_id = data.get("rerank_provider_id")
363
+ chunk_size = data.get("chunk_size")
364
+ chunk_overlap = data.get("chunk_overlap")
365
+ top_k_dense = data.get("top_k_dense")
366
+ top_k_sparse = data.get("top_k_sparse")
367
+ top_m_final = data.get("top_m_final")
368
+
369
+ # 检查是否至少提供了一个更新字段
370
+ if all(
371
+ v is None
372
+ for v in [
373
+ kb_name,
374
+ description,
375
+ emoji,
376
+ embedding_provider_id,
377
+ rerank_provider_id,
378
+ chunk_size,
379
+ chunk_overlap,
380
+ top_k_dense,
381
+ top_k_sparse,
382
+ top_m_final,
383
+ ]
384
+ ):
385
+ return Response().error("至少需要提供一个更新字段").__dict__
386
+
387
+ kb_helper = await kb_manager.update_kb(
388
+ kb_id=kb_id,
389
+ kb_name=kb_name,
390
+ description=description,
391
+ emoji=emoji,
392
+ embedding_provider_id=embedding_provider_id,
393
+ rerank_provider_id=rerank_provider_id,
394
+ chunk_size=chunk_size,
395
+ chunk_overlap=chunk_overlap,
396
+ top_k_dense=top_k_dense,
397
+ top_k_sparse=top_k_sparse,
398
+ top_m_final=top_m_final,
399
+ )
400
+
401
+ if not kb_helper:
402
+ return Response().error("知识库不存在").__dict__
403
+
404
+ kb = kb_helper.kb
405
+ return Response().ok(kb.model_dump(), "更新知识库成功").__dict__
406
+
407
+ except ValueError as e:
408
+ return Response().error(str(e)).__dict__
409
+ except Exception as e:
410
+ logger.error(f"更新知识库失败: {e}")
411
+ logger.error(traceback.format_exc())
412
+ return Response().error(f"更新知识库失败: {e!s}").__dict__
413
+
414
+ async def delete_kb(self):
415
+ """删除知识库
416
+
417
+ Body:
418
+ - kb_id: 知识库 ID (必填)
419
+ """
420
+ try:
421
+ kb_manager = self._get_kb_manager()
422
+ data = await request.json
423
+
424
+ kb_id = data.get("kb_id")
425
+ if not kb_id:
426
+ return Response().error("缺少参数 kb_id").__dict__
427
+
428
+ success = await kb_manager.delete_kb(kb_id)
429
+ if not success:
430
+ return Response().error("知识库不存在").__dict__
431
+
432
+ return Response().ok(message="删除知识库成功").__dict__
433
+
434
+ except ValueError as e:
435
+ return Response().error(str(e)).__dict__
436
+ except Exception as e:
437
+ logger.error(f"删除知识库失败: {e}")
438
+ logger.error(traceback.format_exc())
439
+ return Response().error(f"删除知识库失败: {e!s}").__dict__
440
+
441
+ async def get_kb_stats(self):
442
+ """获取知识库统计信息
443
+
444
+ Query 参数:
445
+ - kb_id: 知识库 ID (必填)
446
+ """
447
+ try:
448
+ kb_manager = self._get_kb_manager()
449
+ kb_id = request.args.get("kb_id")
450
+ if not kb_id:
451
+ return Response().error("缺少参数 kb_id").__dict__
452
+
453
+ kb_helper = await kb_manager.get_kb(kb_id)
454
+ if not kb_helper:
455
+ return Response().error("知识库不存在").__dict__
456
+ kb = kb_helper.kb
457
+
458
+ stats = {
459
+ "kb_id": kb.kb_id,
460
+ "kb_name": kb.kb_name,
461
+ "doc_count": kb.doc_count,
462
+ "chunk_count": kb.chunk_count,
463
+ "created_at": kb.created_at.isoformat(),
464
+ "updated_at": kb.updated_at.isoformat(),
465
+ }
466
+
467
+ return Response().ok(stats).__dict__
468
+
469
+ except ValueError as e:
470
+ return Response().error(str(e)).__dict__
471
+ except Exception as e:
472
+ logger.error(f"获取知识库统计失败: {e}")
473
+ logger.error(traceback.format_exc())
474
+ return Response().error(f"获取知识库统计失败: {e!s}").__dict__
475
+
476
+ # ===== 文档管理 API =====
477
+
478
+ async def list_documents(self):
479
+ """获取文档列表
480
+
481
+ Query 参数:
482
+ - kb_id: 知识库 ID (必填)
483
+ - page: 页码 (默认 1)
484
+ - page_size: 每页数量 (默认 20)
485
+ """
486
+ try:
487
+ kb_manager = self._get_kb_manager()
488
+ kb_id = request.args.get("kb_id")
489
+ if not kb_id:
490
+ return Response().error("缺少参数 kb_id").__dict__
491
+ kb_helper = await kb_manager.get_kb(kb_id)
492
+ if not kb_helper:
493
+ return Response().error("知识库不存在").__dict__
494
+
495
+ page = request.args.get("page", 1, type=int)
496
+ page_size = request.args.get("page_size", 100, type=int)
497
+
498
+ offset = (page - 1) * page_size
499
+ limit = page_size
500
+
501
+ doc_list = await kb_helper.list_documents(offset=offset, limit=limit)
502
+
503
+ doc_list = [doc.model_dump() for doc in doc_list]
504
+
505
+ return (
506
+ Response()
507
+ .ok({"items": doc_list, "page": page, "page_size": page_size})
508
+ .__dict__
509
+ )
510
+
511
+ except ValueError as e:
512
+ return Response().error(str(e)).__dict__
513
+ except Exception as e:
514
+ logger.error(f"获取文档列表失败: {e}")
515
+ logger.error(traceback.format_exc())
516
+ return Response().error(f"获取文档列表失败: {e!s}").__dict__
517
+
518
+ async def upload_document(self):
519
+ """上传文档
520
+
521
+ 支持两种方式:
522
+ 1. multipart/form-data 文件上传(支持多文件,最多10个)
523
+ 2. JSON 格式 base64 编码上传(支持多文件,最多10个)
524
+
525
+ Form Data (multipart/form-data):
526
+ - kb_id: 知识库 ID (必填)
527
+ - file: 文件对象 (必填,可多个,字段名为 file, file1, file2, ... 或 files[])
528
+
529
+ JSON Body (application/json):
530
+ - kb_id: 知识库 ID (必填)
531
+ - files: 文件数组 (必填)
532
+ - file_name: 文件名 (必填)
533
+ - file_content: base64 编码的文件内容 (必填)
534
+
535
+ 返回:
536
+ - task_id: 任务ID,用于查询上传进度和结果
537
+ """
538
+ try:
539
+ kb_manager = self._get_kb_manager()
540
+
541
+ # 检查 Content-Type
542
+ content_type = request.content_type
543
+ kb_id = None
544
+ chunk_size = None
545
+ chunk_overlap = None
546
+ batch_size = 32
547
+ tasks_limit = 3
548
+ max_retries = 3
549
+ files_to_upload = [] # 存储待上传的文件信息列表
550
+
551
+ if content_type and "multipart/form-data" not in content_type:
552
+ return (
553
+ Response().error("Content-Type 须为 multipart/form-data").__dict__
554
+ )
555
+ form_data = await request.form
556
+ files = await request.files
557
+
558
+ kb_id = form_data.get("kb_id")
559
+ chunk_size = int(form_data.get("chunk_size", 512))
560
+ chunk_overlap = int(form_data.get("chunk_overlap", 50))
561
+ batch_size = int(form_data.get("batch_size", 32))
562
+ tasks_limit = int(form_data.get("tasks_limit", 3))
563
+ max_retries = int(form_data.get("max_retries", 3))
564
+ if not kb_id:
565
+ return Response().error("缺少参数 kb_id").__dict__
566
+
567
+ # 收集所有文件
568
+ file_list = []
569
+ # 支持 file, file1, file2, ... 或 files[] 格式
570
+ for key in files.keys():
571
+ if key == "file" or key.startswith("file") or key == "files[]":
572
+ file_items = files.getlist(key)
573
+ file_list.extend(file_items)
574
+
575
+ if not file_list:
576
+ return Response().error("缺少文件").__dict__
577
+
578
+ # 限制文件数量
579
+ if len(file_list) > 10:
580
+ return Response().error("最多只能上传10个文件").__dict__
581
+
582
+ # 处理每个文件
583
+ for file in file_list:
584
+ file_name = file.filename
585
+
586
+ # 保存到临时文件
587
+ temp_file_path = f"data/temp/{uuid.uuid4()}_{file_name}"
588
+ await file.save(temp_file_path)
589
+
590
+ try:
591
+ # 异步读取文件内容
592
+ async with aiofiles.open(temp_file_path, "rb") as f:
593
+ file_content = await f.read()
594
+
595
+ # 提取文件类型
596
+ file_type = (
597
+ file_name.rsplit(".", 1)[-1].lower() if "." in file_name else ""
598
+ )
599
+
600
+ files_to_upload.append(
601
+ {
602
+ "file_name": file_name,
603
+ "file_content": file_content,
604
+ "file_type": file_type,
605
+ },
606
+ )
607
+ finally:
608
+ # 清理临时文件
609
+ if os.path.exists(temp_file_path):
610
+ os.remove(temp_file_path)
611
+
612
+ # 获取知识库
613
+ kb_helper = await kb_manager.get_kb(kb_id)
614
+ if not kb_helper:
615
+ return Response().error("知识库不存在").__dict__
616
+
617
+ # 生成任务ID
618
+ task_id = str(uuid.uuid4())
619
+
620
+ # 初始化任务状态
621
+ self.upload_tasks[task_id] = {
622
+ "status": "pending",
623
+ "result": None,
624
+ "error": None,
625
+ }
626
+
627
+ # 启动后台任务
628
+ asyncio.create_task(
629
+ self._background_upload_task(
630
+ task_id=task_id,
631
+ kb_helper=kb_helper,
632
+ files_to_upload=files_to_upload,
633
+ chunk_size=chunk_size,
634
+ chunk_overlap=chunk_overlap,
635
+ batch_size=batch_size,
636
+ tasks_limit=tasks_limit,
637
+ max_retries=max_retries,
638
+ ),
639
+ )
640
+
641
+ return (
642
+ Response()
643
+ .ok(
644
+ {
645
+ "task_id": task_id,
646
+ "file_count": len(files_to_upload),
647
+ "message": "task created, processing in background",
648
+ },
649
+ )
650
+ .__dict__
651
+ )
652
+
653
+ except ValueError as e:
654
+ return Response().error(str(e)).__dict__
655
+ except Exception as e:
656
+ logger.error(f"上传文档失败: {e}")
657
+ logger.error(traceback.format_exc())
658
+ return Response().error(f"上传文档失败: {e!s}").__dict__
659
+
660
+ async def get_upload_progress(self):
661
+ """获取上传进度和结果
662
+
663
+ Query 参数:
664
+ - task_id: 任务 ID (必填)
665
+
666
+ 返回状态:
667
+ - pending: 任务待处理
668
+ - processing: 任务处理中
669
+ - completed: 任务完成
670
+ - failed: 任务失败
671
+ """
672
+ try:
673
+ task_id = request.args.get("task_id")
674
+ if not task_id:
675
+ return Response().error("缺少参数 task_id").__dict__
676
+
677
+ # 检查任务是否存在
678
+ if task_id not in self.upload_tasks:
679
+ return Response().error("找不到该任务").__dict__
680
+
681
+ task_info = self.upload_tasks[task_id]
682
+ status = task_info["status"]
683
+
684
+ # 构建返回数据
685
+ response_data = {
686
+ "task_id": task_id,
687
+ "status": status,
688
+ }
689
+
690
+ # 如果任务正在处理,返回进度信息
691
+ if status == "processing" and task_id in self.upload_progress:
692
+ response_data["progress"] = self.upload_progress[task_id]
693
+
694
+ # 如果任务完成,返回结果
695
+ if status == "completed":
696
+ response_data["result"] = task_info["result"]
697
+ # 清理已完成的任务
698
+ # del self.upload_tasks[task_id]
699
+ # if task_id in self.upload_progress:
700
+ # del self.upload_progress[task_id]
701
+
702
+ # 如果任务失败,返回错误信息
703
+ if status == "failed":
704
+ response_data["error"] = task_info["error"]
705
+
706
+ return Response().ok(response_data).__dict__
707
+
708
+ except Exception as e:
709
+ logger.error(f"获取上传进度失败: {e}")
710
+ logger.error(traceback.format_exc())
711
+ return Response().error(f"获取上传进度失败: {e!s}").__dict__
712
+
713
+ async def get_document(self):
714
+ """获取文档详情
715
+
716
+ Query 参数:
717
+ - doc_id: 文档 ID (必填)
718
+ """
719
+ try:
720
+ kb_manager = self._get_kb_manager()
721
+ kb_id = request.args.get("kb_id")
722
+ if not kb_id:
723
+ return Response().error("缺少参数 kb_id").__dict__
724
+ doc_id = request.args.get("doc_id")
725
+ if not doc_id:
726
+ return Response().error("缺少参数 doc_id").__dict__
727
+ kb_helper = await kb_manager.get_kb(kb_id)
728
+ if not kb_helper:
729
+ return Response().error("知识库不存在").__dict__
730
+
731
+ doc = await kb_helper.get_document(doc_id)
732
+ if not doc:
733
+ return Response().error("文档不存在").__dict__
734
+
735
+ return Response().ok(doc.model_dump()).__dict__
736
+
737
+ except ValueError as e:
738
+ return Response().error(str(e)).__dict__
739
+ except Exception as e:
740
+ logger.error(f"获取文档详情失败: {e}")
741
+ logger.error(traceback.format_exc())
742
+ return Response().error(f"获取文档详情失败: {e!s}").__dict__
743
+
744
+ async def delete_document(self):
745
+ """删除文档
746
+
747
+ Body:
748
+ - kb_id: 知识库 ID (必填)
749
+ - doc_id: 文档 ID (必填)
750
+ """
751
+ try:
752
+ kb_manager = self._get_kb_manager()
753
+ data = await request.json
754
+
755
+ kb_id = data.get("kb_id")
756
+ if not kb_id:
757
+ return Response().error("缺少参数 kb_id").__dict__
758
+ doc_id = data.get("doc_id")
759
+ if not doc_id:
760
+ return Response().error("缺少参数 doc_id").__dict__
761
+
762
+ kb_helper = await kb_manager.get_kb(kb_id)
763
+ if not kb_helper:
764
+ return Response().error("知识库不存在").__dict__
765
+
766
+ await kb_helper.delete_document(doc_id)
767
+ return Response().ok(message="删除文档成功").__dict__
768
+
769
+ except ValueError as e:
770
+ return Response().error(str(e)).__dict__
771
+ except Exception as e:
772
+ logger.error(f"删除文档失败: {e}")
773
+ logger.error(traceback.format_exc())
774
+ return Response().error(f"删除文档失败: {e!s}").__dict__
775
+
776
+ async def delete_chunk(self):
777
+ """删除文本块
778
+
779
+ Body:
780
+ - kb_id: 知识库 ID (必填)
781
+ - chunk_id: 块 ID (必填)
782
+ """
783
+ try:
784
+ kb_manager = self._get_kb_manager()
785
+ data = await request.json
786
+
787
+ kb_id = data.get("kb_id")
788
+ if not kb_id:
789
+ return Response().error("缺少参数 kb_id").__dict__
790
+ chunk_id = data.get("chunk_id")
791
+ if not chunk_id:
792
+ return Response().error("缺少参数 chunk_id").__dict__
793
+ doc_id = data.get("doc_id")
794
+ if not doc_id:
795
+ return Response().error("缺少参数 doc_id").__dict__
796
+
797
+ kb_helper = await kb_manager.get_kb(kb_id)
798
+ if not kb_helper:
799
+ return Response().error("知识库不存在").__dict__
800
+
801
+ await kb_helper.delete_chunk(chunk_id, doc_id)
802
+ return Response().ok(message="删除文本块成功").__dict__
803
+
804
+ except ValueError as e:
805
+ return Response().error(str(e)).__dict__
806
+ except Exception as e:
807
+ logger.error(f"删除文本块失败: {e}")
808
+ logger.error(traceback.format_exc())
809
+ return Response().error(f"删除文本块失败: {e!s}").__dict__
810
+
811
+ async def list_chunks(self):
812
+ """获取块列表
813
+
814
+ Query 参数:
815
+ - kb_id: 知识库 ID (必填)
816
+ - page: 页码 (默认 1)
817
+ - page_size: 每页数量 (默认 20)
818
+ """
819
+ try:
820
+ kb_manager = self._get_kb_manager()
821
+ kb_id = request.args.get("kb_id")
822
+ doc_id = request.args.get("doc_id")
823
+ page = request.args.get("page", 1, type=int)
824
+ page_size = request.args.get("page_size", 100, type=int)
825
+ if not kb_id:
826
+ return Response().error("缺少参数 kb_id").__dict__
827
+ if not doc_id:
828
+ return Response().error("缺少参数 doc_id").__dict__
829
+ kb_helper = await kb_manager.get_kb(kb_id)
830
+ offset = (page - 1) * page_size
831
+ limit = page_size
832
+ if not kb_helper:
833
+ return Response().error("知识库不存在").__dict__
834
+ chunk_list = await kb_helper.get_chunks_by_doc_id(
835
+ doc_id=doc_id,
836
+ offset=offset,
837
+ limit=limit,
838
+ )
839
+ return (
840
+ Response()
841
+ .ok(
842
+ data={
843
+ "items": chunk_list,
844
+ "page": page,
845
+ "page_size": page_size,
846
+ "total": await kb_helper.get_chunk_count_by_doc_id(doc_id),
847
+ },
848
+ )
849
+ .__dict__
850
+ )
851
+ except ValueError as e:
852
+ return Response().error(str(e)).__dict__
853
+ except Exception as e:
854
+ logger.error(f"获取块列表失败: {e}")
855
+ logger.error(traceback.format_exc())
856
+ return Response().error(f"获取块列表失败: {e!s}").__dict__
857
+
858
+ # ===== 检索 API =====
859
+
860
+ async def retrieve(self):
861
+ """检索知识库
862
+
863
+ Body:
864
+ - query: 查询文本 (必填)
865
+ - kb_ids: 知识库 ID 列表 (必填)
866
+ - top_k: 返回结果数量 (可选, 默认 5)
867
+ - debug: 是否启用调试模式,返回 t-SNE 可视化图片 (可选, 默认 False)
868
+ """
869
+ try:
870
+ kb_manager = self._get_kb_manager()
871
+ data = await request.json
872
+
873
+ query = data.get("query")
874
+ kb_names = data.get("kb_names")
875
+ debug = data.get("debug", False)
876
+
877
+ if not query:
878
+ return Response().error("缺少参数 query").__dict__
879
+ if not kb_names or not isinstance(kb_names, list):
880
+ return Response().error("缺少参数 kb_names 或格式错误").__dict__
881
+
882
+ top_k = data.get("top_k", 5)
883
+
884
+ results = await kb_manager.retrieve(
885
+ query=query,
886
+ kb_names=kb_names,
887
+ top_m_final=top_k,
888
+ )
889
+ result_list = []
890
+ if results:
891
+ result_list = results["results"]
892
+
893
+ response_data = {
894
+ "results": result_list,
895
+ "total": len(result_list),
896
+ "query": query,
897
+ }
898
+
899
+ # Debug 模式:生成 t-SNE 可视化
900
+ if debug:
901
+ try:
902
+ img_base64 = await generate_tsne_visualization(
903
+ query,
904
+ kb_names,
905
+ kb_manager,
906
+ )
907
+ if img_base64:
908
+ response_data["visualization"] = img_base64
909
+ except Exception as e:
910
+ logger.error(f"生成 t-SNE 可视化失败: {e}")
911
+ logger.error(traceback.format_exc())
912
+ response_data["visualization_error"] = str(e)
913
+
914
+ return Response().ok(response_data).__dict__
915
+
916
+ except ValueError as e:
917
+ return Response().error(str(e)).__dict__
918
+ except Exception as e:
919
+ logger.error(f"检索失败: {e}")
920
+ logger.error(traceback.format_exc())
921
+ return Response().error(f"检索失败: {e!s}").__dict__
922
+
923
+ # ===== 会话知识库配置 API =====
924
+
925
+ async def get_session_kb_config(self):
926
+ """获取会话的知识库配置
927
+
928
+ Query 参数:
929
+ - session_id: 会话 ID (必填)
930
+
931
+ 返回:
932
+ - kb_ids: 知识库 ID 列表
933
+ - top_k: 返回结果数量
934
+ - enable_rerank: 是否启用重排序
935
+ """
936
+ try:
937
+ from astrbot.core import sp
938
+
939
+ session_id = request.args.get("session_id")
940
+
941
+ if not session_id:
942
+ return Response().error("缺少参数 session_id").__dict__
943
+
944
+ # 从 SharedPreferences 获取配置
945
+ config = await sp.session_get(session_id, "kb_config", default={})
946
+
947
+ logger.debug(f"[KB配置] 读取到配置: session_id={session_id}")
948
+
949
+ # 如果没有配置,返回默认值
950
+ if not config:
951
+ config = {"kb_ids": [], "top_k": 5, "enable_rerank": True}
952
+
953
+ return Response().ok(config).__dict__
954
+
955
+ except Exception as e:
956
+ logger.error(f"[KB配置] 获取配置时出错: {e}", exc_info=True)
957
+ return Response().error(f"获取会话知识库配置失败: {e!s}").__dict__
958
+
959
+ async def set_session_kb_config(self):
960
+ """设置会话的知识库配置
961
+
962
+ Body:
963
+ - scope: 配置范围 (目前只支持 "session")
964
+ - scope_id: 会话 ID (必填)
965
+ - kb_ids: 知识库 ID 列表 (必填)
966
+ - top_k: 返回结果数量 (可选, 默认 5)
967
+ - enable_rerank: 是否启用重排序 (可选, 默认 true)
968
+ """
969
+ try:
970
+ from astrbot.core import sp
971
+
972
+ data = await request.json
973
+
974
+ scope = data.get("scope")
975
+ scope_id = data.get("scope_id")
976
+ kb_ids = data.get("kb_ids", [])
977
+ top_k = data.get("top_k", 5)
978
+ enable_rerank = data.get("enable_rerank", True)
979
+
980
+ # 验证参数
981
+ if scope != "session":
982
+ return Response().error("目前仅支持 session 范围的配置").__dict__
983
+
984
+ if not scope_id:
985
+ return Response().error("缺少参数 scope_id").__dict__
986
+
987
+ if not isinstance(kb_ids, list):
988
+ return Response().error("kb_ids 必须是列表").__dict__
989
+
990
+ # 验证知识库是否存在
991
+ kb_mgr = self._get_kb_manager()
992
+ invalid_ids = []
993
+ valid_ids = []
994
+ for kb_id in kb_ids:
995
+ kb_helper = await kb_mgr.get_kb(kb_id)
996
+ if kb_helper:
997
+ valid_ids.append(kb_id)
998
+ else:
999
+ invalid_ids.append(kb_id)
1000
+ logger.warning(f"[KB配置] 知识库不存在: {kb_id}")
1001
+
1002
+ if invalid_ids:
1003
+ logger.warning(f"[KB配置] 以下知识库ID无效: {invalid_ids}")
1004
+
1005
+ # 允许保存空列表,表示明确不使用任何知识库
1006
+ if kb_ids and not valid_ids:
1007
+ # 只有当用户提供了 kb_ids 但全部无效时才报错
1008
+ return Response().error(f"所有提供的知识库ID都无效: {kb_ids}").__dict__
1009
+
1010
+ # 如果 kb_ids 为空列表,表示用户想清空配置
1011
+ if not kb_ids:
1012
+ valid_ids = []
1013
+
1014
+ # 构建配置对象(只保存有效的ID)
1015
+ config = {
1016
+ "kb_ids": valid_ids,
1017
+ "top_k": top_k,
1018
+ "enable_rerank": enable_rerank,
1019
+ }
1020
+
1021
+ # 保存到 SharedPreferences
1022
+ await sp.session_put(scope_id, "kb_config", config)
1023
+
1024
+ # 立即验证是否保存成功
1025
+ verify_config = await sp.session_get(scope_id, "kb_config", default={})
1026
+
1027
+ if verify_config == config:
1028
+ return (
1029
+ Response()
1030
+ .ok(
1031
+ {"valid_ids": valid_ids, "invalid_ids": invalid_ids},
1032
+ "保存知识库配置成功",
1033
+ )
1034
+ .__dict__
1035
+ )
1036
+ logger.error("[KB配置] 配置保存失败,验证不匹配")
1037
+ return Response().error("配置保存失败").__dict__
1038
+
1039
+ except Exception as e:
1040
+ logger.error(f"[KB配置] 设置配置时出错: {e}", exc_info=True)
1041
+ return Response().error(f"设置会话知识库配置失败: {e!s}").__dict__
1042
+
1043
+ async def delete_session_kb_config(self):
1044
+ """删除会话的知识库配置
1045
+
1046
+ Body:
1047
+ - scope: 配置范围 (目前只支持 "session")
1048
+ - scope_id: 会话 ID (必填)
1049
+ """
1050
+ try:
1051
+ from astrbot.core import sp
1052
+
1053
+ data = await request.json
1054
+
1055
+ scope = data.get("scope")
1056
+ scope_id = data.get("scope_id")
1057
+
1058
+ # 验证参数
1059
+ if scope != "session":
1060
+ return Response().error("目前仅支持 session 范围的配置").__dict__
1061
+
1062
+ if not scope_id:
1063
+ return Response().error("缺少参数 scope_id").__dict__
1064
+
1065
+ # 从 SharedPreferences 删除配置
1066
+ await sp.session_remove(scope_id, "kb_config")
1067
+
1068
+ return Response().ok(message="删除知识库配置成功").__dict__
1069
+
1070
+ except Exception as e:
1071
+ logger.error(f"删除会话知识库配置失败: {e}")
1072
+ logger.error(traceback.format_exc())
1073
+ return Response().error(f"删除会话知识库配置失败: {e!s}").__dict__
1074
+
1075
+ async def upload_document_from_url(self):
1076
+ """从 URL 上传文档
1077
+
1078
+ Body:
1079
+ - kb_id: 知识库 ID (必填)
1080
+ - url: 要提取内容的网页 URL (必填)
1081
+ - chunk_size: 分块大小 (可选, 默认512)
1082
+ - chunk_overlap: 块重叠大小 (可选, 默认50)
1083
+ - batch_size: 批处理大小 (可选, 默认32)
1084
+ - tasks_limit: 并发任务限制 (可选, 默认3)
1085
+ - max_retries: 最大重试次数 (可选, 默认3)
1086
+
1087
+ 返回:
1088
+ - task_id: 任务ID,用于查询上传进度和结果
1089
+ """
1090
+ try:
1091
+ kb_manager = self._get_kb_manager()
1092
+ data = await request.json
1093
+
1094
+ kb_id = data.get("kb_id")
1095
+ if not kb_id:
1096
+ return Response().error("缺少参数 kb_id").__dict__
1097
+
1098
+ url = data.get("url")
1099
+ if not url:
1100
+ return Response().error("缺少参数 url").__dict__
1101
+
1102
+ chunk_size = data.get("chunk_size", 512)
1103
+ chunk_overlap = data.get("chunk_overlap", 50)
1104
+ batch_size = data.get("batch_size", 32)
1105
+ tasks_limit = data.get("tasks_limit", 3)
1106
+ max_retries = data.get("max_retries", 3)
1107
+ enable_cleaning = data.get("enable_cleaning", False)
1108
+ cleaning_provider_id = data.get("cleaning_provider_id")
1109
+
1110
+ # 获取知识库
1111
+ kb_helper = await kb_manager.get_kb(kb_id)
1112
+ if not kb_helper:
1113
+ return Response().error("知识库不存在").__dict__
1114
+
1115
+ # 生成任务ID
1116
+ task_id = str(uuid.uuid4())
1117
+
1118
+ # 初始化任务状态
1119
+ self.upload_tasks[task_id] = {
1120
+ "status": "pending",
1121
+ "result": None,
1122
+ "error": None,
1123
+ }
1124
+
1125
+ # 启动后台任务
1126
+ asyncio.create_task(
1127
+ self._background_upload_from_url_task(
1128
+ task_id=task_id,
1129
+ kb_helper=kb_helper,
1130
+ url=url,
1131
+ chunk_size=chunk_size,
1132
+ chunk_overlap=chunk_overlap,
1133
+ batch_size=batch_size,
1134
+ tasks_limit=tasks_limit,
1135
+ max_retries=max_retries,
1136
+ enable_cleaning=enable_cleaning,
1137
+ cleaning_provider_id=cleaning_provider_id,
1138
+ ),
1139
+ )
1140
+
1141
+ return (
1142
+ Response()
1143
+ .ok(
1144
+ {
1145
+ "task_id": task_id,
1146
+ "url": url,
1147
+ "message": "URL upload task created, processing in background",
1148
+ },
1149
+ )
1150
+ .__dict__
1151
+ )
1152
+
1153
+ except ValueError as e:
1154
+ return Response().error(str(e)).__dict__
1155
+ except Exception as e:
1156
+ logger.error(f"从URL上传文档失败: {e}")
1157
+ logger.error(traceback.format_exc())
1158
+ return Response().error(f"从URL上传文档失败: {e!s}").__dict__
1159
+
1160
+ async def _background_upload_from_url_task(
1161
+ self,
1162
+ task_id: str,
1163
+ kb_helper,
1164
+ url: str,
1165
+ chunk_size: int,
1166
+ chunk_overlap: int,
1167
+ batch_size: int,
1168
+ tasks_limit: int,
1169
+ max_retries: int,
1170
+ enable_cleaning: bool,
1171
+ cleaning_provider_id: str | None,
1172
+ ):
1173
+ """后台上传URL任务"""
1174
+ try:
1175
+ # 初始化任务状态
1176
+ self.upload_tasks[task_id] = {
1177
+ "status": "processing",
1178
+ "result": None,
1179
+ "error": None,
1180
+ }
1181
+ self.upload_progress[task_id] = {
1182
+ "status": "processing",
1183
+ "file_index": 0,
1184
+ "file_total": 1,
1185
+ "file_name": f"URL: {url}",
1186
+ "stage": "extracting",
1187
+ "current": 0,
1188
+ "total": 100,
1189
+ }
1190
+
1191
+ # 创建进度回调函数
1192
+ async def progress_callback(stage, current, total):
1193
+ if task_id in self.upload_progress:
1194
+ self.upload_progress[task_id].update(
1195
+ {
1196
+ "status": "processing",
1197
+ "file_index": 0,
1198
+ "file_name": f"URL: {url}",
1199
+ "stage": stage,
1200
+ "current": current,
1201
+ "total": total,
1202
+ },
1203
+ )
1204
+
1205
+ # 上传文档
1206
+ doc = await kb_helper.upload_from_url(
1207
+ url=url,
1208
+ chunk_size=chunk_size,
1209
+ chunk_overlap=chunk_overlap,
1210
+ batch_size=batch_size,
1211
+ tasks_limit=tasks_limit,
1212
+ max_retries=max_retries,
1213
+ progress_callback=progress_callback,
1214
+ enable_cleaning=enable_cleaning,
1215
+ cleaning_provider_id=cleaning_provider_id,
1216
+ )
1217
+
1218
+ # 更新任务完成状态
1219
+ result = {
1220
+ "task_id": task_id,
1221
+ "uploaded": [doc.model_dump()],
1222
+ "failed": [],
1223
+ "total": 1,
1224
+ "success_count": 1,
1225
+ "failed_count": 0,
1226
+ }
1227
+
1228
+ self.upload_tasks[task_id] = {
1229
+ "status": "completed",
1230
+ "result": result,
1231
+ "error": None,
1232
+ }
1233
+ self.upload_progress[task_id]["status"] = "completed"
1234
+
1235
+ except Exception as e:
1236
+ logger.error(f"后台上传URL任务 {task_id} 失败: {e}")
1237
+ logger.error(traceback.format_exc())
1238
+ self.upload_tasks[task_id] = {
1239
+ "status": "failed",
1240
+ "result": None,
1241
+ "error": str(e),
1242
+ }
1243
+ if task_id in self.upload_progress:
1244
+ self.upload_progress[task_id]["status"] = "failed"