AstrBot 4.5.1__py3-none-any.whl → 4.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. astrbot/api/__init__.py +10 -11
  2. astrbot/api/event/__init__.py +5 -6
  3. astrbot/api/event/filter/__init__.py +37 -36
  4. astrbot/api/platform/__init__.py +7 -8
  5. astrbot/api/provider/__init__.py +7 -7
  6. astrbot/api/star/__init__.py +3 -4
  7. astrbot/api/util/__init__.py +2 -2
  8. astrbot/cli/__main__.py +5 -5
  9. astrbot/cli/commands/__init__.py +3 -3
  10. astrbot/cli/commands/cmd_conf.py +19 -16
  11. astrbot/cli/commands/cmd_init.py +3 -2
  12. astrbot/cli/commands/cmd_plug.py +8 -10
  13. astrbot/cli/commands/cmd_run.py +5 -6
  14. astrbot/cli/utils/__init__.py +6 -6
  15. astrbot/cli/utils/basic.py +14 -14
  16. astrbot/cli/utils/plugin.py +24 -15
  17. astrbot/cli/utils/version_comparator.py +10 -12
  18. astrbot/core/__init__.py +8 -6
  19. astrbot/core/agent/agent.py +3 -2
  20. astrbot/core/agent/handoff.py +6 -2
  21. astrbot/core/agent/hooks.py +9 -6
  22. astrbot/core/agent/mcp_client.py +50 -15
  23. astrbot/core/agent/message.py +168 -0
  24. astrbot/core/agent/response.py +2 -1
  25. astrbot/core/agent/run_context.py +2 -3
  26. astrbot/core/agent/runners/base.py +10 -13
  27. astrbot/core/agent/runners/tool_loop_agent_runner.py +52 -51
  28. astrbot/core/agent/tool.py +60 -41
  29. astrbot/core/agent/tool_executor.py +9 -3
  30. astrbot/core/astr_agent_context.py +3 -1
  31. astrbot/core/astrbot_config_mgr.py +29 -9
  32. astrbot/core/config/__init__.py +2 -2
  33. astrbot/core/config/astrbot_config.py +28 -26
  34. astrbot/core/config/default.py +4 -6
  35. astrbot/core/conversation_mgr.py +105 -36
  36. astrbot/core/core_lifecycle.py +68 -54
  37. astrbot/core/db/__init__.py +33 -18
  38. astrbot/core/db/migration/helper.py +12 -10
  39. astrbot/core/db/migration/migra_3_to_4.py +53 -34
  40. astrbot/core/db/migration/migra_45_to_46.py +1 -1
  41. astrbot/core/db/migration/shared_preferences_v3.py +2 -1
  42. astrbot/core/db/migration/sqlite_v3.py +26 -23
  43. astrbot/core/db/po.py +27 -18
  44. astrbot/core/db/sqlite.py +74 -45
  45. astrbot/core/db/vec_db/base.py +10 -14
  46. astrbot/core/db/vec_db/faiss_impl/document_storage.py +90 -77
  47. astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +9 -3
  48. astrbot/core/db/vec_db/faiss_impl/vec_db.py +36 -31
  49. astrbot/core/event_bus.py +8 -6
  50. astrbot/core/file_token_service.py +6 -5
  51. astrbot/core/initial_loader.py +7 -5
  52. astrbot/core/knowledge_base/chunking/__init__.py +1 -3
  53. astrbot/core/knowledge_base/chunking/base.py +1 -0
  54. astrbot/core/knowledge_base/chunking/fixed_size.py +2 -0
  55. astrbot/core/knowledge_base/chunking/recursive.py +16 -10
  56. astrbot/core/knowledge_base/kb_db_sqlite.py +50 -48
  57. astrbot/core/knowledge_base/kb_helper.py +30 -17
  58. astrbot/core/knowledge_base/kb_mgr.py +6 -7
  59. astrbot/core/knowledge_base/models.py +10 -4
  60. astrbot/core/knowledge_base/parsers/__init__.py +3 -5
  61. astrbot/core/knowledge_base/parsers/base.py +1 -0
  62. astrbot/core/knowledge_base/parsers/markitdown_parser.py +2 -1
  63. astrbot/core/knowledge_base/parsers/pdf_parser.py +2 -1
  64. astrbot/core/knowledge_base/parsers/text_parser.py +1 -0
  65. astrbot/core/knowledge_base/parsers/util.py +1 -1
  66. astrbot/core/knowledge_base/retrieval/__init__.py +6 -8
  67. astrbot/core/knowledge_base/retrieval/manager.py +17 -14
  68. astrbot/core/knowledge_base/retrieval/rank_fusion.py +7 -3
  69. astrbot/core/knowledge_base/retrieval/sparse_retriever.py +11 -5
  70. astrbot/core/log.py +21 -13
  71. astrbot/core/message/components.py +123 -217
  72. astrbot/core/message/message_event_result.py +24 -24
  73. astrbot/core/persona_mgr.py +20 -11
  74. astrbot/core/pipeline/__init__.py +7 -7
  75. astrbot/core/pipeline/content_safety_check/stage.py +13 -9
  76. astrbot/core/pipeline/content_safety_check/strategies/__init__.py +1 -2
  77. astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py +12 -13
  78. astrbot/core/pipeline/content_safety_check/strategies/keywords.py +1 -0
  79. astrbot/core/pipeline/content_safety_check/strategies/strategy.py +6 -6
  80. astrbot/core/pipeline/context.py +4 -1
  81. astrbot/core/pipeline/context_utils.py +77 -7
  82. astrbot/core/pipeline/preprocess_stage/stage.py +12 -9
  83. astrbot/core/pipeline/process_stage/method/llm_request.py +125 -72
  84. astrbot/core/pipeline/process_stage/method/star_request.py +19 -17
  85. astrbot/core/pipeline/process_stage/stage.py +13 -10
  86. astrbot/core/pipeline/process_stage/utils.py +6 -5
  87. astrbot/core/pipeline/rate_limit_check/stage.py +37 -36
  88. astrbot/core/pipeline/respond/stage.py +23 -20
  89. astrbot/core/pipeline/result_decorate/stage.py +31 -23
  90. astrbot/core/pipeline/scheduler.py +12 -8
  91. astrbot/core/pipeline/session_status_check/stage.py +12 -8
  92. astrbot/core/pipeline/stage.py +10 -4
  93. astrbot/core/pipeline/waking_check/stage.py +24 -18
  94. astrbot/core/pipeline/whitelist_check/stage.py +10 -7
  95. astrbot/core/platform/__init__.py +6 -6
  96. astrbot/core/platform/astr_message_event.py +76 -110
  97. astrbot/core/platform/astrbot_message.py +11 -13
  98. astrbot/core/platform/manager.py +16 -15
  99. astrbot/core/platform/message_session.py +5 -3
  100. astrbot/core/platform/platform.py +16 -24
  101. astrbot/core/platform/platform_metadata.py +4 -4
  102. astrbot/core/platform/register.py +8 -8
  103. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +23 -15
  104. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +51 -33
  105. astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +42 -27
  106. astrbot/core/platform/sources/dingtalk/dingtalk_event.py +7 -3
  107. astrbot/core/platform/sources/discord/client.py +9 -6
  108. astrbot/core/platform/sources/discord/components.py +18 -14
  109. astrbot/core/platform/sources/discord/discord_platform_adapter.py +45 -30
  110. astrbot/core/platform/sources/discord/discord_platform_event.py +38 -30
  111. astrbot/core/platform/sources/lark/lark_adapter.py +23 -17
  112. astrbot/core/platform/sources/lark/lark_event.py +21 -14
  113. astrbot/core/platform/sources/misskey/misskey_adapter.py +107 -67
  114. astrbot/core/platform/sources/misskey/misskey_api.py +153 -129
  115. astrbot/core/platform/sources/misskey/misskey_event.py +20 -15
  116. astrbot/core/platform/sources/misskey/misskey_utils.py +74 -62
  117. astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +63 -44
  118. astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +41 -26
  119. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +36 -17
  120. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py +3 -1
  121. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +12 -7
  122. astrbot/core/platform/sources/satori/satori_adapter.py +56 -38
  123. astrbot/core/platform/sources/satori/satori_event.py +34 -25
  124. astrbot/core/platform/sources/slack/client.py +11 -9
  125. astrbot/core/platform/sources/slack/slack_adapter.py +52 -36
  126. astrbot/core/platform/sources/slack/slack_event.py +34 -24
  127. astrbot/core/platform/sources/telegram/tg_adapter.py +38 -18
  128. astrbot/core/platform/sources/telegram/tg_event.py +32 -18
  129. astrbot/core/platform/sources/webchat/webchat_adapter.py +27 -17
  130. astrbot/core/platform/sources/webchat/webchat_event.py +14 -10
  131. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +115 -120
  132. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +9 -8
  133. astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py +15 -16
  134. astrbot/core/platform/sources/wecom/wecom_adapter.py +35 -18
  135. astrbot/core/platform/sources/wecom/wecom_event.py +55 -48
  136. astrbot/core/platform/sources/wecom/wecom_kf.py +34 -44
  137. astrbot/core/platform/sources/wecom/wecom_kf_message.py +26 -10
  138. astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +18 -10
  139. astrbot/core/platform/sources/wecom_ai_bot/__init__.py +3 -5
  140. astrbot/core/platform/sources/wecom_ai_bot/ierror.py +0 -1
  141. astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +61 -37
  142. astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +67 -28
  143. astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +8 -9
  144. astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +18 -9
  145. astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +14 -12
  146. astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +22 -12
  147. astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +40 -26
  148. astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +47 -45
  149. astrbot/core/platform_message_history_mgr.py +5 -3
  150. astrbot/core/provider/__init__.py +2 -3
  151. astrbot/core/provider/entites.py +8 -8
  152. astrbot/core/provider/entities.py +61 -75
  153. astrbot/core/provider/func_tool_manager.py +59 -55
  154. astrbot/core/provider/manager.py +32 -22
  155. astrbot/core/provider/provider.py +72 -46
  156. astrbot/core/provider/register.py +7 -7
  157. astrbot/core/provider/sources/anthropic_source.py +48 -30
  158. astrbot/core/provider/sources/azure_tts_source.py +17 -13
  159. astrbot/core/provider/sources/coze_api_client.py +27 -17
  160. astrbot/core/provider/sources/coze_source.py +104 -87
  161. astrbot/core/provider/sources/dashscope_source.py +18 -11
  162. astrbot/core/provider/sources/dashscope_tts.py +36 -23
  163. astrbot/core/provider/sources/dify_source.py +25 -20
  164. astrbot/core/provider/sources/edge_tts_source.py +21 -17
  165. astrbot/core/provider/sources/fishaudio_tts_api_source.py +22 -14
  166. astrbot/core/provider/sources/gemini_embedding_source.py +12 -13
  167. astrbot/core/provider/sources/gemini_source.py +72 -58
  168. astrbot/core/provider/sources/gemini_tts_source.py +8 -6
  169. astrbot/core/provider/sources/gsv_selfhosted_source.py +17 -14
  170. astrbot/core/provider/sources/gsvi_tts_source.py +11 -7
  171. astrbot/core/provider/sources/minimax_tts_api_source.py +50 -40
  172. astrbot/core/provider/sources/openai_embedding_source.py +6 -8
  173. astrbot/core/provider/sources/openai_source.py +77 -69
  174. astrbot/core/provider/sources/openai_tts_api_source.py +14 -6
  175. astrbot/core/provider/sources/sensevoice_selfhosted_source.py +13 -11
  176. astrbot/core/provider/sources/vllm_rerank_source.py +10 -4
  177. astrbot/core/provider/sources/volcengine_tts.py +38 -31
  178. astrbot/core/provider/sources/whisper_api_source.py +14 -12
  179. astrbot/core/provider/sources/whisper_selfhosted_source.py +15 -11
  180. astrbot/core/provider/sources/xinference_rerank_source.py +16 -8
  181. astrbot/core/provider/sources/xinference_stt_provider.py +35 -25
  182. astrbot/core/star/__init__.py +16 -11
  183. astrbot/core/star/config.py +10 -15
  184. astrbot/core/star/context.py +97 -75
  185. astrbot/core/star/filter/__init__.py +4 -3
  186. astrbot/core/star/filter/command.py +30 -28
  187. astrbot/core/star/filter/command_group.py +27 -24
  188. astrbot/core/star/filter/custom_filter.py +6 -5
  189. astrbot/core/star/filter/event_message_type.py +4 -2
  190. astrbot/core/star/filter/permission.py +4 -2
  191. astrbot/core/star/filter/platform_adapter_type.py +4 -2
  192. astrbot/core/star/filter/regex.py +4 -2
  193. astrbot/core/star/register/__init__.py +19 -19
  194. astrbot/core/star/register/star.py +6 -2
  195. astrbot/core/star/register/star_handler.py +96 -73
  196. astrbot/core/star/session_llm_manager.py +48 -14
  197. astrbot/core/star/session_plugin_manager.py +29 -15
  198. astrbot/core/star/star.py +1 -2
  199. astrbot/core/star/star_handler.py +13 -8
  200. astrbot/core/star/star_manager.py +151 -59
  201. astrbot/core/star/star_tools.py +44 -37
  202. astrbot/core/star/updator.py +10 -10
  203. astrbot/core/umop_config_router.py +10 -4
  204. astrbot/core/updator.py +13 -5
  205. astrbot/core/utils/astrbot_path.py +3 -5
  206. astrbot/core/utils/dify_api_client.py +33 -15
  207. astrbot/core/utils/io.py +66 -42
  208. astrbot/core/utils/log_pipe.py +1 -1
  209. astrbot/core/utils/metrics.py +7 -7
  210. astrbot/core/utils/path_util.py +15 -16
  211. astrbot/core/utils/pip_installer.py +5 -5
  212. astrbot/core/utils/session_waiter.py +19 -20
  213. astrbot/core/utils/shared_preferences.py +45 -20
  214. astrbot/core/utils/t2i/__init__.py +4 -1
  215. astrbot/core/utils/t2i/network_strategy.py +35 -26
  216. astrbot/core/utils/t2i/renderer.py +11 -5
  217. astrbot/core/utils/t2i/template_manager.py +14 -15
  218. astrbot/core/utils/tencent_record_helper.py +19 -13
  219. astrbot/core/utils/version_comparator.py +10 -13
  220. astrbot/core/zip_updator.py +43 -40
  221. astrbot/dashboard/routes/__init__.py +18 -18
  222. astrbot/dashboard/routes/auth.py +10 -8
  223. astrbot/dashboard/routes/chat.py +30 -21
  224. astrbot/dashboard/routes/config.py +92 -75
  225. astrbot/dashboard/routes/conversation.py +46 -39
  226. astrbot/dashboard/routes/file.py +4 -2
  227. astrbot/dashboard/routes/knowledge_base.py +47 -40
  228. astrbot/dashboard/routes/log.py +9 -4
  229. astrbot/dashboard/routes/persona.py +19 -16
  230. astrbot/dashboard/routes/plugin.py +69 -55
  231. astrbot/dashboard/routes/route.py +3 -1
  232. astrbot/dashboard/routes/session_management.py +130 -116
  233. astrbot/dashboard/routes/stat.py +34 -34
  234. astrbot/dashboard/routes/t2i.py +15 -12
  235. astrbot/dashboard/routes/tools.py +47 -52
  236. astrbot/dashboard/routes/update.py +32 -28
  237. astrbot/dashboard/server.py +30 -26
  238. astrbot/dashboard/utils.py +8 -4
  239. {astrbot-4.5.1.dist-info → astrbot-4.5.2.dist-info}/METADATA +2 -1
  240. astrbot-4.5.2.dist-info/RECORD +261 -0
  241. astrbot-4.5.1.dist-info/RECORD +0 -260
  242. {astrbot-4.5.1.dist-info → astrbot-4.5.2.dist-info}/WHEEL +0 -0
  243. {astrbot-4.5.1.dist-info → astrbot-4.5.2.dist-info}/entry_points.txt +0 -0
  244. {astrbot-4.5.1.dist-info → astrbot-4.5.2.dist-info}/licenses/LICENSE +0 -0
