AstrBot 4.5.0__py3-none-any.whl → 4.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. astrbot/api/__init__.py +10 -11
  2. astrbot/api/event/__init__.py +5 -6
  3. astrbot/api/event/filter/__init__.py +37 -36
  4. astrbot/api/platform/__init__.py +7 -8
  5. astrbot/api/provider/__init__.py +7 -7
  6. astrbot/api/star/__init__.py +3 -4
  7. astrbot/api/util/__init__.py +2 -2
  8. astrbot/cli/__main__.py +5 -5
  9. astrbot/cli/commands/__init__.py +3 -3
  10. astrbot/cli/commands/cmd_conf.py +19 -16
  11. astrbot/cli/commands/cmd_init.py +3 -2
  12. astrbot/cli/commands/cmd_plug.py +8 -10
  13. astrbot/cli/commands/cmd_run.py +5 -6
  14. astrbot/cli/utils/__init__.py +6 -6
  15. astrbot/cli/utils/basic.py +14 -14
  16. astrbot/cli/utils/plugin.py +24 -15
  17. astrbot/cli/utils/version_comparator.py +10 -12
  18. astrbot/core/__init__.py +8 -6
  19. astrbot/core/agent/agent.py +3 -2
  20. astrbot/core/agent/handoff.py +6 -2
  21. astrbot/core/agent/hooks.py +9 -6
  22. astrbot/core/agent/mcp_client.py +50 -15
  23. astrbot/core/agent/message.py +168 -0
  24. astrbot/core/agent/response.py +2 -1
  25. astrbot/core/agent/run_context.py +2 -3
  26. astrbot/core/agent/runners/base.py +10 -13
  27. astrbot/core/agent/runners/tool_loop_agent_runner.py +52 -51
  28. astrbot/core/agent/tool.py +60 -41
  29. astrbot/core/agent/tool_executor.py +9 -3
  30. astrbot/core/astr_agent_context.py +3 -1
  31. astrbot/core/astrbot_config_mgr.py +29 -9
  32. astrbot/core/config/__init__.py +2 -2
  33. astrbot/core/config/astrbot_config.py +28 -26
  34. astrbot/core/config/default.py +44 -6
  35. astrbot/core/conversation_mgr.py +105 -36
  36. astrbot/core/core_lifecycle.py +68 -54
  37. astrbot/core/db/__init__.py +33 -18
  38. astrbot/core/db/migration/helper.py +18 -13
  39. astrbot/core/db/migration/migra_3_to_4.py +53 -34
  40. astrbot/core/db/migration/migra_45_to_46.py +1 -1
  41. astrbot/core/db/migration/shared_preferences_v3.py +2 -1
  42. astrbot/core/db/migration/sqlite_v3.py +26 -23
  43. astrbot/core/db/po.py +27 -18
  44. astrbot/core/db/sqlite.py +74 -45
  45. astrbot/core/db/vec_db/base.py +10 -14
  46. astrbot/core/db/vec_db/faiss_impl/document_storage.py +90 -77
  47. astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +9 -3
  48. astrbot/core/db/vec_db/faiss_impl/vec_db.py +36 -31
  49. astrbot/core/event_bus.py +8 -6
  50. astrbot/core/file_token_service.py +6 -5
  51. astrbot/core/initial_loader.py +7 -5
  52. astrbot/core/knowledge_base/chunking/__init__.py +1 -3
  53. astrbot/core/knowledge_base/chunking/base.py +1 -0
  54. astrbot/core/knowledge_base/chunking/fixed_size.py +2 -0
  55. astrbot/core/knowledge_base/chunking/recursive.py +16 -10
  56. astrbot/core/knowledge_base/kb_db_sqlite.py +50 -48
  57. astrbot/core/knowledge_base/kb_helper.py +30 -17
  58. astrbot/core/knowledge_base/kb_mgr.py +6 -7
  59. astrbot/core/knowledge_base/models.py +10 -4
  60. astrbot/core/knowledge_base/parsers/__init__.py +3 -5
  61. astrbot/core/knowledge_base/parsers/base.py +1 -0
  62. astrbot/core/knowledge_base/parsers/markitdown_parser.py +2 -1
  63. astrbot/core/knowledge_base/parsers/pdf_parser.py +2 -1
  64. astrbot/core/knowledge_base/parsers/text_parser.py +1 -0
  65. astrbot/core/knowledge_base/parsers/util.py +1 -1
  66. astrbot/core/knowledge_base/retrieval/__init__.py +6 -8
  67. astrbot/core/knowledge_base/retrieval/manager.py +17 -14
  68. astrbot/core/knowledge_base/retrieval/rank_fusion.py +7 -3
  69. astrbot/core/knowledge_base/retrieval/sparse_retriever.py +11 -5
  70. astrbot/core/log.py +21 -13
  71. astrbot/core/message/components.py +123 -217
  72. astrbot/core/message/message_event_result.py +24 -24
  73. astrbot/core/persona_mgr.py +20 -11
  74. astrbot/core/pipeline/__init__.py +7 -7
  75. astrbot/core/pipeline/content_safety_check/stage.py +13 -9
  76. astrbot/core/pipeline/content_safety_check/strategies/__init__.py +1 -2
  77. astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py +12 -13
  78. astrbot/core/pipeline/content_safety_check/strategies/keywords.py +1 -0
  79. astrbot/core/pipeline/content_safety_check/strategies/strategy.py +6 -6
  80. astrbot/core/pipeline/context.py +4 -1
  81. astrbot/core/pipeline/context_utils.py +77 -7
  82. astrbot/core/pipeline/preprocess_stage/stage.py +12 -9
  83. astrbot/core/pipeline/process_stage/method/llm_request.py +125 -72
  84. astrbot/core/pipeline/process_stage/method/star_request.py +19 -17
  85. astrbot/core/pipeline/process_stage/stage.py +13 -10
  86. astrbot/core/pipeline/process_stage/utils.py +6 -5
  87. astrbot/core/pipeline/rate_limit_check/stage.py +37 -36
  88. astrbot/core/pipeline/respond/stage.py +23 -20
  89. astrbot/core/pipeline/result_decorate/stage.py +31 -23
  90. astrbot/core/pipeline/scheduler.py +12 -8
  91. astrbot/core/pipeline/session_status_check/stage.py +12 -8
  92. astrbot/core/pipeline/stage.py +10 -4
  93. astrbot/core/pipeline/waking_check/stage.py +24 -18
  94. astrbot/core/pipeline/whitelist_check/stage.py +10 -7
  95. astrbot/core/platform/__init__.py +6 -6
  96. astrbot/core/platform/astr_message_event.py +76 -110
  97. astrbot/core/platform/astrbot_message.py +11 -13
  98. astrbot/core/platform/manager.py +16 -15
  99. astrbot/core/platform/message_session.py +5 -3
  100. astrbot/core/platform/platform.py +16 -24
  101. astrbot/core/platform/platform_metadata.py +4 -4
  102. astrbot/core/platform/register.py +8 -8
  103. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +23 -15
  104. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +51 -33
  105. astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +47 -29
  106. astrbot/core/platform/sources/dingtalk/dingtalk_event.py +7 -3
  107. astrbot/core/platform/sources/discord/client.py +9 -6
  108. astrbot/core/platform/sources/discord/components.py +18 -14
  109. astrbot/core/platform/sources/discord/discord_platform_adapter.py +45 -30
  110. astrbot/core/platform/sources/discord/discord_platform_event.py +38 -30
  111. astrbot/core/platform/sources/lark/lark_adapter.py +23 -17
  112. astrbot/core/platform/sources/lark/lark_event.py +21 -14
  113. astrbot/core/platform/sources/misskey/misskey_adapter.py +107 -67
  114. astrbot/core/platform/sources/misskey/misskey_api.py +153 -129
  115. astrbot/core/platform/sources/misskey/misskey_event.py +20 -15
  116. astrbot/core/platform/sources/misskey/misskey_utils.py +74 -62
  117. astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +63 -44
  118. astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +41 -26
  119. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +36 -17
  120. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py +3 -1
  121. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +12 -7
  122. astrbot/core/platform/sources/satori/satori_adapter.py +56 -38
  123. astrbot/core/platform/sources/satori/satori_event.py +34 -25
  124. astrbot/core/platform/sources/slack/client.py +11 -9
  125. astrbot/core/platform/sources/slack/slack_adapter.py +52 -36
  126. astrbot/core/platform/sources/slack/slack_event.py +34 -24
  127. astrbot/core/platform/sources/telegram/tg_adapter.py +38 -18
  128. astrbot/core/platform/sources/telegram/tg_event.py +32 -18
  129. astrbot/core/platform/sources/webchat/webchat_adapter.py +27 -17
  130. astrbot/core/platform/sources/webchat/webchat_event.py +14 -10
  131. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +115 -120
  132. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +9 -8
  133. astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py +15 -16
  134. astrbot/core/platform/sources/wecom/wecom_adapter.py +35 -18
  135. astrbot/core/platform/sources/wecom/wecom_event.py +55 -48
  136. astrbot/core/platform/sources/wecom/wecom_kf.py +34 -44
  137. astrbot/core/platform/sources/wecom/wecom_kf_message.py +26 -10
  138. astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +18 -10
  139. astrbot/core/platform/sources/wecom_ai_bot/__init__.py +3 -5
  140. astrbot/core/platform/sources/wecom_ai_bot/ierror.py +0 -1
  141. astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +61 -37
  142. astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +67 -28
  143. astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +8 -9
  144. astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +18 -9
  145. astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +14 -12
  146. astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +22 -12
  147. astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +40 -26
  148. astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +47 -45
  149. astrbot/core/platform_message_history_mgr.py +5 -3
  150. astrbot/core/provider/__init__.py +2 -3
  151. astrbot/core/provider/entites.py +8 -8
  152. astrbot/core/provider/entities.py +61 -75
  153. astrbot/core/provider/func_tool_manager.py +59 -55
  154. astrbot/core/provider/manager.py +40 -22
  155. astrbot/core/provider/provider.py +72 -46
  156. astrbot/core/provider/register.py +7 -7
  157. astrbot/core/provider/sources/anthropic_source.py +48 -30
  158. astrbot/core/provider/sources/azure_tts_source.py +17 -13
  159. astrbot/core/provider/sources/coze_api_client.py +27 -17
  160. astrbot/core/provider/sources/coze_source.py +104 -87
  161. astrbot/core/provider/sources/dashscope_source.py +18 -11
  162. astrbot/core/provider/sources/dashscope_tts.py +36 -23
  163. astrbot/core/provider/sources/dify_source.py +25 -20
  164. astrbot/core/provider/sources/edge_tts_source.py +21 -17
  165. astrbot/core/provider/sources/fishaudio_tts_api_source.py +22 -14
  166. astrbot/core/provider/sources/gemini_embedding_source.py +12 -13
  167. astrbot/core/provider/sources/gemini_source.py +72 -58
  168. astrbot/core/provider/sources/gemini_tts_source.py +8 -6
  169. astrbot/core/provider/sources/gsv_selfhosted_source.py +17 -14
  170. astrbot/core/provider/sources/gsvi_tts_source.py +11 -7
  171. astrbot/core/provider/sources/minimax_tts_api_source.py +50 -40
  172. astrbot/core/provider/sources/openai_embedding_source.py +6 -8
  173. astrbot/core/provider/sources/openai_source.py +102 -69
  174. astrbot/core/provider/sources/openai_tts_api_source.py +14 -6
  175. astrbot/core/provider/sources/sensevoice_selfhosted_source.py +13 -11
  176. astrbot/core/provider/sources/vllm_rerank_source.py +10 -4
  177. astrbot/core/provider/sources/volcengine_tts.py +38 -31
  178. astrbot/core/provider/sources/whisper_api_source.py +14 -12
  179. astrbot/core/provider/sources/whisper_selfhosted_source.py +15 -11
  180. astrbot/core/provider/sources/xinference_rerank_source.py +116 -0
  181. astrbot/core/provider/sources/xinference_stt_provider.py +197 -0
  182. astrbot/core/star/__init__.py +16 -11
  183. astrbot/core/star/config.py +10 -15
  184. astrbot/core/star/context.py +109 -84
  185. astrbot/core/star/filter/__init__.py +4 -3
  186. astrbot/core/star/filter/command.py +30 -28
  187. astrbot/core/star/filter/command_group.py +27 -24
  188. astrbot/core/star/filter/custom_filter.py +6 -5
  189. astrbot/core/star/filter/event_message_type.py +4 -2
  190. astrbot/core/star/filter/permission.py +4 -2
  191. astrbot/core/star/filter/platform_adapter_type.py +4 -2
  192. astrbot/core/star/filter/regex.py +4 -2
  193. astrbot/core/star/register/__init__.py +19 -19
  194. astrbot/core/star/register/star.py +6 -2
  195. astrbot/core/star/register/star_handler.py +96 -73
  196. astrbot/core/star/session_llm_manager.py +48 -14
  197. astrbot/core/star/session_plugin_manager.py +29 -15
  198. astrbot/core/star/star.py +1 -2
  199. astrbot/core/star/star_handler.py +13 -8
  200. astrbot/core/star/star_manager.py +151 -59
  201. astrbot/core/star/star_tools.py +44 -37
  202. astrbot/core/star/updator.py +10 -10
  203. astrbot/core/umop_config_router.py +10 -4
  204. astrbot/core/updator.py +13 -5
  205. astrbot/core/utils/astrbot_path.py +3 -5
  206. astrbot/core/utils/dify_api_client.py +33 -15
  207. astrbot/core/utils/io.py +66 -42
  208. astrbot/core/utils/log_pipe.py +1 -1
  209. astrbot/core/utils/metrics.py +7 -7
  210. astrbot/core/utils/path_util.py +15 -16
  211. astrbot/core/utils/pip_installer.py +5 -5
  212. astrbot/core/utils/session_waiter.py +19 -20
  213. astrbot/core/utils/shared_preferences.py +45 -20
  214. astrbot/core/utils/t2i/__init__.py +4 -1
  215. astrbot/core/utils/t2i/network_strategy.py +35 -26
  216. astrbot/core/utils/t2i/renderer.py +11 -5
  217. astrbot/core/utils/t2i/template_manager.py +14 -15
  218. astrbot/core/utils/tencent_record_helper.py +19 -13
  219. astrbot/core/utils/version_comparator.py +10 -13
  220. astrbot/core/zip_updator.py +43 -40
  221. astrbot/dashboard/routes/__init__.py +18 -18
  222. astrbot/dashboard/routes/auth.py +10 -8
  223. astrbot/dashboard/routes/chat.py +30 -21
  224. astrbot/dashboard/routes/config.py +92 -75
  225. astrbot/dashboard/routes/conversation.py +46 -39
  226. astrbot/dashboard/routes/file.py +4 -2
  227. astrbot/dashboard/routes/knowledge_base.py +47 -40
  228. astrbot/dashboard/routes/log.py +9 -4
  229. astrbot/dashboard/routes/persona.py +19 -16
  230. astrbot/dashboard/routes/plugin.py +69 -55
  231. astrbot/dashboard/routes/route.py +3 -1
  232. astrbot/dashboard/routes/session_management.py +130 -116
  233. astrbot/dashboard/routes/stat.py +34 -34
  234. astrbot/dashboard/routes/t2i.py +15 -12
  235. astrbot/dashboard/routes/tools.py +47 -52
  236. astrbot/dashboard/routes/update.py +32 -28
  237. astrbot/dashboard/server.py +30 -26
  238. astrbot/dashboard/utils.py +8 -4
  239. {astrbot-4.5.0.dist-info → astrbot-4.5.2.dist-info}/METADATA +4 -2
  240. astrbot-4.5.2.dist-info/RECORD +261 -0
  241. astrbot-4.5.0.dist-info/RECORD +0 -258
  242. {astrbot-4.5.0.dist-info → astrbot-4.5.2.dist-info}/WHEEL +0 -0
  243. {astrbot-4.5.0.dist-info → astrbot-4.5.2.dist-info}/entry_points.txt +0 -0
  244. {astrbot-4.5.0.dist-info → astrbot-4.5.2.dist-info}/licenses/LICENSE +0 -0
