AstrBot 4.3.5__py3-none-any.whl → 4.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. astrbot/core/agent/runners/tool_loop_agent_runner.py +31 -2
  2. astrbot/core/astrbot_config_mgr.py +23 -51
  3. astrbot/core/config/default.py +92 -12
  4. astrbot/core/conversation_mgr.py +36 -1
  5. astrbot/core/core_lifecycle.py +24 -5
  6. astrbot/core/db/migration/migra_45_to_46.py +44 -0
  7. astrbot/core/db/vec_db/base.py +33 -2
  8. astrbot/core/db/vec_db/faiss_impl/document_storage.py +310 -52
  9. astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +31 -3
  10. astrbot/core/db/vec_db/faiss_impl/vec_db.py +81 -23
  11. astrbot/core/file_token_service.py +6 -1
  12. astrbot/core/initial_loader.py +6 -3
  13. astrbot/core/knowledge_base/chunking/__init__.py +11 -0
  14. astrbot/core/knowledge_base/chunking/base.py +24 -0
  15. astrbot/core/knowledge_base/chunking/fixed_size.py +57 -0
  16. astrbot/core/knowledge_base/chunking/recursive.py +155 -0
  17. astrbot/core/knowledge_base/kb_db_sqlite.py +299 -0
  18. astrbot/core/knowledge_base/kb_helper.py +348 -0
  19. astrbot/core/knowledge_base/kb_mgr.py +287 -0
  20. astrbot/core/knowledge_base/models.py +114 -0
  21. astrbot/core/knowledge_base/parsers/__init__.py +15 -0
  22. astrbot/core/knowledge_base/parsers/base.py +50 -0
  23. astrbot/core/knowledge_base/parsers/markitdown_parser.py +25 -0
  24. astrbot/core/knowledge_base/parsers/pdf_parser.py +100 -0
  25. astrbot/core/knowledge_base/parsers/text_parser.py +41 -0
  26. astrbot/core/knowledge_base/parsers/util.py +13 -0
  27. astrbot/core/knowledge_base/retrieval/__init__.py +16 -0
  28. astrbot/core/knowledge_base/retrieval/hit_stopwords.txt +767 -0
  29. astrbot/core/knowledge_base/retrieval/manager.py +273 -0
  30. astrbot/core/knowledge_base/retrieval/rank_fusion.py +138 -0
  31. astrbot/core/knowledge_base/retrieval/sparse_retriever.py +130 -0
  32. astrbot/core/pipeline/process_stage/method/llm_request.py +29 -7
  33. astrbot/core/pipeline/process_stage/utils.py +80 -0
  34. astrbot/core/platform/astr_message_event.py +8 -7
  35. astrbot/core/platform/sources/misskey/misskey_adapter.py +380 -44
  36. astrbot/core/platform/sources/misskey/misskey_api.py +581 -45
  37. astrbot/core/platform/sources/misskey/misskey_event.py +76 -41
  38. astrbot/core/platform/sources/misskey/misskey_utils.py +254 -43
  39. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +2 -1
  40. astrbot/core/platform/sources/satori/satori_adapter.py +27 -1
  41. astrbot/core/platform/sources/satori/satori_event.py +270 -99
  42. astrbot/core/provider/manager.py +14 -9
  43. astrbot/core/provider/provider.py +67 -0
  44. astrbot/core/provider/sources/anthropic_source.py +4 -4
  45. astrbot/core/provider/sources/dashscope_source.py +10 -9
  46. astrbot/core/provider/sources/dify_source.py +6 -8
  47. astrbot/core/provider/sources/gemini_embedding_source.py +1 -2
  48. astrbot/core/provider/sources/openai_embedding_source.py +1 -2
  49. astrbot/core/provider/sources/openai_source.py +18 -15
  50. astrbot/core/provider/sources/openai_tts_api_source.py +1 -1
  51. astrbot/core/star/context.py +3 -0
  52. astrbot/core/star/star.py +6 -0
  53. astrbot/core/star/star_manager.py +13 -7
  54. astrbot/core/umop_config_router.py +81 -0
  55. astrbot/core/updator.py +1 -1
  56. astrbot/core/utils/io.py +23 -12
  57. astrbot/dashboard/routes/__init__.py +2 -0
  58. astrbot/dashboard/routes/config.py +137 -9
  59. astrbot/dashboard/routes/knowledge_base.py +1065 -0
  60. astrbot/dashboard/routes/plugin.py +24 -5
  61. astrbot/dashboard/routes/update.py +1 -1
  62. astrbot/dashboard/server.py +6 -0
  63. astrbot/dashboard/utils.py +161 -0
  64. {astrbot-4.3.5.dist-info → astrbot-4.5.0.dist-info}/METADATA +29 -13
  65. {astrbot-4.3.5.dist-info → astrbot-4.5.0.dist-info}/RECORD +68 -44
  66. {astrbot-4.3.5.dist-info → astrbot-4.5.0.dist-info}/WHEEL +0 -0
  67. {astrbot-4.3.5.dist-info → astrbot-4.5.0.dist-info}/entry_points.txt +0 -0
  68. {astrbot-4.3.5.dist-info → astrbot-4.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -16,14 +16,42 @@ class BaseVecDB:
16
16
  pass
17
17
 
18
18
  @abc.abstractmethod
19
- async def insert(self, content: str, metadata: dict = None, id: str = None) -> int:
19
+ async def insert(
20
+ self, content: str, metadata: dict | None = None, id: str | None = None
21
+ ) -> int:
20
22
  """
21
23
  插入一条文本和其对应向量,自动生成 ID 并保持一致性。
22
24
  """
23
25
  ...
24
26
 
25
27
  @abc.abstractmethod
26
- async def retrieve(self, query: str, top_k: int = 5) -> list[Result]:
28
+ async def insert_batch(
29
+ self,
30
+ contents: list[str],
31
+ metadatas: list[dict] | None = None,
32
+ ids: list[str] | None = None,
33
+ batch_size: int = 32,
34
+ tasks_limit: int = 3,
35
+ max_retries: int = 3,
36
+ progress_callback=None,
37
+ ) -> int:
38
+ """
39
+ 批量插入文本和其对应向量,自动生成 ID 并保持一致性。
40
+
41
+ Args:
42
+ progress_callback: 进度回调函数,接收参数 (current, total)
43
+ """
44
+ ...
45
+
46
+ @abc.abstractmethod
47
+ async def retrieve(
48
+ self,
49
+ query: str,
50
+ top_k: int = 5,
51
+ fetch_k: int = 20,
52
+ rerank: bool = False,
53
+ metadata_filters: dict | None = None,
54
+ ) -> list[Result]:
27
55
  """
28
56
  搜索最相似的文档。
29
57
  Args:
@@ -44,3 +72,6 @@ class BaseVecDB:
44
72
  bool: 删除是否成功
45
73
  """
46
74
  ...
75
+
76
+ @abc.abstractmethod
77
+ async def close(self): ...
@@ -1,59 +1,224 @@
1
- import aiosqlite
2
1
  import os
2
+ import json
3
+ from datetime import datetime
4
+ from contextlib import asynccontextmanager
5
+
6
+ from sqlalchemy import Text, Column
7
+ from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
8
+ from sqlalchemy.orm import sessionmaker
9
+ from sqlmodel import Field, SQLModel, select, col, func, text, MetaData
10
+ from astrbot.core import logger
11
+
12
+
13
+ class BaseDocModel(SQLModel, table=False):
14
+ metadata = MetaData()
15
+
16
+
17
+ class Document(BaseDocModel, table=True):
18
+ """SQLModel for documents table."""
19
+
20
+ __tablename__ = "documents" # type: ignore
21
+
22
+ id: int | None = Field(
23
+ default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}
24
+ )
25
+ doc_id: str = Field(nullable=False)
26
+ text: str = Field(nullable=False)
27
+ metadata_: str | None = Field(default=None, sa_column=Column("metadata", Text))
28
+ created_at: datetime | None = Field(default=None)
29
+ updated_at: datetime | None = Field(default=None)
3
30
 
