AstrBot 3.5.6__py3-none-any.whl → 4.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (288) hide show
  1. astrbot/api/__init__.py +16 -4
  2. astrbot/api/all.py +2 -1
  3. astrbot/api/event/__init__.py +5 -6
  4. astrbot/api/event/filter/__init__.py +37 -34
  5. astrbot/api/platform/__init__.py +7 -8
  6. astrbot/api/provider/__init__.py +8 -7
  7. astrbot/api/star/__init__.py +3 -4
  8. astrbot/api/util/__init__.py +2 -2
  9. astrbot/cli/__init__.py +1 -0
  10. astrbot/cli/__main__.py +18 -197
  11. astrbot/cli/commands/__init__.py +6 -0
  12. astrbot/cli/commands/cmd_conf.py +209 -0
  13. astrbot/cli/commands/cmd_init.py +56 -0
  14. astrbot/cli/commands/cmd_plug.py +245 -0
  15. astrbot/cli/commands/cmd_run.py +62 -0
  16. astrbot/cli/utils/__init__.py +18 -0
  17. astrbot/cli/utils/basic.py +76 -0
  18. astrbot/cli/utils/plugin.py +246 -0
  19. astrbot/cli/utils/version_comparator.py +90 -0
  20. astrbot/core/__init__.py +17 -19
  21. astrbot/core/agent/agent.py +14 -0
  22. astrbot/core/agent/handoff.py +38 -0
  23. astrbot/core/agent/hooks.py +30 -0
  24. astrbot/core/agent/mcp_client.py +385 -0
  25. astrbot/core/agent/message.py +175 -0
  26. astrbot/core/agent/response.py +14 -0
  27. astrbot/core/agent/run_context.py +22 -0
  28. astrbot/core/agent/runners/__init__.py +3 -0
  29. astrbot/core/agent/runners/base.py +65 -0
  30. astrbot/core/agent/runners/coze/coze_agent_runner.py +367 -0
  31. astrbot/core/agent/runners/coze/coze_api_client.py +324 -0
  32. astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py +403 -0
  33. astrbot/core/agent/runners/dify/dify_agent_runner.py +336 -0
  34. astrbot/core/agent/runners/dify/dify_api_client.py +195 -0
  35. astrbot/core/agent/runners/tool_loop_agent_runner.py +400 -0
  36. astrbot/core/agent/tool.py +285 -0
  37. astrbot/core/agent/tool_executor.py +17 -0
  38. astrbot/core/astr_agent_context.py +19 -0
  39. astrbot/core/astr_agent_hooks.py +36 -0
  40. astrbot/core/astr_agent_run_util.py +80 -0
  41. astrbot/core/astr_agent_tool_exec.py +246 -0
  42. astrbot/core/astrbot_config_mgr.py +275 -0
  43. astrbot/core/config/__init__.py +2 -2
  44. astrbot/core/config/astrbot_config.py +60 -20
  45. astrbot/core/config/default.py +1972 -453
  46. astrbot/core/config/i18n_utils.py +110 -0
  47. astrbot/core/conversation_mgr.py +285 -75
  48. astrbot/core/core_lifecycle.py +167 -62
  49. astrbot/core/db/__init__.py +305 -102
  50. astrbot/core/db/migration/helper.py +69 -0
  51. astrbot/core/db/migration/migra_3_to_4.py +357 -0
  52. astrbot/core/db/migration/migra_45_to_46.py +44 -0
  53. astrbot/core/db/migration/migra_webchat_session.py +131 -0
  54. astrbot/core/db/migration/shared_preferences_v3.py +48 -0
  55. astrbot/core/db/migration/sqlite_v3.py +497 -0
  56. astrbot/core/db/po.py +259 -55
  57. astrbot/core/db/sqlite.py +773 -528
  58. astrbot/core/db/vec_db/base.py +73 -0
  59. astrbot/core/db/vec_db/faiss_impl/__init__.py +3 -0
  60. astrbot/core/db/vec_db/faiss_impl/document_storage.py +392 -0
  61. astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +93 -0
  62. astrbot/core/db/vec_db/faiss_impl/sqlite_init.sql +17 -0
  63. astrbot/core/db/vec_db/faiss_impl/vec_db.py +204 -0
  64. astrbot/core/event_bus.py +26 -22
  65. astrbot/core/exceptions.py +9 -0
  66. astrbot/core/file_token_service.py +98 -0
  67. astrbot/core/initial_loader.py +19 -10
  68. astrbot/core/knowledge_base/chunking/__init__.py +9 -0
  69. astrbot/core/knowledge_base/chunking/base.py +25 -0
  70. astrbot/core/knowledge_base/chunking/fixed_size.py +59 -0
  71. astrbot/core/knowledge_base/chunking/recursive.py +161 -0
  72. astrbot/core/knowledge_base/kb_db_sqlite.py +301 -0
  73. astrbot/core/knowledge_base/kb_helper.py +642 -0
  74. astrbot/core/knowledge_base/kb_mgr.py +330 -0
  75. astrbot/core/knowledge_base/models.py +120 -0
  76. astrbot/core/knowledge_base/parsers/__init__.py +13 -0
  77. astrbot/core/knowledge_base/parsers/base.py +51 -0
  78. astrbot/core/knowledge_base/parsers/markitdown_parser.py +26 -0
  79. astrbot/core/knowledge_base/parsers/pdf_parser.py +101 -0
  80. astrbot/core/knowledge_base/parsers/text_parser.py +42 -0
  81. astrbot/core/knowledge_base/parsers/url_parser.py +103 -0
  82. astrbot/core/knowledge_base/parsers/util.py +13 -0
  83. astrbot/core/knowledge_base/prompts.py +65 -0
  84. astrbot/core/knowledge_base/retrieval/__init__.py +14 -0
  85. astrbot/core/knowledge_base/retrieval/hit_stopwords.txt +767 -0
  86. astrbot/core/knowledge_base/retrieval/manager.py +276 -0
  87. astrbot/core/knowledge_base/retrieval/rank_fusion.py +142 -0
  88. astrbot/core/knowledge_base/retrieval/sparse_retriever.py +136 -0
  89. astrbot/core/log.py +21 -15
  90. astrbot/core/message/components.py +413 -287
  91. astrbot/core/message/message_event_result.py +35 -24
  92. astrbot/core/persona_mgr.py +192 -0
  93. astrbot/core/pipeline/__init__.py +14 -14
  94. astrbot/core/pipeline/content_safety_check/stage.py +13 -9
  95. astrbot/core/pipeline/content_safety_check/strategies/__init__.py +1 -2
  96. astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py +13 -14
  97. astrbot/core/pipeline/content_safety_check/strategies/keywords.py +2 -1
  98. astrbot/core/pipeline/content_safety_check/strategies/strategy.py +6 -6
  99. astrbot/core/pipeline/context.py +7 -1
  100. astrbot/core/pipeline/context_utils.py +107 -0
  101. astrbot/core/pipeline/preprocess_stage/stage.py +63 -36
  102. astrbot/core/pipeline/process_stage/method/agent_request.py +48 -0
  103. astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +464 -0
  104. astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +202 -0
  105. astrbot/core/pipeline/process_stage/method/star_request.py +26 -32
  106. astrbot/core/pipeline/process_stage/stage.py +21 -15
  107. astrbot/core/pipeline/process_stage/utils.py +125 -0
  108. astrbot/core/pipeline/rate_limit_check/stage.py +34 -36
  109. astrbot/core/pipeline/respond/stage.py +142 -101
  110. astrbot/core/pipeline/result_decorate/stage.py +124 -57
  111. astrbot/core/pipeline/scheduler.py +21 -16
  112. astrbot/core/pipeline/session_status_check/stage.py +37 -0
  113. astrbot/core/pipeline/stage.py +11 -76
  114. astrbot/core/pipeline/waking_check/stage.py +69 -33
  115. astrbot/core/pipeline/whitelist_check/stage.py +10 -7
  116. astrbot/core/platform/__init__.py +6 -6
  117. astrbot/core/platform/astr_message_event.py +107 -129
  118. astrbot/core/platform/astrbot_message.py +32 -12
  119. astrbot/core/platform/manager.py +62 -18
  120. astrbot/core/platform/message_session.py +30 -0
  121. astrbot/core/platform/platform.py +16 -24
  122. astrbot/core/platform/platform_metadata.py +9 -4
  123. astrbot/core/platform/register.py +12 -7
  124. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +136 -60
  125. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +126 -46
  126. astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +63 -31
  127. astrbot/core/platform/sources/dingtalk/dingtalk_event.py +30 -26
  128. astrbot/core/platform/sources/discord/client.py +129 -0
  129. astrbot/core/platform/sources/discord/components.py +139 -0
  130. astrbot/core/platform/sources/discord/discord_platform_adapter.py +473 -0
  131. astrbot/core/platform/sources/discord/discord_platform_event.py +313 -0
  132. astrbot/core/platform/sources/lark/lark_adapter.py +27 -18
  133. astrbot/core/platform/sources/lark/lark_event.py +39 -13
  134. astrbot/core/platform/sources/misskey/misskey_adapter.py +770 -0
  135. astrbot/core/platform/sources/misskey/misskey_api.py +964 -0
  136. astrbot/core/platform/sources/misskey/misskey_event.py +163 -0
  137. astrbot/core/platform/sources/misskey/misskey_utils.py +550 -0
  138. astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +149 -33
  139. astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +41 -26
  140. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +36 -17
  141. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py +3 -1
  142. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +14 -8
  143. astrbot/core/platform/sources/satori/satori_adapter.py +792 -0
  144. astrbot/core/platform/sources/satori/satori_event.py +432 -0
  145. astrbot/core/platform/sources/slack/client.py +164 -0
  146. astrbot/core/platform/sources/slack/slack_adapter.py +416 -0
  147. astrbot/core/platform/sources/slack/slack_event.py +253 -0
  148. astrbot/core/platform/sources/telegram/tg_adapter.py +100 -43
  149. astrbot/core/platform/sources/telegram/tg_event.py +136 -36
  150. astrbot/core/platform/sources/webchat/webchat_adapter.py +72 -22
  151. astrbot/core/platform/sources/webchat/webchat_event.py +46 -22
  152. astrbot/core/platform/sources/webchat/webchat_queue_mgr.py +35 -0
  153. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +926 -0
  154. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +178 -0
  155. astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py +159 -0
  156. astrbot/core/platform/sources/wecom/wecom_adapter.py +169 -27
  157. astrbot/core/platform/sources/wecom/wecom_event.py +162 -77
  158. astrbot/core/platform/sources/wecom/wecom_kf.py +279 -0
  159. astrbot/core/platform/sources/wecom/wecom_kf_message.py +196 -0
  160. astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +297 -0
  161. astrbot/core/platform/sources/wecom_ai_bot/__init__.py +15 -0
  162. astrbot/core/platform/sources/wecom_ai_bot/ierror.py +19 -0
  163. astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +472 -0
  164. astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +417 -0
  165. astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +152 -0
  166. astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +153 -0
  167. astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +168 -0
  168. astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +209 -0
  169. astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +306 -0
  170. astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +186 -0
  171. astrbot/core/platform_message_history_mgr.py +49 -0
  172. astrbot/core/provider/__init__.py +2 -3
  173. astrbot/core/provider/entites.py +8 -8
  174. astrbot/core/provider/entities.py +154 -98
  175. astrbot/core/provider/func_tool_manager.py +446 -458
  176. astrbot/core/provider/manager.py +345 -207
  177. astrbot/core/provider/provider.py +188 -73
  178. astrbot/core/provider/register.py +9 -7
  179. astrbot/core/provider/sources/anthropic_source.py +295 -115
  180. astrbot/core/provider/sources/azure_tts_source.py +224 -0
  181. astrbot/core/provider/sources/bailian_rerank_source.py +236 -0
  182. astrbot/core/provider/sources/dashscope_tts.py +138 -14
  183. astrbot/core/provider/sources/edge_tts_source.py +24 -19
  184. astrbot/core/provider/sources/fishaudio_tts_api_source.py +58 -13
  185. astrbot/core/provider/sources/gemini_embedding_source.py +61 -0
  186. astrbot/core/provider/sources/gemini_source.py +310 -132
  187. astrbot/core/provider/sources/gemini_tts_source.py +81 -0
  188. astrbot/core/provider/sources/groq_source.py +15 -0
  189. astrbot/core/provider/sources/gsv_selfhosted_source.py +151 -0
  190. astrbot/core/provider/sources/gsvi_tts_source.py +14 -7
  191. astrbot/core/provider/sources/minimax_tts_api_source.py +159 -0
  192. astrbot/core/provider/sources/openai_embedding_source.py +40 -0
  193. astrbot/core/provider/sources/openai_source.py +241 -145
  194. astrbot/core/provider/sources/openai_tts_api_source.py +18 -7
  195. astrbot/core/provider/sources/sensevoice_selfhosted_source.py +13 -11
  196. astrbot/core/provider/sources/vllm_rerank_source.py +71 -0
  197. astrbot/core/provider/sources/volcengine_tts.py +115 -0
  198. astrbot/core/provider/sources/whisper_api_source.py +18 -13
  199. astrbot/core/provider/sources/whisper_selfhosted_source.py +19 -12
  200. astrbot/core/provider/sources/xinference_rerank_source.py +116 -0
  201. astrbot/core/provider/sources/xinference_stt_provider.py +197 -0
  202. astrbot/core/provider/sources/zhipu_source.py +6 -73
  203. astrbot/core/star/__init__.py +43 -11
  204. astrbot/core/star/config.py +17 -18
  205. astrbot/core/star/context.py +362 -138
  206. astrbot/core/star/filter/__init__.py +4 -3
  207. astrbot/core/star/filter/command.py +111 -35
  208. astrbot/core/star/filter/command_group.py +46 -34
  209. astrbot/core/star/filter/custom_filter.py +6 -5
  210. astrbot/core/star/filter/event_message_type.py +4 -2
  211. astrbot/core/star/filter/permission.py +4 -2
  212. astrbot/core/star/filter/platform_adapter_type.py +45 -12
  213. astrbot/core/star/filter/regex.py +4 -2
  214. astrbot/core/star/register/__init__.py +19 -15
  215. astrbot/core/star/register/star.py +41 -13
  216. astrbot/core/star/register/star_handler.py +236 -86
  217. astrbot/core/star/session_llm_manager.py +280 -0
  218. astrbot/core/star/session_plugin_manager.py +170 -0
  219. astrbot/core/star/star.py +36 -43
  220. astrbot/core/star/star_handler.py +47 -85
  221. astrbot/core/star/star_manager.py +442 -260
  222. astrbot/core/star/star_tools.py +167 -45
  223. astrbot/core/star/updator.py +17 -20
  224. astrbot/core/umop_config_router.py +106 -0
  225. astrbot/core/updator.py +38 -13
  226. astrbot/core/utils/astrbot_path.py +39 -0
  227. astrbot/core/utils/command_parser.py +1 -1
  228. astrbot/core/utils/io.py +119 -60
  229. astrbot/core/utils/log_pipe.py +1 -1
  230. astrbot/core/utils/metrics.py +11 -10
  231. astrbot/core/utils/migra_helper.py +73 -0
  232. astrbot/core/utils/path_util.py +63 -62
  233. astrbot/core/utils/pip_installer.py +37 -15
  234. astrbot/core/utils/session_lock.py +29 -0
  235. astrbot/core/utils/session_waiter.py +19 -20
  236. astrbot/core/utils/shared_preferences.py +174 -34
  237. astrbot/core/utils/t2i/__init__.py +4 -1
  238. astrbot/core/utils/t2i/local_strategy.py +386 -238
  239. astrbot/core/utils/t2i/network_strategy.py +109 -49
  240. astrbot/core/utils/t2i/renderer.py +29 -14
  241. astrbot/core/utils/t2i/template/astrbot_powershell.html +184 -0
  242. astrbot/core/utils/t2i/template_manager.py +111 -0
  243. astrbot/core/utils/tencent_record_helper.py +115 -1
  244. astrbot/core/utils/version_comparator.py +10 -13
  245. astrbot/core/zip_updator.py +112 -65
  246. astrbot/dashboard/routes/__init__.py +20 -13
  247. astrbot/dashboard/routes/auth.py +20 -9
  248. astrbot/dashboard/routes/chat.py +297 -141
  249. astrbot/dashboard/routes/config.py +652 -55
  250. astrbot/dashboard/routes/conversation.py +107 -37
  251. astrbot/dashboard/routes/file.py +26 -0
  252. astrbot/dashboard/routes/knowledge_base.py +1244 -0
  253. astrbot/dashboard/routes/log.py +27 -2
  254. astrbot/dashboard/routes/persona.py +202 -0
  255. astrbot/dashboard/routes/plugin.py +197 -139
  256. astrbot/dashboard/routes/route.py +27 -7
  257. astrbot/dashboard/routes/session_management.py +354 -0
  258. astrbot/dashboard/routes/stat.py +85 -18
  259. astrbot/dashboard/routes/static_file.py +5 -2
  260. astrbot/dashboard/routes/t2i.py +233 -0
  261. astrbot/dashboard/routes/tools.py +184 -120
  262. astrbot/dashboard/routes/update.py +59 -36
  263. astrbot/dashboard/server.py +96 -36
  264. astrbot/dashboard/utils.py +165 -0
  265. astrbot-4.7.0.dist-info/METADATA +294 -0
  266. astrbot-4.7.0.dist-info/RECORD +274 -0
  267. {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/WHEEL +1 -1
  268. astrbot/core/db/plugin/sqlite_impl.py +0 -112
  269. astrbot/core/db/sqlite_init.sql +0 -50
  270. astrbot/core/pipeline/platform_compatibility/stage.py +0 -56
  271. astrbot/core/pipeline/process_stage/method/llm_request.py +0 -606
  272. astrbot/core/platform/sources/gewechat/client.py +0 -806
  273. astrbot/core/platform/sources/gewechat/downloader.py +0 -55
  274. astrbot/core/platform/sources/gewechat/gewechat_event.py +0 -255
  275. astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py +0 -103
  276. astrbot/core/platform/sources/gewechat/xml_data_parser.py +0 -110
  277. astrbot/core/provider/sources/dashscope_source.py +0 -203
  278. astrbot/core/provider/sources/dify_source.py +0 -281
  279. astrbot/core/provider/sources/llmtuner_source.py +0 -132
  280. astrbot/core/rag/embedding/openai_source.py +0 -20
  281. astrbot/core/rag/knowledge_db_mgr.py +0 -94
  282. astrbot/core/rag/store/__init__.py +0 -9
  283. astrbot/core/rag/store/chroma_db.py +0 -42
  284. astrbot/core/utils/dify_api_client.py +0 -152
  285. astrbot-3.5.6.dist-info/METADATA +0 -249
  286. astrbot-3.5.6.dist-info/RECORD +0 -158
  287. {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/entry_points.txt +0 -0
  288. {astrbot-3.5.6.dist-info → astrbot-4.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,385 @@
1
+ import asyncio
2
+ import logging
3
+ from contextlib import AsyncExitStack
4
+ from datetime import timedelta
5
+ from typing import Generic
6
+
7
+ from tenacity import (
8
+ before_sleep_log,
9
+ retry,
10
+ retry_if_exception_type,
11
+ stop_after_attempt,
12
+ wait_exponential,
13
+ )
14
+
15
+ from astrbot import logger
16
+ from astrbot.core.agent.run_context import ContextWrapper
17
+ from astrbot.core.utils.log_pipe import LogPipe
18
+
19
+ from .run_context import TContext
20
+ from .tool import FunctionTool
21
+
22
+ try:
23
+ import anyio
24
+ import mcp
25
+ from mcp.client.sse import sse_client
26
+ except (ModuleNotFoundError, ImportError):
27
+ logger.warning(
28
+ "Warning: Missing 'mcp' dependency, MCP services will be unavailable."
29
+ )
30
+
31
+ try:
32
+ from mcp.client.streamable_http import streamablehttp_client
33
+ except (ModuleNotFoundError, ImportError):
34
+ logger.warning(
35
+ "Warning: Missing 'mcp' dependency or MCP library version too old, Streamable HTTP connection unavailable.",
36
+ )
37
+
38
+
39
+ def _prepare_config(config: dict) -> dict:
40
+ """Prepare configuration, handle nested format"""
41
+ if config.get("mcpServers"):
42
+ first_key = next(iter(config["mcpServers"]))
43
+ config = config["mcpServers"][first_key]
44
+ config.pop("active", None)
45
+ return config
46
+
47
+
48
+ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
49
+ """Quick test MCP server connectivity"""
50
+ import aiohttp
51
+
52
+ cfg = _prepare_config(config.copy())
53
+
54
+ url = cfg["url"]
55
+ headers = cfg.get("headers", {})
56
+ timeout = cfg.get("timeout", 10)
57
+
58
+ try:
59
+ if "transport" in cfg:
60
+ transport_type = cfg["transport"]
61
+ elif "type" in cfg:
62
+ transport_type = cfg["type"]
63
+ else:
64
+ raise Exception("MCP connection config missing transport or type field")
65
+
66
+ async with aiohttp.ClientSession() as session:
67
+ if transport_type == "streamable_http":
68
+ test_payload = {
69
+ "jsonrpc": "2.0",
70
+ "method": "initialize",
71
+ "id": 0,
72
+ "params": {
73
+ "protocolVersion": "2024-11-05",
74
+ "capabilities": {},
75
+ "clientInfo": {"name": "test-client", "version": "1.2.3"},
76
+ },
77
+ }
78
+ async with session.post(
79
+ url,
80
+ headers={
81
+ **headers,
82
+ "Content-Type": "application/json",
83
+ "Accept": "application/json, text/event-stream",
84
+ },
85
+ json=test_payload,
86
+ timeout=aiohttp.ClientTimeout(total=timeout),
87
+ ) as response:
88
+ if response.status == 200:
89
+ return True, ""
90
+ return False, f"HTTP {response.status}: {response.reason}"
91
+ else:
92
+ async with session.get(
93
+ url,
94
+ headers={
95
+ **headers,
96
+ "Accept": "application/json, text/event-stream",
97
+ },
98
+ timeout=aiohttp.ClientTimeout(total=timeout),
99
+ ) as response:
100
+ if response.status == 200:
101
+ return True, ""
102
+ return False, f"HTTP {response.status}: {response.reason}"
103
+
104
+ except asyncio.TimeoutError:
105
+ return False, f"Connection timeout: {timeout} seconds"
106
+ except Exception as e:
107
+ return False, f"{e!s}"
108
+
109
+
110
+ class MCPClient:
111
+ def __init__(self):
112
+ # Initialize session and client objects
113
+ self.session: mcp.ClientSession | None = None
114
+ self.exit_stack = AsyncExitStack()
115
+ self._old_exit_stacks: list[AsyncExitStack] = [] # Track old stacks for cleanup
116
+
117
+ self.name: str | None = None
118
+ self.active: bool = True
119
+ self.tools: list[mcp.Tool] = []
120
+ self.server_errlogs: list[str] = []
121
+ self.running_event = asyncio.Event()
122
+
123
+ # Store connection config for reconnection
124
+ self._mcp_server_config: dict | None = None
125
+ self._server_name: str | None = None
126
+ self._reconnect_lock = asyncio.Lock() # Lock for thread-safe reconnection
127
+ self._reconnecting: bool = False # For logging and debugging
128
+
129
+ async def connect_to_server(self, mcp_server_config: dict, name: str):
130
+ """Connect to MCP server
131
+
132
+ If `url` parameter exists:
133
+ 1. When transport is specified as `streamable_http`, use Streamable HTTP connection.
134
+ 2. When transport is specified as `sse`, use SSE connection.
135
+ 3. If not specified, default to SSE connection to MCP service.
136
+
137
+ Args:
138
+ mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
139
+
140
+ """
141
+ # Store config for reconnection
142
+ self._mcp_server_config = mcp_server_config
143
+ self._server_name = name
144
+
145
+ cfg = _prepare_config(mcp_server_config.copy())
146
+
147
+ def logging_callback(msg: str):
148
+ # Handle MCP service error logs
149
+ print(f"MCP Server {name} Error: {msg}")
150
+ self.server_errlogs.append(msg)
151
+
152
+ if "url" in cfg:
153
+ success, error_msg = await _quick_test_mcp_connection(cfg)
154
+ if not success:
155
+ raise Exception(error_msg)
156
+
157
+ if "transport" in cfg:
158
+ transport_type = cfg["transport"]
159
+ elif "type" in cfg:
160
+ transport_type = cfg["type"]
161
+ else:
162
+ raise Exception("MCP connection config missing transport or type field")
163
+
164
+ if transport_type != "streamable_http":
165
+ # SSE transport method
166
+ self._streams_context = sse_client(
167
+ url=cfg["url"],
168
+ headers=cfg.get("headers", {}),
169
+ timeout=cfg.get("timeout", 5),
170
+ sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
171
+ )
172
+ streams = await self.exit_stack.enter_async_context(
173
+ self._streams_context,
174
+ )
175
+
176
+ # Create a new client session
177
+ read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 60))
178
+ self.session = await self.exit_stack.enter_async_context(
179
+ mcp.ClientSession(
180
+ *streams,
181
+ read_timeout_seconds=read_timeout,
182
+ logging_callback=logging_callback, # type: ignore
183
+ ),
184
+ )
185
+ else:
186
+ timeout = timedelta(seconds=cfg.get("timeout", 30))
187
+ sse_read_timeout = timedelta(
188
+ seconds=cfg.get("sse_read_timeout", 60 * 5),
189
+ )
190
+ self._streams_context = streamablehttp_client(
191
+ url=cfg["url"],
192
+ headers=cfg.get("headers", {}),
193
+ timeout=timeout,
194
+ sse_read_timeout=sse_read_timeout,
195
+ terminate_on_close=cfg.get("terminate_on_close", True),
196
+ )
197
+ read_s, write_s, _ = await self.exit_stack.enter_async_context(
198
+ self._streams_context,
199
+ )
200
+
201
+ # Create a new client session
202
+ read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 60))
203
+ self.session = await self.exit_stack.enter_async_context(
204
+ mcp.ClientSession(
205
+ read_stream=read_s,
206
+ write_stream=write_s,
207
+ read_timeout_seconds=read_timeout,
208
+ logging_callback=logging_callback, # type: ignore
209
+ ),
210
+ )
211
+
212
+ else:
213
+ server_params = mcp.StdioServerParameters(
214
+ **cfg,
215
+ )
216
+
217
+ def callback(msg: str):
218
+ # Handle MCP service error logs
219
+ self.server_errlogs.append(msg)
220
+
221
+ stdio_transport = await self.exit_stack.enter_async_context(
222
+ mcp.stdio_client(
223
+ server_params,
224
+ errlog=LogPipe(
225
+ level=logging.ERROR,
226
+ logger=logger,
227
+ identifier=f"MCPServer-{name}",
228
+ callback=callback,
229
+ ), # type: ignore
230
+ ),
231
+ )
232
+
233
+ # Create a new client session
234
+ self.session = await self.exit_stack.enter_async_context(
235
+ mcp.ClientSession(*stdio_transport),
236
+ )
237
+ await self.session.initialize()
238
+
239
+ async def list_tools_and_save(self) -> mcp.ListToolsResult:
240
+ """List all tools from the server and save them to self.tools"""
241
+ if not self.session:
242
+ raise Exception("MCP Client is not initialized")
243
+ response = await self.session.list_tools()
244
+ self.tools = response.tools
245
+ return response
246
+
247
+ async def _reconnect(self) -> None:
248
+ """Reconnect to the MCP server using the stored configuration.
249
+
250
+ Uses asyncio.Lock to ensure thread-safe reconnection in concurrent environments.
251
+
252
+ Raises:
253
+ Exception: raised when reconnection fails
254
+ """
255
+ async with self._reconnect_lock:
256
+ # Check if already reconnecting (useful for logging)
257
+ if self._reconnecting:
258
+ logger.debug(
259
+ f"MCP Client {self._server_name} is already reconnecting, skipping"
260
+ )
261
+ return
262
+
263
+ if not self._mcp_server_config or not self._server_name:
264
+ raise Exception("Cannot reconnect: missing connection configuration")
265
+
266
+ self._reconnecting = True
267
+ try:
268
+ logger.info(
269
+ f"Attempting to reconnect to MCP server {self._server_name}..."
270
+ )
271
+
272
+ # Save old exit_stack for later cleanup (don't close it now to avoid cancel scope issues)
273
+ if self.exit_stack:
274
+ self._old_exit_stacks.append(self.exit_stack)
275
+
276
+ # Mark old session as invalid
277
+ self.session = None
278
+
279
+ # Create new exit stack for new connection
280
+ self.exit_stack = AsyncExitStack()
281
+
282
+ # Reconnect using stored config
283
+ await self.connect_to_server(self._mcp_server_config, self._server_name)
284
+ await self.list_tools_and_save()
285
+
286
+ logger.info(
287
+ f"Successfully reconnected to MCP server {self._server_name}"
288
+ )
289
+ except Exception as e:
290
+ logger.error(
291
+ f"Failed to reconnect to MCP server {self._server_name}: {e}"
292
+ )
293
+ raise
294
+ finally:
295
+ self._reconnecting = False
296
+
297
+ async def call_tool_with_reconnect(
298
+ self,
299
+ tool_name: str,
300
+ arguments: dict,
301
+ read_timeout_seconds: timedelta,
302
+ ) -> mcp.types.CallToolResult:
303
+ """Call MCP tool with automatic reconnection on failure, max 2 retries.
304
+
305
+ Args:
306
+ tool_name: tool name
307
+ arguments: tool arguments
308
+ read_timeout_seconds: read timeout
309
+
310
+ Returns:
311
+ MCP tool call result
312
+
313
+ Raises:
314
+ ValueError: MCP session is not available
315
+ anyio.ClosedResourceError: raised after reconnection failure
316
+ """
317
+
318
+ @retry(
319
+ retry=retry_if_exception_type(anyio.ClosedResourceError),
320
+ stop=stop_after_attempt(2),
321
+ wait=wait_exponential(multiplier=1, min=1, max=3),
322
+ before_sleep=before_sleep_log(logger, logging.WARNING),
323
+ reraise=True,
324
+ )
325
+ async def _call_with_retry():
326
+ if not self.session:
327
+ raise ValueError("MCP session is not available for MCP function tools.")
328
+
329
+ try:
330
+ return await self.session.call_tool(
331
+ name=tool_name,
332
+ arguments=arguments,
333
+ read_timeout_seconds=read_timeout_seconds,
334
+ )
335
+ except anyio.ClosedResourceError:
336
+ logger.warning(
337
+ f"MCP tool {tool_name} call failed (ClosedResourceError), attempting to reconnect..."
338
+ )
339
+ # Attempt to reconnect
340
+ await self._reconnect()
341
+ # Reraise the exception to trigger tenacity retry
342
+ raise
343
+
344
+ return await _call_with_retry()
345
+
346
+ async def cleanup(self):
347
+ """Clean up resources including old exit stacks from reconnections"""
348
+ # Close current exit stack
349
+ try:
350
+ await self.exit_stack.aclose()
351
+ except Exception as e:
352
+ logger.debug(f"Error closing current exit stack: {e}")
353
+
354
+ # Don't close old exit stacks as they may be in different task contexts
355
+ # They will be garbage collected naturally
356
+ # Just clear the list to release references
357
+ self._old_exit_stacks.clear()
358
+
359
+ # Set running_event first to unblock any waiting tasks
360
+ self.running_event.set()
361
+
362
+
363
+ class MCPTool(FunctionTool, Generic[TContext]):
364
+ """A function tool that calls an MCP service."""
365
+
366
+ def __init__(
367
+ self, mcp_tool: mcp.Tool, mcp_client: MCPClient, mcp_server_name: str, **kwargs
368
+ ):
369
+ super().__init__(
370
+ name=mcp_tool.name,
371
+ description=mcp_tool.description or "",
372
+ parameters=mcp_tool.inputSchema,
373
+ )
374
+ self.mcp_tool = mcp_tool
375
+ self.mcp_client = mcp_client
376
+ self.mcp_server_name = mcp_server_name
377
+
378
+ async def call(
379
+ self, context: ContextWrapper[TContext], **kwargs
380
+ ) -> mcp.types.CallToolResult:
381
+ return await self.mcp_client.call_tool_with_reconnect(
382
+ tool_name=self.mcp_tool.name,
383
+ arguments=kwargs,
384
+ read_timeout_seconds=timedelta(seconds=context.tool_call_timeout),
385
+ )
@@ -0,0 +1,175 @@
1
+ # Inspired by MoonshotAI/kosong, credits to MoonshotAI/kosong authors for the original implementation.
2
+ # License: Apache License 2.0
3
+
4
+ from typing import Any, ClassVar, Literal, cast
5
+
6
+ from pydantic import BaseModel, GetCoreSchemaHandler
7
+ from pydantic_core import core_schema
8
+
9
+
10
+ class ContentPart(BaseModel):
11
+ """A part of the content in a message."""
12
+
13
+ __content_part_registry: ClassVar[dict[str, type["ContentPart"]]] = {}
14
+
15
+ type: str
16
+
17
+ def __init_subclass__(cls, **kwargs: Any) -> None:
18
+ super().__init_subclass__(**kwargs)
19
+
20
+ invalid_subclass_error_msg = f"ContentPart subclass {cls.__name__} must have a `type` field of type `str`"
21
+
22
+ type_value = getattr(cls, "type", None)
23
+ if type_value is None or not isinstance(type_value, str):
24
+ raise ValueError(invalid_subclass_error_msg)
25
+
26
+ cls.__content_part_registry[type_value] = cls
27
+
28
+ @classmethod
29
+ def __get_pydantic_core_schema__(
30
+ cls, source_type: Any, handler: GetCoreSchemaHandler
31
+ ) -> core_schema.CoreSchema:
32
+ # If we're dealing with the base ContentPart class, use custom validation
33
+ if cls.__name__ == "ContentPart":
34
+
35
+ def validate_content_part(value: Any) -> Any:
36
+ # if it's already an instance of a ContentPart subclass, return it
37
+ if hasattr(value, "__class__") and issubclass(value.__class__, cls):
38
+ return value
39
+
40
+ # if it's a dict with a type field, dispatch to the appropriate subclass
41
+ if isinstance(value, dict) and "type" in value:
42
+ type_value: Any | None = cast(dict[str, Any], value).get("type")
43
+ if not isinstance(type_value, str):
44
+ raise ValueError(f"Cannot validate {value} as ContentPart")
45
+ target_class = cls.__content_part_registry[type_value]
46
+ return target_class.model_validate(value)
47
+
48
+ raise ValueError(f"Cannot validate {value} as ContentPart")
49
+
50
+ return core_schema.no_info_plain_validator_function(validate_content_part)
51
+
52
+ # for subclasses, use the default schema
53
+ return handler(source_type)
54
+
55
+
56
+ class TextPart(ContentPart):
57
+ """
58
+ >>> TextPart(text="Hello, world!").model_dump()
59
+ {'type': 'text', 'text': 'Hello, world!'}
60
+ """
61
+
62
+ type: str = "text"
63
+ text: str
64
+
65
+
66
+ class ImageURLPart(ContentPart):
67
+ """
68
+ >>> ImageURLPart(image_url="http://example.com/image.jpg").model_dump()
69
+ {'type': 'image_url', 'image_url': 'http://example.com/image.jpg'}
70
+ """
71
+
72
+ class ImageURL(BaseModel):
73
+ url: str
74
+ """The URL of the image, can be data URI scheme like `data:image/png;base64,...`."""
75
+ id: str | None = None
76
+ """The ID of the image, to allow LLMs to distinguish different images."""
77
+
78
+ type: str = "image_url"
79
+ image_url: ImageURL
80
+
81
+
82
+ class AudioURLPart(ContentPart):
83
+ """
84
+ >>> AudioURLPart(audio_url=AudioURLPart.AudioURL(url="https://example.com/audio.mp3")).model_dump()
85
+ {'type': 'audio_url', 'audio_url': {'url': 'https://example.com/audio.mp3', 'id': None}}
86
+ """
87
+
88
+ class AudioURL(BaseModel):
89
+ url: str
90
+ """The URL of the audio, can be data URI scheme like `data:audio/aac;base64,...`."""
91
+ id: str | None = None
92
+ """The ID of the audio, to allow LLMs to distinguish different audios."""
93
+
94
+ type: str = "audio_url"
95
+ audio_url: AudioURL
96
+
97
+
98
+ class ToolCall(BaseModel):
99
+ """
100
+ A tool call requested by the assistant.
101
+
102
+ >>> ToolCall(
103
+ ... id="123",
104
+ ... function=ToolCall.FunctionBody(
105
+ ... name="function",
106
+ ... arguments="{}"
107
+ ... ),
108
+ ... ).model_dump()
109
+ {'type': 'function', 'id': '123', 'function': {'name': 'function', 'arguments': '{}'}}
110
+ """
111
+
112
+ class FunctionBody(BaseModel):
113
+ name: str
114
+ arguments: str | None
115
+
116
+ type: Literal["function"] = "function"
117
+
118
+ id: str
119
+ """The ID of the tool call."""
120
+ function: FunctionBody
121
+ """The function body of the tool call."""
122
+ extra_content: dict[str, Any] | None = None
123
+ """Extra metadata for the tool call."""
124
+
125
+ def model_dump(self, **kwargs: Any) -> dict[str, Any]:
126
+ if self.extra_content is None:
127
+ kwargs.setdefault("exclude", set()).add("extra_content")
128
+ return super().model_dump(**kwargs)
129
+
130
+
131
+ class ToolCallPart(BaseModel):
132
+ """A part of the tool call."""
133
+
134
+ arguments_part: str | None = None
135
+ """A part of the arguments of the tool call."""
136
+
137
+
138
+ class Message(BaseModel):
139
+ """A message in a conversation."""
140
+
141
+ role: Literal[
142
+ "system",
143
+ "user",
144
+ "assistant",
145
+ "tool",
146
+ ]
147
+
148
+ content: str | list[ContentPart]
149
+ """The content of the message."""
150
+
151
+
152
+ class AssistantMessageSegment(Message):
153
+ """A message segment from the assistant."""
154
+
155
+ role: Literal["assistant"] = "assistant"
156
+ tool_calls: list[ToolCall] | list[dict] | None = None
157
+
158
+
159
+ class ToolCallMessageSegment(Message):
160
+ """A message segment representing a tool call."""
161
+
162
+ role: Literal["tool"] = "tool"
163
+ tool_call_id: str
164
+
165
+
166
+ class UserMessageSegment(Message):
167
+ """A message segment from the user."""
168
+
169
+ role: Literal["user"] = "user"
170
+
171
+
172
+ class SystemMessageSegment(Message):
173
+ """A message segment from the system."""
174
+
175
+ role: Literal["system"] = "system"
@@ -0,0 +1,14 @@
1
+ import typing as T
2
+ from dataclasses import dataclass
3
+
4
+ from astrbot.core.message.message_event_result import MessageChain
5
+
6
+
7
+ class AgentResponseData(T.TypedDict):
8
+ chain: MessageChain
9
+
10
+
11
+ @dataclass
12
+ class AgentResponse:
13
+ type: str
14
+ data: AgentResponseData
@@ -0,0 +1,22 @@
1
+ from typing import Any, Generic
2
+
3
+ from pydantic import Field
4
+ from pydantic.dataclasses import dataclass
5
+ from typing_extensions import TypeVar
6
+
7
+ from .message import Message
8
+
9
+ TContext = TypeVar("TContext", default=Any)
10
+
11
+
12
+ @dataclass(config={"arbitrary_types_allowed": True})
13
+ class ContextWrapper(Generic[TContext]):
14
+ """A context for running an agent, which can be used to pass additional data or state."""
15
+
16
+ context: TContext
17
+ messages: list[Message] = Field(default_factory=list)
18
+ """This field stores the llm message context for the agent run, agent runners will maintain this field automatically."""
19
+ tool_call_timeout: int = 60 # Default tool call timeout in seconds
20
+
21
+
22
+ NoContext = ContextWrapper[None]
@@ -0,0 +1,3 @@
1
+ from .base import BaseAgentRunner
2
+
3
+ __all__ = ["BaseAgentRunner"]
@@ -0,0 +1,65 @@
1
+ import abc
2
+ import typing as T
3
+ from enum import Enum, auto
4
+
5
+ from astrbot import logger
6
+ from astrbot.core.provider.entities import LLMResponse
7
+
8
+ from ..hooks import BaseAgentRunHooks
9
+ from ..response import AgentResponse
10
+ from ..run_context import ContextWrapper, TContext
11
+
12
+
13
+ class AgentState(Enum):
14
+ """Defines the state of the agent."""
15
+
16
+ IDLE = auto() # Initial state
17
+ RUNNING = auto() # Currently processing
18
+ DONE = auto() # Completed
19
+ ERROR = auto() # Error state
20
+
21
+
22
+ class BaseAgentRunner(T.Generic[TContext]):
23
+ @abc.abstractmethod
24
+ async def reset(
25
+ self,
26
+ run_context: ContextWrapper[TContext],
27
+ agent_hooks: BaseAgentRunHooks[TContext],
28
+ **kwargs: T.Any,
29
+ ) -> None:
30
+ """Reset the agent to its initial state.
31
+ This method should be called before starting a new run.
32
+ """
33
+ ...
34
+
35
+ @abc.abstractmethod
36
+ async def step(self) -> T.AsyncGenerator[AgentResponse, None]:
37
+ """Process a single step of the agent."""
38
+ ...
39
+
40
+ @abc.abstractmethod
41
+ async def step_until_done(
42
+ self, max_step: int
43
+ ) -> T.AsyncGenerator[AgentResponse, None]:
44
+ """Process steps until the agent is done."""
45
+ ...
46
+
47
+ @abc.abstractmethod
48
+ def done(self) -> bool:
49
+ """Check if the agent has completed its task.
50
+ Returns True if the agent is done, False otherwise.
51
+ """
52
+ ...
53
+
54
+ @abc.abstractmethod
55
+ def get_final_llm_resp(self) -> LLMResponse | None:
56
+ """Get the final observation from the agent.
57
+ This method should be called after the agent is done.
58
+ """
59
+ ...
60
+
61
+ def _transition_state(self, new_state: AgentState) -> None:
62
+ """Transition the agent state."""
63
+ if self._state != new_state:
64
+ logger.debug(f"Agent state transition: {self._state} -> {new_state}")
65
+ self._state = new_state