@@ -1,27 +1,27 @@
1
1
  import abc
2
2
  import datetime
3
3
  import typing as T
4
- from deprecated import deprecated
4
+ from contextlib import asynccontextmanager
5
5
  from dataclasses import dataclass
6
+
7
+ from deprecated import deprecated
8
+ from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
9
+ from sqlalchemy.orm import sessionmaker
10
+
6
11
  from astrbot.core.db.po import (
7
- Stats,
8
- PlatformStat,
9
- ConversationV2,
10
- PlatformMessageHistory,
11
12
  Attachment,
13
+ ConversationV2,
12
14
  Persona,
15
+ PlatformMessageHistory,
16
+ PlatformStat,
13
17
  Preference,
18
+ Stats,
14
19
  )
15
- from contextlib import asynccontextmanager
16
- from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
17
- from sqlalchemy.orm import sessionmaker
18
20
 
19
21
 
20
22
  @dataclass
21
23
  class BaseDatabase(abc.ABC):
22
- """
23
- 数据库基类
24
- """
24
+ """数据库基类"""
25
25
 
26
26
  DATABASE_URL = ""
27
27
 
@@ -32,12 +32,13 @@ class BaseDatabase(abc.ABC):
32
32
  future=True,
33
33
  )
34
34
  self.AsyncSessionLocal = sessionmaker(
35
- self.engine, class_=AsyncSession, expire_on_commit=False
35
+ self.engine,
36
+ class_=AsyncSession,
37
+ expire_on_commit=False,
36
38
  )
37
39
 
38
40
  async def initialize(self):
39
41
  """初始化数据库连接"""
40
- pass
41
42
 
42
43
  @asynccontextmanager
43
44
  async def get_db(self) -> T.AsyncGenerator[AsyncSession, None]:
@@ -91,7 +92,9 @@ class BaseDatabase(abc.ABC):
91
92
 
92
93
  @abc.abstractmethod
93
94
  async def get_conversations(
94
- self, user_id: str | None = None, platform_id: str | None = None
95
+ self,
96
+ user_id: str | None = None,
97
+ platform_id: str | None = None,
95
98
  ) -> list[ConversationV2]:
96
99
  """Get all conversations for a specific user and platform_id(optional).
97
100
 
@@ -106,7 +109,9 @@ class BaseDatabase(abc.ABC):
106
109
 
107
110
  @abc.abstractmethod
108
111
  async def get_all_conversations(
109
- self, page: int = 1, page_size: int = 20
112
+ self,
113
+ page: int = 1,
114
+ page_size: int = 20,
110
115
  ) -> list[ConversationV2]:
111
116
  """Get all conversations with pagination."""
112
117
  ...
@@ -173,7 +178,10 @@ class BaseDatabase(abc.ABC):
173
178
 
174
179
  @abc.abstractmethod
175
180
  async def delete_platform_message_offset(
176
- self, platform_id: str, user_id: str, offset_sec: int = 86400
181
+ self,
182
+ platform_id: str,
183
+ user_id: str,
184
+ offset_sec: int = 86400,
177
185
  ) -> None:
178
186
  """Delete platform message history records older than the specified offset."""
179
187
  ...
@@ -243,7 +251,11 @@ class BaseDatabase(abc.ABC):
243
251
 
244
252
  @abc.abstractmethod
245
253
  async def insert_preference_or_update(
246
- self, scope: str, scope_id: str, key: str, value: dict
254
+ self,
255
+ scope: str,
256
+ scope_id: str,
257
+ key: str,
258
+ value: dict,
247
259
  ) -> Preference:
248
260
  """Insert a new preference record."""
249
261
  ...
@@ -255,7 +267,10 @@ class BaseDatabase(abc.ABC):
255
267
 
256
268
  @abc.abstractmethod
257
269
  async def get_preferences(
258
- self, scope: str, scope_id: str | None = None, key: str | None = None
270
+ self,
271
+ scope: str,
272
+ scope_id: str | None = None,
273
+ key: str | None = None,
259
274
  ) -> list[Preference]:
260
275
  """Get all preferences for a specific scope ID or key."""
261
276
  ...
@@ -1,20 +1,21 @@
1
1
  import os
2
- from astrbot.core.utils.astrbot_path import get_astrbot_data_path
3
- from astrbot.core.db import BaseDatabase
4
- from astrbot.core.config import AstrBotConfig
2
+
5
3
  from astrbot.api import logger, sp
4
+ from astrbot.core.config import AstrBotConfig
5
+ from astrbot.core.db import BaseDatabase
6
+ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
7
+
6
8
  from .migra_3_to_4 import (
7
9
  migration_conversation_table,
8
- migration_platform_table,
9
- migration_webchat_data,
10
10
  migration_persona_data,
11
+ migration_platform_table,
11
12
  migration_preferences,
13
+ migration_webchat_data,
12
14
  )
13
15
 
14
16
 
15
17
  async def check_migration_needed_v4(db_helper: BaseDatabase) -> bool:
16
- """
17
- 检查是否需要进行数据库迁移
18
+ """检查是否需要进行数据库迁移
18
19
  如果存在 data_v3.db 并且 preference 中没有 migration_done_v4,则需要进行迁移。
19
20
  """
20
21
  # 仅当 data 目录下存在旧版本数据(data_v3.db 文件)时才考虑迁移
@@ -24,7 +25,9 @@ async def check_migration_needed_v4(db_helper: BaseDatabase) -> bool:
24
25
  if not os.path.exists(data_v3_db):
25
26
  return False
26
27
  migration_done = await db_helper.get_preference(
27
- "global", "global", "migration_done_v4"
28
+ "global",
29
+ "global",
30
+ "migration_done_v4",
28
31
  )
29
32
  if migration_done:
30
33
  return False
@@ -36,8 +39,7 @@ async def do_migration_v4(
36
39
  platform_id_map: dict[str, dict[str, str]],
37
40
  astrbot_config: AstrBotConfig,
38
41
  ) -> None:
39
- """
40
- 执行数据库迁移
42
+ """执行数据库迁移
41
43
  迁移旧的 webchat_conversation 表到新的 conversation 表。
42
44
  迁移旧的 platform 到新的 platform_stats 表。
43
45
  """
@@ -1,15 +1,18 @@
1
- import json
2
1
  import datetime
3
- from .. import BaseDatabase
4
- from .sqlite_v3 import SQLiteDatabase as SQLiteV3DatabaseV3
5
- from .shared_preferences_v3 import sp as sp_v3
6
- from astrbot.core.config.default import DB_PATH
2
+ import json
3
+
4
+ from sqlalchemy import text
5
+ from sqlalchemy.ext.asyncio import AsyncSession
6
+
7
7
  from astrbot.api import logger, sp
8
8
  from astrbot.core.config import AstrBotConfig
9
- from astrbot.core.platform.astr_message_event import MessageSesion
10
- from sqlalchemy.ext.asyncio import AsyncSession
9
+ from astrbot.core.config.default import DB_PATH
11
10
  from astrbot.core.db.po import ConversationV2, PlatformMessageHistory
12
- from sqlalchemy import text
11
+ from astrbot.core.platform.astr_message_event import MessageSesion
12
+
13
+ from .. import BaseDatabase
14
+ from .shared_preferences_v3 import sp as sp_v3
15
+ from .sqlite_v3 import SQLiteDatabase as SQLiteV3DatabaseV3
13
16
 
14
17
  """
