AstrBot 4.5.6__py3-none-any.whl → 4.5.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. astrbot/api/all.py +2 -1
  2. astrbot/api/provider/__init__.py +2 -1
  3. astrbot/core/agent/run_context.py +7 -2
  4. astrbot/core/agent/runners/base.py +7 -0
  5. astrbot/core/agent/runners/tool_loop_agent_runner.py +51 -3
  6. astrbot/core/agent/tool.py +5 -6
  7. astrbot/core/astr_agent_context.py +13 -8
  8. astrbot/core/astr_agent_hooks.py +36 -0
  9. astrbot/core/astr_agent_run_util.py +80 -0
  10. astrbot/core/astr_agent_tool_exec.py +246 -0
  11. astrbot/core/config/default.py +53 -7
  12. astrbot/core/exceptions.py +9 -0
  13. astrbot/core/pipeline/context.py +1 -2
  14. astrbot/core/pipeline/context_utils.py +0 -65
  15. astrbot/core/pipeline/process_stage/method/llm_request.py +239 -491
  16. astrbot/core/pipeline/respond/stage.py +21 -20
  17. astrbot/core/platform/platform_metadata.py +3 -0
  18. astrbot/core/platform/register.py +2 -0
  19. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +2 -0
  20. astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +16 -5
  21. astrbot/core/platform/sources/discord/discord_platform_adapter.py +4 -1
  22. astrbot/core/platform/sources/discord/discord_platform_event.py +16 -7
  23. astrbot/core/platform/sources/lark/lark_adapter.py +4 -1
  24. astrbot/core/platform/sources/misskey/misskey_adapter.py +4 -1
  25. astrbot/core/platform/sources/satori/satori_adapter.py +2 -2
  26. astrbot/core/platform/sources/slack/slack_adapter.py +2 -0
  27. astrbot/core/platform/sources/webchat/webchat_adapter.py +3 -0
  28. astrbot/core/platform/sources/webchat/webchat_event.py +8 -1
  29. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +4 -1
  30. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +16 -0
  31. astrbot/core/platform/sources/wecom/wecom_adapter.py +2 -1
  32. astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +4 -1
  33. astrbot/core/provider/__init__.py +2 -2
  34. astrbot/core/provider/entities.py +40 -18
  35. astrbot/core/provider/func_tool_manager.py +15 -6
  36. astrbot/core/provider/manager.py +4 -1
  37. astrbot/core/provider/provider.py +7 -22
  38. astrbot/core/provider/register.py +2 -0
  39. astrbot/core/provider/sources/anthropic_source.py +0 -2
  40. astrbot/core/provider/sources/coze_source.py +0 -2
  41. astrbot/core/provider/sources/dashscope_source.py +1 -3
  42. astrbot/core/provider/sources/dify_source.py +0 -2
  43. astrbot/core/provider/sources/gemini_source.py +31 -3
  44. astrbot/core/provider/sources/groq_source.py +15 -0
  45. astrbot/core/provider/sources/openai_source.py +67 -21
  46. astrbot/core/provider/sources/zhipu_source.py +1 -6
  47. astrbot/core/star/context.py +197 -45
  48. astrbot/core/star/register/star_handler.py +30 -10
  49. astrbot/dashboard/routes/chat.py +5 -0
  50. {astrbot-4.5.6.dist-info → astrbot-4.5.7.dist-info}/METADATA +2 -2
  51. {astrbot-4.5.6.dist-info → astrbot-4.5.7.dist-info}/RECORD +54 -49
  52. {astrbot-4.5.6.dist-info → astrbot-4.5.7.dist-info}/WHEEL +0 -0
  53. {astrbot-4.5.6.dist-info → astrbot-4.5.7.dist-info}/entry_points.txt +0 -0
  54. {astrbot-4.5.6.dist-info → astrbot-4.5.7.dist-info}/licenses/LICENSE +0 -0
astrbot/api/all.py CHANGED
@@ -36,7 +36,8 @@ from astrbot.core.star.config import *
36
36
 
37
37
 
38
38
  # provider
39
- from astrbot.core.provider import Provider, Personality, ProviderMetaData
39
+ from astrbot.core.provider import Provider, ProviderMetaData
40
+ from astrbot.core.db.po import Personality
40
41
 
41
42
  # platform
