AstrBot 4.3.3__py3-none-any.whl → 4.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. astrbot/core/agent/mcp_client.py +18 -4
  2. astrbot/core/agent/runners/tool_loop_agent_runner.py +31 -2
  3. astrbot/core/astr_agent_context.py +1 -0
  4. astrbot/core/astrbot_config_mgr.py +23 -51
  5. astrbot/core/config/default.py +139 -14
  6. astrbot/core/conversation_mgr.py +36 -1
  7. astrbot/core/core_lifecycle.py +24 -5
  8. astrbot/core/db/migration/migra_45_to_46.py +44 -0
  9. astrbot/core/db/vec_db/base.py +33 -2
  10. astrbot/core/db/vec_db/faiss_impl/document_storage.py +310 -52
  11. astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +31 -3
  12. astrbot/core/db/vec_db/faiss_impl/vec_db.py +81 -23
  13. astrbot/core/file_token_service.py +6 -1
  14. astrbot/core/initial_loader.py +6 -3
  15. astrbot/core/knowledge_base/chunking/__init__.py +11 -0
  16. astrbot/core/knowledge_base/chunking/base.py +24 -0
  17. astrbot/core/knowledge_base/chunking/fixed_size.py +57 -0
  18. astrbot/core/knowledge_base/chunking/recursive.py +155 -0
  19. astrbot/core/knowledge_base/kb_db_sqlite.py +299 -0
  20. astrbot/core/knowledge_base/kb_helper.py +348 -0
  21. astrbot/core/knowledge_base/kb_mgr.py +287 -0
  22. astrbot/core/knowledge_base/models.py +114 -0
  23. astrbot/core/knowledge_base/parsers/__init__.py +15 -0
  24. astrbot/core/knowledge_base/parsers/base.py +50 -0
  25. astrbot/core/knowledge_base/parsers/markitdown_parser.py +25 -0
  26. astrbot/core/knowledge_base/parsers/pdf_parser.py +100 -0
  27. astrbot/core/knowledge_base/parsers/text_parser.py +41 -0
  28. astrbot/core/knowledge_base/parsers/util.py +13 -0
  29. astrbot/core/knowledge_base/retrieval/__init__.py +16 -0
  30. astrbot/core/knowledge_base/retrieval/hit_stopwords.txt +767 -0
  31. astrbot/core/knowledge_base/retrieval/manager.py +273 -0
  32. astrbot/core/knowledge_base/retrieval/rank_fusion.py +138 -0
  33. astrbot/core/knowledge_base/retrieval/sparse_retriever.py +130 -0
  34. astrbot/core/pipeline/process_stage/method/llm_request.py +61 -21
  35. astrbot/core/pipeline/process_stage/utils.py +80 -0
  36. astrbot/core/pipeline/scheduler.py +1 -1
  37. astrbot/core/platform/astr_message_event.py +8 -7
  38. astrbot/core/platform/manager.py +4 -0
  39. astrbot/core/platform/sources/misskey/misskey_adapter.py +380 -44
  40. astrbot/core/platform/sources/misskey/misskey_api.py +581 -45
  41. astrbot/core/platform/sources/misskey/misskey_event.py +76 -41
  42. astrbot/core/platform/sources/misskey/misskey_utils.py +254 -43
  43. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +2 -1
  44. astrbot/core/platform/sources/satori/satori_adapter.py +27 -1
  45. astrbot/core/platform/sources/satori/satori_event.py +270 -77
  46. astrbot/core/platform/sources/webchat/webchat_adapter.py +0 -1
  47. astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +289 -0
  48. astrbot/core/platform/sources/wecom_ai_bot/__init__.py +17 -0
  49. astrbot/core/platform/sources/wecom_ai_bot/ierror.py +20 -0
  50. astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +445 -0
  51. astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +378 -0
  52. astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +149 -0
  53. astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +148 -0
  54. astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +166 -0
  55. astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +199 -0
  56. astrbot/core/provider/manager.py +14 -9
  57. astrbot/core/provider/provider.py +67 -0
  58. astrbot/core/provider/sources/anthropic_source.py +4 -4
  59. astrbot/core/provider/sources/dashscope_source.py +10 -9
  60. astrbot/core/provider/sources/dify_source.py +6 -8
  61. astrbot/core/provider/sources/gemini_embedding_source.py +1 -2
  62. astrbot/core/provider/sources/openai_embedding_source.py +1 -2
  63. astrbot/core/provider/sources/openai_source.py +18 -15
  64. astrbot/core/provider/sources/openai_tts_api_source.py +1 -1
  65. astrbot/core/star/context.py +3 -0
  66. astrbot/core/star/star.py +6 -0
  67. astrbot/core/star/star_manager.py +13 -7
  68. astrbot/core/umop_config_router.py +81 -0
  69. astrbot/core/updator.py +1 -1
  70. astrbot/core/utils/io.py +23 -12
  71. astrbot/dashboard/routes/__init__.py +2 -0
  72. astrbot/dashboard/routes/config.py +137 -9
  73. astrbot/dashboard/routes/knowledge_base.py +1065 -0
  74. astrbot/dashboard/routes/plugin.py +24 -5
  75. astrbot/dashboard/routes/tools.py +14 -0
  76. astrbot/dashboard/routes/update.py +1 -1
  77. astrbot/dashboard/server.py +6 -0
  78. astrbot/dashboard/utils.py +161 -0
  79. {astrbot-4.3.3.dist-info → astrbot-4.5.0.dist-info}/METADATA +91 -55
  80. {astrbot-4.3.3.dist-info → astrbot-4.5.0.dist-info}/RECORD +83 -50
  81. {astrbot-4.3.3.dist-info → astrbot-4.5.0.dist-info}/WHEEL +0 -0
  82. {astrbot-4.3.3.dist-info → astrbot-4.5.0.dist-info}/entry_points.txt +0 -0
  83. {astrbot-4.3.3.dist-info → astrbot-4.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -17,7 +17,6 @@ import os
17
17
  from .event_bus import EventBus
18
18
  from . import astrbot_config, html_renderer
19
19
  from asyncio import Queue
20
- from typing import List
21
20
  from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
22
21
  from astrbot.core.star import PluginManager
23
22
  from astrbot.core.platform.manager import PlatformManager
@@ -26,14 +25,17 @@ from astrbot.core.persona_mgr import PersonaManager
26
25
  from astrbot.core.provider.manager import ProviderManager
27
26
  from astrbot.core import LogBroker
28
27
  from astrbot.core.db import BaseDatabase
28
+ from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46
29
29
  from astrbot.core.updator import AstrBotUpdator
30
30
  from astrbot.core import logger, sp
31
31
  from astrbot.core.config.default import VERSION
32
32
  from astrbot.core.conversation_mgr import ConversationManager
33
33
  from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
34
+ from astrbot.core.umop_config_router import UmopConfigRouter
34
35
  from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
35
36
  from astrbot.core.star.star_handler import star_handlers_registry, EventType
36
37
  from astrbot.core.star.star_handler import star_map
38
+ from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
37
39
 
38
40
 
39
41
  class AstrBotCoreLifecycle:
@@ -84,11 +86,21 @@ class AstrBotCoreLifecycle:
84
86
 
85
87
  await html_renderer.initialize()
86
88
 
89
+ # 初始化 UMOP 配置路由器
90
+ self.umop_config_router = UmopConfigRouter(sp=sp)
91
+
87
92
  # 初始化 AstrBot 配置管理器
88
93
  self.astrbot_config_mgr = AstrBotConfigManager(
89
- default_config=self.astrbot_config, sp=sp
94
+ default_config=self.astrbot_config, ucr=self.umop_config_router, sp=sp
90
95
  )
91
96
 
97
+ # 4.5 to 4.6 migration for umop_config_router
98
+ try:
99
+ await migrate_45_to_46(self.astrbot_config_mgr, self.umop_config_router)
100
+ except Exception as e:
101
+ logger.error(f"Migration from version 4.5 to 4.6 failed: {e!s}")
102
+ logger.error(traceback.format_exc())
103
+
92
104
  # 初始化事件队列
93
105
  self.event_queue = Queue()
94
106
 
@@ -110,6 +122,9 @@ class AstrBotCoreLifecycle:
110
122
  # 初始化平台消息历史管理器
111
123
  self.platform_message_history_manager = PlatformMessageHistoryManager(self.db)
112
124
 
125
+ # 初始化知识库管理器
126
+ self.kb_manager = KnowledgeBaseManager(self.provider_manager)
127
+
113
128
  # 初始化提供给插件的上下文
114
129
  self.star_context = Context(
115
130
  self.event_queue,
@@ -121,6 +136,7 @@ class AstrBotCoreLifecycle:
121
136
  self.platform_message_history_manager,
122
137
  self.persona_mgr,
123
138
  self.astrbot_config_mgr,
139
+ self.kb_manager,
124
140
  )
125
141
 
126
142
  # 初始化插件管理器
@@ -132,8 +148,9 @@ class AstrBotCoreLifecycle:
132
148
  # 根据配置实例化各个 Provider
133
149
  await self.provider_manager.initialize()
134
150
 
135
- # 初始化消息事件流水线调度器
151
+ await self.kb_manager.initialize()
136
152
 
153
+ # 初始化消息事件流水线调度器
137
154
  self.pipeline_scheduler_mapping = await self.load_pipeline_scheduler()
138
155
 
139
156
  # 初始化更新器
@@ -148,7 +165,7 @@ class AstrBotCoreLifecycle:
148
165
  self.start_time = int(time.time())
149
166
 
150
167
  # 初始化当前任务列表
151
- self.curr_tasks: List[asyncio.Task] = []
168
+ self.curr_tasks: list[asyncio.Task] = []
152
169
 
153
170
  # 根据配置实例化各个平台适配器
154
171
  await self.platform_manager.initialize()
@@ -233,6 +250,7 @@ class AstrBotCoreLifecycle:
233
250
 
234
251
  await self.provider_manager.terminate()
235
252
  await self.platform_manager.terminate()
253
+ await self.kb_manager.terminate()
236
254
  self.dashboard_shutdown_event.set()
237
255
 
238
256
  # 再次遍历curr_tasks等待每个任务真正结束
@@ -248,12 +266,13 @@ class AstrBotCoreLifecycle:
248
266
  """重启 AstrBot 核心生命周期管理类, 终止各个管理器并重新加载平台实例"""
249
267
  await self.provider_manager.terminate()
250
268
  await self.platform_manager.terminate()
269
+ await self.kb_manager.terminate()
251
270
  self.dashboard_shutdown_event.set()
252
271
  threading.Thread(
253
272
  target=self.astrbot_updator._reboot, name="restart", daemon=True
254
273
  ).start()
255
274
 
256
- def load_platform(self) -> List[asyncio.Task]:
275
+ def load_platform(self) -> list[asyncio.Task]:
257
276
  """加载平台实例并返回所有平台实例的异步任务列表"""
258
277
  tasks = []
259
278
  platform_insts = self.platform_manager.get_insts()
@@ -0,0 +1,44 @@
1
+ from astrbot.api import logger, sp
2
+ from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
3
+ from astrbot.core.umop_config_router import UmopConfigRouter
4
+
5
+
6
+ async def migrate_45_to_46(acm: AstrBotConfigManager, ucr: UmopConfigRouter):
7
+ abconf_data = acm.abconf_data
8
+
9
+ if not isinstance(abconf_data, dict):
10
+ # should be unreachable
11
+ logger.warning(
12
+ f"migrate_45_to_46: abconf_data is not a dict (type={type(abconf_data)}). Value: {abconf_data!r}"
13
+ )
14
+ return
15
+
16
+ # 如果任何一项带有 umop,则说明需要迁移
17
+ need_migration = False
18
+ for conf_id, conf_info in abconf_data.items():
19
+ if isinstance(conf_info, dict) and "umop" in conf_info:
20
+ need_migration = True
21
+ break
22
+
23
+ if not need_migration:
24
+ return
25
+
26
+ logger.info("Starting migration from version 4.5 to 4.6")
27
+
28
+ # extract umo->conf_id mapping
29
+ umo_to_conf_id = {}
30
+ for conf_id, conf_info in abconf_data.items():
31
+ if isinstance(conf_info, dict) and "umop" in conf_info:
32
+ umop_ls = conf_info.pop("umop")
33
+ if not isinstance(umop_ls, list):
34
+ continue
35
+ for umo in umop_ls:
36
+ if isinstance(umo, str) and umo not in umo_to_conf_id:
37
+ umo_to_conf_id[umo] = conf_id
38
+
39
+ # update the abconf data
40
+ await sp.global_put("abconf_mapping", abconf_data)
41
+ # update the umop config router
42
+ await ucr.update_routing_data(umo_to_conf_id)
43
+
44
+ logger.info("Migration from version 45 to 46 completed successfully")
@@ -16,14 +16,42 @@ class BaseVecDB:
16
16
  pass
17
17
 
18
18
  @abc.abstractmethod
19
- async def insert(self, content: str, metadata: dict = None, id: str = None) -> int:
19
+ async def insert(
20
+ self, content: str, metadata: dict | None = None, id: str | None = None
21
+ ) -> int:
20
22
  """
21
23
  插入一条文本和其对应向量,自动生成 ID 并保持一致性。
22
24
  """
23
25
  ...
24
26
 
25
27
  @abc.abstractmethod
26
- async def retrieve(self, query: str, top_k: int = 5) -> list[Result]:
28
+ async def insert_batch(
29
+ self,
30
+ contents: list[str],
31
+ metadatas: list[dict] | None = None,
32
+ ids: list[str] | None = None,
33
+ batch_size: int = 32,
34
+ tasks_limit: int = 3,
35
+ max_retries: int = 3,
36
+ progress_callback=None,
37
+ ) -> int:
38
+ """
39
+ 批量插入文本和其对应向量,自动生成 ID 并保持一致性。
40
+
41
+ Args:
42
+ progress_callback: 进度回调函数,接收参数 (current, total)
43
+ """
44
+ ...
45
+
46
+ @abc.abstractmethod
47
+ async def retrieve(
48
+ self,
49
+ query: str,
50
+ top_k: int = 5,
51
+ fetch_k: int = 20,
52
+ rerank: bool = False,
53
+ metadata_filters: dict | None = None,
54
+ ) -> list[Result]:
27
55
  """
28
56
  搜索最相似的文档。
29
57
  Args:
@@ -44,3 +72,6 @@ class BaseVecDB:
44
72
  bool: 删除是否成功
45
73
  """
46
74
  ...
75
+
76
+ @abc.abstractmethod
77
+ async def close(self): ...
@@ -1,59 +1,224 @@
1
- import aiosqlite
2
1
  import os
2
+ import json
3
+ from datetime import datetime
4
+ from contextlib import asynccontextmanager
5
+
6
+ from sqlalchemy import Text, Column
7
+ from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
8
+ from sqlalchemy.orm import sessionmaker
9
+ from sqlmodel import Field, SQLModel, select, col, func, text, MetaData
10
+ from astrbot.core import logger
11
+
12
+
13
+ class BaseDocModel(SQLModel, table=False):
14
+ metadata = MetaData()
15
+
16
+
17
+ class Document(BaseDocModel, table=True):
18
+ """SQLModel for documents table."""
19
+
20
+ __tablename__ = "documents" # type: ignore
21
+
22
+ id: int | None = Field(
23
+ default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}
24
+ )
25
+ doc_id: str = Field(nullable=False)
26
+ text: str = Field(nullable=False)
27
+ metadata_: str | None = Field(default=None, sa_column=Column("metadata", Text))
28
+ created_at: datetime | None = Field(default=None)
29
+ updated_at: datetime | None = Field(default=None)
3
30
 
