AstrBot 3.5.6__py3-none-any.whl → 4.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/api/__init__.py +16 -4
- astrbot/api/all.py +2 -1
- astrbot/api/event/__init__.py +5 -6
- astrbot/api/event/filter/__init__.py +37 -34
- astrbot/api/platform/__init__.py +7 -8
- astrbot/api/provider/__init__.py +8 -7
- astrbot/api/star/__init__.py +3 -4
- astrbot/api/util/__init__.py +2 -2
- astrbot/cli/__init__.py +1 -0
- astrbot/cli/__main__.py +18 -197
- astrbot/cli/commands/__init__.py +6 -0
- astrbot/cli/commands/cmd_conf.py +209 -0
- astrbot/cli/commands/cmd_init.py +56 -0
- astrbot/cli/commands/cmd_plug.py +245 -0
- astrbot/cli/commands/cmd_run.py +62 -0
- astrbot/cli/utils/__init__.py +18 -0
- astrbot/cli/utils/basic.py +76 -0
- astrbot/cli/utils/plugin.py +246 -0
- astrbot/cli/utils/version_comparator.py +90 -0
- astrbot/core/__init__.py +17 -19
- astrbot/core/agent/agent.py +14 -0
- astrbot/core/agent/handoff.py +38 -0
- astrbot/core/agent/hooks.py +30 -0
- astrbot/core/agent/mcp_client.py +385 -0
- astrbot/core/agent/message.py +175 -0
- astrbot/core/agent/response.py +14 -0
- astrbot/core/agent/run_context.py +22 -0
- astrbot/core/agent/runners/__init__.py +3 -0
- astrbot/core/agent/runners/base.py +65 -0
- astrbot/core/agent/runners/coze/coze_agent_runner.py +367 -0
- astrbot/core/agent/runners/coze/coze_api_client.py +324 -0
- astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py +403 -0
- astrbot/core/agent/runners/dify/dify_agent_runner.py +336 -0
- astrbot/core/agent/runners/dify/dify_api_client.py +195 -0
- astrbot/core/agent/runners/tool_loop_agent_runner.py +400 -0
- astrbot/core/agent/tool.py +285 -0
- astrbot/core/agent/tool_executor.py +17 -0
- astrbot/core/astr_agent_context.py +19 -0
- astrbot/core/astr_agent_hooks.py +36 -0
- astrbot/core/astr_agent_run_util.py +80 -0
- astrbot/core/astr_agent_tool_exec.py +246 -0
- astrbot/core/astrbot_config_mgr.py +275 -0
- astrbot/core/config/__init__.py +2 -2
- astrbot/core/config/astrbot_config.py +60 -20
- astrbot/core/config/default.py +1972 -453
- astrbot/core/config/i18n_utils.py +110 -0
- astrbot/core/conversation_mgr.py +285 -75
- astrbot/core/core_lifecycle.py +167 -62
- astrbot/core/db/__init__.py +305 -102
- astrbot/core/db/migration/helper.py +69 -0
- astrbot/core/db/migration/migra_3_to_4.py +357 -0
- astrbot/core/db/migration/migra_45_to_46.py +44 -0
- astrbot/core/db/migration/migra_webchat_session.py +131 -0
- astrbot/core/db/migration/shared_preferences_v3.py +48 -0
- astrbot/core/db/migration/sqlite_v3.py +497 -0
- astrbot/core/db/po.py +259 -55
- astrbot/core/db/sqlite.py +773 -528
- astrbot/core/db/vec_db/base.py +73 -0
- astrbot/core/db/vec_db/faiss_impl/__init__.py +3 -0
- astrbot/core/db/vec_db/faiss_impl/document_storage.py +392 -0
- astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +93 -0
- astrbot/core/db/vec_db/faiss_impl/sqlite_init.sql +17 -0
- astrbot/core/db/vec_db/faiss_impl/vec_db.py +204 -0
- astrbot/core/event_bus.py +26 -22
- astrbot/core/exceptions.py +9 -0
- astrbot/core/file_token_service.py +98 -0
- astrbot/core/initial_loader.py +19 -10
- astrbot/core/knowledge_base/chunking/__init__.py +9 -0
- astrbot/core/knowledge_base/chunking/base.py +25 -0
- astrbot/core/knowledge_base/chunking/fixed_size.py +59 -0
- astrbot/core/knowledge_base/chunking/recursive.py +161 -0
- astrbot/core/knowledge_base/kb_db_sqlite.py +301 -0
- astrbot/core/knowledge_base/kb_helper.py +642 -0
- astrbot/core/knowledge_base/kb_mgr.py +330 -0
- astrbot/core/knowledge_base/models.py +120 -0
- astrbot/core/knowledge_base/parsers/__init__.py +13 -0
- astrbot/core/knowledge_base/parsers/base.py +51 -0
- astrbot/core/knowledge_base/parsers/markitdown_parser.py +26 -0
- astrbot/core/knowledge_base/parsers/pdf_parser.py +101 -0
- astrbot/core/knowledge_base/parsers/text_parser.py +42 -0
- astrbot/core/knowledge_base/parsers/url_parser.py +103 -0
- astrbot/core/knowledge_base/parsers/util.py +13 -0
- astrbot/core/knowledge_base/prompts.py +65 -0
- astrbot/core/knowledge_base/retrieval/__init__.py +14 -0
- astrbot/core/knowledge_base/retrieval/hit_stopwords.txt +767 -0
- astrbot/core/knowledge_base/retrieval/manager.py +276 -0
- astrbot/core/knowledge_base/retrieval/rank_fusion.py +142 -0
- astrbot/core/knowledge_base/retrieval/sparse_retriever.py +136 -0
- astrbot/core/log.py +21 -15
- astrbot/core/message/components.py +413 -287
- astrbot/core/message/message_event_result.py +35 -24
- astrbot/core/persona_mgr.py +192 -0
- astrbot/core/pipeline/__init__.py +14 -14
- astrbot/core/pipeline/content_safety_check/stage.py +13 -9
- astrbot/core/pipeline/content_safety_check/strategies/__init__.py +1 -2
- astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py +13 -14
- astrbot/core/pipeline/content_safety_check/strategies/keywords.py +2 -1
- astrbot/core/pipeline/content_safety_check/strategies/strategy.py +6 -6
- astrbot/core/pipeline/context.py +7 -1
- astrbot/core/pipeline/context_utils.py +107 -0
- astrbot/core/pipeline/preprocess_stage/stage.py +63 -36
- astrbot/core/pipeline/process_stage/method/agent_request.py +48 -0
- astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +464 -0
- astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +202 -0
- astrbot/core/pipeline/process_stage/method/star_request.py +26 -32
- astrbot/core/pipeline/process_stage/stage.py +21 -15
- astrbot/core/pipeline/process_stage/utils.py +125 -0
- astrbot/core/pipeline/rate_limit_check/stage.py +34 -36
- astrbot/core/pipeline/respond/stage.py +142 -101
- astrbot/core/pipeline/result_decorate/stage.py +124 -57
- astrbot/core/pipeline/scheduler.py +21 -16
- astrbot/core/pipeline/session_status_check/stage.py +37 -0
- astrbot/core/pipeline/stage.py +11 -76
- astrbot/core/pipeline/waking_check/stage.py +69 -33
- astrbot/core/pipeline/whitelist_check/stage.py +10 -7
- astrbot/core/platform/__init__.py +6 -6
- astrbot/core/platform/astr_message_event.py +107 -129
- astrbot/core/platform/astrbot_message.py +32 -12
- astrbot/core/platform/manager.py +62 -18
- astrbot/core/platform/message_session.py +30 -0
- astrbot/core/platform/platform.py +16 -24
- astrbot/core/platform/platform_metadata.py +9 -4
- astrbot/core/platform/register.py +12 -7
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +136 -60
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +126 -46
- astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +63 -31
- astrbot/core/platform/sources/dingtalk/dingtalk_event.py +30 -26
- astrbot/core/platform/sources/discord/client.py +129 -0
- astrbot/core/platform/sources/discord/components.py +139 -0
- astrbot/core/platform/sources/discord/discord_platform_adapter.py +473 -0
- astrbot/core/platform/sources/discord/discord_platform_event.py +313 -0
- astrbot/core/platform/sources/lark/lark_adapter.py +27 -18
- astrbot/core/platform/sources/lark/lark_event.py +39 -13
- astrbot/core/platform/sources/misskey/misskey_adapter.py +770 -0
- astrbot/core/platform/sources/misskey/misskey_api.py +964 -0
- astrbot/core/platform/sources/misskey/misskey_event.py +163 -0
- astrbot/core/platform/sources/misskey/misskey_utils.py +550 -0
- astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +149 -33
- astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +41 -26
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +36 -17
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py +3 -1
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +14 -8
- astrbot/core/platform/sources/satori/satori_adapter.py +792 -0
- astrbot/core/platform/sources/satori/satori_event.py +432 -0
- astrbot/core/platform/sources/slack/client.py +164 -0
- astrbot/core/platform/sources/slack/slack_adapter.py +416 -0
- astrbot/core/platform/sources/slack/slack_event.py +253 -0
- astrbot/core/platform/sources/telegram/tg_adapter.py +100 -43
- astrbot/core/platform/sources/telegram/tg_event.py +136 -36
- astrbot/core/platform/sources/webchat/webchat_adapter.py +72 -22
- astrbot/core/platform/sources/webchat/webchat_event.py +46 -22
- astrbot/core/platform/sources/webchat/webchat_queue_mgr.py +35 -0
- astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +926 -0
- astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +178 -0
- astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py +159 -0
- astrbot/core/platform/sources/wecom/wecom_adapter.py +169 -27
- astrbot/core/platform/sources/wecom/wecom_event.py +162 -77
- astrbot/core/platform/sources/wecom/wecom_kf.py +279 -0
- astrbot/core/platform/sources/wecom/wecom_kf_message.py +196 -0
- astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +297 -0
- astrbot/core/platform/sources/wecom_ai_bot/__init__.py +15 -0
- astrbot/core/platform/sources/wecom_ai_bot/ierror.py +19 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +472 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +417 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +152 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +153 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +168 -0
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +209 -0
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +306 -0
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +186 -0
- astrbot/core/platform_message_history_mgr.py +49 -0
- astrbot/core/provider/__init__.py +2 -3
- astrbot/core/provider/entites.py +8 -8
- astrbot/core/provider/entities.py +154 -98
- astrbot/core/provider/func_tool_manager.py +446 -458
- astrbot/core/provider/manager.py +345 -207
- astrbot/core/provider/provider.py +188 -73
- astrbot/core/provider/register.py +9 -7
- astrbot/core/provider/sources/anthropic_source.py +295 -115
- astrbot/core/provider/sources/azure_tts_source.py +224 -0
- astrbot/core/provider/sources/bailian_rerank_source.py +236 -0
- astrbot/core/provider/sources/dashscope_tts.py +138 -14
- astrbot/core/provider/sources/edge_tts_source.py +24 -19
- astrbot/core/provider/sources/fishaudio_tts_api_source.py +58 -13
- astrbot/core/provider/sources/gemini_embedding_source.py +61 -0
- astrbot/core/provider/sources/gemini_source.py +310 -132
- astrbot/core/provider/sources/gemini_tts_source.py +81 -0
- astrbot/core/provider/sources/groq_source.py +15 -0
- astrbot/core/provider/sources/gsv_selfhosted_source.py +151 -0
- astrbot/core/provider/sources/gsvi_tts_source.py +14 -7
- astrbot/core/provider/sources/minimax_tts_api_source.py +159 -0
- astrbot/core/provider/sources/openai_embedding_source.py +40 -0
- astrbot/core/provider/sources/openai_source.py +241 -145
- astrbot/core/provider/sources/openai_tts_api_source.py +18 -7
- astrbot/core/provider/sources/sensevoice_selfhosted_source.py +13 -11
- astrbot/core/provider/sources/vllm_rerank_source.py +71 -0
- astrbot/core/provider/sources/volcengine_tts.py +115 -0
- astrbot/core/provider/sources/whisper_api_source.py +18 -13
- astrbot/core/provider/sources/whisper_selfhosted_source.py +19 -12
- astrbot/core/provider/sources/xinference_rerank_source.py +116 -0
- astrbot/core/provider/sources/xinference_stt_provider.py +197 -0
- astrbot/core/provider/sources/zhipu_source.py +6 -73
- astrbot/core/star/__init__.py +43 -11
- astrbot/core/star/config.py +17 -18
- astrbot/core/star/context.py +362 -138
- astrbot/core/star/filter/__init__.py +4 -3
- astrbot/core/star/filter/command.py +111 -35
- astrbot/core/star/filter/command_group.py +46 -34
- astrbot/core/star/filter/custom_filter.py +6 -5
- astrbot/core/star/filter/event_message_type.py +4 -2
- astrbot/core/star/filter/permission.py +4 -2
- astrbot/core/star/filter/platform_adapter_type.py +45 -12
- astrbot/core/star/filter/regex.py +4 -2
- astrbot/core/star/register/__init__.py +19 -15
- astrbot/core/star/register/star.py +41 -13
- astrbot/core/star/register/star_handler.py +236 -86
- astrbot/core/star/session_llm_manager.py +280 -0
- astrbot/core/star/session_plugin_manager.py +170 -0
- astrbot/core/star/star.py +36 -43
- astrbot/core/star/star_handler.py +47 -85
- astrbot/core/star/star_manager.py +442 -260
- astrbot/core/star/star_tools.py +167 -45
- astrbot/core/star/updator.py +17 -20
- astrbot/core/umop_config_router.py +106 -0
- astrbot/core/updator.py +38 -13
- astrbot/core/utils/astrbot_path.py +39 -0
- astrbot/core/utils/command_parser.py +1 -1
- astrbot/core/utils/io.py +119 -60
- astrbot/core/utils/log_pipe.py +1 -1
- astrbot/core/utils/metrics.py +11 -10
- astrbot/core/utils/migra_helper.py +73 -0
- astrbot/core/utils/path_util.py +63 -62
- astrbot/core/utils/pip_installer.py +37 -15
- astrbot/core/utils/session_lock.py +29 -0
- astrbot/core/utils/session_waiter.py +19 -20
- astrbot/core/utils/shared_preferences.py +174 -34
- astrbot/core/utils/t2i/__init__.py +4 -1
- astrbot/core/utils/t2i/local_strategy.py +386 -238
- astrbot/core/utils/t2i/network_strategy.py +109 -49
- astrbot/core/utils/t2i/renderer.py +29 -14
- astrbot/core/utils/t2i/template/astrbot_powershell.html +184 -0
- astrbot/core/utils/t2i/template_manager.py +111 -0
- astrbot/core/utils/tencent_record_helper.py +115 -1
- astrbot/core/utils/version_comparator.py +10 -13
- astrbot/core/zip_updator.py +112 -65
- astrbot/dashboard/routes/__init__.py +20 -13
- astrbot/dashboard/routes/auth.py +20 -9
- astrbot/dashboard/routes/chat.py +297 -141
- astrbot/dashboard/routes/config.py +652 -55
- astrbot/dashboard/routes/conversation.py +107 -37
- astrbot/dashboard/routes/file.py +26 -0
- astrbot/dashboard/routes/knowledge_base.py +1244 -0
- astrbot/dashboard/routes/log.py +27 -2
- astrbot/dashboard/routes/persona.py +202 -0
- astrbot/dashboard/routes/plugin.py +197 -139
- astrbot/dashboard/routes/route.py +27 -7
- astrbot/dashboard/routes/session_management.py +354 -0
- astrbot/dashboard/routes/stat.py +85 -18
- astrbot/dashboard/routes/static_file.py +5 -2
- astrbot/dashboard/routes/t2i.py +233 -0
- astrbot/dashboard/routes/tools.py +184 -120
- astrbot/dashboard/routes/update.py +59 -36
- astrbot/dashboard/server.py +96 -36
- astrbot/dashboard/utils.py +165 -0
- astrbot-4.7.0.dist-info/METADATA +294 -0
- astrbot-4.7.0.dist-info/RECORD +274 -0
- {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/WHEEL +1 -1
- astrbot/core/db/plugin/sqlite_impl.py +0 -112
- astrbot/core/db/sqlite_init.sql +0 -50
- astrbot/core/pipeline/platform_compatibility/stage.py +0 -56
- astrbot/core/pipeline/process_stage/method/llm_request.py +0 -606
- astrbot/core/platform/sources/gewechat/client.py +0 -806
- astrbot/core/platform/sources/gewechat/downloader.py +0 -55
- astrbot/core/platform/sources/gewechat/gewechat_event.py +0 -255
- astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py +0 -103
- astrbot/core/platform/sources/gewechat/xml_data_parser.py +0 -110
- astrbot/core/provider/sources/dashscope_source.py +0 -203
- astrbot/core/provider/sources/dify_source.py +0 -281
- astrbot/core/provider/sources/llmtuner_source.py +0 -132
- astrbot/core/rag/embedding/openai_source.py +0 -20
- astrbot/core/rag/knowledge_db_mgr.py +0 -94
- astrbot/core/rag/store/__init__.py +0 -9
- astrbot/core/rag/store/chroma_db.py +0 -42
- astrbot/core/utils/dify_api_client.py +0 -152
- astrbot-3.5.6.dist-info/METADATA +0 -249
- astrbot-3.5.6.dist-info/RECORD +0 -158
- {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/entry_points.txt +0 -0
- {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,385 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
from contextlib import AsyncExitStack
|
|
4
|
+
from datetime import timedelta
|
|
5
|
+
from typing import Generic
|
|
6
|
+
|
|
7
|
+
from tenacity import (
|
|
8
|
+
before_sleep_log,
|
|
9
|
+
retry,
|
|
10
|
+
retry_if_exception_type,
|
|
11
|
+
stop_after_attempt,
|
|
12
|
+
wait_exponential,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
from astrbot import logger
|
|
16
|
+
from astrbot.core.agent.run_context import ContextWrapper
|
|
17
|
+
from astrbot.core.utils.log_pipe import LogPipe
|
|
18
|
+
|
|
19
|
+
from .run_context import TContext
|
|
20
|
+
from .tool import FunctionTool
|
|
21
|
+
|
|
22
|
+
try:
|
|
23
|
+
import anyio
|
|
24
|
+
import mcp
|
|
25
|
+
from mcp.client.sse import sse_client
|
|
26
|
+
except (ModuleNotFoundError, ImportError):
|
|
27
|
+
logger.warning(
|
|
28
|
+
"Warning: Missing 'mcp' dependency, MCP services will be unavailable."
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
try:
|
|
32
|
+
from mcp.client.streamable_http import streamablehttp_client
|
|
33
|
+
except (ModuleNotFoundError, ImportError):
|
|
34
|
+
logger.warning(
|
|
35
|
+
"Warning: Missing 'mcp' dependency or MCP library version too old, Streamable HTTP connection unavailable.",
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _prepare_config(config: dict) -> dict:
|
|
40
|
+
"""Prepare configuration, handle nested format"""
|
|
41
|
+
if config.get("mcpServers"):
|
|
42
|
+
first_key = next(iter(config["mcpServers"]))
|
|
43
|
+
config = config["mcpServers"][first_key]
|
|
44
|
+
config.pop("active", None)
|
|
45
|
+
return config
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
|
49
|
+
"""Quick test MCP server connectivity"""
|
|
50
|
+
import aiohttp
|
|
51
|
+
|
|
52
|
+
cfg = _prepare_config(config.copy())
|
|
53
|
+
|
|
54
|
+
url = cfg["url"]
|
|
55
|
+
headers = cfg.get("headers", {})
|
|
56
|
+
timeout = cfg.get("timeout", 10)
|
|
57
|
+
|
|
58
|
+
try:
|
|
59
|
+
if "transport" in cfg:
|
|
60
|
+
transport_type = cfg["transport"]
|
|
61
|
+
elif "type" in cfg:
|
|
62
|
+
transport_type = cfg["type"]
|
|
63
|
+
else:
|
|
64
|
+
raise Exception("MCP connection config missing transport or type field")
|
|
65
|
+
|
|
66
|
+
async with aiohttp.ClientSession() as session:
|
|
67
|
+
if transport_type == "streamable_http":
|
|
68
|
+
test_payload = {
|
|
69
|
+
"jsonrpc": "2.0",
|
|
70
|
+
"method": "initialize",
|
|
71
|
+
"id": 0,
|
|
72
|
+
"params": {
|
|
73
|
+
"protocolVersion": "2024-11-05",
|
|
74
|
+
"capabilities": {},
|
|
75
|
+
"clientInfo": {"name": "test-client", "version": "1.2.3"},
|
|
76
|
+
},
|
|
77
|
+
}
|
|
78
|
+
async with session.post(
|
|
79
|
+
url,
|
|
80
|
+
headers={
|
|
81
|
+
**headers,
|
|
82
|
+
"Content-Type": "application/json",
|
|
83
|
+
"Accept": "application/json, text/event-stream",
|
|
84
|
+
},
|
|
85
|
+
json=test_payload,
|
|
86
|
+
timeout=aiohttp.ClientTimeout(total=timeout),
|
|
87
|
+
) as response:
|
|
88
|
+
if response.status == 200:
|
|
89
|
+
return True, ""
|
|
90
|
+
return False, f"HTTP {response.status}: {response.reason}"
|
|
91
|
+
else:
|
|
92
|
+
async with session.get(
|
|
93
|
+
url,
|
|
94
|
+
headers={
|
|
95
|
+
**headers,
|
|
96
|
+
"Accept": "application/json, text/event-stream",
|
|
97
|
+
},
|
|
98
|
+
timeout=aiohttp.ClientTimeout(total=timeout),
|
|
99
|
+
) as response:
|
|
100
|
+
if response.status == 200:
|
|
101
|
+
return True, ""
|
|
102
|
+
return False, f"HTTP {response.status}: {response.reason}"
|
|
103
|
+
|
|
104
|
+
except asyncio.TimeoutError:
|
|
105
|
+
return False, f"Connection timeout: {timeout} seconds"
|
|
106
|
+
except Exception as e:
|
|
107
|
+
return False, f"{e!s}"
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class MCPClient:
|
|
111
|
+
def __init__(self):
|
|
112
|
+
# Initialize session and client objects
|
|
113
|
+
self.session: mcp.ClientSession | None = None
|
|
114
|
+
self.exit_stack = AsyncExitStack()
|
|
115
|
+
self._old_exit_stacks: list[AsyncExitStack] = [] # Track old stacks for cleanup
|
|
116
|
+
|
|
117
|
+
self.name: str | None = None
|
|
118
|
+
self.active: bool = True
|
|
119
|
+
self.tools: list[mcp.Tool] = []
|
|
120
|
+
self.server_errlogs: list[str] = []
|
|
121
|
+
self.running_event = asyncio.Event()
|
|
122
|
+
|
|
123
|
+
# Store connection config for reconnection
|
|
124
|
+
self._mcp_server_config: dict | None = None
|
|
125
|
+
self._server_name: str | None = None
|
|
126
|
+
self._reconnect_lock = asyncio.Lock() # Lock for thread-safe reconnection
|
|
127
|
+
self._reconnecting: bool = False # For logging and debugging
|
|
128
|
+
|
|
129
|
+
async def connect_to_server(self, mcp_server_config: dict, name: str):
|
|
130
|
+
"""Connect to MCP server
|
|
131
|
+
|
|
132
|
+
If `url` parameter exists:
|
|
133
|
+
1. When transport is specified as `streamable_http`, use Streamable HTTP connection.
|
|
134
|
+
2. When transport is specified as `sse`, use SSE connection.
|
|
135
|
+
3. If not specified, default to SSE connection to MCP service.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
|
|
139
|
+
|
|
140
|
+
"""
|
|
141
|
+
# Store config for reconnection
|
|
142
|
+
self._mcp_server_config = mcp_server_config
|
|
143
|
+
self._server_name = name
|
|
144
|
+
|
|
145
|
+
cfg = _prepare_config(mcp_server_config.copy())
|
|
146
|
+
|
|
147
|
+
def logging_callback(msg: str):
|
|
148
|
+
# Handle MCP service error logs
|
|
149
|
+
print(f"MCP Server {name} Error: {msg}")
|
|
150
|
+
self.server_errlogs.append(msg)
|
|
151
|
+
|
|
152
|
+
if "url" in cfg:
|
|
153
|
+
success, error_msg = await _quick_test_mcp_connection(cfg)
|
|
154
|
+
if not success:
|
|
155
|
+
raise Exception(error_msg)
|
|
156
|
+
|
|
157
|
+
if "transport" in cfg:
|
|
158
|
+
transport_type = cfg["transport"]
|
|
159
|
+
elif "type" in cfg:
|
|
160
|
+
transport_type = cfg["type"]
|
|
161
|
+
else:
|
|
162
|
+
raise Exception("MCP connection config missing transport or type field")
|
|
163
|
+
|
|
164
|
+
if transport_type != "streamable_http":
|
|
165
|
+
# SSE transport method
|
|
166
|
+
self._streams_context = sse_client(
|
|
167
|
+
url=cfg["url"],
|
|
168
|
+
headers=cfg.get("headers", {}),
|
|
169
|
+
timeout=cfg.get("timeout", 5),
|
|
170
|
+
sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
|
|
171
|
+
)
|
|
172
|
+
streams = await self.exit_stack.enter_async_context(
|
|
173
|
+
self._streams_context,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
# Create a new client session
|
|
177
|
+
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 60))
|
|
178
|
+
self.session = await self.exit_stack.enter_async_context(
|
|
179
|
+
mcp.ClientSession(
|
|
180
|
+
*streams,
|
|
181
|
+
read_timeout_seconds=read_timeout,
|
|
182
|
+
logging_callback=logging_callback, # type: ignore
|
|
183
|
+
),
|
|
184
|
+
)
|
|
185
|
+
else:
|
|
186
|
+
timeout = timedelta(seconds=cfg.get("timeout", 30))
|
|
187
|
+
sse_read_timeout = timedelta(
|
|
188
|
+
seconds=cfg.get("sse_read_timeout", 60 * 5),
|
|
189
|
+
)
|
|
190
|
+
self._streams_context = streamablehttp_client(
|
|
191
|
+
url=cfg["url"],
|
|
192
|
+
headers=cfg.get("headers", {}),
|
|
193
|
+
timeout=timeout,
|
|
194
|
+
sse_read_timeout=sse_read_timeout,
|
|
195
|
+
terminate_on_close=cfg.get("terminate_on_close", True),
|
|
196
|
+
)
|
|
197
|
+
read_s, write_s, _ = await self.exit_stack.enter_async_context(
|
|
198
|
+
self._streams_context,
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
# Create a new client session
|
|
202
|
+
read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 60))
|
|
203
|
+
self.session = await self.exit_stack.enter_async_context(
|
|
204
|
+
mcp.ClientSession(
|
|
205
|
+
read_stream=read_s,
|
|
206
|
+
write_stream=write_s,
|
|
207
|
+
read_timeout_seconds=read_timeout,
|
|
208
|
+
logging_callback=logging_callback, # type: ignore
|
|
209
|
+
),
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
else:
|
|
213
|
+
server_params = mcp.StdioServerParameters(
|
|
214
|
+
**cfg,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
def callback(msg: str):
|
|
218
|
+
# Handle MCP service error logs
|
|
219
|
+
self.server_errlogs.append(msg)
|
|
220
|
+
|
|
221
|
+
stdio_transport = await self.exit_stack.enter_async_context(
|
|
222
|
+
mcp.stdio_client(
|
|
223
|
+
server_params,
|
|
224
|
+
errlog=LogPipe(
|
|
225
|
+
level=logging.ERROR,
|
|
226
|
+
logger=logger,
|
|
227
|
+
identifier=f"MCPServer-{name}",
|
|
228
|
+
callback=callback,
|
|
229
|
+
), # type: ignore
|
|
230
|
+
),
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
# Create a new client session
|
|
234
|
+
self.session = await self.exit_stack.enter_async_context(
|
|
235
|
+
mcp.ClientSession(*stdio_transport),
|
|
236
|
+
)
|
|
237
|
+
await self.session.initialize()
|
|
238
|
+
|
|
239
|
+
async def list_tools_and_save(self) -> mcp.ListToolsResult:
|
|
240
|
+
"""List all tools from the server and save them to self.tools"""
|
|
241
|
+
if not self.session:
|
|
242
|
+
raise Exception("MCP Client is not initialized")
|
|
243
|
+
response = await self.session.list_tools()
|
|
244
|
+
self.tools = response.tools
|
|
245
|
+
return response
|
|
246
|
+
|
|
247
|
+
async def _reconnect(self) -> None:
|
|
248
|
+
"""Reconnect to the MCP server using the stored configuration.
|
|
249
|
+
|
|
250
|
+
Uses asyncio.Lock to ensure thread-safe reconnection in concurrent environments.
|
|
251
|
+
|
|
252
|
+
Raises:
|
|
253
|
+
Exception: raised when reconnection fails
|
|
254
|
+
"""
|
|
255
|
+
async with self._reconnect_lock:
|
|
256
|
+
# Check if already reconnecting (useful for logging)
|
|
257
|
+
if self._reconnecting:
|
|
258
|
+
logger.debug(
|
|
259
|
+
f"MCP Client {self._server_name} is already reconnecting, skipping"
|
|
260
|
+
)
|
|
261
|
+
return
|
|
262
|
+
|
|
263
|
+
if not self._mcp_server_config or not self._server_name:
|
|
264
|
+
raise Exception("Cannot reconnect: missing connection configuration")
|
|
265
|
+
|
|
266
|
+
self._reconnecting = True
|
|
267
|
+
try:
|
|
268
|
+
logger.info(
|
|
269
|
+
f"Attempting to reconnect to MCP server {self._server_name}..."
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
# Save old exit_stack for later cleanup (don't close it now to avoid cancel scope issues)
|
|
273
|
+
if self.exit_stack:
|
|
274
|
+
self._old_exit_stacks.append(self.exit_stack)
|
|
275
|
+
|
|
276
|
+
# Mark old session as invalid
|
|
277
|
+
self.session = None
|
|
278
|
+
|
|
279
|
+
# Create new exit stack for new connection
|
|
280
|
+
self.exit_stack = AsyncExitStack()
|
|
281
|
+
|
|
282
|
+
# Reconnect using stored config
|
|
283
|
+
await self.connect_to_server(self._mcp_server_config, self._server_name)
|
|
284
|
+
await self.list_tools_and_save()
|
|
285
|
+
|
|
286
|
+
logger.info(
|
|
287
|
+
f"Successfully reconnected to MCP server {self._server_name}"
|
|
288
|
+
)
|
|
289
|
+
except Exception as e:
|
|
290
|
+
logger.error(
|
|
291
|
+
f"Failed to reconnect to MCP server {self._server_name}: {e}"
|
|
292
|
+
)
|
|
293
|
+
raise
|
|
294
|
+
finally:
|
|
295
|
+
self._reconnecting = False
|
|
296
|
+
|
|
297
|
+
async def call_tool_with_reconnect(
|
|
298
|
+
self,
|
|
299
|
+
tool_name: str,
|
|
300
|
+
arguments: dict,
|
|
301
|
+
read_timeout_seconds: timedelta,
|
|
302
|
+
) -> mcp.types.CallToolResult:
|
|
303
|
+
"""Call MCP tool with automatic reconnection on failure, max 2 retries.
|
|
304
|
+
|
|
305
|
+
Args:
|
|
306
|
+
tool_name: tool name
|
|
307
|
+
arguments: tool arguments
|
|
308
|
+
read_timeout_seconds: read timeout
|
|
309
|
+
|
|
310
|
+
Returns:
|
|
311
|
+
MCP tool call result
|
|
312
|
+
|
|
313
|
+
Raises:
|
|
314
|
+
ValueError: MCP session is not available
|
|
315
|
+
anyio.ClosedResourceError: raised after reconnection failure
|
|
316
|
+
"""
|
|
317
|
+
|
|
318
|
+
@retry(
|
|
319
|
+
retry=retry_if_exception_type(anyio.ClosedResourceError),
|
|
320
|
+
stop=stop_after_attempt(2),
|
|
321
|
+
wait=wait_exponential(multiplier=1, min=1, max=3),
|
|
322
|
+
before_sleep=before_sleep_log(logger, logging.WARNING),
|
|
323
|
+
reraise=True,
|
|
324
|
+
)
|
|
325
|
+
async def _call_with_retry():
|
|
326
|
+
if not self.session:
|
|
327
|
+
raise ValueError("MCP session is not available for MCP function tools.")
|
|
328
|
+
|
|
329
|
+
try:
|
|
330
|
+
return await self.session.call_tool(
|
|
331
|
+
name=tool_name,
|
|
332
|
+
arguments=arguments,
|
|
333
|
+
read_timeout_seconds=read_timeout_seconds,
|
|
334
|
+
)
|
|
335
|
+
except anyio.ClosedResourceError:
|
|
336
|
+
logger.warning(
|
|
337
|
+
f"MCP tool {tool_name} call failed (ClosedResourceError), attempting to reconnect..."
|
|
338
|
+
)
|
|
339
|
+
# Attempt to reconnect
|
|
340
|
+
await self._reconnect()
|
|
341
|
+
# Reraise the exception to trigger tenacity retry
|
|
342
|
+
raise
|
|
343
|
+
|
|
344
|
+
return await _call_with_retry()
|
|
345
|
+
|
|
346
|
+
async def cleanup(self):
|
|
347
|
+
"""Clean up resources including old exit stacks from reconnections"""
|
|
348
|
+
# Close current exit stack
|
|
349
|
+
try:
|
|
350
|
+
await self.exit_stack.aclose()
|
|
351
|
+
except Exception as e:
|
|
352
|
+
logger.debug(f"Error closing current exit stack: {e}")
|
|
353
|
+
|
|
354
|
+
# Don't close old exit stacks as they may be in different task contexts
|
|
355
|
+
# They will be garbage collected naturally
|
|
356
|
+
# Just clear the list to release references
|
|
357
|
+
self._old_exit_stacks.clear()
|
|
358
|
+
|
|
359
|
+
# Set running_event first to unblock any waiting tasks
|
|
360
|
+
self.running_event.set()
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
class MCPTool(FunctionTool, Generic[TContext]):
|
|
364
|
+
"""A function tool that calls an MCP service."""
|
|
365
|
+
|
|
366
|
+
def __init__(
|
|
367
|
+
self, mcp_tool: mcp.Tool, mcp_client: MCPClient, mcp_server_name: str, **kwargs
|
|
368
|
+
):
|
|
369
|
+
super().__init__(
|
|
370
|
+
name=mcp_tool.name,
|
|
371
|
+
description=mcp_tool.description or "",
|
|
372
|
+
parameters=mcp_tool.inputSchema,
|
|
373
|
+
)
|
|
374
|
+
self.mcp_tool = mcp_tool
|
|
375
|
+
self.mcp_client = mcp_client
|
|
376
|
+
self.mcp_server_name = mcp_server_name
|
|
377
|
+
|
|
378
|
+
async def call(
|
|
379
|
+
self, context: ContextWrapper[TContext], **kwargs
|
|
380
|
+
) -> mcp.types.CallToolResult:
|
|
381
|
+
return await self.mcp_client.call_tool_with_reconnect(
|
|
382
|
+
tool_name=self.mcp_tool.name,
|
|
383
|
+
arguments=kwargs,
|
|
384
|
+
read_timeout_seconds=timedelta(seconds=context.tool_call_timeout),
|
|
385
|
+
)
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
# Inspired by MoonshotAI/kosong, credits to MoonshotAI/kosong authors for the original implementation.
|
|
2
|
+
# License: Apache License 2.0
|
|
3
|
+
|
|
4
|
+
from typing import Any, ClassVar, Literal, cast
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, GetCoreSchemaHandler
|
|
7
|
+
from pydantic_core import core_schema
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ContentPart(BaseModel):
|
|
11
|
+
"""A part of the content in a message."""
|
|
12
|
+
|
|
13
|
+
__content_part_registry: ClassVar[dict[str, type["ContentPart"]]] = {}
|
|
14
|
+
|
|
15
|
+
type: str
|
|
16
|
+
|
|
17
|
+
def __init_subclass__(cls, **kwargs: Any) -> None:
|
|
18
|
+
super().__init_subclass__(**kwargs)
|
|
19
|
+
|
|
20
|
+
invalid_subclass_error_msg = f"ContentPart subclass {cls.__name__} must have a `type` field of type `str`"
|
|
21
|
+
|
|
22
|
+
type_value = getattr(cls, "type", None)
|
|
23
|
+
if type_value is None or not isinstance(type_value, str):
|
|
24
|
+
raise ValueError(invalid_subclass_error_msg)
|
|
25
|
+
|
|
26
|
+
cls.__content_part_registry[type_value] = cls
|
|
27
|
+
|
|
28
|
+
@classmethod
|
|
29
|
+
def __get_pydantic_core_schema__(
|
|
30
|
+
cls, source_type: Any, handler: GetCoreSchemaHandler
|
|
31
|
+
) -> core_schema.CoreSchema:
|
|
32
|
+
# If we're dealing with the base ContentPart class, use custom validation
|
|
33
|
+
if cls.__name__ == "ContentPart":
|
|
34
|
+
|
|
35
|
+
def validate_content_part(value: Any) -> Any:
|
|
36
|
+
# if it's already an instance of a ContentPart subclass, return it
|
|
37
|
+
if hasattr(value, "__class__") and issubclass(value.__class__, cls):
|
|
38
|
+
return value
|
|
39
|
+
|
|
40
|
+
# if it's a dict with a type field, dispatch to the appropriate subclass
|
|
41
|
+
if isinstance(value, dict) and "type" in value:
|
|
42
|
+
type_value: Any | None = cast(dict[str, Any], value).get("type")
|
|
43
|
+
if not isinstance(type_value, str):
|
|
44
|
+
raise ValueError(f"Cannot validate {value} as ContentPart")
|
|
45
|
+
target_class = cls.__content_part_registry[type_value]
|
|
46
|
+
return target_class.model_validate(value)
|
|
47
|
+
|
|
48
|
+
raise ValueError(f"Cannot validate {value} as ContentPart")
|
|
49
|
+
|
|
50
|
+
return core_schema.no_info_plain_validator_function(validate_content_part)
|
|
51
|
+
|
|
52
|
+
# for subclasses, use the default schema
|
|
53
|
+
return handler(source_type)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class TextPart(ContentPart):
|
|
57
|
+
"""
|
|
58
|
+
>>> TextPart(text="Hello, world!").model_dump()
|
|
59
|
+
{'type': 'text', 'text': 'Hello, world!'}
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
type: str = "text"
|
|
63
|
+
text: str
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class ImageURLPart(ContentPart):
|
|
67
|
+
"""
|
|
68
|
+
>>> ImageURLPart(image_url="http://example.com/image.jpg").model_dump()
|
|
69
|
+
{'type': 'image_url', 'image_url': 'http://example.com/image.jpg'}
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
class ImageURL(BaseModel):
|
|
73
|
+
url: str
|
|
74
|
+
"""The URL of the image, can be data URI scheme like `data:image/png;base64,...`."""
|
|
75
|
+
id: str | None = None
|
|
76
|
+
"""The ID of the image, to allow LLMs to distinguish different images."""
|
|
77
|
+
|
|
78
|
+
type: str = "image_url"
|
|
79
|
+
image_url: ImageURL
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class AudioURLPart(ContentPart):
|
|
83
|
+
"""
|
|
84
|
+
>>> AudioURLPart(audio_url=AudioURLPart.AudioURL(url="https://example.com/audio.mp3")).model_dump()
|
|
85
|
+
{'type': 'audio_url', 'audio_url': {'url': 'https://example.com/audio.mp3', 'id': None}}
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
class AudioURL(BaseModel):
|
|
89
|
+
url: str
|
|
90
|
+
"""The URL of the audio, can be data URI scheme like `data:audio/aac;base64,...`."""
|
|
91
|
+
id: str | None = None
|
|
92
|
+
"""The ID of the audio, to allow LLMs to distinguish different audios."""
|
|
93
|
+
|
|
94
|
+
type: str = "audio_url"
|
|
95
|
+
audio_url: AudioURL
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class ToolCall(BaseModel):
|
|
99
|
+
"""
|
|
100
|
+
A tool call requested by the assistant.
|
|
101
|
+
|
|
102
|
+
>>> ToolCall(
|
|
103
|
+
... id="123",
|
|
104
|
+
... function=ToolCall.FunctionBody(
|
|
105
|
+
... name="function",
|
|
106
|
+
... arguments="{}"
|
|
107
|
+
... ),
|
|
108
|
+
... ).model_dump()
|
|
109
|
+
{'type': 'function', 'id': '123', 'function': {'name': 'function', 'arguments': '{}'}}
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
class FunctionBody(BaseModel):
|
|
113
|
+
name: str
|
|
114
|
+
arguments: str | None
|
|
115
|
+
|
|
116
|
+
type: Literal["function"] = "function"
|
|
117
|
+
|
|
118
|
+
id: str
|
|
119
|
+
"""The ID of the tool call."""
|
|
120
|
+
function: FunctionBody
|
|
121
|
+
"""The function body of the tool call."""
|
|
122
|
+
extra_content: dict[str, Any] | None = None
|
|
123
|
+
"""Extra metadata for the tool call."""
|
|
124
|
+
|
|
125
|
+
def model_dump(self, **kwargs: Any) -> dict[str, Any]:
|
|
126
|
+
if self.extra_content is None:
|
|
127
|
+
kwargs.setdefault("exclude", set()).add("extra_content")
|
|
128
|
+
return super().model_dump(**kwargs)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class ToolCallPart(BaseModel):
|
|
132
|
+
"""A part of the tool call."""
|
|
133
|
+
|
|
134
|
+
arguments_part: str | None = None
|
|
135
|
+
"""A part of the arguments of the tool call."""
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class Message(BaseModel):
|
|
139
|
+
"""A message in a conversation."""
|
|
140
|
+
|
|
141
|
+
role: Literal[
|
|
142
|
+
"system",
|
|
143
|
+
"user",
|
|
144
|
+
"assistant",
|
|
145
|
+
"tool",
|
|
146
|
+
]
|
|
147
|
+
|
|
148
|
+
content: str | list[ContentPart]
|
|
149
|
+
"""The content of the message."""
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class AssistantMessageSegment(Message):
|
|
153
|
+
"""A message segment from the assistant."""
|
|
154
|
+
|
|
155
|
+
role: Literal["assistant"] = "assistant"
|
|
156
|
+
tool_calls: list[ToolCall] | list[dict] | None = None
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class ToolCallMessageSegment(Message):
|
|
160
|
+
"""A message segment representing a tool call."""
|
|
161
|
+
|
|
162
|
+
role: Literal["tool"] = "tool"
|
|
163
|
+
tool_call_id: str
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class UserMessageSegment(Message):
|
|
167
|
+
"""A message segment from the user."""
|
|
168
|
+
|
|
169
|
+
role: Literal["user"] = "user"
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class SystemMessageSegment(Message):
|
|
173
|
+
"""A message segment from the system."""
|
|
174
|
+
|
|
175
|
+
role: Literal["system"] = "system"
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
import typing as T
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
|
|
4
|
+
from astrbot.core.message.message_event_result import MessageChain
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class AgentResponseData(T.TypedDict):
|
|
8
|
+
chain: MessageChain
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class AgentResponse:
|
|
13
|
+
type: str
|
|
14
|
+
data: AgentResponseData
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from typing import Any, Generic
|
|
2
|
+
|
|
3
|
+
from pydantic import Field
|
|
4
|
+
from pydantic.dataclasses import dataclass
|
|
5
|
+
from typing_extensions import TypeVar
|
|
6
|
+
|
|
7
|
+
from .message import Message
|
|
8
|
+
|
|
9
|
+
TContext = TypeVar("TContext", default=Any)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass(config={"arbitrary_types_allowed": True})
|
|
13
|
+
class ContextWrapper(Generic[TContext]):
|
|
14
|
+
"""A context for running an agent, which can be used to pass additional data or state."""
|
|
15
|
+
|
|
16
|
+
context: TContext
|
|
17
|
+
messages: list[Message] = Field(default_factory=list)
|
|
18
|
+
"""This field stores the llm message context for the agent run, agent runners will maintain this field automatically."""
|
|
19
|
+
tool_call_timeout: int = 60 # Default tool call timeout in seconds
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
NoContext = ContextWrapper[None]
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import typing as T
|
|
3
|
+
from enum import Enum, auto
|
|
4
|
+
|
|
5
|
+
from astrbot import logger
|
|
6
|
+
from astrbot.core.provider.entities import LLMResponse
|
|
7
|
+
|
|
8
|
+
from ..hooks import BaseAgentRunHooks
|
|
9
|
+
from ..response import AgentResponse
|
|
10
|
+
from ..run_context import ContextWrapper, TContext
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class AgentState(Enum):
|
|
14
|
+
"""Defines the state of the agent."""
|
|
15
|
+
|
|
16
|
+
IDLE = auto() # Initial state
|
|
17
|
+
RUNNING = auto() # Currently processing
|
|
18
|
+
DONE = auto() # Completed
|
|
19
|
+
ERROR = auto() # Error state
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class BaseAgentRunner(T.Generic[TContext]):
|
|
23
|
+
@abc.abstractmethod
|
|
24
|
+
async def reset(
|
|
25
|
+
self,
|
|
26
|
+
run_context: ContextWrapper[TContext],
|
|
27
|
+
agent_hooks: BaseAgentRunHooks[TContext],
|
|
28
|
+
**kwargs: T.Any,
|
|
29
|
+
) -> None:
|
|
30
|
+
"""Reset the agent to its initial state.
|
|
31
|
+
This method should be called before starting a new run.
|
|
32
|
+
"""
|
|
33
|
+
...
|
|
34
|
+
|
|
35
|
+
@abc.abstractmethod
|
|
36
|
+
async def step(self) -> T.AsyncGenerator[AgentResponse, None]:
|
|
37
|
+
"""Process a single step of the agent."""
|
|
38
|
+
...
|
|
39
|
+
|
|
40
|
+
@abc.abstractmethod
|
|
41
|
+
async def step_until_done(
|
|
42
|
+
self, max_step: int
|
|
43
|
+
) -> T.AsyncGenerator[AgentResponse, None]:
|
|
44
|
+
"""Process steps until the agent is done."""
|
|
45
|
+
...
|
|
46
|
+
|
|
47
|
+
@abc.abstractmethod
|
|
48
|
+
def done(self) -> bool:
|
|
49
|
+
"""Check if the agent has completed its task.
|
|
50
|
+
Returns True if the agent is done, False otherwise.
|
|
51
|
+
"""
|
|
52
|
+
...
|
|
53
|
+
|
|
54
|
+
@abc.abstractmethod
|
|
55
|
+
def get_final_llm_resp(self) -> LLMResponse | None:
|
|
56
|
+
"""Get the final observation from the agent.
|
|
57
|
+
This method should be called after the agent is done.
|
|
58
|
+
"""
|
|
59
|
+
...
|
|
60
|
+
|
|
61
|
+
def _transition_state(self, new_state: AgentState) -> None:
|
|
62
|
+
"""Transition the agent state."""
|
|
63
|
+
if self._state != new_state:
|
|
64
|
+
logger.debug(f"Agent state transition: {self._state} -> {new_state}")
|
|
65
|
+
self._state = new_state
|