AstrBot 4.5.0__py3-none-any.whl → 4.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/api/__init__.py +10 -11
- astrbot/api/event/__init__.py +5 -6
- astrbot/api/event/filter/__init__.py +37 -36
- astrbot/api/platform/__init__.py +7 -8
- astrbot/api/provider/__init__.py +7 -7
- astrbot/api/star/__init__.py +3 -4
- astrbot/api/util/__init__.py +2 -2
- astrbot/cli/__main__.py +5 -5
- astrbot/cli/commands/__init__.py +3 -3
- astrbot/cli/commands/cmd_conf.py +19 -16
- astrbot/cli/commands/cmd_init.py +3 -2
- astrbot/cli/commands/cmd_plug.py +8 -10
- astrbot/cli/commands/cmd_run.py +5 -6
- astrbot/cli/utils/__init__.py +6 -6
- astrbot/cli/utils/basic.py +14 -14
- astrbot/cli/utils/plugin.py +24 -15
- astrbot/cli/utils/version_comparator.py +10 -12
- astrbot/core/__init__.py +8 -6
- astrbot/core/agent/agent.py +3 -2
- astrbot/core/agent/handoff.py +6 -2
- astrbot/core/agent/hooks.py +9 -6
- astrbot/core/agent/mcp_client.py +50 -15
- astrbot/core/agent/message.py +168 -0
- astrbot/core/agent/response.py +2 -1
- astrbot/core/agent/run_context.py +2 -3
- astrbot/core/agent/runners/base.py +10 -13
- astrbot/core/agent/runners/tool_loop_agent_runner.py +52 -51
- astrbot/core/agent/tool.py +60 -41
- astrbot/core/agent/tool_executor.py +9 -3
- astrbot/core/astr_agent_context.py +3 -1
- astrbot/core/astrbot_config_mgr.py +29 -9
- astrbot/core/config/__init__.py +2 -2
- astrbot/core/config/astrbot_config.py +28 -26
- astrbot/core/config/default.py +44 -6
- astrbot/core/conversation_mgr.py +105 -36
- astrbot/core/core_lifecycle.py +68 -54
- astrbot/core/db/__init__.py +33 -18
- astrbot/core/db/migration/helper.py +18 -13
- astrbot/core/db/migration/migra_3_to_4.py +53 -34
- astrbot/core/db/migration/migra_45_to_46.py +1 -1
- astrbot/core/db/migration/shared_preferences_v3.py +2 -1
- astrbot/core/db/migration/sqlite_v3.py +26 -23
- astrbot/core/db/po.py +27 -18
- astrbot/core/db/sqlite.py +74 -45
- astrbot/core/db/vec_db/base.py +10 -14
- astrbot/core/db/vec_db/faiss_impl/document_storage.py +90 -77
- astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +9 -3
- astrbot/core/db/vec_db/faiss_impl/vec_db.py +36 -31
- astrbot/core/event_bus.py +8 -6
- astrbot/core/file_token_service.py +6 -5
- astrbot/core/initial_loader.py +7 -5
- astrbot/core/knowledge_base/chunking/__init__.py +1 -3
- astrbot/core/knowledge_base/chunking/base.py +1 -0
- astrbot/core/knowledge_base/chunking/fixed_size.py +2 -0
- astrbot/core/knowledge_base/chunking/recursive.py +16 -10
- astrbot/core/knowledge_base/kb_db_sqlite.py +50 -48
- astrbot/core/knowledge_base/kb_helper.py +30 -17
- astrbot/core/knowledge_base/kb_mgr.py +6 -7
- astrbot/core/knowledge_base/models.py +10 -4
- astrbot/core/knowledge_base/parsers/__init__.py +3 -5
- astrbot/core/knowledge_base/parsers/base.py +1 -0
- astrbot/core/knowledge_base/parsers/markitdown_parser.py +2 -1
- astrbot/core/knowledge_base/parsers/pdf_parser.py +2 -1
- astrbot/core/knowledge_base/parsers/text_parser.py +1 -0
- astrbot/core/knowledge_base/parsers/util.py +1 -1
- astrbot/core/knowledge_base/retrieval/__init__.py +6 -8
- astrbot/core/knowledge_base/retrieval/manager.py +17 -14
- astrbot/core/knowledge_base/retrieval/rank_fusion.py +7 -3
- astrbot/core/knowledge_base/retrieval/sparse_retriever.py +11 -5
- astrbot/core/log.py +21 -13
- astrbot/core/message/components.py +123 -217
- astrbot/core/message/message_event_result.py +24 -24
- astrbot/core/persona_mgr.py +20 -11
- astrbot/core/pipeline/__init__.py +7 -7
- astrbot/core/pipeline/content_safety_check/stage.py +13 -9
- astrbot/core/pipeline/content_safety_check/strategies/__init__.py +1 -2
- astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py +12 -13
- astrbot/core/pipeline/content_safety_check/strategies/keywords.py +1 -0
- astrbot/core/pipeline/content_safety_check/strategies/strategy.py +6 -6
- astrbot/core/pipeline/context.py +4 -1
- astrbot/core/pipeline/context_utils.py +77 -7
- astrbot/core/pipeline/preprocess_stage/stage.py +12 -9
- astrbot/core/pipeline/process_stage/method/llm_request.py +125 -72
- astrbot/core/pipeline/process_stage/method/star_request.py +19 -17
- astrbot/core/pipeline/process_stage/stage.py +13 -10
- astrbot/core/pipeline/process_stage/utils.py +6 -5
- astrbot/core/pipeline/rate_limit_check/stage.py +37 -36
- astrbot/core/pipeline/respond/stage.py +23 -20
- astrbot/core/pipeline/result_decorate/stage.py +31 -23
- astrbot/core/pipeline/scheduler.py +12 -8
- astrbot/core/pipeline/session_status_check/stage.py +12 -8
- astrbot/core/pipeline/stage.py +10 -4
- astrbot/core/pipeline/waking_check/stage.py +24 -18
- astrbot/core/pipeline/whitelist_check/stage.py +10 -7
- astrbot/core/platform/__init__.py +6 -6
- astrbot/core/platform/astr_message_event.py +76 -110
- astrbot/core/platform/astrbot_message.py +11 -13
- astrbot/core/platform/manager.py +16 -15
- astrbot/core/platform/message_session.py +5 -3
- astrbot/core/platform/platform.py +16 -24
- astrbot/core/platform/platform_metadata.py +4 -4
- astrbot/core/platform/register.py +8 -8
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +23 -15
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +51 -33
- astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +47 -29
- astrbot/core/platform/sources/dingtalk/dingtalk_event.py +7 -3
- astrbot/core/platform/sources/discord/client.py +9 -6
- astrbot/core/platform/sources/discord/components.py +18 -14
- astrbot/core/platform/sources/discord/discord_platform_adapter.py +45 -30
- astrbot/core/platform/sources/discord/discord_platform_event.py +38 -30
- astrbot/core/platform/sources/lark/lark_adapter.py +23 -17
- astrbot/core/platform/sources/lark/lark_event.py +21 -14
- astrbot/core/platform/sources/misskey/misskey_adapter.py +107 -67
- astrbot/core/platform/sources/misskey/misskey_api.py +153 -129
- astrbot/core/platform/sources/misskey/misskey_event.py +20 -15
- astrbot/core/platform/sources/misskey/misskey_utils.py +74 -62
- astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +63 -44
- astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +41 -26
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +36 -17
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py +3 -1
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +12 -7
- astrbot/core/platform/sources/satori/satori_adapter.py +56 -38
- astrbot/core/platform/sources/satori/satori_event.py +34 -25
- astrbot/core/platform/sources/slack/client.py +11 -9
- astrbot/core/platform/sources/slack/slack_adapter.py +52 -36
- astrbot/core/platform/sources/slack/slack_event.py +34 -24
- astrbot/core/platform/sources/telegram/tg_adapter.py +38 -18
- astrbot/core/platform/sources/telegram/tg_event.py +32 -18
- astrbot/core/platform/sources/webchat/webchat_adapter.py +27 -17
- astrbot/core/platform/sources/webchat/webchat_event.py +14 -10
- astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +115 -120
- astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +9 -8
- astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py +15 -16
- astrbot/core/platform/sources/wecom/wecom_adapter.py +35 -18
- astrbot/core/platform/sources/wecom/wecom_event.py +55 -48
- astrbot/core/platform/sources/wecom/wecom_kf.py +34 -44
- astrbot/core/platform/sources/wecom/wecom_kf_message.py +26 -10
- astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +18 -10
- astrbot/core/platform/sources/wecom_ai_bot/__init__.py +3 -5
- astrbot/core/platform/sources/wecom_ai_bot/ierror.py +0 -1
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +61 -37
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +67 -28
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +8 -9
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +18 -9
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +14 -12
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +22 -12
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +40 -26
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +47 -45
- astrbot/core/platform_message_history_mgr.py +5 -3
- astrbot/core/provider/__init__.py +2 -3
- astrbot/core/provider/entites.py +8 -8
- astrbot/core/provider/entities.py +61 -75
- astrbot/core/provider/func_tool_manager.py +59 -55
- astrbot/core/provider/manager.py +40 -22
- astrbot/core/provider/provider.py +72 -46
- astrbot/core/provider/register.py +7 -7
- astrbot/core/provider/sources/anthropic_source.py +48 -30
- astrbot/core/provider/sources/azure_tts_source.py +17 -13
- astrbot/core/provider/sources/coze_api_client.py +27 -17
- astrbot/core/provider/sources/coze_source.py +104 -87
- astrbot/core/provider/sources/dashscope_source.py +18 -11
- astrbot/core/provider/sources/dashscope_tts.py +36 -23
- astrbot/core/provider/sources/dify_source.py +25 -20
- astrbot/core/provider/sources/edge_tts_source.py +21 -17
- astrbot/core/provider/sources/fishaudio_tts_api_source.py +22 -14
- astrbot/core/provider/sources/gemini_embedding_source.py +12 -13
- astrbot/core/provider/sources/gemini_source.py +72 -58
- astrbot/core/provider/sources/gemini_tts_source.py +8 -6
- astrbot/core/provider/sources/gsv_selfhosted_source.py +17 -14
- astrbot/core/provider/sources/gsvi_tts_source.py +11 -7
- astrbot/core/provider/sources/minimax_tts_api_source.py +50 -40
- astrbot/core/provider/sources/openai_embedding_source.py +6 -8
- astrbot/core/provider/sources/openai_source.py +102 -69
- astrbot/core/provider/sources/openai_tts_api_source.py +14 -6
- astrbot/core/provider/sources/sensevoice_selfhosted_source.py +13 -11
- astrbot/core/provider/sources/vllm_rerank_source.py +10 -4
- astrbot/core/provider/sources/volcengine_tts.py +38 -31
- astrbot/core/provider/sources/whisper_api_source.py +14 -12
- astrbot/core/provider/sources/whisper_selfhosted_source.py +15 -11
- astrbot/core/provider/sources/xinference_rerank_source.py +116 -0
- astrbot/core/provider/sources/xinference_stt_provider.py +197 -0
- astrbot/core/star/__init__.py +16 -11
- astrbot/core/star/config.py +10 -15
- astrbot/core/star/context.py +109 -84
- astrbot/core/star/filter/__init__.py +4 -3
- astrbot/core/star/filter/command.py +30 -28
- astrbot/core/star/filter/command_group.py +27 -24
- astrbot/core/star/filter/custom_filter.py +6 -5
- astrbot/core/star/filter/event_message_type.py +4 -2
- astrbot/core/star/filter/permission.py +4 -2
- astrbot/core/star/filter/platform_adapter_type.py +4 -2
- astrbot/core/star/filter/regex.py +4 -2
- astrbot/core/star/register/__init__.py +19 -19
- astrbot/core/star/register/star.py +6 -2
- astrbot/core/star/register/star_handler.py +96 -73
- astrbot/core/star/session_llm_manager.py +48 -14
- astrbot/core/star/session_plugin_manager.py +29 -15
- astrbot/core/star/star.py +1 -2
- astrbot/core/star/star_handler.py +13 -8
- astrbot/core/star/star_manager.py +151 -59
- astrbot/core/star/star_tools.py +44 -37
- astrbot/core/star/updator.py +10 -10
- astrbot/core/umop_config_router.py +10 -4
- astrbot/core/updator.py +13 -5
- astrbot/core/utils/astrbot_path.py +3 -5
- astrbot/core/utils/dify_api_client.py +33 -15
- astrbot/core/utils/io.py +66 -42
- astrbot/core/utils/log_pipe.py +1 -1
- astrbot/core/utils/metrics.py +7 -7
- astrbot/core/utils/path_util.py +15 -16
- astrbot/core/utils/pip_installer.py +5 -5
- astrbot/core/utils/session_waiter.py +19 -20
- astrbot/core/utils/shared_preferences.py +45 -20
- astrbot/core/utils/t2i/__init__.py +4 -1
- astrbot/core/utils/t2i/network_strategy.py +35 -26
- astrbot/core/utils/t2i/renderer.py +11 -5
- astrbot/core/utils/t2i/template_manager.py +14 -15
- astrbot/core/utils/tencent_record_helper.py +19 -13
- astrbot/core/utils/version_comparator.py +10 -13
- astrbot/core/zip_updator.py +43 -40
- astrbot/dashboard/routes/__init__.py +18 -18
- astrbot/dashboard/routes/auth.py +10 -8
- astrbot/dashboard/routes/chat.py +30 -21
- astrbot/dashboard/routes/config.py +92 -75
- astrbot/dashboard/routes/conversation.py +46 -39
- astrbot/dashboard/routes/file.py +4 -2
- astrbot/dashboard/routes/knowledge_base.py +47 -40
- astrbot/dashboard/routes/log.py +9 -4
- astrbot/dashboard/routes/persona.py +19 -16
- astrbot/dashboard/routes/plugin.py +69 -55
- astrbot/dashboard/routes/route.py +3 -1
- astrbot/dashboard/routes/session_management.py +130 -116
- astrbot/dashboard/routes/stat.py +34 -34
- astrbot/dashboard/routes/t2i.py +15 -12
- astrbot/dashboard/routes/tools.py +47 -52
- astrbot/dashboard/routes/update.py +32 -28
- astrbot/dashboard/server.py +30 -26
- astrbot/dashboard/utils.py +8 -4
- {astrbot-4.5.0.dist-info → astrbot-4.5.2.dist-info}/METADATA +4 -2
- astrbot-4.5.2.dist-info/RECORD +261 -0
- astrbot-4.5.0.dist-info/RECORD +0 -258
- {astrbot-4.5.0.dist-info → astrbot-4.5.2.dist-info}/WHEEL +0 -0
- {astrbot-4.5.0.dist-info → astrbot-4.5.2.dist-info}/entry_points.txt +0 -0
- {astrbot-4.5.0.dist-info → astrbot-4.5.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
import sqlite3
|
|
2
2
|
import time
|
|
3
|
-
from astrbot.core.db.po import Platform, Stats
|
|
4
|
-
from typing import Tuple, List, Dict, Any
|
|
5
3
|
from dataclasses import dataclass
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from astrbot.core.db.po import Platform, Stats
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
@dataclass
|
|
@@ -94,7 +95,7 @@ class SQLiteDatabase:
|
|
|
94
95
|
c.execute(
|
|
95
96
|
"""
|
|
96
97
|
PRAGMA table_info(webchat_conversation)
|
|
97
|
-
"""
|
|
98
|
+
""",
|
|
98
99
|
)
|
|
99
100
|
res = c.fetchall()
|
|
100
101
|
has_title = False
|
|
@@ -108,14 +109,14 @@ class SQLiteDatabase:
|
|
|
108
109
|
c.execute(
|
|
109
110
|
"""
|
|
110
111
|
ALTER TABLE webchat_conversation ADD COLUMN title TEXT;
|
|
111
|
-
"""
|
|
112
|
+
""",
|
|
112
113
|
)
|
|
113
114
|
self.conn.commit()
|
|
114
115
|
if not has_persona_id:
|
|
115
116
|
c.execute(
|
|
116
117
|
"""
|
|
117
118
|
ALTER TABLE webchat_conversation ADD COLUMN persona_id TEXT;
|
|
118
|
-
"""
|
|
119
|
+
""",
|
|
119
120
|
)
|
|
120
121
|
self.conn.commit()
|
|
121
122
|
|
|
@@ -126,7 +127,7 @@ class SQLiteDatabase:
|
|
|
126
127
|
conn.text_factory = str
|
|
127
128
|
return conn
|
|
128
129
|
|
|
129
|
-
def _exec_sql(self, sql: str, params:
|
|
130
|
+
def _exec_sql(self, sql: str, params: tuple = None):
|
|
130
131
|
conn = self.conn
|
|
131
132
|
try:
|
|
132
133
|
c = self.conn.cursor()
|
|
@@ -174,7 +175,7 @@ class SQLiteDatabase:
|
|
|
174
175
|
"""
|
|
175
176
|
SELECT * FROM platform
|
|
176
177
|
"""
|
|
177
|
-
+ where_clause
|
|
178
|
+
+ where_clause,
|
|
178
179
|
)
|
|
179
180
|
|
|
180
181
|
platform = []
|
|
@@ -194,7 +195,7 @@ class SQLiteDatabase:
|
|
|
194
195
|
c.execute(
|
|
195
196
|
"""
|
|
196
197
|
SELECT SUM(count) FROM platform
|
|
197
|
-
"""
|
|
198
|
+
""",
|
|
198
199
|
)
|
|
199
200
|
res = c.fetchone()
|
|
200
201
|
c.close()
|
|
@@ -214,7 +215,7 @@ class SQLiteDatabase:
|
|
|
214
215
|
SELECT name, SUM(count), timestamp FROM platform
|
|
215
216
|
"""
|
|
216
217
|
+ where_clause
|
|
217
|
-
+ " GROUP BY name"
|
|
218
|
+
+ " GROUP BY name",
|
|
218
219
|
)
|
|
219
220
|
|
|
220
221
|
platform = []
|
|
@@ -242,7 +243,7 @@ class SQLiteDatabase:
|
|
|
242
243
|
c.close()
|
|
243
244
|
|
|
244
245
|
if not res:
|
|
245
|
-
return
|
|
246
|
+
return None
|
|
246
247
|
|
|
247
248
|
return Conversation(*res)
|
|
248
249
|
|
|
@@ -257,7 +258,7 @@ class SQLiteDatabase:
|
|
|
257
258
|
(user_id, cid, history, updated_at, created_at),
|
|
258
259
|
)
|
|
259
260
|
|
|
260
|
-
def get_conversations(self, user_id: str) ->
|
|
261
|
+
def get_conversations(self, user_id: str) -> tuple:
|
|
261
262
|
try:
|
|
262
263
|
c = self.conn.cursor()
|
|
263
264
|
except sqlite3.ProgrammingError:
|
|
@@ -280,7 +281,7 @@ class SQLiteDatabase:
|
|
|
280
281
|
title = row[3]
|
|
281
282
|
persona_id = row[4]
|
|
282
283
|
conversations.append(
|
|
283
|
-
Conversation("", cid, "[]", created_at, updated_at, title, persona_id)
|
|
284
|
+
Conversation("", cid, "[]", created_at, updated_at, title, persona_id),
|
|
284
285
|
)
|
|
285
286
|
return conversations
|
|
286
287
|
|
|
@@ -319,8 +320,10 @@ class SQLiteDatabase:
|
|
|
319
320
|
)
|
|
320
321
|
|
|
321
322
|
def get_all_conversations(
|
|
322
|
-
self,
|
|
323
|
-
|
|
323
|
+
self,
|
|
324
|
+
page: int = 1,
|
|
325
|
+
page_size: int = 20,
|
|
326
|
+
) -> tuple[list[dict[str, Any]], int]:
|
|
324
327
|
"""获取所有对话,支持分页,按更新时间降序排序"""
|
|
325
328
|
try:
|
|
326
329
|
c = self.conn.cursor()
|
|
@@ -366,7 +369,7 @@ class SQLiteDatabase:
|
|
|
366
369
|
"persona_id": persona_id or "",
|
|
367
370
|
"created_at": created_at or 0,
|
|
368
371
|
"updated_at": updated_at or 0,
|
|
369
|
-
}
|
|
372
|
+
},
|
|
370
373
|
)
|
|
371
374
|
|
|
372
375
|
return conversations, total_count
|
|
@@ -381,12 +384,12 @@ class SQLiteDatabase:
|
|
|
381
384
|
self,
|
|
382
385
|
page: int = 1,
|
|
383
386
|
page_size: int = 20,
|
|
384
|
-
platforms:
|
|
385
|
-
message_types:
|
|
386
|
-
search_query: str = None,
|
|
387
|
-
exclude_ids:
|
|
388
|
-
exclude_platforms:
|
|
389
|
-
) ->
|
|
387
|
+
platforms: list[str] | None = None,
|
|
388
|
+
message_types: list[str] | None = None,
|
|
389
|
+
search_query: str | None = None,
|
|
390
|
+
exclude_ids: list[str] | None = None,
|
|
391
|
+
exclude_platforms: list[str] | None = None,
|
|
392
|
+
) -> tuple[list[dict[str, Any]], int]:
|
|
390
393
|
"""获取筛选后的对话列表"""
|
|
391
394
|
try:
|
|
392
395
|
c = self.conn.cursor()
|
|
@@ -422,7 +425,7 @@ class SQLiteDatabase:
|
|
|
422
425
|
if search_query:
|
|
423
426
|
search_query = search_query.encode("unicode_escape").decode("utf-8")
|
|
424
427
|
where_clauses.append(
|
|
425
|
-
"(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)"
|
|
428
|
+
"(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)",
|
|
426
429
|
)
|
|
427
430
|
search_param = f"%{search_query}%"
|
|
428
431
|
params.extend([search_param, search_param, search_param, search_param])
|
|
@@ -482,7 +485,7 @@ class SQLiteDatabase:
|
|
|
482
485
|
"persona_id": persona_id or "",
|
|
483
486
|
"created_at": created_at or 0,
|
|
484
487
|
"updated_at": updated_at or 0,
|
|
485
|
-
}
|
|
488
|
+
},
|
|
486
489
|
)
|
|
487
490
|
|
|
488
491
|
return conversations, total_count
|
astrbot/core/db/po.py
CHANGED
|
@@ -1,15 +1,15 @@
|
|
|
1
1
|
import uuid
|
|
2
|
-
|
|
3
|
-
from datetime import datetime, timezone
|
|
4
2
|
from dataclasses import dataclass, field
|
|
3
|
+
from datetime import datetime, timezone
|
|
4
|
+
from typing import TypedDict
|
|
5
|
+
|
|
5
6
|
from sqlmodel import (
|
|
7
|
+
JSON,
|
|
8
|
+
Field,
|
|
6
9
|
SQLModel,
|
|
7
10
|
Text,
|
|
8
|
-
JSON,
|
|
9
11
|
UniqueConstraint,
|
|
10
|
-
Field,
|
|
11
12
|
)
|
|
12
|
-
from typing import Optional, TypedDict
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
class PlatformStat(SQLModel, table=True):
|
|
@@ -40,7 +40,8 @@ class ConversationV2(SQLModel, table=True):
|
|
|
40
40
|
__tablename__ = "conversations"
|
|
41
41
|
|
|
42
42
|
inner_conversation_id: int = Field(
|
|
43
|
-
primary_key=True,
|
|
43
|
+
primary_key=True,
|
|
44
|
+
sa_column_kwargs={"autoincrement": True},
|
|
44
45
|
)
|
|
45
46
|
conversation_id: str = Field(
|
|
46
47
|
max_length=36,
|
|
@@ -50,14 +51,14 @@ class ConversationV2(SQLModel, table=True):
|
|
|
50
51
|
)
|
|
51
52
|
platform_id: str = Field(nullable=False)
|
|
52
53
|
user_id: str = Field(nullable=False)
|
|
53
|
-
content:
|
|
54
|
+
content: list | None = Field(default=None, sa_type=JSON)
|
|
54
55
|
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
|
55
56
|
updated_at: datetime = Field(
|
|
56
57
|
default_factory=lambda: datetime.now(timezone.utc),
|
|
57
58
|
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
|
58
59
|
)
|
|
59
|
-
title:
|
|
60
|
-
persona_id:
|
|
60
|
+
title: str | None = Field(default=None, max_length=255)
|
|
61
|
+
persona_id: str | None = Field(default=None)
|
|
61
62
|
|
|
62
63
|
__table_args__ = (
|
|
63
64
|
UniqueConstraint(
|
|
@@ -76,13 +77,15 @@ class Persona(SQLModel, table=True):
|
|
|
76
77
|
__tablename__ = "personas"
|
|
77
78
|
|
|
78
79
|
id: int | None = Field(
|
|
79
|
-
primary_key=True,
|
|
80
|
+
primary_key=True,
|
|
81
|
+
sa_column_kwargs={"autoincrement": True},
|
|
82
|
+
default=None,
|
|
80
83
|
)
|
|
81
84
|
persona_id: str = Field(max_length=255, nullable=False)
|
|
82
85
|
system_prompt: str = Field(sa_type=Text, nullable=False)
|
|
83
|
-
begin_dialogs:
|
|
86
|
+
begin_dialogs: list | None = Field(default=None, sa_type=JSON)
|
|
84
87
|
"""a list of strings, each representing a dialog to start with"""
|
|
85
|
-
tools:
|
|
88
|
+
tools: list | None = Field(default=None, sa_type=JSON)
|
|
86
89
|
"""None means use ALL tools for default, empty list means no tools, otherwise a list of tool names."""
|
|
87
90
|
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
|
88
91
|
updated_at: datetime = Field(
|
|
@@ -104,7 +107,9 @@ class Preference(SQLModel, table=True):
|
|
|
104
107
|
__tablename__ = "preferences"
|
|
105
108
|
|
|
106
109
|
id: int | None = Field(
|
|
107
|
-
default=None,
|
|
110
|
+
default=None,
|
|
111
|
+
primary_key=True,
|
|
112
|
+
sa_column_kwargs={"autoincrement": True},
|
|
108
113
|
)
|
|
109
114
|
scope: str = Field(nullable=False)
|
|
110
115
|
"""Scope of the preference, such as 'global', 'umo', 'plugin'."""
|
|
@@ -138,13 +143,15 @@ class PlatformMessageHistory(SQLModel, table=True):
|
|
|
138
143
|
__tablename__ = "platform_message_history"
|
|
139
144
|
|
|
140
145
|
id: int | None = Field(
|
|
141
|
-
primary_key=True,
|
|
146
|
+
primary_key=True,
|
|
147
|
+
sa_column_kwargs={"autoincrement": True},
|
|
148
|
+
default=None,
|
|
142
149
|
)
|
|
143
150
|
platform_id: str = Field(nullable=False)
|
|
144
151
|
user_id: str = Field(nullable=False) # An id of group, user in platform
|
|
145
|
-
sender_id:
|
|
146
|
-
sender_name:
|
|
147
|
-
default=None
|
|
152
|
+
sender_id: str | None = Field(default=None) # ID of the sender in the platform
|
|
153
|
+
sender_name: str | None = Field(
|
|
154
|
+
default=None,
|
|
148
155
|
) # Name of the sender in the platform
|
|
149
156
|
content: dict = Field(sa_type=JSON, nullable=False) # a message chain list
|
|
150
157
|
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
|
@@ -163,7 +170,9 @@ class Attachment(SQLModel, table=True):
|
|
|
163
170
|
__tablename__ = "attachments"
|
|
164
171
|
|
|
165
172
|
inner_attachment_id: int | None = Field(
|
|
166
|
-
primary_key=True,
|
|
173
|
+
primary_key=True,
|
|
174
|
+
sa_column_kwargs={"autoincrement": True},
|
|
175
|
+
default=None,
|
|
167
176
|
)
|
|
168
177
|
attachment_id: str = Field(
|
|
169
178
|
max_length=36,
|
astrbot/core/db/sqlite.py
CHANGED
|
@@ -1,22 +1,27 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
-
import typing as T
|
|
3
2
|
import threading
|
|
3
|
+
import typing as T
|
|
4
4
|
from datetime import datetime, timedelta
|
|
5
|
+
|
|
6
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
7
|
+
from sqlmodel import col, delete, desc, func, or_, select, text, update
|
|
8
|
+
|
|
5
9
|
from astrbot.core.db import BaseDatabase
|
|
6
10
|
from astrbot.core.db.po import (
|
|
7
|
-
ConversationV2,
|
|
8
|
-
PlatformStat,
|
|
9
|
-
PlatformMessageHistory,
|
|
10
11
|
Attachment,
|
|
12
|
+
ConversationV2,
|
|
11
13
|
Persona,
|
|
14
|
+
PlatformMessageHistory,
|
|
15
|
+
PlatformStat,
|
|
12
16
|
Preference,
|
|
13
|
-
Stats as DeprecatedStats,
|
|
14
|
-
Platform as DeprecatedPlatformStat,
|
|
15
17
|
SQLModel,
|
|
16
18
|
)
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
19
|
+
from astrbot.core.db.po import (
|
|
20
|
+
Platform as DeprecatedPlatformStat,
|
|
21
|
+
)
|
|
22
|
+
from astrbot.core.db.po import (
|
|
23
|
+
Stats as DeprecatedStats,
|
|
24
|
+
)
|
|
20
25
|
|
|
21
26
|
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
|
|
22
27
|
|
|
@@ -57,7 +62,9 @@ class SQLiteDatabase(BaseDatabase):
|
|
|
57
62
|
async with session.begin():
|
|
58
63
|
if timestamp is None:
|
|
59
64
|
timestamp = datetime.now().replace(
|
|
60
|
-
minute=0,
|
|
65
|
+
minute=0,
|
|
66
|
+
second=0,
|
|
67
|
+
microsecond=0,
|
|
61
68
|
)
|
|
62
69
|
current_hour = timestamp
|
|
63
70
|
await session.execute(
|
|
@@ -81,13 +88,13 @@ class SQLiteDatabase(BaseDatabase):
|
|
|
81
88
|
session: AsyncSession
|
|
82
89
|
result = await session.execute(
|
|
83
90
|
select(func.count(col(PlatformStat.platform_id))).select_from(
|
|
84
|
-
PlatformStat
|
|
85
|
-
)
|
|
91
|
+
PlatformStat,
|
|
92
|
+
),
|
|
86
93
|
)
|
|
87
94
|
count = result.scalar_one_or_none()
|
|
88
95
|
return count if count is not None else 0
|
|
89
96
|
|
|
90
|
-
async def get_platform_stats(self, offset_sec: int = 86400) ->
|
|
97
|
+
async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat]:
|
|
91
98
|
"""Get platform statistics within the specified offset in seconds and group by platform_id."""
|
|
92
99
|
async with self.get_db() as session:
|
|
93
100
|
session: AsyncSession
|
|
@@ -138,7 +145,7 @@ class SQLiteDatabase(BaseDatabase):
|
|
|
138
145
|
select(ConversationV2)
|
|
139
146
|
.order_by(desc(ConversationV2.created_at))
|
|
140
147
|
.offset(offset)
|
|
141
|
-
.limit(page_size)
|
|
148
|
+
.limit(page_size),
|
|
142
149
|
)
|
|
143
150
|
return result.scalars().all()
|
|
144
151
|
|
|
@@ -157,7 +164,7 @@ class SQLiteDatabase(BaseDatabase):
|
|
|
157
164
|
|
|
158
165
|
if platform_ids:
|
|
159
166
|
base_query = base_query.where(
|
|
160
|
-
col(ConversationV2.platform_id).in_(platform_ids)
|
|
167
|
+
col(ConversationV2.platform_id).in_(platform_ids),
|
|
161
168
|
)
|
|
162
169
|
if search_query:
|
|
163
170
|
search_query = search_query.encode("unicode_escape").decode("utf-8")
|
|
@@ -167,16 +174,16 @@ class SQLiteDatabase(BaseDatabase):
|
|
|
167
174
|
col(ConversationV2.content).ilike(f"%{search_query}%"),
|
|
168
175
|
col(ConversationV2.user_id).ilike(f"%{search_query}%"),
|
|
169
176
|
col(ConversationV2.conversation_id).ilike(f"%{search_query}%"),
|
|
170
|
-
)
|
|
177
|
+
),
|
|
171
178
|
)
|
|
172
179
|
if "message_types" in kwargs and len(kwargs["message_types"]) > 0:
|
|
173
180
|
for msg_type in kwargs["message_types"]:
|
|
174
181
|
base_query = base_query.where(
|
|
175
|
-
col(ConversationV2.user_id).ilike(f"%:{msg_type}:%")
|
|
182
|
+
col(ConversationV2.user_id).ilike(f"%:{msg_type}:%"),
|
|
176
183
|
)
|
|
177
184
|
if "platforms" in kwargs and len(kwargs["platforms"]) > 0:
|
|
178
185
|
base_query = base_query.where(
|
|
179
|
-
col(ConversationV2.platform_id).in_(kwargs["platforms"])
|
|
186
|
+
col(ConversationV2.platform_id).in_(kwargs["platforms"]),
|
|
180
187
|
)
|
|
181
188
|
|
|
182
189
|
# Get total count matching the filters
|
|
@@ -233,7 +240,7 @@ class SQLiteDatabase(BaseDatabase):
|
|
|
233
240
|
session: AsyncSession
|
|
234
241
|
async with session.begin():
|
|
235
242
|
query = update(ConversationV2).where(
|
|
236
|
-
col(ConversationV2.conversation_id) == cid
|
|
243
|
+
col(ConversationV2.conversation_id) == cid,
|
|
237
244
|
)
|
|
238
245
|
values = {}
|
|
239
246
|
if title is not None:
|
|
@@ -243,7 +250,7 @@ class SQLiteDatabase(BaseDatabase):
|
|
|
243
250
|
if content is not None:
|
|
244
251
|
values["content"] = content
|
|
245
252
|
if not values:
|
|
246
|
-
return
|
|
253
|
+
return None
|
|
247
254
|
query = query.values(**values)
|
|
248
255
|
await session.execute(query)
|
|
249
256
|
return await self.get_conversation_by_id(cid)
|
|
@@ -254,8 +261,8 @@ class SQLiteDatabase(BaseDatabase):
|
|
|
254
261
|
async with session.begin():
|
|
255
262
|
await session.execute(
|
|
256
263
|
delete(ConversationV2).where(
|
|
257
|
-
col(ConversationV2.conversation_id) == cid
|
|
258
|
-
)
|
|
264
|
+
col(ConversationV2.conversation_id) == cid,
|
|
265
|
+
),
|
|
259
266
|
)
|
|
260
267
|
|
|
261
268
|
async def delete_conversations_by_user_id(self, user_id: str) -> None:
|
|
@@ -263,7 +270,9 @@ class SQLiteDatabase(BaseDatabase):
|
|
|
263
270
|
session: AsyncSession
|
|
264
271
|
async with session.begin():
|
|
265
272
|
await session.execute(
|
|
266
|
-
delete(ConversationV2).where(
|
|
273
|
+
delete(ConversationV2).where(
|
|
274
|
+
col(ConversationV2.user_id) == user_id
|
|
275
|
+
),
|
|
267
276
|
)
|
|
268
277
|
|
|
269
278
|
async def get_session_conversations(
|
|
@@ -282,7 +291,7 @@ class SQLiteDatabase(BaseDatabase):
|
|
|
282
291
|
select(
|
|
283
292
|
col(Preference.scope_id).label("session_id"),
|
|
284
293
|
func.json_extract(Preference.value, "$.val").label(
|
|
285
|
-
"conversation_id"
|
|
294
|
+
"conversation_id",
|
|
286
295
|
), # type: ignore
|
|
287
296
|
col(ConversationV2.persona_id).label("persona_id"),
|
|
288
297
|
col(ConversationV2.title).label("title"),
|
|
@@ -295,7 +304,8 @@ class SQLiteDatabase(BaseDatabase):
|
|
|
295
304
|
== ConversationV2.conversation_id,
|
|
296
305
|
)
|
|
297
306
|
.outerjoin(
|
|
298
|
-
Persona,
|
|
307
|
+
Persona,
|
|
308
|
+
col(ConversationV2.persona_id) == Persona.persona_id,
|
|
299
309
|
)
|
|
300
310
|
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
|
|
301
311
|
)
|
|
@@ -308,14 +318,14 @@ class SQLiteDatabase(BaseDatabase):
|
|
|
308
318
|
col(Preference.scope_id).ilike(search_pattern),
|
|
309
319
|
col(ConversationV2.title).ilike(search_pattern),
|
|
310
320
|
col(Persona.persona_id).ilike(search_pattern),
|
|
311
|
-
)
|
|
321
|
+
),
|
|
312
322
|
)
|
|
313
323
|
|
|
314
324
|
# 平台筛选
|
|
315
325
|
if platform:
|
|
316
326
|
platform_pattern = f"{platform}:%"
|
|
317
327
|
base_query = base_query.where(
|
|
318
|
-
col(Preference.scope_id).like(platform_pattern)
|
|
328
|
+
col(Preference.scope_id).like(platform_pattern),
|
|
319
329
|
)
|
|
320
330
|
|
|
321
331
|
# 排序
|
|
@@ -336,7 +346,8 @@ class SQLiteDatabase(BaseDatabase):
|
|
|
336
346
|
== ConversationV2.conversation_id,
|
|
337
347
|
)
|
|
338
348
|
.outerjoin(
|
|
339
|
-
Persona,
|
|
349
|
+
Persona,
|
|
350
|
+
col(ConversationV2.persona_id) == Persona.persona_id,
|
|
340
351
|
)
|
|
341
352
|
.where(Preference.scope == "umo", Preference.key == "sel_conv_id")
|
|
342
353
|
)
|
|
@@ -349,13 +360,13 @@ class SQLiteDatabase(BaseDatabase):
|
|
|
349
360
|
col(Preference.scope_id).ilike(search_pattern),
|
|
350
361
|
col(ConversationV2.title).ilike(search_pattern),
|
|
351
362
|
col(Persona.persona_id).ilike(search_pattern),
|
|
352
|
-
)
|
|
363
|
+
),
|
|
353
364
|
)
|
|
354
365
|
|
|
355
366
|
if platform:
|
|
356
367
|
platform_pattern = f"{platform}:%"
|
|
357
368
|
count_base_query = count_base_query.where(
|
|
358
|
-
col(Preference.scope_id).like(platform_pattern)
|
|
369
|
+
col(Preference.scope_id).like(platform_pattern),
|
|
359
370
|
)
|
|
360
371
|
|
|
361
372
|
total_result = await session.execute(count_base_query)
|
|
@@ -396,7 +407,10 @@ class SQLiteDatabase(BaseDatabase):
|
|
|
396
407
|
return new_history
|
|
397
408
|
|
|
398
409
|
async def delete_platform_message_offset(
|
|
399
|
-
self,
|
|
410
|
+
self,
|
|
411
|
+
platform_id,
|
|
412
|
+
user_id,
|
|
413
|
+
offset_sec=86400,
|
|
400
414
|
):
|
|
401
415
|
"""Delete platform message history records older than the specified offset."""
|
|
402
416
|
async with self.get_db() as session:
|
|
@@ -409,11 +423,15 @@ class SQLiteDatabase(BaseDatabase):
|
|
|
409
423
|
col(PlatformMessageHistory.platform_id) == platform_id,
|
|
410
424
|
col(PlatformMessageHistory.user_id) == user_id,
|
|
411
425
|
col(PlatformMessageHistory.created_at) < cutoff_time,
|
|
412
|
-
)
|
|
426
|
+
),
|
|
413
427
|
)
|
|
414
428
|
|
|
415
429
|
async def get_platform_message_history(
|
|
416
|
-
self,
|
|
430
|
+
self,
|
|
431
|
+
platform_id,
|
|
432
|
+
user_id,
|
|
433
|
+
page=1,
|
|
434
|
+
page_size=20,
|
|
417
435
|
):
|
|
418
436
|
"""Get platform message history records."""
|
|
419
437
|
async with self.get_db() as session:
|
|
@@ -452,7 +470,11 @@ class SQLiteDatabase(BaseDatabase):
|
|
|
452
470
|
return result.scalar_one_or_none()
|
|
453
471
|
|
|
454
472
|
async def insert_persona(
|
|
455
|
-
self,
|
|
473
|
+
self,
|
|
474
|
+
persona_id,
|
|
475
|
+
system_prompt,
|
|
476
|
+
begin_dialogs=None,
|
|
477
|
+
tools=None,
|
|
456
478
|
):
|
|
457
479
|
"""Insert a new persona record."""
|
|
458
480
|
async with self.get_db() as session:
|
|
@@ -484,7 +506,11 @@ class SQLiteDatabase(BaseDatabase):
|
|
|
484
506
|
return result.scalars().all()
|
|
485
507
|
|
|
486
508
|
async def update_persona(
|
|
487
|
-
self,
|
|
509
|
+
self,
|
|
510
|
+
persona_id,
|
|
511
|
+
system_prompt=None,
|
|
512
|
+
begin_dialogs=None,
|
|
513
|
+
tools=NOT_GIVEN,
|
|
488
514
|
):
|
|
489
515
|
"""Update a persona's system prompt or begin dialogs."""
|
|
490
516
|
async with self.get_db() as session:
|
|
@@ -499,7 +525,7 @@ class SQLiteDatabase(BaseDatabase):
|
|
|
499
525
|
if tools is not NOT_GIVEN:
|
|
500
526
|
values["tools"] = tools
|
|
501
527
|
if not values:
|
|
502
|
-
return
|
|
528
|
+
return None
|
|
503
529
|
query = query.values(**values)
|
|
504
530
|
await session.execute(query)
|
|
505
531
|
return await self.get_persona_by_id(persona_id)
|
|
@@ -510,7 +536,7 @@ class SQLiteDatabase(BaseDatabase):
|
|
|
510
536
|
session: AsyncSession
|
|
511
537
|
async with session.begin():
|
|
512
538
|
await session.execute(
|
|
513
|
-
delete(Persona).where(col(Persona.persona_id) == persona_id)
|
|
539
|
+
delete(Persona).where(col(Persona.persona_id) == persona_id),
|
|
514
540
|
)
|
|
515
541
|
|
|
516
542
|
async def insert_preference_or_update(self, scope, scope_id, key, value):
|
|
@@ -529,7 +555,10 @@ class SQLiteDatabase(BaseDatabase):
|
|
|
529
555
|
existing_preference.value = value
|
|
530
556
|
else:
|
|
531
557
|
new_preference = Preference(
|
|
532
|
-
scope=scope,
|
|
558
|
+
scope=scope,
|
|
559
|
+
scope_id=scope_id,
|
|
560
|
+
key=key,
|
|
561
|
+
value=value,
|
|
533
562
|
)
|
|
534
563
|
session.add(new_preference)
|
|
535
564
|
return existing_preference or new_preference
|
|
@@ -568,7 +597,7 @@ class SQLiteDatabase(BaseDatabase):
|
|
|
568
597
|
col(Preference.scope) == scope,
|
|
569
598
|
col(Preference.scope_id) == scope_id,
|
|
570
599
|
col(Preference.key) == key,
|
|
571
|
-
)
|
|
600
|
+
),
|
|
572
601
|
)
|
|
573
602
|
await session.commit()
|
|
574
603
|
|
|
@@ -581,7 +610,7 @@ class SQLiteDatabase(BaseDatabase):
|
|
|
581
610
|
delete(Preference).where(
|
|
582
611
|
col(Preference.scope) == scope,
|
|
583
612
|
col(Preference.scope_id) == scope_id,
|
|
584
|
-
)
|
|
613
|
+
),
|
|
585
614
|
)
|
|
586
615
|
await session.commit()
|
|
587
616
|
|
|
@@ -598,7 +627,7 @@ class SQLiteDatabase(BaseDatabase):
|
|
|
598
627
|
now = datetime.now()
|
|
599
628
|
start_time = now - timedelta(seconds=offset_sec)
|
|
600
629
|
result = await session.execute(
|
|
601
|
-
select(PlatformStat).where(PlatformStat.timestamp >= start_time)
|
|
630
|
+
select(PlatformStat).where(PlatformStat.timestamp >= start_time),
|
|
602
631
|
)
|
|
603
632
|
all_datas = result.scalars().all()
|
|
604
633
|
deprecated_stats = DeprecatedStats()
|
|
@@ -608,7 +637,7 @@ class SQLiteDatabase(BaseDatabase):
|
|
|
608
637
|
name=data.platform_id,
|
|
609
638
|
count=data.count,
|
|
610
639
|
timestamp=int(data.timestamp.timestamp()),
|
|
611
|
-
)
|
|
640
|
+
),
|
|
612
641
|
)
|
|
613
642
|
return deprecated_stats
|
|
614
643
|
|
|
@@ -630,7 +659,7 @@ class SQLiteDatabase(BaseDatabase):
|
|
|
630
659
|
async with self.get_db() as session:
|
|
631
660
|
session: AsyncSession
|
|
632
661
|
result = await session.execute(
|
|
633
|
-
select(func.sum(PlatformStat.count)).select_from(PlatformStat)
|
|
662
|
+
select(func.sum(PlatformStat.count)).select_from(PlatformStat),
|
|
634
663
|
)
|
|
635
664
|
total_count = result.scalar_one_or_none()
|
|
636
665
|
return total_count if total_count is not None else 0
|
|
@@ -656,7 +685,7 @@ class SQLiteDatabase(BaseDatabase):
|
|
|
656
685
|
result = await session.execute(
|
|
657
686
|
select(PlatformStat.platform_id, func.sum(PlatformStat.count))
|
|
658
687
|
.where(PlatformStat.timestamp >= start_time)
|
|
659
|
-
.group_by(PlatformStat.platform_id)
|
|
688
|
+
.group_by(PlatformStat.platform_id),
|
|
660
689
|
)
|
|
661
690
|
grouped_stats = result.all()
|
|
662
691
|
deprecated_stats = DeprecatedStats()
|
|
@@ -666,7 +695,7 @@ class SQLiteDatabase(BaseDatabase):
|
|
|
666
695
|
name=platform_id,
|
|
667
696
|
count=count,
|
|
668
697
|
timestamp=int(start_time.timestamp()),
|
|
669
|
-
)
|
|
698
|
+
),
|
|
670
699
|
)
|
|
671
700
|
return deprecated_stats
|
|
672
701
|
|
astrbot/core/db/vec_db/base.py
CHANGED
|
@@ -10,18 +10,16 @@ class Result:
|
|
|
10
10
|
|
|
11
11
|
class BaseVecDB:
|
|
12
12
|
async def initialize(self):
|
|
13
|
-
"""
|
|
14
|
-
初始化向量数据库
|
|
15
|
-
"""
|
|
16
|
-
pass
|
|
13
|
+
"""初始化向量数据库"""
|
|
17
14
|
|
|
18
15
|
@abc.abstractmethod
|
|
19
16
|
async def insert(
|
|
20
|
-
self,
|
|
17
|
+
self,
|
|
18
|
+
content: str,
|
|
19
|
+
metadata: dict | None = None,
|
|
20
|
+
id: str | None = None,
|
|
21
21
|
) -> int:
|
|
22
|
-
"""
|
|
23
|
-
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
|
|
24
|
-
"""
|
|
22
|
+
"""插入一条文本和其对应向量,自动生成 ID 并保持一致性。"""
|
|
25
23
|
...
|
|
26
24
|
|
|
27
25
|
@abc.abstractmethod
|
|
@@ -35,11 +33,11 @@ class BaseVecDB:
|
|
|
35
33
|
max_retries: int = 3,
|
|
36
34
|
progress_callback=None,
|
|
37
35
|
) -> int:
|
|
38
|
-
"""
|
|
39
|
-
批量插入文本和其对应向量,自动生成 ID 并保持一致性。
|
|
36
|
+
"""批量插入文本和其对应向量,自动生成 ID 并保持一致性。
|
|
40
37
|
|
|
41
38
|
Args:
|
|
42
39
|
progress_callback: 进度回调函数,接收参数 (current, total)
|
|
40
|
+
|
|
43
41
|
"""
|
|
44
42
|
...
|
|
45
43
|
|
|
@@ -52,8 +50,7 @@ class BaseVecDB:
|
|
|
52
50
|
rerank: bool = False,
|
|
53
51
|
metadata_filters: dict | None = None,
|
|
54
52
|
) -> list[Result]:
|
|
55
|
-
"""
|
|
56
|
-
搜索最相似的文档。
|
|
53
|
+
"""搜索最相似的文档。
|
|
57
54
|
Args:
|
|
58
55
|
query (str): 查询文本
|
|
59
56
|
top_k (int): 返回的最相似文档的数量
|
|
@@ -64,8 +61,7 @@ class BaseVecDB:
|
|
|
64
61
|
|
|
65
62
|
@abc.abstractmethod
|
|
66
63
|
async def delete(self, doc_id: str) -> bool:
|
|
67
|
-
"""
|
|
68
|
-
删除指定文档。
|
|
64
|
+
"""删除指定文档。
|
|
69
65
|
Args:
|
|
70
66
|
doc_id (str): 要删除的文档 ID
|
|
71
67
|
Returns:
|