AstrBot 3.5.6__py3-none-any.whl → 4.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (288) hide show
  1. astrbot/api/__init__.py +16 -4
  2. astrbot/api/all.py +2 -1
  3. astrbot/api/event/__init__.py +5 -6
  4. astrbot/api/event/filter/__init__.py +37 -34
  5. astrbot/api/platform/__init__.py +7 -8
  6. astrbot/api/provider/__init__.py +8 -7
  7. astrbot/api/star/__init__.py +3 -4
  8. astrbot/api/util/__init__.py +2 -2
  9. astrbot/cli/__init__.py +1 -0
  10. astrbot/cli/__main__.py +18 -197
  11. astrbot/cli/commands/__init__.py +6 -0
  12. astrbot/cli/commands/cmd_conf.py +209 -0
  13. astrbot/cli/commands/cmd_init.py +56 -0
  14. astrbot/cli/commands/cmd_plug.py +245 -0
  15. astrbot/cli/commands/cmd_run.py +62 -0
  16. astrbot/cli/utils/__init__.py +18 -0
  17. astrbot/cli/utils/basic.py +76 -0
  18. astrbot/cli/utils/plugin.py +246 -0
  19. astrbot/cli/utils/version_comparator.py +90 -0
  20. astrbot/core/__init__.py +17 -19
  21. astrbot/core/agent/agent.py +14 -0
  22. astrbot/core/agent/handoff.py +38 -0
  23. astrbot/core/agent/hooks.py +30 -0
  24. astrbot/core/agent/mcp_client.py +385 -0
  25. astrbot/core/agent/message.py +175 -0
  26. astrbot/core/agent/response.py +14 -0
  27. astrbot/core/agent/run_context.py +22 -0
  28. astrbot/core/agent/runners/__init__.py +3 -0
  29. astrbot/core/agent/runners/base.py +65 -0
  30. astrbot/core/agent/runners/coze/coze_agent_runner.py +367 -0
  31. astrbot/core/agent/runners/coze/coze_api_client.py +324 -0
  32. astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py +403 -0
  33. astrbot/core/agent/runners/dify/dify_agent_runner.py +336 -0
  34. astrbot/core/agent/runners/dify/dify_api_client.py +195 -0
  35. astrbot/core/agent/runners/tool_loop_agent_runner.py +400 -0
  36. astrbot/core/agent/tool.py +285 -0
  37. astrbot/core/agent/tool_executor.py +17 -0
  38. astrbot/core/astr_agent_context.py +19 -0
  39. astrbot/core/astr_agent_hooks.py +36 -0
  40. astrbot/core/astr_agent_run_util.py +80 -0
  41. astrbot/core/astr_agent_tool_exec.py +246 -0
  42. astrbot/core/astrbot_config_mgr.py +275 -0
  43. astrbot/core/config/__init__.py +2 -2
  44. astrbot/core/config/astrbot_config.py +60 -20
  45. astrbot/core/config/default.py +1972 -453
  46. astrbot/core/config/i18n_utils.py +110 -0
  47. astrbot/core/conversation_mgr.py +285 -75
  48. astrbot/core/core_lifecycle.py +167 -62
  49. astrbot/core/db/__init__.py +305 -102
  50. astrbot/core/db/migration/helper.py +69 -0
  51. astrbot/core/db/migration/migra_3_to_4.py +357 -0
  52. astrbot/core/db/migration/migra_45_to_46.py +44 -0
  53. astrbot/core/db/migration/migra_webchat_session.py +131 -0
  54. astrbot/core/db/migration/shared_preferences_v3.py +48 -0
  55. astrbot/core/db/migration/sqlite_v3.py +497 -0
  56. astrbot/core/db/po.py +259 -55
  57. astrbot/core/db/sqlite.py +773 -528
  58. astrbot/core/db/vec_db/base.py +73 -0
  59. astrbot/core/db/vec_db/faiss_impl/__init__.py +3 -0
  60. astrbot/core/db/vec_db/faiss_impl/document_storage.py +392 -0
  61. astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +93 -0
  62. astrbot/core/db/vec_db/faiss_impl/sqlite_init.sql +17 -0
  63. astrbot/core/db/vec_db/faiss_impl/vec_db.py +204 -0
  64. astrbot/core/event_bus.py +26 -22
  65. astrbot/core/exceptions.py +9 -0
  66. astrbot/core/file_token_service.py +98 -0
  67. astrbot/core/initial_loader.py +19 -10
  68. astrbot/core/knowledge_base/chunking/__init__.py +9 -0
  69. astrbot/core/knowledge_base/chunking/base.py +25 -0
  70. astrbot/core/knowledge_base/chunking/fixed_size.py +59 -0
  71. astrbot/core/knowledge_base/chunking/recursive.py +161 -0
  72. astrbot/core/knowledge_base/kb_db_sqlite.py +301 -0
  73. astrbot/core/knowledge_base/kb_helper.py +642 -0
  74. astrbot/core/knowledge_base/kb_mgr.py +330 -0
  75. astrbot/core/knowledge_base/models.py +120 -0
  76. astrbot/core/knowledge_base/parsers/__init__.py +13 -0
  77. astrbot/core/knowledge_base/parsers/base.py +51 -0
  78. astrbot/core/knowledge_base/parsers/markitdown_parser.py +26 -0
  79. astrbot/core/knowledge_base/parsers/pdf_parser.py +101 -0
  80. astrbot/core/knowledge_base/parsers/text_parser.py +42 -0
  81. astrbot/core/knowledge_base/parsers/url_parser.py +103 -0
  82. astrbot/core/knowledge_base/parsers/util.py +13 -0
  83. astrbot/core/knowledge_base/prompts.py +65 -0
  84. astrbot/core/knowledge_base/retrieval/__init__.py +14 -0
  85. astrbot/core/knowledge_base/retrieval/hit_stopwords.txt +767 -0
  86. astrbot/core/knowledge_base/retrieval/manager.py +276 -0
  87. astrbot/core/knowledge_base/retrieval/rank_fusion.py +142 -0
  88. astrbot/core/knowledge_base/retrieval/sparse_retriever.py +136 -0
  89. astrbot/core/log.py +21 -15
  90. astrbot/core/message/components.py +413 -287
  91. astrbot/core/message/message_event_result.py +35 -24
  92. astrbot/core/persona_mgr.py +192 -0
  93. astrbot/core/pipeline/__init__.py +14 -14
  94. astrbot/core/pipeline/content_safety_check/stage.py +13 -9
  95. astrbot/core/pipeline/content_safety_check/strategies/__init__.py +1 -2
  96. astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py +13 -14
  97. astrbot/core/pipeline/content_safety_check/strategies/keywords.py +2 -1
  98. astrbot/core/pipeline/content_safety_check/strategies/strategy.py +6 -6
  99. astrbot/core/pipeline/context.py +7 -1
  100. astrbot/core/pipeline/context_utils.py +107 -0
  101. astrbot/core/pipeline/preprocess_stage/stage.py +63 -36
  102. astrbot/core/pipeline/process_stage/method/agent_request.py +48 -0
  103. astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +464 -0
  104. astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +202 -0
  105. astrbot/core/pipeline/process_stage/method/star_request.py +26 -32
  106. astrbot/core/pipeline/process_stage/stage.py +21 -15
  107. astrbot/core/pipeline/process_stage/utils.py +125 -0
  108. astrbot/core/pipeline/rate_limit_check/stage.py +34 -36
  109. astrbot/core/pipeline/respond/stage.py +142 -101
  110. astrbot/core/pipeline/result_decorate/stage.py +124 -57
  111. astrbot/core/pipeline/scheduler.py +21 -16
  112. astrbot/core/pipeline/session_status_check/stage.py +37 -0
  113. astrbot/core/pipeline/stage.py +11 -76
  114. astrbot/core/pipeline/waking_check/stage.py +69 -33
  115. astrbot/core/pipeline/whitelist_check/stage.py +10 -7
  116. astrbot/core/platform/__init__.py +6 -6
  117. astrbot/core/platform/astr_message_event.py +107 -129
  118. astrbot/core/platform/astrbot_message.py +32 -12
  119. astrbot/core/platform/manager.py +62 -18
  120. astrbot/core/platform/message_session.py +30 -0
  121. astrbot/core/platform/platform.py +16 -24
  122. astrbot/core/platform/platform_metadata.py +9 -4
  123. astrbot/core/platform/register.py +12 -7
  124. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +136 -60
  125. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +126 -46
  126. astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +63 -31
  127. astrbot/core/platform/sources/dingtalk/dingtalk_event.py +30 -26
  128. astrbot/core/platform/sources/discord/client.py +129 -0
  129. astrbot/core/platform/sources/discord/components.py +139 -0
  130. astrbot/core/platform/sources/discord/discord_platform_adapter.py +473 -0
  131. astrbot/core/platform/sources/discord/discord_platform_event.py +313 -0
  132. astrbot/core/platform/sources/lark/lark_adapter.py +27 -18
  133. astrbot/core/platform/sources/lark/lark_event.py +39 -13
  134. astrbot/core/platform/sources/misskey/misskey_adapter.py +770 -0
  135. astrbot/core/platform/sources/misskey/misskey_api.py +964 -0
  136. astrbot/core/platform/sources/misskey/misskey_event.py +163 -0
  137. astrbot/core/platform/sources/misskey/misskey_utils.py +550 -0
  138. astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +149 -33
  139. astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +41 -26
  140. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +36 -17
  141. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py +3 -1
  142. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +14 -8
  143. astrbot/core/platform/sources/satori/satori_adapter.py +792 -0
  144. astrbot/core/platform/sources/satori/satori_event.py +432 -0
  145. astrbot/core/platform/sources/slack/client.py +164 -0
  146. astrbot/core/platform/sources/slack/slack_adapter.py +416 -0
  147. astrbot/core/platform/sources/slack/slack_event.py +253 -0
  148. astrbot/core/platform/sources/telegram/tg_adapter.py +100 -43
  149. astrbot/core/platform/sources/telegram/tg_event.py +136 -36
  150. astrbot/core/platform/sources/webchat/webchat_adapter.py +72 -22
  151. astrbot/core/platform/sources/webchat/webchat_event.py +46 -22
  152. astrbot/core/platform/sources/webchat/webchat_queue_mgr.py +35 -0
  153. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +926 -0
  154. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +178 -0
  155. astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py +159 -0
  156. astrbot/core/platform/sources/wecom/wecom_adapter.py +169 -27
  157. astrbot/core/platform/sources/wecom/wecom_event.py +162 -77
  158. astrbot/core/platform/sources/wecom/wecom_kf.py +279 -0
  159. astrbot/core/platform/sources/wecom/wecom_kf_message.py +196 -0
  160. astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +297 -0
  161. astrbot/core/platform/sources/wecom_ai_bot/__init__.py +15 -0
  162. astrbot/core/platform/sources/wecom_ai_bot/ierror.py +19 -0
  163. astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +472 -0
  164. astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +417 -0
  165. astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +152 -0
  166. astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +153 -0
  167. astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +168 -0
  168. astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +209 -0
  169. astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +306 -0
  170. astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +186 -0
  171. astrbot/core/platform_message_history_mgr.py +49 -0
  172. astrbot/core/provider/__init__.py +2 -3
  173. astrbot/core/provider/entites.py +8 -8
  174. astrbot/core/provider/entities.py +154 -98
  175. astrbot/core/provider/func_tool_manager.py +446 -458
  176. astrbot/core/provider/manager.py +345 -207
  177. astrbot/core/provider/provider.py +188 -73
  178. astrbot/core/provider/register.py +9 -7
  179. astrbot/core/provider/sources/anthropic_source.py +295 -115
  180. astrbot/core/provider/sources/azure_tts_source.py +224 -0
  181. astrbot/core/provider/sources/bailian_rerank_source.py +236 -0
  182. astrbot/core/provider/sources/dashscope_tts.py +138 -14
  183. astrbot/core/provider/sources/edge_tts_source.py +24 -19
  184. astrbot/core/provider/sources/fishaudio_tts_api_source.py +58 -13
  185. astrbot/core/provider/sources/gemini_embedding_source.py +61 -0
  186. astrbot/core/provider/sources/gemini_source.py +310 -132
  187. astrbot/core/provider/sources/gemini_tts_source.py +81 -0
  188. astrbot/core/provider/sources/groq_source.py +15 -0
  189. astrbot/core/provider/sources/gsv_selfhosted_source.py +151 -0
  190. astrbot/core/provider/sources/gsvi_tts_source.py +14 -7
  191. astrbot/core/provider/sources/minimax_tts_api_source.py +159 -0
  192. astrbot/core/provider/sources/openai_embedding_source.py +40 -0
  193. astrbot/core/provider/sources/openai_source.py +241 -145
  194. astrbot/core/provider/sources/openai_tts_api_source.py +18 -7
  195. astrbot/core/provider/sources/sensevoice_selfhosted_source.py +13 -11
  196. astrbot/core/provider/sources/vllm_rerank_source.py +71 -0
  197. astrbot/core/provider/sources/volcengine_tts.py +115 -0
  198. astrbot/core/provider/sources/whisper_api_source.py +18 -13
  199. astrbot/core/provider/sources/whisper_selfhosted_source.py +19 -12
  200. astrbot/core/provider/sources/xinference_rerank_source.py +116 -0
  201. astrbot/core/provider/sources/xinference_stt_provider.py +197 -0
  202. astrbot/core/provider/sources/zhipu_source.py +6 -73
  203. astrbot/core/star/__init__.py +43 -11
  204. astrbot/core/star/config.py +17 -18
  205. astrbot/core/star/context.py +362 -138
  206. astrbot/core/star/filter/__init__.py +4 -3
  207. astrbot/core/star/filter/command.py +111 -35
  208. astrbot/core/star/filter/command_group.py +46 -34
  209. astrbot/core/star/filter/custom_filter.py +6 -5
  210. astrbot/core/star/filter/event_message_type.py +4 -2
  211. astrbot/core/star/filter/permission.py +4 -2
  212. astrbot/core/star/filter/platform_adapter_type.py +45 -12
  213. astrbot/core/star/filter/regex.py +4 -2
  214. astrbot/core/star/register/__init__.py +19 -15
  215. astrbot/core/star/register/star.py +41 -13
  216. astrbot/core/star/register/star_handler.py +236 -86
  217. astrbot/core/star/session_llm_manager.py +280 -0
  218. astrbot/core/star/session_plugin_manager.py +170 -0
  219. astrbot/core/star/star.py +36 -43
  220. astrbot/core/star/star_handler.py +47 -85
  221. astrbot/core/star/star_manager.py +442 -260
  222. astrbot/core/star/star_tools.py +167 -45
  223. astrbot/core/star/updator.py +17 -20
  224. astrbot/core/umop_config_router.py +106 -0
  225. astrbot/core/updator.py +38 -13
  226. astrbot/core/utils/astrbot_path.py +39 -0
  227. astrbot/core/utils/command_parser.py +1 -1
  228. astrbot/core/utils/io.py +119 -60
  229. astrbot/core/utils/log_pipe.py +1 -1
  230. astrbot/core/utils/metrics.py +11 -10
  231. astrbot/core/utils/migra_helper.py +73 -0
  232. astrbot/core/utils/path_util.py +63 -62
  233. astrbot/core/utils/pip_installer.py +37 -15
  234. astrbot/core/utils/session_lock.py +29 -0
  235. astrbot/core/utils/session_waiter.py +19 -20
  236. astrbot/core/utils/shared_preferences.py +174 -34
  237. astrbot/core/utils/t2i/__init__.py +4 -1
  238. astrbot/core/utils/t2i/local_strategy.py +386 -238
  239. astrbot/core/utils/t2i/network_strategy.py +109 -49
  240. astrbot/core/utils/t2i/renderer.py +29 -14
  241. astrbot/core/utils/t2i/template/astrbot_powershell.html +184 -0
  242. astrbot/core/utils/t2i/template_manager.py +111 -0
  243. astrbot/core/utils/tencent_record_helper.py +115 -1
  244. astrbot/core/utils/version_comparator.py +10 -13
  245. astrbot/core/zip_updator.py +112 -65
  246. astrbot/dashboard/routes/__init__.py +20 -13
  247. astrbot/dashboard/routes/auth.py +20 -9
  248. astrbot/dashboard/routes/chat.py +297 -141
  249. astrbot/dashboard/routes/config.py +652 -55
  250. astrbot/dashboard/routes/conversation.py +107 -37
  251. astrbot/dashboard/routes/file.py +26 -0
  252. astrbot/dashboard/routes/knowledge_base.py +1244 -0
  253. astrbot/dashboard/routes/log.py +27 -2
  254. astrbot/dashboard/routes/persona.py +202 -0
  255. astrbot/dashboard/routes/plugin.py +197 -139
  256. astrbot/dashboard/routes/route.py +27 -7
  257. astrbot/dashboard/routes/session_management.py +354 -0
  258. astrbot/dashboard/routes/stat.py +85 -18
  259. astrbot/dashboard/routes/static_file.py +5 -2
  260. astrbot/dashboard/routes/t2i.py +233 -0
  261. astrbot/dashboard/routes/tools.py +184 -120
  262. astrbot/dashboard/routes/update.py +59 -36
  263. astrbot/dashboard/server.py +96 -36
  264. astrbot/dashboard/utils.py +165 -0
  265. astrbot-4.7.0.dist-info/METADATA +294 -0
  266. astrbot-4.7.0.dist-info/RECORD +274 -0
  267. {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/WHEEL +1 -1
  268. astrbot/core/db/plugin/sqlite_impl.py +0 -112
  269. astrbot/core/db/sqlite_init.sql +0 -50
  270. astrbot/core/pipeline/platform_compatibility/stage.py +0 -56
  271. astrbot/core/pipeline/process_stage/method/llm_request.py +0 -606
  272. astrbot/core/platform/sources/gewechat/client.py +0 -806
  273. astrbot/core/platform/sources/gewechat/downloader.py +0 -55
  274. astrbot/core/platform/sources/gewechat/gewechat_event.py +0 -255
  275. astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py +0 -103
  276. astrbot/core/platform/sources/gewechat/xml_data_parser.py +0 -110
  277. astrbot/core/provider/sources/dashscope_source.py +0 -203
  278. astrbot/core/provider/sources/dify_source.py +0 -281
  279. astrbot/core/provider/sources/llmtuner_source.py +0 -132
  280. astrbot/core/rag/embedding/openai_source.py +0 -20
  281. astrbot/core/rag/knowledge_db_mgr.py +0 -94
  282. astrbot/core/rag/store/__init__.py +0 -9
  283. astrbot/core/rag/store/chroma_db.py +0 -42
  284. astrbot/core/utils/dify_api_client.py +0 -152
  285. astrbot-3.5.6.dist-info/METADATA +0 -249
  286. astrbot-3.5.6.dist-info/RECORD +0 -158
  287. {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/entry_points.txt +0 -0
  288. {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,73 @@
1
+ import abc
2
+ from dataclasses import dataclass
3
+
4
+
5
+ @dataclass
6
+ class Result:
7
+ similarity: float
8
+ data: dict
9
+
10
+
11
+ class BaseVecDB:
12
+ async def initialize(self):
13
+ """初始化向量数据库"""
14
+
15
+ @abc.abstractmethod
16
+ async def insert(
17
+ self,
18
+ content: str,
19
+ metadata: dict | None = None,
20
+ id: str | None = None,
21
+ ) -> int:
22
+ """插入一条文本和其对应向量,自动生成 ID 并保持一致性。"""
23
+ ...
24
+
25
+ @abc.abstractmethod
26
+ async def insert_batch(
27
+ self,
28
+ contents: list[str],
29
+ metadatas: list[dict] | None = None,
30
+ ids: list[str] | None = None,
31
+ batch_size: int = 32,
32
+ tasks_limit: int = 3,
33
+ max_retries: int = 3,
34
+ progress_callback=None,
35
+ ) -> int:
36
+ """批量插入文本和其对应向量,自动生成 ID 并保持一致性。
37
+
38
+ Args:
39
+ progress_callback: 进度回调函数,接收参数 (current, total)
40
+
41
+ """
42
+ ...
43
+
44
+ @abc.abstractmethod
45
+ async def retrieve(
46
+ self,
47
+ query: str,
48
+ top_k: int = 5,
49
+ fetch_k: int = 20,
50
+ rerank: bool = False,
51
+ metadata_filters: dict | None = None,
52
+ ) -> list[Result]:
53
+ """搜索最相似的文档。
54
+ Args:
55
+ query (str): 查询文本
56
+ top_k (int): 返回的最相似文档的数量
57
+ Returns:
58
+ List[Result]: 查询结果
59
+ """
60
+ ...
61
+
62
+ @abc.abstractmethod
63
+ async def delete(self, doc_id: str) -> bool:
64
+ """删除指定文档。
65
+ Args:
66
+ doc_id (str): 要删除的文档 ID
67
+ Returns:
68
+ bool: 删除是否成功
69
+ """
70
+ ...
71
+
72
+ @abc.abstractmethod
73
+ async def close(self): ...
@@ -0,0 +1,3 @@
1
+ from .vec_db import FaissVecDB
2
+
3
+ __all__ = ["FaissVecDB"]
@@ -0,0 +1,392 @@
1
+ import json
2
+ import os
3
+ from contextlib import asynccontextmanager
4
+ from datetime import datetime
5
+
6
+ from sqlalchemy import Column, Text
7
+ from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
8
+ from sqlalchemy.orm import sessionmaker
9
+ from sqlmodel import Field, MetaData, SQLModel, col, func, select, text
10
+
11
+ from astrbot.core import logger
12
+
13
+
14
+ class BaseDocModel(SQLModel, table=False):
15
+ metadata = MetaData()
16
+
17
+
18
+ class Document(BaseDocModel, table=True):
19
+ """SQLModel for documents table."""
20
+
21
+ __tablename__ = "documents" # type: ignore
22
+
23
+ id: int | None = Field(
24
+ default=None,
25
+ primary_key=True,
26
+ sa_column_kwargs={"autoincrement": True},
27
+ )
28
+ doc_id: str = Field(nullable=False)
29
+ text: str = Field(nullable=False)
30
+ metadata_: str | None = Field(default=None, sa_column=Column("metadata", Text))
31
+ created_at: datetime | None = Field(default=None)
32
+ updated_at: datetime | None = Field(default=None)
33
+
34
+
35
+ class DocumentStorage:
36
+ def __init__(self, db_path: str):
37
+ self.db_path = db_path
38
+ self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
39
+ self.engine: AsyncEngine | None = None
40
+ self.async_session_maker: sessionmaker | None = None
41
+ self.sqlite_init_path = os.path.join(
42
+ os.path.dirname(__file__),
43
+ "sqlite_init.sql",
44
+ )
45
+
46
+ async def initialize(self):
47
+ """Initialize the SQLite database and create the documents table if it doesn't exist."""
48
+ await self.connect()
49
+ async with self.engine.begin() as conn: # type: ignore
50
+ # Create tables using SQLModel
51
+ await conn.run_sync(BaseDocModel.metadata.create_all)
52
+
53
+ try:
54
+ await conn.execute(
55
+ text(
56
+ "ALTER TABLE documents ADD COLUMN kb_doc_id TEXT "
57
+ "GENERATED ALWAYS AS (json_extract(metadata, '$.kb_doc_id')) STORED",
58
+ ),
59
+ )
60
+ await conn.execute(
61
+ text(
62
+ "ALTER TABLE documents ADD COLUMN user_id TEXT "
63
+ "GENERATED ALWAYS AS (json_extract(metadata, '$.user_id')) STORED",
64
+ ),
65
+ )
66
+
67
+ # Create indexes
68
+ await conn.execute(
69
+ text(
70
+ "CREATE INDEX IF NOT EXISTS idx_documents_kb_doc_id ON documents(kb_doc_id)",
71
+ ),
72
+ )
73
+ await conn.execute(
74
+ text(
75
+ "CREATE INDEX IF NOT EXISTS idx_documents_user_id ON documents(user_id)",
76
+ ),
77
+ )
78
+ except BaseException:
79
+ pass
80
+
81
+ await conn.commit()
82
+
83
+ async def connect(self):
84
+ """Connect to the SQLite database."""
85
+ if self.engine is None:
86
+ self.engine = create_async_engine(
87
+ self.DATABASE_URL,
88
+ echo=False,
89
+ future=True,
90
+ )
91
+ self.async_session_maker = sessionmaker(
92
+ self.engine, # type: ignore
93
+ class_=AsyncSession,
94
+ expire_on_commit=False,
95
+ ) # type: ignore
96
+
97
+ @asynccontextmanager
98
+ async def get_session(self):
99
+ """Context manager for database sessions."""
100
+ async with self.async_session_maker() as session: # type: ignore
101
+ yield session
102
+
103
+ async def get_documents(
104
+ self,
105
+ metadata_filters: dict,
106
+ ids: list | None = None,
107
+ offset: int | None = 0,
108
+ limit: int | None = 100,
109
+ ) -> list[dict]:
110
+ """Retrieve documents by metadata filters and ids.
111
+
112
+ Args:
113
+ metadata_filters (dict): The metadata filters to apply.
114
+ ids (list | None): Optional list of document IDs to filter.
115
+ offset (int | None): Offset for pagination.
116
+ limit (int | None): Limit for pagination.
117
+
118
+ Returns:
119
+ list: The list of documents that match the filters.
120
+
121
+ """
122
+ if self.engine is None:
123
+ logger.warning(
124
+ "Database connection is not initialized, returning empty result",
125
+ )
126
+ return []
127
+
128
+ async with self.get_session() as session:
129
+ query = select(Document)
130
+
131
+ for key, val in metadata_filters.items():
132
+ query = query.where(
133
+ text(f"json_extract(metadata, '$.{key}') = :filter_{key}"),
134
+ ).params(**{f"filter_{key}": val})
135
+
136
+ if ids is not None and len(ids) > 0:
137
+ valid_ids = [int(i) for i in ids if i != -1]
138
+ if valid_ids:
139
+ query = query.where(col(Document.id).in_(valid_ids))
140
+
141
+ if limit is not None:
142
+ query = query.limit(limit)
143
+ if offset is not None:
144
+ query = query.offset(offset)
145
+
146
+ result = await session.execute(query)
147
+ documents = result.scalars().all()
148
+
149
+ return [self._document_to_dict(doc) for doc in documents]
150
+
151
+ async def insert_document(self, doc_id: str, text: str, metadata: dict) -> int:
152
+ """Insert a single document and return its integer ID.
153
+
154
+ Args:
155
+ doc_id (str): The document ID (UUID string).
156
+ text (str): The document text.
157
+ metadata (dict): The document metadata.
158
+
159
+ Returns:
160
+ int: The integer ID of the inserted document.
161
+
162
+ """
163
+ assert self.engine is not None, "Database connection is not initialized."
164
+
165
+ async with self.get_session() as session, session.begin():
166
+ document = Document(
167
+ doc_id=doc_id,
168
+ text=text,
169
+ metadata_=json.dumps(metadata),
170
+ created_at=datetime.now(),
171
+ updated_at=datetime.now(),
172
+ )
173
+ session.add(document)
174
+ await session.flush() # Flush to get the ID
175
+ return document.id # type: ignore
176
+
177
+ async def insert_documents_batch(
178
+ self,
179
+ doc_ids: list[str],
180
+ texts: list[str],
181
+ metadatas: list[dict],
182
+ ) -> list[int]:
183
+ """Batch insert documents and return their integer IDs.
184
+
185
+ Args:
186
+ doc_ids (list[str]): List of document IDs (UUID strings).
187
+ texts (list[str]): List of document texts.
188
+ metadatas (list[dict]): List of document metadata.
189
+
190
+ Returns:
191
+ list[int]: List of integer IDs of the inserted documents.
192
+
193
+ """
194
+ assert self.engine is not None, "Database connection is not initialized."
195
+
196
+ async with self.get_session() as session, session.begin():
197
+ import json
198
+
199
+ documents = []
200
+ for doc_id, text, metadata in zip(doc_ids, texts, metadatas):
201
+ document = Document(
202
+ doc_id=doc_id,
203
+ text=text,
204
+ metadata_=json.dumps(metadata),
205
+ created_at=datetime.now(),
206
+ updated_at=datetime.now(),
207
+ )
208
+ documents.append(document)
209
+ session.add(document)
210
+
211
+ await session.flush() # Flush to get all IDs
212
+ return [doc.id for doc in documents] # type: ignore
213
+
214
+ async def delete_document_by_doc_id(self, doc_id: str):
215
+ """Delete a document by its doc_id.
216
+
217
+ Args:
218
+ doc_id (str): The doc_id of the document to delete.
219
+
220
+ """
221
+ assert self.engine is not None, "Database connection is not initialized."
222
+
223
+ async with self.get_session() as session, session.begin():
224
+ query = select(Document).where(col(Document.doc_id) == doc_id)
225
+ result = await session.execute(query)
226
+ document = result.scalar_one_or_none()
227
+
228
+ if document:
229
+ await session.delete(document)
230
+
231
+ async def get_document_by_doc_id(self, doc_id: str):
232
+ """Retrieve a document by its doc_id.
233
+
234
+ Args:
235
+ doc_id (str): The doc_id of the document to retrieve.
236
+
237
+ Returns:
238
+ dict: The document data or None if not found.
239
+
240
+ """
241
+ assert self.engine is not None, "Database connection is not initialized."
242
+
243
+ async with self.get_session() as session:
244
+ query = select(Document).where(col(Document.doc_id) == doc_id)
245
+ result = await session.execute(query)
246
+ document = result.scalar_one_or_none()
247
+
248
+ if document:
249
+ return self._document_to_dict(document)
250
+ return None
251
+
252
+ async def update_document_by_doc_id(self, doc_id: str, new_text: str):
253
+ """Update a document by its doc_id.
254
+
255
+ Args:
256
+ doc_id (str): The doc_id.
257
+ new_text (str): The new text to update the document with.
258
+
259
+ """
260
+ assert self.engine is not None, "Database connection is not initialized."
261
+
262
+ async with self.get_session() as session, session.begin():
263
+ query = select(Document).where(col(Document.doc_id) == doc_id)
264
+ result = await session.execute(query)
265
+ document = result.scalar_one_or_none()
266
+
267
+ if document:
268
+ document.text = new_text
269
+ document.updated_at = datetime.now()
270
+ session.add(document)
271
+
272
+ async def delete_documents(self, metadata_filters: dict):
273
+ """Delete documents by their metadata filters.
274
+
275
+ Args:
276
+ metadata_filters (dict): The metadata filters to apply.
277
+
278
+ """
279
+ if self.engine is None:
280
+ logger.warning(
281
+ "Database connection is not initialized, skipping delete operation",
282
+ )
283
+ return
284
+
285
+ async with self.get_session() as session, session.begin():
286
+ query = select(Document)
287
+
288
+ for key, val in metadata_filters.items():
289
+ query = query.where(
290
+ text(f"json_extract(metadata, '$.{key}') = :filter_{key}"),
291
+ ).params(**{f"filter_{key}": val})
292
+
293
+ result = await session.execute(query)
294
+ documents = result.scalars().all()
295
+
296
+ for doc in documents:
297
+ await session.delete(doc)
298
+
299
+ async def count_documents(self, metadata_filters: dict | None = None) -> int:
300
+ """Count documents in the database.
301
+
302
+ Args:
303
+ metadata_filters (dict | None): Metadata filters to apply.
304
+
305
+ Returns:
306
+ int: The count of documents.
307
+
308
+ """
309
+ if self.engine is None:
310
+ logger.warning("Database connection is not initialized, returning 0")
311
+ return 0
312
+
313
+ async with self.get_session() as session:
314
+ query = select(func.count(col(Document.id)))
315
+
316
+ if metadata_filters:
317
+ for key, val in metadata_filters.items():
318
+ query = query.where(
319
+ text(f"json_extract(metadata, '$.{key}') = :filter_{key}"),
320
+ ).params(**{f"filter_{key}": val})
321
+
322
+ result = await session.execute(query)
323
+ count = result.scalar_one_or_none()
324
+ return count if count is not None else 0
325
+
326
+ async def get_user_ids(self) -> list[str]:
327
+ """Retrieve all user IDs from the documents table.
328
+
329
+ Returns:
330
+ list: A list of user IDs.
331
+
332
+ """
333
+ assert self.engine is not None, "Database connection is not initialized."
334
+
335
+ async with self.get_session() as session:
336
+ query = text(
337
+ "SELECT DISTINCT user_id FROM documents WHERE user_id IS NOT NULL",
338
+ )
339
+ result = await session.execute(query)
340
+ rows = result.fetchall()
341
+ return [row[0] for row in rows]
342
+
343
+ def _document_to_dict(self, document: Document) -> dict:
344
+ """Convert a Document model to a dictionary.
345
+
346
+ Args:
347
+ document (Document): The document to convert.
348
+
349
+ Returns:
350
+ dict: The converted dictionary.
351
+
352
+ """
353
+ return {
354
+ "id": document.id,
355
+ "doc_id": document.doc_id,
356
+ "text": document.text,
357
+ "metadata": document.metadata_,
358
+ "created_at": document.created_at.isoformat()
359
+ if isinstance(document.created_at, datetime)
360
+ else document.created_at,
361
+ "updated_at": document.updated_at.isoformat()
362
+ if isinstance(document.updated_at, datetime)
363
+ else document.updated_at,
364
+ }
365
+
366
+ async def tuple_to_dict(self, row):
367
+ """Convert a tuple to a dictionary.
368
+
369
+ Args:
370
+ row (tuple): The row to convert.
371
+
372
+ Returns:
373
+ dict: The converted dictionary.
374
+
375
+ Note: This method is kept for backward compatibility but is no longer used internally.
376
+
377
+ """
378
+ return {
379
+ "id": row[0],
380
+ "doc_id": row[1],
381
+ "text": row[2],
382
+ "metadata": row[3],
383
+ "created_at": row[4],
384
+ "updated_at": row[5],
385
+ }
386
+
387
+ async def close(self):
388
+ """Close the connection to the SQLite database."""
389
+ if self.engine:
390
+ await self.engine.dispose()
391
+ self.engine = None
392
+ self.async_session_maker = None
@@ -0,0 +1,93 @@
1
+ try:
2
+ import faiss
3
+ except ModuleNotFoundError:
4
+ raise ImportError(
5
+ "faiss 未安装。请使用 'pip install faiss-cpu' 或 'pip install faiss-gpu' 安装。",
6
+ )
7
+ import os
8
+
9
+ import numpy as np
10
+
11
+
12
+ class EmbeddingStorage:
13
+ def __init__(self, dimension: int, path: str | None = None):
14
+ self.dimension = dimension
15
+ self.path = path
16
+ self.index = None
17
+ if path and os.path.exists(path):
18
+ self.index = faiss.read_index(path)
19
+ else:
20
+ base_index = faiss.IndexFlatL2(dimension)
21
+ self.index = faiss.IndexIDMap(base_index)
22
+
23
+ async def insert(self, vector: np.ndarray, id: int):
24
+ """插入向量
25
+
26
+ Args:
27
+ vector (np.ndarray): 要插入的向量
28
+ id (int): 向量的ID
29
+ Raises:
30
+ ValueError: 如果向量的维度与存储的维度不匹配
31
+
32
+ """
33
+ assert self.index is not None, "FAISS index is not initialized."
34
+ if vector.shape[0] != self.dimension:
35
+ raise ValueError(
36
+ f"向量维度不匹配, 期望: {self.dimension}, 实际: {vector.shape[0]}",
37
+ )
38
+ self.index.add_with_ids(vector.reshape(1, -1), np.array([id]))
39
+ await self.save_index()
40
+
41
+ async def insert_batch(self, vectors: np.ndarray, ids: list[int]):
42
+ """批量插入向量
43
+
44
+ Args:
45
+ vectors (np.ndarray): 要插入的向量数组
46
+ ids (list[int]): 向量的ID列表
47
+ Raises:
48
+ ValueError: 如果向量的维度与存储的维度不匹配
49
+
50
+ """
51
+ assert self.index is not None, "FAISS index is not initialized."
52
+ if vectors.shape[1] != self.dimension:
53
+ raise ValueError(
54
+ f"向量维度不匹配, 期望: {self.dimension}, 实际: {vectors.shape[1]}",
55
+ )
56
+ self.index.add_with_ids(vectors, np.array(ids))
57
+ await self.save_index()
58
+
59
+ async def search(self, vector: np.ndarray, k: int) -> tuple:
60
+ """搜索最相似的向量
61
+
62
+ Args:
63
+ vector (np.ndarray): 查询向量
64
+ k (int): 返回的最相似向量的数量
65
+ Returns:
66
+ tuple: (距离, 索引)
67
+
68
+ """
69
+ assert self.index is not None, "FAISS index is not initialized."
70
+ faiss.normalize_L2(vector)
71
+ distances, indices = self.index.search(vector, k)
72
+ return distances, indices
73
+
74
+ async def delete(self, ids: list[int]):
75
+ """删除向量
76
+
77
+ Args:
78
+ ids (list[int]): 要删除的向量ID列表
79
+
80
+ """
81
+ assert self.index is not None, "FAISS index is not initialized."
82
+ id_array = np.array(ids, dtype=np.int64)
83
+ self.index.remove_ids(id_array)
84
+ await self.save_index()
85
+
86
+ async def save_index(self):
87
+ """保存索引
88
+
89
+ Args:
90
+ path (str): 保存索引的路径
91
+
92
+ """
93
+ faiss.write_index(self.index, self.path)
@@ -0,0 +1,17 @@
1
+ -- 创建文档存储表,包含 faiss 中文档的 id,文档文本,create_at,updated_at
2
+ CREATE TABLE documents (
3
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
4
+ doc_id TEXT NOT NULL,
5
+ text TEXT NOT NULL,
6
+ metadata TEXT,
7
+ created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
8
+ updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
9
+ );
10
+
11
+ ALTER TABLE documents
12
+ ADD COLUMN group_id TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.group_id')) STORED;
13
+ ALTER TABLE documents
14
+ ADD COLUMN user_id TEXT GENERATED ALWAYS AS (json_extract(metadata, '$.user_id')) STORED;
15
+
16
+ CREATE INDEX idx_documents_user_id ON documents(user_id);
17
+ CREATE INDEX idx_documents_group_id ON documents(group_id);