42
43
  from astrbot.core.platform import (
@@ -1,4 +1,5 @@
1
- from astrbot.core.provider import Personality, Provider, STTProvider
1
+ from astrbot.core.db.po import Personality
2
+ from astrbot.core.provider import Provider, STTProvider
2
3
  from astrbot.core.provider.entities import (
3
4
  LLMResponse,
4
5
  ProviderMetaData,
@@ -1,16 +1,21 @@
1
- from dataclasses import dataclass
2
1
  from typing import Any, Generic
3
2
 
3
+ from pydantic import Field
4
+ from pydantic.dataclasses import dataclass
4
5
  from typing_extensions import TypeVar
5
6
 
7
+ from .message import Message
8
+
6
9
  TContext = TypeVar("TContext", default=Any)
7
10
 
8
11
 
9
- @dataclass
12
+ @dataclass(config={"arbitrary_types_allowed": True})
10
13
  class ContextWrapper(Generic[TContext]):
11
14
  """A context for running an agent, which can be used to pass additional data or state."""
12
15
 
13
16
  context: TContext
17
+ messages: list[Message] = Field(default_factory=list)
18
+ """This field stores the llm message context for the agent run, agent runners will maintain this field automatically."""
14
19
  tool_call_timeout: int = 60 # Default tool call timeout in seconds
15
20
 
16
21
 
@@ -40,6 +40,13 @@ class BaseAgentRunner(T.Generic[TContext]):
40
40
  """Process a single step of the agent."""
41
41
  ...
42
42
 
43
+ @abc.abstractmethod
44
+ async def step_until_done(
45
+ self, max_step: int
46
+ ) -> T.AsyncGenerator[AgentResponse, None]:
47
+ """Process steps until the agent is done."""
48
+ ...
49
+
43
50
  @abc.abstractmethod
44
51
  def done(self) -> bool:
45
52
  """Check if the agent has completed its task.
@@ -23,7 +23,7 @@ from astrbot.core.provider.entities import (
23
23
  from astrbot.core.provider.provider import Provider
24
24
 
25
25
  from ..hooks import BaseAgentRunHooks
26
- from ..message import AssistantMessageSegment, ToolCallMessageSegment
26
+ from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
27
27
  from ..response import AgentResponseData
28
28
  from ..run_context import ContextWrapper, TContext
29
29
  from ..tool_executor import BaseFunctionToolExecutor
@@ -55,6 +55,20 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
55
55
  self.agent_hooks = agent_hooks
56
56
  self.run_context = run_context
57
57
 
58
+ messages = []
59
+ # append existing messages in the run context
60
+ for msg in request.contexts:
61
+ messages.append(Message.model_validate(msg))
62
+ if request.prompt is not None:
63
+ m = await request.assemble_context()
64
+ messages.append(Message.model_validate(m))
65
+ if request.system_prompt:
66
+ messages.insert(
67
+ 0,
68
+ Message(role="system", content=request.system_prompt),
69
+ )
70
+ self.run_context.messages = messages
71
+
58
72
  def _transition_state(self, new_state: AgentState) -> None:
59
73
  """转换 Agent 状态"""
60
74
  if self._state != new_state:
@@ -96,13 +110,22 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
96
110
  type="streaming_delta",
97
111
  data=AgentResponseData(chain=llm_response.result_chain),
98
112
  )
99
- else:
113
+ elif llm_response.completion_text:
100
114
  yield AgentResponse(
101
115
  type="streaming_delta",
102
116
  data=AgentResponseData(
103
117
  chain=MessageChain().message(llm_response.completion_text),
104
118
  ),
105
119
  )
120
+ elif llm_response.reasoning_content:
121
+ yield AgentResponse(
122
+ type="streaming_delta",
123
+ data=AgentResponseData(
124
+ chain=MessageChain(type="reasoning").message(
125
+ llm_response.reasoning_content,
126
+ ),
127
+ ),
128
+ )
106
129
  continue
107
130
  llm_resp_result = llm_response
108
131
  break # got final response
@@ -130,6 +153,13 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
130
153
  # 如果没有工具调用,转换到完成状态
131
154
  self.final_llm_resp = llm_resp
132
155
  self._transition_state(AgentState.DONE)
156
+ # record the final assistant message
157
+ self.run_context.messages.append(
158
+ Message(
159
+ role="assistant",
160
+ content=llm_resp.completion_text or "",
161
+ ),
162
+ )
133
163
  try:
134
164
  await self.agent_hooks.on_agent_done(self.run_context, llm_resp)
135
165
  except Exception as e:
@@ -156,13 +186,16 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
156
186
  yield AgentResponse(
157
187
  type="tool_call",
158
188
  data=AgentResponseData(
159
- chain=MessageChain().message(f"🔨 调用工具: {tool_call_name}"),
189
+ chain=MessageChain(type="tool_call").message(
190
+ f"🔨 调用工具: {tool_call_name}"
191
+ ),
160
192
  ),
161
193
  )
162
194
  async for result in self._handle_function_tools(self.req, llm_resp):
163
195
  if isinstance(result, list):
164
196
  tool_call_result_blocks = result
165
197
  elif isinstance(result, MessageChain):
198
+ result.type = "tool_call_result"
166
199
  yield AgentResponse(
167
200
  type="tool_call_result",
168
201
  data=AgentResponseData(chain=result),
@@ -175,8 +208,23 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
175
208
  ),
176
209
  tool_calls_result=tool_call_result_blocks,
177
210
  )
211
+ # record the assistant message with tool calls
212
+ self.run_context.messages.extend(
213
+ tool_calls_result.to_openai_messages_model()
214
+ )
215
+
178
216
  self.req.append_tool_calls_result(tool_calls_result)
179
217
 
218
+ async def step_until_done(
219
+ self, max_step: int
220
+ ) -> T.AsyncGenerator[AgentResponse, None]:
221
+ """Process steps until the agent is done."""
222
+ step_count = 0
223
+ while not self.done() and step_count < max_step:
224
+ step_count += 1
225
+ async for resp in self.step():
226
+ yield resp
227
+
180
228
  async def _handle_function_tools(
181
229
  self,
182
230
  req: ProviderRequest,
@@ -4,12 +4,13 @@ from typing import Any, Generic
4
4
  import jsonschema
5
5
  import mcp
6
6
  from deprecated import deprecated
7
- from pydantic import model_validator
7
+ from pydantic import Field, model_validator
8
8
  from pydantic.dataclasses import dataclass
9
9
 
10
10
  from .run_context import ContextWrapper, TContext
11
11
 
12
12
  ParametersType = dict[str, Any]
13
+ ToolExecResult = str | mcp.types.CallToolResult
13
14
 
14
15
 
15
16
  @dataclass
@@ -55,15 +56,14 @@ class FunctionTool(ToolSchema, Generic[TContext]):
55
56
  def __repr__(self):
56
57
  return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description})"
57
58
 
58
- async def call(
59
- self, context: ContextWrapper[TContext], **kwargs
60
- ) -> str | mcp.types.CallToolResult:
59
+ async def call(self, context: ContextWrapper[TContext], **kwargs) -> ToolExecResult:
61
60
  """Run the tool with the given arguments. The handler field has priority."""
62
61
  raise NotImplementedError(
63
62
  "FunctionTool.call() must be implemented by subclasses or set a handler."
64
63
  )
65
64
 
66
65
 
66
+ @dataclass
67
67
  class ToolSet:
68
68
  """A set of function tools that can be used in function calling.
69
69
 
@@ -71,8 +71,7 @@ class ToolSet:
71
71
  convert the tools to different API formats (OpenAI, Anthropic, Google GenAI).
72
72
  """
73
73
 
74
- def __init__(self, tools: list[FunctionTool] | None = None):
75
- self.tools: list[FunctionTool] = tools or []
74
+ tools: list[FunctionTool] = Field(default_factory=list)
76
75
 
77
76
  def empty(self) -> bool:
78
77
  """Check if the tool set is empty."""
@@ -1,14 +1,19 @@
1
- from dataclasses import dataclass
1
+ from pydantic import Field
2
+ from pydantic.dataclasses import dataclass
2
3
 
4
+ from astrbot.core.agent.run_context import ContextWrapper
3
5
  from astrbot.core.platform.astr_message_event import AstrMessageEvent
4
- from astrbot.core.provider import Provider
5
- from astrbot.core.provider.entities import ProviderRequest
6
+ from astrbot.core.star.context import Context
6
7
 
7
8
 
8
- @dataclass
9
+ @dataclass(config={"arbitrary_types_allowed": True})
9
10
  class AstrAgentContext:
10
- provider: Provider
11
- first_provider_request: ProviderRequest
12
- curr_provider_request: ProviderRequest
13
- streaming: bool
11
+ context: Context
12
+ """The star context instance"""
14
13
  event: AstrMessageEvent
14
+ """The message event associated with the agent context."""
15
+ extra: dict[str, str] = Field(default_factory=dict)
16
+ """Customized extra data."""
17
+
18
+
19
+ AgentContextWrapper = ContextWrapper[AstrAgentContext]
@@ -0,0 +1,36 @@
1
+ from typing import Any
2
+
3
+ from mcp.types import CallToolResult
4
+
5
+ from astrbot.core.agent.hooks import BaseAgentRunHooks
6
+ from astrbot.core.agent.run_context import ContextWrapper
7
+ from astrbot.core.agent.tool import FunctionTool
8
+ from astrbot.core.astr_agent_context import AstrAgentContext
9
+ from astrbot.core.pipeline.context_utils import call_event_hook
10
+ from astrbot.core.star.star_handler import EventType
11
+
12
+
13
+ class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
14
+ async def on_agent_done(self, run_context, llm_response):
15
+ # 执行事件钩子
16
+ await call_event_hook(
17
+ run_context.context.event,
18
+ EventType.OnLLMResponseEvent,
19
+ llm_response,
20
+ )
21
+
22
+ async def on_tool_end(
23
+ self,
24
+ run_context: ContextWrapper[AstrAgentContext],
25
+ tool: FunctionTool[Any],
26
+ tool_args: dict | None,
27
+ tool_result: CallToolResult | None,
28
+ ):
29
+ run_context.context.event.clear_result()
30
+
31
+
32
+ class EmptyAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
33
+ pass
34
+
35
+
36
+ MAIN_AGENT_HOOKS = MainAgentHooks()
@@ -0,0 +1,80 @@
1
+ import traceback
2
+ from collections.abc import AsyncGenerator
3
+
4
+ from astrbot.core import logger
5
+ from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
6
+ from astrbot.core.astr_agent_context import AstrAgentContext
7
+ from astrbot.core.message.message_event_result import (
8
+ MessageChain,
9
+ MessageEventResult,
10
+ ResultContentType,
11
+ )
12
+
13
+ AgentRunner = ToolLoopAgentRunner[AstrAgentContext]
14
+
15
+
16
+ async def run_agent(
17
+ agent_runner: AgentRunner,
18
+ max_step: int = 30,
19
+ show_tool_use: bool = True,
20
+ stream_to_general: bool = False,
21
+ show_reasoning: bool = False,
22
+ ) -> AsyncGenerator[MessageChain | None, None]:
23
+ step_idx = 0
24
+ astr_event = agent_runner.run_context.context.event
25
+ while step_idx < max_step:
26
+ step_idx += 1
27
+ try:
28
+ async for resp in agent_runner.step():
29
+ if astr_event.is_stopped():
30
+ return
31
+ if resp.type == "tool_call_result":
32
+ msg_chain = resp.data["chain"]
33
+ if msg_chain.type == "tool_direct_result":
34
+ # tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
35
+ await astr_event.send(resp.data["chain"])
36
+ continue
37
+ # 对于其他情况,暂时先不处理
38
+ continue
39
+ elif resp.type == "tool_call":
40
+ if agent_runner.streaming:
41
+ # 用来标记流式响应需要分节
42
+ yield MessageChain(chain=[], type="break")
43
+ if show_tool_use:
44
+ await astr_event.send(resp.data["chain"])
45
+ continue
46
+
47
+ if stream_to_general and resp.type == "streaming_delta":
48
+ continue
49
+
50
+ if stream_to_general or not agent_runner.streaming:
51
+ content_typ = (
52
+ ResultContentType.LLM_RESULT
53
+ if resp.type == "llm_result"
54
+ else ResultContentType.GENERAL_RESULT
55
+ )
56
+ astr_event.set_result(
57
+ MessageEventResult(
58
+ chain=resp.data["chain"].chain,
59
+ result_content_type=content_typ,
60
+ ),
61
+ )
62
+ yield
63
+ astr_event.clear_result()
64
+ elif resp.type == "streaming_delta":
65
+ chain = resp.data["chain"]
66
+ if chain.type == "reasoning" and not show_reasoning:
67
+ # display the reasoning content only when configured
68
+ continue
69
+ yield resp.data["chain"] # MessageChain
70
+ if agent_runner.done():
71
+ break
72
+
73
+ except Exception as e:
74
+ logger.error(traceback.format_exc())
75
+ err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在控制台查看和分享错误详情。\n"
76
+ if agent_runner.streaming:
77
+ yield MessageChain().message(err_msg)
78
+ else:
79
+ astr_event.set_result(MessageEventResult().message(err_msg))
80
+ return
@@ -0,0 +1,246 @@
1
+ import asyncio
2
+ import inspect
3
+ import traceback
4
+ import typing as T
5
+
6
+ import mcp
7
+
8
+ from astrbot import logger
9
+ from astrbot.core.agent.handoff import HandoffTool
10
+ from astrbot.core.agent.mcp_client import MCPTool
11
+ from astrbot.core.agent.run_context import ContextWrapper
12
+ from astrbot.core.agent.tool import FunctionTool, ToolSet
13
+ from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
14
+ from astrbot.core.astr_agent_context import AstrAgentContext
15
+ from astrbot.core.message.message_event_result import (
16
+ CommandResult,
17
+ MessageChain,
18
+ MessageEventResult,
19
+ )
20
+ from astrbot.core.provider.register import llm_tools
21
+
22
+
23
+ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
24
+ @classmethod
25
+ async def execute(cls, tool, run_context, **tool_args):
26
+ """执行函数调用。
27
+
28
+ Args:
29
+ event (AstrMessageEvent): 事件对象, 当 origin 为 local 时必须提供。
30
+ **kwargs: 函数调用的参数。
31
+
32
+ Returns:
33
+ AsyncGenerator[None | mcp.types.CallToolResult, None]
34
+
35
+ """
36
+ if isinstance(tool, HandoffTool):
37
+ async for r in cls._execute_handoff(tool, run_context, **tool_args):
38
+ yield r
39
+ return
40
+
41
+ elif isinstance(tool, MCPTool):
42
+ async for r in cls._execute_mcp(tool, run_context, **tool_args):
43
+ yield r
44
+ return
45
+
46
+ else:
47
+ async for r in cls._execute_local(tool, run_context, **tool_args):
48
+ yield r
49
+ return
50
+
51
+ @classmethod
52
+ async def _execute_handoff(
53
+ cls,
54
+ tool: HandoffTool,
55
+ run_context: ContextWrapper[AstrAgentContext],
56
+ **tool_args,
57
+ ):
58
+ input_ = tool_args.get("input")
59
+
60
+ # make toolset for the agent
61
+ tools = tool.agent.tools
62
+ if tools:
63
+ toolset = ToolSet()
64
+ for t in tools:
65
+ if isinstance(t, str):
66
+ _t = llm_tools.get_func(t)
67
+ if _t:
68
+ toolset.add_tool(_t)
69
+ elif isinstance(t, FunctionTool):
70
+ toolset.add_tool(t)
71
+ else:
72
+ toolset = None
73
+
74
+ ctx = run_context.context.context
75
+ event = run_context.context.event
76
+ umo = event.unified_msg_origin
77
+ prov_id = await ctx.get_current_chat_provider_id(umo)
78
+ llm_resp = await ctx.tool_loop_agent(
79
+ event=event,
80
+ chat_provider_id=prov_id,
81
+ prompt=input_,
82
+ system_prompt=tool.agent.instructions,
83
+ tools=toolset,
84
+ max_steps=30,
85
+ run_hooks=tool.agent.run_hooks,
86
+ )
87
+ yield mcp.types.CallToolResult(
88
+ content=[mcp.types.TextContent(type="text", text=llm_resp.completion_text)]
89
+ )
90
+
91
+ @classmethod
92
+ async def _execute_local(
93
+ cls,
94
+ tool: FunctionTool,
95
+ run_context: ContextWrapper[AstrAgentContext],
96
+ **tool_args,
97
+ ):
98
+ event = run_context.context.event
99
+ if not event:
100
+ raise ValueError("Event must be provided for local function tools.")
101
+
102
+ is_override_call = False
103
+ for ty in type(tool).mro():
104
+ if "call" in ty.__dict__ and ty.__dict__["call"] is not FunctionTool.call:
105
+ is_override_call = True
106
+ break
107
+
108
+ # 检查 tool 下有没有 run 方法
109
+ if not tool.handler and not hasattr(tool, "run") and not is_override_call:
110
+ raise ValueError("Tool must have a valid handler or override 'run' method.")
111
+
112
+ awaitable = None
113
+ method_name = ""
114
+ if tool.handler:
115
+ awaitable = tool.handler
116
+ method_name = "decorator_handler"
117
+ elif is_override_call:
118
+ awaitable = tool.call
119
+ method_name = "call"
120
+ elif hasattr(tool, "run"):
121
+ awaitable = getattr(tool, "run")
122
+ method_name = "run"
123
+ if awaitable is None:
124
+ raise ValueError("Tool must have a valid handler or override 'run' method.")
125
+
126
+ wrapper = call_local_llm_tool(
127
+ context=run_context,
128
+ handler=awaitable,
129
+ method_name=method_name,
130
+ **tool_args,
131
+ )
132
+ while True:
133
+ try:
134
+ resp = await asyncio.wait_for(
135
+ anext(wrapper),
136
+ timeout=run_context.tool_call_timeout,
137
+ )
138
+ if resp is not None:
139
+ if isinstance(resp, mcp.types.CallToolResult):
140
+ yield resp
141
+ else:
142
+ text_content = mcp.types.TextContent(
143
+ type="text",
144
+ text=str(resp),
145
+ )
146
+ yield mcp.types.CallToolResult(content=[text_content])
147
+ else:
148
+ # NOTE: Tool 在这里直接请求发送消息给用户
149
+ # TODO: 是否需要判断 event.get_result() 是否为空?
150
+ # 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
151
+ if res := run_context.context.event.get_result():
152
+ if res.chain:
153
+ try:
154
+ await event.send(
155
+ MessageChain(
156
+ chain=res.chain,
157
+ type="tool_direct_result",
158
+ )
159
+ )
160
+ except Exception as e:
161
+ logger.error(
162
+ f"Tool 直接发送消息失败: {e}",
163
+ exc_info=True,
164
+ )
165
+ yield None
166
+ except asyncio.TimeoutError:
167
+ raise Exception(
168
+ f"tool {tool.name} execution timeout after {run_context.tool_call_timeout} seconds.",
169
+ )
170
+ except StopAsyncIteration:
171
+ break
172
+
173
+ @classmethod
174
+ async def _execute_mcp(
175
+ cls,
176
+ tool: FunctionTool,
177
+ run_context: ContextWrapper[AstrAgentContext],
178
+ **tool_args,
179
+ ):
180
+ res = await tool.call(run_context, **tool_args)
181
+ if not res:
182
+ return
183
+ yield res
184
+
185
+
186
+ async def call_local_llm_tool(
187
+ context: ContextWrapper[AstrAgentContext],
188
+ handler: T.Callable[..., T.Awaitable[T.Any]],
189
+ method_name: str,
190
+ *args,
191
+ **kwargs,
192
+ ) -> T.AsyncGenerator[T.Any, None]:
193
+ """执行本地 LLM 工具的处理函数并处理其返回结果"""
194
+ ready_to_call = None # 一个协程或者异步生成器
195
+
196
+ trace_ = None
197
+
198
+ event = context.context.event
199
+
200
+ try:
201
+ if method_name == "run" or method_name == "decorator_handler":
202
+ ready_to_call = handler(event, *args, **kwargs)
203
+ elif method_name == "call":
204
+ ready_to_call = handler(context, *args, **kwargs)
205
+ else:
206
+ raise ValueError(f"未知的方法名: {method_name}")
207
+ except ValueError as e:
208
+ logger.error(f"调用本地 LLM 工具时出错: {e}", exc_info=True)
209
+ except TypeError:
210
+ logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
211
+ except Exception as e:
212
+ trace_ = traceback.format_exc()
213
+ logger.error(f"调用本地 LLM 工具时出错: {e}\n{trace_}")
214
+
215
+ if not ready_to_call:
216
+ return
217
+
218
+ if inspect.isasyncgen(ready_to_call):
219
+ _has_yielded = False
220
+ try:
221
+ async for ret in ready_to_call:
222
+ # 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
223
+ # 返回值只能是 MessageEventResult 或者 None(无返回值)
224
+ _has_yielded = True
225
+ if isinstance(ret, (MessageEventResult, CommandResult)):
226
+ # 如果返回值是 MessageEventResult, 设置结果并继续
227
+ event.set_result(ret)
228
+ yield
229
+ else:
230
+ # 如果返回值是 None, 则不设置结果并继续
231
+ # 继续执行后续阶段
232
+ yield ret
233
+ if not _has_yielded:
234
+ # 如果这个异步生成器没有执行到 yield 分支
235
+ yield
236
+ except Exception as e:
237
+ logger.error(f"Previous Error: {trace_}")
238
+ raise e
239
+ elif inspect.iscoroutine(ready_to_call):
240
+ # 如果只是一个协程, 直接执行
241
+ ret = await ready_to_call
242
+ if isinstance(ret, (MessageEventResult, CommandResult)):
243
+ event.set_result(ret)
244
+ yield
245
+ else:
246
+ yield ret