AstrBot 4.5.1__py3-none-any.whl → 4.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (244) hide show
  1. astrbot/api/__init__.py +10 -11
  2. astrbot/api/event/__init__.py +5 -6
  3. astrbot/api/event/filter/__init__.py +37 -36
  4. astrbot/api/platform/__init__.py +7 -8
  5. astrbot/api/provider/__init__.py +7 -7
  6. astrbot/api/star/__init__.py +3 -4
  7. astrbot/api/util/__init__.py +2 -2
  8. astrbot/cli/__main__.py +5 -5
  9. astrbot/cli/commands/__init__.py +3 -3
  10. astrbot/cli/commands/cmd_conf.py +19 -16
  11. astrbot/cli/commands/cmd_init.py +3 -2
  12. astrbot/cli/commands/cmd_plug.py +8 -10
  13. astrbot/cli/commands/cmd_run.py +5 -6
  14. astrbot/cli/utils/__init__.py +6 -6
  15. astrbot/cli/utils/basic.py +14 -14
  16. astrbot/cli/utils/plugin.py +24 -15
  17. astrbot/cli/utils/version_comparator.py +10 -12
  18. astrbot/core/__init__.py +8 -6
  19. astrbot/core/agent/agent.py +3 -2
  20. astrbot/core/agent/handoff.py +6 -2
  21. astrbot/core/agent/hooks.py +9 -6
  22. astrbot/core/agent/mcp_client.py +50 -15
  23. astrbot/core/agent/message.py +168 -0
  24. astrbot/core/agent/response.py +2 -1
  25. astrbot/core/agent/run_context.py +2 -3
  26. astrbot/core/agent/runners/base.py +10 -13
  27. astrbot/core/agent/runners/tool_loop_agent_runner.py +52 -51
  28. astrbot/core/agent/tool.py +60 -41
  29. astrbot/core/agent/tool_executor.py +9 -3
  30. astrbot/core/astr_agent_context.py +3 -1
  31. astrbot/core/astrbot_config_mgr.py +29 -9
  32. astrbot/core/config/__init__.py +2 -2
  33. astrbot/core/config/astrbot_config.py +28 -26
  34. astrbot/core/config/default.py +4 -6
  35. astrbot/core/conversation_mgr.py +105 -36
  36. astrbot/core/core_lifecycle.py +68 -54
  37. astrbot/core/db/__init__.py +33 -18
  38. astrbot/core/db/migration/helper.py +12 -10
  39. astrbot/core/db/migration/migra_3_to_4.py +53 -34
  40. astrbot/core/db/migration/migra_45_to_46.py +1 -1
  41. astrbot/core/db/migration/shared_preferences_v3.py +2 -1
  42. astrbot/core/db/migration/sqlite_v3.py +26 -23
  43. astrbot/core/db/po.py +27 -18
  44. astrbot/core/db/sqlite.py +74 -45
  45. astrbot/core/db/vec_db/base.py +10 -14
  46. astrbot/core/db/vec_db/faiss_impl/document_storage.py +90 -77
  47. astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +9 -3
  48. astrbot/core/db/vec_db/faiss_impl/vec_db.py +36 -31
  49. astrbot/core/event_bus.py +8 -6
  50. astrbot/core/file_token_service.py +6 -5
  51. astrbot/core/initial_loader.py +7 -5
  52. astrbot/core/knowledge_base/chunking/__init__.py +1 -3
  53. astrbot/core/knowledge_base/chunking/base.py +1 -0
  54. astrbot/core/knowledge_base/chunking/fixed_size.py +2 -0
  55. astrbot/core/knowledge_base/chunking/recursive.py +16 -10
  56. astrbot/core/knowledge_base/kb_db_sqlite.py +50 -48
  57. astrbot/core/knowledge_base/kb_helper.py +30 -17
  58. astrbot/core/knowledge_base/kb_mgr.py +6 -7
  59. astrbot/core/knowledge_base/models.py +10 -4
  60. astrbot/core/knowledge_base/parsers/__init__.py +3 -5
  61. astrbot/core/knowledge_base/parsers/base.py +1 -0
  62. astrbot/core/knowledge_base/parsers/markitdown_parser.py +2 -1
  63. astrbot/core/knowledge_base/parsers/pdf_parser.py +2 -1
  64. astrbot/core/knowledge_base/parsers/text_parser.py +1 -0
  65. astrbot/core/knowledge_base/parsers/util.py +1 -1
  66. astrbot/core/knowledge_base/retrieval/__init__.py +6 -8
  67. astrbot/core/knowledge_base/retrieval/manager.py +17 -14
  68. astrbot/core/knowledge_base/retrieval/rank_fusion.py +7 -3
  69. astrbot/core/knowledge_base/retrieval/sparse_retriever.py +11 -5
  70. astrbot/core/log.py +21 -13
  71. astrbot/core/message/components.py +123 -217
  72. astrbot/core/message/message_event_result.py +24 -24
  73. astrbot/core/persona_mgr.py +20 -11
  74. astrbot/core/pipeline/__init__.py +7 -7
  75. astrbot/core/pipeline/content_safety_check/stage.py +13 -9
  76. astrbot/core/pipeline/content_safety_check/strategies/__init__.py +1 -2
  77. astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py +12 -13
  78. astrbot/core/pipeline/content_safety_check/strategies/keywords.py +1 -0
  79. astrbot/core/pipeline/content_safety_check/strategies/strategy.py +6 -6
  80. astrbot/core/pipeline/context.py +4 -1
  81. astrbot/core/pipeline/context_utils.py +77 -7
  82. astrbot/core/pipeline/preprocess_stage/stage.py +12 -9
  83. astrbot/core/pipeline/process_stage/method/llm_request.py +125 -72
  84. astrbot/core/pipeline/process_stage/method/star_request.py +19 -17
  85. astrbot/core/pipeline/process_stage/stage.py +13 -10
  86. astrbot/core/pipeline/process_stage/utils.py +6 -5
  87. astrbot/core/pipeline/rate_limit_check/stage.py +37 -36
  88. astrbot/core/pipeline/respond/stage.py +23 -20
  89. astrbot/core/pipeline/result_decorate/stage.py +31 -23
  90. astrbot/core/pipeline/scheduler.py +12 -8
  91. astrbot/core/pipeline/session_status_check/stage.py +12 -8
  92. astrbot/core/pipeline/stage.py +10 -4
  93. astrbot/core/pipeline/waking_check/stage.py +24 -18
  94. astrbot/core/pipeline/whitelist_check/stage.py +10 -7
  95. astrbot/core/platform/__init__.py +6 -6
  96. astrbot/core/platform/astr_message_event.py +76 -110
  97. astrbot/core/platform/astrbot_message.py +11 -13
  98. astrbot/core/platform/manager.py +16 -15
  99. astrbot/core/platform/message_session.py +5 -3
  100. astrbot/core/platform/platform.py +16 -24
  101. astrbot/core/platform/platform_metadata.py +4 -4
  102. astrbot/core/platform/register.py +8 -8
  103. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +23 -15
  104. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +51 -33
  105. astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +42 -27
  106. astrbot/core/platform/sources/dingtalk/dingtalk_event.py +7 -3
  107. astrbot/core/platform/sources/discord/client.py +9 -6
  108. astrbot/core/platform/sources/discord/components.py +18 -14
  109. astrbot/core/platform/sources/discord/discord_platform_adapter.py +45 -30
  110. astrbot/core/platform/sources/discord/discord_platform_event.py +38 -30
  111. astrbot/core/platform/sources/lark/lark_adapter.py +23 -17
  112. astrbot/core/platform/sources/lark/lark_event.py +21 -14
  113. astrbot/core/platform/sources/misskey/misskey_adapter.py +107 -67
  114. astrbot/core/platform/sources/misskey/misskey_api.py +153 -129
  115. astrbot/core/platform/sources/misskey/misskey_event.py +20 -15
  116. astrbot/core/platform/sources/misskey/misskey_utils.py +74 -62
  117. astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +63 -44
  118. astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +41 -26
  119. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +36 -17
  120. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py +3 -1
  121. astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +12 -7
  122. astrbot/core/platform/sources/satori/satori_adapter.py +56 -38
  123. astrbot/core/platform/sources/satori/satori_event.py +34 -25
  124. astrbot/core/platform/sources/slack/client.py +11 -9
  125. astrbot/core/platform/sources/slack/slack_adapter.py +52 -36
  126. astrbot/core/platform/sources/slack/slack_event.py +34 -24
  127. astrbot/core/platform/sources/telegram/tg_adapter.py +38 -18
  128. astrbot/core/platform/sources/telegram/tg_event.py +32 -18
  129. astrbot/core/platform/sources/webchat/webchat_adapter.py +27 -17
  130. astrbot/core/platform/sources/webchat/webchat_event.py +14 -10
  131. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +115 -120
  132. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +9 -8
  133. astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py +15 -16
  134. astrbot/core/platform/sources/wecom/wecom_adapter.py +35 -18
  135. astrbot/core/platform/sources/wecom/wecom_event.py +55 -48
  136. astrbot/core/platform/sources/wecom/wecom_kf.py +34 -44
  137. astrbot/core/platform/sources/wecom/wecom_kf_message.py +26 -10
  138. astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +18 -10
  139. astrbot/core/platform/sources/wecom_ai_bot/__init__.py +3 -5
  140. astrbot/core/platform/sources/wecom_ai_bot/ierror.py +0 -1
  141. astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +61 -37
  142. astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +67 -28
  143. astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +8 -9
  144. astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +18 -9
  145. astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +14 -12
  146. astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +22 -12
  147. astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +40 -26
  148. astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +47 -45
  149. astrbot/core/platform_message_history_mgr.py +5 -3
  150. astrbot/core/provider/__init__.py +2 -3
  151. astrbot/core/provider/entites.py +8 -8
  152. astrbot/core/provider/entities.py +61 -75
  153. astrbot/core/provider/func_tool_manager.py +59 -55
  154. astrbot/core/provider/manager.py +32 -22
  155. astrbot/core/provider/provider.py +72 -46
  156. astrbot/core/provider/register.py +7 -7
  157. astrbot/core/provider/sources/anthropic_source.py +48 -30
  158. astrbot/core/provider/sources/azure_tts_source.py +17 -13
  159. astrbot/core/provider/sources/coze_api_client.py +27 -17
  160. astrbot/core/provider/sources/coze_source.py +104 -87
  161. astrbot/core/provider/sources/dashscope_source.py +18 -11
  162. astrbot/core/provider/sources/dashscope_tts.py +36 -23
  163. astrbot/core/provider/sources/dify_source.py +25 -20
  164. astrbot/core/provider/sources/edge_tts_source.py +21 -17
  165. astrbot/core/provider/sources/fishaudio_tts_api_source.py +22 -14
  166. astrbot/core/provider/sources/gemini_embedding_source.py +12 -13
  167. astrbot/core/provider/sources/gemini_source.py +72 -58
  168. astrbot/core/provider/sources/gemini_tts_source.py +8 -6
  169. astrbot/core/provider/sources/gsv_selfhosted_source.py +17 -14
  170. astrbot/core/provider/sources/gsvi_tts_source.py +11 -7
  171. astrbot/core/provider/sources/minimax_tts_api_source.py +50 -40
  172. astrbot/core/provider/sources/openai_embedding_source.py +6 -8
  173. astrbot/core/provider/sources/openai_source.py +77 -69
  174. astrbot/core/provider/sources/openai_tts_api_source.py +14 -6
  175. astrbot/core/provider/sources/sensevoice_selfhosted_source.py +13 -11
  176. astrbot/core/provider/sources/vllm_rerank_source.py +10 -4
  177. astrbot/core/provider/sources/volcengine_tts.py +38 -31
  178. astrbot/core/provider/sources/whisper_api_source.py +14 -12
  179. astrbot/core/provider/sources/whisper_selfhosted_source.py +15 -11
  180. astrbot/core/provider/sources/xinference_rerank_source.py +16 -8
  181. astrbot/core/provider/sources/xinference_stt_provider.py +35 -25
  182. astrbot/core/star/__init__.py +16 -11
  183. astrbot/core/star/config.py +10 -15
  184. astrbot/core/star/context.py +97 -75
  185. astrbot/core/star/filter/__init__.py +4 -3
  186. astrbot/core/star/filter/command.py +30 -28
  187. astrbot/core/star/filter/command_group.py +27 -24
  188. astrbot/core/star/filter/custom_filter.py +6 -5
  189. astrbot/core/star/filter/event_message_type.py +4 -2
  190. astrbot/core/star/filter/permission.py +4 -2
  191. astrbot/core/star/filter/platform_adapter_type.py +4 -2
  192. astrbot/core/star/filter/regex.py +4 -2
  193. astrbot/core/star/register/__init__.py +19 -19
  194. astrbot/core/star/register/star.py +6 -2
  195. astrbot/core/star/register/star_handler.py +96 -73
  196. astrbot/core/star/session_llm_manager.py +48 -14
  197. astrbot/core/star/session_plugin_manager.py +29 -15
  198. astrbot/core/star/star.py +1 -2
  199. astrbot/core/star/star_handler.py +13 -8
  200. astrbot/core/star/star_manager.py +151 -59
  201. astrbot/core/star/star_tools.py +44 -37
  202. astrbot/core/star/updator.py +10 -10
  203. astrbot/core/umop_config_router.py +10 -4
  204. astrbot/core/updator.py +13 -5
  205. astrbot/core/utils/astrbot_path.py +3 -5
  206. astrbot/core/utils/dify_api_client.py +33 -15
  207. astrbot/core/utils/io.py +66 -42
  208. astrbot/core/utils/log_pipe.py +1 -1
  209. astrbot/core/utils/metrics.py +7 -7
  210. astrbot/core/utils/path_util.py +15 -16
  211. astrbot/core/utils/pip_installer.py +5 -5
  212. astrbot/core/utils/session_waiter.py +19 -20
  213. astrbot/core/utils/shared_preferences.py +45 -20
  214. astrbot/core/utils/t2i/__init__.py +4 -1
  215. astrbot/core/utils/t2i/network_strategy.py +35 -26
  216. astrbot/core/utils/t2i/renderer.py +11 -5
  217. astrbot/core/utils/t2i/template_manager.py +14 -15
  218. astrbot/core/utils/tencent_record_helper.py +19 -13
  219. astrbot/core/utils/version_comparator.py +10 -13
  220. astrbot/core/zip_updator.py +43 -40
  221. astrbot/dashboard/routes/__init__.py +18 -18
  222. astrbot/dashboard/routes/auth.py +10 -8
  223. astrbot/dashboard/routes/chat.py +30 -21
  224. astrbot/dashboard/routes/config.py +92 -75
  225. astrbot/dashboard/routes/conversation.py +46 -39
  226. astrbot/dashboard/routes/file.py +4 -2
  227. astrbot/dashboard/routes/knowledge_base.py +47 -40
  228. astrbot/dashboard/routes/log.py +9 -4
  229. astrbot/dashboard/routes/persona.py +19 -16
  230. astrbot/dashboard/routes/plugin.py +69 -55
  231. astrbot/dashboard/routes/route.py +3 -1
  232. astrbot/dashboard/routes/session_management.py +130 -116
  233. astrbot/dashboard/routes/stat.py +34 -34
  234. astrbot/dashboard/routes/t2i.py +15 -12
  235. astrbot/dashboard/routes/tools.py +47 -52
  236. astrbot/dashboard/routes/update.py +32 -28
  237. astrbot/dashboard/server.py +30 -26
  238. astrbot/dashboard/utils.py +8 -4
  239. {astrbot-4.5.1.dist-info → astrbot-4.5.2.dist-info}/METADATA +2 -1
  240. astrbot-4.5.2.dist-info/RECORD +261 -0
  241. astrbot-4.5.1.dist-info/RECORD +0 -260
  242. {astrbot-4.5.1.dist-info → astrbot-4.5.2.dist-info}/WHEEL +0 -0
  243. {astrbot-4.5.1.dist-info → astrbot-4.5.2.dist-info}/entry_points.txt +0 -0
  244. {astrbot-4.5.1.dist-info → astrbot-4.5.2.dist-info}/licenses/LICENSE +0 -0