@@ -1,8 +1,9 @@
1
1
  import sqlite3
2
2
  import time
3
- from astrbot.core.db.po import Platform, Stats
4
- from typing import Tuple, List, Dict, Any
5
3
  from dataclasses import dataclass
4
+ from typing import Any
5
+
6
+ from astrbot.core.db.po import Platform, Stats
6
7
 
7
8
 
8
9
  @dataclass
@@ -94,7 +95,7 @@ class SQLiteDatabase:
94
95
  c.execute(
95
96
  """
96
97
  PRAGMA table_info(webchat_conversation)
97
- """
98
+ """,
98
99
  )
99
100
  res = c.fetchall()
100
101
  has_title = False
@@ -108,14 +109,14 @@ class SQLiteDatabase:
108
109
  c.execute(
109
110
  """
110
111
  ALTER TABLE webchat_conversation ADD COLUMN title TEXT;
111
- """
112
+ """,
112
113
  )
113
114
  self.conn.commit()
114
115
  if not has_persona_id:
115
116
  c.execute(
116
117
  """
117
118
  ALTER TABLE webchat_conversation ADD COLUMN persona_id TEXT;
118
- """
119
+ """,
119
120
  )
120
121
  self.conn.commit()
121
122
 
@@ -126,7 +127,7 @@ class SQLiteDatabase:
126
127
  conn.text_factory = str
127
128
  return conn
128
129
 
129
- def _exec_sql(self, sql: str, params: Tuple = None):
130
+ def _exec_sql(self, sql: str, params: tuple = None):
130
131
  conn = self.conn
131
132
  try:
132
133
  c = self.conn.cursor()
@@ -174,7 +175,7 @@ class SQLiteDatabase:
174
175
  """
175
176
  SELECT * FROM platform
176
177
  """
177
- + where_clause
178
+ + where_clause,
178
179
  )
179
180
 
180
181
  platform = []
@@ -194,7 +195,7 @@ class SQLiteDatabase:
194
195
  c.execute(
195
196
  """
196
197
  SELECT SUM(count) FROM platform
197
- """
198
+ """,
198
199
  )
