AstrBot 4.9.2__py3-none-any.whl → 4.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/cli/__init__.py +1 -1
- astrbot/core/agent/message.py +6 -4
- astrbot/core/agent/response.py +22 -1
- astrbot/core/agent/run_context.py +1 -1
- astrbot/core/agent/runners/tool_loop_agent_runner.py +99 -20
- astrbot/core/astr_agent_context.py +3 -1
- astrbot/core/astr_agent_run_util.py +42 -3
- astrbot/core/astr_agent_tool_exec.py +34 -4
- astrbot/core/config/default.py +127 -184
- astrbot/core/core_lifecycle.py +3 -0
- astrbot/core/db/__init__.py +72 -0
- astrbot/core/db/po.py +59 -0
- astrbot/core/db/sqlite.py +240 -0
- astrbot/core/message/components.py +4 -5
- astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +6 -1
- astrbot/core/pipeline/respond/stage.py +1 -1
- astrbot/core/platform/sources/telegram/tg_event.py +9 -0
- astrbot/core/platform/sources/webchat/webchat_event.py +22 -18
- astrbot/core/provider/entities.py +41 -0
- astrbot/core/provider/manager.py +203 -93
- astrbot/core/provider/sources/anthropic_source.py +55 -11
- astrbot/core/provider/sources/gemini_source.py +84 -33
- astrbot/core/provider/sources/openai_source.py +21 -6
- astrbot/core/star/command_management.py +449 -0
- astrbot/core/star/context.py +4 -0
- astrbot/core/star/filter/command.py +1 -0
- astrbot/core/star/filter/command_group.py +1 -0
- astrbot/core/star/star_handler.py +4 -0
- astrbot/core/star/star_manager.py +2 -0
- astrbot/core/utils/llm_metadata.py +63 -0
- astrbot/core/utils/migra_helper.py +93 -0
- astrbot/dashboard/routes/__init__.py +2 -0
- astrbot/dashboard/routes/chat.py +56 -13
- astrbot/dashboard/routes/command.py +82 -0
- astrbot/dashboard/routes/config.py +291 -33
- astrbot/dashboard/routes/stat.py +96 -0
- astrbot/dashboard/routes/tools.py +20 -4
- astrbot/dashboard/server.py +1 -0
- {astrbot-4.9.2.dist-info → astrbot-4.10.0.dist-info}/METADATA +2 -2
- {astrbot-4.9.2.dist-info → astrbot-4.10.0.dist-info}/RECORD +43 -40
- {astrbot-4.9.2.dist-info → astrbot-4.10.0.dist-info}/WHEEL +0 -0
- {astrbot-4.9.2.dist-info → astrbot-4.10.0.dist-info}/entry_points.txt +0 -0
- {astrbot-4.9.2.dist-info → astrbot-4.10.0.dist-info}/licenses/LICENSE +0 -0
astrbot/dashboard/routes/chat.py
CHANGED
|
@@ -227,16 +227,19 @@ class ChatRoute(Route):
|
|
|
227
227
|
text: str,
|
|
228
228
|
media_parts: list,
|
|
229
229
|
reasoning: str,
|
|
230
|
+
agent_stats: dict,
|
|
230
231
|
):
|
|
231
232
|
"""保存 bot 消息到历史记录,返回保存的记录"""
|
|
232
233
|
bot_message_parts = []
|
|
234
|
+
bot_message_parts.extend(media_parts)
|
|
233
235
|
if text:
|
|
234
236
|
bot_message_parts.append({"type": "plain", "text": text})
|
|
235
|
-
bot_message_parts.extend(media_parts)
|
|
236
237
|
|
|
237
238
|
new_his = {"type": "bot", "message": bot_message_parts}
|
|
238
239
|
if reasoning:
|
|
239
240
|
new_his["reasoning"] = reasoning
|
|
241
|
+
if agent_stats:
|
|
242
|
+
new_his["agent_stats"] = agent_stats
|
|
240
243
|
|
|
241
244
|
record = await self.platform_history_mgr.insert(
|
|
242
245
|
platform_id="webchat",
|
|
@@ -294,7 +297,8 @@ class ChatRoute(Route):
|
|
|
294
297
|
accumulated_parts = []
|
|
295
298
|
accumulated_text = ""
|
|
296
299
|
accumulated_reasoning = ""
|
|
297
|
-
|
|
300
|
+
tool_calls = {}
|
|
301
|
+
agent_stats = {}
|
|
298
302
|
try:
|
|
299
303
|
async with track_conversation(self.running_convs, webchat_conv_id):
|
|
300
304
|
while True:
|
|
@@ -314,6 +318,16 @@ class ChatRoute(Route):
|
|
|
314
318
|
result_text = result["data"]
|
|
315
319
|
msg_type = result.get("type")
|
|
316
320
|
streaming = result.get("streaming", False)
|
|
321
|
+
chain_type = result.get("chain_type")
|
|
322
|
+
|
|
323
|
+
if chain_type == "agent_stats":
|
|
324
|
+
stats_info = {
|
|
325
|
+
"type": "agent_stats",
|
|
326
|
+
"data": json.loads(result_text),
|
|
327
|
+
}
|
|
328
|
+
yield f"data: {json.dumps(stats_info, ensure_ascii=False)}\n\n"
|
|
329
|
+
agent_stats = stats_info["data"]
|
|
330
|
+
continue
|
|
317
331
|
|
|
318
332
|
# 发送 SSE 数据
|
|
319
333
|
try:
|
|
@@ -335,11 +349,35 @@ class ChatRoute(Route):
|
|
|
335
349
|
|
|
336
350
|
# 累积消息部分
|
|
337
351
|
if msg_type == "plain":
|
|
338
|
-
chain_type = result.get("chain_type"
|
|
339
|
-
if chain_type == "
|
|
352
|
+
chain_type = result.get("chain_type")
|
|
353
|
+
if chain_type == "tool_call":
|
|
354
|
+
tool_call = json.loads(result_text)
|
|
355
|
+
tool_calls[tool_call.get("id")] = tool_call
|
|
356
|
+
if accumulated_text:
|
|
357
|
+
# 如果累积了文本,则先保存文本
|
|
358
|
+
accumulated_parts.append(
|
|
359
|
+
{"type": "plain", "text": accumulated_text}
|
|
360
|
+
)
|
|
361
|
+
accumulated_text = ""
|
|
362
|
+
elif chain_type == "tool_call_result":
|
|
363
|
+
tcr = json.loads(result_text)
|
|
364
|
+
tc_id = tcr.get("id")
|
|
365
|
+
if tc_id in tool_calls:
|
|
366
|
+
tool_calls[tc_id]["result"] = tcr.get("result")
|
|
367
|
+
tool_calls[tc_id]["finished_ts"] = tcr.get("ts")
|
|
368
|
+
accumulated_parts.append(
|
|
369
|
+
{
|
|
370
|
+
"type": "tool_call",
|
|
371
|
+
"tool_calls": [tool_calls[tc_id]],
|
|
372
|
+
}
|
|
373
|
+
)
|
|
374
|
+
tool_calls.pop(tc_id, None)
|
|
375
|
+
elif chain_type == "reasoning":
|
|
340
376
|
accumulated_reasoning += result_text
|
|
341
|
-
|
|
377
|
+
elif streaming:
|
|
342
378
|
accumulated_text += result_text
|
|
379
|
+
else:
|
|
380
|
+
accumulated_text = result_text
|
|
343
381
|
elif msg_type == "image":
|
|
344
382
|
filename = result_text.replace("[IMAGE]", "")
|
|
345
383
|
part = await self._create_attachment_from_file(
|
|
@@ -367,15 +405,20 @@ class ChatRoute(Route):
|
|
|
367
405
|
if msg_type == "end":
|
|
368
406
|
break
|
|
369
407
|
elif (
|
|
370
|
-
(streaming and msg_type == "complete")
|
|
371
|
-
or
|
|
372
|
-
or msg_type == "break"
|
|
408
|
+
(streaming and msg_type == "complete") or not streaming
|
|
409
|
+
# or msg_type == "break"
|
|
373
410
|
):
|
|
411
|
+
if (
|
|
412
|
+
chain_type == "tool_call"
|
|
413
|
+
or chain_type == "tool_call_result"
|
|
414
|
+
):
|
|
415
|
+
continue
|
|
374
416
|
saved_record = await self._save_bot_message(
|
|
375
417
|
webchat_conv_id,
|
|
376
418
|
accumulated_text,
|
|
377
419
|
accumulated_parts,
|
|
378
420
|
accumulated_reasoning,
|
|
421
|
+
agent_stats,
|
|
379
422
|
)
|
|
380
423
|
# 发送保存的消息信息给前端
|
|
381
424
|
if saved_record and not client_disconnected:
|
|
@@ -390,11 +433,11 @@ class ChatRoute(Route):
|
|
|
390
433
|
yield f"data: {json.dumps(saved_info, ensure_ascii=False)}\n\n"
|
|
391
434
|
except Exception:
|
|
392
435
|
pass
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
436
|
+
accumulated_parts = []
|
|
437
|
+
accumulated_text = ""
|
|
438
|
+
accumulated_reasoning = ""
|
|
439
|
+
# tool_calls = {}
|
|
440
|
+
agent_stats = {}
|
|
398
441
|
except BaseException as e:
|
|
399
442
|
logger.exception(f"WebChat stream unexpected error: {e}", exc_info=True)
|
|
400
443
|
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
from quart import request
|
|
2
|
+
|
|
3
|
+
from astrbot.core.star.command_management import (
|
|
4
|
+
list_command_conflicts,
|
|
5
|
+
list_commands,
|
|
6
|
+
)
|
|
7
|
+
from astrbot.core.star.command_management import (
|
|
8
|
+
rename_command as rename_command_service,
|
|
9
|
+
)
|
|
10
|
+
from astrbot.core.star.command_management import (
|
|
11
|
+
toggle_command as toggle_command_service,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
from .route import Response, Route, RouteContext
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class CommandRoute(Route):
|
|
18
|
+
def __init__(self, context: RouteContext) -> None:
|
|
19
|
+
super().__init__(context)
|
|
20
|
+
self.routes = {
|
|
21
|
+
"/commands": ("GET", self.get_commands),
|
|
22
|
+
"/commands/conflicts": ("GET", self.get_conflicts),
|
|
23
|
+
"/commands/toggle": ("POST", self.toggle_command),
|
|
24
|
+
"/commands/rename": ("POST", self.rename_command),
|
|
25
|
+
}
|
|
26
|
+
self.register_routes()
|
|
27
|
+
|
|
28
|
+
async def get_commands(self):
|
|
29
|
+
commands = await list_commands()
|
|
30
|
+
summary = {
|
|
31
|
+
"total": len(commands),
|
|
32
|
+
"disabled": len([cmd for cmd in commands if not cmd["enabled"]]),
|
|
33
|
+
"conflicts": len([cmd for cmd in commands if cmd.get("has_conflict")]),
|
|
34
|
+
}
|
|
35
|
+
return Response().ok({"items": commands, "summary": summary}).__dict__
|
|
36
|
+
|
|
37
|
+
async def get_conflicts(self):
|
|
38
|
+
conflicts = await list_command_conflicts()
|
|
39
|
+
return Response().ok(conflicts).__dict__
|
|
40
|
+
|
|
41
|
+
async def toggle_command(self):
|
|
42
|
+
data = await request.get_json()
|
|
43
|
+
handler_full_name = data.get("handler_full_name")
|
|
44
|
+
enabled = data.get("enabled")
|
|
45
|
+
|
|
46
|
+
if handler_full_name is None or enabled is None:
|
|
47
|
+
return Response().error("handler_full_name 与 enabled 均为必填。").__dict__
|
|
48
|
+
|
|
49
|
+
if isinstance(enabled, str):
|
|
50
|
+
enabled = enabled.lower() in ("1", "true", "yes", "on")
|
|
51
|
+
|
|
52
|
+
try:
|
|
53
|
+
await toggle_command_service(handler_full_name, bool(enabled))
|
|
54
|
+
except ValueError as exc:
|
|
55
|
+
return Response().error(str(exc)).__dict__
|
|
56
|
+
|
|
57
|
+
payload = await _get_command_payload(handler_full_name)
|
|
58
|
+
return Response().ok(payload).__dict__
|
|
59
|
+
|
|
60
|
+
async def rename_command(self):
|
|
61
|
+
data = await request.get_json()
|
|
62
|
+
handler_full_name = data.get("handler_full_name")
|
|
63
|
+
new_name = data.get("new_name")
|
|
64
|
+
|
|
65
|
+
if not handler_full_name or not new_name:
|
|
66
|
+
return Response().error("handler_full_name 与 new_name 均为必填。").__dict__
|
|
67
|
+
|
|
68
|
+
try:
|
|
69
|
+
await rename_command_service(handler_full_name, new_name)
|
|
70
|
+
except ValueError as exc:
|
|
71
|
+
return Response().error(str(exc)).__dict__
|
|
72
|
+
|
|
73
|
+
payload = await _get_command_payload(handler_full_name)
|
|
74
|
+
return Response().ok(payload).__dict__
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
async def _get_command_payload(handler_full_name: str):
|
|
78
|
+
commands = await list_commands()
|
|
79
|
+
for cmd in commands:
|
|
80
|
+
if cmd["handler_full_name"] == handler_full_name:
|
|
81
|
+
return cmd
|
|
82
|
+
return {}
|
|
@@ -6,7 +6,7 @@ from typing import Any
|
|
|
6
6
|
|
|
7
7
|
from quart import request
|
|
8
8
|
|
|
9
|
-
from astrbot.core import file_token_service, logger
|
|
9
|
+
from astrbot.core import astrbot_config, file_token_service, logger
|
|
10
10
|
from astrbot.core.config.astrbot_config import AstrBotConfig
|
|
11
11
|
from astrbot.core.config.default import (
|
|
12
12
|
CONFIG_METADATA_2,
|
|
@@ -21,6 +21,7 @@ from astrbot.core.platform.register import platform_cls_map, platform_registry
|
|
|
21
21
|
from astrbot.core.provider import Provider
|
|
22
22
|
from astrbot.core.provider.register import provider_registry
|
|
23
23
|
from astrbot.core.star.star import star_registry
|
|
24
|
+
from astrbot.core.utils.llm_metadata import LLM_METADATAS
|
|
24
25
|
from astrbot.core.utils.webhook_utils import ensure_platform_webhook_config
|
|
25
26
|
|
|
26
27
|
from .route import Response, Route, RouteContext
|
|
@@ -179,13 +180,149 @@ class ConfigRoute(Route):
|
|
|
179
180
|
"/config/provider/new": ("POST", self.post_new_provider),
|
|
180
181
|
"/config/provider/update": ("POST", self.post_update_provider),
|
|
181
182
|
"/config/provider/delete": ("POST", self.post_delete_provider),
|
|
183
|
+
"/config/provider/template": ("GET", self.get_provider_template),
|
|
182
184
|
"/config/provider/check_one": ("GET", self.check_one_provider_status),
|
|
183
185
|
"/config/provider/list": ("GET", self.get_provider_config_list),
|
|
184
186
|
"/config/provider/model_list": ("GET", self.get_provider_model_list),
|
|
185
187
|
"/config/provider/get_embedding_dim": ("POST", self.get_embedding_dim),
|
|
188
|
+
"/config/provider_sources/<provider_source_id>/models": (
|
|
189
|
+
"GET",
|
|
190
|
+
self.get_provider_source_models,
|
|
191
|
+
),
|
|
192
|
+
"/config/provider_sources/<provider_source_id>/update": (
|
|
193
|
+
"POST",
|
|
194
|
+
self.update_provider_source,
|
|
195
|
+
),
|
|
196
|
+
"/config/provider_sources/<provider_source_id>/delete": (
|
|
197
|
+
"POST",
|
|
198
|
+
self.delete_provider_source,
|
|
199
|
+
),
|
|
186
200
|
}
|
|
187
201
|
self.register_routes()
|
|
188
202
|
|
|
203
|
+
async def delete_provider_source(self, provider_source_id: str):
|
|
204
|
+
"""删除 provider_source,并更新关联的 providers"""
|
|
205
|
+
|
|
206
|
+
provider_sources = self.config.get("provider_sources", [])
|
|
207
|
+
target_idx = next(
|
|
208
|
+
(
|
|
209
|
+
i
|
|
210
|
+
for i, ps in enumerate(provider_sources)
|
|
211
|
+
if ps.get("id") == provider_source_id
|
|
212
|
+
),
|
|
213
|
+
-1,
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
if target_idx == -1:
|
|
217
|
+
return Response().error("未找到对应的 provider source").__dict__
|
|
218
|
+
|
|
219
|
+
# 删除 provider_source
|
|
220
|
+
del provider_sources[target_idx]
|
|
221
|
+
|
|
222
|
+
# 写回配置
|
|
223
|
+
self.config["provider_sources"] = provider_sources
|
|
224
|
+
|
|
225
|
+
# 删除引用了该 provider_source 的 providers
|
|
226
|
+
await self.core_lifecycle.provider_manager.delete_provider(
|
|
227
|
+
provider_source_id=provider_source_id
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
try:
|
|
231
|
+
save_config(self.config, self.config, is_core=True)
|
|
232
|
+
except Exception as e:
|
|
233
|
+
logger.error(traceback.format_exc())
|
|
234
|
+
return Response().error(str(e)).__dict__
|
|
235
|
+
|
|
236
|
+
return Response().ok(message="删除 provider source 成功").__dict__
|
|
237
|
+
|
|
238
|
+
async def update_provider_source(self, provider_source_id: str):
|
|
239
|
+
"""更新或新增 provider_source,并重载关联的 providers"""
|
|
240
|
+
|
|
241
|
+
post_data = await request.json
|
|
242
|
+
if not post_data:
|
|
243
|
+
return Response().error("缺少配置数据").__dict__
|
|
244
|
+
|
|
245
|
+
new_source_config = post_data.get("config") or post_data
|
|
246
|
+
original_id = provider_source_id
|
|
247
|
+
|
|
248
|
+
if not isinstance(new_source_config, dict):
|
|
249
|
+
return Response().error("缺少或错误的配置数据").__dict__
|
|
250
|
+
|
|
251
|
+
# 确保配置中有 id 字段
|
|
252
|
+
if not new_source_config.get("id"):
|
|
253
|
+
new_source_config["id"] = original_id
|
|
254
|
+
|
|
255
|
+
provider_sources = self.config.get("provider_sources", [])
|
|
256
|
+
|
|
257
|
+
for ps in provider_sources:
|
|
258
|
+
if ps.get("id") == new_source_config["id"] and ps.get("id") != original_id:
|
|
259
|
+
return (
|
|
260
|
+
Response()
|
|
261
|
+
.error(
|
|
262
|
+
f"Provider source ID '{new_source_config['id']}' exists already, please try another ID.",
|
|
263
|
+
)
|
|
264
|
+
.__dict__
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
# 查找旧的 provider_source,若不存在则追加为新配置
|
|
268
|
+
target_idx = next(
|
|
269
|
+
(i for i, ps in enumerate(provider_sources) if ps.get("id") == original_id),
|
|
270
|
+
-1,
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
old_id = original_id
|
|
274
|
+
if target_idx == -1:
|
|
275
|
+
provider_sources.append(new_source_config)
|
|
276
|
+
else:
|
|
277
|
+
old_id = provider_sources[target_idx].get("id")
|
|
278
|
+
provider_sources[target_idx] = new_source_config
|
|
279
|
+
|
|
280
|
+
# 更新引用了该 provider_source 的 providers
|
|
281
|
+
affected_providers = []
|
|
282
|
+
for provider in self.config.get("provider", []):
|
|
283
|
+
if provider.get("provider_source_id") == old_id:
|
|
284
|
+
provider["provider_source_id"] = new_source_config["id"]
|
|
285
|
+
affected_providers.append(provider)
|
|
286
|
+
|
|
287
|
+
# 写回配置
|
|
288
|
+
self.config["provider_sources"] = provider_sources
|
|
289
|
+
|
|
290
|
+
try:
|
|
291
|
+
save_config(self.config, self.config, is_core=True)
|
|
292
|
+
except Exception as e:
|
|
293
|
+
logger.error(traceback.format_exc())
|
|
294
|
+
return Response().error(str(e)).__dict__
|
|
295
|
+
|
|
296
|
+
# 重载受影响的 providers,使新的 source 配置生效
|
|
297
|
+
reload_errors = []
|
|
298
|
+
prov_mgr = self.core_lifecycle.provider_manager
|
|
299
|
+
for provider in affected_providers:
|
|
300
|
+
try:
|
|
301
|
+
await prov_mgr.reload(provider)
|
|
302
|
+
except Exception as e:
|
|
303
|
+
logger.error(traceback.format_exc())
|
|
304
|
+
reload_errors.append(f"{provider.get('id')}: {e}")
|
|
305
|
+
|
|
306
|
+
if reload_errors:
|
|
307
|
+
return (
|
|
308
|
+
Response()
|
|
309
|
+
.error("更新成功,但部分提供商重载失败: " + ", ".join(reload_errors))
|
|
310
|
+
.__dict__
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
return Response().ok(message="更新 provider source 成功").__dict__
|
|
314
|
+
|
|
315
|
+
async def get_provider_template(self):
|
|
316
|
+
config_schema = {
|
|
317
|
+
"provider": CONFIG_METADATA_2["provider_group"]["metadata"]["provider"]
|
|
318
|
+
}
|
|
319
|
+
data = {
|
|
320
|
+
"config_schema": config_schema,
|
|
321
|
+
"providers": astrbot_config["provider"],
|
|
322
|
+
"provider_sources": astrbot_config["provider_sources"],
|
|
323
|
+
}
|
|
324
|
+
return Response().ok(data=data).__dict__
|
|
325
|
+
|
|
189
326
|
async def get_uc_table(self):
|
|
190
327
|
"""获取 UMOP 配置路由表"""
|
|
191
328
|
return Response().ok({"routing": self.ucr.umop_to_conf_id}).__dict__
|
|
@@ -433,9 +570,25 @@ class ConfigRoute(Route):
|
|
|
433
570
|
return Response().error("缺少参数 provider_type").__dict__
|
|
434
571
|
provider_type_ls = provider_type.split(",")
|
|
435
572
|
provider_list = []
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
573
|
+
ps = self.core_lifecycle.provider_manager.providers_config
|
|
574
|
+
p_source_pt = {
|
|
575
|
+
psrc["id"]: psrc["provider_type"]
|
|
576
|
+
for psrc in self.core_lifecycle.provider_manager.provider_sources_config
|
|
577
|
+
}
|
|
578
|
+
for provider in ps:
|
|
579
|
+
ps_id = provider.get("provider_source_id", None)
|
|
580
|
+
if (
|
|
581
|
+
ps_id
|
|
582
|
+
and ps_id in p_source_pt
|
|
583
|
+
and p_source_pt[ps_id] in provider_type_ls
|
|
584
|
+
):
|
|
585
|
+
# chat
|
|
586
|
+
prov = self.core_lifecycle.provider_manager.get_merged_provider_config(
|
|
587
|
+
provider
|
|
588
|
+
)
|
|
589
|
+
provider_list.append(prov)
|
|
590
|
+
elif not ps_id and provider.get("provider_type", None) in provider_type_ls:
|
|
591
|
+
# agent runner, embedding, etc
|
|
439
592
|
provider_list.append(provider)
|
|
440
593
|
return Response().ok(provider_list).__dict__
|
|
441
594
|
|
|
@@ -458,9 +611,18 @@ class ConfigRoute(Route):
|
|
|
458
611
|
|
|
459
612
|
try:
|
|
460
613
|
models = await provider.get_models()
|
|
614
|
+
models = models or []
|
|
615
|
+
|
|
616
|
+
metadata_map = {}
|
|
617
|
+
for model_id in models:
|
|
618
|
+
meta = LLM_METADATAS.get(model_id)
|
|
619
|
+
if meta:
|
|
620
|
+
metadata_map[model_id] = meta
|
|
621
|
+
|
|
461
622
|
ret = {
|
|
462
623
|
"models": models,
|
|
463
624
|
"provider_id": provider_id,
|
|
625
|
+
"model_metadata": metadata_map,
|
|
464
626
|
}
|
|
465
627
|
return Response().ok(ret).__dict__
|
|
466
628
|
except Exception as e:
|
|
@@ -522,6 +684,100 @@ class ConfigRoute(Route):
|
|
|
522
684
|
logger.error(traceback.format_exc())
|
|
523
685
|
return Response().error(f"获取嵌入维度失败: {e!s}").__dict__
|
|
524
686
|
|
|
687
|
+
async def get_provider_source_models(self, provider_source_id: str):
|
|
688
|
+
"""获取指定 provider_source 支持的模型列表
|
|
689
|
+
|
|
690
|
+
本质上会临时初始化一个 Provider 实例,调用 get_models() 获取模型列表,然后销毁实例
|
|
691
|
+
"""
|
|
692
|
+
try:
|
|
693
|
+
from astrbot.core.provider.register import provider_cls_map
|
|
694
|
+
|
|
695
|
+
# 从配置中查找对应的 provider_source
|
|
696
|
+
provider_sources = self.config.get("provider_sources", [])
|
|
697
|
+
provider_source = None
|
|
698
|
+
for ps in provider_sources:
|
|
699
|
+
if ps.get("id") == provider_source_id:
|
|
700
|
+
provider_source = ps
|
|
701
|
+
break
|
|
702
|
+
|
|
703
|
+
if not provider_source:
|
|
704
|
+
return (
|
|
705
|
+
Response()
|
|
706
|
+
.error(f"未找到 ID 为 {provider_source_id} 的 provider_source")
|
|
707
|
+
.__dict__
|
|
708
|
+
)
|
|
709
|
+
|
|
710
|
+
# 获取 provider 类型
|
|
711
|
+
provider_type = provider_source.get("type", None)
|
|
712
|
+
if not provider_type:
|
|
713
|
+
return Response().error("provider_source 缺少 type 字段").__dict__
|
|
714
|
+
|
|
715
|
+
try:
|
|
716
|
+
self.core_lifecycle.provider_manager.dynamic_import_provider(
|
|
717
|
+
provider_type
|
|
718
|
+
)
|
|
719
|
+
except ImportError as e:
|
|
720
|
+
logger.error(traceback.format_exc())
|
|
721
|
+
return Response().error(f"动态导入提供商适配器失败: {e!s}").__dict__
|
|
722
|
+
|
|
723
|
+
# 获取对应的 provider 类
|
|
724
|
+
if provider_type not in provider_cls_map:
|
|
725
|
+
return (
|
|
726
|
+
Response()
|
|
727
|
+
.error(f"未找到适用于 {provider_type} 的提供商适配器")
|
|
728
|
+
.__dict__
|
|
729
|
+
)
|
|
730
|
+
|
|
731
|
+
provider_metadata = provider_cls_map[provider_type]
|
|
732
|
+
cls_type = provider_metadata.cls_type
|
|
733
|
+
|
|
734
|
+
if not cls_type:
|
|
735
|
+
return Response().error(f"无法找到 {provider_type} 的类").__dict__
|
|
736
|
+
|
|
737
|
+
# 检查是否是 Provider 类型
|
|
738
|
+
if not issubclass(cls_type, Provider):
|
|
739
|
+
return (
|
|
740
|
+
Response()
|
|
741
|
+
.error(f"提供商 {provider_type} 不支持获取模型列表")
|
|
742
|
+
.__dict__
|
|
743
|
+
)
|
|
744
|
+
|
|
745
|
+
# 临时实例化 provider
|
|
746
|
+
inst = cls_type(provider_source, {})
|
|
747
|
+
|
|
748
|
+
# 如果有 initialize 方法,调用它
|
|
749
|
+
init_fn = getattr(inst, "initialize", None)
|
|
750
|
+
if inspect.iscoroutinefunction(init_fn):
|
|
751
|
+
await init_fn()
|
|
752
|
+
|
|
753
|
+
# 获取模型列表
|
|
754
|
+
models = await inst.get_models()
|
|
755
|
+
models = models or []
|
|
756
|
+
|
|
757
|
+
metadata_map = {}
|
|
758
|
+
for model_id in models:
|
|
759
|
+
meta = LLM_METADATAS.get(model_id)
|
|
760
|
+
if meta:
|
|
761
|
+
metadata_map[model_id] = meta
|
|
762
|
+
|
|
763
|
+
# 销毁实例(如果有 terminate 方法)
|
|
764
|
+
terminate_fn = getattr(inst, "terminate", None)
|
|
765
|
+
if inspect.iscoroutinefunction(terminate_fn):
|
|
766
|
+
await terminate_fn()
|
|
767
|
+
|
|
768
|
+
logger.info(
|
|
769
|
+
f"获取到 provider_source {provider_source_id} 的模型列表: {models}",
|
|
770
|
+
)
|
|
771
|
+
|
|
772
|
+
return (
|
|
773
|
+
Response()
|
|
774
|
+
.ok({"models": models, "model_metadata": metadata_map})
|
|
775
|
+
.__dict__
|
|
776
|
+
)
|
|
777
|
+
except Exception as e:
|
|
778
|
+
logger.error(traceback.format_exc())
|
|
779
|
+
return Response().error(f"获取模型列表失败: {e!s}").__dict__
|
|
780
|
+
|
|
525
781
|
async def get_platform_list(self):
|
|
526
782
|
"""获取所有平台的列表"""
|
|
527
783
|
platform_list = []
|
|
@@ -533,7 +789,15 @@ class ConfigRoute(Route):
|
|
|
533
789
|
data = await request.json
|
|
534
790
|
config = data.get("config", None)
|
|
535
791
|
conf_id = data.get("conf_id", None)
|
|
792
|
+
|
|
536
793
|
try:
|
|
794
|
+
# 不更新 provider_sources, provider, platform
|
|
795
|
+
# 这些配置有单独的接口进行更新
|
|
796
|
+
if conf_id == "default":
|
|
797
|
+
no_update_keys = ["provider_sources", "provider", "platform"]
|
|
798
|
+
for key in no_update_keys:
|
|
799
|
+
config[key] = self.acm.default_conf[key]
|
|
800
|
+
|
|
537
801
|
await self._save_astrbot_configs(config, conf_id)
|
|
538
802
|
await self.core_lifecycle.reload_pipeline_scheduler(conf_id)
|
|
539
803
|
return Response().ok(None, "保存成功~").__dict__
|
|
@@ -573,28 +837,30 @@ class ConfigRoute(Route):
|
|
|
573
837
|
|
|
574
838
|
async def post_new_provider(self):
|
|
575
839
|
new_provider_config = await request.json
|
|
576
|
-
|
|
840
|
+
|
|
577
841
|
try:
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
new_provider_config,
|
|
842
|
+
await self.core_lifecycle.provider_manager.create_provider(
|
|
843
|
+
new_provider_config
|
|
581
844
|
)
|
|
582
845
|
except Exception as e:
|
|
583
846
|
return Response().error(str(e)).__dict__
|
|
584
|
-
return Response().ok(None, "
|
|
847
|
+
return Response().ok(None, "新增服务提供商配置成功").__dict__
|
|
585
848
|
|
|
586
849
|
async def post_update_platform(self):
|
|
587
850
|
update_platform_config = await request.json
|
|
588
|
-
|
|
851
|
+
origin_platform_id = update_platform_config.get("id", None)
|
|
589
852
|
new_config = update_platform_config.get("config", None)
|
|
590
|
-
if not
|
|
853
|
+
if not origin_platform_id or not new_config:
|
|
591
854
|
return Response().error("参数错误").__dict__
|
|
592
855
|
|
|
856
|
+
if origin_platform_id != new_config.get("id", None):
|
|
857
|
+
return Response().error("机器人名称不允许修改").__dict__
|
|
858
|
+
|
|
593
859
|
# 如果是支持统一 webhook 模式的平台,且启用了统一 webhook 模式,确保有 webhook_uuid
|
|
594
860
|
ensure_platform_webhook_config(new_config)
|
|
595
861
|
|
|
596
862
|
for i, platform in enumerate(self.config["platform"]):
|
|
597
|
-
if platform["id"] ==
|
|
863
|
+
if platform["id"] == origin_platform_id:
|
|
598
864
|
self.config["platform"][i] = new_config
|
|
599
865
|
break
|
|
600
866
|
else:
|
|
@@ -609,21 +875,15 @@ class ConfigRoute(Route):
|
|
|
609
875
|
|
|
610
876
|
async def post_update_provider(self):
|
|
611
877
|
update_provider_config = await request.json
|
|
612
|
-
|
|
878
|
+
origin_provider_id = update_provider_config.get("id", None)
|
|
613
879
|
new_config = update_provider_config.get("config", None)
|
|
614
|
-
if not
|
|
880
|
+
if not origin_provider_id or not new_config:
|
|
615
881
|
return Response().error("参数错误").__dict__
|
|
616
882
|
|
|
617
|
-
for i, provider in enumerate(self.config["provider"]):
|
|
618
|
-
if provider["id"] == provider_id:
|
|
619
|
-
self.config["provider"][i] = new_config
|
|
620
|
-
break
|
|
621
|
-
else:
|
|
622
|
-
return Response().error("未找到对应服务提供商").__dict__
|
|
623
|
-
|
|
624
883
|
try:
|
|
625
|
-
|
|
626
|
-
|
|
884
|
+
await self.core_lifecycle.provider_manager.update_provider(
|
|
885
|
+
origin_provider_id, new_config
|
|
886
|
+
)
|
|
627
887
|
except Exception as e:
|
|
628
888
|
return Response().error(str(e)).__dict__
|
|
629
889
|
return Response().ok(None, "更新成功,已经实时生效~").__dict__
|
|
@@ -646,19 +906,17 @@ class ConfigRoute(Route):
|
|
|
646
906
|
|
|
647
907
|
async def post_delete_provider(self):
|
|
648
908
|
provider_id = await request.json
|
|
649
|
-
provider_id = provider_id.get("id")
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
break
|
|
654
|
-
else:
|
|
655
|
-
return Response().error("未找到对应服务提供商").__dict__
|
|
909
|
+
provider_id = provider_id.get("id", "")
|
|
910
|
+
if not provider_id:
|
|
911
|
+
return Response().error("缺少参数 id").__dict__
|
|
912
|
+
|
|
656
913
|
try:
|
|
657
|
-
|
|
658
|
-
|
|
914
|
+
await self.core_lifecycle.provider_manager.delete_provider(
|
|
915
|
+
provider_id=provider_id
|
|
916
|
+
)
|
|
659
917
|
except Exception as e:
|
|
660
918
|
return Response().error(str(e)).__dict__
|
|
661
|
-
return Response().ok(None, "
|
|
919
|
+
return Response().ok(None, "删除成功,已经实时生效。").__dict__
|
|
662
920
|
|
|
663
921
|
async def get_llm_tools(self):
|
|
664
922
|
"""获取函数调用工具。包含了本地加载的以及 MCP 服务的工具"""
|