AstrBot 4.5.6__py3-none-any.whl → 4.5.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/api/all.py +2 -1
- astrbot/api/provider/__init__.py +2 -1
- astrbot/core/agent/message.py +1 -1
- astrbot/core/agent/run_context.py +7 -2
- astrbot/core/agent/runners/base.py +7 -0
- astrbot/core/agent/runners/tool_loop_agent_runner.py +51 -3
- astrbot/core/agent/tool.py +5 -6
- astrbot/core/astr_agent_context.py +13 -8
- astrbot/core/astr_agent_hooks.py +36 -0
- astrbot/core/astr_agent_run_util.py +80 -0
- astrbot/core/astr_agent_tool_exec.py +246 -0
- astrbot/core/config/default.py +53 -7
- astrbot/core/exceptions.py +9 -0
- astrbot/core/pipeline/context.py +1 -2
- astrbot/core/pipeline/context_utils.py +0 -65
- astrbot/core/pipeline/process_stage/method/llm_request.py +239 -491
- astrbot/core/pipeline/respond/stage.py +21 -20
- astrbot/core/platform/platform_metadata.py +3 -0
- astrbot/core/platform/register.py +2 -0
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +2 -0
- astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +16 -5
- astrbot/core/platform/sources/discord/discord_platform_adapter.py +4 -1
- astrbot/core/platform/sources/discord/discord_platform_event.py +16 -7
- astrbot/core/platform/sources/lark/lark_adapter.py +4 -1
- astrbot/core/platform/sources/misskey/misskey_adapter.py +4 -1
- astrbot/core/platform/sources/satori/satori_adapter.py +2 -2
- astrbot/core/platform/sources/slack/slack_adapter.py +2 -0
- astrbot/core/platform/sources/webchat/webchat_adapter.py +3 -0
- astrbot/core/platform/sources/webchat/webchat_event.py +8 -1
- astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +4 -1
- astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +16 -0
- astrbot/core/platform/sources/wecom/wecom_adapter.py +2 -1
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +4 -1
- astrbot/core/provider/__init__.py +2 -2
- astrbot/core/provider/entities.py +40 -18
- astrbot/core/provider/func_tool_manager.py +15 -6
- astrbot/core/provider/manager.py +4 -1
- astrbot/core/provider/provider.py +7 -22
- astrbot/core/provider/register.py +2 -0
- astrbot/core/provider/sources/anthropic_source.py +0 -2
- astrbot/core/provider/sources/coze_source.py +0 -2
- astrbot/core/provider/sources/dashscope_source.py +1 -3
- astrbot/core/provider/sources/dify_source.py +0 -2
- astrbot/core/provider/sources/gemini_source.py +31 -3
- astrbot/core/provider/sources/groq_source.py +15 -0
- astrbot/core/provider/sources/openai_source.py +67 -21
- astrbot/core/provider/sources/zhipu_source.py +1 -6
- astrbot/core/star/context.py +197 -45
- astrbot/core/star/register/star_handler.py +30 -10
- astrbot/dashboard/routes/chat.py +5 -0
- {astrbot-4.5.6.dist-info → astrbot-4.5.8.dist-info}/METADATA +2 -2
- {astrbot-4.5.6.dist-info → astrbot-4.5.8.dist-info}/RECORD +55 -50
- {astrbot-4.5.6.dist-info → astrbot-4.5.8.dist-info}/WHEEL +0 -0
- {astrbot-4.5.6.dist-info → astrbot-4.5.8.dist-info}/entry_points.txt +0 -0
- {astrbot-4.5.6.dist-info → astrbot-4.5.8.dist-info}/licenses/LICENSE +0 -0
astrbot/api/all.py
CHANGED
|
@@ -36,7 +36,8 @@ from astrbot.core.star.config import *
|
|
|
36
36
|
|
|
37
37
|
|
|
38
38
|
# provider
|
|
39
|
-
from astrbot.core.provider import Provider,
|
|
39
|
+
from astrbot.core.provider import Provider, ProviderMetaData
|
|
40
|
+
from astrbot.core.db.po import Personality
|
|
40
41
|
|
|
41
42
|
# platform
|
|
42
43
|
from astrbot.core.platform import (
|
astrbot/api/provider/__init__.py
CHANGED
astrbot/core/agent/message.py
CHANGED
|
@@ -1,16 +1,21 @@
|
|
|
1
|
-
from dataclasses import dataclass
|
|
2
1
|
from typing import Any, Generic
|
|
3
2
|
|
|
3
|
+
from pydantic import Field
|
|
4
|
+
from pydantic.dataclasses import dataclass
|
|
4
5
|
from typing_extensions import TypeVar
|
|
5
6
|
|
|
7
|
+
from .message import Message
|
|
8
|
+
|
|
6
9
|
TContext = TypeVar("TContext", default=Any)
|
|
7
10
|
|
|
8
11
|
|
|
9
|
-
@dataclass
|
|
12
|
+
@dataclass(config={"arbitrary_types_allowed": True})
|
|
10
13
|
class ContextWrapper(Generic[TContext]):
|
|
11
14
|
"""A context for running an agent, which can be used to pass additional data or state."""
|
|
12
15
|
|
|
13
16
|
context: TContext
|
|
17
|
+
messages: list[Message] = Field(default_factory=list)
|
|
18
|
+
"""This field stores the llm message context for the agent run, agent runners will maintain this field automatically."""
|
|
14
19
|
tool_call_timeout: int = 60 # Default tool call timeout in seconds
|
|
15
20
|
|
|
16
21
|
|
|
@@ -40,6 +40,13 @@ class BaseAgentRunner(T.Generic[TContext]):
|
|
|
40
40
|
"""Process a single step of the agent."""
|
|
41
41
|
...
|
|
42
42
|
|
|
43
|
+
@abc.abstractmethod
|
|
44
|
+
async def step_until_done(
|
|
45
|
+
self, max_step: int
|
|
46
|
+
) -> T.AsyncGenerator[AgentResponse, None]:
|
|
47
|
+
"""Process steps until the agent is done."""
|
|
48
|
+
...
|
|
49
|
+
|
|
43
50
|
@abc.abstractmethod
|
|
44
51
|
def done(self) -> bool:
|
|
45
52
|
"""Check if the agent has completed its task.
|
|
@@ -23,7 +23,7 @@ from astrbot.core.provider.entities import (
|
|
|
23
23
|
from astrbot.core.provider.provider import Provider
|
|
24
24
|
|
|
25
25
|
from ..hooks import BaseAgentRunHooks
|
|
26
|
-
from ..message import AssistantMessageSegment, ToolCallMessageSegment
|
|
26
|
+
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
|
|
27
27
|
from ..response import AgentResponseData
|
|
28
28
|
from ..run_context import ContextWrapper, TContext
|
|
29
29
|
from ..tool_executor import BaseFunctionToolExecutor
|
|
@@ -55,6 +55,20 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|
|
55
55
|
self.agent_hooks = agent_hooks
|
|
56
56
|
self.run_context = run_context
|
|
57
57
|
|
|
58
|
+
messages = []
|
|
59
|
+
# append existing messages in the run context
|
|
60
|
+
for msg in request.contexts:
|
|
61
|
+
messages.append(Message.model_validate(msg))
|
|
62
|
+
if request.prompt is not None:
|
|
63
|
+
m = await request.assemble_context()
|
|
64
|
+
messages.append(Message.model_validate(m))
|
|
65
|
+
if request.system_prompt:
|
|
66
|
+
messages.insert(
|
|
67
|
+
0,
|
|
68
|
+
Message(role="system", content=request.system_prompt),
|
|
69
|
+
)
|
|
70
|
+
self.run_context.messages = messages
|
|
71
|
+
|
|
58
72
|
def _transition_state(self, new_state: AgentState) -> None:
|
|
59
73
|
"""转换 Agent 状态"""
|
|
60
74
|
if self._state != new_state:
|
|
@@ -96,13 +110,22 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|
|
96
110
|
type="streaming_delta",
|
|
97
111
|
data=AgentResponseData(chain=llm_response.result_chain),
|
|
98
112
|
)
|
|
99
|
-
|
|
113
|
+
elif llm_response.completion_text:
|
|
100
114
|
yield AgentResponse(
|
|
101
115
|
type="streaming_delta",
|
|
102
116
|
data=AgentResponseData(
|
|
103
117
|
chain=MessageChain().message(llm_response.completion_text),
|
|
104
118
|
),
|
|
105
119
|
)
|
|
120
|
+
elif llm_response.reasoning_content:
|
|
121
|
+
yield AgentResponse(
|
|
122
|
+
type="streaming_delta",
|
|
123
|
+
data=AgentResponseData(
|
|
124
|
+
chain=MessageChain(type="reasoning").message(
|
|
125
|
+
llm_response.reasoning_content,
|
|
126
|
+
),
|
|
127
|
+
),
|
|
128
|
+
)
|
|
106
129
|
continue
|
|
107
130
|
llm_resp_result = llm_response
|
|
108
131
|
break # got final response
|
|
@@ -130,6 +153,13 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|
|
130
153
|
# 如果没有工具调用,转换到完成状态
|
|
131
154
|
self.final_llm_resp = llm_resp
|
|
132
155
|
self._transition_state(AgentState.DONE)
|
|
156
|
+
# record the final assistant message
|
|
157
|
+
self.run_context.messages.append(
|
|
158
|
+
Message(
|
|
159
|
+
role="assistant",
|
|
160
|
+
content=llm_resp.completion_text or "",
|
|
161
|
+
),
|
|
162
|
+
)
|
|
133
163
|
try:
|
|
134
164
|
await self.agent_hooks.on_agent_done(self.run_context, llm_resp)
|
|
135
165
|
except Exception as e:
|
|
@@ -156,13 +186,16 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|
|
156
186
|
yield AgentResponse(
|
|
157
187
|
type="tool_call",
|
|
158
188
|
data=AgentResponseData(
|
|
159
|
-
chain=MessageChain().message(
|
|
189
|
+
chain=MessageChain(type="tool_call").message(
|
|
190
|
+
f"🔨 调用工具: {tool_call_name}"
|
|
191
|
+
),
|
|
160
192
|
),
|
|
161
193
|
)
|
|
162
194
|
async for result in self._handle_function_tools(self.req, llm_resp):
|
|
163
195
|
if isinstance(result, list):
|
|
164
196
|
tool_call_result_blocks = result
|
|
165
197
|
elif isinstance(result, MessageChain):
|
|
198
|
+
result.type = "tool_call_result"
|
|
166
199
|
yield AgentResponse(
|
|
167
200
|
type="tool_call_result",
|
|
168
201
|
data=AgentResponseData(chain=result),
|
|
@@ -175,8 +208,23 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|
|
175
208
|
),
|
|
176
209
|
tool_calls_result=tool_call_result_blocks,
|
|
177
210
|
)
|
|
211
|
+
# record the assistant message with tool calls
|
|
212
|
+
self.run_context.messages.extend(
|
|
213
|
+
tool_calls_result.to_openai_messages_model()
|
|
214
|
+
)
|
|
215
|
+
|
|
178
216
|
self.req.append_tool_calls_result(tool_calls_result)
|
|
179
217
|
|
|
218
|
+
async def step_until_done(
|
|
219
|
+
self, max_step: int
|
|
220
|
+
) -> T.AsyncGenerator[AgentResponse, None]:
|
|
221
|
+
"""Process steps until the agent is done."""
|
|
222
|
+
step_count = 0
|
|
223
|
+
while not self.done() and step_count < max_step:
|
|
224
|
+
step_count += 1
|
|
225
|
+
async for resp in self.step():
|
|
226
|
+
yield resp
|
|
227
|
+
|
|
180
228
|
async def _handle_function_tools(
|
|
181
229
|
self,
|
|
182
230
|
req: ProviderRequest,
|
astrbot/core/agent/tool.py
CHANGED
|
@@ -4,12 +4,13 @@ from typing import Any, Generic
|
|
|
4
4
|
import jsonschema
|
|
5
5
|
import mcp
|
|
6
6
|
from deprecated import deprecated
|
|
7
|
-
from pydantic import model_validator
|
|
7
|
+
from pydantic import Field, model_validator
|
|
8
8
|
from pydantic.dataclasses import dataclass
|
|
9
9
|
|
|
10
10
|
from .run_context import ContextWrapper, TContext
|
|
11
11
|
|
|
12
12
|
ParametersType = dict[str, Any]
|
|
13
|
+
ToolExecResult = str | mcp.types.CallToolResult
|
|
13
14
|
|
|
14
15
|
|
|
15
16
|
@dataclass
|
|
@@ -55,15 +56,14 @@ class FunctionTool(ToolSchema, Generic[TContext]):
|
|
|
55
56
|
def __repr__(self):
|
|
56
57
|
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description})"
|
|
57
58
|
|
|
58
|
-
async def call(
|
|
59
|
-
self, context: ContextWrapper[TContext], **kwargs
|
|
60
|
-
) -> str | mcp.types.CallToolResult:
|
|
59
|
+
async def call(self, context: ContextWrapper[TContext], **kwargs) -> ToolExecResult:
|
|
61
60
|
"""Run the tool with the given arguments. The handler field has priority."""
|
|
62
61
|
raise NotImplementedError(
|
|
63
62
|
"FunctionTool.call() must be implemented by subclasses or set a handler."
|
|
64
63
|
)
|
|
65
64
|
|
|
66
65
|
|
|
66
|
+
@dataclass
|
|
67
67
|
class ToolSet:
|
|
68
68
|
"""A set of function tools that can be used in function calling.
|
|
69
69
|
|
|
@@ -71,8 +71,7 @@ class ToolSet:
|
|
|
71
71
|
convert the tools to different API formats (OpenAI, Anthropic, Google GenAI).
|
|
72
72
|
"""
|
|
73
73
|
|
|
74
|
-
|
|
75
|
-
self.tools: list[FunctionTool] = tools or []
|
|
74
|
+
tools: list[FunctionTool] = Field(default_factory=list)
|
|
76
75
|
|
|
77
76
|
def empty(self) -> bool:
|
|
78
77
|
"""Check if the tool set is empty."""
|
|
@@ -1,14 +1,19 @@
|
|
|
1
|
-
from
|
|
1
|
+
from pydantic import Field
|
|
2
|
+
from pydantic.dataclasses import dataclass
|
|
2
3
|
|
|
4
|
+
from astrbot.core.agent.run_context import ContextWrapper
|
|
3
5
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
|
4
|
-
from astrbot.core.
|
|
5
|
-
from astrbot.core.provider.entities import ProviderRequest
|
|
6
|
+
from astrbot.core.star.context import Context
|
|
6
7
|
|
|
7
8
|
|
|
8
|
-
@dataclass
|
|
9
|
+
@dataclass(config={"arbitrary_types_allowed": True})
|
|
9
10
|
class AstrAgentContext:
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
curr_provider_request: ProviderRequest
|
|
13
|
-
streaming: bool
|
|
11
|
+
context: Context
|
|
12
|
+
"""The star context instance"""
|
|
14
13
|
event: AstrMessageEvent
|
|
14
|
+
"""The message event associated with the agent context."""
|
|
15
|
+
extra: dict[str, str] = Field(default_factory=dict)
|
|
16
|
+
"""Customized extra data."""
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
AgentContextWrapper = ContextWrapper[AstrAgentContext]
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from mcp.types import CallToolResult
|
|
4
|
+
|
|
5
|
+
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
|
6
|
+
from astrbot.core.agent.run_context import ContextWrapper
|
|
7
|
+
from astrbot.core.agent.tool import FunctionTool
|
|
8
|
+
from astrbot.core.astr_agent_context import AstrAgentContext
|
|
9
|
+
from astrbot.core.pipeline.context_utils import call_event_hook
|
|
10
|
+
from astrbot.core.star.star_handler import EventType
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
|
|
14
|
+
async def on_agent_done(self, run_context, llm_response):
|
|
15
|
+
# 执行事件钩子
|
|
16
|
+
await call_event_hook(
|
|
17
|
+
run_context.context.event,
|
|
18
|
+
EventType.OnLLMResponseEvent,
|
|
19
|
+
llm_response,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
async def on_tool_end(
|
|
23
|
+
self,
|
|
24
|
+
run_context: ContextWrapper[AstrAgentContext],
|
|
25
|
+
tool: FunctionTool[Any],
|
|
26
|
+
tool_args: dict | None,
|
|
27
|
+
tool_result: CallToolResult | None,
|
|
28
|
+
):
|
|
29
|
+
run_context.context.event.clear_result()
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class EmptyAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
MAIN_AGENT_HOOKS = MainAgentHooks()
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
import traceback
|
|
2
|
+
from collections.abc import AsyncGenerator
|
|
3
|
+
|
|
4
|
+
from astrbot.core import logger
|
|
5
|
+
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
|
|
6
|
+
from astrbot.core.astr_agent_context import AstrAgentContext
|
|
7
|
+
from astrbot.core.message.message_event_result import (
|
|
8
|
+
MessageChain,
|
|
9
|
+
MessageEventResult,
|
|
10
|
+
ResultContentType,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
AgentRunner = ToolLoopAgentRunner[AstrAgentContext]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
async def run_agent(
|
|
17
|
+
agent_runner: AgentRunner,
|
|
18
|
+
max_step: int = 30,
|
|
19
|
+
show_tool_use: bool = True,
|
|
20
|
+
stream_to_general: bool = False,
|
|
21
|
+
show_reasoning: bool = False,
|
|
22
|
+
) -> AsyncGenerator[MessageChain | None, None]:
|
|
23
|
+
step_idx = 0
|
|
24
|
+
astr_event = agent_runner.run_context.context.event
|
|
25
|
+
while step_idx < max_step:
|
|
26
|
+
step_idx += 1
|
|
27
|
+
try:
|
|
28
|
+
async for resp in agent_runner.step():
|
|
29
|
+
if astr_event.is_stopped():
|
|
30
|
+
return
|
|
31
|
+
if resp.type == "tool_call_result":
|
|
32
|
+
msg_chain = resp.data["chain"]
|
|
33
|
+
if msg_chain.type == "tool_direct_result":
|
|
34
|
+
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
|
|
35
|
+
await astr_event.send(resp.data["chain"])
|
|
36
|
+
continue
|
|
37
|
+
# 对于其他情况,暂时先不处理
|
|
38
|
+
continue
|
|
39
|
+
elif resp.type == "tool_call":
|
|
40
|
+
if agent_runner.streaming:
|
|
41
|
+
# 用来标记流式响应需要分节
|
|
42
|
+
yield MessageChain(chain=[], type="break")
|
|
43
|
+
if show_tool_use:
|
|
44
|
+
await astr_event.send(resp.data["chain"])
|
|
45
|
+
continue
|
|
46
|
+
|
|
47
|
+
if stream_to_general and resp.type == "streaming_delta":
|
|
48
|
+
continue
|
|
49
|
+
|
|
50
|
+
if stream_to_general or not agent_runner.streaming:
|
|
51
|
+
content_typ = (
|
|
52
|
+
ResultContentType.LLM_RESULT
|
|
53
|
+
if resp.type == "llm_result"
|
|
54
|
+
else ResultContentType.GENERAL_RESULT
|
|
55
|
+
)
|
|
56
|
+
astr_event.set_result(
|
|
57
|
+
MessageEventResult(
|
|
58
|
+
chain=resp.data["chain"].chain,
|
|
59
|
+
result_content_type=content_typ,
|
|
60
|
+
),
|
|
61
|
+
)
|
|
62
|
+
yield
|
|
63
|
+
astr_event.clear_result()
|
|
64
|
+
elif resp.type == "streaming_delta":
|
|
65
|
+
chain = resp.data["chain"]
|
|
66
|
+
if chain.type == "reasoning" and not show_reasoning:
|
|
67
|
+
# display the reasoning content only when configured
|
|
68
|
+
continue
|
|
69
|
+
yield resp.data["chain"] # MessageChain
|
|
70
|
+
if agent_runner.done():
|
|
71
|
+
break
|
|
72
|
+
|
|
73
|
+
except Exception as e:
|
|
74
|
+
logger.error(traceback.format_exc())
|
|
75
|
+
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在控制台查看和分享错误详情。\n"
|
|
76
|
+
if agent_runner.streaming:
|
|
77
|
+
yield MessageChain().message(err_msg)
|
|
78
|
+
else:
|
|
79
|
+
astr_event.set_result(MessageEventResult().message(err_msg))
|
|
80
|
+
return
|
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import inspect
|
|
3
|
+
import traceback
|
|
4
|
+
import typing as T
|
|
5
|
+
|
|
6
|
+
import mcp
|
|
7
|
+
|
|
8
|
+
from astrbot import logger
|
|
9
|
+
from astrbot.core.agent.handoff import HandoffTool
|
|
10
|
+
from astrbot.core.agent.mcp_client import MCPTool
|
|
11
|
+
from astrbot.core.agent.run_context import ContextWrapper
|
|
12
|
+
from astrbot.core.agent.tool import FunctionTool, ToolSet
|
|
13
|
+
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
|
|
14
|
+
from astrbot.core.astr_agent_context import AstrAgentContext
|
|
15
|
+
from astrbot.core.message.message_event_result import (
|
|
16
|
+
CommandResult,
|
|
17
|
+
MessageChain,
|
|
18
|
+
MessageEventResult,
|
|
19
|
+
)
|
|
20
|
+
from astrbot.core.provider.register import llm_tools
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
|
24
|
+
@classmethod
|
|
25
|
+
async def execute(cls, tool, run_context, **tool_args):
|
|
26
|
+
"""执行函数调用。
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
event (AstrMessageEvent): 事件对象, 当 origin 为 local 时必须提供。
|
|
30
|
+
**kwargs: 函数调用的参数。
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
AsyncGenerator[None | mcp.types.CallToolResult, None]
|
|
34
|
+
|
|
35
|
+
"""
|
|
36
|
+
if isinstance(tool, HandoffTool):
|
|
37
|
+
async for r in cls._execute_handoff(tool, run_context, **tool_args):
|
|
38
|
+
yield r
|
|
39
|
+
return
|
|
40
|
+
|
|
41
|
+
elif isinstance(tool, MCPTool):
|
|
42
|
+
async for r in cls._execute_mcp(tool, run_context, **tool_args):
|
|
43
|
+
yield r
|
|
44
|
+
return
|
|
45
|
+
|
|
46
|
+
else:
|
|
47
|
+
async for r in cls._execute_local(tool, run_context, **tool_args):
|
|
48
|
+
yield r
|
|
49
|
+
return
|
|
50
|
+
|
|
51
|
+
@classmethod
|
|
52
|
+
async def _execute_handoff(
|
|
53
|
+
cls,
|
|
54
|
+
tool: HandoffTool,
|
|
55
|
+
run_context: ContextWrapper[AstrAgentContext],
|
|
56
|
+
**tool_args,
|
|
57
|
+
):
|
|
58
|
+
input_ = tool_args.get("input")
|
|
59
|
+
|
|
60
|
+
# make toolset for the agent
|
|
61
|
+
tools = tool.agent.tools
|
|
62
|
+
if tools:
|
|
63
|
+
toolset = ToolSet()
|
|
64
|
+
for t in tools:
|
|
65
|
+
if isinstance(t, str):
|
|
66
|
+
_t = llm_tools.get_func(t)
|
|
67
|
+
if _t:
|
|
68
|
+
toolset.add_tool(_t)
|
|
69
|
+
elif isinstance(t, FunctionTool):
|
|
70
|
+
toolset.add_tool(t)
|
|
71
|
+
else:
|
|
72
|
+
toolset = None
|
|
73
|
+
|
|
74
|
+
ctx = run_context.context.context
|
|
75
|
+
event = run_context.context.event
|
|
76
|
+
umo = event.unified_msg_origin
|
|
77
|
+
prov_id = await ctx.get_current_chat_provider_id(umo)
|
|
78
|
+
llm_resp = await ctx.tool_loop_agent(
|
|
79
|
+
event=event,
|
|
80
|
+
chat_provider_id=prov_id,
|
|
81
|
+
prompt=input_,
|
|
82
|
+
system_prompt=tool.agent.instructions,
|
|
83
|
+
tools=toolset,
|
|
84
|
+
max_steps=30,
|
|
85
|
+
run_hooks=tool.agent.run_hooks,
|
|
86
|
+
)
|
|
87
|
+
yield mcp.types.CallToolResult(
|
|
88
|
+
content=[mcp.types.TextContent(type="text", text=llm_resp.completion_text)]
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
@classmethod
|
|
92
|
+
async def _execute_local(
|
|
93
|
+
cls,
|
|
94
|
+
tool: FunctionTool,
|
|
95
|
+
run_context: ContextWrapper[AstrAgentContext],
|
|
96
|
+
**tool_args,
|
|
97
|
+
):
|
|
98
|
+
event = run_context.context.event
|
|
99
|
+
if not event:
|
|
100
|
+
raise ValueError("Event must be provided for local function tools.")
|
|
101
|
+
|
|
102
|
+
is_override_call = False
|
|
103
|
+
for ty in type(tool).mro():
|
|
104
|
+
if "call" in ty.__dict__ and ty.__dict__["call"] is not FunctionTool.call:
|
|
105
|
+
is_override_call = True
|
|
106
|
+
break
|
|
107
|
+
|
|
108
|
+
# 检查 tool 下有没有 run 方法
|
|
109
|
+
if not tool.handler and not hasattr(tool, "run") and not is_override_call:
|
|
110
|
+
raise ValueError("Tool must have a valid handler or override 'run' method.")
|
|
111
|
+
|
|
112
|
+
awaitable = None
|
|
113
|
+
method_name = ""
|
|
114
|
+
if tool.handler:
|
|
115
|
+
awaitable = tool.handler
|
|
116
|
+
method_name = "decorator_handler"
|
|
117
|
+
elif is_override_call:
|
|
118
|
+
awaitable = tool.call
|
|
119
|
+
method_name = "call"
|
|
120
|
+
elif hasattr(tool, "run"):
|
|
121
|
+
awaitable = getattr(tool, "run")
|
|
122
|
+
method_name = "run"
|
|
123
|
+
if awaitable is None:
|
|
124
|
+
raise ValueError("Tool must have a valid handler or override 'run' method.")
|
|
125
|
+
|
|
126
|
+
wrapper = call_local_llm_tool(
|
|
127
|
+
context=run_context,
|
|
128
|
+
handler=awaitable,
|
|
129
|
+
method_name=method_name,
|
|
130
|
+
**tool_args,
|
|
131
|
+
)
|
|
132
|
+
while True:
|
|
133
|
+
try:
|
|
134
|
+
resp = await asyncio.wait_for(
|
|
135
|
+
anext(wrapper),
|
|
136
|
+
timeout=run_context.tool_call_timeout,
|
|
137
|
+
)
|
|
138
|
+
if resp is not None:
|
|
139
|
+
if isinstance(resp, mcp.types.CallToolResult):
|
|
140
|
+
yield resp
|
|
141
|
+
else:
|
|
142
|
+
text_content = mcp.types.TextContent(
|
|
143
|
+
type="text",
|
|
144
|
+
text=str(resp),
|
|
145
|
+
)
|
|
146
|
+
yield mcp.types.CallToolResult(content=[text_content])
|
|
147
|
+
else:
|
|
148
|
+
# NOTE: Tool 在这里直接请求发送消息给用户
|
|
149
|
+
# TODO: 是否需要判断 event.get_result() 是否为空?
|
|
150
|
+
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
|
|
151
|
+
if res := run_context.context.event.get_result():
|
|
152
|
+
if res.chain:
|
|
153
|
+
try:
|
|
154
|
+
await event.send(
|
|
155
|
+
MessageChain(
|
|
156
|
+
chain=res.chain,
|
|
157
|
+
type="tool_direct_result",
|
|
158
|
+
)
|
|
159
|
+
)
|
|
160
|
+
except Exception as e:
|
|
161
|
+
logger.error(
|
|
162
|
+
f"Tool 直接发送消息失败: {e}",
|
|
163
|
+
exc_info=True,
|
|
164
|
+
)
|
|
165
|
+
yield None
|
|
166
|
+
except asyncio.TimeoutError:
|
|
167
|
+
raise Exception(
|
|
168
|
+
f"tool {tool.name} execution timeout after {run_context.tool_call_timeout} seconds.",
|
|
169
|
+
)
|
|
170
|
+
except StopAsyncIteration:
|
|
171
|
+
break
|
|
172
|
+
|
|
173
|
+
@classmethod
|
|
174
|
+
async def _execute_mcp(
|
|
175
|
+
cls,
|
|
176
|
+
tool: FunctionTool,
|
|
177
|
+
run_context: ContextWrapper[AstrAgentContext],
|
|
178
|
+
**tool_args,
|
|
179
|
+
):
|
|
180
|
+
res = await tool.call(run_context, **tool_args)
|
|
181
|
+
if not res:
|
|
182
|
+
return
|
|
183
|
+
yield res
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
async def call_local_llm_tool(
|
|
187
|
+
context: ContextWrapper[AstrAgentContext],
|
|
188
|
+
handler: T.Callable[..., T.Awaitable[T.Any]],
|
|
189
|
+
method_name: str,
|
|
190
|
+
*args,
|
|
191
|
+
**kwargs,
|
|
192
|
+
) -> T.AsyncGenerator[T.Any, None]:
|
|
193
|
+
"""执行本地 LLM 工具的处理函数并处理其返回结果"""
|
|
194
|
+
ready_to_call = None # 一个协程或者异步生成器
|
|
195
|
+
|
|
196
|
+
trace_ = None
|
|
197
|
+
|
|
198
|
+
event = context.context.event
|
|
199
|
+
|
|
200
|
+
try:
|
|
201
|
+
if method_name == "run" or method_name == "decorator_handler":
|
|
202
|
+
ready_to_call = handler(event, *args, **kwargs)
|
|
203
|
+
elif method_name == "call":
|
|
204
|
+
ready_to_call = handler(context, *args, **kwargs)
|
|
205
|
+
else:
|
|
206
|
+
raise ValueError(f"未知的方法名: {method_name}")
|
|
207
|
+
except ValueError as e:
|
|
208
|
+
logger.error(f"调用本地 LLM 工具时出错: {e}", exc_info=True)
|
|
209
|
+
except TypeError:
|
|
210
|
+
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
|
|
211
|
+
except Exception as e:
|
|
212
|
+
trace_ = traceback.format_exc()
|
|
213
|
+
logger.error(f"调用本地 LLM 工具时出错: {e}\n{trace_}")
|
|
214
|
+
|
|
215
|
+
if not ready_to_call:
|
|
216
|
+
return
|
|
217
|
+
|
|
218
|
+
if inspect.isasyncgen(ready_to_call):
|
|
219
|
+
_has_yielded = False
|
|
220
|
+
try:
|
|
221
|
+
async for ret in ready_to_call:
|
|
222
|
+
# 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码
|
|
223
|
+
# 返回值只能是 MessageEventResult 或者 None(无返回值)
|
|
224
|
+
_has_yielded = True
|
|
225
|
+
if isinstance(ret, (MessageEventResult, CommandResult)):
|
|
226
|
+
# 如果返回值是 MessageEventResult, 设置结果并继续
|
|
227
|
+
event.set_result(ret)
|
|
228
|
+
yield
|
|
229
|
+
else:
|
|
230
|
+
# 如果返回值是 None, 则不设置结果并继续
|
|
231
|
+
# 继续执行后续阶段
|
|
232
|
+
yield ret
|
|
233
|
+
if not _has_yielded:
|
|
234
|
+
# 如果这个异步生成器没有执行到 yield 分支
|
|
235
|
+
yield
|
|
236
|
+
except Exception as e:
|
|
237
|
+
logger.error(f"Previous Error: {trace_}")
|
|
238
|
+
raise e
|
|
239
|
+
elif inspect.iscoroutine(ready_to_call):
|
|
240
|
+
# 如果只是一个协程, 直接执行
|
|
241
|
+
ret = await ready_to_call
|
|
242
|
+
if isinstance(ret, (MessageEventResult, CommandResult)):
|
|
243
|
+
event.set_result(ret)
|
|
244
|
+
yield
|
|
245
|
+
else:
|
|
246
|
+
yield ret
|