199
200
  res = c.fetchone()
200
201
  c.close()
@@ -214,7 +215,7 @@ class SQLiteDatabase:
214
215
  SELECT name, SUM(count), timestamp FROM platform
215
216
  """
216
217
  + where_clause
217
- + " GROUP BY name"
218
+ + " GROUP BY name",
218
219
  )
219
220
 
220
221
  platform = []
@@ -242,7 +243,7 @@ class SQLiteDatabase:
242
243
  c.close()
243
244
 
244
245
  if not res:
245
- return
246
+ return None
246
247
 
247
248
  return Conversation(*res)
248
249
 
@@ -257,7 +258,7 @@ class SQLiteDatabase:
257
258
  (user_id, cid, history, updated_at, created_at),
258
259
  )
259
260
 
260
- def get_conversations(self, user_id: str) -> Tuple:
261
+ def get_conversations(self, user_id: str) -> tuple:
261
262
  try:
262
263
  c = self.conn.cursor()
263
264
  except sqlite3.ProgrammingError:
@@ -280,7 +281,7 @@ class SQLiteDatabase:
280
281
  title = row[3]
281
282
  persona_id = row[4]
282
283
  conversations.append(
283
- Conversation("", cid, "[]", created_at, updated_at, title, persona_id)
284
+ Conversation("", cid, "[]", created_at, updated_at, title, persona_id),
284
285
  )
285
286
  return conversations
286
287
 
@@ -319,8 +320,10 @@ class SQLiteDatabase:
319
320
  )
320
321
 
321
322
  def get_all_conversations(
322
- self, page: int = 1, page_size: int = 20
323
- ) -> Tuple[List[Dict[str, Any]], int]:
323
+ self,
324
+ page: int = 1,
325
+ page_size: int = 20,
326
+ ) -> tuple[list[dict[str, Any]], int]:
324
327
  """获取所有对话,支持分页,按更新时间降序排序"""
325
328
  try:
326
329
  c = self.conn.cursor()
@@ -366,7 +369,7 @@ class SQLiteDatabase:
366
369
  "persona_id": persona_id or "",
367
370
  "created_at": created_at or 0,
368
371
  "updated_at": updated_at or 0,
369
- }
372
+ },
370
373
  )
371
374
 
372
375
  return conversations, total_count
@@ -381,12 +384,12 @@ class SQLiteDatabase:
381
384
  self,
382
385
  page: int = 1,
383
386
  page_size: int = 20,
384
- platforms: List[str] = None,
385
- message_types: List[str] = None,
386
- search_query: str = None,
387
- exclude_ids: List[str] = None,
388
- exclude_platforms: List[str] = None,
389
- ) -> Tuple[List[Dict[str, Any]], int]:
387
+ platforms: list[str] | None = None,
388
+ message_types: list[str] | None = None,
389
+ search_query: str | None = None,
390
+ exclude_ids: list[str] | None = None,
391
+ exclude_platforms: list[str] | None = None,
392
+ ) -> tuple[list[dict[str, Any]], int]:
390
393
  """获取筛选后的对话列表"""
391
394
  try:
392
395
  c = self.conn.cursor()
@@ -422,7 +425,7 @@ class SQLiteDatabase:
422
425
  if search_query:
423
426
  search_query = search_query.encode("unicode_escape").decode("utf-8")
424
427
  where_clauses.append(
425
- "(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)"
428
+ "(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)",
426
429
  )
427
430
  search_param = f"%{search_query}%"
428
431
  params.extend([search_param, search_param, search_param, search_param])
@@ -482,7 +485,7 @@ class SQLiteDatabase:
482
485
  "persona_id": persona_id or "",
483
486
  "created_at": created_at or 0,
484
487
  "updated_at": updated_at or 0,
485
- }
488
+ },
486
489
  )
487
490
 
488
491
  return conversations, total_count
astrbot/core/db/po.py CHANGED
@@ -1,15 +1,15 @@
1
1
  import uuid
2
-
3
- from datetime import datetime, timezone
4
2
  from dataclasses import dataclass, field
3
+ from datetime import datetime, timezone
4
+ from typing import TypedDict
5
+
5
6
  from sqlmodel import (
7
+ JSON,
8
+ Field,
6
9
  SQLModel,
7
10
  Text,
8
- JSON,
9
11
  UniqueConstraint,
10
- Field,
11
12
  )
12
- from typing import Optional, TypedDict
13
13
 
14
14
 
15
15
  class PlatformStat(SQLModel, table=True):
@@ -40,7 +40,8 @@ class ConversationV2(SQLModel, table=True):
40
40
  __tablename__ = "conversations"
41
41
 
42
42
  inner_conversation_id: int = Field(
43
- primary_key=True, sa_column_kwargs={"autoincrement": True}
43
+ primary_key=True,
44
+ sa_column_kwargs={"autoincrement": True},
44
45
  )
45
46
  conversation_id: str = Field(
46
47
  max_length=36,
@@ -50,14 +51,14 @@ class ConversationV2(SQLModel, table=True):
50
51
  )
51
52
  platform_id: str = Field(nullable=False)
52
53
  user_id: str = Field(nullable=False)
53
- content: Optional[list] = Field(default=None, sa_type=JSON)
54
+ content: list | None = Field(default=None, sa_type=JSON)
54
55
  created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
55
56
  updated_at: datetime = Field(
56
57
  default_factory=lambda: datetime.now(timezone.utc),
57
58
  sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
58
59
  )
59
- title: Optional[str] = Field(default=None, max_length=255)
60
- persona_id: Optional[str] = Field(default=None)
60
+ title: str | None = Field(default=None, max_length=255)
61
+ persona_id: str | None = Field(default=None)
61
62
 
62
63
  __table_args__ = (
63
64
  UniqueConstraint(
@@ -76,13 +77,15 @@ class Persona(SQLModel, table=True):
76
77
  __tablename__ = "personas"
77
78
 
78
79
  id: int | None = Field(
79
- primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
80
+ primary_key=True,
81
+ sa_column_kwargs={"autoincrement": True},
82
+ default=None,
80
83
  )
81
84
  persona_id: str = Field(max_length=255, nullable=False)
82
85
  system_prompt: str = Field(sa_type=Text, nullable=False)
83
- begin_dialogs: Optional[list] = Field(default=None, sa_type=JSON)
86
+ begin_dialogs: list | None = Field(default=None, sa_type=JSON)
84
87
  """a list of strings, each representing a dialog to start with"""
85
- tools: Optional[list] = Field(default=None, sa_type=JSON)
88
+ tools: list | None = Field(default=None, sa_type=JSON)
86
89
  """None means use ALL tools for default, empty list means no tools, otherwise a list of tool names."""
87
90
  created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
88
91
  updated_at: datetime = Field(
@@ -104,7 +107,9 @@ class Preference(SQLModel, table=True):
104
107
  __tablename__ = "preferences"
105
108
 
106
109
  id: int | None = Field(
107
- default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}
110
+ default=None,
111
+ primary_key=True,
112
+ sa_column_kwargs={"autoincrement": True},
108
113
  )
109
114
  scope: str = Field(nullable=False)
110
115
  """Scope of the preference, such as 'global', 'umo', 'plugin'."""
@@ -138,13 +143,15 @@ class PlatformMessageHistory(SQLModel, table=True):
138
143
  __tablename__ = "platform_message_history"
139
144
 
140
145
  id: int | None = Field(
141
- primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
146
+ primary_key=True,
147
+ sa_column_kwargs={"autoincrement": True},
148
+ default=None,
142
149
  )
143
150
  platform_id: str = Field(nullable=False)
144
151
  user_id: str = Field(nullable=False) # An id of group, user in platform
145
- sender_id: Optional[str] = Field(default=None) # ID of the sender in the platform
146
- sender_name: Optional[str] = Field(
147
- default=None
152
+ sender_id: str | None = Field(default=None) # ID of the sender in the platform
153
+ sender_name: str | None = Field(
154
+ default=None,
148
155
  ) # Name of the sender in the platform
149
156
  content: dict = Field(sa_type=JSON, nullable=False) # a message chain list
150
157
  created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
@@ -163,7 +170,9 @@ class Attachment(SQLModel, table=True):
163
170
  __tablename__ = "attachments"
164
171
 
165
172
  inner_attachment_id: int | None = Field(
166
- primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
173
+ primary_key=True,
174
+ sa_column_kwargs={"autoincrement": True},
175
+ default=None,
167
176
  )
168
177
  attachment_id: str = Field(
169
178
  max_length=36,
astrbot/core/db/sqlite.py CHANGED
@@ -1,22 +1,27 @@
1
1
  import asyncio
2
- import typing as T
3
2
  import threading
3
+ import typing as T
4
4
  from datetime import datetime, timedelta
5
+
6
+ from sqlalchemy.ext.asyncio import AsyncSession
7
+ from sqlmodel import col, delete, desc, func, or_, select, text, update
8
+
5
9
  from astrbot.core.db import BaseDatabase
6
10
  from astrbot.core.db.po import (
7
- ConversationV2,
8
- PlatformStat,
9
- PlatformMessageHistory,
10
11
  Attachment,
12
+ ConversationV2,
11
13
  Persona,
14
+ PlatformMessageHistory,
15
+ PlatformStat,
12
16
  Preference,
13
- Stats as DeprecatedStats,
14
- Platform as DeprecatedPlatformStat,
15
17
  SQLModel,
16
18
  )
17
-
18
- from sqlmodel import select, update, delete, text, func, or_, desc, col
19
- from sqlalchemy.ext.asyncio import AsyncSession
19
+ from astrbot.core.db.po import (
20
+ Platform as DeprecatedPlatformStat,
21
+ )
22
+ from astrbot.core.db.po import (
23
+ Stats as DeprecatedStats,
24
+ )
20
25
 
21
26
  NOT_GIVEN = T.TypeVar("NOT_GIVEN")
22
27
 
@@ -57,7 +62,9 @@ class SQLiteDatabase(BaseDatabase):
57
62
  async with session.begin():
58
63
  if timestamp is None:
59
64
  timestamp = datetime.now().replace(
60
- minute=0, second=0, microsecond=0
65
+ minute=0,
66
+ second=0,
67
+ microsecond=0,
61
68
  )
62
69
  current_hour = timestamp
63
70
  await session.execute(
@@ -81,13 +88,13 @@ class SQLiteDatabase(BaseDatabase):
81
88
  session: AsyncSession
82
89
  result = await session.execute(
83
90
  select(func.count(col(PlatformStat.platform_id))).select_from(
84
- PlatformStat
85
- )
91
+ PlatformStat,
92
+ ),
86
93
  )
87
94
  count = result.scalar_one_or_none()
88
95
  return count if count is not None else 0
89
96
 
90
- async def get_platform_stats(self, offset_sec: int = 86400) -> T.List[PlatformStat]:
97
+ async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat]:
91
98
  """Get platform statistics within the specified offset in seconds and group by platform_id."""
92
99
  async with self.get_db() as session:
93
100
  session: AsyncSession
@@ -138,7 +145,7 @@ class SQLiteDatabase(BaseDatabase):
138
145
  select(ConversationV2)
139
146
  .order_by(desc(ConversationV2.created_at))
140
147
  .offset(offset)
141
- .limit(page_size)
148
+ .limit(page_size),
142
149
  )
143
150
  return result.scalars().all()
144
151
 
@@ -157,7 +164,7 @@ class SQLiteDatabase(BaseDatabase):
157
164
 
158
165
  if platform_ids:
159
166
  base_query = base_query.where(
160
- col(ConversationV2.platform_id).in_(platform_ids)
167
+ col(ConversationV2.platform_id).in_(platform_ids),
161
168
  )
162
169
  if search_query:
163
170
  search_query = search_query.encode("unicode_escape").decode("utf-8")
@@ -167,16 +174,16 @@ class SQLiteDatabase(BaseDatabase):
167
174
  col(ConversationV2.content).ilike(f"%{search_query}%"),
168
175
  col(ConversationV2.user_id).ilike(f"%{search_query}%"),
169
176
  col(ConversationV2.conversation_id).ilike(f"%{search_query}%"),
170
- )
177
+ ),
171
178
  )
172
179
  if "message_types" in kwargs and len(kwargs["message_types"]) > 0:
173
180
  for msg_type in kwargs["message_types"]:
174
181
  base_query = base_query.where(
175
- col(ConversationV2.user_id).ilike(f"%:{msg_type}:%")
182
+ col(ConversationV2.user_id).ilike(f"%:{msg_type}:%"),
176
183
  )
177
184
  if "platforms" in kwargs and len(kwargs["platforms"]) > 0:
178
185
  base_query = base_query.where(
179
- col(ConversationV2.platform_id).in_(kwargs["platforms"])
186
+ col(ConversationV2.platform_id).in_(kwargs["platforms"]),
180
187
  )
181
188
 
182
189
  # Get total count matching the filters
@@ -233,7 +240,7 @@ class SQLiteDatabase(BaseDatabase):
233
240
  session: AsyncSession
234
241
  async with session.begin():
235
242
  query = update(ConversationV2).where(
236
- col(ConversationV2.conversation_id) == cid
243
+ col(ConversationV2.conversation_id) == cid,
237
244
  )
238
245
  values = {}
239
246
  if title is not None:
@@ -243,7 +250,7 @@ class SQLiteDatabase(BaseDatabase):
243
250
  if content is not None:
244
251
  values["content"] = content
245
252
  if not values:
246
- return
253
+ return None
247
254
  query = query.values(**values)
248
255
  await session.execute(query)
249
256
  return await self.get_conversation_by_id(cid)
@@ -254,8 +261,8 @@ class SQLiteDatabase(BaseDatabase):
254
261
  async with session.begin():
255
262
  await session.execute(
256
263
  delete(ConversationV2).where(
257
- col(ConversationV2.conversation_id) == cid
258
- )
264
+ col(ConversationV2.conversation_id) == cid,
265
+ ),
259
266
  )
260
267
 
261
268
  async def delete_conversations_by_user_id(self, user_id: str) -> None:
@@ -263,7 +270,9 @@ class SQLiteDatabase(BaseDatabase):
263
270
  session: AsyncSession
264
271
  async with session.begin():
265
272
  await session.execute(
266
- delete(ConversationV2).where(col(ConversationV2.user_id) == user_id)
273
+ delete(ConversationV2).where(
274
+ col(ConversationV2.user_id) == user_id
275
+ ),
267
276
  )
268
277
 
269
278
  async def get_session_conversations(
@@ -282,7 +291,7 @@ class SQLiteDatabase(BaseDatabase):
282
291
  select(
283
292
  col(Preference.scope_id).label("session_id"),
284
293
  func.json_extract(Preference.value, "$.val").label(
285
- "conversation_id"
294
+ "conversation_id",
286
295
  ), # type: ignore
287
296
  col(ConversationV2.persona_id).label("persona_id"),
288
297
  col(ConversationV2.title).label("title"),
@@ -295,7 +304,8 @@ class SQLiteDatabase(BaseDatabase):
295
304
  == ConversationV2.conversation_id,
296
305
  )
297
306
  .outerjoin(
298
- Persona, col(ConversationV2.persona_id) == Persona.persona_id
307
+ Persona,
308
+ col(ConversationV2.persona_id) == Persona.persona_id,
299
309
  )
300
310
  .where(Preference.scope == "umo", Preference.key == "sel_conv_id")
301
311
  )
@@ -308,14 +318,14 @@ class SQLiteDatabase(BaseDatabase):
308
318
  col(Preference.scope_id).ilike(search_pattern),
309
319
  col(ConversationV2.title).ilike(search_pattern),
310
320
  col(Persona.persona_id).ilike(search_pattern),
311
- )
321
+ ),
312
322
  )
313
323
 
314
324
  # 平台筛选
315
325
  if platform:
316
326
  platform_pattern = f"{platform}:%"
317
327
  base_query = base_query.where(
318
- col(Preference.scope_id).like(platform_pattern)
328
+ col(Preference.scope_id).like(platform_pattern),
319
329
  )
320
330
 
321
331
  # 排序
@@ -336,7 +346,8 @@ class SQLiteDatabase(BaseDatabase):
336
346
  == ConversationV2.conversation_id,
337
347
  )
338
348
  .outerjoin(
339
- Persona, col(ConversationV2.persona_id) == Persona.persona_id
349
+ Persona,
350
+ col(ConversationV2.persona_id) == Persona.persona_id,
340
351
  )
341
352
  .where(Preference.scope == "umo", Preference.key == "sel_conv_id")
342
353
  )
@@ -349,13 +360,13 @@ class SQLiteDatabase(BaseDatabase):
349
360
  col(Preference.scope_id).ilike(search_pattern),
350
361
  col(ConversationV2.title).ilike(search_pattern),
351
362
  col(Persona.persona_id).ilike(search_pattern),
352
- )
363
+ ),
353
364
  )
354
365
 
355
366
  if platform:
356
367
  platform_pattern = f"{platform}:%"
357
368
  count_base_query = count_base_query.where(
358
- col(Preference.scope_id).like(platform_pattern)
369
+ col(Preference.scope_id).like(platform_pattern),
359
370
  )
360
371
 
361
372
  total_result = await session.execute(count_base_query)
@@ -396,7 +407,10 @@ class SQLiteDatabase(BaseDatabase):
396
407
  return new_history
397
408
 
398
409
  async def delete_platform_message_offset(
399
- self, platform_id, user_id, offset_sec=86400
410
+ self,
411
+ platform_id,
412
+ user_id,
413
+ offset_sec=86400,
400
414
  ):
401
415
  """Delete platform message history records older than the specified offset."""
402
416
  async with self.get_db() as session:
@@ -409,11 +423,15 @@ class SQLiteDatabase(BaseDatabase):
409
423
  col(PlatformMessageHistory.platform_id) == platform_id,
410
424
  col(PlatformMessageHistory.user_id) == user_id,
411
425
  col(PlatformMessageHistory.created_at) < cutoff_time,
412
- )
426
+ ),
413
427
  )
414
428
 
415
429
  async def get_platform_message_history(
416
- self, platform_id, user_id, page=1, page_size=20
430
+ self,
431
+ platform_id,
432
+ user_id,
433
+ page=1,
434
+ page_size=20,
417
435
  ):
418
436
  """Get platform message history records."""
419
437
  async with self.get_db() as session:
@@ -452,7 +470,11 @@ class SQLiteDatabase(BaseDatabase):
452
470
  return result.scalar_one_or_none()
453
471
 
454
472
  async def insert_persona(
455
- self, persona_id, system_prompt, begin_dialogs=None, tools=None
473
+ self,
474
+ persona_id,
475
+ system_prompt,
476
+ begin_dialogs=None,
477
+ tools=None,
456
478
  ):
457
479
  """Insert a new persona record."""
458
480
  async with self.get_db() as session:
@@ -484,7 +506,11 @@ class SQLiteDatabase(BaseDatabase):
484
506
  return result.scalars().all()
485
507
 
486
508
  async def update_persona(
487
- self, persona_id, system_prompt=None, begin_dialogs=None, tools=NOT_GIVEN
509
+ self,
510
+ persona_id,
511
+ system_prompt=None,
512
+ begin_dialogs=None,
513
+ tools=NOT_GIVEN,
488
514
  ):
489
515
  """Update a persona's system prompt or begin dialogs."""