4
31
 
5
32
  class DocumentStorage:
6
33
  def __init__(self, db_path: str):
7
34
  self.db_path = db_path
8
- self.connection = None
35
+ self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
36
+ self.engine: AsyncEngine | None = None
37
+ self.async_session_maker: sessionmaker | None = None
9
38
  self.sqlite_init_path = os.path.join(
10
39
  os.path.dirname(__file__), "sqlite_init.sql"
11
40
  )
12
41
 
13
42
  async def initialize(self):
14
43
  """Initialize the SQLite database and create the documents table if it doesn't exist."""
15
- if not os.path.exists(self.db_path):
16
- await self.connect()
17
- async with self.connection.cursor() as cursor:
18
- with open(self.sqlite_init_path, "r", encoding="utf-8") as f:
19
- sql_script = f.read()
20
- await cursor.executescript(sql_script)
21
- await self.connection.commit()
22
- else:
23
- await self.connect()
44
+ await self.connect()
45
+ async with self.engine.begin() as conn: # type: ignore
46
+ # Create tables using SQLModel
47
+ await conn.run_sync(BaseDocModel.metadata.create_all)
48
+
49
+ try:
50
+ await conn.execute(
51
+ text(
52
+ "ALTER TABLE documents ADD COLUMN kb_doc_id TEXT "
53
+ "GENERATED ALWAYS AS (json_extract(metadata, '$.kb_doc_id')) STORED"
54
+ )
55
+ )
56
+ await conn.execute(
57
+ text(
58
+ "ALTER TABLE documents ADD COLUMN user_id TEXT "
59
+ "GENERATED ALWAYS AS (json_extract(metadata, '$.user_id')) STORED"
60
+ )
61
+ )
62
+
63
+ # Create indexes
64
+ await conn.execute(
65
+ text(
66
+ "CREATE INDEX IF NOT EXISTS idx_documents_kb_doc_id ON documents(kb_doc_id)"
67
+ )
68
+ )
69
+ await conn.execute(
70
+ text(
71
+ "CREATE INDEX IF NOT EXISTS idx_documents_user_id ON documents(user_id)"
72
+ )
73
+ )
74
+ except BaseException:
75
+ pass
76
+
77
+ await conn.commit()
24
78
 