4
31
 
5
32
  class DocumentStorage:
6
33
  def __init__(self, db_path: str):
7
34
  self.db_path = db_path
8
- self.connection = None
35
+ self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
36
+ self.engine: AsyncEngine | None = None
37
+ self.async_session_maker: sessionmaker | None = None
9
38
  self.sqlite_init_path = os.path.join(
10
39
  os.path.dirname(__file__), "sqlite_init.sql"
11
40
  )
12
41
 
13
42
  async def initialize(self):
14
43
  """Initialize the SQLite database and create the documents table if it doesn't exist."""
15
- if not os.path.exists(self.db_path):
16
- await self.connect()
17
- async with self.connection.cursor() as cursor:
18
- with open(self.sqlite_init_path, "r", encoding="utf-8") as f:
19
- sql_script = f.read()
20
- await cursor.executescript(sql_script)
21
- await self.connection.commit()
22
- else:
23
- await self.connect()
44
+ await self.connect()
45
+ async with self.engine.begin() as conn: # type: ignore
46
+ # Create tables using SQLModel
47
+ await conn.run_sync(BaseDocModel.metadata.create_all)
48
+
49
+ try:
50
+ await conn.execute(
51
+ text(
52
+ "ALTER TABLE documents ADD COLUMN kb_doc_id TEXT "
53
+ "GENERATED ALWAYS AS (json_extract(metadata, '$.kb_doc_id')) STORED"
54
+ )
55
+ )
56
+ await conn.execute(
57
+ text(
58
+ "ALTER TABLE documents ADD COLUMN user_id TEXT "
59
+ "GENERATED ALWAYS AS (json_extract(metadata, '$.user_id')) STORED"
60
+ )
61
+ )
62
+
63
+ # Create indexes
64
+ await conn.execute(
65
+ text(
66
+ "CREATE INDEX IF NOT EXISTS idx_documents_kb_doc_id ON documents(kb_doc_id)"
67
+ )
68
+ )
69
+ await conn.execute(
70
+ text(
71
+ "CREATE INDEX IF NOT EXISTS idx_documents_user_id ON documents(user_id)"
72
+ )
73
+ )
74
+ except BaseException:
75
+ pass
76
+
77
+ await conn.commit()
24
78
 
