AstrBot 3.5.6__py3-none-any.whl → 4.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/api/__init__.py +16 -4
- astrbot/api/all.py +2 -1
- astrbot/api/event/__init__.py +5 -6
- astrbot/api/event/filter/__init__.py +37 -34
- astrbot/api/platform/__init__.py +7 -8
- astrbot/api/provider/__init__.py +8 -7
- astrbot/api/star/__init__.py +3 -4
- astrbot/api/util/__init__.py +2 -2
- astrbot/cli/__init__.py +1 -0
- astrbot/cli/__main__.py +18 -197
- astrbot/cli/commands/__init__.py +6 -0
- astrbot/cli/commands/cmd_conf.py +209 -0
- astrbot/cli/commands/cmd_init.py +56 -0
- astrbot/cli/commands/cmd_plug.py +245 -0
- astrbot/cli/commands/cmd_run.py +62 -0
- astrbot/cli/utils/__init__.py +18 -0
- astrbot/cli/utils/basic.py +76 -0
- astrbot/cli/utils/plugin.py +246 -0
- astrbot/cli/utils/version_comparator.py +90 -0
- astrbot/core/__init__.py +17 -19
- astrbot/core/agent/agent.py +14 -0
- astrbot/core/agent/handoff.py +38 -0
- astrbot/core/agent/hooks.py +30 -0
- astrbot/core/agent/mcp_client.py +385 -0
- astrbot/core/agent/message.py +175 -0
- astrbot/core/agent/response.py +14 -0
- astrbot/core/agent/run_context.py +22 -0
- astrbot/core/agent/runners/__init__.py +3 -0
- astrbot/core/agent/runners/base.py +65 -0
- astrbot/core/agent/runners/coze/coze_agent_runner.py +367 -0
- astrbot/core/agent/runners/coze/coze_api_client.py +324 -0
- astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py +403 -0
- astrbot/core/agent/runners/dify/dify_agent_runner.py +336 -0
- astrbot/core/agent/runners/dify/dify_api_client.py +195 -0
- astrbot/core/agent/runners/tool_loop_agent_runner.py +400 -0
- astrbot/core/agent/tool.py +285 -0
- astrbot/core/agent/tool_executor.py +17 -0
- astrbot/core/astr_agent_context.py +19 -0
- astrbot/core/astr_agent_hooks.py +36 -0
- astrbot/core/astr_agent_run_util.py +80 -0
- astrbot/core/astr_agent_tool_exec.py +246 -0
- astrbot/core/astrbot_config_mgr.py +275 -0
- astrbot/core/config/__init__.py +2 -2
- astrbot/core/config/astrbot_config.py +60 -20
- astrbot/core/config/default.py +1972 -453
- astrbot/core/config/i18n_utils.py +110 -0
- astrbot/core/conversation_mgr.py +285 -75
- astrbot/core/core_lifecycle.py +167 -62
- astrbot/core/db/__init__.py +305 -102
- astrbot/core/db/migration/helper.py +69 -0
- astrbot/core/db/migration/migra_3_to_4.py +357 -0
- astrbot/core/db/migration/migra_45_to_46.py +44 -0
- astrbot/core/db/migration/migra_webchat_session.py +131 -0
- astrbot/core/db/migration/shared_preferences_v3.py +48 -0
- astrbot/core/db/migration/sqlite_v3.py +497 -0
- astrbot/core/db/po.py +259 -55
- astrbot/core/db/sqlite.py +773 -528
- astrbot/core/db/vec_db/base.py +73 -0
- astrbot/core/db/vec_db/faiss_impl/__init__.py +3 -0
- astrbot/core/db/vec_db/faiss_impl/document_storage.py +392 -0
- astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +93 -0
- astrbot/core/db/vec_db/faiss_impl/sqlite_init.sql +17 -0
- astrbot/core/db/vec_db/faiss_impl/vec_db.py +204 -0
- astrbot/core/event_bus.py +26 -22
- astrbot/core/exceptions.py +9 -0
- astrbot/core/file_token_service.py +98 -0
- astrbot/core/initial_loader.py +19 -10
- astrbot/core/knowledge_base/chunking/__init__.py +9 -0
- astrbot/core/knowledge_base/chunking/base.py +25 -0
- astrbot/core/knowledge_base/chunking/fixed_size.py +59 -0
- astrbot/core/knowledge_base/chunking/recursive.py +161 -0
- astrbot/core/knowledge_base/kb_db_sqlite.py +301 -0
- astrbot/core/knowledge_base/kb_helper.py +642 -0
- astrbot/core/knowledge_base/kb_mgr.py +330 -0
- astrbot/core/knowledge_base/models.py +120 -0
- astrbot/core/knowledge_base/parsers/__init__.py +13 -0
- astrbot/core/knowledge_base/parsers/base.py +51 -0
- astrbot/core/knowledge_base/parsers/markitdown_parser.py +26 -0
- astrbot/core/knowledge_base/parsers/pdf_parser.py +101 -0
- astrbot/core/knowledge_base/parsers/text_parser.py +42 -0
- astrbot/core/knowledge_base/parsers/url_parser.py +103 -0
- astrbot/core/knowledge_base/parsers/util.py +13 -0
- astrbot/core/knowledge_base/prompts.py +65 -0
- astrbot/core/knowledge_base/retrieval/__init__.py +14 -0
- astrbot/core/knowledge_base/retrieval/hit_stopwords.txt +767 -0
- astrbot/core/knowledge_base/retrieval/manager.py +276 -0
- astrbot/core/knowledge_base/retrieval/rank_fusion.py +142 -0
- astrbot/core/knowledge_base/retrieval/sparse_retriever.py +136 -0
- astrbot/core/log.py +21 -15
- astrbot/core/message/components.py +413 -287
- astrbot/core/message/message_event_result.py +35 -24
- astrbot/core/persona_mgr.py +192 -0
- astrbot/core/pipeline/__init__.py +14 -14
- astrbot/core/pipeline/content_safety_check/stage.py +13 -9
- astrbot/core/pipeline/content_safety_check/strategies/__init__.py +1 -2
- astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py +13 -14
- astrbot/core/pipeline/content_safety_check/strategies/keywords.py +2 -1
- astrbot/core/pipeline/content_safety_check/strategies/strategy.py +6 -6
- astrbot/core/pipeline/context.py +7 -1
- astrbot/core/pipeline/context_utils.py +107 -0
- astrbot/core/pipeline/preprocess_stage/stage.py +63 -36
- astrbot/core/pipeline/process_stage/method/agent_request.py +48 -0
- astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +464 -0
- astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +202 -0
- astrbot/core/pipeline/process_stage/method/star_request.py +26 -32
- astrbot/core/pipeline/process_stage/stage.py +21 -15
- astrbot/core/pipeline/process_stage/utils.py +125 -0
- astrbot/core/pipeline/rate_limit_check/stage.py +34 -36
- astrbot/core/pipeline/respond/stage.py +142 -101
- astrbot/core/pipeline/result_decorate/stage.py +124 -57
- astrbot/core/pipeline/scheduler.py +21 -16
- astrbot/core/pipeline/session_status_check/stage.py +37 -0
- astrbot/core/pipeline/stage.py +11 -76
- astrbot/core/pipeline/waking_check/stage.py +69 -33
- astrbot/core/pipeline/whitelist_check/stage.py +10 -7
- astrbot/core/platform/__init__.py +6 -6
- astrbot/core/platform/astr_message_event.py +107 -129
- astrbot/core/platform/astrbot_message.py +32 -12
- astrbot/core/platform/manager.py +62 -18
- astrbot/core/platform/message_session.py +30 -0
- astrbot/core/platform/platform.py +16 -24
- astrbot/core/platform/platform_metadata.py +9 -4
- astrbot/core/platform/register.py +12 -7
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +136 -60
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +126 -46
- astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +63 -31
- astrbot/core/platform/sources/dingtalk/dingtalk_event.py +30 -26
- astrbot/core/platform/sources/discord/client.py +129 -0
- astrbot/core/platform/sources/discord/components.py +139 -0
- astrbot/core/platform/sources/discord/discord_platform_adapter.py +473 -0
- astrbot/core/platform/sources/discord/discord_platform_event.py +313 -0
- astrbot/core/platform/sources/lark/lark_adapter.py +27 -18
- astrbot/core/platform/sources/lark/lark_event.py +39 -13
- astrbot/core/platform/sources/misskey/misskey_adapter.py +770 -0
- astrbot/core/platform/sources/misskey/misskey_api.py +964 -0
- astrbot/core/platform/sources/misskey/misskey_event.py +163 -0
- astrbot/core/platform/sources/misskey/misskey_utils.py +550 -0
- astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +149 -33
- astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +41 -26
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +36 -17
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py +3 -1
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +14 -8
- astrbot/core/platform/sources/satori/satori_adapter.py +792 -0
- astrbot/core/platform/sources/satori/satori_event.py +432 -0
- astrbot/core/platform/sources/slack/client.py +164 -0
- astrbot/core/platform/sources/slack/slack_adapter.py +416 -0
- astrbot/core/platform/sources/slack/slack_event.py +253 -0
- astrbot/core/platform/sources/telegram/tg_adapter.py +100 -43
- astrbot/core/platform/sources/telegram/tg_event.py +136 -36
- astrbot/core/platform/sources/webchat/webchat_adapter.py +72 -22
- astrbot/core/platform/sources/webchat/webchat_event.py +46 -22
- astrbot/core/platform/sources/webchat/webchat_queue_mgr.py +35 -0
- astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +926 -0
- astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +178 -0
- astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py +159 -0
- astrbot/core/platform/sources/wecom/wecom_adapter.py +169 -27
- astrbot/core/platform/sources/wecom/wecom_event.py +162 -77
- astrbot/core/platform/sources/wecom/wecom_kf.py +279 -0
- astrbot/core/platform/sources/wecom/wecom_kf_message.py +196 -0
- astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +297 -0
- astrbot/core/platform/sources/wecom_ai_bot/__init__.py +15 -0
- astrbot/core/platform/sources/wecom_ai_bot/ierror.py +19 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +472 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +417 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +152 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +153 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +168 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +209 -0
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +306 -0
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +186 -0
- astrbot/core/platform_message_history_mgr.py +49 -0
- astrbot/core/provider/__init__.py +2 -3
- astrbot/core/provider/entites.py +8 -8
- astrbot/core/provider/entities.py +154 -98
- astrbot/core/provider/func_tool_manager.py +446 -458
- astrbot/core/provider/manager.py +345 -207
- astrbot/core/provider/provider.py +188 -73
- astrbot/core/provider/register.py +9 -7
- astrbot/core/provider/sources/anthropic_source.py +295 -115
- astrbot/core/provider/sources/azure_tts_source.py +224 -0
- astrbot/core/provider/sources/bailian_rerank_source.py +236 -0
- astrbot/core/provider/sources/dashscope_tts.py +138 -14
- astrbot/core/provider/sources/edge_tts_source.py +24 -19
- astrbot/core/provider/sources/fishaudio_tts_api_source.py +58 -13
- astrbot/core/provider/sources/gemini_embedding_source.py +61 -0
- astrbot/core/provider/sources/gemini_source.py +310 -132
- astrbot/core/provider/sources/gemini_tts_source.py +81 -0
- astrbot/core/provider/sources/groq_source.py +15 -0
- astrbot/core/provider/sources/gsv_selfhosted_source.py +151 -0
- astrbot/core/provider/sources/gsvi_tts_source.py +14 -7
- astrbot/core/provider/sources/minimax_tts_api_source.py +159 -0
- astrbot/core/provider/sources/openai_embedding_source.py +40 -0
- astrbot/core/provider/sources/openai_source.py +241 -145
- astrbot/core/provider/sources/openai_tts_api_source.py +18 -7
- astrbot/core/provider/sources/sensevoice_selfhosted_source.py +13 -11
- astrbot/core/provider/sources/vllm_rerank_source.py +71 -0
- astrbot/core/provider/sources/volcengine_tts.py +115 -0
- astrbot/core/provider/sources/whisper_api_source.py +18 -13
- astrbot/core/provider/sources/whisper_selfhosted_source.py +19 -12
- astrbot/core/provider/sources/xinference_rerank_source.py +116 -0
- astrbot/core/provider/sources/xinference_stt_provider.py +197 -0
- astrbot/core/provider/sources/zhipu_source.py +6 -73
- astrbot/core/star/__init__.py +43 -11
- astrbot/core/star/config.py +17 -18
- astrbot/core/star/context.py +362 -138
- astrbot/core/star/filter/__init__.py +4 -3
- astrbot/core/star/filter/command.py +111 -35
- astrbot/core/star/filter/command_group.py +46 -34
- astrbot/core/star/filter/custom_filter.py +6 -5
- astrbot/core/star/filter/event_message_type.py +4 -2
- astrbot/core/star/filter/permission.py +4 -2
- astrbot/core/star/filter/platform_adapter_type.py +45 -12
- astrbot/core/star/filter/regex.py +4 -2
- astrbot/core/star/register/__init__.py +19 -15
- astrbot/core/star/register/star.py +41 -13
- astrbot/core/star/register/star_handler.py +236 -86
- astrbot/core/star/session_llm_manager.py +280 -0
- astrbot/core/star/session_plugin_manager.py +170 -0
- astrbot/core/star/star.py +36 -43
- astrbot/core/star/star_handler.py +47 -85
- astrbot/core/star/star_manager.py +442 -260
- astrbot/core/star/star_tools.py +167 -45
- astrbot/core/star/updator.py +17 -20
- astrbot/core/umop_config_router.py +106 -0
- astrbot/core/updator.py +38 -13
- astrbot/core/utils/astrbot_path.py +39 -0
- astrbot/core/utils/command_parser.py +1 -1
- astrbot/core/utils/io.py +119 -60
- astrbot/core/utils/log_pipe.py +1 -1
- astrbot/core/utils/metrics.py +11 -10
- astrbot/core/utils/migra_helper.py +73 -0
- astrbot/core/utils/path_util.py +63 -62
- astrbot/core/utils/pip_installer.py +37 -15
- astrbot/core/utils/session_lock.py +29 -0
- astrbot/core/utils/session_waiter.py +19 -20
- astrbot/core/utils/shared_preferences.py +174 -34
- astrbot/core/utils/t2i/__init__.py +4 -1
- astrbot/core/utils/t2i/local_strategy.py +386 -238
- astrbot/core/utils/t2i/network_strategy.py +109 -49
- astrbot/core/utils/t2i/renderer.py +29 -14
- astrbot/core/utils/t2i/template/astrbot_powershell.html +184 -0
- astrbot/core/utils/t2i/template_manager.py +111 -0
- astrbot/core/utils/tencent_record_helper.py +115 -1
- astrbot/core/utils/version_comparator.py +10 -13
- astrbot/core/zip_updator.py +112 -65
- astrbot/dashboard/routes/__init__.py +20 -13
- astrbot/dashboard/routes/auth.py +20 -9
- astrbot/dashboard/routes/chat.py +297 -141
- astrbot/dashboard/routes/config.py +652 -55
- astrbot/dashboard/routes/conversation.py +107 -37
- astrbot/dashboard/routes/file.py +26 -0
- astrbot/dashboard/routes/knowledge_base.py +1244 -0
- astrbot/dashboard/routes/log.py +27 -2
- astrbot/dashboard/routes/persona.py +202 -0
- astrbot/dashboard/routes/plugin.py +197 -139
- astrbot/dashboard/routes/route.py +27 -7
- astrbot/dashboard/routes/session_management.py +354 -0
- astrbot/dashboard/routes/stat.py +85 -18
- astrbot/dashboard/routes/static_file.py +5 -2
- astrbot/dashboard/routes/t2i.py +233 -0
- astrbot/dashboard/routes/tools.py +184 -120
- astrbot/dashboard/routes/update.py +59 -36
- astrbot/dashboard/server.py +96 -36
- astrbot/dashboard/utils.py +165 -0
- astrbot-4.7.0.dist-info/METADATA +294 -0
- astrbot-4.7.0.dist-info/RECORD +274 -0
- {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/WHEEL +1 -1
- astrbot/core/db/plugin/sqlite_impl.py +0 -112
- astrbot/core/db/sqlite_init.sql +0 -50
- astrbot/core/pipeline/platform_compatibility/stage.py +0 -56
- astrbot/core/pipeline/process_stage/method/llm_request.py +0 -606
- astrbot/core/platform/sources/gewechat/client.py +0 -806
- astrbot/core/platform/sources/gewechat/downloader.py +0 -55
- astrbot/core/platform/sources/gewechat/gewechat_event.py +0 -255
- astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py +0 -103
- astrbot/core/platform/sources/gewechat/xml_data_parser.py +0 -110
- astrbot/core/provider/sources/dashscope_source.py +0 -203
- astrbot/core/provider/sources/dify_source.py +0 -281
- astrbot/core/provider/sources/llmtuner_source.py +0 -132
- astrbot/core/rag/embedding/openai_source.py +0 -20
- astrbot/core/rag/knowledge_db_mgr.py +0 -94
- astrbot/core/rag/store/__init__.py +0 -9
- astrbot/core/rag/store/chroma_db.py +0 -42
- astrbot/core/utils/dify_api_client.py +0 -152
- astrbot-3.5.6.dist-info/METADATA +0 -249
- astrbot-3.5.6.dist-info/RECORD +0 -158
- {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/entry_points.txt +0 -0
- {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/licenses/LICENSE +0 -0
astrbot/core/db/sqlite.py
CHANGED
|
@@ -1,565 +1,810 @@
|
|
|
1
|
-
import
|
|
2
|
-
import
|
|
3
|
-
import
|
|
4
|
-
from
|
|
5
|
-
|
|
6
|
-
from
|
|
1
|
+
import asyncio
|
|
2
|
+
import threading
|
|
3
|
+
import typing as T
|
|
4
|
+
from datetime import datetime, timedelta, timezone
|
|
5
|
+
|
|
6
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
7
|
+
from sqlmodel import col, delete, desc, func, or_, select, text, update
|
|
8
|
+
|
|
9
|
+
from astrbot.core.db import BaseDatabase
|
|
10
|
+
from astrbot.core.db.po import (
|
|
11
|
+
Attachment,
|
|
12
|
+
ConversationV2,
|
|
13
|
+
Persona,
|
|
14
|
+
PlatformMessageHistory,
|
|
15
|
+
PlatformSession,
|
|
16
|
+
PlatformStat,
|
|
17
|
+
Preference,
|
|
18
|
+
SQLModel,
|
|
19
|
+
)
|
|
20
|
+
from astrbot.core.db.po import (
|
|
21
|
+
Platform as DeprecatedPlatformStat,
|
|
22
|
+
)
|
|
23
|
+
from astrbot.core.db.po import (
|
|
24
|
+
Stats as DeprecatedStats,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
|
|
7
28
|
|
|
8
29
|
|
|
9
30
|
class SQLiteDatabase(BaseDatabase):
|
|
10
31
|
def __init__(self, db_path: str) -> None:
|
|
11
|
-
super().__init__()
|
|
12
32
|
self.db_path = db_path
|
|
33
|
+
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
|
|
34
|
+
self.inited = False
|
|
35
|
+
super().__init__()
|
|
36
|
+
|
|
37
|
+
async def initialize(self) -> None:
|
|
38
|
+
"""Initialize the database by creating tables if they do not exist."""
|
|
39
|
+
async with self.engine.begin() as conn:
|
|
40
|
+
await conn.run_sync(SQLModel.metadata.create_all)
|
|
41
|
+
await conn.execute(text("PRAGMA journal_mode=WAL"))
|
|
42
|
+
await conn.execute(text("PRAGMA synchronous=NORMAL"))
|
|
43
|
+
await conn.execute(text("PRAGMA cache_size=20000"))
|
|
44
|
+
await conn.execute(text("PRAGMA temp_store=MEMORY"))
|
|
45
|
+
await conn.execute(text("PRAGMA mmap_size=134217728"))
|
|
46
|
+
await conn.execute(text("PRAGMA optimize"))
|
|
47
|
+
await conn.commit()
|
|
48
|
+
|
|
49
|
+
# ====
|
|
50
|
+
# Platform Statistics
|
|
51
|
+
# ====
|
|
52
|
+
|
|
53
|
+
async def insert_platform_stats(
|
|
54
|
+
self,
|
|
55
|
+
platform_id,
|
|
56
|
+
platform_type,
|
|
57
|
+
count=1,
|
|
58
|
+
timestamp=None,
|
|
59
|
+
) -> None:
|
|
60
|
+
"""Insert a new platform statistic record."""
|
|
61
|
+
async with self.get_db() as session:
|
|
62
|
+
session: AsyncSession
|
|
63
|
+
async with session.begin():
|
|
64
|
+
if timestamp is None:
|
|
65
|
+
timestamp = datetime.now().replace(
|
|
66
|
+
minute=0,
|
|
67
|
+
second=0,
|
|
68
|
+
microsecond=0,
|
|
69
|
+
)
|
|
70
|
+
current_hour = timestamp
|
|
71
|
+
await session.execute(
|
|
72
|
+
text("""
|
|
73
|
+
INSERT INTO platform_stats (timestamp, platform_id, platform_type, count)
|
|
74
|
+
VALUES (:timestamp, :platform_id, :platform_type, :count)
|
|
75
|
+
ON CONFLICT(timestamp, platform_id, platform_type) DO UPDATE SET
|
|
76
|
+
count = platform_stats.count + EXCLUDED.count
|
|
77
|
+
"""),
|
|
78
|
+
{
|
|
79
|
+
"timestamp": current_hour,
|
|
80
|
+
"platform_id": platform_id,
|
|
81
|
+
"platform_type": platform_type,
|
|
82
|
+
"count": count,
|
|
83
|
+
},
|
|
84
|
+
)
|
|
13
85
|
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
# 检查 webchat_conversation 的 title 字段是否存在
|
|
24
|
-
c.execute(
|
|
25
|
-
"""
|
|
26
|
-
PRAGMA table_info(webchat_conversation)
|
|
27
|
-
"""
|
|
28
|
-
)
|
|
29
|
-
res = c.fetchall()
|
|
30
|
-
has_title = False
|
|
31
|
-
has_persona_id = False
|
|
32
|
-
for row in res:
|
|
33
|
-
if row[1] == "title":
|
|
34
|
-
has_title = True
|
|
35
|
-
if row[1] == "persona_id":
|
|
36
|
-
has_persona_id = True
|
|
37
|
-
if not has_title:
|
|
38
|
-
c.execute(
|
|
39
|
-
"""
|
|
40
|
-
ALTER TABLE webchat_conversation ADD COLUMN title TEXT;
|
|
41
|
-
"""
|
|
86
|
+
async def count_platform_stats(self) -> int:
|
|
87
|
+
"""Count the number of platform statistics records."""
|
|
88
|
+
async with self.get_db() as session:
|
|
89
|
+
session: AsyncSession
|
|
90
|
+
result = await session.execute(
|
|
91
|
+
select(func.count(col(PlatformStat.platform_id))).select_from(
|
|
92
|
+
PlatformStat,
|
|
93
|
+
),
|
|
42
94
|
)
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
95
|
+
count = result.scalar_one_or_none()
|
|
96
|
+
return count if count is not None else 0
|
|
97
|
+
|
|
98
|
+
async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat]:
|
|
99
|
+
"""Get platform statistics within the specified offset in seconds and group by platform_id."""
|
|
100
|
+
async with self.get_db() as session:
|
|
101
|
+
session: AsyncSession
|
|
102
|
+
now = datetime.now()
|
|
103
|
+
start_time = now - timedelta(seconds=offset_sec)
|
|
104
|
+
result = await session.execute(
|
|
105
|
+
text("""
|
|
106
|
+
SELECT * FROM platform_stats
|
|
107
|
+
WHERE timestamp >= :start_time
|
|
108
|
+
ORDER BY timestamp DESC
|
|
109
|
+
GROUP BY platform_id
|
|
110
|
+
"""),
|
|
111
|
+
{"start_time": start_time},
|
|
49
112
|
)
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
113
|
+
return list(result.scalars().all())
|
|
114
|
+
|
|
115
|
+
# ====
|
|
116
|
+
# Conversation Management
|
|
117
|
+
# ====
|
|
118
|
+
|
|
119
|
+
async def get_conversations(self, user_id=None, platform_id=None):
|
|
120
|
+
async with self.get_db() as session:
|
|
121
|
+
session: AsyncSession
|
|
122
|
+
query = select(ConversationV2)
|
|
123
|
+
|
|
124
|
+
if user_id:
|
|
125
|
+
query = query.where(ConversationV2.user_id == user_id)
|
|
126
|
+
if platform_id:
|
|
127
|
+
query = query.where(ConversationV2.platform_id == platform_id)
|
|
128
|
+
# order by
|
|
129
|
+
query = query.order_by(desc(ConversationV2.created_at))
|
|
130
|
+
result = await session.execute(query)
|
|
131
|
+
|
|
132
|
+
return result.scalars().all()
|
|
133
|
+
|
|
134
|
+
async def get_conversation_by_id(self, cid):
|
|
135
|
+
async with self.get_db() as session:
|
|
136
|
+
session: AsyncSession
|
|
137
|
+
query = select(ConversationV2).where(ConversationV2.conversation_id == cid)
|
|
138
|
+
result = await session.execute(query)
|
|
139
|
+
return result.scalar_one_or_none()
|
|
140
|
+
|
|
141
|
+
async def get_all_conversations(self, page=1, page_size=20):
|
|
142
|
+
async with self.get_db() as session:
|
|
143
|
+
session: AsyncSession
|
|
144
|
+
offset = (page - 1) * page_size
|
|
145
|
+
result = await session.execute(
|
|
146
|
+
select(ConversationV2)
|
|
147
|
+
.order_by(desc(ConversationV2.created_at))
|
|
148
|
+
.offset(offset)
|
|
149
|
+
.limit(page_size),
|
|
83
150
|
)
|
|
151
|
+
return result.scalars().all()
|
|
84
152
|
|
|
85
|
-
def
|
|
86
|
-
|
|
153
|
+
async def get_filtered_conversations(
|
|
154
|
+
self,
|
|
155
|
+
page=1,
|
|
156
|
+
page_size=20,
|
|
157
|
+
platform_ids=None,
|
|
158
|
+
search_query="",
|
|
159
|
+
**kwargs,
|
|
160
|
+
):
|
|
161
|
+
async with self.get_db() as session:
|
|
162
|
+
session: AsyncSession
|
|
163
|
+
# Build the base query with filters
|
|
164
|
+
base_query = select(ConversationV2)
|
|
165
|
+
|
|
166
|
+
if platform_ids:
|
|
167
|
+
base_query = base_query.where(
|
|
168
|
+
col(ConversationV2.platform_id).in_(platform_ids),
|
|
169
|
+
)
|
|
170
|
+
if search_query:
|
|
171
|
+
search_query = search_query.encode("unicode_escape").decode("utf-8")
|
|
172
|
+
base_query = base_query.where(
|
|
173
|
+
or_(
|
|
174
|
+
col(ConversationV2.title).ilike(f"%{search_query}%"),
|
|
175
|
+
col(ConversationV2.content).ilike(f"%{search_query}%"),
|
|
176
|
+
col(ConversationV2.user_id).ilike(f"%{search_query}%"),
|
|
177
|
+
col(ConversationV2.conversation_id).ilike(f"%{search_query}%"),
|
|
178
|
+
),
|
|
179
|
+
)
|
|
180
|
+
if "message_types" in kwargs and len(kwargs["message_types"]) > 0:
|
|
181
|
+
for msg_type in kwargs["message_types"]:
|
|
182
|
+
base_query = base_query.where(
|
|
183
|
+
col(ConversationV2.user_id).ilike(f"%:{msg_type}:%"),
|
|
184
|
+
)
|
|
185
|
+
if "platforms" in kwargs and len(kwargs["platforms"]) > 0:
|
|
186
|
+
base_query = base_query.where(
|
|
187
|
+
col(ConversationV2.platform_id).in_(kwargs["platforms"]),
|
|
188
|
+
)
|
|
87
189
|
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
INSERT INTO command(name, count, timestamp) VALUES (?, ?, ?)
|
|
93
|
-
""",
|
|
94
|
-
(k, v, int(time.time())),
|
|
95
|
-
)
|
|
190
|
+
# Get total count matching the filters
|
|
191
|
+
count_query = select(func.count()).select_from(base_query.subquery())
|
|
192
|
+
total_count = await session.execute(count_query)
|
|
193
|
+
total = total_count.scalar_one()
|
|
96
194
|
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
(k, v, int(time.time())),
|
|
195
|
+
# Get paginated results
|
|
196
|
+
offset = (page - 1) * page_size
|
|
197
|
+
result_query = (
|
|
198
|
+
base_query.order_by(desc(ConversationV2.created_at))
|
|
199
|
+
.offset(offset)
|
|
200
|
+
.limit(page_size)
|
|
104
201
|
)
|
|
202
|
+
result = await session.execute(result_query)
|
|
203
|
+
conversations = result.scalars().all()
|
|
105
204
|
|
|
106
|
-
|
|
107
|
-
res = self.get_llm_history(session_id, provider_type)
|
|
108
|
-
if res:
|
|
109
|
-
self._exec_sql(
|
|
110
|
-
"""
|
|
111
|
-
UPDATE llm_history SET content = ? WHERE session_id = ? AND provider_type = ?
|
|
112
|
-
""",
|
|
113
|
-
(content, session_id, provider_type),
|
|
114
|
-
)
|
|
115
|
-
else:
|
|
116
|
-
self._exec_sql(
|
|
117
|
-
"""
|
|
118
|
-
INSERT INTO llm_history(provider_type, session_id, content) VALUES (?, ?, ?)
|
|
119
|
-
""",
|
|
120
|
-
(provider_type, session_id, content),
|
|
121
|
-
)
|
|
205
|
+
return conversations, total
|
|
122
206
|
|
|
123
|
-
def
|
|
124
|
-
self,
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
207
|
+
async def create_conversation(
|
|
208
|
+
self,
|
|
209
|
+
user_id,
|
|
210
|
+
platform_id,
|
|
211
|
+
content=None,
|
|
212
|
+
title=None,
|
|
213
|
+
persona_id=None,
|
|
214
|
+
cid=None,
|
|
215
|
+
created_at=None,
|
|
216
|
+
updated_at=None,
|
|
217
|
+
):
|
|
218
|
+
kwargs = {}
|
|
219
|
+
if cid:
|
|
220
|
+
kwargs["conversation_id"] = cid
|
|
221
|
+
if created_at:
|
|
222
|
+
kwargs["created_at"] = created_at
|
|
223
|
+
if updated_at:
|
|
224
|
+
kwargs["updated_at"] = updated_at
|
|
225
|
+
async with self.get_db() as session:
|
|
226
|
+
session: AsyncSession
|
|
227
|
+
async with session.begin():
|
|
228
|
+
new_conversation = ConversationV2(
|
|
229
|
+
user_id=user_id,
|
|
230
|
+
content=content or [],
|
|
231
|
+
platform_id=platform_id,
|
|
232
|
+
title=title,
|
|
233
|
+
persona_id=persona_id,
|
|
234
|
+
**kwargs,
|
|
235
|
+
)
|
|
236
|
+
session.add(new_conversation)
|
|
237
|
+
return new_conversation
|
|
238
|
+
|
|
239
|
+
async def update_conversation(self, cid, title=None, persona_id=None, content=None):
|
|
240
|
+
async with self.get_db() as session:
|
|
241
|
+
session: AsyncSession
|
|
242
|
+
async with session.begin():
|
|
243
|
+
query = update(ConversationV2).where(
|
|
244
|
+
col(ConversationV2.conversation_id) == cid,
|
|
245
|
+
)
|
|
246
|
+
values = {}
|
|
247
|
+
if title is not None:
|
|
248
|
+
values["title"] = title
|
|
249
|
+
if persona_id is not None:
|
|
250
|
+
values["persona_id"] = persona_id
|
|
251
|
+
if content is not None:
|
|
252
|
+
values["content"] = content
|
|
253
|
+
if not values:
|
|
254
|
+
return None
|
|
255
|
+
query = query.values(**values)
|
|
256
|
+
await session.execute(query)
|
|
257
|
+
return await self.get_conversation_by_id(cid)
|
|
258
|
+
|
|
259
|
+
async def delete_conversation(self, cid):
|
|
260
|
+
async with self.get_db() as session:
|
|
261
|
+
session: AsyncSession
|
|
262
|
+
async with session.begin():
|
|
263
|
+
await session.execute(
|
|
264
|
+
delete(ConversationV2).where(
|
|
265
|
+
col(ConversationV2.conversation_id) == cid,
|
|
266
|
+
),
|
|
267
|
+
)
|
|
130
268
|
|
|
131
|
-
|
|
132
|
-
|
|
269
|
+
async def delete_conversations_by_user_id(self, user_id: str) -> None:
|
|
270
|
+
async with self.get_db() as session:
|
|
271
|
+
session: AsyncSession
|
|
272
|
+
async with session.begin():
|
|
273
|
+
await session.execute(
|
|
274
|
+
delete(ConversationV2).where(
|
|
275
|
+
col(ConversationV2.user_id) == user_id
|
|
276
|
+
),
|
|
277
|
+
)
|
|
133
278
|
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
sql += " WHERE " + " AND ".join(conditions)
|
|
145
|
-
|
|
146
|
-
c.execute(sql, params)
|
|
147
|
-
|
|
148
|
-
res = c.fetchall()
|
|
149
|
-
histories = []
|
|
150
|
-
for row in res:
|
|
151
|
-
histories.append(LLMHistory(*row))
|
|
152
|
-
c.close()
|
|
153
|
-
return histories
|
|
154
|
-
|
|
155
|
-
def get_base_stats(self, offset_sec: int = 86400) -> Stats:
|
|
156
|
-
"""获取 offset_sec 秒前到现在的基础统计数据"""
|
|
157
|
-
where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}"
|
|
158
|
-
|
|
159
|
-
try:
|
|
160
|
-
c = self.conn.cursor()
|
|
161
|
-
except sqlite3.ProgrammingError:
|
|
162
|
-
c = self._get_conn(self.db_path).cursor()
|
|
163
|
-
|
|
164
|
-
c.execute(
|
|
165
|
-
"""
|
|
166
|
-
SELECT * FROM platform
|
|
167
|
-
"""
|
|
168
|
-
+ where_clause
|
|
169
|
-
)
|
|
170
|
-
|
|
171
|
-
platform = []
|
|
172
|
-
for row in c.fetchall():
|
|
173
|
-
platform.append(Platform(*row))
|
|
174
|
-
|
|
175
|
-
# c.execute(
|
|
176
|
-
# '''
|
|
177
|
-
# SELECT * FROM command
|
|
178
|
-
# ''' + where_clause
|
|
179
|
-
# )
|
|
180
|
-
|
|
181
|
-
# command = []
|
|
182
|
-
# for row in c.fetchall():
|
|
183
|
-
# command.append(Command(*row))
|
|
184
|
-
|
|
185
|
-
# c.execute(
|
|
186
|
-
# '''
|
|
187
|
-
# SELECT * FROM llm
|
|
188
|
-
# ''' + where_clause
|
|
189
|
-
# )
|
|
190
|
-
|
|
191
|
-
# llm = []
|
|
192
|
-
# for row in c.fetchall():
|
|
193
|
-
# llm.append(Provider(*row))
|
|
194
|
-
|
|
195
|
-
c.close()
|
|
196
|
-
|
|
197
|
-
return Stats(platform, [], [])
|
|
198
|
-
|
|
199
|
-
def get_total_message_count(self) -> int:
|
|
200
|
-
try:
|
|
201
|
-
c = self.conn.cursor()
|
|
202
|
-
except sqlite3.ProgrammingError:
|
|
203
|
-
c = self._get_conn(self.db_path).cursor()
|
|
204
|
-
|
|
205
|
-
c.execute(
|
|
206
|
-
"""
|
|
207
|
-
SELECT SUM(count) FROM platform
|
|
208
|
-
"""
|
|
209
|
-
)
|
|
210
|
-
res = c.fetchone()
|
|
211
|
-
c.close()
|
|
212
|
-
return res[0]
|
|
213
|
-
|
|
214
|
-
def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats:
|
|
215
|
-
"""获取 offset_sec 秒前到现在的基础统计数据(合并)"""
|
|
216
|
-
where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}"
|
|
217
|
-
|
|
218
|
-
try:
|
|
219
|
-
c = self.conn.cursor()
|
|
220
|
-
except sqlite3.ProgrammingError:
|
|
221
|
-
c = self._get_conn(self.db_path).cursor()
|
|
222
|
-
|
|
223
|
-
c.execute(
|
|
224
|
-
"""
|
|
225
|
-
SELECT name, SUM(count), timestamp FROM platform
|
|
226
|
-
"""
|
|
227
|
-
+ where_clause
|
|
228
|
-
+ " GROUP BY name"
|
|
229
|
-
)
|
|
230
|
-
|
|
231
|
-
platform = []
|
|
232
|
-
for row in c.fetchall():
|
|
233
|
-
platform.append(Platform(*row))
|
|
234
|
-
|
|
235
|
-
c.close()
|
|
236
|
-
|
|
237
|
-
return Stats(platform, [], [])
|
|
238
|
-
|
|
239
|
-
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
|
|
240
|
-
try:
|
|
241
|
-
c = self.conn.cursor()
|
|
242
|
-
except sqlite3.ProgrammingError:
|
|
243
|
-
c = self._get_conn(self.db_path).cursor()
|
|
244
|
-
|
|
245
|
-
c.execute(
|
|
246
|
-
"""
|
|
247
|
-
SELECT * FROM webchat_conversation WHERE user_id = ? AND cid = ?
|
|
248
|
-
""",
|
|
249
|
-
(user_id, cid),
|
|
250
|
-
)
|
|
251
|
-
|
|
252
|
-
res = c.fetchone()
|
|
253
|
-
c.close()
|
|
254
|
-
|
|
255
|
-
if not res:
|
|
256
|
-
return
|
|
257
|
-
|
|
258
|
-
return Conversation(*res)
|
|
259
|
-
|
|
260
|
-
def new_conversation(self, user_id: str, cid: str):
|
|
261
|
-
history = "[]"
|
|
262
|
-
updated_at = int(time.time())
|
|
263
|
-
created_at = updated_at
|
|
264
|
-
self._exec_sql(
|
|
265
|
-
"""
|
|
266
|
-
INSERT INTO webchat_conversation(user_id, cid, history, updated_at, created_at) VALUES (?, ?, ?, ?, ?)
|
|
267
|
-
""",
|
|
268
|
-
(user_id, cid, history, updated_at, created_at),
|
|
269
|
-
)
|
|
270
|
-
|
|
271
|
-
def get_conversations(self, user_id: str) -> Tuple:
|
|
272
|
-
try:
|
|
273
|
-
c = self.conn.cursor()
|
|
274
|
-
except sqlite3.ProgrammingError:
|
|
275
|
-
c = self._get_conn(self.db_path).cursor()
|
|
276
|
-
|
|
277
|
-
c.execute(
|
|
278
|
-
"""
|
|
279
|
-
SELECT cid, created_at, updated_at, title, persona_id FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC
|
|
280
|
-
""",
|
|
281
|
-
(user_id,),
|
|
282
|
-
)
|
|
283
|
-
|
|
284
|
-
res = c.fetchall()
|
|
285
|
-
c.close()
|
|
286
|
-
conversations = []
|
|
287
|
-
for row in res:
|
|
288
|
-
cid = row[0]
|
|
289
|
-
created_at = row[1]
|
|
290
|
-
updated_at = row[2]
|
|
291
|
-
title = row[3]
|
|
292
|
-
persona_id = row[4]
|
|
293
|
-
conversations.append(
|
|
294
|
-
Conversation("", cid, "[]", created_at, updated_at, title, persona_id)
|
|
295
|
-
)
|
|
296
|
-
return conversations
|
|
297
|
-
|
|
298
|
-
def update_conversation(self, user_id: str, cid: str, history: str):
|
|
299
|
-
"""更新对话,并且同时更新时间"""
|
|
300
|
-
updated_at = int(time.time())
|
|
301
|
-
self._exec_sql(
|
|
302
|
-
"""
|
|
303
|
-
UPDATE webchat_conversation SET history = ?, updated_at = ? WHERE user_id = ? AND cid = ?
|
|
304
|
-
""",
|
|
305
|
-
(history, updated_at, user_id, cid),
|
|
306
|
-
)
|
|
307
|
-
|
|
308
|
-
def update_conversation_title(self, user_id: str, cid: str, title: str):
|
|
309
|
-
self._exec_sql(
|
|
310
|
-
"""
|
|
311
|
-
UPDATE webchat_conversation SET title = ? WHERE user_id = ? AND cid = ?
|
|
312
|
-
""",
|
|
313
|
-
(title, user_id, cid),
|
|
314
|
-
)
|
|
315
|
-
|
|
316
|
-
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
|
|
317
|
-
self._exec_sql(
|
|
318
|
-
"""
|
|
319
|
-
UPDATE webchat_conversation SET persona_id = ? WHERE user_id = ? AND cid = ?
|
|
320
|
-
""",
|
|
321
|
-
(persona_id, user_id, cid),
|
|
322
|
-
)
|
|
323
|
-
|
|
324
|
-
def delete_conversation(self, user_id: str, cid: str):
|
|
325
|
-
self._exec_sql(
|
|
326
|
-
"""
|
|
327
|
-
DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ?
|
|
328
|
-
""",
|
|
329
|
-
(user_id, cid),
|
|
330
|
-
)
|
|
331
|
-
|
|
332
|
-
def insert_atri_vision_data(self, vision: ATRIVision):
|
|
333
|
-
ts = int(time.time())
|
|
334
|
-
keywords = ",".join(vision.keywords)
|
|
335
|
-
self._exec_sql(
|
|
336
|
-
"""
|
|
337
|
-
INSERT INTO atri_vision(id, url_or_path, caption, is_meme, keywords, platform_name, session_id, sender_nickname, timestamp) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
338
|
-
""",
|
|
339
|
-
(
|
|
340
|
-
vision.id,
|
|
341
|
-
vision.url_or_path,
|
|
342
|
-
vision.caption,
|
|
343
|
-
vision.is_meme,
|
|
344
|
-
keywords,
|
|
345
|
-
vision.platform_name,
|
|
346
|
-
vision.session_id,
|
|
347
|
-
vision.sender_nickname,
|
|
348
|
-
ts,
|
|
349
|
-
),
|
|
350
|
-
)
|
|
351
|
-
|
|
352
|
-
def get_atri_vision_data(self) -> Tuple:
|
|
353
|
-
try:
|
|
354
|
-
c = self.conn.cursor()
|
|
355
|
-
except sqlite3.ProgrammingError:
|
|
356
|
-
c = self._get_conn(self.db_path).cursor()
|
|
357
|
-
|
|
358
|
-
c.execute(
|
|
359
|
-
"""
|
|
360
|
-
SELECT * FROM atri_vision
|
|
361
|
-
"""
|
|
362
|
-
)
|
|
363
|
-
|
|
364
|
-
res = c.fetchall()
|
|
365
|
-
visions = []
|
|
366
|
-
for row in res:
|
|
367
|
-
visions.append(ATRIVision(*row))
|
|
368
|
-
c.close()
|
|
369
|
-
return visions
|
|
370
|
-
|
|
371
|
-
def get_atri_vision_data_by_path_or_id(
|
|
372
|
-
self, url_or_path: str, id: str
|
|
373
|
-
) -> ATRIVision:
|
|
374
|
-
try:
|
|
375
|
-
c = self.conn.cursor()
|
|
376
|
-
except sqlite3.ProgrammingError:
|
|
377
|
-
c = self._get_conn(self.db_path).cursor()
|
|
378
|
-
|
|
379
|
-
c.execute(
|
|
380
|
-
"""
|
|
381
|
-
SELECT * FROM atri_vision WHERE url_or_path = ? OR id = ?
|
|
382
|
-
""",
|
|
383
|
-
(url_or_path, id),
|
|
384
|
-
)
|
|
385
|
-
|
|
386
|
-
res = c.fetchone()
|
|
387
|
-
c.close()
|
|
388
|
-
if res:
|
|
389
|
-
return ATRIVision(*res)
|
|
390
|
-
return None
|
|
391
|
-
|
|
392
|
-
def get_all_conversations(
|
|
393
|
-
self, page: int = 1, page_size: int = 20
|
|
394
|
-
) -> Tuple[List[Dict[str, Any]], int]:
|
|
395
|
-
"""获取所有对话,支持分页,按更新时间降序排序"""
|
|
396
|
-
try:
|
|
397
|
-
c = self.conn.cursor()
|
|
398
|
-
except sqlite3.ProgrammingError:
|
|
399
|
-
c = self._get_conn(self.db_path).cursor()
|
|
400
|
-
|
|
401
|
-
try:
|
|
402
|
-
# 获取总记录数
|
|
403
|
-
c.execute("""
|
|
404
|
-
SELECT COUNT(*) FROM webchat_conversation
|
|
405
|
-
""")
|
|
406
|
-
total_count = c.fetchone()[0]
|
|
407
|
-
|
|
408
|
-
# 计算偏移量
|
|
279
|
+
async def get_session_conversations(
|
|
280
|
+
self,
|
|
281
|
+
page=1,
|
|
282
|
+
page_size=20,
|
|
283
|
+
search_query=None,
|
|
284
|
+
platform=None,
|
|
285
|
+
) -> tuple[list[dict], int]:
|
|
286
|
+
"""Get paginated session conversations with joined conversation and persona details."""
|
|
287
|
+
async with self.get_db() as session:
|
|
288
|
+
session: AsyncSession
|
|
409
289
|
offset = (page - 1) * page_size
|
|
410
290
|
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
291
|
+
base_query = (
|
|
292
|
+
select(
|
|
293
|
+
col(Preference.scope_id).label("session_id"),
|
|
294
|
+
func.json_extract(Preference.value, "$.val").label(
|
|
295
|
+
"conversation_id",
|
|
296
|
+
), # type: ignore
|
|
297
|
+
col(ConversationV2.persona_id).label("persona_id"),
|
|
298
|
+
col(ConversationV2.title).label("title"),
|
|
299
|
+
col(Persona.persona_id).label("persona_name"),
|
|
300
|
+
)
|
|
301
|
+
.select_from(Preference)
|
|
302
|
+
.outerjoin(
|
|
303
|
+
ConversationV2,
|
|
304
|
+
func.json_extract(Preference.value, "$.val")
|
|
305
|
+
== ConversationV2.conversation_id,
|
|
306
|
+
)
|
|
307
|
+
.outerjoin(
|
|
308
|
+
Persona,
|
|
309
|
+
col(ConversationV2.persona_id) == Persona.persona_id,
|
|
310
|
+
)
|
|
311
|
+
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
|
|
420
312
|
)
|
|
421
313
|
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
314
|
+
# 搜索筛选
|
|
315
|
+
if search_query:
|
|
316
|
+
search_pattern = f"%{search_query}%"
|
|
317
|
+
base_query = base_query.where(
|
|
318
|
+
or_(
|
|
319
|
+
col(Preference.scope_id).ilike(search_pattern),
|
|
320
|
+
col(ConversationV2.title).ilike(search_pattern),
|
|
321
|
+
col(Persona.persona_id).ilike(search_pattern),
|
|
322
|
+
),
|
|
323
|
+
)
|
|
425
324
|
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
325
|
+
# 平台筛选
|
|
326
|
+
if platform:
|
|
327
|
+
platform_pattern = f"{platform}:%"
|
|
328
|
+
base_query = base_query.where(
|
|
329
|
+
col(Preference.scope_id).like(platform_pattern),
|
|
330
|
+
)
|
|
431
331
|
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
332
|
+
# 排序
|
|
333
|
+
base_query = base_query.order_by(Preference.scope_id)
|
|
334
|
+
|
|
335
|
+
# 分页结果
|
|
336
|
+
result_query = base_query.offset(offset).limit(page_size)
|
|
337
|
+
result = await session.execute(result_query)
|
|
338
|
+
rows = result.fetchall()
|
|
339
|
+
|
|
340
|
+
# 查询总数(应用相同的筛选条件)
|
|
341
|
+
count_base_query = (
|
|
342
|
+
select(func.count(col(Preference.scope_id)))
|
|
343
|
+
.select_from(Preference)
|
|
344
|
+
.outerjoin(
|
|
345
|
+
ConversationV2,
|
|
346
|
+
func.json_extract(Preference.value, "$.val")
|
|
347
|
+
== ConversationV2.conversation_id,
|
|
441
348
|
)
|
|
349
|
+
.outerjoin(
|
|
350
|
+
Persona,
|
|
351
|
+
col(ConversationV2.persona_id) == Persona.persona_id,
|
|
352
|
+
)
|
|
353
|
+
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
|
|
354
|
+
)
|
|
442
355
|
|
|
443
|
-
|
|
356
|
+
# 应用相同的搜索和平台筛选条件到计数查询
|
|
357
|
+
if search_query:
|
|
358
|
+
search_pattern = f"%{search_query}%"
|
|
359
|
+
count_base_query = count_base_query.where(
|
|
360
|
+
or_(
|
|
361
|
+
col(Preference.scope_id).ilike(search_pattern),
|
|
362
|
+
col(ConversationV2.title).ilike(search_pattern),
|
|
363
|
+
col(Persona.persona_id).ilike(search_pattern),
|
|
364
|
+
),
|
|
365
|
+
)
|
|
444
366
|
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
367
|
+
if platform:
|
|
368
|
+
platform_pattern = f"{platform}:%"
|
|
369
|
+
count_base_query = count_base_query.where(
|
|
370
|
+
col(Preference.scope_id).like(platform_pattern),
|
|
371
|
+
)
|
|
450
372
|
|
|
451
|
-
|
|
373
|
+
total_result = await session.execute(count_base_query)
|
|
374
|
+
total = total_result.scalar() or 0
|
|
375
|
+
|
|
376
|
+
sessions_data = [
|
|
377
|
+
{
|
|
378
|
+
"session_id": row.session_id,
|
|
379
|
+
"conversation_id": row.conversation_id,
|
|
380
|
+
"persona_id": row.persona_id,
|
|
381
|
+
"title": row.title,
|
|
382
|
+
"persona_name": row.persona_name,
|
|
383
|
+
}
|
|
384
|
+
for row in rows
|
|
385
|
+
]
|
|
386
|
+
return sessions_data, total
|
|
387
|
+
|
|
388
|
+
async def insert_platform_message_history(
|
|
452
389
|
self,
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
390
|
+
platform_id,
|
|
391
|
+
user_id,
|
|
392
|
+
content,
|
|
393
|
+
sender_id=None,
|
|
394
|
+
sender_name=None,
|
|
395
|
+
):
|
|
396
|
+
"""Insert a new platform message history record."""
|
|
397
|
+
async with self.get_db() as session:
|
|
398
|
+
session: AsyncSession
|
|
399
|
+
async with session.begin():
|
|
400
|
+
new_history = PlatformMessageHistory(
|
|
401
|
+
platform_id=platform_id,
|
|
402
|
+
user_id=user_id,
|
|
403
|
+
content=content,
|
|
404
|
+
sender_id=sender_id,
|
|
405
|
+
sender_name=sender_name,
|
|
406
|
+
)
|
|
407
|
+
session.add(new_history)
|
|
408
|
+
return new_history
|
|
471
409
|
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
where_clauses.append(f"({' OR '.join(message_type_conditions)})")
|
|
491
|
-
|
|
492
|
-
# 搜索关键词
|
|
493
|
-
if search_query:
|
|
494
|
-
search_query = search_query.encode("unicode_escape").decode("utf-8")
|
|
495
|
-
where_clauses.append(
|
|
496
|
-
"(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)"
|
|
410
|
+
async def delete_platform_message_offset(
|
|
411
|
+
self,
|
|
412
|
+
platform_id,
|
|
413
|
+
user_id,
|
|
414
|
+
offset_sec=86400,
|
|
415
|
+
):
|
|
416
|
+
"""Delete platform message history records newer than the specified offset."""
|
|
417
|
+
async with self.get_db() as session:
|
|
418
|
+
session: AsyncSession
|
|
419
|
+
async with session.begin():
|
|
420
|
+
now = datetime.now()
|
|
421
|
+
cutoff_time = now - timedelta(seconds=offset_sec)
|
|
422
|
+
await session.execute(
|
|
423
|
+
delete(PlatformMessageHistory).where(
|
|
424
|
+
col(PlatformMessageHistory.platform_id) == platform_id,
|
|
425
|
+
col(PlatformMessageHistory.user_id) == user_id,
|
|
426
|
+
col(PlatformMessageHistory.created_at) >= cutoff_time,
|
|
427
|
+
),
|
|
497
428
|
)
|
|
498
|
-
search_param = f"%{search_query}%"
|
|
499
|
-
params.extend([search_param, search_param, search_param, search_param])
|
|
500
429
|
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
430
|
+
async def get_platform_message_history(
|
|
431
|
+
self,
|
|
432
|
+
platform_id,
|
|
433
|
+
user_id,
|
|
434
|
+
page=1,
|
|
435
|
+
page_size=20,
|
|
436
|
+
):
|
|
437
|
+
"""Get platform message history records."""
|
|
438
|
+
async with self.get_db() as session:
|
|
439
|
+
session: AsyncSession
|
|
440
|
+
offset = (page - 1) * page_size
|
|
441
|
+
query = (
|
|
442
|
+
select(PlatformMessageHistory)
|
|
443
|
+
.where(
|
|
444
|
+
PlatformMessageHistory.platform_id == platform_id,
|
|
445
|
+
PlatformMessageHistory.user_id == user_id,
|
|
446
|
+
)
|
|
447
|
+
.order_by(desc(PlatformMessageHistory.created_at))
|
|
448
|
+
)
|
|
449
|
+
result = await session.execute(query.offset(offset).limit(page_size))
|
|
450
|
+
return result.scalars().all()
|
|
451
|
+
|
|
452
|
+
async def insert_attachment(self, path, type, mime_type):
|
|
453
|
+
"""Insert a new attachment record."""
|
|
454
|
+
async with self.get_db() as session:
|
|
455
|
+
session: AsyncSession
|
|
456
|
+
async with session.begin():
|
|
457
|
+
new_attachment = Attachment(
|
|
458
|
+
path=path,
|
|
459
|
+
type=type,
|
|
460
|
+
mime_type=mime_type,
|
|
461
|
+
)
|
|
462
|
+
session.add(new_attachment)
|
|
463
|
+
return new_attachment
|
|
464
|
+
|
|
465
|
+
async def get_attachment_by_id(self, attachment_id):
|
|
466
|
+
"""Get an attachment by its ID."""
|
|
467
|
+
async with self.get_db() as session:
|
|
468
|
+
session: AsyncSession
|
|
469
|
+
query = select(Attachment).where(Attachment.attachment_id == attachment_id)
|
|
470
|
+
result = await session.execute(query)
|
|
471
|
+
return result.scalar_one_or_none()
|
|
472
|
+
|
|
473
|
+
async def insert_persona(
|
|
474
|
+
self,
|
|
475
|
+
persona_id,
|
|
476
|
+
system_prompt,
|
|
477
|
+
begin_dialogs=None,
|
|
478
|
+
tools=None,
|
|
479
|
+
):
|
|
480
|
+
"""Insert a new persona record."""
|
|
481
|
+
async with self.get_db() as session:
|
|
482
|
+
session: AsyncSession
|
|
483
|
+
async with session.begin():
|
|
484
|
+
new_persona = Persona(
|
|
485
|
+
persona_id=persona_id,
|
|
486
|
+
system_prompt=system_prompt,
|
|
487
|
+
begin_dialogs=begin_dialogs or [],
|
|
488
|
+
tools=tools,
|
|
489
|
+
)
|
|
490
|
+
session.add(new_persona)
|
|
491
|
+
return new_persona
|
|
492
|
+
|
|
493
|
+
async def get_persona_by_id(self, persona_id):
|
|
494
|
+
"""Get a persona by its ID."""
|
|
495
|
+
async with self.get_db() as session:
|
|
496
|
+
session: AsyncSession
|
|
497
|
+
query = select(Persona).where(Persona.persona_id == persona_id)
|
|
498
|
+
result = await session.execute(query)
|
|
499
|
+
return result.scalar_one_or_none()
|
|
500
|
+
|
|
501
|
+
async def get_personas(self):
|
|
502
|
+
"""Get all personas for a specific bot."""
|
|
503
|
+
async with self.get_db() as session:
|
|
504
|
+
session: AsyncSession
|
|
505
|
+
query = select(Persona)
|
|
506
|
+
result = await session.execute(query)
|
|
507
|
+
return result.scalars().all()
|
|
508
|
+
|
|
509
|
+
async def update_persona(
|
|
510
|
+
self,
|
|
511
|
+
persona_id,
|
|
512
|
+
system_prompt=None,
|
|
513
|
+
begin_dialogs=None,
|
|
514
|
+
tools=NOT_GIVEN,
|
|
515
|
+
):
|
|
516
|
+
"""Update a persona's system prompt or begin dialogs."""
|
|
517
|
+
async with self.get_db() as session:
|
|
518
|
+
session: AsyncSession
|
|
519
|
+
async with session.begin():
|
|
520
|
+
query = update(Persona).where(col(Persona.persona_id) == persona_id)
|
|
521
|
+
values = {}
|
|
522
|
+
if system_prompt is not None:
|
|
523
|
+
values["system_prompt"] = system_prompt
|
|
524
|
+
if begin_dialogs is not None:
|
|
525
|
+
values["begin_dialogs"] = begin_dialogs
|
|
526
|
+
if tools is not NOT_GIVEN:
|
|
527
|
+
values["tools"] = tools
|
|
528
|
+
if not values:
|
|
529
|
+
return None
|
|
530
|
+
query = query.values(**values)
|
|
531
|
+
await session.execute(query)
|
|
532
|
+
return await self.get_persona_by_id(persona_id)
|
|
533
|
+
|
|
534
|
+
async def delete_persona(self, persona_id):
|
|
535
|
+
"""Delete a persona by its ID."""
|
|
536
|
+
async with self.get_db() as session:
|
|
537
|
+
session: AsyncSession
|
|
538
|
+
async with session.begin():
|
|
539
|
+
await session.execute(
|
|
540
|
+
delete(Persona).where(col(Persona.persona_id) == persona_id),
|
|
541
|
+
)
|
|
506
542
|
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
543
|
+
async def insert_preference_or_update(self, scope, scope_id, key, value):
|
|
544
|
+
"""Insert a new preference record or update if it exists."""
|
|
545
|
+
async with self.get_db() as session:
|
|
546
|
+
session: AsyncSession
|
|
547
|
+
async with session.begin():
|
|
548
|
+
query = select(Preference).where(
|
|
549
|
+
Preference.scope == scope,
|
|
550
|
+
Preference.scope_id == scope_id,
|
|
551
|
+
Preference.key == key,
|
|
552
|
+
)
|
|
553
|
+
result = await session.execute(query)
|
|
554
|
+
existing_preference = result.scalar_one_or_none()
|
|
555
|
+
if existing_preference:
|
|
556
|
+
existing_preference.value = value
|
|
557
|
+
else:
|
|
558
|
+
new_preference = Preference(
|
|
559
|
+
scope=scope,
|
|
560
|
+
scope_id=scope_id,
|
|
561
|
+
key=key,
|
|
562
|
+
value=value,
|
|
563
|
+
)
|
|
564
|
+
session.add(new_preference)
|
|
565
|
+
return existing_preference or new_preference
|
|
566
|
+
|
|
567
|
+
async def get_preference(self, scope, scope_id, key):
|
|
568
|
+
"""Get a preference by key."""
|
|
569
|
+
async with self.get_db() as session:
|
|
570
|
+
session: AsyncSession
|
|
571
|
+
query = select(Preference).where(
|
|
572
|
+
Preference.scope == scope,
|
|
573
|
+
Preference.scope_id == scope_id,
|
|
574
|
+
Preference.key == key,
|
|
575
|
+
)
|
|
576
|
+
result = await session.execute(query)
|
|
577
|
+
return result.scalar_one_or_none()
|
|
578
|
+
|
|
579
|
+
async def get_preferences(self, scope, scope_id=None, key=None):
|
|
580
|
+
"""Get all preferences for a specific scope ID or key."""
|
|
581
|
+
async with self.get_db() as session:
|
|
582
|
+
session: AsyncSession
|
|
583
|
+
query = select(Preference).where(Preference.scope == scope)
|
|
584
|
+
if scope_id is not None:
|
|
585
|
+
query = query.where(Preference.scope_id == scope_id)
|
|
586
|
+
if key is not None:
|
|
587
|
+
query = query.where(Preference.key == key)
|
|
588
|
+
result = await session.execute(query)
|
|
589
|
+
return result.scalars().all()
|
|
590
|
+
|
|
591
|
+
async def remove_preference(self, scope, scope_id, key):
|
|
592
|
+
"""Remove a preference by scope ID and key."""
|
|
593
|
+
async with self.get_db() as session:
|
|
594
|
+
session: AsyncSession
|
|
595
|
+
async with session.begin():
|
|
596
|
+
await session.execute(
|
|
597
|
+
delete(Preference).where(
|
|
598
|
+
col(Preference.scope) == scope,
|
|
599
|
+
col(Preference.scope_id) == scope_id,
|
|
600
|
+
col(Preference.key) == key,
|
|
601
|
+
),
|
|
602
|
+
)
|
|
603
|
+
await session.commit()
|
|
604
|
+
|
|
605
|
+
async def clear_preferences(self, scope, scope_id):
|
|
606
|
+
"""Clear all preferences for a specific scope ID."""
|
|
607
|
+
async with self.get_db() as session:
|
|
608
|
+
session: AsyncSession
|
|
609
|
+
async with session.begin():
|
|
610
|
+
await session.execute(
|
|
611
|
+
delete(Preference).where(
|
|
612
|
+
col(Preference.scope) == scope,
|
|
613
|
+
col(Preference.scope_id) == scope_id,
|
|
614
|
+
),
|
|
615
|
+
)
|
|
616
|
+
await session.commit()
|
|
617
|
+
|
|
618
|
+
# ====
|
|
619
|
+
# Deprecated Methods
|
|
620
|
+
# ====
|
|
621
|
+
|
|
622
|
+
def get_base_stats(self, offset_sec=86400):
|
|
623
|
+
"""Get base statistics within the specified offset in seconds."""
|
|
624
|
+
|
|
625
|
+
async def _inner():
|
|
626
|
+
async with self.get_db() as session:
|
|
627
|
+
session: AsyncSession
|
|
628
|
+
now = datetime.now()
|
|
629
|
+
start_time = now - timedelta(seconds=offset_sec)
|
|
630
|
+
result = await session.execute(
|
|
631
|
+
select(PlatformStat).where(PlatformStat.timestamp >= start_time),
|
|
632
|
+
)
|
|
633
|
+
all_datas = result.scalars().all()
|
|
634
|
+
deprecated_stats = DeprecatedStats()
|
|
635
|
+
for data in all_datas:
|
|
636
|
+
deprecated_stats.platform.append(
|
|
637
|
+
DeprecatedPlatformStat(
|
|
638
|
+
name=data.platform_id,
|
|
639
|
+
count=data.count,
|
|
640
|
+
timestamp=int(data.timestamp.timestamp()),
|
|
641
|
+
),
|
|
642
|
+
)
|
|
643
|
+
return deprecated_stats
|
|
644
|
+
|
|
645
|
+
result = None
|
|
646
|
+
|
|
647
|
+
def runner():
|
|
648
|
+
nonlocal result
|
|
649
|
+
result = asyncio.run(_inner())
|
|
650
|
+
|
|
651
|
+
t = threading.Thread(target=runner)
|
|
652
|
+
t.start()
|
|
653
|
+
t.join()
|
|
654
|
+
return result
|
|
655
|
+
|
|
656
|
+
def get_total_message_count(self):
|
|
657
|
+
"""Get the total message count from platform statistics."""
|
|
658
|
+
|
|
659
|
+
async def _inner():
|
|
660
|
+
async with self.get_db() as session:
|
|
661
|
+
session: AsyncSession
|
|
662
|
+
result = await session.execute(
|
|
663
|
+
select(func.sum(PlatformStat.count)).select_from(PlatformStat),
|
|
664
|
+
)
|
|
665
|
+
total_count = result.scalar_one_or_none()
|
|
666
|
+
return total_count if total_count is not None else 0
|
|
667
|
+
|
|
668
|
+
result = None
|
|
669
|
+
|
|
670
|
+
def runner():
|
|
671
|
+
nonlocal result
|
|
672
|
+
result = asyncio.run(_inner())
|
|
673
|
+
|
|
674
|
+
t = threading.Thread(target=runner)
|
|
675
|
+
t.start()
|
|
676
|
+
t.join()
|
|
677
|
+
return result
|
|
678
|
+
|
|
679
|
+
def get_grouped_base_stats(self, offset_sec=86400):
|
|
680
|
+
# group by platform_id
|
|
681
|
+
async def _inner():
|
|
682
|
+
async with self.get_db() as session:
|
|
683
|
+
session: AsyncSession
|
|
684
|
+
now = datetime.now()
|
|
685
|
+
start_time = now - timedelta(seconds=offset_sec)
|
|
686
|
+
result = await session.execute(
|
|
687
|
+
select(PlatformStat.platform_id, func.sum(PlatformStat.count))
|
|
688
|
+
.where(PlatformStat.timestamp >= start_time)
|
|
689
|
+
.group_by(PlatformStat.platform_id),
|
|
690
|
+
)
|
|
691
|
+
grouped_stats = result.all()
|
|
692
|
+
deprecated_stats = DeprecatedStats()
|
|
693
|
+
for platform_id, count in grouped_stats:
|
|
694
|
+
deprecated_stats.platform.append(
|
|
695
|
+
DeprecatedPlatformStat(
|
|
696
|
+
name=platform_id,
|
|
697
|
+
count=count,
|
|
698
|
+
timestamp=int(start_time.timestamp()),
|
|
699
|
+
),
|
|
700
|
+
)
|
|
701
|
+
return deprecated_stats
|
|
702
|
+
|
|
703
|
+
result = None
|
|
704
|
+
|
|
705
|
+
def runner():
|
|
706
|
+
nonlocal result
|
|
707
|
+
result = asyncio.run(_inner())
|
|
708
|
+
|
|
709
|
+
t = threading.Thread(target=runner)
|
|
710
|
+
t.start()
|
|
711
|
+
t.join()
|
|
712
|
+
return result
|
|
713
|
+
|
|
714
|
+
# ====
|
|
715
|
+
# Platform Session Management
|
|
716
|
+
# ====
|
|
717
|
+
|
|
718
|
+
async def create_platform_session(
|
|
719
|
+
self,
|
|
720
|
+
creator: str,
|
|
721
|
+
platform_id: str = "webchat",
|
|
722
|
+
session_id: str | None = None,
|
|
723
|
+
display_name: str | None = None,
|
|
724
|
+
is_group: int = 0,
|
|
725
|
+
) -> PlatformSession:
|
|
726
|
+
"""Create a new Platform session."""
|
|
727
|
+
kwargs = {}
|
|
728
|
+
if session_id:
|
|
729
|
+
kwargs["session_id"] = session_id
|
|
730
|
+
|
|
731
|
+
async with self.get_db() as session:
|
|
732
|
+
session: AsyncSession
|
|
733
|
+
async with session.begin():
|
|
734
|
+
new_session = PlatformSession(
|
|
735
|
+
creator=creator,
|
|
736
|
+
platform_id=platform_id,
|
|
737
|
+
display_name=display_name,
|
|
738
|
+
is_group=is_group,
|
|
739
|
+
**kwargs,
|
|
740
|
+
)
|
|
741
|
+
session.add(new_session)
|
|
742
|
+
await session.flush()
|
|
743
|
+
await session.refresh(new_session)
|
|
744
|
+
return new_session
|
|
745
|
+
|
|
746
|
+
async def get_platform_session_by_id(
|
|
747
|
+
self, session_id: str
|
|
748
|
+
) -> PlatformSession | None:
|
|
749
|
+
"""Get a Platform session by its ID."""
|
|
750
|
+
async with self.get_db() as session:
|
|
751
|
+
session: AsyncSession
|
|
752
|
+
query = select(PlatformSession).where(
|
|
753
|
+
PlatformSession.session_id == session_id,
|
|
754
|
+
)
|
|
755
|
+
result = await session.execute(query)
|
|
756
|
+
return result.scalar_one_or_none()
|
|
512
757
|
|
|
513
|
-
|
|
514
|
-
|
|
758
|
+
async def get_platform_sessions_by_creator(
|
|
759
|
+
self,
|
|
760
|
+
creator: str,
|
|
761
|
+
platform_id: str | None = None,
|
|
762
|
+
page: int = 1,
|
|
763
|
+
page_size: int = 20,
|
|
764
|
+
) -> list[PlatformSession]:
|
|
765
|
+
"""Get all Platform sessions for a specific creator (username) and optionally platform."""
|
|
766
|
+
async with self.get_db() as session:
|
|
767
|
+
session: AsyncSession
|
|
768
|
+
offset = (page - 1) * page_size
|
|
769
|
+
query = select(PlatformSession).where(PlatformSession.creator == creator)
|
|
515
770
|
|
|
516
|
-
|
|
517
|
-
|
|
771
|
+
if platform_id:
|
|
772
|
+
query = query.where(PlatformSession.platform_id == platform_id)
|
|
518
773
|
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
774
|
+
query = (
|
|
775
|
+
query.order_by(desc(PlatformSession.updated_at))
|
|
776
|
+
.offset(offset)
|
|
777
|
+
.limit(page_size)
|
|
778
|
+
)
|
|
779
|
+
result = await session.execute(query)
|
|
780
|
+
return list(result.scalars().all())
|
|
522
781
|
|
|
523
|
-
|
|
524
|
-
|
|
782
|
+
async def update_platform_session(
|
|
783
|
+
self,
|
|
784
|
+
session_id: str,
|
|
785
|
+
display_name: str | None = None,
|
|
786
|
+
) -> None:
|
|
787
|
+
"""Update a Platform session's updated_at timestamp and optionally display_name."""
|
|
788
|
+
async with self.get_db() as session:
|
|
789
|
+
session: AsyncSession
|
|
790
|
+
async with session.begin():
|
|
791
|
+
values: dict[str, T.Any] = {"updated_at": datetime.now(timezone.utc)}
|
|
792
|
+
if display_name is not None:
|
|
793
|
+
values["display_name"] = display_name
|
|
794
|
+
|
|
795
|
+
await session.execute(
|
|
796
|
+
update(PlatformSession)
|
|
797
|
+
.where(col(PlatformSession.session_id) == session_id)
|
|
798
|
+
.values(**values),
|
|
799
|
+
)
|
|
525
800
|
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
# 获取分页数据
|
|
537
|
-
c.execute(data_sql, query_params)
|
|
538
|
-
rows = c.fetchall()
|
|
539
|
-
|
|
540
|
-
conversations = []
|
|
541
|
-
|
|
542
|
-
for row in rows:
|
|
543
|
-
user_id, cid, created_at, updated_at, title, persona_id = row
|
|
544
|
-
# 确保 cid 是字符串类型,否则使用一个默认值
|
|
545
|
-
safe_cid = str(cid) if cid else "unknown"
|
|
546
|
-
display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid
|
|
547
|
-
|
|
548
|
-
conversations.append(
|
|
549
|
-
{
|
|
550
|
-
"user_id": user_id or "",
|
|
551
|
-
"cid": safe_cid,
|
|
552
|
-
"title": title or f"对话 {display_cid}",
|
|
553
|
-
"persona_id": persona_id or "",
|
|
554
|
-
"created_at": created_at or 0,
|
|
555
|
-
"updated_at": updated_at or 0,
|
|
556
|
-
}
|
|
557
|
-
)
|
|
558
|
-
|
|
559
|
-
return conversations, total_count
|
|
560
|
-
|
|
561
|
-
except Exception as _:
|
|
562
|
-
# 返回空列表和0,确保即使出错也有有效的返回值
|
|
563
|
-
return [], 0
|
|
564
|
-
finally:
|
|
565
|
-
c.close()
|
|
801
|
+
async def delete_platform_session(self, session_id: str) -> None:
|
|
802
|
+
"""Delete a Platform session by its ID."""
|
|
803
|
+
async with self.get_db() as session:
|
|
804
|
+
session: AsyncSession
|
|
805
|
+
async with session.begin():
|
|
806
|
+
await session.execute(
|
|
807
|
+
delete(PlatformSession).where(
|
|
808
|
+
col(PlatformSession.session_id) == session_id,
|
|
809
|
+
),
|
|
810
|
+
)
|