25
79
  async def connect(self):
26
80
  """Connect to the SQLite database."""
27
- self.connection = await aiosqlite.connect(self.db_path)
81
+ if self.engine is None:
82
+ self.engine = create_async_engine(
83
+ self.DATABASE_URL,
84
+ echo=False,
85
+ future=True,
86
+ )
87
+ self.async_session_maker = sessionmaker(
88
+ self.engine, # type: ignore
89
+ class_=AsyncSession,
90
+ expire_on_commit=False,
91
+ ) # type: ignore
92
+
93
+ @asynccontextmanager
94
+ async def get_session(self):
95
+ """Context manager for database sessions."""
96
+ async with self.async_session_maker() as session: # type: ignore
97
+ yield session
28
98
 
29
- async def get_documents(self, metadata_filters: dict, ids: list = None):
99
+ async def get_documents(
100
+ self,
101
+ metadata_filters: dict,
102
+ ids: list | None = None,
103
+ offset: int | None = 0,
104
+ limit: int | None = 100,
105
+ ) -> list[dict]:
30
106
  """Retrieve documents by metadata filters and ids.
31
107
 
32
108
  Args:
33
109
  metadata_filters (dict): The metadata filters to apply.
110
+ ids (list | None): Optional list of document IDs to filter.
111
+ offset (int | None): Offset for pagination.
112
+ limit (int | None): Limit for pagination.
113
+
114
+ Returns:
115
+ list: The list of documents that match the filters.
116
+ """
117
+ if self.engine is None:
118
+ logger.warning(
119
+ "Database connection is not initialized, returning empty result"
120
+ )
121
+ return []
122
+
123
+ async with self.get_session() as session:
124
+ query = select(Document)
125
+
126
+ for key, val in metadata_filters.items():
127
+ query = query.where(
128
+ text(f"json_extract(metadata, '$.{key}') = :filter_{key}")
129
+ ).params(**{f"filter_{key}": val})
130
+
131
+ if ids is not None and len(ids) > 0:
132
+ valid_ids = [int(i) for i in ids if i != -1]
133
+ if valid_ids:
134
+ query = query.where(col(Document.id).in_(valid_ids))
135
+
136
+ if limit is not None:
137
+ query = query.limit(limit)
138
+ if offset is not None:
139
+ query = query.offset(offset)
140
+
141
+ result = await session.execute(query)
142
+ documents = result.scalars().all()
143
+
144
+ return [self._document_to_dict(doc) for doc in documents]
145
+
146
+ async def insert_document(self, doc_id: str, text: str, metadata: dict) -> int:
147
+ """Insert a single document and return its integer ID.
148
+
149
+ Args:
150
+ doc_id (str): The document ID (UUID string).
151
+ text (str): The document text.
152
+ metadata (dict): The document metadata.
34
153
 
35
154
  Returns:
36
- list: The list of document IDs(primary key, not doc_id) that match the filters.
155
+ int: The integer ID of the inserted document.
156
+ """
157
+ assert self.engine is not None, "Database connection is not initialized."
158
+
159
+ async with self.get_session() as session:
160
+ async with session.begin():
161
+ document = Document(
162
+ doc_id=doc_id,
163
+ text=text,
164
+ metadata_=json.dumps(metadata),
165
+ created_at=datetime.now(),
166
+ updated_at=datetime.now(),
167
+ )
168
+ session.add(document)
169
+ await session.flush() # Flush to get the ID
170
+ return document.id # type: ignore
171
+
172
+ async def insert_documents_batch(
173
+ self, doc_ids: list[str], texts: list[str], metadatas: list[dict]
174
+ ) -> list[int]:
175
+ """Batch insert documents and return their integer IDs.
176
+
177
+ Args:
178
+ doc_ids (list[str]): List of document IDs (UUID strings).
179
+ texts (list[str]): List of document texts.
180
+ metadatas (list[dict]): List of document metadata.
181
+
182
+ Returns:
183
+ list[int]: List of integer IDs of the inserted documents.
184
+ """
185
+ assert self.engine is not None, "Database connection is not initialized."
186
+
187
+ async with self.get_session() as session:
188
+ async with session.begin():
189
+ import json
190
+
191
+ documents = []
192
+ for doc_id, text, metadata in zip(doc_ids, texts, metadatas):
193
+ document = Document(
194
+ doc_id=doc_id,
195
+ text=text,
196
+ metadata_=json.dumps(metadata),
197
+ created_at=datetime.now(),
198
+ updated_at=datetime.now(),
199
+ )
200
+ documents.append(document)
201
+ session.add(document)
202
+
203
+ await session.flush() # Flush to get all IDs
204
+ return [doc.id for doc in documents] # type: ignore
205
+
206
+ async def delete_document_by_doc_id(self, doc_id: str):
207
+ """Delete a document by its doc_id.
208
+
209
+ Args:
210
+ doc_id (str): The doc_id of the document to delete.
37
211
  """
