AstrBot 4.5.1__py3-none-any.whl → 4.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/api/__init__.py +10 -11
- astrbot/api/event/__init__.py +5 -6
- astrbot/api/event/filter/__init__.py +37 -36
- astrbot/api/platform/__init__.py +7 -8
- astrbot/api/provider/__init__.py +7 -7
- astrbot/api/star/__init__.py +3 -4
- astrbot/api/util/__init__.py +2 -2
- astrbot/cli/__main__.py +5 -5
- astrbot/cli/commands/__init__.py +3 -3
- astrbot/cli/commands/cmd_conf.py +19 -16
- astrbot/cli/commands/cmd_init.py +3 -2
- astrbot/cli/commands/cmd_plug.py +8 -10
- astrbot/cli/commands/cmd_run.py +5 -6
- astrbot/cli/utils/__init__.py +6 -6
- astrbot/cli/utils/basic.py +14 -14
- astrbot/cli/utils/plugin.py +24 -15
- astrbot/cli/utils/version_comparator.py +10 -12
- astrbot/core/__init__.py +8 -6
- astrbot/core/agent/agent.py +3 -2
- astrbot/core/agent/handoff.py +6 -2
- astrbot/core/agent/hooks.py +9 -6
- astrbot/core/agent/mcp_client.py +50 -15
- astrbot/core/agent/message.py +168 -0
- astrbot/core/agent/response.py +2 -1
- astrbot/core/agent/run_context.py +2 -3
- astrbot/core/agent/runners/base.py +10 -13
- astrbot/core/agent/runners/tool_loop_agent_runner.py +52 -51
- astrbot/core/agent/tool.py +60 -41
- astrbot/core/agent/tool_executor.py +9 -3
- astrbot/core/astr_agent_context.py +3 -1
- astrbot/core/astrbot_config_mgr.py +29 -9
- astrbot/core/config/__init__.py +2 -2
- astrbot/core/config/astrbot_config.py +28 -26
- astrbot/core/config/default.py +4 -6
- astrbot/core/conversation_mgr.py +105 -36
- astrbot/core/core_lifecycle.py +68 -54
- astrbot/core/db/__init__.py +33 -18
- astrbot/core/db/migration/helper.py +12 -10
- astrbot/core/db/migration/migra_3_to_4.py +53 -34
- astrbot/core/db/migration/migra_45_to_46.py +1 -1
- astrbot/core/db/migration/shared_preferences_v3.py +2 -1
- astrbot/core/db/migration/sqlite_v3.py +26 -23
- astrbot/core/db/po.py +27 -18
- astrbot/core/db/sqlite.py +74 -45
- astrbot/core/db/vec_db/base.py +10 -14
- astrbot/core/db/vec_db/faiss_impl/document_storage.py +90 -77
- astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +9 -3
- astrbot/core/db/vec_db/faiss_impl/vec_db.py +36 -31
- astrbot/core/event_bus.py +8 -6
- astrbot/core/file_token_service.py +6 -5
- astrbot/core/initial_loader.py +7 -5
- astrbot/core/knowledge_base/chunking/__init__.py +1 -3
- astrbot/core/knowledge_base/chunking/base.py +1 -0
- astrbot/core/knowledge_base/chunking/fixed_size.py +2 -0
- astrbot/core/knowledge_base/chunking/recursive.py +16 -10
- astrbot/core/knowledge_base/kb_db_sqlite.py +50 -48
- astrbot/core/knowledge_base/kb_helper.py +30 -17
- astrbot/core/knowledge_base/kb_mgr.py +6 -7
- astrbot/core/knowledge_base/models.py +10 -4
- astrbot/core/knowledge_base/parsers/__init__.py +3 -5
- astrbot/core/knowledge_base/parsers/base.py +1 -0
- astrbot/core/knowledge_base/parsers/markitdown_parser.py +2 -1
- astrbot/core/knowledge_base/parsers/pdf_parser.py +2 -1
- astrbot/core/knowledge_base/parsers/text_parser.py +1 -0
- astrbot/core/knowledge_base/parsers/util.py +1 -1
- astrbot/core/knowledge_base/retrieval/__init__.py +6 -8
- astrbot/core/knowledge_base/retrieval/manager.py +17 -14
- astrbot/core/knowledge_base/retrieval/rank_fusion.py +7 -3
- astrbot/core/knowledge_base/retrieval/sparse_retriever.py +11 -5
- astrbot/core/log.py +21 -13
- astrbot/core/message/components.py +123 -217
- astrbot/core/message/message_event_result.py +24 -24
- astrbot/core/persona_mgr.py +20 -11
- astrbot/core/pipeline/__init__.py +7 -7
- astrbot/core/pipeline/content_safety_check/stage.py +13 -9
- astrbot/core/pipeline/content_safety_check/strategies/__init__.py +1 -2
- astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py +12 -13
- astrbot/core/pipeline/content_safety_check/strategies/keywords.py +1 -0
- astrbot/core/pipeline/content_safety_check/strategies/strategy.py +6 -6
- astrbot/core/pipeline/context.py +4 -1
- astrbot/core/pipeline/context_utils.py +77 -7
- astrbot/core/pipeline/preprocess_stage/stage.py +12 -9
- astrbot/core/pipeline/process_stage/method/llm_request.py +125 -72
- astrbot/core/pipeline/process_stage/method/star_request.py +19 -17
- astrbot/core/pipeline/process_stage/stage.py +13 -10
- astrbot/core/pipeline/process_stage/utils.py +6 -5
- astrbot/core/pipeline/rate_limit_check/stage.py +37 -36
- astrbot/core/pipeline/respond/stage.py +23 -20
- astrbot/core/pipeline/result_decorate/stage.py +31 -23
- astrbot/core/pipeline/scheduler.py +12 -8
- astrbot/core/pipeline/session_status_check/stage.py +12 -8
- astrbot/core/pipeline/stage.py +10 -4
- astrbot/core/pipeline/waking_check/stage.py +24 -18
- astrbot/core/pipeline/whitelist_check/stage.py +10 -7
- astrbot/core/platform/__init__.py +6 -6
- astrbot/core/platform/astr_message_event.py +76 -110
- astrbot/core/platform/astrbot_message.py +11 -13
- astrbot/core/platform/manager.py +16 -15
- astrbot/core/platform/message_session.py +5 -3
- astrbot/core/platform/platform.py +16 -24
- astrbot/core/platform/platform_metadata.py +4 -4
- astrbot/core/platform/register.py +8 -8
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +23 -15
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +51 -33
- astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +42 -27
- astrbot/core/platform/sources/dingtalk/dingtalk_event.py +7 -3
- astrbot/core/platform/sources/discord/client.py +9 -6
- astrbot/core/platform/sources/discord/components.py +18 -14
- astrbot/core/platform/sources/discord/discord_platform_adapter.py +45 -30
- astrbot/core/platform/sources/discord/discord_platform_event.py +38 -30
- astrbot/core/platform/sources/lark/lark_adapter.py +23 -17
- astrbot/core/platform/sources/lark/lark_event.py +21 -14
- astrbot/core/platform/sources/misskey/misskey_adapter.py +107 -67
- astrbot/core/platform/sources/misskey/misskey_api.py +153 -129
- astrbot/core/platform/sources/misskey/misskey_event.py +20 -15
- astrbot/core/platform/sources/misskey/misskey_utils.py +74 -62
- astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +63 -44
- astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +41 -26
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +36 -17
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py +3 -1
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +12 -7
- astrbot/core/platform/sources/satori/satori_adapter.py +56 -38
- astrbot/core/platform/sources/satori/satori_event.py +34 -25
- astrbot/core/platform/sources/slack/client.py +11 -9
- astrbot/core/platform/sources/slack/slack_adapter.py +52 -36
- astrbot/core/platform/sources/slack/slack_event.py +34 -24
- astrbot/core/platform/sources/telegram/tg_adapter.py +38 -18
- astrbot/core/platform/sources/telegram/tg_event.py +32 -18
- astrbot/core/platform/sources/webchat/webchat_adapter.py +27 -17
- astrbot/core/platform/sources/webchat/webchat_event.py +14 -10
- astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +115 -120
- astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +9 -8
- astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py +15 -16
- astrbot/core/platform/sources/wecom/wecom_adapter.py +35 -18
- astrbot/core/platform/sources/wecom/wecom_event.py +55 -48
- astrbot/core/platform/sources/wecom/wecom_kf.py +34 -44
- astrbot/core/platform/sources/wecom/wecom_kf_message.py +26 -10
- astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +18 -10
- astrbot/core/platform/sources/wecom_ai_bot/__init__.py +3 -5
- astrbot/core/platform/sources/wecom_ai_bot/ierror.py +0 -1
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +61 -37
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +67 -28
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +8 -9
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +18 -9
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +14 -12
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +22 -12
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +40 -26
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +47 -45
- astrbot/core/platform_message_history_mgr.py +5 -3
- astrbot/core/provider/__init__.py +2 -3
- astrbot/core/provider/entites.py +8 -8
- astrbot/core/provider/entities.py +61 -75
- astrbot/core/provider/func_tool_manager.py +59 -55
- astrbot/core/provider/manager.py +32 -22
- astrbot/core/provider/provider.py +72 -46
- astrbot/core/provider/register.py +7 -7
- astrbot/core/provider/sources/anthropic_source.py +48 -30
- astrbot/core/provider/sources/azure_tts_source.py +17 -13
- astrbot/core/provider/sources/coze_api_client.py +27 -17
- astrbot/core/provider/sources/coze_source.py +104 -87
- astrbot/core/provider/sources/dashscope_source.py +18 -11
- astrbot/core/provider/sources/dashscope_tts.py +36 -23
- astrbot/core/provider/sources/dify_source.py +25 -20
- astrbot/core/provider/sources/edge_tts_source.py +21 -17
- astrbot/core/provider/sources/fishaudio_tts_api_source.py +22 -14
- astrbot/core/provider/sources/gemini_embedding_source.py +12 -13
- astrbot/core/provider/sources/gemini_source.py +72 -58
- astrbot/core/provider/sources/gemini_tts_source.py +8 -6
- astrbot/core/provider/sources/gsv_selfhosted_source.py +17 -14
- astrbot/core/provider/sources/gsvi_tts_source.py +11 -7
- astrbot/core/provider/sources/minimax_tts_api_source.py +50 -40
- astrbot/core/provider/sources/openai_embedding_source.py +6 -8
- astrbot/core/provider/sources/openai_source.py +77 -69
- astrbot/core/provider/sources/openai_tts_api_source.py +14 -6
- astrbot/core/provider/sources/sensevoice_selfhosted_source.py +13 -11
- astrbot/core/provider/sources/vllm_rerank_source.py +10 -4
- astrbot/core/provider/sources/volcengine_tts.py +38 -31
- astrbot/core/provider/sources/whisper_api_source.py +14 -12
- astrbot/core/provider/sources/whisper_selfhosted_source.py +15 -11
- astrbot/core/provider/sources/xinference_rerank_source.py +16 -8
- astrbot/core/provider/sources/xinference_stt_provider.py +35 -25
- astrbot/core/star/__init__.py +16 -11
- astrbot/core/star/config.py +10 -15
- astrbot/core/star/context.py +97 -75
- astrbot/core/star/filter/__init__.py +4 -3
- astrbot/core/star/filter/command.py +30 -28
- astrbot/core/star/filter/command_group.py +27 -24
- astrbot/core/star/filter/custom_filter.py +6 -5
- astrbot/core/star/filter/event_message_type.py +4 -2
- astrbot/core/star/filter/permission.py +4 -2
- astrbot/core/star/filter/platform_adapter_type.py +4 -2
- astrbot/core/star/filter/regex.py +4 -2
- astrbot/core/star/register/__init__.py +19 -19
- astrbot/core/star/register/star.py +6 -2
- astrbot/core/star/register/star_handler.py +96 -73
- astrbot/core/star/session_llm_manager.py +48 -14
- astrbot/core/star/session_plugin_manager.py +29 -15
- astrbot/core/star/star.py +1 -2
- astrbot/core/star/star_handler.py +13 -8
- astrbot/core/star/star_manager.py +151 -59
- astrbot/core/star/star_tools.py +44 -37
- astrbot/core/star/updator.py +10 -10
- astrbot/core/umop_config_router.py +10 -4
- astrbot/core/updator.py +13 -5
- astrbot/core/utils/astrbot_path.py +3 -5
- astrbot/core/utils/dify_api_client.py +33 -15
- astrbot/core/utils/io.py +66 -42
- astrbot/core/utils/log_pipe.py +1 -1
- astrbot/core/utils/metrics.py +7 -7
- astrbot/core/utils/path_util.py +15 -16
- astrbot/core/utils/pip_installer.py +5 -5
- astrbot/core/utils/session_waiter.py +19 -20
- astrbot/core/utils/shared_preferences.py +45 -20
- astrbot/core/utils/t2i/__init__.py +4 -1
- astrbot/core/utils/t2i/network_strategy.py +35 -26
- astrbot/core/utils/t2i/renderer.py +11 -5
- astrbot/core/utils/t2i/template_manager.py +14 -15
- astrbot/core/utils/tencent_record_helper.py +19 -13
- astrbot/core/utils/version_comparator.py +10 -13
- astrbot/core/zip_updator.py +43 -40
- astrbot/dashboard/routes/__init__.py +18 -18
- astrbot/dashboard/routes/auth.py +10 -8
- astrbot/dashboard/routes/chat.py +30 -21
- astrbot/dashboard/routes/config.py +92 -75
- astrbot/dashboard/routes/conversation.py +46 -39
- astrbot/dashboard/routes/file.py +4 -2
- astrbot/dashboard/routes/knowledge_base.py +47 -40
- astrbot/dashboard/routes/log.py +9 -4
- astrbot/dashboard/routes/persona.py +19 -16
- astrbot/dashboard/routes/plugin.py +69 -55
- astrbot/dashboard/routes/route.py +3 -1
- astrbot/dashboard/routes/session_management.py +130 -116
- astrbot/dashboard/routes/stat.py +34 -34
- astrbot/dashboard/routes/t2i.py +15 -12
- astrbot/dashboard/routes/tools.py +56 -53
- astrbot/dashboard/routes/update.py +32 -28
- astrbot/dashboard/server.py +30 -26
- astrbot/dashboard/utils.py +8 -4
- {astrbot-4.5.1.dist-info → astrbot-4.5.3.dist-info}/METADATA +2 -1
- astrbot-4.5.3.dist-info/RECORD +261 -0
- astrbot-4.5.1.dist-info/RECORD +0 -260
- {astrbot-4.5.1.dist-info → astrbot-4.5.3.dist-info}/WHEEL +0 -0
- {astrbot-4.5.1.dist-info → astrbot-4.5.3.dist-info}/entry_points.txt +0 -0
- {astrbot-4.5.1.dist-info → astrbot-4.5.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,17 +1,18 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
2
4
|
import json
|
|
3
5
|
import os
|
|
4
|
-
import
|
|
6
|
+
from collections.abc import Awaitable, Callable
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
5
9
|
import aiohttp
|
|
6
10
|
|
|
7
|
-
from typing import Dict, List, Awaitable, Callable, Any
|
|
8
11
|
from astrbot import logger
|
|
9
12
|
from astrbot.core import sp
|
|
10
|
-
|
|
13
|
+
from astrbot.core.agent.mcp_client import MCPClient, MCPTool
|
|
14
|
+
from astrbot.core.agent.tool import FunctionTool, ToolSet
|
|
11
15
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
|
12
|
-
from astrbot.core.agent.mcp_client import MCPClient
|
|
13
|
-
from astrbot.core.agent.tool import ToolSet, FunctionTool
|
|
14
|
-
|
|
15
16
|
|
|
16
17
|
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
|
|
17
18
|
|
|
@@ -30,7 +31,7 @@ FuncTool = FunctionTool
|
|
|
30
31
|
|
|
31
32
|
def _prepare_config(config: dict) -> dict:
|
|
32
33
|
"""准备配置,处理嵌套格式"""
|
|
33
|
-
if
|
|
34
|
+
if config.get("mcpServers"):
|
|
34
35
|
first_key = next(iter(config["mcpServers"]))
|
|
35
36
|
config = config["mcpServers"][first_key]
|
|
36
37
|
config.pop("active", None)
|
|
@@ -72,8 +73,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
|
|
72
73
|
) as response:
|
|
73
74
|
if response.status == 200:
|
|
74
75
|
return True, ""
|
|
75
|
-
|
|
76
|
-
return False, f"HTTP {response.status}: {response.reason}"
|
|
76
|
+
return False, f"HTTP {response.status}: {response.reason}"
|
|
77
77
|
else:
|
|
78
78
|
async with session.get(
|
|
79
79
|
url,
|
|
@@ -85,8 +85,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
|
|
85
85
|
) as response:
|
|
86
86
|
if response.status == 200:
|
|
87
87
|
return True, ""
|
|
88
|
-
|
|
89
|
-
return False, f"HTTP {response.status}: {response.reason}"
|
|
88
|
+
return False, f"HTTP {response.status}: {response.reason}"
|
|
90
89
|
|
|
91
90
|
except asyncio.TimeoutError:
|
|
92
91
|
return False, f"连接超时: {timeout}秒"
|
|
@@ -96,10 +95,10 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
|
|
96
95
|
|
|
97
96
|
class FunctionToolManager:
|
|
98
97
|
def __init__(self) -> None:
|
|
99
|
-
self.func_list:
|
|
100
|
-
self.mcp_client_dict:
|
|
98
|
+
self.func_list: list[FuncTool] = []
|
|
99
|
+
self.mcp_client_dict: dict[str, MCPClient] = {}
|
|
101
100
|
"""MCP 服务列表"""
|
|
102
|
-
self.mcp_client_event:
|
|
101
|
+
self.mcp_client_event: dict[str, asyncio.Event] = {}
|
|
103
102
|
|
|
104
103
|
def empty(self) -> bool:
|
|
105
104
|
return len(self.func_list) == 0
|
|
@@ -150,14 +149,12 @@ class FunctionToolManager:
|
|
|
150
149
|
func_args=func_args,
|
|
151
150
|
desc=desc,
|
|
152
151
|
handler=handler,
|
|
153
|
-
)
|
|
152
|
+
),
|
|
154
153
|
)
|
|
155
154
|
logger.info(f"添加函数调用工具: {name}")
|
|
156
155
|
|
|
157
156
|
def remove_func(self, name: str) -> None:
|
|
158
|
-
"""
|
|
159
|
-
删除一个函数调用工具。
|
|
160
|
-
"""
|
|
157
|
+
"""删除一个函数调用工具。"""
|
|
161
158
|
for i, f in enumerate(self.func_list):
|
|
162
159
|
if f.name == name:
|
|
163
160
|
self.func_list.pop(i)
|
|
@@ -202,16 +199,16 @@ class FunctionToolManager:
|
|
|
202
199
|
logger.info(f"未找到 MCP 服务配置文件,已创建默认配置文件 {mcp_json_file}")
|
|
203
200
|
return
|
|
204
201
|
|
|
205
|
-
mcp_server_json_obj:
|
|
206
|
-
open(mcp_json_file,
|
|
202
|
+
mcp_server_json_obj: dict[str, dict] = json.load(
|
|
203
|
+
open(mcp_json_file, encoding="utf-8"),
|
|
207
204
|
)["mcpServers"]
|
|
208
205
|
|
|
209
|
-
for name in mcp_server_json_obj
|
|
206
|
+
for name in mcp_server_json_obj:
|
|
210
207
|
cfg = mcp_server_json_obj[name]
|
|
211
208
|
if cfg.get("active", True):
|
|
212
209
|
event = asyncio.Event()
|
|
213
210
|
asyncio.create_task(
|
|
214
|
-
self._init_mcp_client_task_wrapper(name, cfg, event)
|
|
211
|
+
self._init_mcp_client_task_wrapper(name, cfg, event),
|
|
215
212
|
)
|
|
216
213
|
self.mcp_client_event[name] = event
|
|
217
214
|
|
|
@@ -257,18 +254,15 @@ class FunctionToolManager:
|
|
|
257
254
|
self.func_list = [
|
|
258
255
|
f
|
|
259
256
|
for f in self.func_list
|
|
260
|
-
if not (f
|
|
257
|
+
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
|
|
261
258
|
]
|
|
262
259
|
|
|
263
260
|
# 将 MCP 工具转换为 FuncTool 并添加到 func_list
|
|
264
261
|
for tool in mcp_client.tools:
|
|
265
|
-
func_tool =
|
|
266
|
-
|
|
267
|
-
parameters=tool.inputSchema,
|
|
268
|
-
description=tool.description,
|
|
269
|
-
origin="mcp",
|
|
270
|
-
mcp_server_name=name,
|
|
262
|
+
func_tool = MCPTool(
|
|
263
|
+
mcp_tool=tool,
|
|
271
264
|
mcp_client=mcp_client,
|
|
265
|
+
mcp_server_name=name,
|
|
272
266
|
)
|
|
273
267
|
self.func_list.append(func_tool)
|
|
274
268
|
|
|
@@ -287,7 +281,7 @@ class FunctionToolManager:
|
|
|
287
281
|
self.func_list = [
|
|
288
282
|
f
|
|
289
283
|
for f in self.func_list
|
|
290
|
-
if not (f
|
|
284
|
+
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
|
|
291
285
|
]
|
|
292
286
|
logger.info(f"已关闭 MCP 服务 {name}")
|
|
293
287
|
|
|
@@ -325,9 +319,11 @@ class FunctionToolManager:
|
|
|
325
319
|
event (asyncio.Event): Event to signal when the MCP client is ready.
|
|
326
320
|
ready_future (asyncio.Future): Future to signal when the MCP client is ready.
|
|
327
321
|
timeout (int): Timeout for the initialization.
|
|
322
|
+
|
|
328
323
|
Raises:
|
|
329
324
|
TimeoutError: If the initialization does not complete within the specified timeout.
|
|
330
325
|
Exception: If there is an error during initialization.
|
|
326
|
+
|
|
331
327
|
"""
|
|
332
328
|
if not event:
|
|
333
329
|
event = asyncio.Event()
|
|
@@ -336,7 +332,7 @@ class FunctionToolManager:
|
|
|
336
332
|
if name in self.mcp_client_dict:
|
|
337
333
|
return
|
|
338
334
|
asyncio.create_task(
|
|
339
|
-
self._init_mcp_client_task_wrapper(name, config, event, ready_future)
|
|
335
|
+
self._init_mcp_client_task_wrapper(name, config, event, ready_future),
|
|
340
336
|
)
|
|
341
337
|
try:
|
|
342
338
|
await asyncio.wait_for(ready_future, timeout=timeout)
|
|
@@ -349,13 +345,16 @@ class FunctionToolManager:
|
|
|
349
345
|
raise exc
|
|
350
346
|
|
|
351
347
|
async def disable_mcp_server(
|
|
352
|
-
self,
|
|
348
|
+
self,
|
|
349
|
+
name: str | None = None,
|
|
350
|
+
timeout: float = 10,
|
|
353
351
|
) -> None:
|
|
354
352
|
"""Disable an MCP server by its name.
|
|
355
353
|
|
|
356
354
|
Args:
|
|
357
355
|
name (str): The name of the MCP server to disable. If None, ALL MCP servers will be disabled.
|
|
358
356
|
timeout (int): Timeout.
|
|
357
|
+
|
|
359
358
|
"""
|
|
360
359
|
if name:
|
|
361
360
|
if name not in self.mcp_client_event:
|
|
@@ -372,7 +371,7 @@ class FunctionToolManager:
|
|
|
372
371
|
self.func_list = [
|
|
373
372
|
f
|
|
374
373
|
for f in self.func_list
|
|
375
|
-
if f
|
|
374
|
+
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
|
|
376
375
|
]
|
|
377
376
|
else:
|
|
378
377
|
running_events = [
|
|
@@ -386,30 +385,26 @@ class FunctionToolManager:
|
|
|
386
385
|
finally:
|
|
387
386
|
self.mcp_client_event.clear()
|
|
388
387
|
self.mcp_client_dict.clear()
|
|
389
|
-
self.func_list = [
|
|
388
|
+
self.func_list = [
|
|
389
|
+
f for f in self.func_list if not isinstance(f, MCPTool)
|
|
390
|
+
]
|
|
390
391
|
|
|
391
392
|
def get_func_desc_openai_style(self, omit_empty_parameter_field=False) -> list:
|
|
392
|
-
"""
|
|
393
|
-
获得 OpenAI API 风格的**已经激活**的工具描述
|
|
394
|
-
"""
|
|
393
|
+
"""获得 OpenAI API 风格的**已经激活**的工具描述"""
|
|
395
394
|
tools = [f for f in self.func_list if f.active]
|
|
396
395
|
toolset = ToolSet(tools)
|
|
397
396
|
return toolset.openai_schema(
|
|
398
|
-
omit_empty_parameter_field=omit_empty_parameter_field
|
|
397
|
+
omit_empty_parameter_field=omit_empty_parameter_field,
|
|
399
398
|
)
|
|
400
399
|
|
|
401
400
|
def get_func_desc_anthropic_style(self) -> list:
|
|
402
|
-
"""
|
|
403
|
-
获得 Anthropic API 风格的**已经激活**的工具描述
|
|
404
|
-
"""
|
|
401
|
+
"""获得 Anthropic API 风格的**已经激活**的工具描述"""
|
|
405
402
|
tools = [f for f in self.func_list if f.active]
|
|
406
403
|
toolset = ToolSet(tools)
|
|
407
404
|
return toolset.anthropic_schema()
|
|
408
405
|
|
|
409
406
|
def get_func_desc_google_genai_style(self) -> dict:
|
|
410
|
-
"""
|
|
411
|
-
获得 Google GenAI API 风格的**已经激活**的工具描述
|
|
412
|
-
"""
|
|
407
|
+
"""获得 Google GenAI API 风格的**已经激活**的工具描述"""
|
|
413
408
|
tools = [f for f in self.func_list if f.active]
|
|
414
409
|
toolset = ToolSet(tools)
|
|
415
410
|
return toolset.google_schema()
|
|
@@ -418,13 +413,18 @@ class FunctionToolManager:
|
|
|
418
413
|
"""停用一个已经注册的函数调用工具。
|
|
419
414
|
|
|
420
415
|
Returns:
|
|
421
|
-
如果没找到,会返回 False
|
|
416
|
+
如果没找到,会返回 False
|
|
417
|
+
|
|
418
|
+
"""
|
|
422
419
|
func_tool = self.get_func(name)
|
|
423
420
|
if func_tool is not None:
|
|
424
421
|
func_tool.active = False
|
|
425
422
|
|
|
426
423
|
inactivated_llm_tools: list = sp.get(
|
|
427
|
-
"inactivated_llm_tools",
|
|
424
|
+
"inactivated_llm_tools",
|
|
425
|
+
[],
|
|
426
|
+
scope="global",
|
|
427
|
+
scope_id="global",
|
|
428
428
|
)
|
|
429
429
|
if name not in inactivated_llm_tools:
|
|
430
430
|
inactivated_llm_tools.append(name)
|
|
@@ -445,13 +445,16 @@ class FunctionToolManager:
|
|
|
445
445
|
if func_tool.handler_module_path in star_map:
|
|
446
446
|
if not star_map[func_tool.handler_module_path].activated:
|
|
447
447
|
raise ValueError(
|
|
448
|
-
f"此函数调用工具所属的插件 {star_map[func_tool.handler_module_path].name} 已被禁用,请先在管理面板启用再激活此工具。"
|
|
448
|
+
f"此函数调用工具所属的插件 {star_map[func_tool.handler_module_path].name} 已被禁用,请先在管理面板启用再激活此工具。",
|
|
449
449
|
)
|
|
450
450
|
|
|
451
451
|
func_tool.active = True
|
|
452
452
|
|
|
453
453
|
inactivated_llm_tools: list = sp.get(
|
|
454
|
-
"inactivated_llm_tools",
|
|
454
|
+
"inactivated_llm_tools",
|
|
455
|
+
[],
|
|
456
|
+
scope="global",
|
|
457
|
+
scope_id="global",
|
|
455
458
|
)
|
|
456
459
|
if name in inactivated_llm_tools:
|
|
457
460
|
inactivated_llm_tools.remove(name)
|
|
@@ -479,7 +482,7 @@ class FunctionToolManager:
|
|
|
479
482
|
return DEFAULT_MCP_CONFIG
|
|
480
483
|
|
|
481
484
|
try:
|
|
482
|
-
with open(self.mcp_config_path,
|
|
485
|
+
with open(self.mcp_config_path, encoding="utf-8") as f:
|
|
483
486
|
return json.load(f)
|
|
484
487
|
except Exception as e:
|
|
485
488
|
logger.error(f"加载 MCP 配置失败: {e}")
|
|
@@ -509,7 +512,8 @@ class FunctionToolManager:
|
|
|
509
512
|
if response.status == 200:
|
|
510
513
|
data = await response.json()
|
|
511
514
|
mcp_server_list = data.get("data", {}).get(
|
|
512
|
-
"mcp_server_list",
|
|
515
|
+
"mcp_server_list",
|
|
516
|
+
[],
|
|
513
517
|
)
|
|
514
518
|
local_mcp_config = self.load_mcp_config()
|
|
515
519
|
|
|
@@ -541,23 +545,23 @@ class FunctionToolManager:
|
|
|
541
545
|
self.enable_mcp_server(
|
|
542
546
|
name=name,
|
|
543
547
|
config=local_mcp_config["mcpServers"][name],
|
|
544
|
-
)
|
|
548
|
+
),
|
|
545
549
|
)
|
|
546
550
|
await asyncio.gather(*tasks)
|
|
547
551
|
logger.info(
|
|
548
|
-
f"从 ModelScope 同步了 {synced_count} 个 MCP 服务器"
|
|
552
|
+
f"从 ModelScope 同步了 {synced_count} 个 MCP 服务器",
|
|
549
553
|
)
|
|
550
554
|
else:
|
|
551
555
|
logger.warning("没有找到可用的 ModelScope MCP 服务器")
|
|
552
556
|
else:
|
|
553
557
|
raise Exception(
|
|
554
|
-
f"ModelScope API 请求失败: HTTP {response.status}"
|
|
558
|
+
f"ModelScope API 请求失败: HTTP {response.status}",
|
|
555
559
|
)
|
|
556
560
|
|
|
557
561
|
except aiohttp.ClientError as e:
|
|
558
|
-
raise Exception(f"网络连接错误: {
|
|
562
|
+
raise Exception(f"网络连接错误: {e!s}")
|
|
559
563
|
except Exception as e:
|
|
560
|
-
raise Exception(f"同步 ModelScope MCP 服务器时发生错误: {
|
|
564
|
+
raise Exception(f"同步 ModelScope MCP 服务器时发生错误: {e!s}")
|
|
561
565
|
|
|
562
566
|
def __str__(self):
|
|
563
567
|
return str(self.func_list)
|
astrbot/core/provider/manager.py
CHANGED
|
@@ -5,16 +5,16 @@ from astrbot.core import logger, sp
|
|
|
5
5
|
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
|
6
6
|
from astrbot.core.db import BaseDatabase
|
|
7
7
|
|
|
8
|
+
from ..persona_mgr import PersonaManager
|
|
8
9
|
from .entities import ProviderType
|
|
9
10
|
from .provider import (
|
|
11
|
+
EmbeddingProvider,
|
|
10
12
|
Provider,
|
|
13
|
+
RerankProvider,
|
|
11
14
|
STTProvider,
|
|
12
15
|
TTSProvider,
|
|
13
|
-
EmbeddingProvider,
|
|
14
|
-
RerankProvider,
|
|
15
16
|
)
|
|
16
17
|
from .register import llm_tools, provider_cls_map
|
|
17
|
-
from ..persona_mgr import PersonaManager
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
class ProviderManager:
|
|
@@ -76,7 +76,10 @@ class ProviderManager:
|
|
|
76
76
|
return self.persona_mgr.selected_default_persona_v3
|
|
77
77
|
|
|
78
78
|
async def set_provider(
|
|
79
|
-
self,
|
|
79
|
+
self,
|
|
80
|
+
provider_id: str,
|
|
81
|
+
provider_type: ProviderType,
|
|
82
|
+
umo: str | None = None,
|
|
80
83
|
):
|
|
81
84
|
"""设置提供商。
|
|
82
85
|
|
|
@@ -86,6 +89,7 @@ class ProviderManager:
|
|
|
86
89
|
umo (str, optional): 用户会话 ID,用于提供商会话隔离。
|
|
87
90
|
|
|
88
91
|
Version 4.0.0: 这个版本下已经默认隔离提供商
|
|
92
|
+
|
|
89
93
|
"""
|
|
90
94
|
if provider_id not in self.inst_map:
|
|
91
95
|
raise ValueError(f"提供商 {provider_id} 不存在,无法设置。")
|
|
@@ -100,17 +104,20 @@ class ProviderManager:
|
|
|
100
104
|
|
|
101
105
|
prov = self.inst_map[provider_id]
|
|
102
106
|
if provider_type == ProviderType.TEXT_TO_SPEECH and isinstance(
|
|
103
|
-
prov,
|
|
107
|
+
prov,
|
|
108
|
+
TTSProvider,
|
|
104
109
|
):
|
|
105
110
|
self.curr_tts_provider_inst = prov
|
|
106
111
|
sp.put("curr_provider_tts", provider_id, scope="global", scope_id="global")
|
|
107
112
|
elif provider_type == ProviderType.SPEECH_TO_TEXT and isinstance(
|
|
108
|
-
prov,
|
|
113
|
+
prov,
|
|
114
|
+
STTProvider,
|
|
109
115
|
):
|
|
110
116
|
self.curr_stt_provider_inst = prov
|
|
111
117
|
sp.put("curr_provider_stt", provider_id, scope="global", scope_id="global")
|
|
112
118
|
elif provider_type == ProviderType.CHAT_COMPLETION and isinstance(
|
|
113
|
-
prov,
|
|
119
|
+
prov,
|
|
120
|
+
Provider,
|
|
114
121
|
):
|
|
115
122
|
self.curr_provider_inst = prov
|
|
116
123
|
sp.put("curr_provider", provider_id, scope="global", scope_id="global")
|
|
@@ -120,7 +127,9 @@ class ProviderManager:
|
|
|
120
127
|
return self.inst_map.get(provider_id)
|
|
121
128
|
|
|
122
129
|
def get_using_provider(
|
|
123
|
-
self,
|
|
130
|
+
self,
|
|
131
|
+
provider_type: ProviderType,
|
|
132
|
+
umo=None,
|
|
124
133
|
) -> Provider | STTProvider | TTSProvider | None:
|
|
125
134
|
"""获取正在使用的提供商实例。
|
|
126
135
|
|
|
@@ -130,6 +139,7 @@ class ProviderManager:
|
|
|
130
139
|
|
|
131
140
|
Returns:
|
|
132
141
|
Provider: 正在使用的提供商实例。
|
|
142
|
+
|
|
133
143
|
"""
|
|
134
144
|
provider = None
|
|
135
145
|
if umo:
|
|
@@ -219,7 +229,7 @@ class ProviderManager:
|
|
|
219
229
|
return
|
|
220
230
|
|
|
221
231
|
logger.info(
|
|
222
|
-
f"载入 {provider_config['type']}({provider_config['id']}) 服务提供商 ..."
|
|
232
|
+
f"载入 {provider_config['type']}({provider_config['id']}) 服务提供商 ...",
|
|
223
233
|
)
|
|
224
234
|
|
|
225
235
|
# 动态导入
|
|
@@ -321,18 +331,18 @@ class ProviderManager:
|
|
|
321
331
|
)
|
|
322
332
|
except (ImportError, ModuleNotFoundError) as e:
|
|
323
333
|
logger.critical(
|
|
324
|
-
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。"
|
|
334
|
+
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。",
|
|
325
335
|
)
|
|
326
336
|
return
|
|
327
337
|
except Exception as e:
|
|
328
338
|
logger.critical(
|
|
329
|
-
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。未知原因"
|
|
339
|
+
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。未知原因",
|
|
330
340
|
)
|
|
331
341
|
return
|
|
332
342
|
|
|
333
343
|
if provider_config["type"] not in provider_cls_map:
|
|
334
344
|
logger.error(
|
|
335
|
-
f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。"
|
|
345
|
+
f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。",
|
|
336
346
|
)
|
|
337
347
|
return
|
|
338
348
|
|
|
@@ -358,7 +368,7 @@ class ProviderManager:
|
|
|
358
368
|
):
|
|
359
369
|
self.curr_stt_provider_inst = inst
|
|
360
370
|
logger.info(
|
|
361
|
-
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。"
|
|
371
|
+
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。",
|
|
362
372
|
)
|
|
363
373
|
if not self.curr_stt_provider_inst:
|
|
364
374
|
self.curr_stt_provider_inst = inst
|
|
@@ -374,7 +384,7 @@ class ProviderManager:
|
|
|
374
384
|
if self.provider_settings.get("provider_id") == provider_config["id"]:
|
|
375
385
|
self.curr_tts_provider_inst = inst
|
|
376
386
|
logger.info(
|
|
377
|
-
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。"
|
|
387
|
+
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。",
|
|
378
388
|
)
|
|
379
389
|
if not self.curr_tts_provider_inst:
|
|
380
390
|
self.curr_tts_provider_inst = inst
|
|
@@ -397,7 +407,7 @@ class ProviderManager:
|
|
|
397
407
|
):
|
|
398
408
|
self.curr_provider_inst = inst
|
|
399
409
|
logger.info(
|
|
400
|
-
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。"
|
|
410
|
+
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。",
|
|
401
411
|
)
|
|
402
412
|
if not self.curr_provider_inst:
|
|
403
413
|
self.curr_provider_inst = inst
|
|
@@ -416,10 +426,10 @@ class ProviderManager:
|
|
|
416
426
|
self.inst_map[provider_config["id"]] = inst
|
|
417
427
|
except Exception as e:
|
|
418
428
|
logger.error(
|
|
419
|
-
f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}"
|
|
429
|
+
f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}",
|
|
420
430
|
)
|
|
421
431
|
raise Exception(
|
|
422
|
-
f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}"
|
|
432
|
+
f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}",
|
|
423
433
|
)
|
|
424
434
|
|
|
425
435
|
async def reload(self, provider_config: dict):
|
|
@@ -439,7 +449,7 @@ class ProviderManager:
|
|
|
439
449
|
elif self.curr_provider_inst is None and len(self.provider_insts) > 0:
|
|
440
450
|
self.curr_provider_inst = self.provider_insts[0]
|
|
441
451
|
logger.info(
|
|
442
|
-
f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。"
|
|
452
|
+
f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。",
|
|
443
453
|
)
|
|
444
454
|
|
|
445
455
|
if len(self.stt_provider_insts) == 0:
|
|
@@ -447,7 +457,7 @@ class ProviderManager:
|
|
|
447
457
|
elif self.curr_stt_provider_inst is None and len(self.stt_provider_insts) > 0:
|
|
448
458
|
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
|
449
459
|
logger.info(
|
|
450
|
-
f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。"
|
|
460
|
+
f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。",
|
|
451
461
|
)
|
|
452
462
|
|
|
453
463
|
if len(self.tts_provider_insts) == 0:
|
|
@@ -455,7 +465,7 @@ class ProviderManager:
|
|
|
455
465
|
elif self.curr_tts_provider_inst is None and len(self.tts_provider_insts) > 0:
|
|
456
466
|
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
|
457
467
|
logger.info(
|
|
458
|
-
f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。"
|
|
468
|
+
f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。",
|
|
459
469
|
)
|
|
460
470
|
|
|
461
471
|
def get_insts(self):
|
|
@@ -464,7 +474,7 @@ class ProviderManager:
|
|
|
464
474
|
async def terminate_provider(self, provider_id: str):
|
|
465
475
|
if provider_id in self.inst_map:
|
|
466
476
|
logger.info(
|
|
467
|
-
f"终止 {provider_id} 提供商适配器({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)}) ..."
|
|
477
|
+
f"终止 {provider_id} 提供商适配器({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)}) ...",
|
|
468
478
|
)
|
|
469
479
|
|
|
470
480
|
if self.inst_map[provider_id] in self.provider_insts:
|
|
@@ -491,7 +501,7 @@ class ProviderManager:
|
|
|
491
501
|
await self.inst_map[provider_id].terminate() # type: ignore
|
|
492
502
|
|
|
493
503
|
logger.info(
|
|
494
|
-
f"{provider_id} 提供商适配器已终止({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)})"
|
|
504
|
+
f"{provider_id} 提供商适配器已终止({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)})",
|
|
495
505
|
)
|
|
496
506
|
del self.inst_map[provider_id]
|
|
497
507
|
|