25
79
  async def connect(self):
26
80
  """Connect to the SQLite database."""
27
- self.connection = await aiosqlite.connect(self.db_path)
81
+ if self.engine is None:
82
+ self.engine = create_async_engine(
83
+ self.DATABASE_URL,
84
+ echo=False,
85
+ future=True,
86
+ )
87
+ self.async_session_maker = sessionmaker(
88
+ self.engine, # type: ignore
89
+ class_=AsyncSession,
90
+ expire_on_commit=False,
91
+ ) # type: ignore
92
+
93
+ @asynccontextmanager
94
+ async def get_session(self):
95
+ """Context manager for database sessions."""
96
+ async with self.async_session_maker() as session: # type: ignore
97
+ yield session
28
98
 
29
- async def get_documents(self, metadata_filters: dict, ids: list = None):
99
+ async def get_documents(
100
+ self,
101
+ metadata_filters: dict,
102
+ ids: list | None = None,
103
+ offset: int | None = 0,
104
+ limit: int | None = 100,
105
+ ) -> list[dict]:
30
106
  """Retrieve documents by metadata filters and ids.
31
107
 
32
108
  Args:
33
109
  metadata_filters (dict): The metadata filters to apply.
110
+ ids (list | None): Optional list of document IDs to filter.
111
+ offset (int | None): Offset for pagination.
112
+ limit (int | None): Limit for pagination.
113
+
114
+ Returns:
115
+ list: The list of documents that match the filters.
116
+ """
117
+ if self.engine is None:
118
+ logger.warning(
119
+ "Database connection is not initialized, returning empty result"
120
+ )
121
+ return []
122
+
123
+ async with self.get_session() as session:
124
+ query = select(Document)
125
+
126
+ for key, val in metadata_filters.items():
127
+ query = query.where(
128
+ text(f"json_extract(metadata, '$.{key}') = :filter_{key}")
129
+ ).params(**{f"filter_{key}": val})
130
+
131
+ if ids is not None and len(ids) > 0:
132
+ valid_ids = [int(i) for i in ids if i != -1]
133
+ if valid_ids:
134
+ query = query.where(col(Document.id).in_(valid_ids))
135
+
136
+ if limit is not None:
137
+ query = query.limit(limit)
138
+ if offset is not None:
139
+ query = query.offset(offset)
140
+
141
+ result = await session.execute(query)
142
+ documents = result.scalars().all()
143
+
144
+ return [self._document_to_dict(doc) for doc in documents]
145
+
146
+ async def insert_document(self, doc_id: str, text: str, metadata: dict) -> int:
147
+ """Insert a single document and return its integer ID.
148
+
149
+ Args:
150
+ doc_id (str): The document ID (UUID string).
151
+ text (str): The document text.
152
+ metadata (dict): The document metadata.
34
153
 
35
154
  Returns:
36
- list: The list of document IDs(primary key, not doc_id) that match the filters.
155
+ int: The integer ID of the inserted document.
156
+ """
157
+ assert self.engine is not None, "Database connection is not initialized."
158
+
159
+ async with self.get_session() as session:
160
+ async with session.begin():
161
+ document = Document(
162
+ doc_id=doc_id,
163
+ text=text,
164
+ metadata_=json.dumps(metadata),
165
+ created_at=datetime.now(),
166
+ updated_at=datetime.now(),
167
+ )
168
+ session.add(document)
169
+ await session.flush() # Flush to get the ID
170
+ return document.id # type: ignore
171
+
172
+ async def insert_documents_batch(
173
+ self, doc_ids: list[str], texts: list[str], metadatas: list[dict]
174
+ ) -> list[int]:
175
+ """Batch insert documents and return their integer IDs.
176
+
177
+ Args:
178
+ doc_ids (list[str]): List of document IDs (UUID strings).
179
+ texts (list[str]): List of document texts.
180
+ metadatas (list[dict]): List of document metadata.
181
+
182
+ Returns:
183
+ list[int]: List of integer IDs of the inserted documents.
184
+ """
185
+ assert self.engine is not None, "Database connection is not initialized."
186
+
187
+ async with self.get_session() as session:
188
+ async with session.begin():
189
+ import json
190
+
191
+ documents = []
192
+ for doc_id, text, metadata in zip(doc_ids, texts, metadatas):
193
+ document = Document(
194
+ doc_id=doc_id,
195
+ text=text,
196
+ metadata_=json.dumps(metadata),
197
+ created_at=datetime.now(),
198
+ updated_at=datetime.now(),
199
+ )
200
+ documents.append(document)
201
+ session.add(document)
202
+
203
+ await session.flush() # Flush to get all IDs
204
+ return [doc.id for doc in documents] # type: ignore
205
+
206
+ async def delete_document_by_doc_id(self, doc_id: str):
207
+ """Delete a document by its doc_id.
208
+
209
+ Args:
210
+ doc_id (str): The doc_id of the document to delete.
37
211
  """