38
- # metadata filter -> SQL WHERE clause
39
- where_clauses = []
40
- values = []
41
- for key, val in metadata_filters.items():
42
- where_clauses.append(f"json_extract(metadata, '$.{key}') = ?")
43
- values.append(val)
44
- if ids is not None and len(ids) > 0:
45
- ids = [str(i) for i in ids if i != -1]
46
- where_clauses.append("id IN ({})".format(",".join("?" * len(ids))))
47
- values.extend(ids)
48
- where_sql = " AND ".join(where_clauses) or "1=1"
49
-
50
- result = []
51
- async with self.connection.cursor() as cursor:
52
- sql = "SELECT * FROM documents WHERE " + where_sql
53
- await cursor.execute(sql, values)
54
- for row in await cursor.fetchall():
55
- result.append(await self.tuple_to_dict(row))
56
- return result
212
+ assert self.engine is not None, "Database connection is not initialized."
213
+
214
+ async with self.get_session() as session:
215
+ async with session.begin():
216
+ query = select(Document).where(col(Document.doc_id) == doc_id)
217
+ result = await session.execute(query)
218
+ document = result.scalar_one_or_none()
219
+
220
+ if document:
221
+ await session.delete(document)
57
222
 
58
223
  async def get_document_by_doc_id(self, doc_id: str):
