AstrBot 4.3.3__py3-none-any.whl → 4.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/core/agent/mcp_client.py +18 -4
- astrbot/core/agent/runners/tool_loop_agent_runner.py +31 -2
- astrbot/core/astr_agent_context.py +1 -0
- astrbot/core/astrbot_config_mgr.py +23 -51
- astrbot/core/config/default.py +139 -14
- astrbot/core/conversation_mgr.py +36 -1
- astrbot/core/core_lifecycle.py +24 -5
- astrbot/core/db/migration/migra_45_to_46.py +44 -0
- astrbot/core/db/vec_db/base.py +33 -2
- astrbot/core/db/vec_db/faiss_impl/document_storage.py +310 -52
- astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +31 -3
- astrbot/core/db/vec_db/faiss_impl/vec_db.py +81 -23
- astrbot/core/file_token_service.py +6 -1
- astrbot/core/initial_loader.py +6 -3
- astrbot/core/knowledge_base/chunking/__init__.py +11 -0
- astrbot/core/knowledge_base/chunking/base.py +24 -0
- astrbot/core/knowledge_base/chunking/fixed_size.py +57 -0
- astrbot/core/knowledge_base/chunking/recursive.py +155 -0
- astrbot/core/knowledge_base/kb_db_sqlite.py +299 -0
- astrbot/core/knowledge_base/kb_helper.py +348 -0
- astrbot/core/knowledge_base/kb_mgr.py +287 -0
- astrbot/core/knowledge_base/models.py +114 -0
- astrbot/core/knowledge_base/parsers/__init__.py +15 -0
- astrbot/core/knowledge_base/parsers/base.py +50 -0
- astrbot/core/knowledge_base/parsers/markitdown_parser.py +25 -0
- astrbot/core/knowledge_base/parsers/pdf_parser.py +100 -0
- astrbot/core/knowledge_base/parsers/text_parser.py +41 -0
- astrbot/core/knowledge_base/parsers/util.py +13 -0
- astrbot/core/knowledge_base/retrieval/__init__.py +16 -0
- astrbot/core/knowledge_base/retrieval/hit_stopwords.txt +767 -0
- astrbot/core/knowledge_base/retrieval/manager.py +273 -0
- astrbot/core/knowledge_base/retrieval/rank_fusion.py +138 -0
- astrbot/core/knowledge_base/retrieval/sparse_retriever.py +130 -0
- astrbot/core/pipeline/process_stage/method/llm_request.py +61 -21
- astrbot/core/pipeline/process_stage/utils.py +80 -0
- astrbot/core/pipeline/scheduler.py +1 -1
- astrbot/core/platform/astr_message_event.py +8 -7
- astrbot/core/platform/manager.py +4 -0
- astrbot/core/platform/sources/misskey/misskey_adapter.py +380 -44
- astrbot/core/platform/sources/misskey/misskey_api.py +581 -45
- astrbot/core/platform/sources/misskey/misskey_event.py +76 -41
- astrbot/core/platform/sources/misskey/misskey_utils.py +254 -43
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +2 -1
- astrbot/core/platform/sources/satori/satori_adapter.py +27 -1
- astrbot/core/platform/sources/satori/satori_event.py +270 -77
- astrbot/core/platform/sources/webchat/webchat_adapter.py +0 -1
- astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +289 -0
- astrbot/core/platform/sources/wecom_ai_bot/__init__.py +17 -0
- astrbot/core/platform/sources/wecom_ai_bot/ierror.py +20 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +445 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +378 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +149 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +148 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +166 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +199 -0
- astrbot/core/provider/manager.py +14 -9
- astrbot/core/provider/provider.py +67 -0
- astrbot/core/provider/sources/anthropic_source.py +4 -4
- astrbot/core/provider/sources/dashscope_source.py +10 -9
- astrbot/core/provider/sources/dify_source.py +6 -8
- astrbot/core/provider/sources/gemini_embedding_source.py +1 -2
- astrbot/core/provider/sources/openai_embedding_source.py +1 -2
- astrbot/core/provider/sources/openai_source.py +18 -15
- astrbot/core/provider/sources/openai_tts_api_source.py +1 -1
- astrbot/core/star/context.py +3 -0
- astrbot/core/star/star.py +6 -0
- astrbot/core/star/star_manager.py +13 -7
- astrbot/core/umop_config_router.py +81 -0
- astrbot/core/updator.py +1 -1
- astrbot/core/utils/io.py +23 -12
- astrbot/dashboard/routes/__init__.py +2 -0
- astrbot/dashboard/routes/config.py +137 -9
- astrbot/dashboard/routes/knowledge_base.py +1065 -0
- astrbot/dashboard/routes/plugin.py +24 -5
- astrbot/dashboard/routes/tools.py +14 -0
- astrbot/dashboard/routes/update.py +1 -1
- astrbot/dashboard/server.py +6 -0
- astrbot/dashboard/utils.py +161 -0
- {astrbot-4.3.3.dist-info → astrbot-4.5.0.dist-info}/METADATA +91 -55
- {astrbot-4.3.3.dist-info → astrbot-4.5.0.dist-info}/RECORD +83 -50
- {astrbot-4.3.3.dist-info → astrbot-4.5.0.dist-info}/WHEEL +0 -0
- {astrbot-4.3.3.dist-info → astrbot-4.5.0.dist-info}/entry_points.txt +0 -0
- {astrbot-4.3.3.dist-info → astrbot-4.5.0.dist-info}/licenses/LICENSE +0 -0
astrbot/core/core_lifecycle.py
CHANGED
|
@@ -17,7 +17,6 @@ import os
|
|
|
17
17
|
from .event_bus import EventBus
|
|
18
18
|
from . import astrbot_config, html_renderer
|
|
19
19
|
from asyncio import Queue
|
|
20
|
-
from typing import List
|
|
21
20
|
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
|
|
22
21
|
from astrbot.core.star import PluginManager
|
|
23
22
|
from astrbot.core.platform.manager import PlatformManager
|
|
@@ -26,14 +25,17 @@ from astrbot.core.persona_mgr import PersonaManager
|
|
|
26
25
|
from astrbot.core.provider.manager import ProviderManager
|
|
27
26
|
from astrbot.core import LogBroker
|
|
28
27
|
from astrbot.core.db import BaseDatabase
|
|
28
|
+
from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46
|
|
29
29
|
from astrbot.core.updator import AstrBotUpdator
|
|
30
30
|
from astrbot.core import logger, sp
|
|
31
31
|
from astrbot.core.config.default import VERSION
|
|
32
32
|
from astrbot.core.conversation_mgr import ConversationManager
|
|
33
33
|
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
|
|
34
|
+
from astrbot.core.umop_config_router import UmopConfigRouter
|
|
34
35
|
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
|
35
36
|
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
|
36
37
|
from astrbot.core.star.star_handler import star_map
|
|
38
|
+
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
|
|
37
39
|
|
|
38
40
|
|
|
39
41
|
class AstrBotCoreLifecycle:
|
|
@@ -84,11 +86,21 @@ class AstrBotCoreLifecycle:
|
|
|
84
86
|
|
|
85
87
|
await html_renderer.initialize()
|
|
86
88
|
|
|
89
|
+
# 初始化 UMOP 配置路由器
|
|
90
|
+
self.umop_config_router = UmopConfigRouter(sp=sp)
|
|
91
|
+
|
|
87
92
|
# 初始化 AstrBot 配置管理器
|
|
88
93
|
self.astrbot_config_mgr = AstrBotConfigManager(
|
|
89
|
-
default_config=self.astrbot_config, sp=sp
|
|
94
|
+
default_config=self.astrbot_config, ucr=self.umop_config_router, sp=sp
|
|
90
95
|
)
|
|
91
96
|
|
|
97
|
+
# 4.5 to 4.6 migration for umop_config_router
|
|
98
|
+
try:
|
|
99
|
+
await migrate_45_to_46(self.astrbot_config_mgr, self.umop_config_router)
|
|
100
|
+
except Exception as e:
|
|
101
|
+
logger.error(f"Migration from version 4.5 to 4.6 failed: {e!s}")
|
|
102
|
+
logger.error(traceback.format_exc())
|
|
103
|
+
|
|
92
104
|
# 初始化事件队列
|
|
93
105
|
self.event_queue = Queue()
|
|
94
106
|
|
|
@@ -110,6 +122,9 @@ class AstrBotCoreLifecycle:
|
|
|
110
122
|
# 初始化平台消息历史管理器
|
|
111
123
|
self.platform_message_history_manager = PlatformMessageHistoryManager(self.db)
|
|
112
124
|
|
|
125
|
+
# 初始化知识库管理器
|
|
126
|
+
self.kb_manager = KnowledgeBaseManager(self.provider_manager)
|
|
127
|
+
|
|
113
128
|
# 初始化提供给插件的上下文
|
|
114
129
|
self.star_context = Context(
|
|
115
130
|
self.event_queue,
|
|
@@ -121,6 +136,7 @@ class AstrBotCoreLifecycle:
|
|
|
121
136
|
self.platform_message_history_manager,
|
|
122
137
|
self.persona_mgr,
|
|
123
138
|
self.astrbot_config_mgr,
|
|
139
|
+
self.kb_manager,
|
|
124
140
|
)
|
|
125
141
|
|
|
126
142
|
# 初始化插件管理器
|
|
@@ -132,8 +148,9 @@ class AstrBotCoreLifecycle:
|
|
|
132
148
|
# 根据配置实例化各个 Provider
|
|
133
149
|
await self.provider_manager.initialize()
|
|
134
150
|
|
|
135
|
-
|
|
151
|
+
await self.kb_manager.initialize()
|
|
136
152
|
|
|
153
|
+
# 初始化消息事件流水线调度器
|
|
137
154
|
self.pipeline_scheduler_mapping = await self.load_pipeline_scheduler()
|
|
138
155
|
|
|
139
156
|
# 初始化更新器
|
|
@@ -148,7 +165,7 @@ class AstrBotCoreLifecycle:
|
|
|
148
165
|
self.start_time = int(time.time())
|
|
149
166
|
|
|
150
167
|
# 初始化当前任务列表
|
|
151
|
-
self.curr_tasks:
|
|
168
|
+
self.curr_tasks: list[asyncio.Task] = []
|
|
152
169
|
|
|
153
170
|
# 根据配置实例化各个平台适配器
|
|
154
171
|
await self.platform_manager.initialize()
|
|
@@ -233,6 +250,7 @@ class AstrBotCoreLifecycle:
|
|
|
233
250
|
|
|
234
251
|
await self.provider_manager.terminate()
|
|
235
252
|
await self.platform_manager.terminate()
|
|
253
|
+
await self.kb_manager.terminate()
|
|
236
254
|
self.dashboard_shutdown_event.set()
|
|
237
255
|
|
|
238
256
|
# 再次遍历curr_tasks等待每个任务真正结束
|
|
@@ -248,12 +266,13 @@ class AstrBotCoreLifecycle:
|
|
|
248
266
|
"""重启 AstrBot 核心生命周期管理类, 终止各个管理器并重新加载平台实例"""
|
|
249
267
|
await self.provider_manager.terminate()
|
|
250
268
|
await self.platform_manager.terminate()
|
|
269
|
+
await self.kb_manager.terminate()
|
|
251
270
|
self.dashboard_shutdown_event.set()
|
|
252
271
|
threading.Thread(
|
|
253
272
|
target=self.astrbot_updator._reboot, name="restart", daemon=True
|
|
254
273
|
).start()
|
|
255
274
|
|
|
256
|
-
def load_platform(self) ->
|
|
275
|
+
def load_platform(self) -> list[asyncio.Task]:
|
|
257
276
|
"""加载平台实例并返回所有平台实例的异步任务列表"""
|
|
258
277
|
tasks = []
|
|
259
278
|
platform_insts = self.platform_manager.get_insts()
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from astrbot.api import logger, sp
|
|
2
|
+
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
|
3
|
+
from astrbot.core.umop_config_router import UmopConfigRouter
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
async def migrate_45_to_46(acm: AstrBotConfigManager, ucr: UmopConfigRouter):
|
|
7
|
+
abconf_data = acm.abconf_data
|
|
8
|
+
|
|
9
|
+
if not isinstance(abconf_data, dict):
|
|
10
|
+
# should be unreachable
|
|
11
|
+
logger.warning(
|
|
12
|
+
f"migrate_45_to_46: abconf_data is not a dict (type={type(abconf_data)}). Value: {abconf_data!r}"
|
|
13
|
+
)
|
|
14
|
+
return
|
|
15
|
+
|
|
16
|
+
# 如果任何一项带有 umop,则说明需要迁移
|
|
17
|
+
need_migration = False
|
|
18
|
+
for conf_id, conf_info in abconf_data.items():
|
|
19
|
+
if isinstance(conf_info, dict) and "umop" in conf_info:
|
|
20
|
+
need_migration = True
|
|
21
|
+
break
|
|
22
|
+
|
|
23
|
+
if not need_migration:
|
|
24
|
+
return
|
|
25
|
+
|
|
26
|
+
logger.info("Starting migration from version 4.5 to 4.6")
|
|
27
|
+
|
|
28
|
+
# extract umo->conf_id mapping
|
|
29
|
+
umo_to_conf_id = {}
|
|
30
|
+
for conf_id, conf_info in abconf_data.items():
|
|
31
|
+
if isinstance(conf_info, dict) and "umop" in conf_info:
|
|
32
|
+
umop_ls = conf_info.pop("umop")
|
|
33
|
+
if not isinstance(umop_ls, list):
|
|
34
|
+
continue
|
|
35
|
+
for umo in umop_ls:
|
|
36
|
+
if isinstance(umo, str) and umo not in umo_to_conf_id:
|
|
37
|
+
umo_to_conf_id[umo] = conf_id
|
|
38
|
+
|
|
39
|
+
# update the abconf data
|
|
40
|
+
await sp.global_put("abconf_mapping", abconf_data)
|
|
41
|
+
# update the umop config router
|
|
42
|
+
await ucr.update_routing_data(umo_to_conf_id)
|
|
43
|
+
|
|
44
|
+
logger.info("Migration from version 45 to 46 completed successfully")
|
astrbot/core/db/vec_db/base.py
CHANGED
|
@@ -16,14 +16,42 @@ class BaseVecDB:
|
|
|
16
16
|
pass
|
|
17
17
|
|
|
18
18
|
@abc.abstractmethod
|
|
19
|
-
async def insert(
|
|
19
|
+
async def insert(
|
|
20
|
+
self, content: str, metadata: dict | None = None, id: str | None = None
|
|
21
|
+
) -> int:
|
|
20
22
|
"""
|
|
21
23
|
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
|
|
22
24
|
"""
|
|
23
25
|
...
|
|
24
26
|
|
|
25
27
|
@abc.abstractmethod
|
|
26
|
-
async def
|
|
28
|
+
async def insert_batch(
|
|
29
|
+
self,
|
|
30
|
+
contents: list[str],
|
|
31
|
+
metadatas: list[dict] | None = None,
|
|
32
|
+
ids: list[str] | None = None,
|
|
33
|
+
batch_size: int = 32,
|
|
34
|
+
tasks_limit: int = 3,
|
|
35
|
+
max_retries: int = 3,
|
|
36
|
+
progress_callback=None,
|
|
37
|
+
) -> int:
|
|
38
|
+
"""
|
|
39
|
+
批量插入文本和其对应向量,自动生成 ID 并保持一致性。
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
progress_callback: 进度回调函数,接收参数 (current, total)
|
|
43
|
+
"""
|
|
44
|
+
...
|
|
45
|
+
|
|
46
|
+
@abc.abstractmethod
|
|
47
|
+
async def retrieve(
|
|
48
|
+
self,
|
|
49
|
+
query: str,
|
|
50
|
+
top_k: int = 5,
|
|
51
|
+
fetch_k: int = 20,
|
|
52
|
+
rerank: bool = False,
|
|
53
|
+
metadata_filters: dict | None = None,
|
|
54
|
+
) -> list[Result]:
|
|
27
55
|
"""
|
|
28
56
|
搜索最相似的文档。
|
|
29
57
|
Args:
|
|
@@ -44,3 +72,6 @@ class BaseVecDB:
|
|
|
44
72
|
bool: 删除是否成功
|
|
45
73
|
"""
|
|
46
74
|
...
|
|
75
|
+
|
|
76
|
+
@abc.abstractmethod
|
|
77
|
+
async def close(self): ...
|
|
@@ -1,59 +1,224 @@
|
|
|
1
|
-
import aiosqlite
|
|
2
1
|
import os
|
|
2
|
+
import json
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from contextlib import asynccontextmanager
|
|
5
|
+
|
|
6
|
+
from sqlalchemy import Text, Column
|
|
7
|
+
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
|
|
8
|
+
from sqlalchemy.orm import sessionmaker
|
|
9
|
+
from sqlmodel import Field, SQLModel, select, col, func, text, MetaData
|
|
10
|
+
from astrbot.core import logger
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class BaseDocModel(SQLModel, table=False):
|
|
14
|
+
metadata = MetaData()
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Document(BaseDocModel, table=True):
|
|
18
|
+
"""SQLModel for documents table."""
|
|
19
|
+
|
|
20
|
+
__tablename__ = "documents" # type: ignore
|
|
21
|
+
|
|
22
|
+
id: int | None = Field(
|
|
23
|
+
default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}
|
|
24
|
+
)
|
|
25
|
+
doc_id: str = Field(nullable=False)
|
|
26
|
+
text: str = Field(nullable=False)
|
|
27
|
+
metadata_: str | None = Field(default=None, sa_column=Column("metadata", Text))
|
|
28
|
+
created_at: datetime | None = Field(default=None)
|
|
29
|
+
updated_at: datetime | None = Field(default=None)
|
|
3
30
|
|
|
4
31
|
|
|
5
32
|
class DocumentStorage:
|
|
6
33
|
def __init__(self, db_path: str):
|
|
7
34
|
self.db_path = db_path
|
|
8
|
-
self.
|
|
35
|
+
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
|
|
36
|
+
self.engine: AsyncEngine | None = None
|
|
37
|
+
self.async_session_maker: sessionmaker | None = None
|
|
9
38
|
self.sqlite_init_path = os.path.join(
|
|
10
39
|
os.path.dirname(__file__), "sqlite_init.sql"
|
|
11
40
|
)
|
|
12
41
|
|
|
13
42
|
async def initialize(self):
|
|
14
43
|
"""Initialize the SQLite database and create the documents table if it doesn't exist."""
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
44
|
+
await self.connect()
|
|
45
|
+
async with self.engine.begin() as conn: # type: ignore
|
|
46
|
+
# Create tables using SQLModel
|
|
47
|
+
await conn.run_sync(BaseDocModel.metadata.create_all)
|
|
48
|
+
|
|
49
|
+
try:
|
|
50
|
+
await conn.execute(
|
|
51
|
+
text(
|
|
52
|
+
"ALTER TABLE documents ADD COLUMN kb_doc_id TEXT "
|
|
53
|
+
"GENERATED ALWAYS AS (json_extract(metadata, '$.kb_doc_id')) STORED"
|
|
54
|
+
)
|
|
55
|
+
)
|
|
56
|
+
await conn.execute(
|
|
57
|
+
text(
|
|
58
|
+
"ALTER TABLE documents ADD COLUMN user_id TEXT "
|
|
59
|
+
"GENERATED ALWAYS AS (json_extract(metadata, '$.user_id')) STORED"
|
|
60
|
+
)
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
# Create indexes
|
|
64
|
+
await conn.execute(
|
|
65
|
+
text(
|
|
66
|
+
"CREATE INDEX IF NOT EXISTS idx_documents_kb_doc_id ON documents(kb_doc_id)"
|
|
67
|
+
)
|
|
68
|
+
)
|
|
69
|
+
await conn.execute(
|
|
70
|
+
text(
|
|
71
|
+
"CREATE INDEX IF NOT EXISTS idx_documents_user_id ON documents(user_id)"
|
|
72
|
+
)
|
|
73
|
+
)
|
|
74
|
+
except BaseException:
|
|
75
|
+
pass
|
|
76
|
+
|
|
77
|
+
await conn.commit()
|
|
24
78
|
|
|
25
79
|
async def connect(self):
|
|
26
80
|
"""Connect to the SQLite database."""
|
|
27
|
-
self.
|
|
81
|
+
if self.engine is None:
|
|
82
|
+
self.engine = create_async_engine(
|
|
83
|
+
self.DATABASE_URL,
|
|
84
|
+
echo=False,
|
|
85
|
+
future=True,
|
|
86
|
+
)
|
|
87
|
+
self.async_session_maker = sessionmaker(
|
|
88
|
+
self.engine, # type: ignore
|
|
89
|
+
class_=AsyncSession,
|
|
90
|
+
expire_on_commit=False,
|
|
91
|
+
) # type: ignore
|
|
92
|
+
|
|
93
|
+
@asynccontextmanager
|
|
94
|
+
async def get_session(self):
|
|
95
|
+
"""Context manager for database sessions."""
|
|
96
|
+
async with self.async_session_maker() as session: # type: ignore
|
|
97
|
+
yield session
|
|
28
98
|
|
|
29
|
-
async def get_documents(
|
|
99
|
+
async def get_documents(
|
|
100
|
+
self,
|
|
101
|
+
metadata_filters: dict,
|
|
102
|
+
ids: list | None = None,
|
|
103
|
+
offset: int | None = 0,
|
|
104
|
+
limit: int | None = 100,
|
|
105
|
+
) -> list[dict]:
|
|
30
106
|
"""Retrieve documents by metadata filters and ids.
|
|
31
107
|
|
|
32
108
|
Args:
|
|
33
109
|
metadata_filters (dict): The metadata filters to apply.
|
|
110
|
+
ids (list | None): Optional list of document IDs to filter.
|
|
111
|
+
offset (int | None): Offset for pagination.
|
|
112
|
+
limit (int | None): Limit for pagination.
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
list: The list of documents that match the filters.
|
|
116
|
+
"""
|
|
117
|
+
if self.engine is None:
|
|
118
|
+
logger.warning(
|
|
119
|
+
"Database connection is not initialized, returning empty result"
|
|
120
|
+
)
|
|
121
|
+
return []
|
|
122
|
+
|
|
123
|
+
async with self.get_session() as session:
|
|
124
|
+
query = select(Document)
|
|
125
|
+
|
|
126
|
+
for key, val in metadata_filters.items():
|
|
127
|
+
query = query.where(
|
|
128
|
+
text(f"json_extract(metadata, '$.{key}') = :filter_{key}")
|
|
129
|
+
).params(**{f"filter_{key}": val})
|
|
130
|
+
|
|
131
|
+
if ids is not None and len(ids) > 0:
|
|
132
|
+
valid_ids = [int(i) for i in ids if i != -1]
|
|
133
|
+
if valid_ids:
|
|
134
|
+
query = query.where(col(Document.id).in_(valid_ids))
|
|
135
|
+
|
|
136
|
+
if limit is not None:
|
|
137
|
+
query = query.limit(limit)
|
|
138
|
+
if offset is not None:
|
|
139
|
+
query = query.offset(offset)
|
|
140
|
+
|
|
141
|
+
result = await session.execute(query)
|
|
142
|
+
documents = result.scalars().all()
|
|
143
|
+
|
|
144
|
+
return [self._document_to_dict(doc) for doc in documents]
|
|
145
|
+
|
|
146
|
+
async def insert_document(self, doc_id: str, text: str, metadata: dict) -> int:
|
|
147
|
+
"""Insert a single document and return its integer ID.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
doc_id (str): The document ID (UUID string).
|
|
151
|
+
text (str): The document text.
|
|
152
|
+
metadata (dict): The document metadata.
|
|
34
153
|
|
|
35
154
|
Returns:
|
|
36
|
-
|
|
155
|
+
int: The integer ID of the inserted document.
|
|
156
|
+
"""
|
|
157
|
+
assert self.engine is not None, "Database connection is not initialized."
|
|
158
|
+
|
|
159
|
+
async with self.get_session() as session:
|
|
160
|
+
async with session.begin():
|
|
161
|
+
document = Document(
|
|
162
|
+
doc_id=doc_id,
|
|
163
|
+
text=text,
|
|
164
|
+
metadata_=json.dumps(metadata),
|
|
165
|
+
created_at=datetime.now(),
|
|
166
|
+
updated_at=datetime.now(),
|
|
167
|
+
)
|
|
168
|
+
session.add(document)
|
|
169
|
+
await session.flush() # Flush to get the ID
|
|
170
|
+
return document.id # type: ignore
|
|
171
|
+
|
|
172
|
+
async def insert_documents_batch(
|
|
173
|
+
self, doc_ids: list[str], texts: list[str], metadatas: list[dict]
|
|
174
|
+
) -> list[int]:
|
|
175
|
+
"""Batch insert documents and return their integer IDs.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
doc_ids (list[str]): List of document IDs (UUID strings).
|
|
179
|
+
texts (list[str]): List of document texts.
|
|
180
|
+
metadatas (list[dict]): List of document metadata.
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
list[int]: List of integer IDs of the inserted documents.
|
|
184
|
+
"""
|
|
185
|
+
assert self.engine is not None, "Database connection is not initialized."
|
|
186
|
+
|
|
187
|
+
async with self.get_session() as session:
|
|
188
|
+
async with session.begin():
|
|
189
|
+
import json
|
|
190
|
+
|
|
191
|
+
documents = []
|
|
192
|
+
for doc_id, text, metadata in zip(doc_ids, texts, metadatas):
|
|
193
|
+
document = Document(
|
|
194
|
+
doc_id=doc_id,
|
|
195
|
+
text=text,
|
|
196
|
+
metadata_=json.dumps(metadata),
|
|
197
|
+
created_at=datetime.now(),
|
|
198
|
+
updated_at=datetime.now(),
|
|
199
|
+
)
|
|
200
|
+
documents.append(document)
|
|
201
|
+
session.add(document)
|
|
202
|
+
|
|
203
|
+
await session.flush() # Flush to get all IDs
|
|
204
|
+
return [doc.id for doc in documents] # type: ignore
|
|
205
|
+
|
|
206
|
+
async def delete_document_by_doc_id(self, doc_id: str):
|
|
207
|
+
"""Delete a document by its doc_id.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
doc_id (str): The doc_id of the document to delete.
|
|
37
211
|
"""
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
where_sql = " AND ".join(where_clauses) or "1=1"
|
|
49
|
-
|
|
50
|
-
result = []
|
|
51
|
-
async with self.connection.cursor() as cursor:
|
|
52
|
-
sql = "SELECT * FROM documents WHERE " + where_sql
|
|
53
|
-
await cursor.execute(sql, values)
|
|
54
|
-
for row in await cursor.fetchall():
|
|
55
|
-
result.append(await self.tuple_to_dict(row))
|
|
56
|
-
return result
|
|
212
|
+
assert self.engine is not None, "Database connection is not initialized."
|
|
213
|
+
|
|
214
|
+
async with self.get_session() as session:
|
|
215
|
+
async with session.begin():
|
|
216
|
+
query = select(Document).where(col(Document.doc_id) == doc_id)
|
|
217
|
+
result = await session.execute(query)
|
|
218
|
+
document = result.scalar_one_or_none()
|
|
219
|
+
|
|
220
|
+
if document:
|
|
221
|
+
await session.delete(document)
|
|
57
222
|
|
|
58
223
|
async def get_document_by_doc_id(self, doc_id: str):
|
|
59
224
|
"""Retrieve a document by its doc_id.
|
|
@@ -62,28 +227,91 @@ class DocumentStorage:
|
|
|
62
227
|
doc_id (str): The doc_id of the document to retrieve.
|
|
63
228
|
|
|
64
229
|
Returns:
|
|
65
|
-
dict: The document data.
|
|
230
|
+
dict: The document data or None if not found.
|
|
66
231
|
"""
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
232
|
+
assert self.engine is not None, "Database connection is not initialized."
|
|
233
|
+
|
|
234
|
+
async with self.get_session() as session:
|
|
235
|
+
query = select(Document).where(col(Document.doc_id) == doc_id)
|
|
236
|
+
result = await session.execute(query)
|
|
237
|
+
document = result.scalar_one_or_none()
|
|
238
|
+
|
|
239
|
+
if document:
|
|
240
|
+
return self._document_to_dict(document)
|
|
241
|
+
return None
|
|
74
242
|
|
|
75
243
|
async def update_document_by_doc_id(self, doc_id: str, new_text: str):
|
|
76
|
-
"""
|
|
244
|
+
"""Update a document by its doc_id.
|
|
77
245
|
|
|
78
246
|
Args:
|
|
79
247
|
doc_id (str): The doc_id.
|
|
80
248
|
new_text (str): The new text to update the document with.
|
|
81
249
|
"""
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
250
|
+
assert self.engine is not None, "Database connection is not initialized."
|
|
251
|
+
|
|
252
|
+
async with self.get_session() as session:
|
|
253
|
+
async with session.begin():
|
|
254
|
+
query = select(Document).where(col(Document.doc_id) == doc_id)
|
|
255
|
+
result = await session.execute(query)
|
|
256
|
+
document = result.scalar_one_or_none()
|
|
257
|
+
|
|
258
|
+
if document:
|
|
259
|
+
document.text = new_text
|
|
260
|
+
document.updated_at = datetime.now()
|
|
261
|
+
session.add(document)
|
|
262
|
+
|
|
263
|
+
async def delete_documents(self, metadata_filters: dict):
|
|
264
|
+
"""Delete documents by their metadata filters.
|
|
265
|
+
|
|
266
|
+
Args:
|
|
267
|
+
metadata_filters (dict): The metadata filters to apply.
|
|
268
|
+
"""
|
|
269
|
+
if self.engine is None:
|
|
270
|
+
logger.warning(
|
|
271
|
+
"Database connection is not initialized, skipping delete operation"
|
|
85
272
|
)
|
|
86
|
-
|
|
273
|
+
return
|
|
274
|
+
|
|
275
|
+
async with self.get_session() as session:
|
|
276
|
+
async with session.begin():
|
|
277
|
+
query = select(Document)
|
|
278
|
+
|
|
279
|
+
for key, val in metadata_filters.items():
|
|
280
|
+
query = query.where(
|
|
281
|
+
text(f"json_extract(metadata, '$.{key}') = :filter_{key}")
|
|
282
|
+
).params(**{f"filter_{key}": val})
|
|
283
|
+
|
|
284
|
+
result = await session.execute(query)
|
|
285
|
+
documents = result.scalars().all()
|
|
286
|
+
|
|
287
|
+
for doc in documents:
|
|
288
|
+
await session.delete(doc)
|
|
289
|
+
|
|
290
|
+
async def count_documents(self, metadata_filters: dict | None = None) -> int:
|
|
291
|
+
"""Count documents in the database.
|
|
292
|
+
|
|
293
|
+
Args:
|
|
294
|
+
metadata_filters (dict | None): Metadata filters to apply.
|
|
295
|
+
|
|
296
|
+
Returns:
|
|
297
|
+
int: The count of documents.
|
|
298
|
+
"""
|
|
299
|
+
if self.engine is None:
|
|
300
|
+
logger.warning("Database connection is not initialized, returning 0")
|
|
301
|
+
return 0
|
|
302
|
+
|
|
303
|
+
async with self.get_session() as session:
|
|
304
|
+
query = select(func.count(col(Document.id)))
|
|
305
|
+
|
|
306
|
+
if metadata_filters:
|
|
307
|
+
for key, val in metadata_filters.items():
|
|
308
|
+
query = query.where(
|
|
309
|
+
text(f"json_extract(metadata, '$.{key}') = :filter_{key}")
|
|
310
|
+
).params(**{f"filter_{key}": val})
|
|
311
|
+
|
|
312
|
+
result = await session.execute(query)
|
|
313
|
+
count = result.scalar_one_or_none()
|
|
314
|
+
return count if count is not None else 0
|
|
87
315
|
|
|
88
316
|
async def get_user_ids(self) -> list[str]:
|
|
89
317
|
"""Retrieve all user IDs from the documents table.
|
|
@@ -91,11 +319,38 @@ class DocumentStorage:
|
|
|
91
319
|
Returns:
|
|
92
320
|
list: A list of user IDs.
|
|
93
321
|
"""
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
322
|
+
assert self.engine is not None, "Database connection is not initialized."
|
|
323
|
+
|
|
324
|
+
async with self.get_session() as session:
|
|
325
|
+
query = text(
|
|
326
|
+
"SELECT DISTINCT user_id FROM documents WHERE user_id IS NOT NULL"
|
|
327
|
+
)
|
|
328
|
+
result = await session.execute(query)
|
|
329
|
+
rows = result.fetchall()
|
|
97
330
|
return [row[0] for row in rows]
|
|
98
331
|
|
|
332
|
+
def _document_to_dict(self, document: Document) -> dict:
|
|
333
|
+
"""Convert a Document model to a dictionary.
|
|
334
|
+
|
|
335
|
+
Args:
|
|
336
|
+
document (Document): The document to convert.
|
|
337
|
+
|
|
338
|
+
Returns:
|
|
339
|
+
dict: The converted dictionary.
|
|
340
|
+
"""
|
|
341
|
+
return {
|
|
342
|
+
"id": document.id,
|
|
343
|
+
"doc_id": document.doc_id,
|
|
344
|
+
"text": document.text,
|
|
345
|
+
"metadata": document.metadata_,
|
|
346
|
+
"created_at": document.created_at.isoformat()
|
|
347
|
+
if isinstance(document.created_at, datetime)
|
|
348
|
+
else document.created_at,
|
|
349
|
+
"updated_at": document.updated_at.isoformat()
|
|
350
|
+
if isinstance(document.updated_at, datetime)
|
|
351
|
+
else document.updated_at,
|
|
352
|
+
}
|
|
353
|
+
|
|
99
354
|
async def tuple_to_dict(self, row):
|
|
100
355
|
"""Convert a tuple to a dictionary.
|
|
101
356
|
|
|
@@ -104,6 +359,8 @@ class DocumentStorage:
|
|
|
104
359
|
|
|
105
360
|
Returns:
|
|
106
361
|
dict: The converted dictionary.
|
|
362
|
+
|
|
363
|
+
Note: This method is kept for backward compatibility but is no longer used internally.
|
|
107
364
|
"""
|
|
108
365
|
return {
|
|
109
366
|
"id": row[0],
|
|
@@ -116,6 +373,7 @@ class DocumentStorage:
|
|
|
116
373
|
|
|
117
374
|
async def close(self):
|
|
118
375
|
"""Close the connection to the SQLite database."""
|
|
119
|
-
if self.
|
|
120
|
-
await self.
|
|
121
|
-
self.
|
|
376
|
+
if self.engine:
|
|
377
|
+
await self.engine.dispose()
|
|
378
|
+
self.engine = None
|
|
379
|
+
self.async_session_maker = None
|