38
- # metadata filter -> SQL WHERE clause
39
- where_clauses = []
40
- values = []
41
- for key, val in metadata_filters.items():
42
- where_clauses.append(f"json_extract(metadata, '$.{key}') = ?")
43
- values.append(val)
44
- if ids is not None and len(ids) > 0:
45
- ids = [str(i) for i in ids if i != -1]
46
- where_clauses.append("id IN ({})".format(",".join("?" * len(ids))))
47
- values.extend(ids)
48
- where_sql = " AND ".join(where_clauses) or "1=1"
49
-
50
- result = []
51
- async with self.connection.cursor() as cursor:
52
- sql = "SELECT * FROM documents WHERE " + where_sql
53
- await cursor.execute(sql, values)
54
- for row in await cursor.fetchall():
55
- result.append(await self.tuple_to_dict(row))
56
- return result
212
+ assert self.engine is not None, "Database connection is not initialized."
213
+
214
+ async with self.get_session() as session:
215
+ async with session.begin():
216
+ query = select(Document).where(col(Document.doc_id) == doc_id)
217
+ result = await session.execute(query)
218
+ document = result.scalar_one_or_none()
219
+
220
+ if document:
221
+ await session.delete(document)
57
222
 
58
223
  async def get_document_by_doc_id(self, doc_id: str):