59
224
  """Retrieve a document by its doc_id.
@@ -62,28 +227,91 @@ class DocumentStorage:
62
227
  doc_id (str): The doc_id of the document to retrieve.
63
228
 
64
229
  Returns:
65
- dict: The document data.
230
+ dict: The document data or None if not found.
66
231
  """
67
- async with self.connection.cursor() as cursor:
68
- await cursor.execute("SELECT * FROM documents WHERE doc_id = ?", (doc_id,))
69
- row = await cursor.fetchone()
70
- if row:
71
- return await self.tuple_to_dict(row)
72
- else:
73
- return None
232
+ assert self.engine is not None, "Database connection is not initialized."
233
+
234
+ async with self.get_session() as session:
235
+ query = select(Document).where(col(Document.doc_id) == doc_id)
236
+ result = await session.execute(query)
237
+ document = result.scalar_one_or_none()
238
+
239
+ if document:
240
+ return self._document_to_dict(document)
241
+ return None
74
242
 
75
243
  async def update_document_by_doc_id(self, doc_id: str, new_text: str):
76
- """Retrieve a document by its doc_id.
244
+ """Update a document by its doc_id.
77
245
 
78
246
  Args:
79
247
  doc_id (str): The doc_id.
80
248
  new_text (str): The new text to update the document with.
81
249
  """
82
- async with self.connection.cursor() as cursor:
83
- await cursor.execute(
84
- "UPDATE documents SET text = ? WHERE doc_id = ?", (new_text, doc_id)
250
+ assert self.engine is not None, "Database connection is not initialized."
251
+
252
+ async with self.get_session() as session:
253
+ async with session.begin():
254
+ query = select(Document).where(col(Document.doc_id) == doc_id)
255
+ result = await session.execute(query)
256
+ document = result.scalar_one_or_none()
257
+
258
+ if document:
259
+ document.text = new_text
260
+ document.updated_at = datetime.now()
261
+ session.add(document)
262
+
263
+ async def delete_documents(self, metadata_filters: dict):
264
+ """Delete documents by their metadata filters.
265
+
266
+ Args:
267
+ metadata_filters (dict): The metadata filters to apply.
268
+ """
269
+ if self.engine is None:
270
+ logger.warning(
271
+ "Database connection is not initialized, skipping delete operation"
85
272
  )
86
- await self.connection.commit()
273
+ return
274
+
275
+ async with self.get_session() as session:
276
+ async with session.begin():
277
+ query = select(Document)
278
+
279
+ for key, val in metadata_filters.items():
280
+ query = query.where(
281
+ text(f"json_extract(metadata, '$.{key}') = :filter_{key}")
282
+ ).params(**{f"filter_{key}": val})
283
+
284
+ result = await session.execute(query)
285
+ documents = result.scalars().all()
286
+
287
+ for doc in documents:
288
+ await session.delete(doc)
289
+
290
+ async def count_documents(self, metadata_filters: dict | None = None) -> int:
291
+ """Count documents in the database.
292
+
293
+ Args:
294
+ metadata_filters (dict | None): Metadata filters to apply.
295
+
296
+ Returns:
297
+ int: The count of documents.
298
+ """
299
+ if self.engine is None:
300
+ logger.warning("Database connection is not initialized, returning 0")
301
+ return 0
302
+
303
+ async with self.get_session() as session:
304
+ query = select(func.count(col(Document.id)))
305
+
306
+ if metadata_filters:
307
+ for key, val in metadata_filters.items():
308
+ query = query.where(
309
+ text(f"json_extract(metadata, '$.{key}') = :filter_{key}")
310
+ ).params(**{f"filter_{key}": val})
311
+
312
+ result = await session.execute(query)
313
+ count = result.scalar_one_or_none()
314
+ return count if count is not None else 0
87
315
 
88
316
  async def get_user_ids(self) -> list[str]:
89
317
  """Retrieve all user IDs from the documents table.
