AstrBot 4.3.3__py3-none-any.whl → 4.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/core/agent/mcp_client.py +18 -4
- astrbot/core/agent/runners/tool_loop_agent_runner.py +31 -2
- astrbot/core/astr_agent_context.py +1 -0
- astrbot/core/astrbot_config_mgr.py +23 -51
- astrbot/core/config/default.py +139 -14
- astrbot/core/conversation_mgr.py +36 -1
- astrbot/core/core_lifecycle.py +24 -5
- astrbot/core/db/migration/migra_45_to_46.py +44 -0
- astrbot/core/db/vec_db/base.py +33 -2
- astrbot/core/db/vec_db/faiss_impl/document_storage.py +310 -52
- astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +31 -3
- astrbot/core/db/vec_db/faiss_impl/vec_db.py +81 -23
- astrbot/core/file_token_service.py +6 -1
- astrbot/core/initial_loader.py +6 -3
- astrbot/core/knowledge_base/chunking/__init__.py +11 -0
- astrbot/core/knowledge_base/chunking/base.py +24 -0
- astrbot/core/knowledge_base/chunking/fixed_size.py +57 -0
- astrbot/core/knowledge_base/chunking/recursive.py +155 -0
- astrbot/core/knowledge_base/kb_db_sqlite.py +299 -0
- astrbot/core/knowledge_base/kb_helper.py +348 -0
- astrbot/core/knowledge_base/kb_mgr.py +287 -0
- astrbot/core/knowledge_base/models.py +114 -0
- astrbot/core/knowledge_base/parsers/__init__.py +15 -0
- astrbot/core/knowledge_base/parsers/base.py +50 -0
- astrbot/core/knowledge_base/parsers/markitdown_parser.py +25 -0
- astrbot/core/knowledge_base/parsers/pdf_parser.py +100 -0
- astrbot/core/knowledge_base/parsers/text_parser.py +41 -0
- astrbot/core/knowledge_base/parsers/util.py +13 -0
- astrbot/core/knowledge_base/retrieval/__init__.py +16 -0
- astrbot/core/knowledge_base/retrieval/hit_stopwords.txt +767 -0
- astrbot/core/knowledge_base/retrieval/manager.py +273 -0
- astrbot/core/knowledge_base/retrieval/rank_fusion.py +138 -0
- astrbot/core/knowledge_base/retrieval/sparse_retriever.py +130 -0
- astrbot/core/pipeline/process_stage/method/llm_request.py +61 -21
- astrbot/core/pipeline/process_stage/utils.py +80 -0
- astrbot/core/pipeline/scheduler.py +1 -1
- astrbot/core/platform/astr_message_event.py +8 -7
- astrbot/core/platform/manager.py +4 -0
- astrbot/core/platform/sources/misskey/misskey_adapter.py +380 -44
- astrbot/core/platform/sources/misskey/misskey_api.py +581 -45
- astrbot/core/platform/sources/misskey/misskey_event.py +76 -41
- astrbot/core/platform/sources/misskey/misskey_utils.py +254 -43
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +2 -1
- astrbot/core/platform/sources/satori/satori_adapter.py +27 -1
- astrbot/core/platform/sources/satori/satori_event.py +270 -77
- astrbot/core/platform/sources/webchat/webchat_adapter.py +0 -1
- astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +289 -0
- astrbot/core/platform/sources/wecom_ai_bot/__init__.py +17 -0
- astrbot/core/platform/sources/wecom_ai_bot/ierror.py +20 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +445 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +378 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +149 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +148 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +166 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +199 -0
- astrbot/core/provider/manager.py +14 -9
- astrbot/core/provider/provider.py +67 -0
- astrbot/core/provider/sources/anthropic_source.py +4 -4
- astrbot/core/provider/sources/dashscope_source.py +10 -9
- astrbot/core/provider/sources/dify_source.py +6 -8
- astrbot/core/provider/sources/gemini_embedding_source.py +1 -2
- astrbot/core/provider/sources/openai_embedding_source.py +1 -2
- astrbot/core/provider/sources/openai_source.py +18 -15
- astrbot/core/provider/sources/openai_tts_api_source.py +1 -1
- astrbot/core/star/context.py +3 -0
- astrbot/core/star/star.py +6 -0
- astrbot/core/star/star_manager.py +13 -7
- astrbot/core/umop_config_router.py +81 -0
- astrbot/core/updator.py +1 -1
- astrbot/core/utils/io.py +23 -12
- astrbot/dashboard/routes/__init__.py +2 -0
- astrbot/dashboard/routes/config.py +137 -9
- astrbot/dashboard/routes/knowledge_base.py +1065 -0
- astrbot/dashboard/routes/plugin.py +24 -5
- astrbot/dashboard/routes/tools.py +14 -0
- astrbot/dashboard/routes/update.py +1 -1
- astrbot/dashboard/server.py +6 -0
- astrbot/dashboard/utils.py +161 -0
- {astrbot-4.3.3.dist-info → astrbot-4.5.0.dist-info}/METADATA +91 -55
- {astrbot-4.3.3.dist-info → astrbot-4.5.0.dist-info}/RECORD +83 -50
- {astrbot-4.3.3.dist-info → astrbot-4.5.0.dist-info}/WHEEL +0 -0
- {astrbot-4.3.3.dist-info → astrbot-4.5.0.dist-info}/entry_points.txt +0 -0
- {astrbot-4.3.3.dist-info → astrbot-4.5.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,299 @@
|
|
|
1
|
+
from contextlib import asynccontextmanager
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
from sqlmodel import col, desc
|
|
5
|
+
from sqlalchemy import text, func, select, update, delete
|
|
6
|
+
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
|
7
|
+
|
|
8
|
+
from astrbot.core import logger
|
|
9
|
+
from astrbot.core.knowledge_base.models import (
|
|
10
|
+
BaseKBModel,
|
|
11
|
+
KBDocument,
|
|
12
|
+
KBMedia,
|
|
13
|
+
KnowledgeBase,
|
|
14
|
+
)
|
|
15
|
+
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class KBSQLiteDatabase:
|
|
19
|
+
def __init__(self, db_path: str = "data/knowledge_base/kb.db") -> None:
|
|
20
|
+
"""初始化知识库数据库
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
db_path: 数据库文件路径, 默认为 data/knowledge_base/kb.db
|
|
24
|
+
"""
|
|
25
|
+
self.db_path = db_path
|
|
26
|
+
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
|
|
27
|
+
self.inited = False
|
|
28
|
+
|
|
29
|
+
# 确保目录存在
|
|
30
|
+
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
|
31
|
+
|
|
32
|
+
# 创建异步引擎
|
|
33
|
+
self.engine = create_async_engine(
|
|
34
|
+
self.DATABASE_URL,
|
|
35
|
+
echo=False,
|
|
36
|
+
pool_pre_ping=True,
|
|
37
|
+
pool_recycle=3600,
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
# 创建会话工厂
|
|
41
|
+
self.async_session = async_sessionmaker(
|
|
42
|
+
self.engine,
|
|
43
|
+
class_=AsyncSession,
|
|
44
|
+
expire_on_commit=False,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
@asynccontextmanager
|
|
48
|
+
async def get_db(self):
|
|
49
|
+
"""获取数据库会话
|
|
50
|
+
|
|
51
|
+
用法:
|
|
52
|
+
async with kb_db.get_db() as session:
|
|
53
|
+
# 执行数据库操作
|
|
54
|
+
result = await session.execute(stmt)
|
|
55
|
+
"""
|
|
56
|
+
async with self.async_session() as session:
|
|
57
|
+
yield session
|
|
58
|
+
|
|
59
|
+
async def initialize(self) -> None:
|
|
60
|
+
"""初始化数据库,创建表并配置 SQLite 参数"""
|
|
61
|
+
async with self.engine.begin() as conn:
|
|
62
|
+
# 创建所有知识库相关表
|
|
63
|
+
await conn.run_sync(BaseKBModel.metadata.create_all)
|
|
64
|
+
|
|
65
|
+
# 配置 SQLite 性能优化参数
|
|
66
|
+
await conn.execute(text("PRAGMA journal_mode=WAL"))
|
|
67
|
+
await conn.execute(text("PRAGMA synchronous=NORMAL"))
|
|
68
|
+
await conn.execute(text("PRAGMA cache_size=20000"))
|
|
69
|
+
await conn.execute(text("PRAGMA temp_store=MEMORY"))
|
|
70
|
+
await conn.execute(text("PRAGMA mmap_size=134217728"))
|
|
71
|
+
await conn.execute(text("PRAGMA optimize"))
|
|
72
|
+
await conn.commit()
|
|
73
|
+
|
|
74
|
+
self.inited = True
|
|
75
|
+
|
|
76
|
+
async def migrate_to_v1(self) -> None:
|
|
77
|
+
"""执行知识库数据库 v1 迁移
|
|
78
|
+
|
|
79
|
+
创建所有必要的索引以优化查询性能
|
|
80
|
+
"""
|
|
81
|
+
async with self.get_db() as session:
|
|
82
|
+
session: AsyncSession
|
|
83
|
+
async with session.begin():
|
|
84
|
+
# 创建知识库表索引
|
|
85
|
+
await session.execute(
|
|
86
|
+
text(
|
|
87
|
+
"CREATE INDEX IF NOT EXISTS idx_kb_kb_id "
|
|
88
|
+
"ON knowledge_bases(kb_id)"
|
|
89
|
+
)
|
|
90
|
+
)
|
|
91
|
+
await session.execute(
|
|
92
|
+
text(
|
|
93
|
+
"CREATE INDEX IF NOT EXISTS idx_kb_name "
|
|
94
|
+
"ON knowledge_bases(kb_name)"
|
|
95
|
+
)
|
|
96
|
+
)
|
|
97
|
+
await session.execute(
|
|
98
|
+
text(
|
|
99
|
+
"CREATE INDEX IF NOT EXISTS idx_kb_created_at "
|
|
100
|
+
"ON knowledge_bases(created_at)"
|
|
101
|
+
)
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# 创建文档表索引
|
|
105
|
+
await session.execute(
|
|
106
|
+
text(
|
|
107
|
+
"CREATE INDEX IF NOT EXISTS idx_doc_doc_id "
|
|
108
|
+
"ON kb_documents(doc_id)"
|
|
109
|
+
)
|
|
110
|
+
)
|
|
111
|
+
await session.execute(
|
|
112
|
+
text(
|
|
113
|
+
"CREATE INDEX IF NOT EXISTS idx_doc_kb_id "
|
|
114
|
+
"ON kb_documents(kb_id)"
|
|
115
|
+
)
|
|
116
|
+
)
|
|
117
|
+
await session.execute(
|
|
118
|
+
text(
|
|
119
|
+
"CREATE INDEX IF NOT EXISTS idx_doc_name "
|
|
120
|
+
"ON kb_documents(doc_name)"
|
|
121
|
+
)
|
|
122
|
+
)
|
|
123
|
+
await session.execute(
|
|
124
|
+
text(
|
|
125
|
+
"CREATE INDEX IF NOT EXISTS idx_doc_type "
|
|
126
|
+
"ON kb_documents(file_type)"
|
|
127
|
+
)
|
|
128
|
+
)
|
|
129
|
+
await session.execute(
|
|
130
|
+
text(
|
|
131
|
+
"CREATE INDEX IF NOT EXISTS idx_doc_created_at "
|
|
132
|
+
"ON kb_documents(created_at)"
|
|
133
|
+
)
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# 创建多媒体表索引
|
|
137
|
+
await session.execute(
|
|
138
|
+
text(
|
|
139
|
+
"CREATE INDEX IF NOT EXISTS idx_media_media_id "
|
|
140
|
+
"ON kb_media(media_id)"
|
|
141
|
+
)
|
|
142
|
+
)
|
|
143
|
+
await session.execute(
|
|
144
|
+
text(
|
|
145
|
+
"CREATE INDEX IF NOT EXISTS idx_media_doc_id "
|
|
146
|
+
"ON kb_media(doc_id)"
|
|
147
|
+
)
|
|
148
|
+
)
|
|
149
|
+
await session.execute(
|
|
150
|
+
text(
|
|
151
|
+
"CREATE INDEX IF NOT EXISTS idx_media_kb_id ON kb_media(kb_id)"
|
|
152
|
+
)
|
|
153
|
+
)
|
|
154
|
+
await session.execute(
|
|
155
|
+
text(
|
|
156
|
+
"CREATE INDEX IF NOT EXISTS idx_media_type "
|
|
157
|
+
"ON kb_media(media_type)"
|
|
158
|
+
)
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
await session.commit()
|
|
162
|
+
|
|
163
|
+
async def close(self) -> None:
|
|
164
|
+
"""关闭数据库连接"""
|
|
165
|
+
await self.engine.dispose()
|
|
166
|
+
logger.info(f"知识库数据库已关闭: {self.db_path}")
|
|
167
|
+
|
|
168
|
+
async def get_kb_by_id(self, kb_id: str) -> KnowledgeBase | None:
|
|
169
|
+
"""根据 ID 获取知识库"""
|
|
170
|
+
async with self.get_db() as session:
|
|
171
|
+
stmt = select(KnowledgeBase).where(col(KnowledgeBase.kb_id) == kb_id)
|
|
172
|
+
result = await session.execute(stmt)
|
|
173
|
+
return result.scalar_one_or_none()
|
|
174
|
+
|
|
175
|
+
async def get_kb_by_name(self, kb_name: str) -> KnowledgeBase | None:
|
|
176
|
+
"""根据名称获取知识库"""
|
|
177
|
+
async with self.get_db() as session:
|
|
178
|
+
stmt = select(KnowledgeBase).where(col(KnowledgeBase.kb_name) == kb_name)
|
|
179
|
+
result = await session.execute(stmt)
|
|
180
|
+
return result.scalar_one_or_none()
|
|
181
|
+
|
|
182
|
+
async def list_kbs(self, offset: int = 0, limit: int = 100) -> list[KnowledgeBase]:
|
|
183
|
+
"""列出所有知识库"""
|
|
184
|
+
async with self.get_db() as session:
|
|
185
|
+
stmt = (
|
|
186
|
+
select(KnowledgeBase)
|
|
187
|
+
.offset(offset)
|
|
188
|
+
.limit(limit)
|
|
189
|
+
.order_by(desc(KnowledgeBase.created_at))
|
|
190
|
+
)
|
|
191
|
+
result = await session.execute(stmt)
|
|
192
|
+
return list(result.scalars().all())
|
|
193
|
+
|
|
194
|
+
async def count_kbs(self) -> int:
|
|
195
|
+
"""统计知识库数量"""
|
|
196
|
+
async with self.get_db() as session:
|
|
197
|
+
stmt = select(func.count(col(KnowledgeBase.id)))
|
|
198
|
+
result = await session.execute(stmt)
|
|
199
|
+
return result.scalar() or 0
|
|
200
|
+
|
|
201
|
+
# ===== 文档查询 =====
|
|
202
|
+
|
|
203
|
+
async def get_document_by_id(self, doc_id: str) -> KBDocument | None:
|
|
204
|
+
"""根据 ID 获取文档"""
|
|
205
|
+
async with self.get_db() as session:
|
|
206
|
+
stmt = select(KBDocument).where(col(KBDocument.doc_id) == doc_id)
|
|
207
|
+
result = await session.execute(stmt)
|
|
208
|
+
return result.scalar_one_or_none()
|
|
209
|
+
|
|
210
|
+
async def list_documents_by_kb(
|
|
211
|
+
self, kb_id: str, offset: int = 0, limit: int = 100
|
|
212
|
+
) -> list[KBDocument]:
|
|
213
|
+
"""列出知识库的所有文档"""
|
|
214
|
+
async with self.get_db() as session:
|
|
215
|
+
stmt = (
|
|
216
|
+
select(KBDocument)
|
|
217
|
+
.where(col(KBDocument.kb_id) == kb_id)
|
|
218
|
+
.offset(offset)
|
|
219
|
+
.limit(limit)
|
|
220
|
+
.order_by(desc(KBDocument.created_at))
|
|
221
|
+
)
|
|
222
|
+
result = await session.execute(stmt)
|
|
223
|
+
return list(result.scalars().all())
|
|
224
|
+
|
|
225
|
+
async def count_documents_by_kb(self, kb_id: str) -> int:
|
|
226
|
+
"""统计知识库的文档数量"""
|
|
227
|
+
async with self.get_db() as session:
|
|
228
|
+
stmt = select(func.count(col(KBDocument.id))).where(
|
|
229
|
+
col(KBDocument.kb_id) == kb_id
|
|
230
|
+
)
|
|
231
|
+
result = await session.execute(stmt)
|
|
232
|
+
return result.scalar() or 0
|
|
233
|
+
|
|
234
|
+
async def get_document_with_metadata(self, doc_id: str) -> dict | None:
|
|
235
|
+
async with self.get_db() as session:
|
|
236
|
+
stmt = (
|
|
237
|
+
select(KBDocument, KnowledgeBase)
|
|
238
|
+
.join(KnowledgeBase, col(KBDocument.kb_id) == col(KnowledgeBase.kb_id))
|
|
239
|
+
.where(col(KBDocument.doc_id) == doc_id)
|
|
240
|
+
)
|
|
241
|
+
result = await session.execute(stmt)
|
|
242
|
+
row = result.first()
|
|
243
|
+
|
|
244
|
+
if not row:
|
|
245
|
+
return None
|
|
246
|
+
|
|
247
|
+
return {
|
|
248
|
+
"document": row[0],
|
|
249
|
+
"knowledge_base": row[1],
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
async def delete_document_by_id(self, doc_id: str, vec_db: FaissVecDB):
|
|
253
|
+
"""删除单个文档及其相关数据"""
|
|
254
|
+
# 在知识库表中删除
|
|
255
|
+
async with self.get_db() as session:
|
|
256
|
+
async with session.begin():
|
|
257
|
+
# 删除文档记录
|
|
258
|
+
delete_stmt = delete(KBDocument).where(col(KBDocument.doc_id) == doc_id)
|
|
259
|
+
await session.execute(delete_stmt)
|
|
260
|
+
await session.commit()
|
|
261
|
+
|
|
262
|
+
# 在 vec db 中删除相关向量
|
|
263
|
+
await vec_db.delete_documents(metadata_filters={"kb_doc_id": doc_id})
|
|
264
|
+
|
|
265
|
+
# ===== 多媒体查询 =====
|
|
266
|
+
|
|
267
|
+
async def list_media_by_doc(self, doc_id: str) -> list[KBMedia]:
|
|
268
|
+
"""列出文档的所有多媒体资源"""
|
|
269
|
+
async with self.get_db() as session:
|
|
270
|
+
stmt = select(KBMedia).where(col(KBMedia.doc_id) == doc_id)
|
|
271
|
+
result = await session.execute(stmt)
|
|
272
|
+
return list(result.scalars().all())
|
|
273
|
+
|
|
274
|
+
async def get_media_by_id(self, media_id: str) -> KBMedia | None:
|
|
275
|
+
"""根据 ID 获取多媒体资源"""
|
|
276
|
+
async with self.get_db() as session:
|
|
277
|
+
stmt = select(KBMedia).where(col(KBMedia.media_id) == media_id)
|
|
278
|
+
result = await session.execute(stmt)
|
|
279
|
+
return result.scalar_one_or_none()
|
|
280
|
+
|
|
281
|
+
async def update_kb_stats(self, kb_id: str, vec_db: FaissVecDB) -> None:
|
|
282
|
+
"""更新知识库统计信息"""
|
|
283
|
+
chunk_cnt = await vec_db.count_documents()
|
|
284
|
+
|
|
285
|
+
async with self.get_db() as session:
|
|
286
|
+
async with session.begin():
|
|
287
|
+
update_stmt = (
|
|
288
|
+
update(KnowledgeBase)
|
|
289
|
+
.where(col(KnowledgeBase.kb_id) == kb_id)
|
|
290
|
+
.values(
|
|
291
|
+
doc_count=select(func.count(col(KBDocument.id)))
|
|
292
|
+
.where(col(KBDocument.kb_id) == kb_id)
|
|
293
|
+
.scalar_subquery(),
|
|
294
|
+
chunk_count=chunk_cnt,
|
|
295
|
+
)
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
await session.execute(update_stmt)
|
|
299
|
+
await session.commit()
|
|
@@ -0,0 +1,348 @@
|
|
|
1
|
+
import uuid
|
|
2
|
+
import aiofiles
|
|
3
|
+
import json
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from .models import KnowledgeBase, KBDocument, KBMedia
|
|
6
|
+
from .kb_db_sqlite import KBSQLiteDatabase
|
|
7
|
+
from astrbot.core.db.vec_db.base import BaseVecDB
|
|
8
|
+
from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB
|
|
9
|
+
from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider
|
|
10
|
+
from astrbot.core.provider.manager import ProviderManager
|
|
11
|
+
from .parsers.util import select_parser
|
|
12
|
+
from .chunking.base import BaseChunker
|
|
13
|
+
from astrbot.core import logger
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class KBHelper:
|
|
17
|
+
vec_db: BaseVecDB
|
|
18
|
+
kb: KnowledgeBase
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
kb_db: KBSQLiteDatabase,
|
|
23
|
+
kb: KnowledgeBase,
|
|
24
|
+
provider_manager: ProviderManager,
|
|
25
|
+
kb_root_dir: str,
|
|
26
|
+
chunker: BaseChunker,
|
|
27
|
+
):
|
|
28
|
+
self.kb_db = kb_db
|
|
29
|
+
self.kb = kb
|
|
30
|
+
self.prov_mgr = provider_manager
|
|
31
|
+
self.kb_root_dir = kb_root_dir
|
|
32
|
+
self.chunker = chunker
|
|
33
|
+
|
|
34
|
+
self.kb_dir = Path(self.kb_root_dir) / self.kb.kb_id
|
|
35
|
+
self.kb_medias_dir = Path(self.kb_dir) / "medias" / self.kb.kb_id
|
|
36
|
+
self.kb_files_dir = Path(self.kb_dir) / "files" / self.kb.kb_id
|
|
37
|
+
|
|
38
|
+
self.kb_medias_dir.mkdir(parents=True, exist_ok=True)
|
|
39
|
+
self.kb_files_dir.mkdir(parents=True, exist_ok=True)
|
|
40
|
+
|
|
41
|
+
async def initialize(self):
|
|
42
|
+
await self._ensure_vec_db()
|
|
43
|
+
|
|
44
|
+
async def get_ep(self) -> EmbeddingProvider:
|
|
45
|
+
if not self.kb.embedding_provider_id:
|
|
46
|
+
raise ValueError(f"知识库 {self.kb.kb_name} 未配置 Embedding Provider")
|
|
47
|
+
ep: EmbeddingProvider = await self.prov_mgr.get_provider_by_id(
|
|
48
|
+
self.kb.embedding_provider_id
|
|
49
|
+
) # type: ignore
|
|
50
|
+
if not ep:
|
|
51
|
+
raise ValueError(
|
|
52
|
+
f"无法找到 ID 为 {self.kb.embedding_provider_id} 的 Embedding Provider"
|
|
53
|
+
)
|
|
54
|
+
return ep
|
|
55
|
+
|
|
56
|
+
async def get_rp(self) -> RerankProvider | None:
|
|
57
|
+
if not self.kb.rerank_provider_id:
|
|
58
|
+
return None
|
|
59
|
+
rp: RerankProvider = await self.prov_mgr.get_provider_by_id(
|
|
60
|
+
self.kb.rerank_provider_id
|
|
61
|
+
) # type: ignore
|
|
62
|
+
if not rp:
|
|
63
|
+
raise ValueError(
|
|
64
|
+
f"无法找到 ID 为 {self.kb.rerank_provider_id} 的 Rerank Provider"
|
|
65
|
+
)
|
|
66
|
+
return rp
|
|
67
|
+
|
|
68
|
+
async def _ensure_vec_db(self) -> FaissVecDB:
|
|
69
|
+
if not self.kb.embedding_provider_id:
|
|
70
|
+
raise ValueError(f"知识库 {self.kb.kb_name} 未配置 Embedding Provider")
|
|
71
|
+
|
|
72
|
+
ep = await self.get_ep()
|
|
73
|
+
rp = await self.get_rp()
|
|
74
|
+
|
|
75
|
+
vec_db = FaissVecDB(
|
|
76
|
+
doc_store_path=str(self.kb_dir / "doc.db"),
|
|
77
|
+
index_store_path=str(self.kb_dir / "index.faiss"),
|
|
78
|
+
embedding_provider=ep,
|
|
79
|
+
rerank_provider=rp,
|
|
80
|
+
)
|
|
81
|
+
await vec_db.initialize()
|
|
82
|
+
self.vec_db = vec_db
|
|
83
|
+
return vec_db
|
|
84
|
+
|
|
85
|
+
async def delete_vec_db(self):
|
|
86
|
+
"""删除知识库的向量数据库和所有相关文件"""
|
|
87
|
+
import shutil
|
|
88
|
+
|
|
89
|
+
await self.terminate()
|
|
90
|
+
if self.kb_dir.exists():
|
|
91
|
+
shutil.rmtree(self.kb_dir)
|
|
92
|
+
|
|
93
|
+
async def terminate(self):
|
|
94
|
+
if self.vec_db:
|
|
95
|
+
await self.vec_db.close()
|
|
96
|
+
|
|
97
|
+
async def upload_document(
|
|
98
|
+
self,
|
|
99
|
+
file_name: str,
|
|
100
|
+
file_content: bytes,
|
|
101
|
+
file_type: str,
|
|
102
|
+
chunk_size: int = 512,
|
|
103
|
+
chunk_overlap: int = 50,
|
|
104
|
+
batch_size: int = 32,
|
|
105
|
+
tasks_limit: int = 3,
|
|
106
|
+
max_retries: int = 3,
|
|
107
|
+
progress_callback=None,
|
|
108
|
+
) -> KBDocument:
|
|
109
|
+
"""上传并处理文档(带原子性保证和失败清理)
|
|
110
|
+
|
|
111
|
+
流程:
|
|
112
|
+
1. 保存原始文件
|
|
113
|
+
2. 解析文档内容
|
|
114
|
+
3. 提取多媒体资源
|
|
115
|
+
4. 分块处理
|
|
116
|
+
5. 生成向量并存储
|
|
117
|
+
6. 保存元数据(事务)
|
|
118
|
+
7. 更新统计
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
progress_callback: 进度回调函数,接收参数 (stage, current, total)
|
|
122
|
+
- stage: 当前阶段 ('parsing', 'chunking', 'embedding')
|
|
123
|
+
- current: 当前进度
|
|
124
|
+
- total: 总数
|
|
125
|
+
"""
|
|
126
|
+
await self._ensure_vec_db()
|
|
127
|
+
doc_id = str(uuid.uuid4())
|
|
128
|
+
media_paths: list[Path] = []
|
|
129
|
+
|
|
130
|
+
# file_path = self.kb_files_dir / f"{doc_id}.{file_type}"
|
|
131
|
+
# async with aiofiles.open(file_path, "wb") as f:
|
|
132
|
+
# await f.write(file_content)
|
|
133
|
+
|
|
134
|
+
try:
|
|
135
|
+
# 阶段1: 解析文档
|
|
136
|
+
if progress_callback:
|
|
137
|
+
await progress_callback("parsing", 0, 100)
|
|
138
|
+
|
|
139
|
+
parser = await select_parser(f".{file_type}")
|
|
140
|
+
parse_result = await parser.parse(file_content, file_name)
|
|
141
|
+
text_content = parse_result.text
|
|
142
|
+
media_items = parse_result.media
|
|
143
|
+
|
|
144
|
+
if progress_callback:
|
|
145
|
+
await progress_callback("parsing", 100, 100)
|
|
146
|
+
|
|
147
|
+
# 保存媒体文件
|
|
148
|
+
saved_media = []
|
|
149
|
+
for media_item in media_items:
|
|
150
|
+
media = await self._save_media(
|
|
151
|
+
doc_id=doc_id,
|
|
152
|
+
media_type=media_item.media_type,
|
|
153
|
+
file_name=media_item.file_name,
|
|
154
|
+
content=media_item.content,
|
|
155
|
+
mime_type=media_item.mime_type,
|
|
156
|
+
)
|
|
157
|
+
saved_media.append(media)
|
|
158
|
+
media_paths.append(Path(media.file_path))
|
|
159
|
+
|
|
160
|
+
# 阶段2: 分块
|
|
161
|
+
if progress_callback:
|
|
162
|
+
await progress_callback("chunking", 0, 100)
|
|
163
|
+
|
|
164
|
+
chunks_text = await self.chunker.chunk(
|
|
165
|
+
text_content, chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
|
166
|
+
)
|
|
167
|
+
contents = []
|
|
168
|
+
metadatas = []
|
|
169
|
+
for idx, chunk_text in enumerate(chunks_text):
|
|
170
|
+
contents.append(chunk_text)
|
|
171
|
+
metadatas.append(
|
|
172
|
+
{
|
|
173
|
+
"kb_id": self.kb.kb_id,
|
|
174
|
+
"kb_doc_id": doc_id,
|
|
175
|
+
"chunk_index": idx,
|
|
176
|
+
}
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
if progress_callback:
|
|
180
|
+
await progress_callback("chunking", 100, 100)
|
|
181
|
+
|
|
182
|
+
# 阶段3: 生成向量(带进度回调)
|
|
183
|
+
async def embedding_progress_callback(current, total):
|
|
184
|
+
if progress_callback:
|
|
185
|
+
await progress_callback("embedding", current, total)
|
|
186
|
+
|
|
187
|
+
await self.vec_db.insert_batch(
|
|
188
|
+
contents=contents,
|
|
189
|
+
metadatas=metadatas,
|
|
190
|
+
batch_size=batch_size,
|
|
191
|
+
tasks_limit=tasks_limit,
|
|
192
|
+
max_retries=max_retries,
|
|
193
|
+
progress_callback=embedding_progress_callback,
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
# 保存文档的元数据
|
|
197
|
+
doc = KBDocument(
|
|
198
|
+
doc_id=doc_id,
|
|
199
|
+
kb_id=self.kb.kb_id,
|
|
200
|
+
doc_name=file_name,
|
|
201
|
+
file_type=file_type,
|
|
202
|
+
file_size=len(file_content),
|
|
203
|
+
# file_path=str(file_path),
|
|
204
|
+
file_path="",
|
|
205
|
+
chunk_count=len(chunks_text),
|
|
206
|
+
media_count=0,
|
|
207
|
+
)
|
|
208
|
+
async with self.kb_db.get_db() as session:
|
|
209
|
+
async with session.begin():
|
|
210
|
+
session.add(doc)
|
|
211
|
+
for media in saved_media:
|
|
212
|
+
session.add(media)
|
|
213
|
+
await session.commit()
|
|
214
|
+
|
|
215
|
+
await session.refresh(doc)
|
|
216
|
+
|
|
217
|
+
vec_db: FaissVecDB = self.vec_db # type: ignore
|
|
218
|
+
await self.kb_db.update_kb_stats(kb_id=self.kb.kb_id, vec_db=vec_db)
|
|
219
|
+
await self.refresh_kb()
|
|
220
|
+
await self.refresh_document(doc_id)
|
|
221
|
+
return doc
|
|
222
|
+
except Exception as e:
|
|
223
|
+
logger.error(f"上传文档失败: {e}")
|
|
224
|
+
# if file_path.exists():
|
|
225
|
+
# file_path.unlink()
|
|
226
|
+
|
|
227
|
+
for media_path in media_paths:
|
|
228
|
+
try:
|
|
229
|
+
if media_path.exists():
|
|
230
|
+
media_path.unlink()
|
|
231
|
+
except Exception as me:
|
|
232
|
+
logger.warning(f"清理多媒体文件失败 {media_path}: {me}")
|
|
233
|
+
|
|
234
|
+
raise e
|
|
235
|
+
|
|
236
|
+
async def list_documents(
|
|
237
|
+
self, offset: int = 0, limit: int = 100
|
|
238
|
+
) -> list[KBDocument]:
|
|
239
|
+
"""列出知识库的所有文档"""
|
|
240
|
+
docs = await self.kb_db.list_documents_by_kb(self.kb.kb_id, offset, limit)
|
|
241
|
+
return docs
|
|
242
|
+
|
|
243
|
+
async def get_document(self, doc_id: str) -> KBDocument | None:
|
|
244
|
+
"""获取单个文档"""
|
|
245
|
+
doc = await self.kb_db.get_document_by_id(doc_id)
|
|
246
|
+
return doc
|
|
247
|
+
|
|
248
|
+
async def delete_document(self, doc_id: str):
|
|
249
|
+
"""删除单个文档及其相关数据"""
|
|
250
|
+
await self.kb_db.delete_document_by_id(
|
|
251
|
+
doc_id=doc_id,
|
|
252
|
+
vec_db=self.vec_db, # type: ignore
|
|
253
|
+
)
|
|
254
|
+
await self.kb_db.update_kb_stats(
|
|
255
|
+
kb_id=self.kb.kb_id,
|
|
256
|
+
vec_db=self.vec_db, # type: ignore
|
|
257
|
+
)
|
|
258
|
+
await self.refresh_kb()
|
|
259
|
+
|
|
260
|
+
async def delete_chunk(self, chunk_id: str, doc_id: str):
|
|
261
|
+
"""删除单个文本块及其相关数据"""
|
|
262
|
+
vec_db: FaissVecDB = self.vec_db # type: ignore
|
|
263
|
+
await vec_db.delete(chunk_id)
|
|
264
|
+
await self.kb_db.update_kb_stats(
|
|
265
|
+
kb_id=self.kb.kb_id,
|
|
266
|
+
vec_db=self.vec_db, # type: ignore
|
|
267
|
+
)
|
|
268
|
+
await self.refresh_kb()
|
|
269
|
+
await self.refresh_document(doc_id)
|
|
270
|
+
|
|
271
|
+
async def refresh_kb(self):
|
|
272
|
+
if self.kb:
|
|
273
|
+
kb = await self.kb_db.get_kb_by_id(self.kb.kb_id)
|
|
274
|
+
if kb:
|
|
275
|
+
self.kb = kb
|
|
276
|
+
|
|
277
|
+
async def refresh_document(self, doc_id: str) -> None:
|
|
278
|
+
"""更新文档的元数据"""
|
|
279
|
+
doc = await self.get_document(doc_id)
|
|
280
|
+
if not doc:
|
|
281
|
+
raise ValueError(f"无法找到 ID 为 {doc_id} 的文档")
|
|
282
|
+
chunk_count = await self.get_chunk_count_by_doc_id(doc_id)
|
|
283
|
+
doc.chunk_count = chunk_count
|
|
284
|
+
async with self.kb_db.get_db() as session:
|
|
285
|
+
async with session.begin():
|
|
286
|
+
session.add(doc)
|
|
287
|
+
await session.commit()
|
|
288
|
+
await session.refresh(doc)
|
|
289
|
+
|
|
290
|
+
async def get_chunks_by_doc_id(
|
|
291
|
+
self, doc_id: str, offset: int = 0, limit: int = 100
|
|
292
|
+
) -> list[dict]:
|
|
293
|
+
"""获取文档的所有块及其元数据"""
|
|
294
|
+
vec_db: FaissVecDB = self.vec_db # type: ignore
|
|
295
|
+
chunks = await vec_db.document_storage.get_documents(
|
|
296
|
+
metadata_filters={"kb_doc_id": doc_id}, offset=offset, limit=limit
|
|
297
|
+
)
|
|
298
|
+
result = []
|
|
299
|
+
for chunk in chunks:
|
|
300
|
+
chunk_md = json.loads(chunk["metadata"])
|
|
301
|
+
result.append(
|
|
302
|
+
{
|
|
303
|
+
"chunk_id": chunk["doc_id"],
|
|
304
|
+
"doc_id": chunk_md["kb_doc_id"],
|
|
305
|
+
"kb_id": chunk_md["kb_id"],
|
|
306
|
+
"chunk_index": chunk_md["chunk_index"],
|
|
307
|
+
"content": chunk["text"],
|
|
308
|
+
"char_count": len(chunk["text"]),
|
|
309
|
+
}
|
|
310
|
+
)
|
|
311
|
+
return result
|
|
312
|
+
|
|
313
|
+
async def get_chunk_count_by_doc_id(self, doc_id: str) -> int:
|
|
314
|
+
"""获取文档的块数量"""
|
|
315
|
+
vec_db: FaissVecDB = self.vec_db # type: ignore
|
|
316
|
+
count = await vec_db.count_documents(metadata_filter={"kb_doc_id": doc_id})
|
|
317
|
+
return count
|
|
318
|
+
|
|
319
|
+
async def _save_media(
|
|
320
|
+
self,
|
|
321
|
+
doc_id: str,
|
|
322
|
+
media_type: str,
|
|
323
|
+
file_name: str,
|
|
324
|
+
content: bytes,
|
|
325
|
+
mime_type: str,
|
|
326
|
+
) -> KBMedia:
|
|
327
|
+
"""保存多媒体资源"""
|
|
328
|
+
media_id = str(uuid.uuid4())
|
|
329
|
+
ext = Path(file_name).suffix
|
|
330
|
+
|
|
331
|
+
# 保存文件
|
|
332
|
+
file_path = self.kb_medias_dir / doc_id / f"{media_id}{ext}"
|
|
333
|
+
file_path.parent.mkdir(parents=True, exist_ok=True)
|
|
334
|
+
async with aiofiles.open(file_path, "wb") as f:
|
|
335
|
+
await f.write(content)
|
|
336
|
+
|
|
337
|
+
media = KBMedia(
|
|
338
|
+
media_id=media_id,
|
|
339
|
+
doc_id=doc_id,
|
|
340
|
+
kb_id=self.kb.kb_id,
|
|
341
|
+
media_type=media_type,
|
|
342
|
+
file_name=file_name,
|
|
343
|
+
file_path=str(file_path),
|
|
344
|
+
file_size=len(content),
|
|
345
|
+
mime_type=mime_type,
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
return media
|