@@ -1,15 +1,24 @@
1
- """
2
- 本地 Agent 模式的 LLM 调用 Stage
3
- """
1
+ """本地 Agent 模式的 LLM 调用 Stage"""
4
2
 
5
3
  import asyncio
6
4
  import copy
7
5
  import json
8
6
  import traceback
9
- from datetime import timedelta
10
7
  from collections.abc import AsyncGenerator
11
- from astrbot.core.conversation_mgr import Conversation
8
+ from typing import Any
9
+
10
+ from mcp.types import CallToolResult
11
+
12
12
  from astrbot.core import logger
13
+ from astrbot.core.agent.handoff import HandoffTool
14
+ from astrbot.core.agent.hooks import BaseAgentRunHooks
15
+ from astrbot.core.agent.mcp_client import MCPTool
16
+ from astrbot.core.agent.run_context import ContextWrapper
17
+ from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
18
+ from astrbot.core.agent.tool import FunctionTool, ToolSet
19
+ from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
20
+ from astrbot.core.astr_agent_context import AstrAgentContext
21
+ from astrbot.core.conversation_mgr import Conversation
13
22
  from astrbot.core.message.components import Image
14
23
  from astrbot.core.message.message_event_result import (
15
24
  MessageChain,
@@ -22,21 +31,14 @@ from astrbot.core.provider.entities import (
22
31
  LLMResponse,
23
32
  ProviderRequest,
24
33
  )
25
- from astrbot.core.agent.hooks import BaseAgentRunHooks
26
- from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
27
- from astrbot.core.agent.run_context import ContextWrapper
28
- from astrbot.core.agent.tool import ToolSet, FunctionTool
29
- from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
30
- from astrbot.core.agent.handoff import HandoffTool
34
+ from astrbot.core.provider.register import llm_tools
31
35
  from astrbot.core.star.session_llm_manager import SessionServiceManager
32
- from astrbot.core.star.star_handler import EventType
36
+ from astrbot.core.star.star_handler import EventType, star_map
33
37
  from astrbot.core.utils.metrics import Metric
34
- from ...context import PipelineContext, call_event_hook, call_handler
38
+
39
+ from ...context import PipelineContext, call_event_hook, call_local_llm_tool
35
40
  from ..stage import Stage
36
41
  from ..utils import inject_kb_context
37
- from astrbot.core.provider.register import llm_tools
38
- from astrbot.core.star.star_handler import star_map
39
- from astrbot.core.astr_agent_context import AstrAgentContext
40
42
 
41
43
  try:
42
44
  import mcp
@@ -59,24 +61,23 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
59
61
 
60
62
  Returns:
61
63
  AsyncGenerator[None | mcp.types.CallToolResult, None]
64
+
62
65
  """
63
66
  if isinstance(tool, HandoffTool):
64
67
  async for r in cls._execute_handoff(tool, run_context, **tool_args):
65
68
  yield r
66
69
  return
67
70
 
68
- if tool.origin == "local":
69
- async for r in cls._execute_local(tool, run_context, **tool_args):
71
+ elif isinstance(tool, MCPTool):
72
+ async for r in cls._execute_mcp(tool, run_context, **tool_args):
70
73
  yield r
71
74
  return
72
75
 
73
- elif tool.origin == "mcp":
74
- async for r in cls._execute_mcp(tool, run_context, **tool_args):
76
+ else:
77
+ async for r in cls._execute_local(tool, run_context, **tool_args):
75
78
  yield r
76
79
  return
77
80
 
78
- raise Exception(f"Unknown function origin: {tool.origin}")
79
-
80
81
  @classmethod
81
82
  async def _execute_handoff(
82
83
  cls,
@@ -113,18 +114,22 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
113
114
  first_provider_request=run_context.context.first_provider_request,
114
115
  curr_provider_request=request,
115
116
  streaming=run_context.context.streaming,
117
+ event=run_context.context.event,
116
118
  )
117
119
 
120
+ event = run_context.context.event
121
+
118
122
  logger.debug(f"正在将任务委托给 Agent: {tool.agent.name}, input: {input_}")
119
- await run_context.event.send(
120
- MessageChain().message("✨ 正在将任务委托给 Agent: " + tool.agent.name)
123
+ await event.send(
124
+ MessageChain().message("✨ 正在将任务委托给 Agent: " + tool.agent.name),
121
125
  )
122
126
 
123
127
  await agent_runner.reset(
124
128
  provider=run_context.context.provider,
125
129
  request=request,
126
130
  run_context=AgentContextWrapper(
127
- context=astr_agent_ctx, event=run_context.event
131
+ context=astr_agent_ctx,
132
+ tool_call_timeout=run_context.tool_call_timeout,
128
133
  ),
129
134
  tool_executor=FunctionToolExecutor(),
130
135
  agent_hooks=tool.agent.run_hooks or BaseAgentRunHooks[AstrAgentContext](),
@@ -146,7 +151,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
146
151
  return
147
152
 
148
153
  logger.debug(
149
- f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}"
154
+ f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}",
150
155
  )
151
156
 
152
157
  result = (
@@ -174,25 +179,46 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
174
179
  run_context: ContextWrapper[AstrAgentContext],
175
180
  **tool_args,
176
181
  ):
177
- if not run_context.event:
182
+ event = run_context.context.event
183
+ if not event:
178
184
  raise ValueError("Event must be provided for local function tools.")
179
185
 
180
- # 检查 tool 下有没有 run 方法
181
- if not tool.handler and not hasattr(tool, "run"):
182
- raise ValueError("Tool must have a valid handler or 'run' method.")
183
- awaitable = tool.handler or getattr(tool, "run")
186
+ is_override_call = False
187
+ for ty in type(tool).mro():
188
+ if "call" in ty.__dict__ and ty.__dict__["call"] is not FunctionTool.call:
189
+ logger.debug(f"Found call in: {ty}")
190
+ is_override_call = True
191
+ break
184
192
 
185
- wrapper = call_handler(
186
- event=run_context.event,
193
+ # 检查 tool 下有没有 run 方法
194
+ if not tool.handler and not hasattr(tool, "run") and not is_override_call:
195
+ raise ValueError("Tool must have a valid handler or override 'run' method.")
196
+
197
+ awaitable = None
198
+ method_name = ""
199
+ if tool.handler:
200
+ awaitable = tool.handler
201
+ method_name = "decorator_handler"
202
+ elif is_override_call:
203
+ awaitable = tool.call
204
+ method_name = "call"
205
+ elif hasattr(tool, "run"):
206
+ awaitable = getattr(tool, "run")
207
+ method_name = "run"
208
+ if awaitable is None:
209
+ raise ValueError("Tool must have a valid handler or override 'run' method.")
210
+
211
+ wrapper = call_local_llm_tool(
212
+ context=run_context,
187
213
  handler=awaitable,
214
+ method_name=method_name,
188
215
  **tool_args,
189
216
  )
190
- # async for resp in wrapper:
191
217
  while True:
192
218
  try:
193
219
  resp = await asyncio.wait_for(
194
220
  anext(wrapper),
195
- timeout=run_context.context.tool_call_timeout,
221
+ timeout=run_context.tool_call_timeout,
196
222
  )
197
223
  if resp is not None:
198
224
  if isinstance(resp, mcp.types.CallToolResult):
@@ -207,10 +233,24 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
207
233
  # NOTE: Tool 在这里直接请求发送消息给用户
208
234
  # TODO: 是否需要判断 event.get_result() 是否为空?
209
235
  # 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
236
+ if res := run_context.context.event.get_result():
237
+ if res.chain:
238
+ try:
239
+ await event.send(
240
+ MessageChain(
241
+ chain=res.chain,
242
+ type="tool_direct_result",
243
+ )
244
+ )
245
+ except Exception as e:
246
+ logger.error(
247
+ f"Tool 直接发送消息失败: {e}",
248
+ exc_info=True,
249
+ )
210
250
  yield None
211
251
  except asyncio.TimeoutError:
212
252
  raise Exception(
213
- f"tool {tool.name} execution timeout after {run_context.context.tool_call_timeout} seconds."
253
+ f"tool {tool.name} execution timeout after {run_context.tool_call_timeout} seconds.",
214
254
  )
215
255
  except StopAsyncIteration:
216
256
  break
@@ -222,19 +262,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
222
262
  run_context: ContextWrapper[AstrAgentContext],
223
263
  **tool_args,
224
264
  ):
225
- if not tool.mcp_client:
226
- raise ValueError("MCP client is not available for MCP function tools.")
227
-
228
- session = tool.mcp_client.session
229
- if not session:
230
- raise ValueError("MCP session is not available for MCP function tools.")
231
- res = await session.call_tool(
232
- name=tool.name,
233
- arguments=tool_args,
234
- read_timeout_seconds=timedelta(
235
- seconds=run_context.context.tool_call_timeout
236
- ),
237
- )
265
+ res = await tool.call(run_context, **tool_args)
238
266
  if not res:
239
267
  return
240
268
  yield res
@@ -244,18 +272,31 @@ class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
244
272
  async def on_agent_done(self, run_context, llm_response):
245
273
  # 执行事件钩子
246
274
  await call_event_hook(
247
- run_context.event, EventType.OnLLMResponseEvent, llm_response
275
+ run_context.context.event,
276
+ EventType.OnLLMResponseEvent,
277
+ llm_response,
248
278
  )
249
279
 
280
+ async def on_tool_end(
281
+ self,
282
+ run_context: ContextWrapper[AstrAgentContext],
283
+ tool: FunctionTool[Any],
284
+ tool_args: dict | None,
285
+ tool_result: CallToolResult | None,
286
+ ):
287
+ run_context.context.event.clear_result()
288
+
250
289
 
251
290
  MAIN_AGENT_HOOKS = MainAgentHooks()
252
291
 
253
292
 
254
293
  async def run_agent(
255
- agent_runner: AgentRunner, max_step: int = 30, show_tool_use: bool = True
294
+ agent_runner: AgentRunner,
295
+ max_step: int = 30,
296
+ show_tool_use: bool = True,
256
297
  ) -> AsyncGenerator[MessageChain, None]:
257
298
  step_idx = 0
258
- astr_event = agent_runner.run_context.event
299
+ astr_event = agent_runner.run_context.context.event
259
300
  while step_idx < max_step:
260
301
  step_idx += 1
261
302
  try:
@@ -290,19 +331,18 @@ async def run_agent(
290
331
  MessageEventResult(
291
332
  chain=resp.data["chain"].chain,
292
333
  result_content_type=content_typ,
293
- )
334
+ ),
294
335
  )
295
336
  yield
296
337
  astr_event.clear_result()
297
- else:
298
- if resp.type == "streaming_delta":
299
- yield resp.data["chain"] # MessageChain
338
+ elif resp.type == "streaming_delta":
339
+ yield resp.data["chain"] # MessageChain
300
340
  if agent_runner.done():
301
341
  break
302
342
 
303
343
  except Exception as e:
304
344
  logger.error(traceback.format_exc())
305
- err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n"
345
+ err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在控制台查看和分享错误详情。\n"
306
346
  if agent_runner.streaming:
307
347
  yield MessageChain().message(err_msg)
308
348
  else:
@@ -332,7 +372,7 @@ class LLMRequestSubStage(Stage):
332
372
  for bwp in self.bot_wake_prefixs:
333
373
  if self.provider_wake_prefix.startswith(bwp):
334
374
  logger.info(
335
- f"识别 LLM 聊天额外唤醒前缀 {self.provider_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。"
375
+ f"识别 LLM 聊天额外唤醒前缀 {self.provider_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。",
336
376
  )
337
377
  self.provider_wake_prefix = self.provider_wake_prefix[len(bwp) :]
338
378
 
@@ -367,7 +407,9 @@ class LLMRequestSubStage(Stage):
367
407
  return conversation
368
408
 
369
409
  async def process(
370
- self, event: AstrMessageEvent, _nested: bool = False
410
+ self,
411
+ event: AstrMessageEvent,
412
+ _nested: bool = False,
371
413
  ) -> None | AsyncGenerator[None, None]:
372
414
  req: ProviderRequest | None = None
373
415
 
@@ -423,7 +465,9 @@ class LLMRequestSubStage(Stage):
423
465
  # 应用知识库
424
466
  try:
425
467
  await inject_kb_context(
426
- umo=event.unified_msg_origin, p_ctx=self.ctx, req=req
468
+ umo=event.unified_msg_origin,
469
+ p_ctx=self.ctx,
470
+ req=req,
427
471
  )
428
472
  except Exception as e:
429
473
  logger.error(f"调用知识库时遇到问题: {e}")
@@ -475,7 +519,7 @@ class LLMRequestSubStage(Stage):
475
519
  # 如果模型不支持工具使用,但请求中包含工具列表,则清空。
476
520
  if "tool_use" not in provider_cfg:
477
521
  logger.debug(
478
- f"用户设置提供商 {provider} 不支持工具使用,清空工具列表。"
522
+ f"用户设置提供商 {provider} 不支持工具使用,清空工具列表。",
479
523
  )
480
524
  req.func_tool = None
481
525
  # 插件可用性设置
@@ -498,19 +542,22 @@ class LLMRequestSubStage(Stage):
498
542
  # run agent
499
543
  agent_runner = AgentRunner()
500
544
  logger.debug(
501
- f"handle provider[id: {provider.provider_config['id']}] request: {req}"
545
+ f"handle provider[id: {provider.provider_config['id']}] request: {req}",
502
546
  )
503
547
  astr_agent_ctx = AstrAgentContext(
504
548
  provider=provider,
505
549
  first_provider_request=req,
506
550
  curr_provider_request=req,
507
551
  streaming=self.streaming_response,
508
- tool_call_timeout=self.tool_call_timeout,
552
+ event=event,
509
553
  )
510
554
  await agent_runner.reset(
511
555
  provider=provider,
512
556
  request=req,
513
- run_context=AgentContextWrapper(context=astr_agent_ctx, event=event),
557
+ run_context=AgentContextWrapper(
558
+ context=astr_agent_ctx,
559
+ tool_call_timeout=self.tool_call_timeout,
560
+ ),
514
561
  tool_executor=FunctionToolExecutor(),
515
562
  agent_hooks=MAIN_AGENT_HOOKS,
516
563
  streaming=self.streaming_response,
@@ -522,8 +569,8 @@ class LLMRequestSubStage(Stage):
522
569
  MessageEventResult()
523
570
  .set_result_content_type(ResultContentType.STREAMING_RESULT)
524
571
  .set_async_stream(
525
- run_agent(agent_runner, self.max_step, self.show_tool_use)
526
- )
572
+ run_agent(agent_runner, self.max_step, self.show_tool_use),
573
+ ),
527
574
  )
528
575
  yield
529
576
  if agent_runner.done():
@@ -540,7 +587,7 @@ class LLMRequestSubStage(Stage):
540
587
  MessageEventResult(
541
588
  chain=chain,
542
589
  result_content_type=ResultContentType.STREAMING_FINISH,
543
- )
590
+ ),
544
591
  )
545
592
  else:
546
593
  async for _ in run_agent(agent_runner, self.max_step, self.show_tool_use):
@@ -560,17 +607,21 @@ class LLMRequestSubStage(Stage):
560
607
  llm_tick=1,
561
608
  model_name=agent_runner.provider.get_model(),
562
609
  provider_type=agent_runner.provider.meta().type,
563
- )
610
+ ),
564
611
  )
565
612
 
566
613
  async def _handle_webchat(
567
- self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider
614
+ self,
615
+ event: AstrMessageEvent,
616
+ req: ProviderRequest,
617
+ prov: Provider,
568
618
  ):
569
619
  """处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title"""
570
620
  if not req.conversation:
571
621
  return
572
622
  conversation = await self.conv_manager.get_conversation(
573
- event.unified_msg_origin, req.conversation.cid
623
+ event.unified_msg_origin,
624
+ req.conversation.cid,
574
625
  )
575
626
  if conversation and not req.conversation.title:
576
627
  messages = json.loads(conversation.history)
@@ -607,7 +658,7 @@ class LLMRequestSubStage(Stage):
607
658
  )
608
659
  if llm_resp and llm_resp.completion_text:
609
660
  logger.debug(
610
- f"WebChat 对话标题生成响应: {llm_resp.completion_text.strip()}"
661
+ f"WebChat 对话标题生成响应: {llm_resp.completion_text.strip()}",
611
662
  )
612
663
  title = llm_resp.completion_text.strip()
613
664
  if not title or "<None>" in title:
@@ -650,7 +701,9 @@ class LLMRequestSubStage(Stage):
650
701
  messages.append({"role": "assistant", "content": llm_response.completion_text})
651
702
  messages = list(filter(lambda item: "_no_save" not in item, messages))
652
703
  await self.conv_manager.update_conversation(
653
- event.unified_msg_origin, req.conversation.cid, history=messages
704
+ event.unified_msg_origin,
705
+ req.conversation.cid,
706
+ history=messages,
654
707
  )
655
708
 
656
709
  def fix_messages(self, messages: list[dict]) -> list[dict]:
@@ -1,16 +1,17 @@
1
- """
2
- 本地 Agent 模式的 AstrBot 插件调用 Stage
3
- """
1
+ """本地 Agent 模式的 AstrBot 插件调用 Stage"""
2
+
3
+ import traceback
4
+ from collections.abc import AsyncGenerator
5
+ from typing import Any
4
6
 
5
- from ...context import PipelineContext, call_handler
6
- from ..stage import Stage
7
- from typing import Dict, Any, List, AsyncGenerator, Union
8
- from astrbot.core.platform.astr_message_event import AstrMessageEvent
9
- from astrbot.core.message.message_event_result import MessageEventResult
10
7
  from astrbot.core import logger
11
- from astrbot.core.star.star_handler import StarHandlerMetadata
8
+ from astrbot.core.message.message_event_result import MessageEventResult
9
+ from astrbot.core.platform.astr_message_event import AstrMessageEvent
12
10
  from astrbot.core.star.star import star_map
13
- import traceback
11
+ from astrbot.core.star.star_handler import StarHandlerMetadata
12
+
13
+ from ...context import PipelineContext, call_handler
14
+ from ..stage import Stage
14
15
 
15
16
 
16
17
  class StarRequestSubStage(Stage):
@@ -21,13 +22,14 @@ class StarRequestSubStage(Stage):
21
22
  self.ctx = ctx
22
23
 
23
24
  async def process(
24
- self, event: AstrMessageEvent
25
- ) -> Union[None, AsyncGenerator[None, None]]:
26
- activated_handlers: List[StarHandlerMetadata] = event.get_extra(
27
- "activated_handlers"
25
+ self,
26
+ event: AstrMessageEvent,
27
+ ) -> None | AsyncGenerator[None, None]:
28
+ activated_handlers: list[StarHandlerMetadata] = event.get_extra(
29
+ "activated_handlers",
28
30
  )
29
- handlers_parsed_params: Dict[str, Dict[str, Any]] = event.get_extra(
30
- "handlers_parsed_params"
31
+ handlers_parsed_params: dict[str, dict[str, Any]] = event.get_extra(
32
+ "handlers_parsed_params",
31
33
  )
32
34
  if not handlers_parsed_params:
33
35
  handlers_parsed_params = {}
@@ -37,7 +39,7 @@ class StarRequestSubStage(Stage):
37
39
  md = star_map.get(handler.handler_module_path)
38
40
  if not md:
39
41
  logger.warning(
40
- f"Cannot find plugin for given handler module path: {handler.handler_module_path}"
42
+ f"Cannot find plugin for given handler module path: {handler.handler_module_path}",
41
43
  )
42
44
  continue
43
45
  logger.debug(f"plugin -> {md.name} - {handler.handler_name}")
@@ -1,12 +1,14 @@
1
- from typing import List, Union, AsyncGenerator
2
- from ..stage import Stage, register_stage
1
+ from collections.abc import AsyncGenerator
2
+
3
+ from astrbot.core import logger
4
+ from astrbot.core.platform.astr_message_event import AstrMessageEvent
5
+ from astrbot.core.provider.entities import ProviderRequest
6
+ from astrbot.core.star.star_handler import StarHandlerMetadata
7
+
3
8
  from ..context import PipelineContext
9
+ from ..stage import Stage, register_stage
4
10
  from .method.llm_request import LLMRequestSubStage
5
11
  from .method.star_request import StarRequestSubStage
6
- from astrbot.core.platform.astr_message_event import AstrMessageEvent
7
- from astrbot.core.star.star_handler import StarHandlerMetadata
8
- from astrbot.core.provider.entities import ProviderRequest
9
- from astrbot.core import logger
10
12
 
11
13
 
12
14
  @register_stage
@@ -22,11 +24,12 @@ class ProcessStage(Stage):
22
24
  await self.star_request_sub_stage.initialize(ctx)
23
25
 
24
26
  async def process(
25
- self, event: AstrMessageEvent
26
- ) -> Union[None, AsyncGenerator[None, None]]:
27
+ self,
28
+ event: AstrMessageEvent,
29
+ ) -> None | AsyncGenerator[None, None]:
27
30
  """处理事件"""
28
- activated_handlers: List[StarHandlerMetadata] = event.get_extra(
29
- "activated_handlers"
31
+ activated_handlers: list[StarHandlerMetadata] = event.get_extra(
32
+ "activated_handlers",
30
33
  )
31
34
  # 有插件 Handler 被激活
32
35
  if activated_handlers:
@@ -1,6 +1,7 @@
1
- from ..context import PipelineContext
2
- from astrbot.core.provider.entities import ProviderRequest
3
1
  from astrbot.api import logger, sp
2
+ from astrbot.core.provider.entities import ProviderRequest
3
+
4
+ from ..context import PipelineContext
4
5
 
5
6
 
6
7
  async def inject_kb_context(
@@ -8,14 +9,14 @@ async def inject_kb_context(
8
9
  p_ctx: PipelineContext,
9
10
  req: ProviderRequest,
10
11
  ) -> None:
11
- """inject knowledge base context into the provider request
12
+ """Inject knowledge base context into the provider request
12
13
 
13
14
  Args:
14
15
  umo: Unique message object (session ID)
15
16
  p_ctx: Pipeline context
16
17
  req: Provider request
17
- """
18
18
 
19
+ """
19
20
  kb_mgr = p_ctx.plugin_manager.context.kb_manager
20
21
 
21
22
  # 1. 优先读取会话级配置
@@ -45,7 +46,7 @@ async def inject_kb_context(
45
46
 
46
47
  if invalid_kb_ids:
47
48
  logger.warning(
48
- f"[知识库] 会话 {umo} 配置的以下知识库无效: {invalid_kb_ids}"
49
+ f"[知识库] 会话 {umo} 配置的以下知识库无效: {invalid_kb_ids}",
49
50
  )
50
51
 
51
52
  if not kb_names:
@@ -1,18 +1,19 @@
1
1
  import asyncio
2
- from datetime import datetime, timedelta
3
2
  from collections import defaultdict, deque
4
- from typing import DefaultDict, Deque, Union, AsyncGenerator
5
- from ..stage import Stage, register_stage
6
- from ..context import PipelineContext
7
- from astrbot.core.platform.astr_message_event import AstrMessageEvent
3
+ from collections.abc import AsyncGenerator
4
+ from datetime import datetime, timedelta
5
+
8
6
  from astrbot.core import logger
9
7
  from astrbot.core.config.astrbot_config import RateLimitStrategy
8
+ from astrbot.core.platform.astr_message_event import AstrMessageEvent
9
+
10
+ from ..context import PipelineContext
11
+ from ..stage import Stage, register_stage
10
12
 
11
13
 
12
14
  @register_stage
13
15
  class RateLimitStage(Stage):
14
- """
15
- 检查是否需要限制消息发送的限流器。
16
+ """检查是否需要限制消息发送的限流器。
16
17
 
17
18
  使用 Fixed Window 算法。
18
19
  如果触发限流,将 stall 流水线,直到下一个时间窗口来临时自动唤醒。
@@ -20,32 +21,30 @@ class RateLimitStage(Stage):
20
21
 
21
22
  def __init__(self):
22
23
  # 存储每个会话的请求时间队列
23
- self.event_timestamps: DefaultDict[str, Deque[datetime]] = defaultdict(deque)
24
+ self.event_timestamps: defaultdict[str, deque[datetime]] = defaultdict(deque)
24
25
  # 为每个会话设置一个锁,避免并发冲突
25
- self.locks: DefaultDict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
26
+ self.locks: defaultdict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
26
27
  # 限流参数
27
28
  self.rate_limit_count: int = 0
28
29
  self.rate_limit_time: timedelta = timedelta(0)
29
30
 
30
31
  async def initialize(self, ctx: PipelineContext) -> None:
31
- """
32
- 初始化限流器,根据配置设置限流参数。
33
- """
32
+ """初始化限流器,根据配置设置限流参数。"""
34
33
  self.rate_limit_count = ctx.astrbot_config["platform_settings"]["rate_limit"][
35
34
  "count"
36
35
  ]
37
36
  self.rate_limit_time = timedelta(
38
- seconds=ctx.astrbot_config["platform_settings"]["rate_limit"]["time"]
37
+ seconds=ctx.astrbot_config["platform_settings"]["rate_limit"]["time"],
39
38
  )
40
39
  self.rl_strategy = ctx.astrbot_config["platform_settings"]["rate_limit"][
41
40
  "strategy"
42
41
  ] # stall or discard
43
42
 
44
43
  async def process(
45
- self, event: AstrMessageEvent
46
- ) -> Union[None, AsyncGenerator[None, None]]:
47
- """
48
- 检查并处理限流逻辑。如果触发限流,流水线会 stall 并在窗口期后自动恢复。
44
+ self,
45
+ event: AstrMessageEvent,
46
+ ) -> None | AsyncGenerator[None, None]:
47
+ """检查并处理限流逻辑。如果触发限流,流水线会 stall 并在窗口期后自动恢复。
49
48
 
50
49
  Args:
51
50
  event (AstrMessageEvent): 当前消息事件。
@@ -53,6 +52,7 @@ class RateLimitStage(Stage):
53
52
 
54
53
  Returns:
55
54
  MessageEventResult: 继续或停止事件处理的结果。
55
+
56
56
  """