@@ -91,11 +319,38 @@ class DocumentStorage:
91
319
  Returns:
92
320
  list: A list of user IDs.
93
321
  """
94
- async with self.connection.cursor() as cursor:
95
- await cursor.execute("SELECT DISTINCT user_id FROM documents")
96
- rows = await cursor.fetchall()
322
+ assert self.engine is not None, "Database connection is not initialized."
323
+
324
+ async with self.get_session() as session:
325
+ query = text(
326
+ "SELECT DISTINCT user_id FROM documents WHERE user_id IS NOT NULL"
327
+ )
328
+ result = await session.execute(query)
329
+ rows = result.fetchall()
97
330
  return [row[0] for row in rows]
98
331
 
332
+ def _document_to_dict(self, document: Document) -> dict:
333
+ """Convert a Document model to a dictionary.
334
+
335
+ Args:
336
+ document (Document): The document to convert.
337
+
338
+ Returns:
339
+ dict: The converted dictionary.
340
+ """
341
+ return {
342
+ "id": document.id,
343
+ "doc_id": document.doc_id,
344
+ "text": document.text,
345
+ "metadata": document.metadata_,
346
+ "created_at": document.created_at.isoformat()
347
+ if isinstance(document.created_at, datetime)
348
+ else document.created_at,
349
+ "updated_at": document.updated_at.isoformat()
350
+ if isinstance(document.updated_at, datetime)
351
+ else document.updated_at,
352
+ }
353
+
99
354
  async def tuple_to_dict(self, row):
100
355
  """Convert a tuple to a dictionary.
101
356
 
@@ -104,6 +359,8 @@ class DocumentStorage:
104
359
 
105
360
  Returns:
106
361
  dict: The converted dictionary.
362
+
363
+ Note: This method is kept for backward compatibility but is no longer used internally.
107
364
  """
108
365
  return {
109
366
  "id": row[0],
@@ -116,6 +373,7 @@ class DocumentStorage:
116
373
 
117
374
  async def close(self):
118
375
  """Close the connection to the SQLite database."""
119
- if self.connection:
120
- await self.connection.close()
121
- self.connection = None
376
+ if self.engine:
377
+ await self.engine.dispose()
378
+ self.engine = None
379
+ self.async_session_maker = None
@@ -9,7 +9,7 @@ import numpy as np
9
9
 
10
10
 
11
11
  class EmbeddingStorage:
12
- def __init__(self, dimension: int, path: str = None):
12
+ def __init__(self, dimension: int, path: str | None = None):
13
13
  self.dimension = dimension
14
14
  self.path = path
15
15
  self.index = None
@@ -18,7 +18,6 @@ class EmbeddingStorage:
18
18
  else:
19
19
  base_index = faiss.IndexFlatL2(dimension)
20
20
  self.index = faiss.IndexIDMap(base_index)
21
- self.storage = {}
22
21
 
23
22
  async def insert(self, vector: np.ndarray, id: int):
24
23
  """插入向量
