AstrBot 4.5.0__py3-none-any.whl → 4.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/api/__init__.py +10 -11
- astrbot/api/event/__init__.py +5 -6
- astrbot/api/event/filter/__init__.py +37 -36
- astrbot/api/platform/__init__.py +7 -8
- astrbot/api/provider/__init__.py +7 -7
- astrbot/api/star/__init__.py +3 -4
- astrbot/api/util/__init__.py +2 -2
- astrbot/cli/__main__.py +5 -5
- astrbot/cli/commands/__init__.py +3 -3
- astrbot/cli/commands/cmd_conf.py +19 -16
- astrbot/cli/commands/cmd_init.py +3 -2
- astrbot/cli/commands/cmd_plug.py +8 -10
- astrbot/cli/commands/cmd_run.py +5 -6
- astrbot/cli/utils/__init__.py +6 -6
- astrbot/cli/utils/basic.py +14 -14
- astrbot/cli/utils/plugin.py +24 -15
- astrbot/cli/utils/version_comparator.py +10 -12
- astrbot/core/__init__.py +8 -6
- astrbot/core/agent/agent.py +3 -2
- astrbot/core/agent/handoff.py +6 -2
- astrbot/core/agent/hooks.py +9 -6
- astrbot/core/agent/mcp_client.py +50 -15
- astrbot/core/agent/message.py +168 -0
- astrbot/core/agent/response.py +2 -1
- astrbot/core/agent/run_context.py +2 -3
- astrbot/core/agent/runners/base.py +10 -13
- astrbot/core/agent/runners/tool_loop_agent_runner.py +52 -51
- astrbot/core/agent/tool.py +60 -41
- astrbot/core/agent/tool_executor.py +9 -3
- astrbot/core/astr_agent_context.py +3 -1
- astrbot/core/astrbot_config_mgr.py +29 -9
- astrbot/core/config/__init__.py +2 -2
- astrbot/core/config/astrbot_config.py +28 -26
- astrbot/core/config/default.py +44 -6
- astrbot/core/conversation_mgr.py +105 -36
- astrbot/core/core_lifecycle.py +68 -54
- astrbot/core/db/__init__.py +33 -18
- astrbot/core/db/migration/helper.py +18 -13
- astrbot/core/db/migration/migra_3_to_4.py +53 -34
- astrbot/core/db/migration/migra_45_to_46.py +1 -1
- astrbot/core/db/migration/shared_preferences_v3.py +2 -1
- astrbot/core/db/migration/sqlite_v3.py +26 -23
- astrbot/core/db/po.py +27 -18
- astrbot/core/db/sqlite.py +74 -45
- astrbot/core/db/vec_db/base.py +10 -14
- astrbot/core/db/vec_db/faiss_impl/document_storage.py +90 -77
- astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +9 -3
- astrbot/core/db/vec_db/faiss_impl/vec_db.py +36 -31
- astrbot/core/event_bus.py +8 -6
- astrbot/core/file_token_service.py +6 -5
- astrbot/core/initial_loader.py +7 -5
- astrbot/core/knowledge_base/chunking/__init__.py +1 -3
- astrbot/core/knowledge_base/chunking/base.py +1 -0
- astrbot/core/knowledge_base/chunking/fixed_size.py +2 -0
- astrbot/core/knowledge_base/chunking/recursive.py +16 -10
- astrbot/core/knowledge_base/kb_db_sqlite.py +50 -48
- astrbot/core/knowledge_base/kb_helper.py +30 -17
- astrbot/core/knowledge_base/kb_mgr.py +6 -7
- astrbot/core/knowledge_base/models.py +10 -4
- astrbot/core/knowledge_base/parsers/__init__.py +3 -5
- astrbot/core/knowledge_base/parsers/base.py +1 -0
- astrbot/core/knowledge_base/parsers/markitdown_parser.py +2 -1
- astrbot/core/knowledge_base/parsers/pdf_parser.py +2 -1
- astrbot/core/knowledge_base/parsers/text_parser.py +1 -0
- astrbot/core/knowledge_base/parsers/util.py +1 -1
- astrbot/core/knowledge_base/retrieval/__init__.py +6 -8
- astrbot/core/knowledge_base/retrieval/manager.py +17 -14
- astrbot/core/knowledge_base/retrieval/rank_fusion.py +7 -3
- astrbot/core/knowledge_base/retrieval/sparse_retriever.py +11 -5
- astrbot/core/log.py +21 -13
- astrbot/core/message/components.py +123 -217
- astrbot/core/message/message_event_result.py +24 -24
- astrbot/core/persona_mgr.py +20 -11
- astrbot/core/pipeline/__init__.py +7 -7
- astrbot/core/pipeline/content_safety_check/stage.py +13 -9
- astrbot/core/pipeline/content_safety_check/strategies/__init__.py +1 -2
- astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py +12 -13
- astrbot/core/pipeline/content_safety_check/strategies/keywords.py +1 -0
- astrbot/core/pipeline/content_safety_check/strategies/strategy.py +6 -6
- astrbot/core/pipeline/context.py +4 -1
- astrbot/core/pipeline/context_utils.py +77 -7
- astrbot/core/pipeline/preprocess_stage/stage.py +12 -9
- astrbot/core/pipeline/process_stage/method/llm_request.py +125 -72
- astrbot/core/pipeline/process_stage/method/star_request.py +19 -17
- astrbot/core/pipeline/process_stage/stage.py +13 -10
- astrbot/core/pipeline/process_stage/utils.py +6 -5
- astrbot/core/pipeline/rate_limit_check/stage.py +37 -36
- astrbot/core/pipeline/respond/stage.py +23 -20
- astrbot/core/pipeline/result_decorate/stage.py +31 -23
- astrbot/core/pipeline/scheduler.py +12 -8
- astrbot/core/pipeline/session_status_check/stage.py +12 -8
- astrbot/core/pipeline/stage.py +10 -4
- astrbot/core/pipeline/waking_check/stage.py +24 -18
- astrbot/core/pipeline/whitelist_check/stage.py +10 -7
- astrbot/core/platform/__init__.py +6 -6
- astrbot/core/platform/astr_message_event.py +76 -110
- astrbot/core/platform/astrbot_message.py +11 -13
- astrbot/core/platform/manager.py +16 -15
- astrbot/core/platform/message_session.py +5 -3
- astrbot/core/platform/platform.py +16 -24
- astrbot/core/platform/platform_metadata.py +4 -4
- astrbot/core/platform/register.py +8 -8
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +23 -15
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +51 -33
- astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +47 -29
- astrbot/core/platform/sources/dingtalk/dingtalk_event.py +7 -3
- astrbot/core/platform/sources/discord/client.py +9 -6
- astrbot/core/platform/sources/discord/components.py +18 -14
- astrbot/core/platform/sources/discord/discord_platform_adapter.py +45 -30
- astrbot/core/platform/sources/discord/discord_platform_event.py +38 -30
- astrbot/core/platform/sources/lark/lark_adapter.py +23 -17
- astrbot/core/platform/sources/lark/lark_event.py +21 -14
- astrbot/core/platform/sources/misskey/misskey_adapter.py +107 -67
- astrbot/core/platform/sources/misskey/misskey_api.py +153 -129
- astrbot/core/platform/sources/misskey/misskey_event.py +20 -15
- astrbot/core/platform/sources/misskey/misskey_utils.py +74 -62
- astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +63 -44
- astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +41 -26
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +36 -17
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py +3 -1
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +12 -7
- astrbot/core/platform/sources/satori/satori_adapter.py +56 -38
- astrbot/core/platform/sources/satori/satori_event.py +34 -25
- astrbot/core/platform/sources/slack/client.py +11 -9
- astrbot/core/platform/sources/slack/slack_adapter.py +52 -36
- astrbot/core/platform/sources/slack/slack_event.py +34 -24
- astrbot/core/platform/sources/telegram/tg_adapter.py +38 -18
- astrbot/core/platform/sources/telegram/tg_event.py +32 -18
- astrbot/core/platform/sources/webchat/webchat_adapter.py +27 -17
- astrbot/core/platform/sources/webchat/webchat_event.py +14 -10
- astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +115 -120
- astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +9 -8
- astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py +15 -16
- astrbot/core/platform/sources/wecom/wecom_adapter.py +35 -18
- astrbot/core/platform/sources/wecom/wecom_event.py +55 -48
- astrbot/core/platform/sources/wecom/wecom_kf.py +34 -44
- astrbot/core/platform/sources/wecom/wecom_kf_message.py +26 -10
- astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +18 -10
- astrbot/core/platform/sources/wecom_ai_bot/__init__.py +3 -5
- astrbot/core/platform/sources/wecom_ai_bot/ierror.py +0 -1
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +61 -37
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +67 -28
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +8 -9
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +18 -9
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +14 -12
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +22 -12
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +40 -26
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +47 -45
- astrbot/core/platform_message_history_mgr.py +5 -3
- astrbot/core/provider/__init__.py +2 -3
- astrbot/core/provider/entites.py +8 -8
- astrbot/core/provider/entities.py +61 -75
- astrbot/core/provider/func_tool_manager.py +59 -55
- astrbot/core/provider/manager.py +40 -22
- astrbot/core/provider/provider.py +72 -46
- astrbot/core/provider/register.py +7 -7
- astrbot/core/provider/sources/anthropic_source.py +48 -30
- astrbot/core/provider/sources/azure_tts_source.py +17 -13
- astrbot/core/provider/sources/coze_api_client.py +27 -17
- astrbot/core/provider/sources/coze_source.py +104 -87
- astrbot/core/provider/sources/dashscope_source.py +18 -11
- astrbot/core/provider/sources/dashscope_tts.py +36 -23
- astrbot/core/provider/sources/dify_source.py +25 -20
- astrbot/core/provider/sources/edge_tts_source.py +21 -17
- astrbot/core/provider/sources/fishaudio_tts_api_source.py +22 -14
- astrbot/core/provider/sources/gemini_embedding_source.py +12 -13
- astrbot/core/provider/sources/gemini_source.py +72 -58
- astrbot/core/provider/sources/gemini_tts_source.py +8 -6
- astrbot/core/provider/sources/gsv_selfhosted_source.py +17 -14
- astrbot/core/provider/sources/gsvi_tts_source.py +11 -7
- astrbot/core/provider/sources/minimax_tts_api_source.py +50 -40
- astrbot/core/provider/sources/openai_embedding_source.py +6 -8
- astrbot/core/provider/sources/openai_source.py +102 -69
- astrbot/core/provider/sources/openai_tts_api_source.py +14 -6
- astrbot/core/provider/sources/sensevoice_selfhosted_source.py +13 -11
- astrbot/core/provider/sources/vllm_rerank_source.py +10 -4
- astrbot/core/provider/sources/volcengine_tts.py +38 -31
- astrbot/core/provider/sources/whisper_api_source.py +14 -12
- astrbot/core/provider/sources/whisper_selfhosted_source.py +15 -11
- astrbot/core/provider/sources/xinference_rerank_source.py +116 -0
- astrbot/core/provider/sources/xinference_stt_provider.py +197 -0
- astrbot/core/star/__init__.py +16 -11
- astrbot/core/star/config.py +10 -15
- astrbot/core/star/context.py +109 -84
- astrbot/core/star/filter/__init__.py +4 -3
- astrbot/core/star/filter/command.py +30 -28
- astrbot/core/star/filter/command_group.py +27 -24
- astrbot/core/star/filter/custom_filter.py +6 -5
- astrbot/core/star/filter/event_message_type.py +4 -2
- astrbot/core/star/filter/permission.py +4 -2
- astrbot/core/star/filter/platform_adapter_type.py +4 -2
- astrbot/core/star/filter/regex.py +4 -2
- astrbot/core/star/register/__init__.py +19 -19
- astrbot/core/star/register/star.py +6 -2
- astrbot/core/star/register/star_handler.py +96 -73
- astrbot/core/star/session_llm_manager.py +48 -14
- astrbot/core/star/session_plugin_manager.py +29 -15
- astrbot/core/star/star.py +1 -2
- astrbot/core/star/star_handler.py +13 -8
- astrbot/core/star/star_manager.py +151 -59
- astrbot/core/star/star_tools.py +44 -37
- astrbot/core/star/updator.py +10 -10
- astrbot/core/umop_config_router.py +10 -4
- astrbot/core/updator.py +13 -5
- astrbot/core/utils/astrbot_path.py +3 -5
- astrbot/core/utils/dify_api_client.py +33 -15
- astrbot/core/utils/io.py +66 -42
- astrbot/core/utils/log_pipe.py +1 -1
- astrbot/core/utils/metrics.py +7 -7
- astrbot/core/utils/path_util.py +15 -16
- astrbot/core/utils/pip_installer.py +5 -5
- astrbot/core/utils/session_waiter.py +19 -20
- astrbot/core/utils/shared_preferences.py +45 -20
- astrbot/core/utils/t2i/__init__.py +4 -1
- astrbot/core/utils/t2i/network_strategy.py +35 -26
- astrbot/core/utils/t2i/renderer.py +11 -5
- astrbot/core/utils/t2i/template_manager.py +14 -15
- astrbot/core/utils/tencent_record_helper.py +19 -13
- astrbot/core/utils/version_comparator.py +10 -13
- astrbot/core/zip_updator.py +43 -40
- astrbot/dashboard/routes/__init__.py +18 -18
- astrbot/dashboard/routes/auth.py +10 -8
- astrbot/dashboard/routes/chat.py +30 -21
- astrbot/dashboard/routes/config.py +92 -75
- astrbot/dashboard/routes/conversation.py +46 -39
- astrbot/dashboard/routes/file.py +4 -2
- astrbot/dashboard/routes/knowledge_base.py +47 -40
- astrbot/dashboard/routes/log.py +9 -4
- astrbot/dashboard/routes/persona.py +19 -16
- astrbot/dashboard/routes/plugin.py +69 -55
- astrbot/dashboard/routes/route.py +3 -1
- astrbot/dashboard/routes/session_management.py +130 -116
- astrbot/dashboard/routes/stat.py +34 -34
- astrbot/dashboard/routes/t2i.py +15 -12
- astrbot/dashboard/routes/tools.py +47 -52
- astrbot/dashboard/routes/update.py +32 -28
- astrbot/dashboard/server.py +30 -26
- astrbot/dashboard/utils.py +8 -4
- {astrbot-4.5.0.dist-info → astrbot-4.5.2.dist-info}/METADATA +4 -2
- astrbot-4.5.2.dist-info/RECORD +261 -0
- astrbot-4.5.0.dist-info/RECORD +0 -258
- {astrbot-4.5.0.dist-info → astrbot-4.5.2.dist-info}/WHEEL +0 -0
- {astrbot-4.5.0.dist-info → astrbot-4.5.2.dist-info}/entry_points.txt +0 -0
- {astrbot-4.5.0.dist-info → astrbot-4.5.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,17 +1,18 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
2
4
|
import json
|
|
3
5
|
import os
|
|
4
|
-
import
|
|
6
|
+
from collections.abc import Awaitable, Callable
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
5
9
|
import aiohttp
|
|
6
10
|
|
|
7
|
-
from typing import Dict, List, Awaitable, Callable, Any
|
|
8
11
|
from astrbot import logger
|
|
9
12
|
from astrbot.core import sp
|
|
10
|
-
|
|
13
|
+
from astrbot.core.agent.mcp_client import MCPClient, MCPTool
|
|
14
|
+
from astrbot.core.agent.tool import FunctionTool, ToolSet
|
|
11
15
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
|
12
|
-
from astrbot.core.agent.mcp_client import MCPClient
|
|
13
|
-
from astrbot.core.agent.tool import ToolSet, FunctionTool
|
|
14
|
-
|
|
15
16
|
|
|
16
17
|
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
|
|
17
18
|
|
|
@@ -30,7 +31,7 @@ FuncTool = FunctionTool
|
|
|
30
31
|
|
|
31
32
|
def _prepare_config(config: dict) -> dict:
|
|
32
33
|
"""准备配置,处理嵌套格式"""
|
|
33
|
-
if
|
|
34
|
+
if config.get("mcpServers"):
|
|
34
35
|
first_key = next(iter(config["mcpServers"]))
|
|
35
36
|
config = config["mcpServers"][first_key]
|
|
36
37
|
config.pop("active", None)
|
|
@@ -72,8 +73,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
|
|
72
73
|
) as response:
|
|
73
74
|
if response.status == 200:
|
|
74
75
|
return True, ""
|
|
75
|
-
|
|
76
|
-
return False, f"HTTP {response.status}: {response.reason}"
|
|
76
|
+
return False, f"HTTP {response.status}: {response.reason}"
|
|
77
77
|
else:
|
|
78
78
|
async with session.get(
|
|
79
79
|
url,
|
|
@@ -85,8 +85,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
|
|
85
85
|
) as response:
|
|
86
86
|
if response.status == 200:
|
|
87
87
|
return True, ""
|
|
88
|
-
|
|
89
|
-
return False, f"HTTP {response.status}: {response.reason}"
|
|
88
|
+
return False, f"HTTP {response.status}: {response.reason}"
|
|
90
89
|
|
|
91
90
|
except asyncio.TimeoutError:
|
|
92
91
|
return False, f"连接超时: {timeout}秒"
|
|
@@ -96,10 +95,10 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
|
|
96
95
|
|
|
97
96
|
class FunctionToolManager:
|
|
98
97
|
def __init__(self) -> None:
|
|
99
|
-
self.func_list:
|
|
100
|
-
self.mcp_client_dict:
|
|
98
|
+
self.func_list: list[FuncTool] = []
|
|
99
|
+
self.mcp_client_dict: dict[str, MCPClient] = {}
|
|
101
100
|
"""MCP 服务列表"""
|
|
102
|
-
self.mcp_client_event:
|
|
101
|
+
self.mcp_client_event: dict[str, asyncio.Event] = {}
|
|
103
102
|
|
|
104
103
|
def empty(self) -> bool:
|
|
105
104
|
return len(self.func_list) == 0
|
|
@@ -150,14 +149,12 @@ class FunctionToolManager:
|
|
|
150
149
|
func_args=func_args,
|
|
151
150
|
desc=desc,
|
|
152
151
|
handler=handler,
|
|
153
|
-
)
|
|
152
|
+
),
|
|
154
153
|
)
|
|
155
154
|
logger.info(f"添加函数调用工具: {name}")
|
|
156
155
|
|
|
157
156
|
def remove_func(self, name: str) -> None:
|
|
158
|
-
"""
|
|
159
|
-
删除一个函数调用工具。
|
|
160
|
-
"""
|
|
157
|
+
"""删除一个函数调用工具。"""
|
|
161
158
|
for i, f in enumerate(self.func_list):
|
|
162
159
|
if f.name == name:
|
|
163
160
|
self.func_list.pop(i)
|
|
@@ -202,16 +199,16 @@ class FunctionToolManager:
|
|
|
202
199
|
logger.info(f"未找到 MCP 服务配置文件,已创建默认配置文件 {mcp_json_file}")
|
|
203
200
|
return
|
|
204
201
|
|
|
205
|
-
mcp_server_json_obj:
|
|
206
|
-
open(mcp_json_file,
|
|
202
|
+
mcp_server_json_obj: dict[str, dict] = json.load(
|
|
203
|
+
open(mcp_json_file, encoding="utf-8"),
|
|
207
204
|
)["mcpServers"]
|
|
208
205
|
|
|
209
|
-
for name in mcp_server_json_obj
|
|
206
|
+
for name in mcp_server_json_obj:
|
|
210
207
|
cfg = mcp_server_json_obj[name]
|
|
211
208
|
if cfg.get("active", True):
|
|
212
209
|
event = asyncio.Event()
|
|
213
210
|
asyncio.create_task(
|
|
214
|
-
self._init_mcp_client_task_wrapper(name, cfg, event)
|
|
211
|
+
self._init_mcp_client_task_wrapper(name, cfg, event),
|
|
215
212
|
)
|
|
216
213
|
self.mcp_client_event[name] = event
|
|
217
214
|
|
|
@@ -257,18 +254,15 @@ class FunctionToolManager:
|
|
|
257
254
|
self.func_list = [
|
|
258
255
|
f
|
|
259
256
|
for f in self.func_list
|
|
260
|
-
if not (f
|
|
257
|
+
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
|
|
261
258
|
]
|
|
262
259
|
|
|
263
260
|
# 将 MCP 工具转换为 FuncTool 并添加到 func_list
|
|
264
261
|
for tool in mcp_client.tools:
|
|
265
|
-
func_tool =
|
|
266
|
-
|
|
267
|
-
parameters=tool.inputSchema,
|
|
268
|
-
description=tool.description,
|
|
269
|
-
origin="mcp",
|
|
270
|
-
mcp_server_name=name,
|
|
262
|
+
func_tool = MCPTool(
|
|
263
|
+
mcp_tool=tool,
|
|
271
264
|
mcp_client=mcp_client,
|
|
265
|
+
mcp_server_name=name,
|
|
272
266
|
)
|
|
273
267
|
self.func_list.append(func_tool)
|
|
274
268
|
|
|
@@ -287,7 +281,7 @@ class FunctionToolManager:
|
|
|
287
281
|
self.func_list = [
|
|
288
282
|
f
|
|
289
283
|
for f in self.func_list
|
|
290
|
-
if not (f
|
|
284
|
+
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
|
|
291
285
|
]
|
|
292
286
|
logger.info(f"已关闭 MCP 服务 {name}")
|
|
293
287
|
|
|
@@ -325,9 +319,11 @@ class FunctionToolManager:
|
|
|
325
319
|
event (asyncio.Event): Event to signal when the MCP client is ready.
|
|
326
320
|
ready_future (asyncio.Future): Future to signal when the MCP client is ready.
|
|
327
321
|
timeout (int): Timeout for the initialization.
|
|
322
|
+
|
|
328
323
|
Raises:
|
|
329
324
|
TimeoutError: If the initialization does not complete within the specified timeout.
|
|
330
325
|
Exception: If there is an error during initialization.
|
|
326
|
+
|
|
331
327
|
"""
|
|
332
328
|
if not event:
|
|
333
329
|
event = asyncio.Event()
|
|
@@ -336,7 +332,7 @@ class FunctionToolManager:
|
|
|
336
332
|
if name in self.mcp_client_dict:
|
|
337
333
|
return
|
|
338
334
|
asyncio.create_task(
|
|
339
|
-
self._init_mcp_client_task_wrapper(name, config, event, ready_future)
|
|
335
|
+
self._init_mcp_client_task_wrapper(name, config, event, ready_future),
|
|
340
336
|
)
|
|
341
337
|
try:
|
|
342
338
|
await asyncio.wait_for(ready_future, timeout=timeout)
|
|
@@ -349,13 +345,16 @@ class FunctionToolManager:
|
|
|
349
345
|
raise exc
|
|
350
346
|
|
|
351
347
|
async def disable_mcp_server(
|
|
352
|
-
self,
|
|
348
|
+
self,
|
|
349
|
+
name: str | None = None,
|
|
350
|
+
timeout: float = 10,
|
|
353
351
|
) -> None:
|
|
354
352
|
"""Disable an MCP server by its name.
|
|
355
353
|
|
|
356
354
|
Args:
|
|
357
355
|
name (str): The name of the MCP server to disable. If None, ALL MCP servers will be disabled.
|
|
358
356
|
timeout (int): Timeout.
|
|
357
|
+
|
|
359
358
|
"""
|
|
360
359
|
if name:
|
|
361
360
|
if name not in self.mcp_client_event:
|
|
@@ -372,7 +371,7 @@ class FunctionToolManager:
|
|
|
372
371
|
self.func_list = [
|
|
373
372
|
f
|
|
374
373
|
for f in self.func_list
|
|
375
|
-
if f
|
|
374
|
+
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
|
|
376
375
|
]
|
|
377
376
|
else:
|
|
378
377
|
running_events = [
|
|
@@ -386,30 +385,26 @@ class FunctionToolManager:
|
|
|
386
385
|
finally:
|
|
387
386
|
self.mcp_client_event.clear()
|
|
388
387
|
self.mcp_client_dict.clear()
|
|
389
|
-
self.func_list = [
|
|
388
|
+
self.func_list = [
|
|
389
|
+
f for f in self.func_list if not isinstance(f, MCPTool)
|
|
390
|
+
]
|
|
390
391
|
|
|
391
392
|
def get_func_desc_openai_style(self, omit_empty_parameter_field=False) -> list:
|
|
392
|
-
"""
|
|
393
|
-
获得 OpenAI API 风格的**已经激活**的工具描述
|
|
394
|
-
"""
|
|
393
|
+
"""获得 OpenAI API 风格的**已经激活**的工具描述"""
|
|
395
394
|
tools = [f for f in self.func_list if f.active]
|
|
396
395
|
toolset = ToolSet(tools)
|
|
397
396
|
return toolset.openai_schema(
|
|
398
|
-
omit_empty_parameter_field=omit_empty_parameter_field
|
|
397
|
+
omit_empty_parameter_field=omit_empty_parameter_field,
|
|
399
398
|
)
|
|
400
399
|
|
|
401
400
|
def get_func_desc_anthropic_style(self) -> list:
|
|
402
|
-
"""
|
|
403
|
-
获得 Anthropic API 风格的**已经激活**的工具描述
|
|
404
|
-
"""
|
|
401
|
+
"""获得 Anthropic API 风格的**已经激活**的工具描述"""
|
|
405
402
|
tools = [f for f in self.func_list if f.active]
|
|
406
403
|
toolset = ToolSet(tools)
|
|
407
404
|
return toolset.anthropic_schema()
|
|
408
405
|
|
|
409
406
|
def get_func_desc_google_genai_style(self) -> dict:
|
|
410
|
-
"""
|
|
411
|
-
获得 Google GenAI API 风格的**已经激活**的工具描述
|
|
412
|
-
"""
|
|
407
|
+
"""获得 Google GenAI API 风格的**已经激活**的工具描述"""
|
|
413
408
|
tools = [f for f in self.func_list if f.active]
|
|
414
409
|
toolset = ToolSet(tools)
|
|
415
410
|
return toolset.google_schema()
|
|
@@ -418,13 +413,18 @@ class FunctionToolManager:
|
|
|
418
413
|
"""停用一个已经注册的函数调用工具。
|
|
419
414
|
|
|
420
415
|
Returns:
|
|
421
|
-
如果没找到,会返回 False
|
|
416
|
+
如果没找到,会返回 False
|
|
417
|
+
|
|
418
|
+
"""
|
|
422
419
|
func_tool = self.get_func(name)
|
|
423
420
|
if func_tool is not None:
|
|
424
421
|
func_tool.active = False
|
|
425
422
|
|
|
426
423
|
inactivated_llm_tools: list = sp.get(
|
|
427
|
-
"inactivated_llm_tools",
|
|
424
|
+
"inactivated_llm_tools",
|
|
425
|
+
[],
|
|
426
|
+
scope="global",
|
|
427
|
+
scope_id="global",
|
|
428
428
|
)
|
|
429
429
|
if name not in inactivated_llm_tools:
|
|
430
430
|
inactivated_llm_tools.append(name)
|
|
@@ -445,13 +445,16 @@ class FunctionToolManager:
|
|
|
445
445
|
if func_tool.handler_module_path in star_map:
|
|
446
446
|
if not star_map[func_tool.handler_module_path].activated:
|
|
447
447
|
raise ValueError(
|
|
448
|
-
f"此函数调用工具所属的插件 {star_map[func_tool.handler_module_path].name} 已被禁用,请先在管理面板启用再激活此工具。"
|
|
448
|
+
f"此函数调用工具所属的插件 {star_map[func_tool.handler_module_path].name} 已被禁用,请先在管理面板启用再激活此工具。",
|
|
449
449
|
)
|
|
450
450
|
|
|
451
451
|
func_tool.active = True
|
|
452
452
|
|
|
453
453
|
inactivated_llm_tools: list = sp.get(
|
|
454
|
-
"inactivated_llm_tools",
|
|
454
|
+
"inactivated_llm_tools",
|
|
455
|
+
[],
|
|
456
|
+
scope="global",
|
|
457
|
+
scope_id="global",
|
|
455
458
|
)
|
|
456
459
|
if name in inactivated_llm_tools:
|
|
457
460
|
inactivated_llm_tools.remove(name)
|
|
@@ -479,7 +482,7 @@ class FunctionToolManager:
|
|
|
479
482
|
return DEFAULT_MCP_CONFIG
|
|
480
483
|
|
|
481
484
|
try:
|
|
482
|
-
with open(self.mcp_config_path,
|
|
485
|
+
with open(self.mcp_config_path, encoding="utf-8") as f:
|
|
483
486
|
return json.load(f)
|
|
484
487
|
except Exception as e:
|
|
485
488
|
logger.error(f"加载 MCP 配置失败: {e}")
|
|
@@ -509,7 +512,8 @@ class FunctionToolManager:
|
|
|
509
512
|
if response.status == 200:
|
|
510
513
|
data = await response.json()
|
|
511
514
|
mcp_server_list = data.get("data", {}).get(
|
|
512
|
-
"mcp_server_list",
|
|
515
|
+
"mcp_server_list",
|
|
516
|
+
[],
|
|
513
517
|
)
|
|
514
518
|
local_mcp_config = self.load_mcp_config()
|
|
515
519
|
|
|
@@ -541,23 +545,23 @@ class FunctionToolManager:
|
|
|
541
545
|
self.enable_mcp_server(
|
|
542
546
|
name=name,
|
|
543
547
|
config=local_mcp_config["mcpServers"][name],
|
|
544
|
-
)
|
|
548
|
+
),
|
|
545
549
|
)
|
|
546
550
|
await asyncio.gather(*tasks)
|
|
547
551
|
logger.info(
|
|
548
|
-
f"从 ModelScope 同步了 {synced_count} 个 MCP 服务器"
|
|
552
|
+
f"从 ModelScope 同步了 {synced_count} 个 MCP 服务器",
|
|
549
553
|
)
|
|
550
554
|
else:
|
|
551
555
|
logger.warning("没有找到可用的 ModelScope MCP 服务器")
|
|
552
556
|
else:
|
|
553
557
|
raise Exception(
|
|
554
|
-
f"ModelScope API 请求失败: HTTP {response.status}"
|
|
558
|
+
f"ModelScope API 请求失败: HTTP {response.status}",
|
|
555
559
|
)
|
|
556
560
|
|
|
557
561
|
except aiohttp.ClientError as e:
|
|
558
|
-
raise Exception(f"网络连接错误: {
|
|
562
|
+
raise Exception(f"网络连接错误: {e!s}")
|
|
559
563
|
except Exception as e:
|
|
560
|
-
raise Exception(f"同步 ModelScope MCP 服务器时发生错误: {
|
|
564
|
+
raise Exception(f"同步 ModelScope MCP 服务器时发生错误: {e!s}")
|
|
561
565
|
|
|
562
566
|
def __str__(self):
|
|
563
567
|
return str(self.func_list)
|
astrbot/core/provider/manager.py
CHANGED
|
@@ -5,16 +5,16 @@ from astrbot.core import logger, sp
|
|
|
5
5
|
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
|
6
6
|
from astrbot.core.db import BaseDatabase
|
|
7
7
|
|
|
8
|
+
from ..persona_mgr import PersonaManager
|
|
8
9
|
from .entities import ProviderType
|
|
9
10
|
from .provider import (
|
|
11
|
+
EmbeddingProvider,
|
|
10
12
|
Provider,
|
|
13
|
+
RerankProvider,
|
|
11
14
|
STTProvider,
|
|
12
15
|
TTSProvider,
|
|
13
|
-
EmbeddingProvider,
|
|
14
|
-
RerankProvider,
|
|
15
16
|
)
|
|
16
17
|
from .register import llm_tools, provider_cls_map
|
|
17
|
-
from ..persona_mgr import PersonaManager
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
class ProviderManager:
|
|
@@ -76,7 +76,10 @@ class ProviderManager:
|
|
|
76
76
|
return self.persona_mgr.selected_default_persona_v3
|
|
77
77
|
|
|
78
78
|
async def set_provider(
|
|
79
|
-
self,
|
|
79
|
+
self,
|
|
80
|
+
provider_id: str,
|
|
81
|
+
provider_type: ProviderType,
|
|
82
|
+
umo: str | None = None,
|
|
80
83
|
):
|
|
81
84
|
"""设置提供商。
|
|
82
85
|
|
|
@@ -86,6 +89,7 @@ class ProviderManager:
|
|
|
86
89
|
umo (str, optional): 用户会话 ID,用于提供商会话隔离。
|
|
87
90
|
|
|
88
91
|
Version 4.0.0: 这个版本下已经默认隔离提供商
|
|
92
|
+
|
|
89
93
|
"""
|
|
90
94
|
if provider_id not in self.inst_map:
|
|
91
95
|
raise ValueError(f"提供商 {provider_id} 不存在,无法设置。")
|
|
@@ -100,17 +104,20 @@ class ProviderManager:
|
|
|
100
104
|
|
|
101
105
|
prov = self.inst_map[provider_id]
|
|
102
106
|
if provider_type == ProviderType.TEXT_TO_SPEECH and isinstance(
|
|
103
|
-
prov,
|
|
107
|
+
prov,
|
|
108
|
+
TTSProvider,
|
|
104
109
|
):
|
|
105
110
|
self.curr_tts_provider_inst = prov
|
|
106
111
|
sp.put("curr_provider_tts", provider_id, scope="global", scope_id="global")
|
|
107
112
|
elif provider_type == ProviderType.SPEECH_TO_TEXT and isinstance(
|
|
108
|
-
prov,
|
|
113
|
+
prov,
|
|
114
|
+
STTProvider,
|
|
109
115
|
):
|
|
110
116
|
self.curr_stt_provider_inst = prov
|
|
111
117
|
sp.put("curr_provider_stt", provider_id, scope="global", scope_id="global")
|
|
112
118
|
elif provider_type == ProviderType.CHAT_COMPLETION and isinstance(
|
|
113
|
-
prov,
|
|
119
|
+
prov,
|
|
120
|
+
Provider,
|
|
114
121
|
):
|
|
115
122
|
self.curr_provider_inst = prov
|
|
116
123
|
sp.put("curr_provider", provider_id, scope="global", scope_id="global")
|
|
@@ -120,7 +127,9 @@ class ProviderManager:
|
|
|
120
127
|
return self.inst_map.get(provider_id)
|
|
121
128
|
|
|
122
129
|
def get_using_provider(
|
|
123
|
-
self,
|
|
130
|
+
self,
|
|
131
|
+
provider_type: ProviderType,
|
|
132
|
+
umo=None,
|
|
124
133
|
) -> Provider | STTProvider | TTSProvider | None:
|
|
125
134
|
"""获取正在使用的提供商实例。
|
|
126
135
|
|
|
@@ -130,6 +139,7 @@ class ProviderManager:
|
|
|
130
139
|
|
|
131
140
|
Returns:
|
|
132
141
|
Provider: 正在使用的提供商实例。
|
|
142
|
+
|
|
133
143
|
"""
|
|
134
144
|
provider = None
|
|
135
145
|
if umo:
|
|
@@ -219,7 +229,7 @@ class ProviderManager:
|
|
|
219
229
|
return
|
|
220
230
|
|
|
221
231
|
logger.info(
|
|
222
|
-
f"载入 {provider_config['type']}({provider_config['id']}) 服务提供商 ..."
|
|
232
|
+
f"载入 {provider_config['type']}({provider_config['id']}) 服务提供商 ...",
|
|
223
233
|
)
|
|
224
234
|
|
|
225
235
|
# 动态导入
|
|
@@ -259,6 +269,10 @@ class ProviderManager:
|
|
|
259
269
|
from .sources.whisper_selfhosted_source import (
|
|
260
270
|
ProviderOpenAIWhisperSelfHost as ProviderOpenAIWhisperSelfHost,
|
|
261
271
|
)
|
|
272
|
+
case "xinference_stt":
|
|
273
|
+
from .sources.xinference_stt_provider import (
|
|
274
|
+
ProviderXinferenceSTT as ProviderXinferenceSTT,
|
|
275
|
+
)
|
|
262
276
|
case "openai_tts_api":
|
|
263
277
|
from .sources.openai_tts_api_source import (
|
|
264
278
|
ProviderOpenAITTSAPI as ProviderOpenAITTSAPI,
|
|
@@ -311,20 +325,24 @@ class ProviderManager:
|
|
|
311
325
|
from .sources.vllm_rerank_source import (
|
|
312
326
|
VLLMRerankProvider as VLLMRerankProvider,
|
|
313
327
|
)
|
|
328
|
+
case "xinference_rerank":
|
|
329
|
+
from .sources.xinference_rerank_source import (
|
|
330
|
+
XinferenceRerankProvider as XinferenceRerankProvider,
|
|
331
|
+
)
|
|
314
332
|
except (ImportError, ModuleNotFoundError) as e:
|
|
315
333
|
logger.critical(
|
|
316
|
-
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。"
|
|
334
|
+
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。",
|
|
317
335
|
)
|
|
318
336
|
return
|
|
319
337
|
except Exception as e:
|
|
320
338
|
logger.critical(
|
|
321
|
-
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。未知原因"
|
|
339
|
+
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。未知原因",
|
|
322
340
|
)
|
|
323
341
|
return
|
|
324
342
|
|
|
325
343
|
if provider_config["type"] not in provider_cls_map:
|
|
326
344
|
logger.error(
|
|
327
|
-
f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。"
|
|
345
|
+
f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。",
|
|
328
346
|
)
|
|
329
347
|
return
|
|
330
348
|
|
|
@@ -350,7 +368,7 @@ class ProviderManager:
|
|
|
350
368
|
):
|
|
351
369
|
self.curr_stt_provider_inst = inst
|
|
352
370
|
logger.info(
|
|
353
|
-
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。"
|
|
371
|
+
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。",
|
|
354
372
|
)
|
|
355
373
|
if not self.curr_stt_provider_inst:
|
|
356
374
|
self.curr_stt_provider_inst = inst
|
|
@@ -366,7 +384,7 @@ class ProviderManager:
|
|
|
366
384
|
if self.provider_settings.get("provider_id") == provider_config["id"]:
|
|
367
385
|
self.curr_tts_provider_inst = inst
|
|
368
386
|
logger.info(
|
|
369
|
-
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。"
|
|
387
|
+
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。",
|
|
370
388
|
)
|
|
371
389
|
if not self.curr_tts_provider_inst:
|
|
372
390
|
self.curr_tts_provider_inst = inst
|
|
@@ -389,7 +407,7 @@ class ProviderManager:
|
|
|
389
407
|
):
|
|
390
408
|
self.curr_provider_inst = inst
|
|
391
409
|
logger.info(
|
|
392
|
-
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。"
|
|
410
|
+
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。",
|
|
393
411
|
)
|
|
394
412
|
if not self.curr_provider_inst:
|
|
395
413
|
self.curr_provider_inst = inst
|
|
@@ -408,10 +426,10 @@ class ProviderManager:
|
|
|
408
426
|
self.inst_map[provider_config["id"]] = inst
|
|
409
427
|
except Exception as e:
|
|
410
428
|
logger.error(
|
|
411
|
-
f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}"
|
|
429
|
+
f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}",
|
|
412
430
|
)
|
|
413
431
|
raise Exception(
|
|
414
|
-
f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}"
|
|
432
|
+
f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}",
|
|
415
433
|
)
|
|
416
434
|
|
|
417
435
|
async def reload(self, provider_config: dict):
|
|
@@ -431,7 +449,7 @@ class ProviderManager:
|
|
|
431
449
|
elif self.curr_provider_inst is None and len(self.provider_insts) > 0:
|
|
432
450
|
self.curr_provider_inst = self.provider_insts[0]
|
|
433
451
|
logger.info(
|
|
434
|
-
f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。"
|
|
452
|
+
f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。",
|
|
435
453
|
)
|
|
436
454
|
|
|
437
455
|
if len(self.stt_provider_insts) == 0:
|
|
@@ -439,7 +457,7 @@ class ProviderManager:
|
|
|
439
457
|
elif self.curr_stt_provider_inst is None and len(self.stt_provider_insts) > 0:
|
|
440
458
|
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
|
441
459
|
logger.info(
|
|
442
|
-
f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。"
|
|
460
|
+
f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。",
|
|
443
461
|
)
|
|
444
462
|
|
|
445
463
|
if len(self.tts_provider_insts) == 0:
|
|
@@ -447,7 +465,7 @@ class ProviderManager:
|
|
|
447
465
|
elif self.curr_tts_provider_inst is None and len(self.tts_provider_insts) > 0:
|
|
448
466
|
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
|
449
467
|
logger.info(
|
|
450
|
-
f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。"
|
|
468
|
+
f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。",
|
|
451
469
|
)
|
|
452
470
|
|
|
453
471
|
def get_insts(self):
|
|
@@ -456,7 +474,7 @@ class ProviderManager:
|
|
|
456
474
|
async def terminate_provider(self, provider_id: str):
|
|
457
475
|
if provider_id in self.inst_map:
|
|
458
476
|
logger.info(
|
|
459
|
-
f"终止 {provider_id} 提供商适配器({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)}) ..."
|
|
477
|
+
f"终止 {provider_id} 提供商适配器({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)}) ...",
|
|
460
478
|
)
|
|
461
479
|
|
|
462
480
|
if self.inst_map[provider_id] in self.provider_insts:
|
|
@@ -483,7 +501,7 @@ class ProviderManager:
|
|
|
483
501
|
await self.inst_map[provider_id].terminate() # type: ignore
|
|
484
502
|
|
|
485
503
|
logger.info(
|
|
486
|
-
f"{provider_id} 提供商适配器已终止({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)})"
|
|
504
|
+
f"{provider_id} 提供商适配器已终止({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)})",
|
|
487
505
|
)
|
|
488
506
|
del self.inst_map[provider_id]
|
|
489
507
|
|