490
516
  async with self.get_db() as session:
@@ -499,7 +525,7 @@ class SQLiteDatabase(BaseDatabase):
499
525
  if tools is not NOT_GIVEN:
500
526
  values["tools"] = tools
501
527
  if not values:
502
- return
528
+ return None
503
529
  query = query.values(**values)
504
530
  await session.execute(query)
505
531
  return await self.get_persona_by_id(persona_id)
@@ -510,7 +536,7 @@ class SQLiteDatabase(BaseDatabase):
510
536
  session: AsyncSession
511
537
  async with session.begin():
512
538
  await session.execute(
513
- delete(Persona).where(col(Persona.persona_id) == persona_id)
539
+ delete(Persona).where(col(Persona.persona_id) == persona_id),
514
540
  )
515
541
 
516
542
  async def insert_preference_or_update(self, scope, scope_id, key, value):
@@ -529,7 +555,10 @@ class SQLiteDatabase(BaseDatabase):
529
555
  existing_preference.value = value
530
556
  else:
531
557
  new_preference = Preference(
532
- scope=scope, scope_id=scope_id, key=key, value=value
558
+ scope=scope,
559
+ scope_id=scope_id,
560
+ key=key,
561
+ value=value,
533
562
  )
534
563
  session.add(new_preference)