@@ -29,12 +28,29 @@ class EmbeddingStorage:
29
28
  Raises:
30
29
  ValueError: 如果向量的维度与存储的维度不匹配
31
30
  """
31
+ assert self.index is not None, "FAISS index is not initialized."
32
32
  if vector.shape[0] != self.dimension:
33
33
  raise ValueError(
34
34
  f"向量维度不匹配, 期望: {self.dimension}, 实际: {vector.shape[0]}"
35
35
  )
36
36
  self.index.add_with_ids(vector.reshape(1, -1), np.array([id]))
37
- self.storage[id] = vector
37
+ await self.save_index()
38
+
39
+ async def insert_batch(self, vectors: np.ndarray, ids: list[int]):
40
+ """批量插入向量
41
+
42
+ Args:
43
+ vectors (np.ndarray): 要插入的向量数组
44
+ ids (list[int]): 向量的ID列表
45
+ Raises:
46
+ ValueError: 如果向量的维度与存储的维度不匹配
47
+ """
48
+ assert self.index is not None, "FAISS index is not initialized."
49
+ if vectors.shape[1] != self.dimension:
50
+ raise ValueError(
51
+ f"向量维度不匹配, 期望: {self.dimension}, 实际: {vectors.shape[1]}"
52
+ )
53
+ self.index.add_with_ids(vectors, np.array(ids))
38
54
  await self.save_index()
39
55
 
40
56
  async def search(self, vector: np.ndarray, k: int) -> tuple:
@@ -46,10 +62,22 @@ class EmbeddingStorage:
46
62
  Returns:
47
63
  tuple: (距离, 索引)
48
64
  """
65
+ assert self.index is not None, "FAISS index is not initialized."
49
66
  faiss.normalize_L2(vector)
50
67
  distances, indices = self.index.search(vector, k)
51
68
  return distances, indices
52
69
 
70
+ async def delete(self, ids: list[int]):
71
+ """删除向量
72
+
73
+ Args:
74
+ ids (list[int]): 要删除的向量ID列表
75
+ """
76
+ assert self.index is not None, "FAISS index is not initialized."
77
+ id_array = np.array(ids, dtype=np.int64)
78
+ self.index.remove_ids(id_array)
79
+ await self.save_index()
80
+
53
81
  async def save_index(self):