59
224
  """Retrieve a document by its doc_id.
@@ -62,28 +227,91 @@ class DocumentStorage:
62
227
  doc_id (str): The doc_id of the document to retrieve.
63
228
 
64
229
  Returns:
65
- dict: The document data.
230
+ dict: The document data or None if not found.
66
231
  """
67
- async with self.connection.cursor() as cursor:
68
- await cursor.execute("SELECT * FROM documents WHERE doc_id = ?", (doc_id,))
69
- row = await cursor.fetchone()
70
- if row:
71
- return await self.tuple_to_dict(row)
72
- else:
73
- return None
232
+ assert self.engine is not None, "Database connection is not initialized."
233
+
234
+ async with self.get_session() as session:
235
+ query = select(Document).where(col(Document.doc_id) == doc_id)
236
+ result = await session.execute(query)
237
+ document = result.scalar_one_or_none()
238
+
239
+ if document:
240
+ return self._document_to_dict(document)
241
+ return None
74
242
 
75
243
  async def update_document_by_doc_id(self, doc_id: str, new_text: str):
76
- """Retrieve a document by its doc_id.
244
+ """Update a document by its doc_id.
77
245
 
78
246
  Args:
79
247
  doc_id (str): The doc_id.
80
248
  new_text (str): The new text to update the document with.