535
564
  return existing_preference or new_preference
@@ -568,7 +597,7 @@ class SQLiteDatabase(BaseDatabase):
568
597
  col(Preference.scope) == scope,
569
598
  col(Preference.scope_id) == scope_id,
570
599
  col(Preference.key) == key,
571
- )
600
+ ),
572
601
  )
573
602
  await session.commit()
574
603
 
@@ -581,7 +610,7 @@ class SQLiteDatabase(BaseDatabase):
581
610
  delete(Preference).where(
582
611
  col(Preference.scope) == scope,
583
612
  col(Preference.scope_id) == scope_id,
584
- )
613
+ ),
585
614
  )
586
615
  await session.commit()
587
616
 
@@ -598,7 +627,7 @@ class SQLiteDatabase(BaseDatabase):
598
627
  now = datetime.now()
599
628
  start_time = now - timedelta(seconds=offset_sec)
600
629
  result = await session.execute(
601
- select(PlatformStat).where(PlatformStat.timestamp >= start_time)
630
+ select(PlatformStat).where(PlatformStat.timestamp >= start_time),
602
631
  )
603
632
  all_datas = result.scalars().all()
604
633
  deprecated_stats = DeprecatedStats()
@@ -608,7 +637,7 @@ class SQLiteDatabase(BaseDatabase):
608
637
  name=data.platform_id,
609
638
  count=data.count,
610
639
  timestamp=int(data.timestamp.timestamp()),
611
- )
640
+ ),
612
641
  )
613
642
  return deprecated_stats
614
643
 
@@ -630,7 +659,7 @@ class SQLiteDatabase(BaseDatabase):
630
659
  async with self.get_db() as session:
631
660
  session: AsyncSession
632
661
  result = await session.execute(
633
- select(func.sum(PlatformStat.count)).select_from(PlatformStat)
662
+ select(func.sum(PlatformStat.count)).select_from(PlatformStat),
634
663
  )
635
664
  total_count = result.scalar_one_or_none()
636
665
  return total_count if total_count is not None else 0
@@ -656,7 +685,7 @@ class SQLiteDatabase(BaseDatabase):
656
685
  result = await session.execute(
657
686
  select(PlatformStat.platform_id, func.sum(PlatformStat.count))
658
687
  .where(PlatformStat.timestamp >= start_time)
659
- .group_by(PlatformStat.platform_id)
688
+ .group_by(PlatformStat.platform_id),
660
689
  )
661
690
  grouped_stats = result.all()
662
691
  deprecated_stats = DeprecatedStats()
@@ -666,7 +695,7 @@ class SQLiteDatabase(BaseDatabase):
666
695
  name=platform_id,
667
696
  count=count,
668
697
  timestamp=int(start_time.timestamp()),
669
- )
698
+ ),
670
699
  )
671
700
  return deprecated_stats
672
701
 
@@ -10,18 +10,16 @@ class Result:
10
10
 
11
11
  class BaseVecDB:
12
12
  async def initialize(self):
13
- """
14
- 初始化向量数据库
15
- """
16
- pass
13
+ """初始化向量数据库"""
17
14
 
18
15
  @abc.abstractmethod
19
16
  async def insert(
20
- self, content: str, metadata: dict | None = None, id: str | None = None
17
+ self,
18
+ content: str,
19
+ metadata: dict | None = None,
20
+ id: str | None = None,
21
21
  ) -> int:
22
- """
23
- 插入一条文本和其对应向量,自动生成 ID 并保持一致性。
24
- """
22
+ """插入一条文本和其对应向量,自动生成 ID 并保持一致性。"""
25
23
  ...
26
24
 
27
25
  @abc.abstractmethod
@@ -35,11 +33,11 @@ class BaseVecDB:
35
33
  max_retries: int = 3,
36
34
  progress_callback=None,
37
35
  ) -> int:
38
- """
39
- 批量插入文本和其对应向量,自动生成 ID 并保持一致性。
36
+ """批量插入文本和其对应向量,自动生成 ID 并保持一致性。
40
37
 
41
38
  Args:
42
39
  progress_callback: 进度回调函数,接收参数 (current, total)
40
+
43
41
  """
44
42
  ...
45
43
 
@@ -52,8 +50,7 @@ class BaseVecDB:
52
50
  rerank: bool = False,
53
51
  metadata_filters: dict | None = None,
54
52
  ) -> list[Result]:
55
- """
56
- 搜索最相似的文档。
53
+ """搜索最相似的文档。
57
54
  Args:
58
55
  query (str): 查询文本
59
56
  top_k (int): 返回的最相似文档的数量
@@ -64,8 +61,7 @@ class BaseVecDB:
64
61
 
65
62
  @abc.abstractmethod
66
63
  async def delete(self, doc_id: str) -> bool:
67
- """
68
- 删除指定文档。
64
+ """删除指定文档。
69
65
  Args:
70
66
  doc_id (str): 要删除的文档 ID
71
67
  Returns: