AstrBot 3.5.6__py3-none-any.whl → 4.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (288) hide show
  1. astrbot/api/__init__.py +16 -4
  2. astrbot/api/all.py +2 -1
  3. astrbot/api/event/__init__.py +5 -6
  4. astrbot/api/event/filter/__init__.py +37 -34
  5. astrbot/api/platform/__init__.py +7 -8
  6. astrbot/api/provider/__init__.py +8 -7
  7. astrbot/api/star/__init__.py +3 -4
  8. astrbot/api/util/__init__.py +2 -2
  9. astrbot/cli/__init__.py +1 -0
  10. astrbot/cli/__main__.py +18 -197
  11. astrbot/cli/commands/__init__.py +6 -0
  12. astrbot/cli/commands/cmd_conf.py +209 -0
  13. astrbot/cli/commands/cmd_init.py +56 -0
  14. astrbot/cli/commands/cmd_plug.py +245 -0
  15. astrbot/cli/commands/cmd_run.py +62 -0
  16. astrbot/cli/utils/__init__.py +18 -0
  17. astrbot/cli/utils/basic.py +76 -0
  18. astrbot/cli/utils/plugin.py +246 -0
  19. astrbot/cli/utils/version_comparator.py +90 -0
  20. astrbot/core/__init__.py +17 -19
  21. astrbot/core/agent/agent.py +14 -0
  22. astrbot/core/agent/handoff.py +38 -0
  23. astrbot/core/agent/hooks.py +30 -0
  24. astrbot/core/agent/mcp_client.py +385 -0
  25. astrbot/core/agent/message.py +175 -0
  26. astrbot/core/agent/response.py +14 -0
  27. astrbot/core/agent/run_context.py +22 -0
  28. astrbot/core/agent/runners/__init__.py +3 -0
  29. astrbot/core/agent/runners/base.py +65 -0
  30. astrbot/core/agent/runners/coze/coze_agent_runner.py +367 -0
  31. astrbot/core/agent/runners/coze/coze_api_client.py +324 -0
  32. astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py +403 -0
  33. astrbot/core/agent/runners/dify/dify_agent_runner.py +336 -0
  34. astrbot/core/agent/runners/dify/dify_api_client.py +195 -0
  35. astrbot/core/agent/runners/tool_loop_agent_runner.py +400 -0
  36. astrbot/core/agent/tool.py +285 -0
  37. astrbot/core/agent/tool_executor.py +17 -0
  38. astrbot/core/astr_agent_context.py +19 -0
  39. astrbot/core/astr_agent_hooks.py +36 -0
  40. astrbot/core/astr_agent_run_util.py +80 -0
  41. astrbot/core/astr_agent_tool_exec.py +246 -0
  42. astrbot/core/astrbot_config_mgr.py +275 -0
  43. astrbot/core/config/__init__.py +2 -2
  44. astrbot/core/config/astrbot_config.py +60 -20
  45. astrbot/core/config/default.py +1972 -453
  46. astrbot/core/config/i18n_utils.py +110 -0
  47. astrbot/core/conversation_mgr.py +285 -75
  48. astrbot/core/core_lifecycle.py +167 -62
  49. astrbot/core/db/__init__.py +305 -102
  50. astrbot/core/db/migration/helper.py +69 -0
  51. astrbot/core/db/migration/migra_3_to_4.py +357 -0
  52. astrbot/core/db/migration/migra_45_to_46.py +44 -0
  53. astrbot/core/db/migration/migra_webchat_session.py +131 -0
  54. astrbot/core/db/migration/shared_preferences_v3.py +48 -0
  55. astrbot/core/db/migration/sqlite_v3.py +497 -0
  56. astrbot/core/db/po.py +259 -55
  57. astrbot/core/db/sqlite.py +773 -528
  58. astrbot/core/db/vec_db/base.py +73 -0
  59. astrbot/core/db/vec_db/faiss_impl/__init__.py +3 -0
  60. astrbot/core/db/vec_db/faiss_impl/document_storage.py +392 -0
  61. astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +93 -0
  62. astrbot/core/db/vec_db/faiss_impl/sqlite_init.sql +17 -0
  63. astrbot/core/db/vec_db/faiss_impl/vec_db.py +204 -0
  64. astrbot/core/event_bus.py +26 -22
  65. astrbot/core/exceptions.py +9 -0
  66. astrbot/core/file_token_service.py +98 -0
  67. astrbot/core/initial_loader.py +19 -10
  68. astrbot/core/knowledge_base/chunking/__init__.py +9 -0
  69. astrbot/core/knowledge_base/chunking/base.py +25 -0
  70. astrbot/core/knowledge_base/chunking/fixed_size.py +59 -0
  71. astrbot/core/knowledge_base/chunking/recursive.py +161 -0
  72. astrbot/core/knowledge_base/kb_db_sqlite.py +301 -0
  73. astrbot/core/knowledge_base/kb_helper.py +642 -0
  74. astrbot/core/knowledge_base/kb_mgr.py +330 -0
  75. astrbot/core/knowledge_base/models.py +120 -0
  76. astrbot/core/knowledge_base/parsers/__init__.py +13 -0
  77. astrbot/core/knowledge_base/parsers/base.py +51 -0
  78. astrbot/core/knowledge_base/parsers/markitdown_parser.py +26 -0
  79. astrbot/core/knowledge_base/parsers/pdf_parser.py +101 -0
  80. astrbot/core/knowledge_base/parsers/text_parser.py +42 -0
  81. astrbot/core/knowledge_base/parsers/url_parser.py +103 -0
  82. astrbot/core/knowledge_base/parsers/util.py +13 -0
  83. astrbot/core/knowledge_base/prompts.py +65 -0
  84. astrbot/core/knowledge_base/retrieval/__init__.py +14 -0
  85. astrbot/core/knowledge_base/retrieval/hit_stopwords.txt +767 -0
  86. astrbot/core/knowledge_base/retrieval/manager.py +276 -0
  87. astrbot/core/knowledge_base/retrieval/rank_fusion.py +142 -0
  88. astrbot/core/knowledge_base/retrieval/sparse_retriever.py +136 -0
  89. astrbot/core/log.py +21 -15
  90. astrbot/core/message/components.py +413 -287
  91. astrbot/core/message/message_event_result.py +35 -24
  92. astrbot/core/persona_mgr.py +192 -0
  93. astrbot/core/pipeline/__init__.py +14 -14
  94. astrbot/core/pipeline/content_safety_check/stage.py +13 -9
  95. astrbot/core/pipeline/content_safety_check/strategies/__init__.py +1 -2
  96. astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py +13 -14
  97. astrbot/core/pipeline/content_safety_check/strategies/keywords.py +2 -1
  98. astrbot/core/pipeline/content_safety_check/strategies/strategy.py +6 -6
  99. astrbot/core/pipeline/context.py +7 -1
  100. astrbot/core/pipeline/context_utils.py +107 -0
  101. astrbot/core/pipeline/preprocess_stage/stage.py +63 -36
  102. astrbot/core/pipeline/process_stage/method/agent_request.py +48 -0
  103. astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +464 -0
  104. astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +202 -0
  105. astrbot/core/pipeline/process_stage/method/star_request.py +26 -32
  106. astrbot/core/pipeline/process_stage/stage.py +21 -15
  107. astrbot/core/pipeline/process_stage/utils.py +125 -0
  108. astrbot/core/pipeline/rate_limit_check/stage.py +34 -36
  109. astrbot/core/pipeline/respond/stage.py +142 -101
  110. astrbot/core/pipeline/result_decorate/stage.py +124 -57
  111. astrbot/core/pipeline/scheduler.py +21 -16
  112. astrbot/core/pipeline/session_status_check/stage.py +37 -0
  113. astrbot/core/pipeline/stage.py +11 -76
  114. astrbot/core/pipeline/waking_check/stage.py +69 -33
  115. astrbot/core/pipeline/whitelist_check/stage.py +10 -7
  116. astrbot/core/platform/__init__.py +6 -6
  117. astrbot/core/platform/astr_message_event.py +107 -129
  118. astrbot/core/platform/astrbot_message.py +32 -12
  119. astrbot/core/platform/manager.py +62 -18
  120. astrbot/core/platform/message_session.py +30 -0
  121. astrbot/core/platform/platform.py +16 -24
  122. astrbot/core/platform/platform_metadata.py +9 -4
  123. astrbot/core/platform/register.py +12 -7
  124. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +136 -60
  125. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +126 -46
  126. astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +63 -31
  127. astrbot/core/platform/sources/dingtalk/dingtalk_event.py +30 -26
  128. astrbot/core/platform/sources/discord/client.py +129 -0
  129. astrbot/core/platform/sources/discord/components.py +139 -0
  130. astrbot/core/platform/sources/discord/discord_platform_adapter.py +473 -0
  131. astrbot/core/platform/sources/discord/discord_platform_event.py +313 -0
  132. astrbot/core/platform/sources/lark/lark_adapter.py +27 -18
  133. astrbot/core/platform/sources/lark/lark_event.py +39 -13
  134. astrbot/core/platform/sources/misskey/misskey_adapter.py +770 -0
  135. astrbot/core/platform/sources/misskey/misskey_api.py +964 -0
  136. astrbot/core/platform/sources/misskey/misskey_event.py +163 -0
  137. astrbot/core/platform/sources/misskey/misskey_utils.py +550 -0
  138. astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +149 -33
  139. astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +41 -26
  140. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +36 -17
  141. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py +3 -1
  142. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +14 -8
  143. astrbot/core/platform/sources/satori/satori_adapter.py +792 -0
  144. astrbot/core/platform/sources/satori/satori_event.py +432 -0
  145. astrbot/core/platform/sources/slack/client.py +164 -0
  146. astrbot/core/platform/sources/slack/slack_adapter.py +416 -0
  147. astrbot/core/platform/sources/slack/slack_event.py +253 -0
  148. astrbot/core/platform/sources/telegram/tg_adapter.py +100 -43
  149. astrbot/core/platform/sources/telegram/tg_event.py +136 -36
  150. astrbot/core/platform/sources/webchat/webchat_adapter.py +72 -22
  151. astrbot/core/platform/sources/webchat/webchat_event.py +46 -22
  152. astrbot/core/platform/sources/webchat/webchat_queue_mgr.py +35 -0
  153. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +926 -0
  154. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +178 -0
  155. astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py +159 -0
  156. astrbot/core/platform/sources/wecom/wecom_adapter.py +169 -27
  157. astrbot/core/platform/sources/wecom/wecom_event.py +162 -77
  158. astrbot/core/platform/sources/wecom/wecom_kf.py +279 -0
  159. astrbot/core/platform/sources/wecom/wecom_kf_message.py +196 -0
  160. astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +297 -0
  161. astrbot/core/platform/sources/wecom_ai_bot/__init__.py +15 -0
  162. astrbot/core/platform/sources/wecom_ai_bot/ierror.py +19 -0
  163. astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +472 -0
  164. astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +417 -0
  165. astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +152 -0
  166. astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +153 -0
  167. astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +168 -0
  168. astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +209 -0
  169. astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +306 -0
  170. astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +186 -0
  171. astrbot/core/platform_message_history_mgr.py +49 -0
  172. astrbot/core/provider/__init__.py +2 -3
  173. astrbot/core/provider/entites.py +8 -8
  174. astrbot/core/provider/entities.py +154 -98
  175. astrbot/core/provider/func_tool_manager.py +446 -458
  176. astrbot/core/provider/manager.py +345 -207
  177. astrbot/core/provider/provider.py +188 -73
  178. astrbot/core/provider/register.py +9 -7
  179. astrbot/core/provider/sources/anthropic_source.py +295 -115
  180. astrbot/core/provider/sources/azure_tts_source.py +224 -0
  181. astrbot/core/provider/sources/bailian_rerank_source.py +236 -0
  182. astrbot/core/provider/sources/dashscope_tts.py +138 -14
  183. astrbot/core/provider/sources/edge_tts_source.py +24 -19
  184. astrbot/core/provider/sources/fishaudio_tts_api_source.py +58 -13
  185. astrbot/core/provider/sources/gemini_embedding_source.py +61 -0
  186. astrbot/core/provider/sources/gemini_source.py +310 -132
  187. astrbot/core/provider/sources/gemini_tts_source.py +81 -0
  188. astrbot/core/provider/sources/groq_source.py +15 -0
  189. astrbot/core/provider/sources/gsv_selfhosted_source.py +151 -0
  190. astrbot/core/provider/sources/gsvi_tts_source.py +14 -7
  191. astrbot/core/provider/sources/minimax_tts_api_source.py +159 -0
  192. astrbot/core/provider/sources/openai_embedding_source.py +40 -0
  193. astrbot/core/provider/sources/openai_source.py +241 -145
  194. astrbot/core/provider/sources/openai_tts_api_source.py +18 -7
  195. astrbot/core/provider/sources/sensevoice_selfhosted_source.py +13 -11
  196. astrbot/core/provider/sources/vllm_rerank_source.py +71 -0
  197. astrbot/core/provider/sources/volcengine_tts.py +115 -0
  198. astrbot/core/provider/sources/whisper_api_source.py +18 -13
  199. astrbot/core/provider/sources/whisper_selfhosted_source.py +19 -12
  200. astrbot/core/provider/sources/xinference_rerank_source.py +116 -0
  201. astrbot/core/provider/sources/xinference_stt_provider.py +197 -0
  202. astrbot/core/provider/sources/zhipu_source.py +6 -73
  203. astrbot/core/star/__init__.py +43 -11
  204. astrbot/core/star/config.py +17 -18
  205. astrbot/core/star/context.py +362 -138
  206. astrbot/core/star/filter/__init__.py +4 -3
  207. astrbot/core/star/filter/command.py +111 -35
  208. astrbot/core/star/filter/command_group.py +46 -34
  209. astrbot/core/star/filter/custom_filter.py +6 -5
  210. astrbot/core/star/filter/event_message_type.py +4 -2
  211. astrbot/core/star/filter/permission.py +4 -2
  212. astrbot/core/star/filter/platform_adapter_type.py +45 -12
  213. astrbot/core/star/filter/regex.py +4 -2
  214. astrbot/core/star/register/__init__.py +19 -15
  215. astrbot/core/star/register/star.py +41 -13
  216. astrbot/core/star/register/star_handler.py +236 -86
  217. astrbot/core/star/session_llm_manager.py +280 -0
  218. astrbot/core/star/session_plugin_manager.py +170 -0
  219. astrbot/core/star/star.py +36 -43
  220. astrbot/core/star/star_handler.py +47 -85
  221. astrbot/core/star/star_manager.py +442 -260
  222. astrbot/core/star/star_tools.py +167 -45
  223. astrbot/core/star/updator.py +17 -20
  224. astrbot/core/umop_config_router.py +106 -0
  225. astrbot/core/updator.py +38 -13
  226. astrbot/core/utils/astrbot_path.py +39 -0
  227. astrbot/core/utils/command_parser.py +1 -1
  228. astrbot/core/utils/io.py +119 -60
  229. astrbot/core/utils/log_pipe.py +1 -1
  230. astrbot/core/utils/metrics.py +11 -10
  231. astrbot/core/utils/migra_helper.py +73 -0
  232. astrbot/core/utils/path_util.py +63 -62
  233. astrbot/core/utils/pip_installer.py +37 -15
  234. astrbot/core/utils/session_lock.py +29 -0
  235. astrbot/core/utils/session_waiter.py +19 -20
  236. astrbot/core/utils/shared_preferences.py +174 -34
  237. astrbot/core/utils/t2i/__init__.py +4 -1
  238. astrbot/core/utils/t2i/local_strategy.py +386 -238
  239. astrbot/core/utils/t2i/network_strategy.py +109 -49
  240. astrbot/core/utils/t2i/renderer.py +29 -14
  241. astrbot/core/utils/t2i/template/astrbot_powershell.html +184 -0
  242. astrbot/core/utils/t2i/template_manager.py +111 -0
  243. astrbot/core/utils/tencent_record_helper.py +115 -1
  244. astrbot/core/utils/version_comparator.py +10 -13
  245. astrbot/core/zip_updator.py +112 -65
  246. astrbot/dashboard/routes/__init__.py +20 -13
  247. astrbot/dashboard/routes/auth.py +20 -9
  248. astrbot/dashboard/routes/chat.py +297 -141
  249. astrbot/dashboard/routes/config.py +652 -55
  250. astrbot/dashboard/routes/conversation.py +107 -37
  251. astrbot/dashboard/routes/file.py +26 -0
  252. astrbot/dashboard/routes/knowledge_base.py +1244 -0
  253. astrbot/dashboard/routes/log.py +27 -2
  254. astrbot/dashboard/routes/persona.py +202 -0
  255. astrbot/dashboard/routes/plugin.py +197 -139
  256. astrbot/dashboard/routes/route.py +27 -7
  257. astrbot/dashboard/routes/session_management.py +354 -0
  258. astrbot/dashboard/routes/stat.py +85 -18
  259. astrbot/dashboard/routes/static_file.py +5 -2
  260. astrbot/dashboard/routes/t2i.py +233 -0
  261. astrbot/dashboard/routes/tools.py +184 -120
  262. astrbot/dashboard/routes/update.py +59 -36
  263. astrbot/dashboard/server.py +96 -36
  264. astrbot/dashboard/utils.py +165 -0
  265. astrbot-4.7.0.dist-info/METADATA +294 -0
  266. astrbot-4.7.0.dist-info/RECORD +274 -0
  267. {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/WHEEL +1 -1
  268. astrbot/core/db/plugin/sqlite_impl.py +0 -112
  269. astrbot/core/db/sqlite_init.sql +0 -50
  270. astrbot/core/pipeline/platform_compatibility/stage.py +0 -56
  271. astrbot/core/pipeline/process_stage/method/llm_request.py +0 -606
  272. astrbot/core/platform/sources/gewechat/client.py +0 -806
  273. astrbot/core/platform/sources/gewechat/downloader.py +0 -55
  274. astrbot/core/platform/sources/gewechat/gewechat_event.py +0 -255
  275. astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py +0 -103
  276. astrbot/core/platform/sources/gewechat/xml_data_parser.py +0 -110
  277. astrbot/core/provider/sources/dashscope_source.py +0 -203
  278. astrbot/core/provider/sources/dify_source.py +0 -281
  279. astrbot/core/provider/sources/llmtuner_source.py +0 -132
  280. astrbot/core/rag/embedding/openai_source.py +0 -20
  281. astrbot/core/rag/knowledge_db_mgr.py +0 -94
  282. astrbot/core/rag/store/__init__.py +0 -9
  283. astrbot/core/rag/store/chroma_db.py +0 -42
  284. astrbot/core/utils/dify_api_client.py +0 -152
  285. astrbot-3.5.6.dist-info/METADATA +0 -249
  286. astrbot-3.5.6.dist-info/RECORD +0 -158
  287. {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/entry_points.txt +0 -0
  288. {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/licenses/LICENSE +0 -0
astrbot/core/db/sqlite.py CHANGED
@@ -1,565 +1,810 @@
1
- import sqlite3
2
- import os
3
- import time
4
- from astrbot.core.db.po import Platform, Stats, LLMHistory, ATRIVision, Conversation
5
- from . import BaseDatabase
6
- from typing import Tuple, List, Dict, Any
1
+ import asyncio
2
+ import threading
3
+ import typing as T
4
+ from datetime import datetime, timedelta, timezone
5
+
6
+ from sqlalchemy.ext.asyncio import AsyncSession
7
+ from sqlmodel import col, delete, desc, func, or_, select, text, update
8
+
9
+ from astrbot.core.db import BaseDatabase
10
+ from astrbot.core.db.po import (
11
+ Attachment,
12
+ ConversationV2,
13
+ Persona,
14
+ PlatformMessageHistory,
15
+ PlatformSession,
16
+ PlatformStat,
17
+ Preference,
18
+ SQLModel,
19
+ )
20
+ from astrbot.core.db.po import (
21
+ Platform as DeprecatedPlatformStat,
22
+ )
23
+ from astrbot.core.db.po import (
24
+ Stats as DeprecatedStats,
25
+ )
26
+
27
+ NOT_GIVEN = T.TypeVar("NOT_GIVEN")
7
28
 
8
29
 
9
30
  class SQLiteDatabase(BaseDatabase):
10
31
  def __init__(self, db_path: str) -> None:
11
- super().__init__()
12
32
  self.db_path = db_path
33
+ self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
34
+ self.inited = False
35
+ super().__init__()
36
+
37
+ async def initialize(self) -> None:
38
+ """Initialize the database by creating tables if they do not exist."""
39
+ async with self.engine.begin() as conn:
40
+ await conn.run_sync(SQLModel.metadata.create_all)
41
+ await conn.execute(text("PRAGMA journal_mode=WAL"))
42
+ await conn.execute(text("PRAGMA synchronous=NORMAL"))
43
+ await conn.execute(text("PRAGMA cache_size=20000"))
44
+ await conn.execute(text("PRAGMA temp_store=MEMORY"))
45
+ await conn.execute(text("PRAGMA mmap_size=134217728"))
46
+ await conn.execute(text("PRAGMA optimize"))
47
+ await conn.commit()
48
+
49
+ # ====
50
+ # Platform Statistics
51
+ # ====
52
+
53
+ async def insert_platform_stats(
54
+ self,
55
+ platform_id,
56
+ platform_type,
57
+ count=1,
58
+ timestamp=None,
59
+ ) -> None:
60
+ """Insert a new platform statistic record."""
61
+ async with self.get_db() as session:
62
+ session: AsyncSession
63
+ async with session.begin():
64
+ if timestamp is None:
65
+ timestamp = datetime.now().replace(
66
+ minute=0,
67
+ second=0,
68
+ microsecond=0,
69
+ )
70
+ current_hour = timestamp
71
+ await session.execute(
72
+ text("""
73
+ INSERT INTO platform_stats (timestamp, platform_id, platform_type, count)
74
+ VALUES (:timestamp, :platform_id, :platform_type, :count)
75
+ ON CONFLICT(timestamp, platform_id, platform_type) DO UPDATE SET
76
+ count = platform_stats.count + EXCLUDED.count
77
+ """),
78
+ {
79
+ "timestamp": current_hour,
80
+ "platform_id": platform_id,
81
+ "platform_type": platform_type,
82
+ "count": count,
83
+ },
84
+ )
13
85
 
14
- with open(os.path.dirname(__file__) + "/sqlite_init.sql", "r") as f:
15
- sql = f.read()
16
-
17
- # 初始化数据库
18
- self.conn = self._get_conn(self.db_path)
19
- c = self.conn.cursor()
20
- c.executescript(sql)
21
- self.conn.commit()
22
-
23
- # 检查 webchat_conversation 的 title 字段是否存在
24
- c.execute(
25
- """
26
- PRAGMA table_info(webchat_conversation)
27
- """
28
- )
29
- res = c.fetchall()
30
- has_title = False
31
- has_persona_id = False
32
- for row in res:
33
- if row[1] == "title":
34
- has_title = True
35
- if row[1] == "persona_id":
36
- has_persona_id = True
37
- if not has_title:
38
- c.execute(
39
- """
40
- ALTER TABLE webchat_conversation ADD COLUMN title TEXT;
41
- """
86
+ async def count_platform_stats(self) -> int:
87
+ """Count the number of platform statistics records."""
88
+ async with self.get_db() as session:
89
+ session: AsyncSession
90
+ result = await session.execute(
91
+ select(func.count(col(PlatformStat.platform_id))).select_from(
92
+ PlatformStat,
93
+ ),
42
94
  )
43
- self.conn.commit()
44
- if not has_persona_id:
45
- c.execute(
46
- """
47
- ALTER TABLE webchat_conversation ADD COLUMN persona_id TEXT;
48
- """
95
+ count = result.scalar_one_or_none()
96
+ return count if count is not None else 0
97
+
98
+ async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat]:
99
+ """Get platform statistics within the specified offset in seconds and group by platform_id."""
100
+ async with self.get_db() as session:
101
+ session: AsyncSession
102
+ now = datetime.now()
103
+ start_time = now - timedelta(seconds=offset_sec)
104
+ result = await session.execute(
105
+ text("""
106
+ SELECT * FROM platform_stats
107
+ WHERE timestamp >= :start_time
108
+ ORDER BY timestamp DESC
109
+ GROUP BY platform_id
110
+ """),
111
+ {"start_time": start_time},
49
112
  )
50
- self.conn.commit()
51
-
52
- c.close()
53
-
54
- def _get_conn(self, db_path: str) -> sqlite3.Connection:
55
- conn = sqlite3.connect(self.db_path)
56
- conn.text_factory = str
57
- return conn
58
-
59
- def _exec_sql(self, sql: str, params: Tuple = None):
60
- conn = self.conn
61
- try:
62
- c = self.conn.cursor()
63
- except sqlite3.ProgrammingError:
64
- conn = self._get_conn(self.db_path)
65
- c = conn.cursor()
66
-
67
- if params:
68
- c.execute(sql, params)
69
- c.close()
70
- else:
71
- c.execute(sql)
72
- c.close()
73
-
74
- conn.commit()
75
-
76
- def insert_platform_metrics(self, metrics: dict):
77
- for k, v in metrics.items():
78
- self._exec_sql(
79
- """
80
- INSERT INTO platform(name, count, timestamp) VALUES (?, ?, ?)
81
- """,
82
- (k, v, int(time.time())),
113
+ return list(result.scalars().all())
114
+
115
+ # ====
116
+ # Conversation Management
117
+ # ====
118
+
119
+ async def get_conversations(self, user_id=None, platform_id=None):
120
+ async with self.get_db() as session:
121
+ session: AsyncSession
122
+ query = select(ConversationV2)
123
+
124
+ if user_id:
125
+ query = query.where(ConversationV2.user_id == user_id)
126
+ if platform_id:
127
+ query = query.where(ConversationV2.platform_id == platform_id)
128
+ # order by
129
+ query = query.order_by(desc(ConversationV2.created_at))
130
+ result = await session.execute(query)
131
+
132
+ return result.scalars().all()
133
+
134
+ async def get_conversation_by_id(self, cid):
135
+ async with self.get_db() as session:
136
+ session: AsyncSession
137
+ query = select(ConversationV2).where(ConversationV2.conversation_id == cid)
138
+ result = await session.execute(query)
139
+ return result.scalar_one_or_none()
140
+
141
+ async def get_all_conversations(self, page=1, page_size=20):
142
+ async with self.get_db() as session:
143
+ session: AsyncSession
144
+ offset = (page - 1) * page_size
145
+ result = await session.execute(
146
+ select(ConversationV2)
147
+ .order_by(desc(ConversationV2.created_at))
148
+ .offset(offset)
149
+ .limit(page_size),
83
150
  )
151
+ return result.scalars().all()
84
152
 
85
- def insert_plugin_metrics(self, metrics: dict):
86
- pass
153
+ async def get_filtered_conversations(
154
+ self,
155
+ page=1,
156
+ page_size=20,
157
+ platform_ids=None,
158
+ search_query="",
159
+ **kwargs,
160
+ ):
161
+ async with self.get_db() as session:
162
+ session: AsyncSession
163
+ # Build the base query with filters
164
+ base_query = select(ConversationV2)
165
+
166
+ if platform_ids:
167
+ base_query = base_query.where(
168
+ col(ConversationV2.platform_id).in_(platform_ids),
169
+ )
170
+ if search_query:
171
+ search_query = search_query.encode("unicode_escape").decode("utf-8")
172
+ base_query = base_query.where(
173
+ or_(
174
+ col(ConversationV2.title).ilike(f"%{search_query}%"),
175
+ col(ConversationV2.content).ilike(f"%{search_query}%"),
176
+ col(ConversationV2.user_id).ilike(f"%{search_query}%"),
177
+ col(ConversationV2.conversation_id).ilike(f"%{search_query}%"),
178
+ ),
179
+ )
180
+ if "message_types" in kwargs and len(kwargs["message_types"]) > 0:
181
+ for msg_type in kwargs["message_types"]:
182
+ base_query = base_query.where(
183
+ col(ConversationV2.user_id).ilike(f"%:{msg_type}:%"),
184
+ )
185
+ if "platforms" in kwargs and len(kwargs["platforms"]) > 0:
186
+ base_query = base_query.where(
187
+ col(ConversationV2.platform_id).in_(kwargs["platforms"]),
188
+ )
87
189
 
88
- def insert_command_metrics(self, metrics: dict):
89
- for k, v in metrics.items():
90
- self._exec_sql(
91
- """
92
- INSERT INTO command(name, count, timestamp) VALUES (?, ?, ?)
93
- """,
94
- (k, v, int(time.time())),
95
- )
190
+ # Get total count matching the filters
191
+ count_query = select(func.count()).select_from(base_query.subquery())
192
+ total_count = await session.execute(count_query)
193
+ total = total_count.scalar_one()
96
194
 
97
- def insert_llm_metrics(self, metrics: dict):
98
- for k, v in metrics.items():
99
- self._exec_sql(
100
- """
101
- INSERT INTO llm(name, count, timestamp) VALUES (?, ?, ?)
102
- """,
103
- (k, v, int(time.time())),
195
+ # Get paginated results
196
+ offset = (page - 1) * page_size
197
+ result_query = (
198
+ base_query.order_by(desc(ConversationV2.created_at))
199
+ .offset(offset)
200
+ .limit(page_size)
104
201
  )
202
+ result = await session.execute(result_query)
203
+ conversations = result.scalars().all()
105
204
 
106
- def update_llm_history(self, session_id: str, content: str, provider_type: str):
107
- res = self.get_llm_history(session_id, provider_type)
108
- if res:
109
- self._exec_sql(
110
- """
111
- UPDATE llm_history SET content = ? WHERE session_id = ? AND provider_type = ?
112
- """,
113
- (content, session_id, provider_type),
114
- )
115
- else:
116
- self._exec_sql(
117
- """
118
- INSERT INTO llm_history(provider_type, session_id, content) VALUES (?, ?, ?)
119
- """,
120
- (provider_type, session_id, content),
121
- )
205
+ return conversations, total
122
206
 
123
- def get_llm_history(
124
- self, session_id: str = None, provider_type: str = None
125
- ) -> Tuple:
126
- try:
127
- c = self.conn.cursor()
128
- except sqlite3.ProgrammingError:
129
- c = self._get_conn(self.db_path).cursor()
207
+ async def create_conversation(
208
+ self,
209
+ user_id,
210
+ platform_id,
211
+ content=None,
212
+ title=None,
213
+ persona_id=None,
214
+ cid=None,
215
+ created_at=None,
216
+ updated_at=None,
217
+ ):
218
+ kwargs = {}
219
+ if cid:
220
+ kwargs["conversation_id"] = cid
221
+ if created_at:
222
+ kwargs["created_at"] = created_at
223
+ if updated_at:
224
+ kwargs["updated_at"] = updated_at
225
+ async with self.get_db() as session:
226
+ session: AsyncSession
227
+ async with session.begin():
228
+ new_conversation = ConversationV2(
229
+ user_id=user_id,
230
+ content=content or [],
231
+ platform_id=platform_id,
232
+ title=title,
233
+ persona_id=persona_id,
234
+ **kwargs,
235
+ )
236
+ session.add(new_conversation)
237
+ return new_conversation
238
+
239
+ async def update_conversation(self, cid, title=None, persona_id=None, content=None):
240
+ async with self.get_db() as session:
241
+ session: AsyncSession
242
+ async with session.begin():
243
+ query = update(ConversationV2).where(
244
+ col(ConversationV2.conversation_id) == cid,
245
+ )
246
+ values = {}
247
+ if title is not None:
248
+ values["title"] = title
249
+ if persona_id is not None:
250
+ values["persona_id"] = persona_id
251
+ if content is not None:
252
+ values["content"] = content
253
+ if not values:
254
+ return None
255
+ query = query.values(**values)
256
+ await session.execute(query)
257
+ return await self.get_conversation_by_id(cid)
258
+
259
+ async def delete_conversation(self, cid):
260
+ async with self.get_db() as session:
261
+ session: AsyncSession
262
+ async with session.begin():
263
+ await session.execute(
264
+ delete(ConversationV2).where(
265
+ col(ConversationV2.conversation_id) == cid,
266
+ ),
267
+ )
130
268
 
131
- conditions = []
132
- params = []
269
+ async def delete_conversations_by_user_id(self, user_id: str) -> None:
270
+ async with self.get_db() as session:
271
+ session: AsyncSession
272
+ async with session.begin():
273
+ await session.execute(
274
+ delete(ConversationV2).where(
275
+ col(ConversationV2.user_id) == user_id
276
+ ),
277
+ )
133
278
 
134
- if session_id:
135
- conditions.append("session_id = ?")
136
- params.append(session_id)
137
-
138
- if provider_type:
139
- conditions.append("provider_type = ?")
140
- params.append(provider_type)
141
-
142
- sql = "SELECT * FROM llm_history"
143
- if conditions:
144
- sql += " WHERE " + " AND ".join(conditions)
145
-
146
- c.execute(sql, params)
147
-
148
- res = c.fetchall()
149
- histories = []
150
- for row in res:
151
- histories.append(LLMHistory(*row))
152
- c.close()
153
- return histories
154
-
155
- def get_base_stats(self, offset_sec: int = 86400) -> Stats:
156
- """获取 offset_sec 秒前到现在的基础统计数据"""
157
- where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}"
158
-
159
- try:
160
- c = self.conn.cursor()
161
- except sqlite3.ProgrammingError:
162
- c = self._get_conn(self.db_path).cursor()
163
-
164
- c.execute(
165
- """
166
- SELECT * FROM platform
167
- """
168
- + where_clause
169
- )
170
-
171
- platform = []
172
- for row in c.fetchall():
173
- platform.append(Platform(*row))
174
-
175
- # c.execute(
176
- # '''
177
- # SELECT * FROM command
178
- # ''' + where_clause
179
- # )
180
-
181
- # command = []
182
- # for row in c.fetchall():
183
- # command.append(Command(*row))
184
-
185
- # c.execute(
186
- # '''
187
- # SELECT * FROM llm
188
- # ''' + where_clause
189
- # )
190
-
191
- # llm = []
192
- # for row in c.fetchall():
193
- # llm.append(Provider(*row))
194
-
195
- c.close()
196
-
197
- return Stats(platform, [], [])
198
-
199
- def get_total_message_count(self) -> int:
200
- try:
201
- c = self.conn.cursor()
202
- except sqlite3.ProgrammingError:
203
- c = self._get_conn(self.db_path).cursor()
204
-
205
- c.execute(
206
- """
207
- SELECT SUM(count) FROM platform
208
- """
209
- )
210
- res = c.fetchone()
211
- c.close()
212
- return res[0]
213
-
214
- def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats:
215
- """获取 offset_sec 秒前到现在的基础统计数据(合并)"""
216
- where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}"
217
-
218
- try:
219
- c = self.conn.cursor()
220
- except sqlite3.ProgrammingError:
221
- c = self._get_conn(self.db_path).cursor()
222
-
223
- c.execute(
224
- """
225
- SELECT name, SUM(count), timestamp FROM platform
226
- """
227
- + where_clause
228
- + " GROUP BY name"
229
- )
230
-
231
- platform = []
232
- for row in c.fetchall():
233
- platform.append(Platform(*row))
234
-
235
- c.close()
236
-
237
- return Stats(platform, [], [])
238
-
239
- def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
240
- try:
241
- c = self.conn.cursor()
242
- except sqlite3.ProgrammingError:
243
- c = self._get_conn(self.db_path).cursor()
244
-
245
- c.execute(
246
- """
247
- SELECT * FROM webchat_conversation WHERE user_id = ? AND cid = ?
248
- """,
249
- (user_id, cid),
250
- )
251
-
252
- res = c.fetchone()
253
- c.close()
254
-
255
- if not res:
256
- return
257
-
258
- return Conversation(*res)
259
-
260
- def new_conversation(self, user_id: str, cid: str):
261
- history = "[]"
262
- updated_at = int(time.time())
263
- created_at = updated_at
264
- self._exec_sql(
265
- """
266
- INSERT INTO webchat_conversation(user_id, cid, history, updated_at, created_at) VALUES (?, ?, ?, ?, ?)
267
- """,
268
- (user_id, cid, history, updated_at, created_at),
269
- )
270
-
271
- def get_conversations(self, user_id: str) -> Tuple:
272
- try:
273
- c = self.conn.cursor()
274
- except sqlite3.ProgrammingError:
275
- c = self._get_conn(self.db_path).cursor()
276
-
277
- c.execute(
278
- """
279
- SELECT cid, created_at, updated_at, title, persona_id FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC
280
- """,
281
- (user_id,),
282
- )
283
-
284
- res = c.fetchall()
285
- c.close()
286
- conversations = []
287
- for row in res:
288
- cid = row[0]
289
- created_at = row[1]
290
- updated_at = row[2]
291
- title = row[3]
292
- persona_id = row[4]
293
- conversations.append(
294
- Conversation("", cid, "[]", created_at, updated_at, title, persona_id)
295
- )
296
- return conversations
297
-
298
- def update_conversation(self, user_id: str, cid: str, history: str):
299
- """更新对话,并且同时更新时间"""
300
- updated_at = int(time.time())
301
- self._exec_sql(
302
- """
303
- UPDATE webchat_conversation SET history = ?, updated_at = ? WHERE user_id = ? AND cid = ?
304
- """,
305
- (history, updated_at, user_id, cid),
306
- )
307
-
308
- def update_conversation_title(self, user_id: str, cid: str, title: str):
309
- self._exec_sql(
310
- """
311
- UPDATE webchat_conversation SET title = ? WHERE user_id = ? AND cid = ?
312
- """,
313
- (title, user_id, cid),
314
- )
315
-
316
- def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
317
- self._exec_sql(
318
- """
319
- UPDATE webchat_conversation SET persona_id = ? WHERE user_id = ? AND cid = ?
320
- """,
321
- (persona_id, user_id, cid),
322
- )
323
-
324
- def delete_conversation(self, user_id: str, cid: str):
325
- self._exec_sql(
326
- """
327
- DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ?
328
- """,
329
- (user_id, cid),
330
- )
331
-
332
- def insert_atri_vision_data(self, vision: ATRIVision):
333
- ts = int(time.time())
334
- keywords = ",".join(vision.keywords)
335
- self._exec_sql(
336
- """
337
- INSERT INTO atri_vision(id, url_or_path, caption, is_meme, keywords, platform_name, session_id, sender_nickname, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
338
- """,
339
- (
340
- vision.id,
341
- vision.url_or_path,
342
- vision.caption,
343
- vision.is_meme,
344
- keywords,
345
- vision.platform_name,
346
- vision.session_id,
347
- vision.sender_nickname,
348
- ts,
349
- ),
350
- )
351
-
352
- def get_atri_vision_data(self) -> Tuple:
353
- try:
354
- c = self.conn.cursor()
355
- except sqlite3.ProgrammingError:
356
- c = self._get_conn(self.db_path).cursor()
357
-
358
- c.execute(
359
- """
360
- SELECT * FROM atri_vision
361
- """
362
- )
363
-
364
- res = c.fetchall()
365
- visions = []
366
- for row in res:
367
- visions.append(ATRIVision(*row))
368
- c.close()
369
- return visions
370
-
371
- def get_atri_vision_data_by_path_or_id(
372
- self, url_or_path: str, id: str
373
- ) -> ATRIVision:
374
- try:
375
- c = self.conn.cursor()
376
- except sqlite3.ProgrammingError:
377
- c = self._get_conn(self.db_path).cursor()
378
-
379
- c.execute(
380
- """
381
- SELECT * FROM atri_vision WHERE url_or_path = ? OR id = ?
382
- """,
383
- (url_or_path, id),
384
- )
385
-
386
- res = c.fetchone()
387
- c.close()
388
- if res:
389
- return ATRIVision(*res)
390
- return None
391
-
392
- def get_all_conversations(
393
- self, page: int = 1, page_size: int = 20
394
- ) -> Tuple[List[Dict[str, Any]], int]:
395
- """获取所有对话,支持分页,按更新时间降序排序"""
396
- try:
397
- c = self.conn.cursor()
398
- except sqlite3.ProgrammingError:
399
- c = self._get_conn(self.db_path).cursor()
400
-
401
- try:
402
- # 获取总记录数
403
- c.execute("""
404
- SELECT COUNT(*) FROM webchat_conversation
405
- """)
406
- total_count = c.fetchone()[0]
407
-
408
- # 计算偏移量
279
+ async def get_session_conversations(
280
+ self,
281
+ page=1,
282
+ page_size=20,
283
+ search_query=None,
284
+ platform=None,
285
+ ) -> tuple[list[dict], int]:
286
+ """Get paginated session conversations with joined conversation and persona details."""
287
+ async with self.get_db() as session:
288
+ session: AsyncSession
409
289
  offset = (page - 1) * page_size
410
290
 
411
- # 获取分页数据,按更新时间降序排序
412
- c.execute(
413
- """
414
- SELECT user_id, cid, created_at, updated_at, title, persona_id
415
- FROM webchat_conversation
416
- ORDER BY updated_at DESC
417
- LIMIT ? OFFSET ?
418
- """,
419
- (page_size, offset),
291
+ base_query = (
292
+ select(
293
+ col(Preference.scope_id).label("session_id"),
294
+ func.json_extract(Preference.value, "$.val").label(
295
+ "conversation_id",
296
+ ), # type: ignore
297
+ col(ConversationV2.persona_id).label("persona_id"),
298
+ col(ConversationV2.title).label("title"),
299
+ col(Persona.persona_id).label("persona_name"),
300
+ )
301
+ .select_from(Preference)
302
+ .outerjoin(
303
+ ConversationV2,
304
+ func.json_extract(Preference.value, "$.val")
305
+ == ConversationV2.conversation_id,
306
+ )
307
+ .outerjoin(
308
+ Persona,
309
+ col(ConversationV2.persona_id) == Persona.persona_id,
310
+ )
311
+ .where(Preference.scope == "umo", Preference.key == "sel_conv_id")
420
312
  )
421
313
 
422
- rows = c.fetchall()
423
-
424
- conversations = []
314
+ # 搜索筛选
315
+ if search_query:
316
+ search_pattern = f"%{search_query}%"
317
+ base_query = base_query.where(
318
+ or_(
319
+ col(Preference.scope_id).ilike(search_pattern),
320
+ col(ConversationV2.title).ilike(search_pattern),
321
+ col(Persona.persona_id).ilike(search_pattern),
322
+ ),
323
+ )
425
324
 
426
- for row in rows:
427
- user_id, cid, created_at, updated_at, title, persona_id = row
428
- # 确保 cid 是字符串类型且至少有8个字符,否则使用一个默认值
429
- safe_cid = str(cid) if cid else "unknown"
430
- display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid
325
+ # 平台筛选
326
+ if platform:
327
+ platform_pattern = f"{platform}:%"
328
+ base_query = base_query.where(
329
+ col(Preference.scope_id).like(platform_pattern),
330
+ )
431
331
 
432
- conversations.append(
433
- {
434
- "user_id": user_id or "",
435
- "cid": safe_cid,
436
- "title": title or f"对话 {display_cid}",
437
- "persona_id": persona_id or "",
438
- "created_at": created_at or 0,
439
- "updated_at": updated_at or 0,
440
- }
332
+ # 排序
333
+ base_query = base_query.order_by(Preference.scope_id)
334
+
335
+ # 分页结果
336
+ result_query = base_query.offset(offset).limit(page_size)
337
+ result = await session.execute(result_query)
338
+ rows = result.fetchall()
339
+
340
+ # 查询总数(应用相同的筛选条件)
341
+ count_base_query = (
342
+ select(func.count(col(Preference.scope_id)))
343
+ .select_from(Preference)
344
+ .outerjoin(
345
+ ConversationV2,
346
+ func.json_extract(Preference.value, "$.val")
347
+ == ConversationV2.conversation_id,
441
348
  )
349
+ .outerjoin(
350
+ Persona,
351
+ col(ConversationV2.persona_id) == Persona.persona_id,
352
+ )
353
+ .where(Preference.scope == "umo", Preference.key == "sel_conv_id")
354
+ )
442
355
 
443
- return conversations, total_count
356
+ # 应用相同的搜索和平台筛选条件到计数查询
357
+ if search_query:
358
+ search_pattern = f"%{search_query}%"
359
+ count_base_query = count_base_query.where(
360
+ or_(
361
+ col(Preference.scope_id).ilike(search_pattern),
362
+ col(ConversationV2.title).ilike(search_pattern),
363
+ col(Persona.persona_id).ilike(search_pattern),
364
+ ),
365
+ )
444
366
 
445
- except Exception as _:
446
- # 返回空列表和0,确保即使出错也有有效的返回值
447
- return [], 0
448
- finally:
449
- c.close()
367
+ if platform:
368
+ platform_pattern = f"{platform}:%"
369
+ count_base_query = count_base_query.where(
370
+ col(Preference.scope_id).like(platform_pattern),
371
+ )
450
372
 
451
- def get_filtered_conversations(
373
+ total_result = await session.execute(count_base_query)
374
+ total = total_result.scalar() or 0
375
+
376
+ sessions_data = [
377
+ {
378
+ "session_id": row.session_id,
379
+ "conversation_id": row.conversation_id,
380
+ "persona_id": row.persona_id,
381
+ "title": row.title,
382
+ "persona_name": row.persona_name,
383
+ }
384
+ for row in rows
385
+ ]
386
+ return sessions_data, total
387
+
388
+ async def insert_platform_message_history(
452
389
  self,
453
- page: int = 1,
454
- page_size: int = 20,
455
- platforms: List[str] = None,
456
- message_types: List[str] = None,
457
- search_query: str = None,
458
- exclude_ids: List[str] = None,
459
- exclude_platforms: List[str] = None,
460
- ) -> Tuple[List[Dict[str, Any]], int]:
461
- """获取筛选后的对话列表"""
462
- try:
463
- c = self.conn.cursor()
464
- except sqlite3.ProgrammingError:
465
- c = self._get_conn(self.db_path).cursor()
466
-
467
- try:
468
- # 构建查询条件
469
- where_clauses = []
470
- params = []
390
+ platform_id,
391
+ user_id,
392
+ content,
393
+ sender_id=None,
394
+ sender_name=None,
395
+ ):
396
+ """Insert a new platform message history record."""
397
+ async with self.get_db() as session:
398
+ session: AsyncSession
399
+ async with session.begin():
400
+ new_history = PlatformMessageHistory(
401
+ platform_id=platform_id,
402
+ user_id=user_id,
403
+ content=content,
404
+ sender_id=sender_id,
405
+ sender_name=sender_name,
406
+ )
407
+ session.add(new_history)
408
+ return new_history
471
409
 
472
- # 平台筛选
473
- if platforms and len(platforms) > 0:
474
- platform_conditions = []
475
- for platform in platforms:
476
- platform_conditions.append("user_id LIKE ?")
477
- params.append(f"{platform}:%")
478
-
479
- if platform_conditions:
480
- where_clauses.append(f"({' OR '.join(platform_conditions)})")
481
-
482
- # 消息类型筛选
483
- if message_types and len(message_types) > 0:
484
- message_type_conditions = []
485
- for msg_type in message_types:
486
- message_type_conditions.append("user_id LIKE ?")
487
- params.append(f"%:{msg_type}:%")
488
-
489
- if message_type_conditions:
490
- where_clauses.append(f"({' OR '.join(message_type_conditions)})")
491
-
492
- # 搜索关键词
493
- if search_query:
494
- search_query = search_query.encode("unicode_escape").decode("utf-8")
495
- where_clauses.append(
496
- "(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)"
410
+ async def delete_platform_message_offset(
411
+ self,
412
+ platform_id,
413
+ user_id,
414
+ offset_sec=86400,
415
+ ):
416
+ """Delete platform message history records newer than the specified offset."""
417
+ async with self.get_db() as session:
418
+ session: AsyncSession
419
+ async with session.begin():
420
+ now = datetime.now()
421
+ cutoff_time = now - timedelta(seconds=offset_sec)
422
+ await session.execute(
423
+ delete(PlatformMessageHistory).where(
424
+ col(PlatformMessageHistory.platform_id) == platform_id,
425
+ col(PlatformMessageHistory.user_id) == user_id,
426
+ col(PlatformMessageHistory.created_at) >= cutoff_time,
427
+ ),
497
428
  )
498
- search_param = f"%{search_query}%"
499
- params.extend([search_param, search_param, search_param, search_param])
500
429
 
501
- # 排除特定用户ID
502
- if exclude_ids and len(exclude_ids) > 0:
503
- for exclude_id in exclude_ids:
504
- where_clauses.append("user_id NOT LIKE ?")
505
- params.append(f"{exclude_id}%")
430
+ async def get_platform_message_history(
431
+ self,
432
+ platform_id,
433
+ user_id,
434
+ page=1,
435
+ page_size=20,
436
+ ):
437
+ """Get platform message history records."""
438
+ async with self.get_db() as session:
439
+ session: AsyncSession
440
+ offset = (page - 1) * page_size
441
+ query = (
442
+ select(PlatformMessageHistory)
443
+ .where(
444
+ PlatformMessageHistory.platform_id == platform_id,
445
+ PlatformMessageHistory.user_id == user_id,
446
+ )
447
+ .order_by(desc(PlatformMessageHistory.created_at))
448
+ )
449
+ result = await session.execute(query.offset(offset).limit(page_size))
450
+ return result.scalars().all()
451
+
452
+ async def insert_attachment(self, path, type, mime_type):
453
+ """Insert a new attachment record."""
454
+ async with self.get_db() as session:
455
+ session: AsyncSession
456
+ async with session.begin():
457
+ new_attachment = Attachment(
458
+ path=path,
459
+ type=type,
460
+ mime_type=mime_type,
461
+ )
462
+ session.add(new_attachment)
463
+ return new_attachment
464
+
465
+ async def get_attachment_by_id(self, attachment_id):
466
+ """Get an attachment by its ID."""
467
+ async with self.get_db() as session:
468
+ session: AsyncSession
469
+ query = select(Attachment).where(Attachment.attachment_id == attachment_id)
470
+ result = await session.execute(query)
471
+ return result.scalar_one_or_none()
472
+
473
+ async def insert_persona(
474
+ self,
475
+ persona_id,
476
+ system_prompt,
477
+ begin_dialogs=None,
478
+ tools=None,
479
+ ):
480
+ """Insert a new persona record."""
481
+ async with self.get_db() as session:
482
+ session: AsyncSession
483
+ async with session.begin():
484
+ new_persona = Persona(
485
+ persona_id=persona_id,
486
+ system_prompt=system_prompt,
487
+ begin_dialogs=begin_dialogs or [],
488
+ tools=tools,
489
+ )
490
+ session.add(new_persona)
491
+ return new_persona
492
+
493
+ async def get_persona_by_id(self, persona_id):
494
+ """Get a persona by its ID."""
495
+ async with self.get_db() as session:
496
+ session: AsyncSession
497
+ query = select(Persona).where(Persona.persona_id == persona_id)
498
+ result = await session.execute(query)
499
+ return result.scalar_one_or_none()
500
+
501
+ async def get_personas(self):
502
+ """Get all personas for a specific bot."""
503
+ async with self.get_db() as session:
504
+ session: AsyncSession
505
+ query = select(Persona)
506
+ result = await session.execute(query)
507
+ return result.scalars().all()
508
+
509
+ async def update_persona(
510
+ self,
511
+ persona_id,
512
+ system_prompt=None,
513
+ begin_dialogs=None,
514
+ tools=NOT_GIVEN,
515
+ ):
516
+ """Update a persona's system prompt or begin dialogs."""
517
+ async with self.get_db() as session:
518
+ session: AsyncSession
519
+ async with session.begin():
520
+ query = update(Persona).where(col(Persona.persona_id) == persona_id)
521
+ values = {}
522
+ if system_prompt is not None:
523
+ values["system_prompt"] = system_prompt
524
+ if begin_dialogs is not None:
525
+ values["begin_dialogs"] = begin_dialogs
526
+ if tools is not NOT_GIVEN:
527
+ values["tools"] = tools
528
+ if not values:
529
+ return None
530
+ query = query.values(**values)
531
+ await session.execute(query)
532
+ return await self.get_persona_by_id(persona_id)
533
+
534
+ async def delete_persona(self, persona_id):
535
+ """Delete a persona by its ID."""
536
+ async with self.get_db() as session:
537
+ session: AsyncSession
538
+ async with session.begin():
539
+ await session.execute(
540
+ delete(Persona).where(col(Persona.persona_id) == persona_id),
541
+ )
506
542
 
507
- # 排除特定平台
508
- if exclude_platforms and len(exclude_platforms) > 0:
509
- for exclude_platform in exclude_platforms:
510
- where_clauses.append("user_id NOT LIKE ?")
511
- params.append(f"{exclude_platform}:%")
543
+ async def insert_preference_or_update(self, scope, scope_id, key, value):
544
+ """Insert a new preference record or update if it exists."""
545
+ async with self.get_db() as session:
546
+ session: AsyncSession
547
+ async with session.begin():
548
+ query = select(Preference).where(
549
+ Preference.scope == scope,
550
+ Preference.scope_id == scope_id,
551
+ Preference.key == key,
552
+ )
553
+ result = await session.execute(query)
554
+ existing_preference = result.scalar_one_or_none()
555
+ if existing_preference:
556
+ existing_preference.value = value
557
+ else:
558
+ new_preference = Preference(
559
+ scope=scope,
560
+ scope_id=scope_id,
561
+ key=key,
562
+ value=value,
563
+ )
564
+ session.add(new_preference)
565
+ return existing_preference or new_preference
566
+
567
+ async def get_preference(self, scope, scope_id, key):
568
+ """Get a preference by key."""
569
+ async with self.get_db() as session:
570
+ session: AsyncSession
571
+ query = select(Preference).where(
572
+ Preference.scope == scope,
573
+ Preference.scope_id == scope_id,
574
+ Preference.key == key,
575
+ )
576
+ result = await session.execute(query)
577
+ return result.scalar_one_or_none()
578
+
579
+ async def get_preferences(self, scope, scope_id=None, key=None):
580
+ """Get all preferences for a specific scope ID or key."""
581
+ async with self.get_db() as session:
582
+ session: AsyncSession
583
+ query = select(Preference).where(Preference.scope == scope)
584
+ if scope_id is not None:
585
+ query = query.where(Preference.scope_id == scope_id)
586
+ if key is not None:
587
+ query = query.where(Preference.key == key)
588
+ result = await session.execute(query)
589
+ return result.scalars().all()
590
+
591
+ async def remove_preference(self, scope, scope_id, key):
592
+ """Remove a preference by scope ID and key."""
593
+ async with self.get_db() as session:
594
+ session: AsyncSession
595
+ async with session.begin():
596
+ await session.execute(
597
+ delete(Preference).where(
598
+ col(Preference.scope) == scope,
599
+ col(Preference.scope_id) == scope_id,
600
+ col(Preference.key) == key,
601
+ ),
602
+ )
603
+ await session.commit()
604
+
605
+ async def clear_preferences(self, scope, scope_id):
606
+ """Clear all preferences for a specific scope ID."""
607
+ async with self.get_db() as session:
608
+ session: AsyncSession
609
+ async with session.begin():
610
+ await session.execute(
611
+ delete(Preference).where(
612
+ col(Preference.scope) == scope,
613
+ col(Preference.scope_id) == scope_id,
614
+ ),
615
+ )
616
+ await session.commit()
617
+
618
+ # ====
619
+ # Deprecated Methods
620
+ # ====
621
+
622
+ def get_base_stats(self, offset_sec=86400):
623
+ """Get base statistics within the specified offset in seconds."""
624
+
625
+ async def _inner():
626
+ async with self.get_db() as session:
627
+ session: AsyncSession
628
+ now = datetime.now()
629
+ start_time = now - timedelta(seconds=offset_sec)
630
+ result = await session.execute(
631
+ select(PlatformStat).where(PlatformStat.timestamp >= start_time),
632
+ )
633
+ all_datas = result.scalars().all()
634
+ deprecated_stats = DeprecatedStats()
635
+ for data in all_datas:
636
+ deprecated_stats.platform.append(
637
+ DeprecatedPlatformStat(
638
+ name=data.platform_id,
639
+ count=data.count,
640
+ timestamp=int(data.timestamp.timestamp()),
641
+ ),
642
+ )
643
+ return deprecated_stats
644
+
645
+ result = None
646
+
647
+ def runner():
648
+ nonlocal result
649
+ result = asyncio.run(_inner())
650
+
651
+ t = threading.Thread(target=runner)
652
+ t.start()
653
+ t.join()
654
+ return result
655
+
656
+ def get_total_message_count(self):
657
+ """Get the total message count from platform statistics."""
658
+
659
+ async def _inner():
660
+ async with self.get_db() as session:
661
+ session: AsyncSession
662
+ result = await session.execute(
663
+ select(func.sum(PlatformStat.count)).select_from(PlatformStat),
664
+ )
665
+ total_count = result.scalar_one_or_none()
666
+ return total_count if total_count is not None else 0
667
+
668
+ result = None
669
+
670
+ def runner():
671
+ nonlocal result
672
+ result = asyncio.run(_inner())
673
+
674
+ t = threading.Thread(target=runner)
675
+ t.start()
676
+ t.join()
677
+ return result
678
+
679
+ def get_grouped_base_stats(self, offset_sec=86400):
680
+ # group by platform_id
681
+ async def _inner():
682
+ async with self.get_db() as session:
683
+ session: AsyncSession
684
+ now = datetime.now()
685
+ start_time = now - timedelta(seconds=offset_sec)
686
+ result = await session.execute(
687
+ select(PlatformStat.platform_id, func.sum(PlatformStat.count))
688
+ .where(PlatformStat.timestamp >= start_time)
689
+ .group_by(PlatformStat.platform_id),
690
+ )
691
+ grouped_stats = result.all()
692
+ deprecated_stats = DeprecatedStats()
693
+ for platform_id, count in grouped_stats:
694
+ deprecated_stats.platform.append(
695
+ DeprecatedPlatformStat(
696
+ name=platform_id,
697
+ count=count,
698
+ timestamp=int(start_time.timestamp()),
699
+ ),
700
+ )
701
+ return deprecated_stats
702
+
703
+ result = None
704
+
705
+ def runner():
706
+ nonlocal result
707
+ result = asyncio.run(_inner())
708
+
709
+ t = threading.Thread(target=runner)
710
+ t.start()
711
+ t.join()
712
+ return result
713
+
714
+ # ====
715
+ # Platform Session Management
716
+ # ====
717
+
718
+ async def create_platform_session(
719
+ self,
720
+ creator: str,
721
+ platform_id: str = "webchat",
722
+ session_id: str | None = None,
723
+ display_name: str | None = None,
724
+ is_group: int = 0,
725
+ ) -> PlatformSession:
726
+ """Create a new Platform session."""
727
+ kwargs = {}
728
+ if session_id:
729
+ kwargs["session_id"] = session_id
730
+
731
+ async with self.get_db() as session:
732
+ session: AsyncSession
733
+ async with session.begin():
734
+ new_session = PlatformSession(
735
+ creator=creator,
736
+ platform_id=platform_id,
737
+ display_name=display_name,
738
+ is_group=is_group,
739
+ **kwargs,
740
+ )
741
+ session.add(new_session)
742
+ await session.flush()
743
+ await session.refresh(new_session)
744
+ return new_session
745
+
746
+ async def get_platform_session_by_id(
747
+ self, session_id: str
748
+ ) -> PlatformSession | None:
749
+ """Get a Platform session by its ID."""
750
+ async with self.get_db() as session:
751
+ session: AsyncSession
752
+ query = select(PlatformSession).where(
753
+ PlatformSession.session_id == session_id,
754
+ )
755
+ result = await session.execute(query)
756
+ return result.scalar_one_or_none()
512
757
 
513
- # 构建完整的 WHERE 子句
514
- where_sql = " WHERE " + " AND ".join(where_clauses) if where_clauses else ""
758
+ async def get_platform_sessions_by_creator(
759
+ self,
760
+ creator: str,
761
+ platform_id: str | None = None,
762
+ page: int = 1,
763
+ page_size: int = 20,
764
+ ) -> list[PlatformSession]:
765
+ """Get all Platform sessions for a specific creator (username) and optionally platform."""
766
+ async with self.get_db() as session:
767
+ session: AsyncSession
768
+ offset = (page - 1) * page_size
769
+ query = select(PlatformSession).where(PlatformSession.creator == creator)
515
770
 
516
- # 构建计数查询
517
- count_sql = f"SELECT COUNT(*) FROM webchat_conversation{where_sql}"
771
+ if platform_id:
772
+ query = query.where(PlatformSession.platform_id == platform_id)
518
773
 
519
- # 获取总记录数
520
- c.execute(count_sql, params)
521
- total_count = c.fetchone()[0]
774
+ query = (
775
+ query.order_by(desc(PlatformSession.updated_at))
776
+ .offset(offset)
777
+ .limit(page_size)
778
+ )
779
+ result = await session.execute(query)
780
+ return list(result.scalars().all())
522
781
 
523
- # 计算偏移量
524
- offset = (page - 1) * page_size
782
+ async def update_platform_session(
783
+ self,
784
+ session_id: str,
785
+ display_name: str | None = None,
786
+ ) -> None:
787
+ """Update a Platform session's updated_at timestamp and optionally display_name."""
788
+ async with self.get_db() as session:
789
+ session: AsyncSession
790
+ async with session.begin():
791
+ values: dict[str, T.Any] = {"updated_at": datetime.now(timezone.utc)}
792
+ if display_name is not None:
793
+ values["display_name"] = display_name
794
+
795
+ await session.execute(
796
+ update(PlatformSession)
797
+ .where(col(PlatformSession.session_id) == session_id)
798
+ .values(**values),
799
+ )
525
800
 
526
- # 构建分页数据查询
527
- data_sql = f"""
528
- SELECT user_id, cid, created_at, updated_at, title, persona_id
529
- FROM webchat_conversation
530
- {where_sql}
531
- ORDER BY updated_at DESC
532
- LIMIT ? OFFSET ?
533
- """
534
- query_params = params + [page_size, offset]
535
-
536
- # 获取分页数据
537
- c.execute(data_sql, query_params)
538
- rows = c.fetchall()
539
-
540
- conversations = []
541
-
542
- for row in rows:
543
- user_id, cid, created_at, updated_at, title, persona_id = row
544
- # 确保 cid 是字符串类型,否则使用一个默认值
545
- safe_cid = str(cid) if cid else "unknown"
546
- display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid
547
-
548
- conversations.append(
549
- {
550
- "user_id": user_id or "",
551
- "cid": safe_cid,
552
- "title": title or f"对话 {display_cid}",
553
- "persona_id": persona_id or "",
554
- "created_at": created_at or 0,
555
- "updated_at": updated_at or 0,
556
- }
557
- )
558
-
559
- return conversations, total_count
560
-
561
- except Exception as _:
562
- # 返回空列表和0,确保即使出错也有有效的返回值
563
- return [], 0
564
- finally:
565
- c.close()
801
+ async def delete_platform_session(self, session_id: str) -> None:
802
+ """Delete a Platform session by its ID."""
803
+ async with self.get_db() as session:
804
+ session: AsyncSession
805
+ async with session.begin():
806
+ await session.execute(
807
+ delete(PlatformSession).where(
808
+ col(PlatformSession.session_id) == session_id,
809
+ ),
810
+ )