57
57
  session_id = event.session_id
58
58
  now = datetime.now()
@@ -66,32 +66,33 @@ class RateLimitStage(Stage):
66
66
  if len(timestamps) < self.rate_limit_count:
67
67
  timestamps.append(now)
68
68
  break
69
- else:
70
- next_window_time = timestamps[0] + self.rate_limit_time
71
- stall_duration = (next_window_time - now).total_seconds() + 0.3
72
-
73
- match self.rl_strategy:
74
- case RateLimitStrategy.STALL.value:
75
- logger.info(
76
- f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。"
77
- )
78
- await asyncio.sleep(stall_duration)
79
- now = datetime.now()
80
- case RateLimitStrategy.DISCARD.value:
81
- logger.info(
82
- f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到限额于 {stall_duration:.2f} 秒后重置。"
83
- )
84
- return event.stop_event()
69
+ next_window_time = timestamps[0] + self.rate_limit_time
70
+ stall_duration = (next_window_time - now).total_seconds() + 0.3
71
+
72
+ match self.rl_strategy:
73
+ case RateLimitStrategy.STALL.value:
74
+ logger.info(
75
+ f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。",
76
+ )
77
+ await asyncio.sleep(stall_duration)
78
+ now = datetime.now()
79
+ case RateLimitStrategy.DISCARD.value:
80
+ logger.info(
81
+ f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到限额于 {stall_duration:.2f} 秒后重置。",
82
+ )
83
+ return event.stop_event()
85
84
 
86
85
  def _remove_expired_timestamps(
87
- self, timestamps: Deque[datetime], now: datetime
86
+ self,
87
+ timestamps: deque[datetime],
88
+ now: datetime,
88
89
  ) -> None:
89
- """
90
- 移除时间窗口外的时间戳。
90
+ """移除时间窗口外的时间戳。
91
91
 
92
92
  Args:
93
93
  timestamps (Deque[datetime]): 当前会话的时间戳队列。
94
94
  now (datetime): 当前时间,用于计算过期时间。
95
+
95
96
  """
96
97
  expiry_threshold: datetime = now - self.rate_limit_time
97
98
  while timestamps and timestamps[0] < expiry_threshold: