AstrBot 4.3.3__py3-none-any.whl → 4.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/core/agent/mcp_client.py +18 -4
- astrbot/core/agent/runners/tool_loop_agent_runner.py +31 -2
- astrbot/core/astr_agent_context.py +1 -0
- astrbot/core/astrbot_config_mgr.py +23 -51
- astrbot/core/config/default.py +139 -14
- astrbot/core/conversation_mgr.py +36 -1
- astrbot/core/core_lifecycle.py +24 -5
- astrbot/core/db/migration/migra_45_to_46.py +44 -0
- astrbot/core/db/vec_db/base.py +33 -2
- astrbot/core/db/vec_db/faiss_impl/document_storage.py +310 -52
- astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +31 -3
- astrbot/core/db/vec_db/faiss_impl/vec_db.py +81 -23
- astrbot/core/file_token_service.py +6 -1
- astrbot/core/initial_loader.py +6 -3
- astrbot/core/knowledge_base/chunking/__init__.py +11 -0
- astrbot/core/knowledge_base/chunking/base.py +24 -0
- astrbot/core/knowledge_base/chunking/fixed_size.py +57 -0
- astrbot/core/knowledge_base/chunking/recursive.py +155 -0
- astrbot/core/knowledge_base/kb_db_sqlite.py +299 -0
- astrbot/core/knowledge_base/kb_helper.py +348 -0
- astrbot/core/knowledge_base/kb_mgr.py +287 -0
- astrbot/core/knowledge_base/models.py +114 -0
- astrbot/core/knowledge_base/parsers/__init__.py +15 -0
- astrbot/core/knowledge_base/parsers/base.py +50 -0
- astrbot/core/knowledge_base/parsers/markitdown_parser.py +25 -0
- astrbot/core/knowledge_base/parsers/pdf_parser.py +100 -0
- astrbot/core/knowledge_base/parsers/text_parser.py +41 -0
- astrbot/core/knowledge_base/parsers/util.py +13 -0
- astrbot/core/knowledge_base/retrieval/__init__.py +16 -0
- astrbot/core/knowledge_base/retrieval/hit_stopwords.txt +767 -0
- astrbot/core/knowledge_base/retrieval/manager.py +273 -0
- astrbot/core/knowledge_base/retrieval/rank_fusion.py +138 -0
- astrbot/core/knowledge_base/retrieval/sparse_retriever.py +130 -0
- astrbot/core/pipeline/process_stage/method/llm_request.py +61 -21
- astrbot/core/pipeline/process_stage/utils.py +80 -0
- astrbot/core/pipeline/scheduler.py +1 -1
- astrbot/core/platform/astr_message_event.py +8 -7
- astrbot/core/platform/manager.py +4 -0
- astrbot/core/platform/sources/misskey/misskey_adapter.py +380 -44
- astrbot/core/platform/sources/misskey/misskey_api.py +581 -45
- astrbot/core/platform/sources/misskey/misskey_event.py +76 -41
- astrbot/core/platform/sources/misskey/misskey_utils.py +254 -43
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +2 -1
- astrbot/core/platform/sources/satori/satori_adapter.py +27 -1
- astrbot/core/platform/sources/satori/satori_event.py +270 -77
- astrbot/core/platform/sources/webchat/webchat_adapter.py +0 -1
- astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +289 -0
- astrbot/core/platform/sources/wecom_ai_bot/__init__.py +17 -0
- astrbot/core/platform/sources/wecom_ai_bot/ierror.py +20 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +445 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +378 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +149 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +148 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +166 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +199 -0
- astrbot/core/provider/manager.py +14 -9
- astrbot/core/provider/provider.py +67 -0
- astrbot/core/provider/sources/anthropic_source.py +4 -4
- astrbot/core/provider/sources/dashscope_source.py +10 -9
- astrbot/core/provider/sources/dify_source.py +6 -8
- astrbot/core/provider/sources/gemini_embedding_source.py +1 -2
- astrbot/core/provider/sources/openai_embedding_source.py +1 -2
- astrbot/core/provider/sources/openai_source.py +18 -15
- astrbot/core/provider/sources/openai_tts_api_source.py +1 -1
- astrbot/core/star/context.py +3 -0
- astrbot/core/star/star.py +6 -0
- astrbot/core/star/star_manager.py +13 -7
- astrbot/core/umop_config_router.py +81 -0
- astrbot/core/updator.py +1 -1
- astrbot/core/utils/io.py +23 -12
- astrbot/dashboard/routes/__init__.py +2 -0
- astrbot/dashboard/routes/config.py +137 -9
- astrbot/dashboard/routes/knowledge_base.py +1065 -0
- astrbot/dashboard/routes/plugin.py +24 -5
- astrbot/dashboard/routes/tools.py +14 -0
- astrbot/dashboard/routes/update.py +1 -1
- astrbot/dashboard/server.py +6 -0
- astrbot/dashboard/utils.py +161 -0
- {astrbot-4.3.3.dist-info → astrbot-4.5.0.dist-info}/METADATA +91 -55
- {astrbot-4.3.3.dist-info → astrbot-4.5.0.dist-info}/RECORD +83 -50
- {astrbot-4.3.3.dist-info → astrbot-4.5.0.dist-info}/WHEEL +0 -0
- {astrbot-4.3.3.dist-info → astrbot-4.5.0.dist-info}/entry_points.txt +0 -0
- {astrbot-4.3.3.dist-info → astrbot-4.5.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,1065 @@
|
|
|
1
|
+
"""知识库管理 API 路由"""
|
|
2
|
+
|
|
3
|
+
import uuid
|
|
4
|
+
import aiofiles
|
|
5
|
+
import os
|
|
6
|
+
import traceback
|
|
7
|
+
import asyncio
|
|
8
|
+
from quart import request
|
|
9
|
+
from astrbot.core import logger
|
|
10
|
+
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
|
11
|
+
from .route import Route, Response, RouteContext
|
|
12
|
+
from ..utils import generate_tsne_visualization
|
|
13
|
+
from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class KnowledgeBaseRoute(Route):
|
|
17
|
+
"""知识库管理路由
|
|
18
|
+
|
|
19
|
+
提供知识库、文档、检索、会话配置等 API 接口
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
context: RouteContext,
|
|
25
|
+
core_lifecycle: AstrBotCoreLifecycle,
|
|
26
|
+
) -> None:
|
|
27
|
+
super().__init__(context)
|
|
28
|
+
self.core_lifecycle = core_lifecycle
|
|
29
|
+
self.kb_manager = None # 延迟初始化
|
|
30
|
+
self.kb_db = None
|
|
31
|
+
self.session_config_db = None # 会话配置数据库
|
|
32
|
+
self.retrieval_manager = None
|
|
33
|
+
self.upload_progress = {} # 存储上传进度 {task_id: {status, file_index, file_total, stage, current, total}}
|
|
34
|
+
self.upload_tasks = {} # 存储后台上传任务 {task_id: {"status", "result", "error"}}
|
|
35
|
+
|
|
36
|
+
# 注册路由
|
|
37
|
+
self.routes = {
|
|
38
|
+
# 知识库管理
|
|
39
|
+
"/kb/list": ("GET", self.list_kbs),
|
|
40
|
+
"/kb/create": ("POST", self.create_kb),
|
|
41
|
+
"/kb/get": ("GET", self.get_kb),
|
|
42
|
+
"/kb/update": ("POST", self.update_kb),
|
|
43
|
+
"/kb/delete": ("POST", self.delete_kb),
|
|
44
|
+
"/kb/stats": ("GET", self.get_kb_stats),
|
|
45
|
+
# 文档管理
|
|
46
|
+
"/kb/document/list": ("GET", self.list_documents),
|
|
47
|
+
"/kb/document/upload": ("POST", self.upload_document),
|
|
48
|
+
"/kb/document/upload/progress": ("GET", self.get_upload_progress),
|
|
49
|
+
"/kb/document/get": ("GET", self.get_document),
|
|
50
|
+
"/kb/document/delete": ("POST", self.delete_document),
|
|
51
|
+
# # 块管理
|
|
52
|
+
"/kb/chunk/list": ("GET", self.list_chunks),
|
|
53
|
+
"/kb/chunk/delete": ("POST", self.delete_chunk),
|
|
54
|
+
# # 多媒体管理
|
|
55
|
+
# "/kb/media/list": ("GET", self.list_media),
|
|
56
|
+
# "/kb/media/delete": ("POST", self.delete_media),
|
|
57
|
+
# 检索
|
|
58
|
+
"/kb/retrieve": ("POST", self.retrieve),
|
|
59
|
+
# 会话知识库配置
|
|
60
|
+
"/kb/session/config/get": ("GET", self.get_session_kb_config),
|
|
61
|
+
"/kb/session/config/set": ("POST", self.set_session_kb_config),
|
|
62
|
+
"/kb/session/config/delete": ("POST", self.delete_session_kb_config),
|
|
63
|
+
}
|
|
64
|
+
self.register_routes()
|
|
65
|
+
|
|
66
|
+
def _get_kb_manager(self):
|
|
67
|
+
return self.core_lifecycle.kb_manager
|
|
68
|
+
|
|
69
|
+
async def _background_upload_task(
|
|
70
|
+
self,
|
|
71
|
+
task_id: str,
|
|
72
|
+
kb_helper,
|
|
73
|
+
files_to_upload: list,
|
|
74
|
+
chunk_size: int,
|
|
75
|
+
chunk_overlap: int,
|
|
76
|
+
batch_size: int,
|
|
77
|
+
tasks_limit: int,
|
|
78
|
+
max_retries: int,
|
|
79
|
+
):
|
|
80
|
+
"""后台上传任务"""
|
|
81
|
+
try:
|
|
82
|
+
# 初始化任务状态
|
|
83
|
+
self.upload_tasks[task_id] = {
|
|
84
|
+
"status": "processing",
|
|
85
|
+
"result": None,
|
|
86
|
+
"error": None,
|
|
87
|
+
}
|
|
88
|
+
self.upload_progress[task_id] = {
|
|
89
|
+
"status": "processing",
|
|
90
|
+
"file_index": 0,
|
|
91
|
+
"file_total": len(files_to_upload),
|
|
92
|
+
"stage": "waiting",
|
|
93
|
+
"current": 0,
|
|
94
|
+
"total": 100,
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
uploaded_docs = []
|
|
98
|
+
failed_docs = []
|
|
99
|
+
|
|
100
|
+
for file_idx, file_info in enumerate(files_to_upload):
|
|
101
|
+
try:
|
|
102
|
+
# 更新整体进度
|
|
103
|
+
self.upload_progress[task_id].update(
|
|
104
|
+
{
|
|
105
|
+
"status": "processing",
|
|
106
|
+
"file_index": file_idx,
|
|
107
|
+
"file_name": file_info["file_name"],
|
|
108
|
+
"stage": "parsing",
|
|
109
|
+
"current": 0,
|
|
110
|
+
"total": 100,
|
|
111
|
+
}
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
# 创建进度回调函数
|
|
115
|
+
async def progress_callback(stage, current, total):
|
|
116
|
+
if task_id in self.upload_progress:
|
|
117
|
+
self.upload_progress[task_id].update(
|
|
118
|
+
{
|
|
119
|
+
"status": "processing",
|
|
120
|
+
"file_index": file_idx,
|
|
121
|
+
"file_name": file_info["file_name"],
|
|
122
|
+
"stage": stage,
|
|
123
|
+
"current": current,
|
|
124
|
+
"total": total,
|
|
125
|
+
}
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
doc = await kb_helper.upload_document(
|
|
129
|
+
file_name=file_info["file_name"],
|
|
130
|
+
file_content=file_info["file_content"],
|
|
131
|
+
file_type=file_info["file_type"],
|
|
132
|
+
chunk_size=chunk_size,
|
|
133
|
+
chunk_overlap=chunk_overlap,
|
|
134
|
+
batch_size=batch_size,
|
|
135
|
+
tasks_limit=tasks_limit,
|
|
136
|
+
max_retries=max_retries,
|
|
137
|
+
progress_callback=progress_callback,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
uploaded_docs.append(doc.model_dump())
|
|
141
|
+
except Exception as e:
|
|
142
|
+
logger.error(f"上传文档 {file_info['file_name']} 失败: {e}")
|
|
143
|
+
failed_docs.append(
|
|
144
|
+
{"file_name": file_info["file_name"], "error": str(e)}
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
# 更新任务完成状态
|
|
148
|
+
result = {
|
|
149
|
+
"task_id": task_id,
|
|
150
|
+
"uploaded": uploaded_docs,
|
|
151
|
+
"failed": failed_docs,
|
|
152
|
+
"total": len(files_to_upload),
|
|
153
|
+
"success_count": len(uploaded_docs),
|
|
154
|
+
"failed_count": len(failed_docs),
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
self.upload_tasks[task_id] = {
|
|
158
|
+
"status": "completed",
|
|
159
|
+
"result": result,
|
|
160
|
+
"error": None,
|
|
161
|
+
}
|
|
162
|
+
self.upload_progress[task_id]["status"] = "completed"
|
|
163
|
+
|
|
164
|
+
except Exception as e:
|
|
165
|
+
logger.error(f"后台上传任务 {task_id} 失败: {e}")
|
|
166
|
+
logger.error(traceback.format_exc())
|
|
167
|
+
self.upload_tasks[task_id] = {
|
|
168
|
+
"status": "failed",
|
|
169
|
+
"result": None,
|
|
170
|
+
"error": str(e),
|
|
171
|
+
}
|
|
172
|
+
if task_id in self.upload_progress:
|
|
173
|
+
self.upload_progress[task_id]["status"] = "failed"
|
|
174
|
+
|
|
175
|
+
async def list_kbs(self):
|
|
176
|
+
"""获取知识库列表
|
|
177
|
+
|
|
178
|
+
Query 参数:
|
|
179
|
+
- page: 页码 (默认 1)
|
|
180
|
+
- page_size: 每页数量 (默认 20)
|
|
181
|
+
- refresh_stats: 是否刷新统计信息 (默认 false,首次加载时可设为 true)
|
|
182
|
+
"""
|
|
183
|
+
try:
|
|
184
|
+
kb_manager = self._get_kb_manager()
|
|
185
|
+
page = request.args.get("page", 1, type=int)
|
|
186
|
+
page_size = request.args.get("page_size", 20, type=int)
|
|
187
|
+
|
|
188
|
+
kbs = await kb_manager.list_kbs()
|
|
189
|
+
|
|
190
|
+
# 转换为字典列表
|
|
191
|
+
kb_list = []
|
|
192
|
+
for kb in kbs:
|
|
193
|
+
kb_list.append(kb.model_dump())
|
|
194
|
+
|
|
195
|
+
return (
|
|
196
|
+
Response()
|
|
197
|
+
.ok({"items": kb_list, "page": page, "page_size": page_size})
|
|
198
|
+
.__dict__
|
|
199
|
+
)
|
|
200
|
+
except ValueError as e:
|
|
201
|
+
return Response().error(str(e)).__dict__
|
|
202
|
+
except Exception as e:
|
|
203
|
+
logger.error(f"获取知识库列表失败: {e}")
|
|
204
|
+
logger.error(traceback.format_exc())
|
|
205
|
+
return Response().error(f"获取知识库列表失败: {str(e)}").__dict__
|
|
206
|
+
|
|
207
|
+
async def create_kb(self):
|
|
208
|
+
"""创建知识库
|
|
209
|
+
|
|
210
|
+
Body:
|
|
211
|
+
- kb_name: 知识库名称 (必填)
|
|
212
|
+
- description: 描述 (可选)
|
|
213
|
+
- emoji: 图标 (可选)
|
|
214
|
+
- embedding_provider_id: 嵌入模型提供商ID (可选)
|
|
215
|
+
- rerank_provider_id: 重排序模型提供商ID (可选)
|
|
216
|
+
- chunk_size: 分块大小 (可选, 默认512)
|
|
217
|
+
- chunk_overlap: 块重叠大小 (可选, 默认50)
|
|
218
|
+
- top_k_dense: 密集检索数量 (可选, 默认50)
|
|
219
|
+
- top_k_sparse: 稀疏检索数量 (可选, 默认50)
|
|
220
|
+
- top_m_final: 最终返回数量 (可选, 默认5)
|
|
221
|
+
"""
|
|
222
|
+
try:
|
|
223
|
+
kb_manager = self._get_kb_manager()
|
|
224
|
+
data = await request.json
|
|
225
|
+
kb_name = data.get("kb_name")
|
|
226
|
+
if not kb_name:
|
|
227
|
+
return Response().error("知识库名称不能为空").__dict__
|
|
228
|
+
|
|
229
|
+
description = data.get("description")
|
|
230
|
+
emoji = data.get("emoji")
|
|
231
|
+
embedding_provider_id = data.get("embedding_provider_id")
|
|
232
|
+
rerank_provider_id = data.get("rerank_provider_id")
|
|
233
|
+
chunk_size = data.get("chunk_size")
|
|
234
|
+
chunk_overlap = data.get("chunk_overlap")
|
|
235
|
+
top_k_dense = data.get("top_k_dense")
|
|
236
|
+
top_k_sparse = data.get("top_k_sparse")
|
|
237
|
+
top_m_final = data.get("top_m_final")
|
|
238
|
+
|
|
239
|
+
# pre-check embedding dim
|
|
240
|
+
if not embedding_provider_id:
|
|
241
|
+
return Response().error("缺少参数 embedding_provider_id").__dict__
|
|
242
|
+
prv = await kb_manager.provider_manager.get_provider_by_id(
|
|
243
|
+
embedding_provider_id
|
|
244
|
+
) # type: ignore
|
|
245
|
+
if not prv or not isinstance(prv, EmbeddingProvider):
|
|
246
|
+
return (
|
|
247
|
+
Response().error(f"嵌入模型不存在或类型错误({type(prv)})").__dict__
|
|
248
|
+
)
|
|
249
|
+
try:
|
|
250
|
+
vec = await prv.get_embedding("astrbot")
|
|
251
|
+
if len(vec) != prv.get_dim():
|
|
252
|
+
raise ValueError(
|
|
253
|
+
f"嵌入向量维度不匹配,实际是 {len(vec)},然而配置是 {prv.get_dim()}"
|
|
254
|
+
)
|
|
255
|
+
except Exception as e:
|
|
256
|
+
return Response().error(f"测试嵌入模型失败: {str(e)}").__dict__
|
|
257
|
+
# pre-check rerank
|
|
258
|
+
if rerank_provider_id:
|
|
259
|
+
rerank_prv: RerankProvider = (
|
|
260
|
+
await kb_manager.provider_manager.get_provider_by_id(
|
|
261
|
+
rerank_provider_id
|
|
262
|
+
)
|
|
263
|
+
) # type: ignore
|
|
264
|
+
if not rerank_prv:
|
|
265
|
+
return Response().error("重排序模型不存在").__dict__
|
|
266
|
+
# 检查重排序模型可用性
|
|
267
|
+
try:
|
|
268
|
+
res = await rerank_prv.rerank(
|
|
269
|
+
query="astrbot", documents=["astrbot knowledge base"]
|
|
270
|
+
)
|
|
271
|
+
if not res:
|
|
272
|
+
raise ValueError("重排序模型返回结果异常")
|
|
273
|
+
except Exception as e:
|
|
274
|
+
return (
|
|
275
|
+
Response()
|
|
276
|
+
.error(f"测试重排序模型失败: {str(e)},请检查控制台日志输出。")
|
|
277
|
+
.__dict__
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
kb_helper = await kb_manager.create_kb(
|
|
281
|
+
kb_name=kb_name,
|
|
282
|
+
description=description,
|
|
283
|
+
emoji=emoji,
|
|
284
|
+
embedding_provider_id=embedding_provider_id,
|
|
285
|
+
rerank_provider_id=rerank_provider_id,
|
|
286
|
+
chunk_size=chunk_size,
|
|
287
|
+
chunk_overlap=chunk_overlap,
|
|
288
|
+
top_k_dense=top_k_dense,
|
|
289
|
+
top_k_sparse=top_k_sparse,
|
|
290
|
+
top_m_final=top_m_final,
|
|
291
|
+
)
|
|
292
|
+
kb = kb_helper.kb
|
|
293
|
+
|
|
294
|
+
return Response().ok(kb.model_dump(), "创建知识库成功").__dict__
|
|
295
|
+
|
|
296
|
+
except ValueError as e:
|
|
297
|
+
return Response().error(str(e)).__dict__
|
|
298
|
+
except Exception as e:
|
|
299
|
+
logger.error(f"创建知识库失败: {e}")
|
|
300
|
+
logger.error(traceback.format_exc())
|
|
301
|
+
return Response().error(f"创建知识库失败: {str(e)}").__dict__
|
|
302
|
+
|
|
303
|
+
async def get_kb(self):
|
|
304
|
+
"""获取知识库详情
|
|
305
|
+
|
|
306
|
+
Query 参数:
|
|
307
|
+
- kb_id: 知识库 ID (必填)
|
|
308
|
+
"""
|
|
309
|
+
try:
|
|
310
|
+
kb_manager = self._get_kb_manager()
|
|
311
|
+
kb_id = request.args.get("kb_id")
|
|
312
|
+
if not kb_id:
|
|
313
|
+
return Response().error("缺少参数 kb_id").__dict__
|
|
314
|
+
|
|
315
|
+
kb_helper = await kb_manager.get_kb(kb_id)
|
|
316
|
+
if not kb_helper:
|
|
317
|
+
return Response().error("知识库不存在").__dict__
|
|
318
|
+
kb = kb_helper.kb
|
|
319
|
+
|
|
320
|
+
return Response().ok(kb.model_dump()).__dict__
|
|
321
|
+
|
|
322
|
+
except ValueError as e:
|
|
323
|
+
return Response().error(str(e)).__dict__
|
|
324
|
+
except Exception as e:
|
|
325
|
+
logger.error(f"获取知识库详情失败: {e}")
|
|
326
|
+
logger.error(traceback.format_exc())
|
|
327
|
+
return Response().error(f"获取知识库详情失败: {str(e)}").__dict__
|
|
328
|
+
|
|
329
|
+
async def update_kb(self):
|
|
330
|
+
"""更新知识库
|
|
331
|
+
|
|
332
|
+
Body:
|
|
333
|
+
- kb_id: 知识库 ID (必填)
|
|
334
|
+
- kb_name: 新的知识库名称 (可选)
|
|
335
|
+
- description: 新的描述 (可选)
|
|
336
|
+
- emoji: 新的图标 (可选)
|
|
337
|
+
- embedding_provider_id: 新的嵌入模型提供商ID (可选)
|
|
338
|
+
- rerank_provider_id: 新的重排序模型提供商ID (可选)
|
|
339
|
+
- chunk_size: 分块大小 (可选)
|
|
340
|
+
- chunk_overlap: 块重叠大小 (可选)
|
|
341
|
+
- top_k_dense: 密集检索数量 (可选)
|
|
342
|
+
- top_k_sparse: 稀疏检索数量 (可选)
|
|
343
|
+
- top_m_final: 最终返回数量 (可选)
|
|
344
|
+
"""
|
|
345
|
+
try:
|
|
346
|
+
kb_manager = self._get_kb_manager()
|
|
347
|
+
data = await request.json
|
|
348
|
+
|
|
349
|
+
kb_id = data.get("kb_id")
|
|
350
|
+
if not kb_id:
|
|
351
|
+
return Response().error("缺少参数 kb_id").__dict__
|
|
352
|
+
|
|
353
|
+
kb_name = data.get("kb_name")
|
|
354
|
+
description = data.get("description")
|
|
355
|
+
emoji = data.get("emoji")
|
|
356
|
+
embedding_provider_id = data.get("embedding_provider_id")
|
|
357
|
+
rerank_provider_id = data.get("rerank_provider_id")
|
|
358
|
+
chunk_size = data.get("chunk_size")
|
|
359
|
+
chunk_overlap = data.get("chunk_overlap")
|
|
360
|
+
top_k_dense = data.get("top_k_dense")
|
|
361
|
+
top_k_sparse = data.get("top_k_sparse")
|
|
362
|
+
top_m_final = data.get("top_m_final")
|
|
363
|
+
|
|
364
|
+
# 检查是否至少提供了一个更新字段
|
|
365
|
+
if all(
|
|
366
|
+
v is None
|
|
367
|
+
for v in [
|
|
368
|
+
kb_name,
|
|
369
|
+
description,
|
|
370
|
+
emoji,
|
|
371
|
+
embedding_provider_id,
|
|
372
|
+
rerank_provider_id,
|
|
373
|
+
chunk_size,
|
|
374
|
+
chunk_overlap,
|
|
375
|
+
top_k_dense,
|
|
376
|
+
top_k_sparse,
|
|
377
|
+
top_m_final,
|
|
378
|
+
]
|
|
379
|
+
):
|
|
380
|
+
return Response().error("至少需要提供一个更新字段").__dict__
|
|
381
|
+
|
|
382
|
+
kb_helper = await kb_manager.update_kb(
|
|
383
|
+
kb_id=kb_id,
|
|
384
|
+
kb_name=kb_name,
|
|
385
|
+
description=description,
|
|
386
|
+
emoji=emoji,
|
|
387
|
+
embedding_provider_id=embedding_provider_id,
|
|
388
|
+
rerank_provider_id=rerank_provider_id,
|
|
389
|
+
chunk_size=chunk_size,
|
|
390
|
+
chunk_overlap=chunk_overlap,
|
|
391
|
+
top_k_dense=top_k_dense,
|
|
392
|
+
top_k_sparse=top_k_sparse,
|
|
393
|
+
top_m_final=top_m_final,
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
if not kb_helper:
|
|
397
|
+
return Response().error("知识库不存在").__dict__
|
|
398
|
+
|
|
399
|
+
kb = kb_helper.kb
|
|
400
|
+
return Response().ok(kb.model_dump(), "更新知识库成功").__dict__
|
|
401
|
+
|
|
402
|
+
except ValueError as e:
|
|
403
|
+
return Response().error(str(e)).__dict__
|
|
404
|
+
except Exception as e:
|
|
405
|
+
logger.error(f"更新知识库失败: {e}")
|
|
406
|
+
logger.error(traceback.format_exc())
|
|
407
|
+
return Response().error(f"更新知识库失败: {str(e)}").__dict__
|
|
408
|
+
|
|
409
|
+
async def delete_kb(self):
|
|
410
|
+
"""删除知识库
|
|
411
|
+
|
|
412
|
+
Body:
|
|
413
|
+
- kb_id: 知识库 ID (必填)
|
|
414
|
+
"""
|
|
415
|
+
try:
|
|
416
|
+
kb_manager = self._get_kb_manager()
|
|
417
|
+
data = await request.json
|
|
418
|
+
|
|
419
|
+
kb_id = data.get("kb_id")
|
|
420
|
+
if not kb_id:
|
|
421
|
+
return Response().error("缺少参数 kb_id").__dict__
|
|
422
|
+
|
|
423
|
+
success = await kb_manager.delete_kb(kb_id)
|
|
424
|
+
if not success:
|
|
425
|
+
return Response().error("知识库不存在").__dict__
|
|
426
|
+
|
|
427
|
+
return Response().ok(message="删除知识库成功").__dict__
|
|
428
|
+
|
|
429
|
+
except ValueError as e:
|
|
430
|
+
return Response().error(str(e)).__dict__
|
|
431
|
+
except Exception as e:
|
|
432
|
+
logger.error(f"删除知识库失败: {e}")
|
|
433
|
+
logger.error(traceback.format_exc())
|
|
434
|
+
return Response().error(f"删除知识库失败: {str(e)}").__dict__
|
|
435
|
+
|
|
436
|
+
async def get_kb_stats(self):
|
|
437
|
+
"""获取知识库统计信息
|
|
438
|
+
|
|
439
|
+
Query 参数:
|
|
440
|
+
- kb_id: 知识库 ID (必填)
|
|
441
|
+
"""
|
|
442
|
+
try:
|
|
443
|
+
kb_manager = self._get_kb_manager()
|
|
444
|
+
kb_id = request.args.get("kb_id")
|
|
445
|
+
if not kb_id:
|
|
446
|
+
return Response().error("缺少参数 kb_id").__dict__
|
|
447
|
+
|
|
448
|
+
kb_helper = await kb_manager.get_kb(kb_id)
|
|
449
|
+
if not kb_helper:
|
|
450
|
+
return Response().error("知识库不存在").__dict__
|
|
451
|
+
kb = kb_helper.kb
|
|
452
|
+
|
|
453
|
+
stats = {
|
|
454
|
+
"kb_id": kb.kb_id,
|
|
455
|
+
"kb_name": kb.kb_name,
|
|
456
|
+
"doc_count": kb.doc_count,
|
|
457
|
+
"chunk_count": kb.chunk_count,
|
|
458
|
+
"created_at": kb.created_at.isoformat(),
|
|
459
|
+
"updated_at": kb.updated_at.isoformat(),
|
|
460
|
+
}
|
|
461
|
+
|
|
462
|
+
return Response().ok(stats).__dict__
|
|
463
|
+
|
|
464
|
+
except ValueError as e:
|
|
465
|
+
return Response().error(str(e)).__dict__
|
|
466
|
+
except Exception as e:
|
|
467
|
+
logger.error(f"获取知识库统计失败: {e}")
|
|
468
|
+
logger.error(traceback.format_exc())
|
|
469
|
+
return Response().error(f"获取知识库统计失败: {str(e)}").__dict__
|
|
470
|
+
|
|
471
|
+
# ===== 文档管理 API =====
|
|
472
|
+
|
|
473
|
+
async def list_documents(self):
|
|
474
|
+
"""获取文档列表
|
|
475
|
+
|
|
476
|
+
Query 参数:
|
|
477
|
+
- kb_id: 知识库 ID (必填)
|
|
478
|
+
- page: 页码 (默认 1)
|
|
479
|
+
- page_size: 每页数量 (默认 20)
|
|
480
|
+
"""
|
|
481
|
+
try:
|
|
482
|
+
kb_manager = self._get_kb_manager()
|
|
483
|
+
kb_id = request.args.get("kb_id")
|
|
484
|
+
if not kb_id:
|
|
485
|
+
return Response().error("缺少参数 kb_id").__dict__
|
|
486
|
+
kb_helper = await kb_manager.get_kb(kb_id)
|
|
487
|
+
if not kb_helper:
|
|
488
|
+
return Response().error("知识库不存在").__dict__
|
|
489
|
+
|
|
490
|
+
page = request.args.get("page", 1, type=int)
|
|
491
|
+
page_size = request.args.get("page_size", 100, type=int)
|
|
492
|
+
|
|
493
|
+
offset = (page - 1) * page_size
|
|
494
|
+
limit = page_size
|
|
495
|
+
|
|
496
|
+
doc_list = await kb_helper.list_documents(offset=offset, limit=limit)
|
|
497
|
+
|
|
498
|
+
doc_list = [doc.model_dump() for doc in doc_list]
|
|
499
|
+
|
|
500
|
+
return (
|
|
501
|
+
Response()
|
|
502
|
+
.ok({"items": doc_list, "page": page, "page_size": page_size})
|
|
503
|
+
.__dict__
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
except ValueError as e:
|
|
507
|
+
return Response().error(str(e)).__dict__
|
|
508
|
+
except Exception as e:
|
|
509
|
+
logger.error(f"获取文档列表失败: {e}")
|
|
510
|
+
logger.error(traceback.format_exc())
|
|
511
|
+
return Response().error(f"获取文档列表失败: {str(e)}").__dict__
|
|
512
|
+
|
|
513
|
+
async def upload_document(self):
|
|
514
|
+
"""上传文档
|
|
515
|
+
|
|
516
|
+
支持两种方式:
|
|
517
|
+
1. multipart/form-data 文件上传(支持多文件,最多10个)
|
|
518
|
+
2. JSON 格式 base64 编码上传(支持多文件,最多10个)
|
|
519
|
+
|
|
520
|
+
Form Data (multipart/form-data):
|
|
521
|
+
- kb_id: 知识库 ID (必填)
|
|
522
|
+
- file: 文件对象 (必填,可多个,字段名为 file, file1, file2, ... 或 files[])
|
|
523
|
+
|
|
524
|
+
JSON Body (application/json):
|
|
525
|
+
- kb_id: 知识库 ID (必填)
|
|
526
|
+
- files: 文件数组 (必填)
|
|
527
|
+
- file_name: 文件名 (必填)
|
|
528
|
+
- file_content: base64 编码的文件内容 (必填)
|
|
529
|
+
|
|
530
|
+
返回:
|
|
531
|
+
- task_id: 任务ID,用于查询上传进度和结果
|
|
532
|
+
"""
|
|
533
|
+
try:
|
|
534
|
+
kb_manager = self._get_kb_manager()
|
|
535
|
+
|
|
536
|
+
# 检查 Content-Type
|
|
537
|
+
content_type = request.content_type
|
|
538
|
+
kb_id = None
|
|
539
|
+
chunk_size = None
|
|
540
|
+
chunk_overlap = None
|
|
541
|
+
batch_size = 32
|
|
542
|
+
tasks_limit = 3
|
|
543
|
+
max_retries = 3
|
|
544
|
+
files_to_upload = [] # 存储待上传的文件信息列表
|
|
545
|
+
|
|
546
|
+
if content_type and "multipart/form-data" not in content_type:
|
|
547
|
+
return (
|
|
548
|
+
Response().error("Content-Type 须为 multipart/form-data").__dict__
|
|
549
|
+
)
|
|
550
|
+
form_data = await request.form
|
|
551
|
+
files = await request.files
|
|
552
|
+
|
|
553
|
+
kb_id = form_data.get("kb_id")
|
|
554
|
+
chunk_size = int(form_data.get("chunk_size", 512))
|
|
555
|
+
chunk_overlap = int(form_data.get("chunk_overlap", 50))
|
|
556
|
+
batch_size = int(form_data.get("batch_size", 32))
|
|
557
|
+
tasks_limit = int(form_data.get("tasks_limit", 3))
|
|
558
|
+
max_retries = int(form_data.get("max_retries", 3))
|
|
559
|
+
if not kb_id:
|
|
560
|
+
return Response().error("缺少参数 kb_id").__dict__
|
|
561
|
+
|
|
562
|
+
# 收集所有文件
|
|
563
|
+
file_list = []
|
|
564
|
+
# 支持 file, file1, file2, ... 或 files[] 格式
|
|
565
|
+
for key in files.keys():
|
|
566
|
+
if key == "file" or key.startswith("file") or key == "files[]":
|
|
567
|
+
file_items = files.getlist(key)
|
|
568
|
+
file_list.extend(file_items)
|
|
569
|
+
|
|
570
|
+
if not file_list:
|
|
571
|
+
return Response().error("缺少文件").__dict__
|
|
572
|
+
|
|
573
|
+
# 限制文件数量
|
|
574
|
+
if len(file_list) > 10:
|
|
575
|
+
return Response().error("最多只能上传10个文件").__dict__
|
|
576
|
+
|
|
577
|
+
# 处理每个文件
|
|
578
|
+
for file in file_list:
|
|
579
|
+
file_name = file.filename
|
|
580
|
+
|
|
581
|
+
# 保存到临时文件
|
|
582
|
+
temp_file_path = f"data/temp/{uuid.uuid4()}_{file_name}"
|
|
583
|
+
await file.save(temp_file_path)
|
|
584
|
+
|
|
585
|
+
try:
|
|
586
|
+
# 异步读取文件内容
|
|
587
|
+
async with aiofiles.open(temp_file_path, "rb") as f:
|
|
588
|
+
file_content = await f.read()
|
|
589
|
+
|
|
590
|
+
# 提取文件类型
|
|
591
|
+
file_type = (
|
|
592
|
+
file_name.rsplit(".", 1)[-1].lower() if "." in file_name else ""
|
|
593
|
+
)
|
|
594
|
+
|
|
595
|
+
files_to_upload.append(
|
|
596
|
+
{
|
|
597
|
+
"file_name": file_name,
|
|
598
|
+
"file_content": file_content,
|
|
599
|
+
"file_type": file_type,
|
|
600
|
+
}
|
|
601
|
+
)
|
|
602
|
+
finally:
|
|
603
|
+
# 清理临时文件
|
|
604
|
+
if os.path.exists(temp_file_path):
|
|
605
|
+
os.remove(temp_file_path)
|
|
606
|
+
|
|
607
|
+
# 获取知识库
|
|
608
|
+
kb_helper = await kb_manager.get_kb(kb_id)
|
|
609
|
+
if not kb_helper:
|
|
610
|
+
return Response().error("知识库不存在").__dict__
|
|
611
|
+
|
|
612
|
+
# 生成任务ID
|
|
613
|
+
task_id = str(uuid.uuid4())
|
|
614
|
+
|
|
615
|
+
# 初始化任务状态
|
|
616
|
+
self.upload_tasks[task_id] = {
|
|
617
|
+
"status": "pending",
|
|
618
|
+
"result": None,
|
|
619
|
+
"error": None,
|
|
620
|
+
}
|
|
621
|
+
|
|
622
|
+
# 启动后台任务
|
|
623
|
+
asyncio.create_task(
|
|
624
|
+
self._background_upload_task(
|
|
625
|
+
task_id=task_id,
|
|
626
|
+
kb_helper=kb_helper,
|
|
627
|
+
files_to_upload=files_to_upload,
|
|
628
|
+
chunk_size=chunk_size,
|
|
629
|
+
chunk_overlap=chunk_overlap,
|
|
630
|
+
batch_size=batch_size,
|
|
631
|
+
tasks_limit=tasks_limit,
|
|
632
|
+
max_retries=max_retries,
|
|
633
|
+
)
|
|
634
|
+
)
|
|
635
|
+
|
|
636
|
+
return (
|
|
637
|
+
Response()
|
|
638
|
+
.ok(
|
|
639
|
+
{
|
|
640
|
+
"task_id": task_id,
|
|
641
|
+
"file_count": len(files_to_upload),
|
|
642
|
+
"message": "task created, processing in background",
|
|
643
|
+
}
|
|
644
|
+
)
|
|
645
|
+
.__dict__
|
|
646
|
+
)
|
|
647
|
+
|
|
648
|
+
except ValueError as e:
|
|
649
|
+
return Response().error(str(e)).__dict__
|
|
650
|
+
except Exception as e:
|
|
651
|
+
logger.error(f"上传文档失败: {e}")
|
|
652
|
+
logger.error(traceback.format_exc())
|
|
653
|
+
return Response().error(f"上传文档失败: {str(e)}").__dict__
|
|
654
|
+
|
|
655
|
+
async def get_upload_progress(self):
|
|
656
|
+
"""获取上传进度和结果
|
|
657
|
+
|
|
658
|
+
Query 参数:
|
|
659
|
+
- task_id: 任务 ID (必填)
|
|
660
|
+
|
|
661
|
+
返回状态:
|
|
662
|
+
- pending: 任务待处理
|
|
663
|
+
- processing: 任务处理中
|
|
664
|
+
- completed: 任务完成
|
|
665
|
+
- failed: 任务失败
|
|
666
|
+
"""
|
|
667
|
+
try:
|
|
668
|
+
task_id = request.args.get("task_id")
|
|
669
|
+
if not task_id:
|
|
670
|
+
return Response().error("缺少参数 task_id").__dict__
|
|
671
|
+
|
|
672
|
+
# 检查任务是否存在
|
|
673
|
+
if task_id not in self.upload_tasks:
|
|
674
|
+
return Response().error("找不到该任务").__dict__
|
|
675
|
+
|
|
676
|
+
task_info = self.upload_tasks[task_id]
|
|
677
|
+
status = task_info["status"]
|
|
678
|
+
|
|
679
|
+
# 构建返回数据
|
|
680
|
+
response_data = {
|
|
681
|
+
"task_id": task_id,
|
|
682
|
+
"status": status,
|
|
683
|
+
}
|
|
684
|
+
|
|
685
|
+
# 如果任务正在处理,返回进度信息
|
|
686
|
+
if status == "processing" and task_id in self.upload_progress:
|
|
687
|
+
response_data["progress"] = self.upload_progress[task_id]
|
|
688
|
+
|
|
689
|
+
# 如果任务完成,返回结果
|
|
690
|
+
if status == "completed":
|
|
691
|
+
response_data["result"] = task_info["result"]
|
|
692
|
+
# 清理已完成的任务
|
|
693
|
+
# del self.upload_tasks[task_id]
|
|
694
|
+
# if task_id in self.upload_progress:
|
|
695
|
+
# del self.upload_progress[task_id]
|
|
696
|
+
|
|
697
|
+
# 如果任务失败,返回错误信息
|
|
698
|
+
if status == "failed":
|
|
699
|
+
response_data["error"] = task_info["error"]
|
|
700
|
+
|
|
701
|
+
return Response().ok(response_data).__dict__
|
|
702
|
+
|
|
703
|
+
except Exception as e:
|
|
704
|
+
logger.error(f"获取上传进度失败: {e}")
|
|
705
|
+
logger.error(traceback.format_exc())
|
|
706
|
+
return Response().error(f"获取上传进度失败: {str(e)}").__dict__
|
|
707
|
+
|
|
708
|
+
async def get_document(self):
|
|
709
|
+
"""获取文档详情
|
|
710
|
+
|
|
711
|
+
Query 参数:
|
|
712
|
+
- doc_id: 文档 ID (必填)
|
|
713
|
+
"""
|
|
714
|
+
try:
|
|
715
|
+
kb_manager = self._get_kb_manager()
|
|
716
|
+
kb_id = request.args.get("kb_id")
|
|
717
|
+
if not kb_id:
|
|
718
|
+
return Response().error("缺少参数 kb_id").__dict__
|
|
719
|
+
doc_id = request.args.get("doc_id")
|
|
720
|
+
if not doc_id:
|
|
721
|
+
return Response().error("缺少参数 doc_id").__dict__
|
|
722
|
+
kb_helper = await kb_manager.get_kb(kb_id)
|
|
723
|
+
if not kb_helper:
|
|
724
|
+
return Response().error("知识库不存在").__dict__
|
|
725
|
+
|
|
726
|
+
doc = await kb_helper.get_document(doc_id)
|
|
727
|
+
if not doc:
|
|
728
|
+
return Response().error("文档不存在").__dict__
|
|
729
|
+
|
|
730
|
+
return Response().ok(doc.model_dump()).__dict__
|
|
731
|
+
|
|
732
|
+
except ValueError as e:
|
|
733
|
+
return Response().error(str(e)).__dict__
|
|
734
|
+
except Exception as e:
|
|
735
|
+
logger.error(f"获取文档详情失败: {e}")
|
|
736
|
+
logger.error(traceback.format_exc())
|
|
737
|
+
return Response().error(f"获取文档详情失败: {str(e)}").__dict__
|
|
738
|
+
|
|
739
|
+
async def delete_document(self):
|
|
740
|
+
"""删除文档
|
|
741
|
+
|
|
742
|
+
Body:
|
|
743
|
+
- kb_id: 知识库 ID (必填)
|
|
744
|
+
- doc_id: 文档 ID (必填)
|
|
745
|
+
"""
|
|
746
|
+
try:
|
|
747
|
+
kb_manager = self._get_kb_manager()
|
|
748
|
+
data = await request.json
|
|
749
|
+
|
|
750
|
+
kb_id = data.get("kb_id")
|
|
751
|
+
if not kb_id:
|
|
752
|
+
return Response().error("缺少参数 kb_id").__dict__
|
|
753
|
+
doc_id = data.get("doc_id")
|
|
754
|
+
if not doc_id:
|
|
755
|
+
return Response().error("缺少参数 doc_id").__dict__
|
|
756
|
+
|
|
757
|
+
kb_helper = await kb_manager.get_kb(kb_id)
|
|
758
|
+
if not kb_helper:
|
|
759
|
+
return Response().error("知识库不存在").__dict__
|
|
760
|
+
|
|
761
|
+
await kb_helper.delete_document(doc_id)
|
|
762
|
+
return Response().ok(message="删除文档成功").__dict__
|
|
763
|
+
|
|
764
|
+
except ValueError as e:
|
|
765
|
+
return Response().error(str(e)).__dict__
|
|
766
|
+
except Exception as e:
|
|
767
|
+
logger.error(f"删除文档失败: {e}")
|
|
768
|
+
logger.error(traceback.format_exc())
|
|
769
|
+
return Response().error(f"删除文档失败: {str(e)}").__dict__
|
|
770
|
+
|
|
771
|
+
async def delete_chunk(self):
|
|
772
|
+
"""删除文本块
|
|
773
|
+
|
|
774
|
+
Body:
|
|
775
|
+
- kb_id: 知识库 ID (必填)
|
|
776
|
+
- chunk_id: 块 ID (必填)
|
|
777
|
+
"""
|
|
778
|
+
try:
|
|
779
|
+
kb_manager = self._get_kb_manager()
|
|
780
|
+
data = await request.json
|
|
781
|
+
|
|
782
|
+
kb_id = data.get("kb_id")
|
|
783
|
+
if not kb_id:
|
|
784
|
+
return Response().error("缺少参数 kb_id").__dict__
|
|
785
|
+
chunk_id = data.get("chunk_id")
|
|
786
|
+
if not chunk_id:
|
|
787
|
+
return Response().error("缺少参数 chunk_id").__dict__
|
|
788
|
+
doc_id = data.get("doc_id")
|
|
789
|
+
if not doc_id:
|
|
790
|
+
return Response().error("缺少参数 doc_id").__dict__
|
|
791
|
+
|
|
792
|
+
kb_helper = await kb_manager.get_kb(kb_id)
|
|
793
|
+
if not kb_helper:
|
|
794
|
+
return Response().error("知识库不存在").__dict__
|
|
795
|
+
|
|
796
|
+
await kb_helper.delete_chunk(chunk_id, doc_id)
|
|
797
|
+
return Response().ok(message="删除文本块成功").__dict__
|
|
798
|
+
|
|
799
|
+
except ValueError as e:
|
|
800
|
+
return Response().error(str(e)).__dict__
|
|
801
|
+
except Exception as e:
|
|
802
|
+
logger.error(f"删除文本块失败: {e}")
|
|
803
|
+
logger.error(traceback.format_exc())
|
|
804
|
+
return Response().error(f"删除文本块失败: {str(e)}").__dict__
|
|
805
|
+
|
|
806
|
+
async def list_chunks(self):
|
|
807
|
+
"""获取块列表
|
|
808
|
+
|
|
809
|
+
Query 参数:
|
|
810
|
+
- kb_id: 知识库 ID (必填)
|
|
811
|
+
- page: 页码 (默认 1)
|
|
812
|
+
- page_size: 每页数量 (默认 20)
|
|
813
|
+
"""
|
|
814
|
+
try:
|
|
815
|
+
kb_manager = self._get_kb_manager()
|
|
816
|
+
kb_id = request.args.get("kb_id")
|
|
817
|
+
doc_id = request.args.get("doc_id")
|
|
818
|
+
page = request.args.get("page", 1, type=int)
|
|
819
|
+
page_size = request.args.get("page_size", 100, type=int)
|
|
820
|
+
if not kb_id:
|
|
821
|
+
return Response().error("缺少参数 kb_id").__dict__
|
|
822
|
+
if not doc_id:
|
|
823
|
+
return Response().error("缺少参数 doc_id").__dict__
|
|
824
|
+
kb_helper = await kb_manager.get_kb(kb_id)
|
|
825
|
+
offset = (page - 1) * page_size
|
|
826
|
+
limit = page_size
|
|
827
|
+
if not kb_helper:
|
|
828
|
+
return Response().error("知识库不存在").__dict__
|
|
829
|
+
chunk_list = await kb_helper.get_chunks_by_doc_id(
|
|
830
|
+
doc_id=doc_id, offset=offset, limit=limit
|
|
831
|
+
)
|
|
832
|
+
return (
|
|
833
|
+
Response()
|
|
834
|
+
.ok(
|
|
835
|
+
data={
|
|
836
|
+
"items": chunk_list,
|
|
837
|
+
"page": page,
|
|
838
|
+
"page_size": page_size,
|
|
839
|
+
"total": await kb_helper.get_chunk_count_by_doc_id(doc_id),
|
|
840
|
+
}
|
|
841
|
+
)
|
|
842
|
+
.__dict__
|
|
843
|
+
)
|
|
844
|
+
except ValueError as e:
|
|
845
|
+
return Response().error(str(e)).__dict__
|
|
846
|
+
except Exception as e:
|
|
847
|
+
logger.error(f"获取块列表失败: {e}")
|
|
848
|
+
logger.error(traceback.format_exc())
|
|
849
|
+
return Response().error(f"获取块列表失败: {str(e)}").__dict__
|
|
850
|
+
|
|
851
|
+
# ===== 检索 API =====
|
|
852
|
+
|
|
853
|
+
async def retrieve(self):
|
|
854
|
+
"""检索知识库
|
|
855
|
+
|
|
856
|
+
Body:
|
|
857
|
+
- query: 查询文本 (必填)
|
|
858
|
+
- kb_ids: 知识库 ID 列表 (必填)
|
|
859
|
+
- top_k: 返回结果数量 (可选, 默认 5)
|
|
860
|
+
- debug: 是否启用调试模式,返回 t-SNE 可视化图片 (可选, 默认 False)
|
|
861
|
+
"""
|
|
862
|
+
try:
|
|
863
|
+
kb_manager = self._get_kb_manager()
|
|
864
|
+
data = await request.json
|
|
865
|
+
|
|
866
|
+
query = data.get("query")
|
|
867
|
+
kb_names = data.get("kb_names")
|
|
868
|
+
debug = data.get("debug", False)
|
|
869
|
+
|
|
870
|
+
if not query:
|
|
871
|
+
return Response().error("缺少参数 query").__dict__
|
|
872
|
+
if not kb_names or not isinstance(kb_names, list):
|
|
873
|
+
return Response().error("缺少参数 kb_names 或格式错误").__dict__
|
|
874
|
+
|
|
875
|
+
top_k = data.get("top_k", 5)
|
|
876
|
+
|
|
877
|
+
results = await kb_manager.retrieve(
|
|
878
|
+
query=query,
|
|
879
|
+
kb_names=kb_names,
|
|
880
|
+
top_m_final=top_k,
|
|
881
|
+
)
|
|
882
|
+
result_list = []
|
|
883
|
+
if results:
|
|
884
|
+
result_list = results["results"]
|
|
885
|
+
|
|
886
|
+
response_data = {
|
|
887
|
+
"results": result_list,
|
|
888
|
+
"total": len(result_list),
|
|
889
|
+
"query": query,
|
|
890
|
+
}
|
|
891
|
+
|
|
892
|
+
# Debug 模式:生成 t-SNE 可视化
|
|
893
|
+
if debug:
|
|
894
|
+
try:
|
|
895
|
+
img_base64 = await generate_tsne_visualization(
|
|
896
|
+
query, kb_names, kb_manager
|
|
897
|
+
)
|
|
898
|
+
if img_base64:
|
|
899
|
+
response_data["visualization"] = img_base64
|
|
900
|
+
except Exception as e:
|
|
901
|
+
logger.error(f"生成 t-SNE 可视化失败: {e}")
|
|
902
|
+
logger.error(traceback.format_exc())
|
|
903
|
+
response_data["visualization_error"] = str(e)
|
|
904
|
+
|
|
905
|
+
return Response().ok(response_data).__dict__
|
|
906
|
+
|
|
907
|
+
except ValueError as e:
|
|
908
|
+
return Response().error(str(e)).__dict__
|
|
909
|
+
except Exception as e:
|
|
910
|
+
logger.error(f"检索失败: {e}")
|
|
911
|
+
logger.error(traceback.format_exc())
|
|
912
|
+
return Response().error(f"检索失败: {str(e)}").__dict__
|
|
913
|
+
|
|
914
|
+
# ===== 会话知识库配置 API =====
|
|
915
|
+
|
|
916
|
+
async def get_session_kb_config(self):
|
|
917
|
+
"""获取会话的知识库配置
|
|
918
|
+
|
|
919
|
+
Query 参数:
|
|
920
|
+
- session_id: 会话 ID (必填)
|
|
921
|
+
|
|
922
|
+
返回:
|
|
923
|
+
- kb_ids: 知识库 ID 列表
|
|
924
|
+
- top_k: 返回结果数量
|
|
925
|
+
- enable_rerank: 是否启用重排序
|
|
926
|
+
"""
|
|
927
|
+
try:
|
|
928
|
+
from astrbot.core import sp
|
|
929
|
+
|
|
930
|
+
session_id = request.args.get("session_id")
|
|
931
|
+
|
|
932
|
+
if not session_id:
|
|
933
|
+
return Response().error("缺少参数 session_id").__dict__
|
|
934
|
+
|
|
935
|
+
# 从 SharedPreferences 获取配置
|
|
936
|
+
config = await sp.session_get(session_id, "kb_config", default={})
|
|
937
|
+
|
|
938
|
+
logger.debug(f"[KB配置] 读取到配置: session_id={session_id}")
|
|
939
|
+
|
|
940
|
+
# 如果没有配置,返回默认值
|
|
941
|
+
if not config:
|
|
942
|
+
config = {"kb_ids": [], "top_k": 5, "enable_rerank": True}
|
|
943
|
+
|
|
944
|
+
return Response().ok(config).__dict__
|
|
945
|
+
|
|
946
|
+
except Exception as e:
|
|
947
|
+
logger.error(f"[KB配置] 获取配置时出错: {e}", exc_info=True)
|
|
948
|
+
return Response().error(f"获取会话知识库配置失败: {str(e)}").__dict__
|
|
949
|
+
|
|
950
|
+
async def set_session_kb_config(self):
|
|
951
|
+
"""设置会话的知识库配置
|
|
952
|
+
|
|
953
|
+
Body:
|
|
954
|
+
- scope: 配置范围 (目前只支持 "session")
|
|
955
|
+
- scope_id: 会话 ID (必填)
|
|
956
|
+
- kb_ids: 知识库 ID 列表 (必填)
|
|
957
|
+
- top_k: 返回结果数量 (可选, 默认 5)
|
|
958
|
+
- enable_rerank: 是否启用重排序 (可选, 默认 true)
|
|
959
|
+
"""
|
|
960
|
+
try:
|
|
961
|
+
from astrbot.core import sp
|
|
962
|
+
|
|
963
|
+
data = await request.json
|
|
964
|
+
|
|
965
|
+
scope = data.get("scope")
|
|
966
|
+
scope_id = data.get("scope_id")
|
|
967
|
+
kb_ids = data.get("kb_ids", [])
|
|
968
|
+
top_k = data.get("top_k", 5)
|
|
969
|
+
enable_rerank = data.get("enable_rerank", True)
|
|
970
|
+
|
|
971
|
+
# 验证参数
|
|
972
|
+
if scope != "session":
|
|
973
|
+
return Response().error("目前仅支持 session 范围的配置").__dict__
|
|
974
|
+
|
|
975
|
+
if not scope_id:
|
|
976
|
+
return Response().error("缺少参数 scope_id").__dict__
|
|
977
|
+
|
|
978
|
+
if not isinstance(kb_ids, list):
|
|
979
|
+
return Response().error("kb_ids 必须是列表").__dict__
|
|
980
|
+
|
|
981
|
+
# 验证知识库是否存在
|
|
982
|
+
kb_mgr = self._get_kb_manager()
|
|
983
|
+
invalid_ids = []
|
|
984
|
+
valid_ids = []
|
|
985
|
+
for kb_id in kb_ids:
|
|
986
|
+
kb_helper = await kb_mgr.get_kb(kb_id)
|
|
987
|
+
if kb_helper:
|
|
988
|
+
valid_ids.append(kb_id)
|
|
989
|
+
else:
|
|
990
|
+
invalid_ids.append(kb_id)
|
|
991
|
+
logger.warning(f"[KB配置] 知识库不存在: {kb_id}")
|
|
992
|
+
|
|
993
|
+
if invalid_ids:
|
|
994
|
+
logger.warning(f"[KB配置] 以下知识库ID无效: {invalid_ids}")
|
|
995
|
+
|
|
996
|
+
# 允许保存空列表,表示明确不使用任何知识库
|
|
997
|
+
if kb_ids and not valid_ids:
|
|
998
|
+
# 只有当用户提供了 kb_ids 但全部无效时才报错
|
|
999
|
+
return Response().error(f"所有提供的知识库ID都无效: {kb_ids}").__dict__
|
|
1000
|
+
|
|
1001
|
+
# 如果 kb_ids 为空列表,表示用户想清空配置
|
|
1002
|
+
if not kb_ids:
|
|
1003
|
+
valid_ids = []
|
|
1004
|
+
|
|
1005
|
+
# 构建配置对象(只保存有效的ID)
|
|
1006
|
+
config = {
|
|
1007
|
+
"kb_ids": valid_ids,
|
|
1008
|
+
"top_k": top_k,
|
|
1009
|
+
"enable_rerank": enable_rerank,
|
|
1010
|
+
}
|
|
1011
|
+
|
|
1012
|
+
# 保存到 SharedPreferences
|
|
1013
|
+
await sp.session_put(scope_id, "kb_config", config)
|
|
1014
|
+
|
|
1015
|
+
# 立即验证是否保存成功
|
|
1016
|
+
verify_config = await sp.session_get(scope_id, "kb_config", default={})
|
|
1017
|
+
|
|
1018
|
+
if verify_config == config:
|
|
1019
|
+
return (
|
|
1020
|
+
Response()
|
|
1021
|
+
.ok(
|
|
1022
|
+
{"valid_ids": valid_ids, "invalid_ids": invalid_ids},
|
|
1023
|
+
"保存知识库配置成功",
|
|
1024
|
+
)
|
|
1025
|
+
.__dict__
|
|
1026
|
+
)
|
|
1027
|
+
else:
|
|
1028
|
+
logger.error("[KB配置] 配置保存失败,验证不匹配")
|
|
1029
|
+
return Response().error("配置保存失败").__dict__
|
|
1030
|
+
|
|
1031
|
+
except Exception as e:
|
|
1032
|
+
logger.error(f"[KB配置] 设置配置时出错: {e}", exc_info=True)
|
|
1033
|
+
return Response().error(f"设置会话知识库配置失败: {str(e)}").__dict__
|
|
1034
|
+
|
|
1035
|
+
async def delete_session_kb_config(self):
|
|
1036
|
+
"""删除会话的知识库配置
|
|
1037
|
+
|
|
1038
|
+
Body:
|
|
1039
|
+
- scope: 配置范围 (目前只支持 "session")
|
|
1040
|
+
- scope_id: 会话 ID (必填)
|
|
1041
|
+
"""
|
|
1042
|
+
try:
|
|
1043
|
+
from astrbot.core import sp
|
|
1044
|
+
|
|
1045
|
+
data = await request.json
|
|
1046
|
+
|
|
1047
|
+
scope = data.get("scope")
|
|
1048
|
+
scope_id = data.get("scope_id")
|
|
1049
|
+
|
|
1050
|
+
# 验证参数
|
|
1051
|
+
if scope != "session":
|
|
1052
|
+
return Response().error("目前仅支持 session 范围的配置").__dict__
|
|
1053
|
+
|
|
1054
|
+
if not scope_id:
|
|
1055
|
+
return Response().error("缺少参数 scope_id").__dict__
|
|
1056
|
+
|
|
1057
|
+
# 从 SharedPreferences 删除配置
|
|
1058
|
+
await sp.session_remove(scope_id, "kb_config")
|
|
1059
|
+
|
|
1060
|
+
return Response().ok(message="删除知识库配置成功").__dict__
|
|
1061
|
+
|
|
1062
|
+
except Exception as e:
|
|
1063
|
+
logger.error(f"删除会话知识库配置失败: {e}")
|
|
1064
|
+
logger.error(traceback.format_exc())
|
|
1065
|
+
return Response().error(f"删除会话知识库配置失败: {str(e)}").__dict__
|