15
18
  1. 迁移旧的 webchat_conversation 表到新的 conversation 表。
@@ -18,7 +21,8 @@ from sqlalchemy import text
18
21
 
19
22
 
20
23
  def get_platform_id(
21
- platform_id_map: dict[str, dict[str, str]], old_platform_name: str
24
+ platform_id_map: dict[str, dict[str, str]],
25
+ old_platform_name: str,
22
26
  ) -> str:
23
27
  return platform_id_map.get(
24
28
  old_platform_name,
@@ -27,7 +31,8 @@ def get_platform_id(
27
31
 
28
32
 
29
33
  def get_platform_type(
30
- platform_id_map: dict[str, dict[str, str]], old_platform_name: str
34
+ platform_id_map: dict[str, dict[str, str]],
35
+ old_platform_name: str,
31
36
  ) -> str:
32
37
  return platform_id_map.get(
33
38
  old_platform_name,
@@ -36,13 +41,15 @@ def get_platform_type(
36
41
 
37
42
 
38
43
  async def migration_conversation_table(
39
- db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
44
+ db_helper: BaseDatabase,
45
+ platform_id_map: dict[str, dict[str, str]],
40
46
  ):
41
47
  db_helper_v3 = SQLiteV3DatabaseV3(
42
- db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
48
+ db_path=DB_PATH.replace("data_v4.db", "data_v3.db"),
43
49
  )
44
50
  conversations, total_cnt = db_helper_v3.get_all_conversations(
45
- page=1, page_size=10000000
51
+ page=1,
52
+ page_size=10000000,
46
53
  )
47
54
  logger.info(f"迁移 {total_cnt} 条旧的会话数据到新的表中...")
48
55
 
@@ -61,13 +68,14 @@ async def migration_conversation_table(
61
68
  )
62
69
  if not conv:
63
70
  logger.info(
64
- f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。"
71
+ f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。",
65
72
  )
66
73
  if ":" not in conv.user_id:
67
74
  continue
68
75
  session = MessageSesion.from_str(session_str=conv.user_id)
69
76
  platform_id = get_platform_id(
70
- platform_id_map, session.platform_name
77
+ platform_id_map,
78
+ session.platform_name,
71
79
  )
72
80
  session.platform_id = platform_id # 更新平台名称为新的 ID
73
81
  conv_v2 = ConversationV2(
@@ -90,10 +98,11 @@ async def migration_conversation_table(
90
98
 
91
99
 
92
100
  async def migration_platform_table(
93
- db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
101
+ db_helper: BaseDatabase,
102
+ platform_id_map: dict[str, dict[str, str]],
94
103
  ):
95
104
  db_helper_v3 = SQLiteV3DatabaseV3(
96
- db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
105
+ db_path=DB_PATH.replace("data_v4.db", "data_v3.db"),
97
106
  )
98
107
  secs_from_2023_4_10_to_now = (
99
108
  datetime.datetime.now(datetime.timezone.utc)
@@ -134,10 +143,12 @@ async def migration_platform_table(
134
143
  if cnt == 0:
135
144
  continue
136
145
  platform_id = get_platform_id(
137
- platform_id_map, platform_stats_v3[idx].name
146
+ platform_id_map,
147
+ platform_stats_v3[idx].name,
138
148
  )
139
149
  platform_type = get_platform_type(
140
- platform_id_map, platform_stats_v3[idx].name
150
+ platform_id_map,
151
+ platform_stats_v3[idx].name,
141
152
  )
142
153
  try:
143
154
  await dbsession.execute(
@@ -149,7 +160,8 @@ async def migration_platform_table(
149
160
  """),
150
161
  {
151
162
  "timestamp": datetime.datetime.fromtimestamp(
152
- bucket_end, tz=datetime.timezone.utc
163
+ bucket_end,
164
+ tz=datetime.timezone.utc,
153
165
  ),
154
166
  "platform_id": platform_id,
155
167
  "platform_type": platform_type,
@@ -165,14 +177,16 @@ async def migration_platform_table(
165
177
 
166
178
 
167
179
  async def migration_webchat_data(
168
- db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
180
+ db_helper: BaseDatabase,
181
+ platform_id_map: dict[str, dict[str, str]],
169
182
  ):
170
183
  """迁移 WebChat 的历史记录到新的 PlatformMessageHistory 表中"""
171
184
  db_helper_v3 = SQLiteV3DatabaseV3(
172
- db_path=DB_PATH.replace("data_v4.db", "data_v3.db")
185
+ db_path=DB_PATH.replace("data_v4.db", "data_v3.db"),
173
186
  )
174
187
  conversations, total_cnt = db_helper_v3.get_all_conversations(
175
- page=1, page_size=10000000
188
+ page=1,
189
+ page_size=10000000,
176
190
  )
177
191
  logger.info(f"迁移 {total_cnt} 条旧的 WebChat 会话数据到新的表中...")
178
192
 
@@ -191,7 +205,7 @@ async def migration_webchat_data(
191
205
  )
192
206
  if not conv:
193
207
  logger.info(
194
- f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。"
208
+ f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。",
195
209
  )
196
210
  if ":" in conv.user_id:
197
211
  continue
@@ -218,10 +232,10 @@ async def migration_webchat_data(
218
232
 
219
233
 
220
234
  async def migration_persona_data(
221
- db_helper: BaseDatabase, astrbot_config: AstrBotConfig
235
+ db_helper: BaseDatabase,
236
+ astrbot_config: AstrBotConfig,
222
237
  ):
223
- """
224
- 迁移 Persona 数据到新的表中。
238
+ """迁移 Persona 数据到新的表中。
225
239
  旧的 Persona 数据存储在 preference 中,新的 Persona 数据存储在 persona 表中。
226
240
  """
227
241
  v3_persona_config: list[dict] = astrbot_config.get("persona", [])
@@ -236,14 +250,15 @@ async def migration_persona_data(
236
250
  try:
237
251
  begin_dialogs = persona.get("begin_dialogs", [])
238
252
  mood_imitation_dialogs = persona.get("mood_imitation_dialogs", [])
239
- mood_prompt = ""
253
+ parts = []
240
254
  user_turn = True
241
255
  for mood_dialog in mood_imitation_dialogs:
242
256
  if user_turn:
243
- mood_prompt += f"A: {mood_dialog}\n"
257
+ parts.append(f"A: {mood_dialog}\n")
244
258
  else:
245
- mood_prompt += f"B: {mood_dialog}\n"
259
+ parts.append(f"B: {mood_dialog}\n")
246
260
  user_turn = not user_turn
261
+ mood_prompt = "".join(parts)
247
262
  system_prompt = persona.get("prompt", "")
248
263
  if mood_prompt:
249
264
  system_prompt += f"Here are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n {mood_prompt}"
@@ -253,14 +268,15 @@ async def migration_persona_data(
253
268
  begin_dialogs=begin_dialogs,
254
269
  )
255
270
  logger.info(
256
- f"迁移 Persona {persona['name']}({persona_new.system_prompt[:30]}...) 到新表成功。"
271
+ f"迁移 Persona {persona['name']}({persona_new.system_prompt[:30]}...) 到新表成功。",
257
272
  )
258
273
  except Exception as e:
259
274
  logger.error(f"解析 Persona 配置失败:{e}")
260
275
 
261
276
 
262
277
  async def migration_preferences(
263
- db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
278
+ db_helper: BaseDatabase,
279
+ platform_id_map: dict[str, dict[str, str]],
264
280
  ):
265
281
  # 1. global scope migration
266
282
  keys = [
@@ -329,10 +345,13 @@ async def migration_preferences(
329
345
 
330
346
  for provider_type, provider_id in perf.items():
331
347
  await sp.put_async(
332
- "umo", str(session), f"provider_perf_{provider_type}", provider_id
348
+ "umo",
349
+ str(session),
350
+ f"provider_perf_{provider_type}",
351
+ provider_id,
333
352
  )
334
353
  logger.info(
335
- f"迁移会话 {umo} 的提供商偏好到新表成功,平台 ID: {platform_id}"
354
+ f"迁移会话 {umo} 的提供商偏好到新表成功,平台 ID: {platform_id}",
336
355
  )
337
356
  except Exception as e:
338
357
  logger.error(f"迁移会话 {umo} 的提供商偏好失败: {e}", exc_info=True)
@@ -9,7 +9,7 @@ async def migrate_45_to_46(acm: AstrBotConfigManager, ucr: UmopConfigRouter):
9
9
  if not isinstance(abconf_data, dict):
10
10
  # should be unreachable
11
11
  logger.warning(
12
- f"migrate_45_to_46: abconf_data is not a dict (type={type(abconf_data)}). Value: {abconf_data!r}"
12
+ f"migrate_45_to_46: abconf_data is not a dict (type={type(abconf_data)}). Value: {abconf_data!r}",
13
13
  )
14
14
  return
15
15
 
@@ -1,6 +1,7 @@
1
1
  import json
2
2
  import os
3
3
  from typing import TypeVar
4
+
4
5
  from astrbot.core.utils.astrbot_path import get_astrbot_data_path
5
6
 
6
7
  _VT = TypeVar("_VT")
@@ -16,7 +17,7 @@ class SharedPreferences:
16
17
  def _load_preferences(self):
17
18
  if os.path.exists(self.path):
18
19
  try:
19
- with open(self.path, "r") as f:
20
+ with open(self.path) as f:
20
21
  return json.load(f)
21
22
  except json.JSONDecodeError:
22
23
  os.remove(self.path)
@@ -1,8 +1,9 @@
1
1
  import sqlite3
2
2
  import time
3
- from astrbot.core.db.po import Platform, Stats
4
- from typing import Tuple, List, Dict, Any
5
3
  from dataclasses import dataclass
4
+ from typing import Any
5
+
6
+ from astrbot.core.db.po import Platform, Stats
6
7
 
7
8
 
8
9
  @dataclass
@@ -94,7 +95,7 @@ class SQLiteDatabase:
94
95
  c.execute(
95
96
  """
96
97
  PRAGMA table_info(webchat_conversation)
97
- """
98
+ """,
98
99
  )
99
100
  res = c.fetchall()
100
101
  has_title = False
@@ -108,14 +109,14 @@ class SQLiteDatabase:
108
109
  c.execute(
109
110
  """
110
111
  ALTER TABLE webchat_conversation ADD COLUMN title TEXT;
111
- """
112
+ """,
112
113
  )
113
114
  self.conn.commit()
114
115
  if not has_persona_id:
115
116
  c.execute(
116
117
  """
117
118
  ALTER TABLE webchat_conversation ADD COLUMN persona_id TEXT;
118
- """
119
+ """,
119
120
  )
120
121
  self.conn.commit()
121
122
 
@@ -126,7 +127,7 @@ class SQLiteDatabase:
126
127
  conn.text_factory = str
127
128
  return conn
128
129
 
129
- def _exec_sql(self, sql: str, params: Tuple = None):
130
+ def _exec_sql(self, sql: str, params: tuple = None):
130
131
  conn = self.conn
131
132
  try:
132
133
  c = self.conn.cursor()
@@ -174,7 +175,7 @@ class SQLiteDatabase:
174
175
  """
175
176
  SELECT * FROM platform
176
177
  """
177
- + where_clause
178
+ + where_clause,
178
179
  )
179
180
 
180
181
  platform = []
@@ -194,7 +195,7 @@ class SQLiteDatabase:
194
195
  c.execute(
195
196
  """
196
197
  SELECT SUM(count) FROM platform
197
- """
198
+ """,
198
199
  )
199
200
  res = c.fetchone()
200
201
  c.close()
@@ -214,7 +215,7 @@ class SQLiteDatabase:
214
215
  SELECT name, SUM(count), timestamp FROM platform
215
216
  """
216
217
  + where_clause
217
- + " GROUP BY name"
218
+ + " GROUP BY name",
218
219
  )
219
220
 
220
221
  platform = []
@@ -242,7 +243,7 @@ class SQLiteDatabase:
242
243
  c.close()
243
244
 
244
245
  if not res:
245
- return
246
+ return None
246
247
 
247
248
  return Conversation(*res)
248
249
 
@@ -257,7 +258,7 @@ class SQLiteDatabase:
257
258
  (user_id, cid, history, updated_at, created_at),
258
259
  )
259
260
 
260
- def get_conversations(self, user_id: str) -> Tuple:
261
+ def get_conversations(self, user_id: str) -> tuple:
261
262
  try:
262
263
  c = self.conn.cursor()
263
264
  except sqlite3.ProgrammingError:
@@ -280,7 +281,7 @@ class SQLiteDatabase:
280
281
  title = row[3]
281
282
  persona_id = row[4]
282
283
  conversations.append(
283
- Conversation("", cid, "[]", created_at, updated_at, title, persona_id)
284
+ Conversation("", cid, "[]", created_at, updated_at, title, persona_id),
284
285
  )
285
286
  return conversations
286
287
 
@@ -319,8 +320,10 @@ class SQLiteDatabase:
319
320
  )
320
321
 
321
322
  def get_all_conversations(
322
- self, page: int = 1, page_size: int = 20
323
- ) -> Tuple[List[Dict[str, Any]], int]:
323
+ self,
324
+ page: int = 1,
325
+ page_size: int = 20,
326
+ ) -> tuple[list[dict[str, Any]], int]:
324
327
  """获取所有对话,支持分页,按更新时间降序排序"""
325
328
  try:
326
329
  c = self.conn.cursor()
@@ -366,7 +369,7 @@ class SQLiteDatabase:
366
369
  "persona_id": persona_id or "",
367
370
  "created_at": created_at or 0,
368
371
  "updated_at": updated_at or 0,
369
- }
372
+ },
370
373
  )
371
374
 
372
375
  return conversations, total_count
@@ -381,12 +384,12 @@ class SQLiteDatabase:
381
384
  self,
382
385
  page: int = 1,
383
386
  page_size: int = 20,
384
- platforms: List[str] = None,
385
- message_types: List[str] = None,
386
- search_query: str = None,
387
- exclude_ids: List[str] = None,
388
- exclude_platforms: List[str] = None,
389
- ) -> Tuple[List[Dict[str, Any]], int]:
387
+ platforms: list[str] | None = None,
388
+ message_types: list[str] | None = None,
389
+ search_query: str | None = None,
390
+ exclude_ids: list[str] | None = None,
391
+ exclude_platforms: list[str] | None = None,
392
+ ) -> tuple[list[dict[str, Any]], int]:
390
393
  """获取筛选后的对话列表"""
391
394
  try:
392
395
  c = self.conn.cursor()
@@ -422,7 +425,7 @@ class SQLiteDatabase:
422
425
  if search_query:
423
426
  search_query = search_query.encode("unicode_escape").decode("utf-8")
424
427
  where_clauses.append(
425
- "(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)"
428
+ "(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)",
426
429
  )
427
430
  search_param = f"%{search_query}%"
428
431
  params.extend([search_param, search_param, search_param, search_param])
@@ -482,7 +485,7 @@ class SQLiteDatabase:
482
485
  "persona_id": persona_id or "",
483
486
  "created_at": created_at or 0,
484
487
  "updated_at": updated_at or 0,
485
- }
488
+ },
486
489
  )
487
490
 
488
491
  return conversations, total_count
astrbot/core/db/po.py CHANGED
@@ -1,15 +1,15 @@
1
1
  import uuid
2
-
3
- from datetime import datetime, timezone
4
2
  from dataclasses import dataclass, field
3
+ from datetime import datetime, timezone
4
+ from typing import TypedDict
5
+
5
6
  from sqlmodel import (
7
+ JSON,
8
+ Field,
6
9
  SQLModel,
7
10
  Text,
8
- JSON,
9
11
  UniqueConstraint,
10
- Field,
11
12
  )
12
- from typing import Optional, TypedDict
13
13
 
14
14
 
15
15
  class PlatformStat(SQLModel, table=True):
@@ -40,7 +40,8 @@ class ConversationV2(SQLModel, table=True):
40
40
  __tablename__ = "conversations"
41
41
 
42
42
  inner_conversation_id: int = Field(
43
- primary_key=True, sa_column_kwargs={"autoincrement": True}
43
+ primary_key=True,
44
+ sa_column_kwargs={"autoincrement": True},
44
45
  )
45
46
  conversation_id: str = Field(
46
47
  max_length=36,
@@ -50,14 +51,14 @@ class ConversationV2(SQLModel, table=True):
50
51
  )
51
52
  platform_id: str = Field(nullable=False)
52
53
  user_id: str = Field(nullable=False)
53
- content: Optional[list] = Field(default=None, sa_type=JSON)
54
+ content: list | None = Field(default=None, sa_type=JSON)
54
55
  created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
55
56
  updated_at: datetime = Field(
56
57
  default_factory=lambda: datetime.now(timezone.utc),
57
58
  sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
58
59
  )
59
- title: Optional[str] = Field(default=None, max_length=255)
60
- persona_id: Optional[str] = Field(default=None)
60
+ title: str | None = Field(default=None, max_length=255)
61
+ persona_id: str | None = Field(default=None)
61
62
 
62
63
  __table_args__ = (
63
64
  UniqueConstraint(
@@ -76,13 +77,15 @@ class Persona(SQLModel, table=True):
76
77
  __tablename__ = "personas"
77
78
 
78
79
  id: int | None = Field(
79
- primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
80
+ primary_key=True,
81
+ sa_column_kwargs={"autoincrement": True},
82
+ default=None,
80
83
  )
81
84
  persona_id: str = Field(max_length=255, nullable=False)
82
85
  system_prompt: str = Field(sa_type=Text, nullable=False)
83
- begin_dialogs: Optional[list] = Field(default=None, sa_type=JSON)
86
+ begin_dialogs: list | None = Field(default=None, sa_type=JSON)
84
87
  """a list of strings, each representing a dialog to start with"""
85
- tools: Optional[list] = Field(default=None, sa_type=JSON)
88
+ tools: list | None = Field(default=None, sa_type=JSON)
86
89
  """None means use ALL tools for default, empty list means no tools, otherwise a list of tool names."""
87
90
  created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
88
91
  updated_at: datetime = Field(
@@ -104,7 +107,9 @@ class Preference(SQLModel, table=True):
104
107
  __tablename__ = "preferences"
105
108
 
106
109
  id: int | None = Field(
107
- default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}
110
+ default=None,
111
+ primary_key=True,
112
+ sa_column_kwargs={"autoincrement": True},
108
113
  )
109
114
  scope: str = Field(nullable=False)
110
115
  """Scope of the preference, such as 'global', 'umo', 'plugin'."""
@@ -138,13 +143,15 @@ class PlatformMessageHistory(SQLModel, table=True):
138
143
  __tablename__ = "platform_message_history"
139
144
 
140
145
  id: int | None = Field(
141
- primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
146
+ primary_key=True,
147
+ sa_column_kwargs={"autoincrement": True},
148
+ default=None,
142
149
  )
143
150
  platform_id: str = Field(nullable=False)
144
151
  user_id: str = Field(nullable=False) # An id of group, user in platform
145
- sender_id: Optional[str] = Field(default=None) # ID of the sender in the platform
146
- sender_name: Optional[str] = Field(
147
- default=None
152
+ sender_id: str | None = Field(default=None) # ID of the sender in the platform
153
+ sender_name: str | None = Field(
154
+ default=None,
148
155
  ) # Name of the sender in the platform
149
156
  content: dict = Field(sa_type=JSON, nullable=False) # a message chain list
150
157
  created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
@@ -163,7 +170,9 @@ class Attachment(SQLModel, table=True):
163
170
  __tablename__ = "attachments"
164
171
 
165
172
  inner_attachment_id: int | None = Field(
166
- primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
173
+ primary_key=True,
174
+ sa_column_kwargs={"autoincrement": True},
175
+ default=None,
167
176
  )
168
177
  attachment_id: str = Field(
169
178
  max_length=36,