AstrBot 3.5.6__py3-none-any.whl → 4.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (288) hide show
  1. astrbot/api/__init__.py +16 -4
  2. astrbot/api/all.py +2 -1
  3. astrbot/api/event/__init__.py +5 -6
  4. astrbot/api/event/filter/__init__.py +37 -34
  5. astrbot/api/platform/__init__.py +7 -8
  6. astrbot/api/provider/__init__.py +8 -7
  7. astrbot/api/star/__init__.py +3 -4
  8. astrbot/api/util/__init__.py +2 -2
  9. astrbot/cli/__init__.py +1 -0
  10. astrbot/cli/__main__.py +18 -197
  11. astrbot/cli/commands/__init__.py +6 -0
  12. astrbot/cli/commands/cmd_conf.py +209 -0
  13. astrbot/cli/commands/cmd_init.py +56 -0
  14. astrbot/cli/commands/cmd_plug.py +245 -0
  15. astrbot/cli/commands/cmd_run.py +62 -0
  16. astrbot/cli/utils/__init__.py +18 -0
  17. astrbot/cli/utils/basic.py +76 -0
  18. astrbot/cli/utils/plugin.py +246 -0
  19. astrbot/cli/utils/version_comparator.py +90 -0
  20. astrbot/core/__init__.py +17 -19
  21. astrbot/core/agent/agent.py +14 -0
  22. astrbot/core/agent/handoff.py +38 -0
  23. astrbot/core/agent/hooks.py +30 -0
  24. astrbot/core/agent/mcp_client.py +385 -0
  25. astrbot/core/agent/message.py +175 -0
  26. astrbot/core/agent/response.py +14 -0
  27. astrbot/core/agent/run_context.py +22 -0
  28. astrbot/core/agent/runners/__init__.py +3 -0
  29. astrbot/core/agent/runners/base.py +65 -0
  30. astrbot/core/agent/runners/coze/coze_agent_runner.py +367 -0
  31. astrbot/core/agent/runners/coze/coze_api_client.py +324 -0
  32. astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py +403 -0
  33. astrbot/core/agent/runners/dify/dify_agent_runner.py +336 -0
  34. astrbot/core/agent/runners/dify/dify_api_client.py +195 -0
  35. astrbot/core/agent/runners/tool_loop_agent_runner.py +400 -0
  36. astrbot/core/agent/tool.py +285 -0
  37. astrbot/core/agent/tool_executor.py +17 -0
  38. astrbot/core/astr_agent_context.py +19 -0
  39. astrbot/core/astr_agent_hooks.py +36 -0
  40. astrbot/core/astr_agent_run_util.py +80 -0
  41. astrbot/core/astr_agent_tool_exec.py +246 -0
  42. astrbot/core/astrbot_config_mgr.py +275 -0
  43. astrbot/core/config/__init__.py +2 -2
  44. astrbot/core/config/astrbot_config.py +60 -20
  45. astrbot/core/config/default.py +1972 -453
  46. astrbot/core/config/i18n_utils.py +110 -0
  47. astrbot/core/conversation_mgr.py +285 -75
  48. astrbot/core/core_lifecycle.py +167 -62
  49. astrbot/core/db/__init__.py +305 -102
  50. astrbot/core/db/migration/helper.py +69 -0
  51. astrbot/core/db/migration/migra_3_to_4.py +357 -0
  52. astrbot/core/db/migration/migra_45_to_46.py +44 -0
  53. astrbot/core/db/migration/migra_webchat_session.py +131 -0
  54. astrbot/core/db/migration/shared_preferences_v3.py +48 -0
  55. astrbot/core/db/migration/sqlite_v3.py +497 -0
  56. astrbot/core/db/po.py +259 -55
  57. astrbot/core/db/sqlite.py +773 -528
  58. astrbot/core/db/vec_db/base.py +73 -0
  59. astrbot/core/db/vec_db/faiss_impl/__init__.py +3 -0
  60. astrbot/core/db/vec_db/faiss_impl/document_storage.py +392 -0
  61. astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +93 -0
  62. astrbot/core/db/vec_db/faiss_impl/sqlite_init.sql +17 -0
  63. astrbot/core/db/vec_db/faiss_impl/vec_db.py +204 -0
  64. astrbot/core/event_bus.py +26 -22
  65. astrbot/core/exceptions.py +9 -0
  66. astrbot/core/file_token_service.py +98 -0
  67. astrbot/core/initial_loader.py +19 -10
  68. astrbot/core/knowledge_base/chunking/__init__.py +9 -0
  69. astrbot/core/knowledge_base/chunking/base.py +25 -0
  70. astrbot/core/knowledge_base/chunking/fixed_size.py +59 -0
  71. astrbot/core/knowledge_base/chunking/recursive.py +161 -0
  72. astrbot/core/knowledge_base/kb_db_sqlite.py +301 -0
  73. astrbot/core/knowledge_base/kb_helper.py +642 -0
  74. astrbot/core/knowledge_base/kb_mgr.py +330 -0
  75. astrbot/core/knowledge_base/models.py +120 -0
  76. astrbot/core/knowledge_base/parsers/__init__.py +13 -0
  77. astrbot/core/knowledge_base/parsers/base.py +51 -0
  78. astrbot/core/knowledge_base/parsers/markitdown_parser.py +26 -0
  79. astrbot/core/knowledge_base/parsers/pdf_parser.py +101 -0
  80. astrbot/core/knowledge_base/parsers/text_parser.py +42 -0
  81. astrbot/core/knowledge_base/parsers/url_parser.py +103 -0
  82. astrbot/core/knowledge_base/parsers/util.py +13 -0
  83. astrbot/core/knowledge_base/prompts.py +65 -0
  84. astrbot/core/knowledge_base/retrieval/__init__.py +14 -0
  85. astrbot/core/knowledge_base/retrieval/hit_stopwords.txt +767 -0
  86. astrbot/core/knowledge_base/retrieval/manager.py +276 -0
  87. astrbot/core/knowledge_base/retrieval/rank_fusion.py +142 -0
  88. astrbot/core/knowledge_base/retrieval/sparse_retriever.py +136 -0
  89. astrbot/core/log.py +21 -15
  90. astrbot/core/message/components.py +413 -287
  91. astrbot/core/message/message_event_result.py +35 -24
  92. astrbot/core/persona_mgr.py +192 -0
  93. astrbot/core/pipeline/__init__.py +14 -14
  94. astrbot/core/pipeline/content_safety_check/stage.py +13 -9
  95. astrbot/core/pipeline/content_safety_check/strategies/__init__.py +1 -2
  96. astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py +13 -14
  97. astrbot/core/pipeline/content_safety_check/strategies/keywords.py +2 -1
  98. astrbot/core/pipeline/content_safety_check/strategies/strategy.py +6 -6
  99. astrbot/core/pipeline/context.py +7 -1
  100. astrbot/core/pipeline/context_utils.py +107 -0
  101. astrbot/core/pipeline/preprocess_stage/stage.py +63 -36
  102. astrbot/core/pipeline/process_stage/method/agent_request.py +48 -0
  103. astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +464 -0
  104. astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +202 -0
  105. astrbot/core/pipeline/process_stage/method/star_request.py +26 -32
  106. astrbot/core/pipeline/process_stage/stage.py +21 -15
  107. astrbot/core/pipeline/process_stage/utils.py +125 -0
  108. astrbot/core/pipeline/rate_limit_check/stage.py +34 -36
  109. astrbot/core/pipeline/respond/stage.py +142 -101
  110. astrbot/core/pipeline/result_decorate/stage.py +124 -57
  111. astrbot/core/pipeline/scheduler.py +21 -16
  112. astrbot/core/pipeline/session_status_check/stage.py +37 -0
  113. astrbot/core/pipeline/stage.py +11 -76
  114. astrbot/core/pipeline/waking_check/stage.py +69 -33
  115. astrbot/core/pipeline/whitelist_check/stage.py +10 -7
  116. astrbot/core/platform/__init__.py +6 -6
  117. astrbot/core/platform/astr_message_event.py +107 -129
  118. astrbot/core/platform/astrbot_message.py +32 -12
  119. astrbot/core/platform/manager.py +62 -18
  120. astrbot/core/platform/message_session.py +30 -0
  121. astrbot/core/platform/platform.py +16 -24
  122. astrbot/core/platform/platform_metadata.py +9 -4
  123. astrbot/core/platform/register.py +12 -7
  124. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +136 -60
  125. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +126 -46
  126. astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +63 -31
  127. astrbot/core/platform/sources/dingtalk/dingtalk_event.py +30 -26
  128. astrbot/core/platform/sources/discord/client.py +129 -0
  129. astrbot/core/platform/sources/discord/components.py +139 -0
  130. astrbot/core/platform/sources/discord/discord_platform_adapter.py +473 -0
  131. astrbot/core/platform/sources/discord/discord_platform_event.py +313 -0
  132. astrbot/core/platform/sources/lark/lark_adapter.py +27 -18
  133. astrbot/core/platform/sources/lark/lark_event.py +39 -13
  134. astrbot/core/platform/sources/misskey/misskey_adapter.py +770 -0
  135. astrbot/core/platform/sources/misskey/misskey_api.py +964 -0
  136. astrbot/core/platform/sources/misskey/misskey_event.py +163 -0
  137. astrbot/core/platform/sources/misskey/misskey_utils.py +550 -0
  138. astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +149 -33
  139. astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +41 -26
  140. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +36 -17
  141. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py +3 -1
  142. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +14 -8
  143. astrbot/core/platform/sources/satori/satori_adapter.py +792 -0
  144. astrbot/core/platform/sources/satori/satori_event.py +432 -0
  145. astrbot/core/platform/sources/slack/client.py +164 -0
  146. astrbot/core/platform/sources/slack/slack_adapter.py +416 -0
  147. astrbot/core/platform/sources/slack/slack_event.py +253 -0
  148. astrbot/core/platform/sources/telegram/tg_adapter.py +100 -43
  149. astrbot/core/platform/sources/telegram/tg_event.py +136 -36
  150. astrbot/core/platform/sources/webchat/webchat_adapter.py +72 -22
  151. astrbot/core/platform/sources/webchat/webchat_event.py +46 -22
  152. astrbot/core/platform/sources/webchat/webchat_queue_mgr.py +35 -0
  153. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +926 -0
  154. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +178 -0
  155. astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py +159 -0
  156. astrbot/core/platform/sources/wecom/wecom_adapter.py +169 -27
  157. astrbot/core/platform/sources/wecom/wecom_event.py +162 -77
  158. astrbot/core/platform/sources/wecom/wecom_kf.py +279 -0
  159. astrbot/core/platform/sources/wecom/wecom_kf_message.py +196 -0
  160. astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +297 -0
  161. astrbot/core/platform/sources/wecom_ai_bot/__init__.py +15 -0
  162. astrbot/core/platform/sources/wecom_ai_bot/ierror.py +19 -0
  163. astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +472 -0
  164. astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +417 -0
  165. astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +152 -0
  166. astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +153 -0
  167. astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +168 -0
  168. astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +209 -0
  169. astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +306 -0
  170. astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +186 -0
  171. astrbot/core/platform_message_history_mgr.py +49 -0
  172. astrbot/core/provider/__init__.py +2 -3
  173. astrbot/core/provider/entites.py +8 -8
  174. astrbot/core/provider/entities.py +154 -98
  175. astrbot/core/provider/func_tool_manager.py +446 -458
  176. astrbot/core/provider/manager.py +345 -207
  177. astrbot/core/provider/provider.py +188 -73
  178. astrbot/core/provider/register.py +9 -7
  179. astrbot/core/provider/sources/anthropic_source.py +295 -115
  180. astrbot/core/provider/sources/azure_tts_source.py +224 -0
  181. astrbot/core/provider/sources/bailian_rerank_source.py +236 -0
  182. astrbot/core/provider/sources/dashscope_tts.py +138 -14
  183. astrbot/core/provider/sources/edge_tts_source.py +24 -19
  184. astrbot/core/provider/sources/fishaudio_tts_api_source.py +58 -13
  185. astrbot/core/provider/sources/gemini_embedding_source.py +61 -0
  186. astrbot/core/provider/sources/gemini_source.py +310 -132
  187. astrbot/core/provider/sources/gemini_tts_source.py +81 -0
  188. astrbot/core/provider/sources/groq_source.py +15 -0
  189. astrbot/core/provider/sources/gsv_selfhosted_source.py +151 -0
  190. astrbot/core/provider/sources/gsvi_tts_source.py +14 -7
  191. astrbot/core/provider/sources/minimax_tts_api_source.py +159 -0
  192. astrbot/core/provider/sources/openai_embedding_source.py +40 -0
  193. astrbot/core/provider/sources/openai_source.py +241 -145
  194. astrbot/core/provider/sources/openai_tts_api_source.py +18 -7
  195. astrbot/core/provider/sources/sensevoice_selfhosted_source.py +13 -11
  196. astrbot/core/provider/sources/vllm_rerank_source.py +71 -0
  197. astrbot/core/provider/sources/volcengine_tts.py +115 -0
  198. astrbot/core/provider/sources/whisper_api_source.py +18 -13
  199. astrbot/core/provider/sources/whisper_selfhosted_source.py +19 -12
  200. astrbot/core/provider/sources/xinference_rerank_source.py +116 -0
  201. astrbot/core/provider/sources/xinference_stt_provider.py +197 -0
  202. astrbot/core/provider/sources/zhipu_source.py +6 -73
  203. astrbot/core/star/__init__.py +43 -11
  204. astrbot/core/star/config.py +17 -18
  205. astrbot/core/star/context.py +362 -138
  206. astrbot/core/star/filter/__init__.py +4 -3
  207. astrbot/core/star/filter/command.py +111 -35
  208. astrbot/core/star/filter/command_group.py +46 -34
  209. astrbot/core/star/filter/custom_filter.py +6 -5
  210. astrbot/core/star/filter/event_message_type.py +4 -2
  211. astrbot/core/star/filter/permission.py +4 -2
  212. astrbot/core/star/filter/platform_adapter_type.py +45 -12
  213. astrbot/core/star/filter/regex.py +4 -2
  214. astrbot/core/star/register/__init__.py +19 -15
  215. astrbot/core/star/register/star.py +41 -13
  216. astrbot/core/star/register/star_handler.py +236 -86
  217. astrbot/core/star/session_llm_manager.py +280 -0
  218. astrbot/core/star/session_plugin_manager.py +170 -0
  219. astrbot/core/star/star.py +36 -43
  220. astrbot/core/star/star_handler.py +47 -85
  221. astrbot/core/star/star_manager.py +442 -260
  222. astrbot/core/star/star_tools.py +167 -45
  223. astrbot/core/star/updator.py +17 -20
  224. astrbot/core/umop_config_router.py +106 -0
  225. astrbot/core/updator.py +38 -13
  226. astrbot/core/utils/astrbot_path.py +39 -0
  227. astrbot/core/utils/command_parser.py +1 -1
  228. astrbot/core/utils/io.py +119 -60
  229. astrbot/core/utils/log_pipe.py +1 -1
  230. astrbot/core/utils/metrics.py +11 -10
  231. astrbot/core/utils/migra_helper.py +73 -0
  232. astrbot/core/utils/path_util.py +63 -62
  233. astrbot/core/utils/pip_installer.py +37 -15
  234. astrbot/core/utils/session_lock.py +29 -0
  235. astrbot/core/utils/session_waiter.py +19 -20
  236. astrbot/core/utils/shared_preferences.py +174 -34
  237. astrbot/core/utils/t2i/__init__.py +4 -1
  238. astrbot/core/utils/t2i/local_strategy.py +386 -238
  239. astrbot/core/utils/t2i/network_strategy.py +109 -49
  240. astrbot/core/utils/t2i/renderer.py +29 -14
  241. astrbot/core/utils/t2i/template/astrbot_powershell.html +184 -0
  242. astrbot/core/utils/t2i/template_manager.py +111 -0
  243. astrbot/core/utils/tencent_record_helper.py +115 -1
  244. astrbot/core/utils/version_comparator.py +10 -13
  245. astrbot/core/zip_updator.py +112 -65
  246. astrbot/dashboard/routes/__init__.py +20 -13
  247. astrbot/dashboard/routes/auth.py +20 -9
  248. astrbot/dashboard/routes/chat.py +297 -141
  249. astrbot/dashboard/routes/config.py +652 -55
  250. astrbot/dashboard/routes/conversation.py +107 -37
  251. astrbot/dashboard/routes/file.py +26 -0
  252. astrbot/dashboard/routes/knowledge_base.py +1244 -0
  253. astrbot/dashboard/routes/log.py +27 -2
  254. astrbot/dashboard/routes/persona.py +202 -0
  255. astrbot/dashboard/routes/plugin.py +197 -139
  256. astrbot/dashboard/routes/route.py +27 -7
  257. astrbot/dashboard/routes/session_management.py +354 -0
  258. astrbot/dashboard/routes/stat.py +85 -18
  259. astrbot/dashboard/routes/static_file.py +5 -2
  260. astrbot/dashboard/routes/t2i.py +233 -0
  261. astrbot/dashboard/routes/tools.py +184 -120
  262. astrbot/dashboard/routes/update.py +59 -36
  263. astrbot/dashboard/server.py +96 -36
  264. astrbot/dashboard/utils.py +165 -0
  265. astrbot-4.7.0.dist-info/METADATA +294 -0
  266. astrbot-4.7.0.dist-info/RECORD +274 -0
  267. {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/WHEEL +1 -1
  268. astrbot/core/db/plugin/sqlite_impl.py +0 -112
  269. astrbot/core/db/sqlite_init.sql +0 -50
  270. astrbot/core/pipeline/platform_compatibility/stage.py +0 -56
  271. astrbot/core/pipeline/process_stage/method/llm_request.py +0 -606
  272. astrbot/core/platform/sources/gewechat/client.py +0 -806
  273. astrbot/core/platform/sources/gewechat/downloader.py +0 -55
  274. astrbot/core/platform/sources/gewechat/gewechat_event.py +0 -255
  275. astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py +0 -103
  276. astrbot/core/platform/sources/gewechat/xml_data_parser.py +0 -110
  277. astrbot/core/provider/sources/dashscope_source.py +0 -203
  278. astrbot/core/provider/sources/dify_source.py +0 -281
  279. astrbot/core/provider/sources/llmtuner_source.py +0 -132
  280. astrbot/core/rag/embedding/openai_source.py +0 -20
  281. astrbot/core/rag/knowledge_db_mgr.py +0 -94
  282. astrbot/core/rag/store/__init__.py +0 -9
  283. astrbot/core/rag/store/chroma_db.py +0 -42
  284. astrbot/core/utils/dify_api_client.py +0 -152
  285. astrbot-3.5.6.dist-info/METADATA +0 -249
  286. astrbot-3.5.6.dist-info/RECORD +0 -158
  287. {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/entry_points.txt +0 -0
  288. {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,276 @@
1
+ """检索管理器
2
+
3
+ 协调稠密检索、稀疏检索和 Rerank,提供统一的检索接口
4
+ """
5
+
6
+ import time
7
+ from dataclasses import dataclass
8
+
9
+ from astrbot import logger
10
+ from astrbot.core.db.vec_db.base import Result
11
+ from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
12
+ from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
13
+ from astrbot.core.knowledge_base.retrieval.rank_fusion import RankFusion
14
+ from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseRetriever
15
+ from astrbot.core.provider.provider import RerankProvider
16
+
17
+ from ..kb_helper import KBHelper
18
+
19
+
20
+ @dataclass
21
+ class RetrievalResult:
22
+ """检索结果"""
23
+
24
+ chunk_id: str
25
+ doc_id: str
26
+ doc_name: str
27
+ kb_id: str
28
+ kb_name: str
29
+ content: str
30
+ score: float
31
+ metadata: dict
32
+
33
+
34
+ class RetrievalManager:
35
+ """检索管理器
36
+
37
+ 职责:
38
+ - 协调稠密检索、稀疏检索和 Rerank
39
+ - 结果融合和排序
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ sparse_retriever: SparseRetriever,
45
+ rank_fusion: RankFusion,
46
+ kb_db: KBSQLiteDatabase,
47
+ ):
48
+ """初始化检索管理器
49
+
50
+ Args:
51
+ vec_db_factory: 向量数据库工厂
52
+ sparse_retriever: 稀疏检索器
53
+ rank_fusion: 结果融合器
54
+ kb_db: 知识库数据库实例
55
+
56
+ """
57
+ self.sparse_retriever = sparse_retriever
58
+ self.rank_fusion = rank_fusion
59
+ self.kb_db = kb_db
60
+
61
+ async def retrieve(
62
+ self,
63
+ query: str,
64
+ kb_ids: list[str],
65
+ kb_id_helper_map: dict[str, KBHelper],
66
+ top_k_fusion: int = 20,
67
+ top_m_final: int = 5,
68
+ ) -> list[RetrievalResult]:
69
+ """混合检索
70
+
71
+ 流程:
72
+ 1. 稠密检索 (向量相似度)
73
+ 2. 稀疏检索 (BM25)
74
+ 3. 结果融合 (RRF)
75
+ 4. Rerank 重排序
76
+
77
+ Args:
78
+ query: 查询文本
79
+ kb_ids: 知识库 ID 列表
80
+ top_m_final: 最终返回数量
81
+ enable_rerank: 是否启用 Rerank
82
+
83
+ Returns:
84
+ List[RetrievalResult]: 检索结果列表
85
+
86
+ """
87
+ if not kb_ids:
88
+ return []
89
+
90
+ kb_options: dict = {}
91
+ new_kb_ids = []
92
+ for kb_id in kb_ids:
93
+ kb_helper = kb_id_helper_map.get(kb_id)
94
+ if kb_helper:
95
+ kb = kb_helper.kb
96
+ kb_options[kb_id] = {
97
+ "top_k_dense": kb.top_k_dense or 50,
98
+ "top_k_sparse": kb.top_k_sparse or 50,
99
+ "top_m_final": kb.top_m_final or 5,
100
+ "vec_db": kb_helper.vec_db,
101
+ "rerank_provider_id": kb.rerank_provider_id,
102
+ }
103
+ new_kb_ids.append(kb_id)
104
+ else:
105
+ logger.warning(f"知识库 ID {kb_id} 实例未找到, 已跳过该知识库的检索")
106
+
107
+ kb_ids = new_kb_ids
108
+
109
+ # 1. 稠密检索
110
+ time_start = time.time()
111
+ dense_results = await self._dense_retrieve(
112
+ query=query,
113
+ kb_ids=kb_ids,
114
+ kb_options=kb_options,
115
+ )
116
+ time_end = time.time()
117
+ logger.debug(
118
+ f"Dense retrieval across {len(kb_ids)} bases took {time_end - time_start:.2f}s and returned {len(dense_results)} results.",
119
+ )
120
+
121
+ # 2. 稀疏检索
122
+ time_start = time.time()
123
+ sparse_results = await self.sparse_retriever.retrieve(
124
+ query=query,
125
+ kb_ids=kb_ids,
126
+ kb_options=kb_options,
127
+ )
128
+ time_end = time.time()
129
+ logger.debug(
130
+ f"Sparse retrieval across {len(kb_ids)} bases took {time_end - time_start:.2f}s and returned {len(sparse_results)} results.",
131
+ )
132
+
133
+ # 3. 结果融合
134
+ time_start = time.time()
135
+ fused_results = await self.rank_fusion.fuse(
136
+ dense_results=dense_results,
137
+ sparse_results=sparse_results,
138
+ top_k=top_k_fusion,
139
+ )
140
+ time_end = time.time()
141
+ logger.debug(
142
+ f"Rank fusion took {time_end - time_start:.2f}s and returned {len(fused_results)} results.",
143
+ )
144
+
145
+ # 4. 转换为 RetrievalResult (获取元数据)
146
+ retrieval_results = []
147
+ for fr in fused_results:
148
+ metadata_dict = await self.kb_db.get_document_with_metadata(fr.doc_id)
149
+ if metadata_dict:
150
+ retrieval_results.append(
151
+ RetrievalResult(
152
+ chunk_id=fr.chunk_id,
153
+ doc_id=fr.doc_id,
154
+ doc_name=metadata_dict["document"].doc_name,
155
+ kb_id=fr.kb_id,
156
+ kb_name=metadata_dict["knowledge_base"].kb_name,
157
+ content=fr.content,
158
+ score=fr.score,
159
+ metadata={
160
+ "chunk_index": fr.chunk_index,
161
+ "char_count": len(fr.content),
162
+ },
163
+ ),
164
+ )
165
+
166
+ # 5. Rerank
167
+ first_rerank = None
168
+ for kb_id in kb_ids:
169
+ vec_db: FaissVecDB = kb_options[kb_id]["vec_db"]
170
+ rerank_pi = kb_options[kb_id]["rerank_provider_id"]
171
+ if (
172
+ vec_db
173
+ and vec_db.rerank_provider
174
+ and rerank_pi
175
+ and rerank_pi == vec_db.rerank_provider.meta().id
176
+ ):
177
+ first_rerank = vec_db.rerank_provider
178
+ break
179
+ if first_rerank and retrieval_results:
180
+ retrieval_results = await self._rerank(
181
+ query=query,
182
+ results=retrieval_results,
183
+ top_k=top_m_final,
184
+ rerank_provider=first_rerank,
185
+ )
186
+
187
+ return retrieval_results[:top_m_final]
188
+
189
+ async def _dense_retrieve(
190
+ self,
191
+ query: str,
192
+ kb_ids: list[str],
193
+ kb_options: dict,
194
+ ):
195
+ """稠密检索 (向量相似度)
196
+
197
+ 为每个知识库使用独立的向量数据库进行检索,然后合并结果。
198
+
199
+ Args:
200
+ query: 查询文本
201
+ kb_ids: 知识库 ID 列表
202
+ top_k: 返回结果数量
203
+
204
+ Returns:
205
+ List[Result]: 检索结果列表
206
+
207
+ """
208
+ all_results: list[Result] = []
209
+ for kb_id in kb_ids:
210
+ if kb_id not in kb_options:
211
+ continue
212
+ try:
213
+ vec_db: FaissVecDB = kb_options[kb_id]["vec_db"]
214
+ dense_k = int(kb_options[kb_id]["top_k_dense"])
215
+ vec_results = await vec_db.retrieve(
216
+ query=query,
217
+ k=dense_k,
218
+ fetch_k=dense_k * 2,
219
+ rerank=False, # 稠密检索阶段不进行 rerank
220
+ metadata_filters={"kb_id": kb_id},
221
+ )
222
+
223
+ all_results.extend(vec_results)
224
+ except Exception as e:
225
+ from astrbot.core import logger
226
+
227
+ logger.warning(f"知识库 {kb_id} 稠密检索失败: {e}")
228
+ continue
229
+
230
+ # 按相似度排序并返回 top_k
231
+ all_results.sort(key=lambda x: x.similarity, reverse=True)
232
+ # return all_results[: len(all_results) // len(kb_ids)]
233
+ return all_results
234
+
235
+ async def _rerank(
236
+ self,
237
+ query: str,
238
+ results: list[RetrievalResult],
239
+ top_k: int,
240
+ rerank_provider: RerankProvider,
241
+ ) -> list[RetrievalResult]:
242
+ """Rerank 重排序
243
+
244
+ Args:
245
+ query: 查询文本
246
+ results: 检索结果列表
247
+ top_k: 返回结果数量
248
+
249
+ Returns:
250
+ List[RetrievalResult]: 重排序后的结果列表
251
+
252
+ """
253
+ if not results:
254
+ return []
255
+
256
+ # 准备文档列表
257
+ docs = [r.content for r in results]
258
+
259
+ # 调用 Rerank Provider
260
+ rerank_results = await rerank_provider.rerank(
261
+ query=query,
262
+ documents=docs,
263
+ )
264
+
265
+ # 更新分数并重新排序
266
+ reranked_list = []
267
+ for rerank_result in rerank_results:
268
+ idx = rerank_result.index
269
+ if idx < len(results):
270
+ result = results[idx]
271
+ result.score = rerank_result.relevance_score
272
+ reranked_list.append(result)
273
+
274
+ reranked_list.sort(key=lambda x: x.score, reverse=True)
275
+
276
+ return reranked_list[:top_k]
@@ -0,0 +1,142 @@
1
+ """检索结果融合器
2
+
3
+ 使用 Reciprocal Rank Fusion (RRF) 算法融合稠密检索和稀疏检索的结果
4
+ """
5
+
6
+ import json
7
+ from dataclasses import dataclass
8
+
9
+ from astrbot.core.db.vec_db.base import Result
10
+ from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
11
+ from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseResult
12
+
13
+
14
+ @dataclass
15
+ class FusedResult:
16
+ """融合后的检索结果"""
17
+
18
+ chunk_id: str
19
+ chunk_index: int
20
+ doc_id: str
21
+ kb_id: str
22
+ content: str
23
+ score: float
24
+
25
+
26
+ class RankFusion:
27
+ """检索结果融合器
28
+
29
+ 职责:
30
+ - 融合稠密检索和稀疏检索的结果
31
+ - 使用 Reciprocal Rank Fusion (RRF) 算法
32
+ """
33
+
34
+ def __init__(self, kb_db: KBSQLiteDatabase, k: int = 60):
35
+ """初始化结果融合器
36
+
37
+ Args:
38
+ kb_db: 知识库数据库实例
39
+ k: RRF 参数,用于平滑排名
40
+
41
+ """
42
+ self.kb_db = kb_db
43
+ self.k = k
44
+
45
+ async def fuse(
46
+ self,
47
+ dense_results: list[Result],
48
+ sparse_results: list[SparseResult],
49
+ top_k: int = 20,
50
+ ) -> list[FusedResult]:
51
+ """融合稠密和稀疏检索结果
52
+
53
+ RRF 公式:
54
+ score(doc) = sum(1 / (k + rank_i))
55
+
56
+ Args:
57
+ dense_results: 稠密检索结果
58
+ sparse_results: 稀疏检索结果
59
+ top_k: 返回结果数量
60
+
61
+ Returns:
62
+ List[FusedResult]: 融合后的结果列表
63
+
64
+ """
65
+ # 1. 构建排名映射
66
+ dense_ranks = {
67
+ r.data["doc_id"]: (idx + 1) for idx, r in enumerate(dense_results)
68
+ } # 这里的 doc_id 实际上是 chunk_id
69
+ sparse_ranks = {r.chunk_id: (idx + 1) for idx, r in enumerate(sparse_results)}
70
+
71
+ # 2. 收集所有唯一的 ID
72
+ # 需要统一为 chunk_id
73
+ all_chunk_ids = set()
74
+ vec_doc_id_to_dense: dict[str, Result] = {} # vec_doc_id -> Result
75
+ chunk_id_to_sparse: dict[str, SparseResult] = {} # chunk_id -> SparseResult
76
+
77
+ # 处理稀疏检索结果
78
+ for r in sparse_results:
79
+ all_chunk_ids.add(r.chunk_id)
80
+ chunk_id_to_sparse[r.chunk_id] = r
81
+
82
+ # 处理稠密检索结果 (需要转换 vec_doc_id 到 chunk_id)
83
+ for r in dense_results:
84
+ vec_doc_id = r.data["doc_id"]
85
+ all_chunk_ids.add(vec_doc_id)
86
+ vec_doc_id_to_dense[vec_doc_id] = r
87
+
88
+ # 3. 计算 RRF 分数
89
+ rrf_scores: dict[str, float] = {}
90
+
91
+ for identifier in all_chunk_ids:
92
+ score = 0.0
93
+
94
+ # 来自稠密检索的贡献
95
+ if identifier in dense_ranks:
96
+ score += 1.0 / (self.k + dense_ranks[identifier])
97
+
98
+ # 来自稀疏检索的贡献
99
+ if identifier in sparse_ranks:
100
+ score += 1.0 / (self.k + sparse_ranks[identifier])
101
+
102
+ rrf_scores[identifier] = score
103
+
104
+ # 4. 排序
105
+ sorted_ids = sorted(
106
+ rrf_scores.keys(),
107
+ key=lambda cid: rrf_scores[cid],
108
+ reverse=True,
109
+ )[:top_k]
110
+
111
+ # 5. 构建融合结果
112
+ fused_results = []
113
+ for identifier in sorted_ids:
114
+ # 优先从稀疏检索获取完整信息
115
+ if identifier in chunk_id_to_sparse:
116
+ sr = chunk_id_to_sparse[identifier]
117
+ fused_results.append(
118
+ FusedResult(
119
+ chunk_id=sr.chunk_id,
120
+ chunk_index=sr.chunk_index,
121
+ doc_id=sr.doc_id,
122
+ kb_id=sr.kb_id,
123
+ content=sr.content,
124
+ score=rrf_scores[identifier],
125
+ ),
126
+ )
127
+ elif identifier in vec_doc_id_to_dense:
128
+ # 从向量检索获取信息,需要从数据库获取块的详细信息
129
+ vec_result = vec_doc_id_to_dense[identifier]
130
+ chunk_md = json.loads(vec_result.data["metadata"])
131
+ fused_results.append(
132
+ FusedResult(
133
+ chunk_id=identifier,
134
+ chunk_index=chunk_md["chunk_index"],
135
+ doc_id=chunk_md["kb_doc_id"],
136
+ kb_id=chunk_md["kb_id"],
137
+ content=vec_result.data["text"],
138
+ score=rrf_scores[identifier],
139
+ ),
140
+ )
141
+
142
+ return fused_results
@@ -0,0 +1,136 @@
1
+ """稀疏检索器
2
+
3
+ 使用 BM25 算法进行基于关键词的文档检索
4
+ """
5
+
6
+ import json
7
+ import os
8
+ from dataclasses import dataclass
9
+
10
+ import jieba
11
+ from rank_bm25 import BM25Okapi
12
+
13
+ from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
14
+ from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
15
+
16
+
17
+ @dataclass
18
+ class SparseResult:
19
+ """稀疏检索结果"""
20
+
21
+ chunk_index: int
22
+ chunk_id: str
23
+ doc_id: str
24
+ kb_id: str
25
+ content: str
26
+ score: float
27
+
28
+
29
+ class SparseRetriever:
30
+ """BM25 稀疏检索器
31
+
32
+ 职责:
33
+ - 基于关键词的文档检索
34
+ - 使用 BM25 算法计算相关度
35
+ """
36
+
37
+ def __init__(self, kb_db: KBSQLiteDatabase):
38
+ """初始化稀疏检索器
39
+
40
+ Args:
41
+ kb_db: 知识库数据库实例
42
+
43
+ """
44
+ self.kb_db = kb_db
45
+ self._index_cache = {} # 缓存 BM25 索引
46
+
47
+ with open(
48
+ os.path.join(os.path.dirname(__file__), "hit_stopwords.txt"),
49
+ encoding="utf-8",
50
+ ) as f:
51
+ self.hit_stopwords = {
52
+ word.strip() for word in set(f.read().splitlines()) if word.strip()
53
+ }
54
+
55
+ async def retrieve(
56
+ self,
57
+ query: str,
58
+ kb_ids: list[str],
59
+ kb_options: dict,
60
+ ) -> list[SparseResult]:
61
+ """执行稀疏检索
62
+
63
+ Args:
64
+ query: 查询文本
65
+ kb_ids: 知识库 ID 列表
66
+ kb_options: 每个知识库的检索选项
67
+
68
+ Returns:
69
+ List[SparseResult]: 检索结果列表
70
+
71
+ """
72
+ # 1. 获取所有相关块
73
+ top_k_sparse = 0
74
+ chunks = []
75
+ for kb_id in kb_ids:
76
+ vec_db: FaissVecDB = kb_options.get(kb_id, {}).get("vec_db")
77
+ if not vec_db:
78
+ continue
79
+ result = await vec_db.document_storage.get_documents(
80
+ metadata_filters={},
81
+ limit=None,
82
+ offset=None,
83
+ )
84
+ chunk_mds = [json.loads(doc["metadata"]) for doc in result]
85
+ result = [
86
+ {
87
+ "chunk_id": doc["doc_id"],
88
+ "chunk_index": chunk_md["chunk_index"],
89
+ "doc_id": chunk_md["kb_doc_id"],
90
+ "kb_id": kb_id,
91
+ "text": doc["text"],
92
+ }
93
+ for doc, chunk_md in zip(result, chunk_mds)
94
+ ]
95
+ chunks.extend(result)
96
+ top_k_sparse += kb_options.get(kb_id, {}).get("top_k_sparse", 50)
97
+
98
+ if not chunks:
99
+ return []
100
+
101
+ # 2. 准备文档和索引
102
+ corpus = [chunk["text"] for chunk in chunks]
103
+ tokenized_corpus = [list(jieba.cut(doc)) for doc in corpus]
104
+ tokenized_corpus = [
105
+ [word for word in doc if word not in self.hit_stopwords]
106
+ for doc in tokenized_corpus
107
+ ]
108
+
109
+ # 3. 构建 BM25 索引
110
+ bm25 = BM25Okapi(tokenized_corpus)
111
+
112
+ # 4. 执行检索
113
+ tokenized_query = list(jieba.cut(query))
114
+ tokenized_query = [
115
+ word for word in tokenized_query if word not in self.hit_stopwords
116
+ ]
117
+ scores = bm25.get_scores(tokenized_query)
118
+
119
+ # 5. 排序并返回 Top-K
120
+ results = []
121
+ for idx, score in enumerate(scores):
122
+ chunk = chunks[idx]
123
+ results.append(
124
+ SparseResult(
125
+ chunk_id=chunk["chunk_id"],
126
+ chunk_index=chunk["chunk_index"],
127
+ doc_id=chunk["doc_id"],
128
+ kb_id=chunk["kb_id"],
129
+ content=chunk["text"],
130
+ score=float(score),
131
+ ),
132
+ )
133
+
134
+ results.sort(key=lambda x: x.score, reverse=True)
135
+ # return results[: len(results) // len(kb_ids)]
136
+ return results[:top_k_sparse]
astrbot/core/log.py CHANGED
@@ -1,5 +1,4 @@
1
- """
2
- 日志系统, 用于支持核心组件和插件的日志记录, 提供了日志订阅功能
1
+ """日志系统, 用于支持核心组件和插件的日志记录, 提供了日志订阅功能
3
2
 
4
3
  const:
5
4
  CACHED_SIZE: 日志缓存大小, 用于限制缓存的日志数量
@@ -21,14 +20,14 @@ function:
21
20
  4. 订阅者可以使用 register() 方法注册到 LogBroker, 订阅日志流
22
21
  """
23
22
 
24
- import logging
25
- import colorlog
26
23
  import asyncio
24
+ import logging
27
25
  import os
28
26
  import sys
29
- from collections import deque
30
27
  from asyncio import Queue
31
- from typing import List
28
+ from collections import deque
29
+
30
+ import colorlog
32
31
 
33
32
  # 日志缓存大小
34
33
  CACHED_SIZE = 200
@@ -52,6 +51,7 @@ def is_plugin_path(pathname):
52
51
 
53
52
  Returns:
54
53
  bool: 如果路径来自插件目录,则返回 True,否则返回 False
54
+
55
55
  """
56
56
  if not pathname:
57
57
  return False
@@ -68,6 +68,7 @@ def get_short_level_name(level_name):
68
68
 
69
69
  Returns:
70
70
  str: 四个字母的日志级别缩写
71
+
71
72
  """
72
73
  level_map = {
73
74
  "DEBUG": "DBUG",
@@ -87,17 +88,16 @@ class LogBroker:
87
88
 
88
89
  def __init__(self):
89
90
  self.log_cache = deque(maxlen=CACHED_SIZE) # 环形缓冲区, 保存最近的日志
90
- self.subscribers: List[Queue] = [] # 订阅者列表
91
+ self.subscribers: list[Queue] = [] # 订阅者列表
91
92
 
92
93
  def register(self) -> Queue:
93
94
  """注册新的订阅者, 并给每个订阅者返回一个带有日志缓存的队列
94
95
 
95
96
  Returns:
96
97
  Queue: 订阅者的队列, 可用于接收日志消息
98
+
97
99
  """
98
100
  q = Queue(maxsize=CACHED_SIZE + 10)
99
- for log in self.log_cache:
100
- q.put_nowait(log)
101
101
  self.subscribers.append(q)
102
102
  return q
103
103
 
@@ -106,6 +106,7 @@ class LogBroker:
106
106
 
107
107
  Args:
108
108
  q (Queue): 需要取消订阅的队列
109
+
109
110
  """
110
111
  self.subscribers.remove(q)
111
112
 
@@ -115,6 +116,7 @@ class LogBroker:
115
116
  Args:
116
117
  log_entry (dict): 日志消息, 包含日志级别和日志内容.
117
118
  example: {"level": "INFO", "data": "This is a log message.", "time": "2023-10-01 12:00:00"}
119
+
118
120
  """
119
121
  self.log_cache.append(log_entry)
120
122
  for q in self.subscribers:
@@ -140,6 +142,7 @@ class LogQueueHandler(logging.Handler):
140
142
 
141
143
  Args:
142
144
  record (logging.LogRecord): 日志记录对象, 包含日志信息
145
+
143
146
  """
144
147
  log_entry = self.format(record)
145
148
  self.log_broker.publish(
@@ -147,7 +150,7 @@ class LogQueueHandler(logging.Handler):
147
150
  "level": record.levelname,
148
151
  "time": record.asctime,
149
152
  "data": log_entry,
150
- }
153
+ },
151
154
  )
152
155
 
153
156
 
@@ -166,6 +169,7 @@ class LogManager:
166
169
 
167
170
  Returns:
168
171
  logging.Logger: 返回配置好的日志记录器
172
+
169
173
  """
170
174
  logger = logging.getLogger(log_name)
171
175
  # 检查该logger或父级logger是否已经有处理器, 如果已经有处理器, 直接返回该logger, 避免重复配置
@@ -173,10 +177,10 @@ class LogManager:
173
177
  return logger
174
178
  # 如果logger没有处理器
175
179
  console_handler = logging.StreamHandler(
176
- sys.stdout
180
+ sys.stdout,
177
181
  ) # 创建一个StreamHandler用于控制台输出
178
182
  console_handler.setLevel(
179
- logging.DEBUG
183
+ logging.DEBUG,
180
184
  ) # 将日志级别设置为DEBUG(最低级别, 显示所有日志), *如果插件没有设置级别, 默认为DEBUG
181
185
 
182
186
  # 创建彩色日志格式化器, 输出日志格式为: [时间] [插件标签] [日志级别] [文件名:行号]: 日志消息
@@ -197,7 +201,8 @@ class LogManager:
197
201
 
198
202
  class FileNameFilter(logging.Filter):
199
203
  """文件名过滤器类, 用于修改日志记录的文件名格式
200
- 例如: 将文件路径 /path/to/file.py 转换为 file.<file> 格式"""
204
+ 例如: 将文件路径 /path/to/file.py 转换为 file.<file> 格式
205
+ """
201
206
 
202
207
  # 获取这个文件和父文件夹的名字:<folder>.<file> 并且去除 .py
203
208
  def filter(self, record):
@@ -233,6 +238,7 @@ class LogManager:
233
238
  Args:
234
239
  logger (logging.Logger): 日志记录器
235
240
  log_broker (LogBroker): 日志代理类, 用于缓存和分发日志消息
241
+
236
242
  """
237
243
  handler = LogQueueHandler(log_broker)
238
244
  handler.setLevel(logging.DEBUG)
@@ -242,7 +248,7 @@ class LogManager:
242
248
  # 为队列处理器设置相同格式的formatter
243
249
  handler.setFormatter(
244
250
  logging.Formatter(
245
- "[%(asctime)s] [%(short_levelname)s] %(plugin_tag)s[%(filename)s:%(lineno)d]: %(message)s"
246
- )
251
+ "[%(asctime)s] [%(short_levelname)s] %(plugin_tag)s[%(filename)s:%(lineno)d]: %(message)s",
252
+ ),
247
253
  )
248
254
  logger.addHandler(handler)