54
82
  """保存索引
55
83
 
@@ -1,11 +1,12 @@
1
1
  import uuid
2
- import json
2
+ import time
3
3
  import numpy as np
4
4
  from .document_storage import DocumentStorage
5
5
  from .embedding_storage import EmbeddingStorage
6
6
  from ..base import Result, BaseVecDB
7
7
  from astrbot.core.provider.provider import EmbeddingProvider
8
8
  from astrbot.core.provider.provider import RerankProvider
9
+ from astrbot import logger
9
10
 
10
11
 
11
12
  class FaissVecDB(BaseVecDB):
@@ -44,18 +45,56 @@ class FaissVecDB(BaseVecDB):
44
45
 
45
46
  vector = await self.embedding_provider.get_embedding(content)
46
47
  vector = np.array(vector, dtype=np.float32)
47
- async with self.document_storage.connection.cursor() as cursor:
48
- await cursor.execute(
49
- "INSERT INTO documents (doc_id, text, metadata) VALUES (?, ?, ?)",
50
- (str_id, content, json.dumps(metadata)),
51
- )
52
- await self.document_storage.connection.commit()
53
- result = await self.document_storage.get_document_by_doc_id(str_id)
54
- int_id = result["id"]
55
48
 
56
- # 插入向量到 FAISS
57
- await self.embedding_storage.insert(vector, int_id)
58
- return int_id
49
+ # 使用 DocumentStorage 的方法插入文档
50
+ int_id = await self.document_storage.insert_document(str_id, content, metadata)
51
+
52
+ # 插入向量到 FAISS
53
+ await self.embedding_storage.insert(vector, int_id)
54
+ return int_id
55
+
56
+ async def insert_batch(
57
+ self,
58
+ contents: list[str],
59
+ metadatas: list[dict] | None = None,
60
+ ids: list[str] | None = None,
61
+ batch_size: int = 32,
62
+ tasks_limit: int = 3,
63
+ max_retries: int = 3,
64
+ progress_callback=None,
65
+ ) -> list[int]:
66
+ """
67
+ 批量插入文本和其对应向量,自动生成 ID 并保持一致性。
68
+
69
+ Args:
70
+ progress_callback: 进度回调函数,接收参数 (current, total)
71
+ """
72
+ metadatas = metadatas or [{} for _ in contents]
73
+ ids = ids or [str(uuid.uuid4()) for _ in contents]
74
+
75
+ start = time.time()
76
+ logger.debug(f"Generating embeddings for {len(contents)} contents...")
77
+ vectors = await self.embedding_provider.get_embeddings_batch(
78
+ contents,
79
+ batch_size=batch_size,
80
+ tasks_limit=tasks_limit,
81
+ max_retries=max_retries,
82
+ progress_callback=progress_callback,
83
+ )
84
+ end = time.time()
85
+ logger.debug(
86
+ f"Generated embeddings for {len(contents)} contents in {end - start:.2f} seconds."
87
+ )
88
+
89
+ # 使用 DocumentStorage 的批量插入方法
90
+ int_ids = await self.document_storage.insert_documents_batch(
91
+ ids, contents, metadatas
92
+ )
93
+
94
+ # 批量插入向量到 FAISS
95
+ vectors_array = np.array(vectors).astype("float32")
96
+ await self.embedding_storage.insert_batch(vectors_array, int_ids)
97
+ return int_ids
59
98
 
60
99
  async def retrieve(
61
100
  self,
@@ -119,23 +158,42 @@ class FaissVecDB(BaseVecDB):
119
158
 
120
159
  return top_k_results
121
160
 
122
- async def delete(self, doc_id: int):
161
+ async def delete(self, doc_id: str):
123
162
  """
124
- 删除一条文档
163
+ 删除一条文档块(chunk)
125
164
  """
126
- await self.document_storage.connection.execute(
127
- "DELETE FROM documents WHERE doc_id = ?", (doc_id,)
128
- )
129
- await self.document_storage.connection.commit()
165
+ # 获得对应的 int id
166
+ result = await self.document_storage.get_document_by_doc_id(doc_id)
167
+ int_id = result["id"] if result else None
168
+ if int_id is None:
169
+ return
170
+
171
+ # 使用 DocumentStorage 的删除方法
172
+ await self.document_storage.delete_document_by_doc_id(doc_id)
173
+ await self.embedding_storage.delete([int_id])
130
174
 
131
175
  async def close(self):
132
176
  await self.document_storage.close()
133
177
 
134
- async def count_documents(self) -> int:
178
+ async def count_documents(self, metadata_filter: dict | None = None) -> int:
135
179
  """
136
180
  计算文档数量
181
+
182
+ Args:
183
+ metadata_filter (dict | None): 元数据过滤器
137
184
  """
138
- async with self.document_storage.connection.cursor() as cursor:
139
- await cursor.execute("SELECT COUNT(*) FROM documents")
140
- count = await cursor.fetchone()
141
- return count[0] if count else 0
185
+ count = await self.document_storage.count_documents(
186
+ metadata_filters=metadata_filter or {}
187
+ )
188
+ return count
189
+
190
+ async def delete_documents(self, metadata_filters: dict):
191
+ """
192
+ 根据元数据过滤器删除文档
193
+ """
194
+ docs = await self.document_storage.get_documents(
195
+ metadata_filters=metadata_filters, offset=None, limit=None
196
+ )
197
+ doc_ids: list[int] = [doc["id"] for doc in docs]
198
+ await self.embedding_storage.delete(doc_ids)
199
+ await self.document_storage.delete_documents(metadata_filters=metadata_filters)