81
249
  """
82
- async with self.connection.cursor() as cursor:
83
- await cursor.execute(
84
- "UPDATE documents SET text = ? WHERE doc_id = ?", (new_text, doc_id)
250
+ assert self.engine is not None, "Database connection is not initialized."
251
+
252
+ async with self.get_session() as session:
253
+ async with session.begin():
254
+ query = select(Document).where(col(Document.doc_id) == doc_id)
255
+ result = await session.execute(query)
256
+ document = result.scalar_one_or_none()
257
+
258
+ if document:
259
+ document.text = new_text
260
+ document.updated_at = datetime.now()
261
+ session.add(document)
262
+
263
+ async def delete_documents(self, metadata_filters: dict):
264
+ """Delete documents by their metadata filters.
265
+
266
+ Args:
267
+ metadata_filters (dict): The metadata filters to apply.
268
+ """
269
+ if self.engine is None:
270
+ logger.warning(
271
+ "Database connection is not initialized, skipping delete operation"
85
272
  )
86
- await self.connection.commit()
273
+ return
274
+
275
+ async with self.get_session() as session:
276
+ async with session.begin():
277
+ query = select(Document)
278
+
279
+ for key, val in metadata_filters.items():
280
+ query = query.where(
281
+ text(f"json_extract(metadata, '$.{key}') = :filter_{key}")
282
+ ).params(**{f"filter_{key}": val})
283
+
284
+ result = await session.execute(query)
285
+ documents = result.scalars().all()
286
+
287
+ for doc in documents:
288
+ await session.delete(doc)
289
+
290
+ async def count_documents(self, metadata_filters: dict | None = None) -> int:
291
+ """Count documents in the database.
292
+
293
+ Args:
294
+ metadata_filters (dict | None): Metadata filters to apply.
295
+
296
+ Returns:
297
+ int: The count of documents.
298
+ """
299
+ if self.engine is None:
300
+ logger.warning("Database connection is not initialized, returning 0")
301
+ return 0
302
+
303
+ async with self.get_session() as session:
304
+ query = select(func.count(col(Document.id)))
305
+
306
+ if metadata_filters:
307
+ for key, val in metadata_filters.items():
308
+ query = query.where(
309
+ text(f"json_extract(metadata, '$.{key}') = :filter_{key}")
310
+ ).params(**{f"filter_{key}": val})
311
+
312
+ result = await session.execute(query)
313
+ count = result.scalar_one_or_none()
314
+ return count if count is not None else 0
87
315
 
88
316
  async def get_user_ids(self) -> list[str]:
89
317
  """Retrieve all user IDs from the documents table.
@@ -91,11 +319,38 @@ class DocumentStorage:
91
319
  Returns:
92
320
  list: A list of user IDs.
93
321
  """
94
- async with self.connection.cursor() as cursor:
95
- await cursor.execute("SELECT DISTINCT user_id FROM documents")
96
- rows = await cursor.fetchall()
322
+ assert self.engine is not None, "Database connection is not initialized."
323
+
324
+ async with self.get_session() as session:
325
+ query = text(
326
+ "SELECT DISTINCT user_id FROM documents WHERE user_id IS NOT NULL"
327
+ )
328
+ result = await session.execute(query)
329
+ rows = result.fetchall()
97
330
  return [row[0] for row in rows]
98
331
 
332
+ def _document_to_dict(self, document: Document) -> dict:
333
+ """Convert a Document model to a dictionary.
334
+
335
+ Args:
336
+ document (Document): The document to convert.
337
+
338
+ Returns:
339
+ dict: The converted dictionary.
340
+ """
341
+ return {
342
+ "id": document.id,
343
+ "doc_id": document.doc_id,
344
+ "text": document.text,
345
+ "metadata": document.metadata_,
346
+ "created_at": document.created_at.isoformat()
347
+ if isinstance(document.created_at, datetime)
348
+ else document.created_at,
349
+ "updated_at": document.updated_at.isoformat()
350
+ if isinstance(document.updated_at, datetime)
351
+ else document.updated_at,
352
+ }
353
+
99
354
  async def tuple_to_dict(self, row):
100
355
  """Convert a tuple to a dictionary.
101
356
 
@@ -104,6 +359,8 @@ class DocumentStorage:
104
359
 
105
360
  Returns:
106
361
  dict: The converted dictionary.
362
+
363
+ Note: This method is kept for backward compatibility but is no longer used internally.
107
364
  """
108
365
  return {
109
366
  "id": row[0],
@@ -116,6 +373,7 @@ class DocumentStorage:
116
373
 
117
374
  async def close(self):
118
375
  """Close the connection to the SQLite database."""
119
- if self.connection:
120
- await self.connection.close()
121
- self.connection = None
376
+ if self.engine:
377
+ await self.engine.dispose()
378
+ self.engine = None
379
+ self.async_session_maker = None