AstrBot 4.5.3__py3-none-any.whl → 4.5.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/api/all.py +2 -1
- astrbot/api/provider/__init__.py +2 -1
- astrbot/core/agent/run_context.py +7 -2
- astrbot/core/agent/runners/base.py +7 -0
- astrbot/core/agent/runners/tool_loop_agent_runner.py +51 -3
- astrbot/core/agent/tool.py +5 -6
- astrbot/core/astr_agent_context.py +13 -8
- astrbot/core/astr_agent_hooks.py +36 -0
- astrbot/core/astr_agent_run_util.py +80 -0
- astrbot/core/astr_agent_tool_exec.py +246 -0
- astrbot/core/config/default.py +53 -7
- astrbot/core/exceptions.py +9 -0
- astrbot/core/pipeline/context.py +1 -2
- astrbot/core/pipeline/context_utils.py +0 -65
- astrbot/core/pipeline/process_stage/method/llm_request.py +239 -491
- astrbot/core/pipeline/respond/stage.py +21 -20
- astrbot/core/platform/platform_metadata.py +3 -0
- astrbot/core/platform/register.py +2 -0
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +2 -0
- astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +16 -5
- astrbot/core/platform/sources/discord/discord_platform_adapter.py +4 -1
- astrbot/core/platform/sources/discord/discord_platform_event.py +16 -7
- astrbot/core/platform/sources/lark/lark_adapter.py +4 -1
- astrbot/core/platform/sources/misskey/misskey_adapter.py +4 -1
- astrbot/core/platform/sources/satori/satori_adapter.py +2 -2
- astrbot/core/platform/sources/slack/slack_adapter.py +2 -0
- astrbot/core/platform/sources/webchat/webchat_adapter.py +3 -0
- astrbot/core/platform/sources/webchat/webchat_event.py +8 -1
- astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +4 -1
- astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +16 -0
- astrbot/core/platform/sources/wecom/wecom_adapter.py +2 -1
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +4 -1
- astrbot/core/provider/__init__.py +2 -2
- astrbot/core/provider/entities.py +40 -18
- astrbot/core/provider/func_tool_manager.py +15 -6
- astrbot/core/provider/manager.py +4 -1
- astrbot/core/provider/provider.py +7 -22
- astrbot/core/provider/register.py +2 -0
- astrbot/core/provider/sources/anthropic_source.py +0 -2
- astrbot/core/provider/sources/coze_source.py +0 -2
- astrbot/core/provider/sources/dashscope_source.py +1 -3
- astrbot/core/provider/sources/dify_source.py +0 -2
- astrbot/core/provider/sources/gemini_source.py +31 -3
- astrbot/core/provider/sources/groq_source.py +15 -0
- astrbot/core/provider/sources/openai_source.py +67 -21
- astrbot/core/provider/sources/zhipu_source.py +1 -6
- astrbot/core/star/context.py +197 -45
- astrbot/core/star/register/star_handler.py +30 -10
- astrbot/dashboard/routes/chat.py +5 -0
- {astrbot-4.5.3.dist-info → astrbot-4.5.7.dist-info}/METADATA +55 -65
- {astrbot-4.5.3.dist-info → astrbot-4.5.7.dist-info}/RECORD +54 -49
- {astrbot-4.5.3.dist-info → astrbot-4.5.7.dist-info}/WHEEL +0 -0
- {astrbot-4.5.3.dist-info → astrbot-4.5.7.dist-info}/entry_points.txt +0 -0
- {astrbot-4.5.3.dist-info → astrbot-4.5.7.dist-info}/licenses/LICENSE +0 -0
|
@@ -3,20 +3,10 @@
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import copy
|
|
5
5
|
import json
|
|
6
|
-
import traceback
|
|
7
6
|
from collections.abc import AsyncGenerator
|
|
8
|
-
from typing import Any
|
|
9
|
-
|
|
10
|
-
from mcp.types import CallToolResult
|
|
11
7
|
|
|
12
8
|
from astrbot.core import logger
|
|
13
|
-
from astrbot.core.agent.
|
|
14
|
-
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
|
15
|
-
from astrbot.core.agent.mcp_client import MCPTool
|
|
16
|
-
from astrbot.core.agent.run_context import ContextWrapper
|
|
17
|
-
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
|
|
18
|
-
from astrbot.core.agent.tool import FunctionTool, ToolSet
|
|
19
|
-
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
|
|
9
|
+
from astrbot.core.agent.tool import ToolSet
|
|
20
10
|
from astrbot.core.astr_agent_context import AstrAgentContext
|
|
21
11
|
from astrbot.core.conversation_mgr import Conversation
|
|
22
12
|
from astrbot.core.message.components import Image
|
|
@@ -31,324 +21,19 @@ from astrbot.core.provider.entities import (
|
|
|
31
21
|
LLMResponse,
|
|
32
22
|
ProviderRequest,
|
|
33
23
|
)
|
|
34
|
-
from astrbot.core.provider.register import llm_tools
|
|
35
24
|
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
|
36
25
|
from astrbot.core.star.star_handler import EventType, star_map
|
|
37
26
|
from astrbot.core.utils.metrics import Metric
|
|
27
|
+
from astrbot.core.utils.session_lock import session_lock_manager
|
|
38
28
|
|
|
39
|
-
from
|
|
29
|
+
from ....astr_agent_context import AgentContextWrapper
|
|
30
|
+
from ....astr_agent_hooks import MAIN_AGENT_HOOKS
|
|
31
|
+
from ....astr_agent_run_util import AgentRunner, run_agent
|
|
32
|
+
from ....astr_agent_tool_exec import FunctionToolExecutor
|
|
33
|
+
from ...context import PipelineContext, call_event_hook
|
|
40
34
|
from ..stage import Stage
|
|
41
35
|
from ..utils import inject_kb_context
|
|
42
36
|
|
|
43
|
-
try:
|
|
44
|
-
import mcp
|
|
45
|
-
except (ModuleNotFoundError, ImportError):
|
|
46
|
-
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
AgentContextWrapper = ContextWrapper[AstrAgentContext]
|
|
50
|
-
AgentRunner = ToolLoopAgentRunner[AstrAgentContext]
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
|
54
|
-
@classmethod
|
|
55
|
-
async def execute(cls, tool, run_context, **tool_args):
|
|
56
|
-
"""执行函数调用。
|
|
57
|
-
|
|
58
|
-
Args:
|
|
59
|
-
event (AstrMessageEvent): 事件对象, 当 origin 为 local 时必须提供。
|
|
60
|
-
**kwargs: 函数调用的参数。
|
|
61
|
-
|
|
62
|
-
Returns:
|
|
63
|
-
AsyncGenerator[None | mcp.types.CallToolResult, None]
|
|
64
|
-
|
|
65
|
-
"""
|
|
66
|
-
if isinstance(tool, HandoffTool):
|
|
67
|
-
async for r in cls._execute_handoff(tool, run_context, **tool_args):
|
|
68
|
-
yield r
|
|
69
|
-
return
|
|
70
|
-
|
|
71
|
-
elif isinstance(tool, MCPTool):
|
|
72
|
-
async for r in cls._execute_mcp(tool, run_context, **tool_args):
|
|
73
|
-
yield r
|
|
74
|
-
return
|
|
75
|
-
|
|
76
|
-
else:
|
|
77
|
-
async for r in cls._execute_local(tool, run_context, **tool_args):
|
|
78
|
-
yield r
|
|
79
|
-
return
|
|
80
|
-
|
|
81
|
-
@classmethod
|
|
82
|
-
async def _execute_handoff(
|
|
83
|
-
cls,
|
|
84
|
-
tool: HandoffTool,
|
|
85
|
-
run_context: ContextWrapper[AstrAgentContext],
|
|
86
|
-
**tool_args,
|
|
87
|
-
):
|
|
88
|
-
input_ = tool_args.get("input", "agent")
|
|
89
|
-
agent_runner = AgentRunner()
|
|
90
|
-
|
|
91
|
-
# make toolset for the agent
|
|
92
|
-
tools = tool.agent.tools
|
|
93
|
-
if tools:
|
|
94
|
-
toolset = ToolSet()
|
|
95
|
-
for t in tools:
|
|
96
|
-
if isinstance(t, str):
|
|
97
|
-
_t = llm_tools.get_func(t)
|
|
98
|
-
if _t:
|
|
99
|
-
toolset.add_tool(_t)
|
|
100
|
-
elif isinstance(t, FunctionTool):
|
|
101
|
-
toolset.add_tool(t)
|
|
102
|
-
else:
|
|
103
|
-
toolset = None
|
|
104
|
-
|
|
105
|
-
request = ProviderRequest(
|
|
106
|
-
prompt=input_,
|
|
107
|
-
system_prompt=tool.description or "",
|
|
108
|
-
image_urls=[], # 暂时不传递原始 agent 的上下文
|
|
109
|
-
contexts=[], # 暂时不传递原始 agent 的上下文
|
|
110
|
-
func_tool=toolset,
|
|
111
|
-
)
|
|
112
|
-
astr_agent_ctx = AstrAgentContext(
|
|
113
|
-
provider=run_context.context.provider,
|
|
114
|
-
first_provider_request=run_context.context.first_provider_request,
|
|
115
|
-
curr_provider_request=request,
|
|
116
|
-
streaming=run_context.context.streaming,
|
|
117
|
-
event=run_context.context.event,
|
|
118
|
-
)
|
|
119
|
-
|
|
120
|
-
event = run_context.context.event
|
|
121
|
-
|
|
122
|
-
logger.debug(f"正在将任务委托给 Agent: {tool.agent.name}, input: {input_}")
|
|
123
|
-
await event.send(
|
|
124
|
-
MessageChain().message("✨ 正在将任务委托给 Agent: " + tool.agent.name),
|
|
125
|
-
)
|
|
126
|
-
|
|
127
|
-
await agent_runner.reset(
|
|
128
|
-
provider=run_context.context.provider,
|
|
129
|
-
request=request,
|
|
130
|
-
run_context=AgentContextWrapper(
|
|
131
|
-
context=astr_agent_ctx,
|
|
132
|
-
tool_call_timeout=run_context.tool_call_timeout,
|
|
133
|
-
),
|
|
134
|
-
tool_executor=FunctionToolExecutor(),
|
|
135
|
-
agent_hooks=tool.agent.run_hooks or BaseAgentRunHooks[AstrAgentContext](),
|
|
136
|
-
streaming=run_context.context.streaming,
|
|
137
|
-
)
|
|
138
|
-
|
|
139
|
-
async for _ in run_agent(agent_runner, 15, True):
|
|
140
|
-
pass
|
|
141
|
-
|
|
142
|
-
if agent_runner.done():
|
|
143
|
-
llm_response = agent_runner.get_final_llm_resp()
|
|
144
|
-
|
|
145
|
-
if not llm_response:
|
|
146
|
-
text_content = mcp.types.TextContent(
|
|
147
|
-
type="text",
|
|
148
|
-
text=f"error when deligate task to {tool.agent.name}",
|
|
149
|
-
)
|
|
150
|
-
yield mcp.types.CallToolResult(content=[text_content])
|
|
151
|
-
return
|
|
152
|
-
|
|
153
|
-
logger.debug(
|
|
154
|
-
f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}",
|
|
155
|
-
)
|
|
156
|
-
|
|
157
|
-
result = (
|
|
158
|
-
f"Agent {tool.agent.name} respond with: {llm_response.completion_text}\n\n"
|
|
159
|
-
"Note: If the result is error or need user provide more information, please provide more information to the agent(you can ask user for more information first)."
|
|
160
|
-
)
|
|
161
|
-
|
|
162
|
-
text_content = mcp.types.TextContent(
|
|
163
|
-
type="text",
|
|
164
|
-
text=result,
|
|
165
|
-
)
|
|
166
|
-
yield mcp.types.CallToolResult(content=[text_content])
|
|
167
|
-
else:
|
|
168
|
-
text_content = mcp.types.TextContent(
|
|
169
|
-
type="text",
|
|
170
|
-
text=f"error when deligate task to {tool.agent.name}",
|
|
171
|
-
)
|
|
172
|
-
yield mcp.types.CallToolResult(content=[text_content])
|
|
173
|
-
return
|
|
174
|
-
|
|
175
|
-
@classmethod
|
|
176
|
-
async def _execute_local(
|
|
177
|
-
cls,
|
|
178
|
-
tool: FunctionTool,
|
|
179
|
-
run_context: ContextWrapper[AstrAgentContext],
|
|
180
|
-
**tool_args,
|
|
181
|
-
):
|
|
182
|
-
event = run_context.context.event
|
|
183
|
-
if not event:
|
|
184
|
-
raise ValueError("Event must be provided for local function tools.")
|
|
185
|
-
|
|
186
|
-
is_override_call = False
|
|
187
|
-
for ty in type(tool).mro():
|
|
188
|
-
if "call" in ty.__dict__ and ty.__dict__["call"] is not FunctionTool.call:
|
|
189
|
-
logger.debug(f"Found call in: {ty}")
|
|
190
|
-
is_override_call = True
|
|
191
|
-
break
|
|
192
|
-
|
|
193
|
-
# 检查 tool 下有没有 run 方法
|
|
194
|
-
if not tool.handler and not hasattr(tool, "run") and not is_override_call:
|
|
195
|
-
raise ValueError("Tool must have a valid handler or override 'run' method.")
|
|
196
|
-
|
|
197
|
-
awaitable = None
|
|
198
|
-
method_name = ""
|
|
199
|
-
if tool.handler:
|
|
200
|
-
awaitable = tool.handler
|
|
201
|
-
method_name = "decorator_handler"
|
|
202
|
-
elif is_override_call:
|
|
203
|
-
awaitable = tool.call
|
|
204
|
-
method_name = "call"
|
|
205
|
-
elif hasattr(tool, "run"):
|
|
206
|
-
awaitable = getattr(tool, "run")
|
|
207
|
-
method_name = "run"
|
|
208
|
-
if awaitable is None:
|
|
209
|
-
raise ValueError("Tool must have a valid handler or override 'run' method.")
|
|
210
|
-
|
|
211
|
-
wrapper = call_local_llm_tool(
|
|
212
|
-
context=run_context,
|
|
213
|
-
handler=awaitable,
|
|
214
|
-
method_name=method_name,
|
|
215
|
-
**tool_args,
|
|
216
|
-
)
|
|
217
|
-
while True:
|
|
218
|
-
try:
|
|
219
|
-
resp = await asyncio.wait_for(
|
|
220
|
-
anext(wrapper),
|
|
221
|
-
timeout=run_context.tool_call_timeout,
|
|
222
|
-
)
|
|
223
|
-
if resp is not None:
|
|
224
|
-
if isinstance(resp, mcp.types.CallToolResult):
|
|
225
|
-
yield resp
|
|
226
|
-
else:
|
|
227
|
-
text_content = mcp.types.TextContent(
|
|
228
|
-
type="text",
|
|
229
|
-
text=str(resp),
|
|
230
|
-
)
|
|
231
|
-
yield mcp.types.CallToolResult(content=[text_content])
|
|
232
|
-
else:
|
|
233
|
-
# NOTE: Tool 在这里直接请求发送消息给用户
|
|
234
|
-
# TODO: 是否需要判断 event.get_result() 是否为空?
|
|
235
|
-
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
|
|
236
|
-
if res := run_context.context.event.get_result():
|
|
237
|
-
if res.chain:
|
|
238
|
-
try:
|
|
239
|
-
await event.send(
|
|
240
|
-
MessageChain(
|
|
241
|
-
chain=res.chain,
|
|
242
|
-
type="tool_direct_result",
|
|
243
|
-
)
|
|
244
|
-
)
|
|
245
|
-
except Exception as e:
|
|
246
|
-
logger.error(
|
|
247
|
-
f"Tool 直接发送消息失败: {e}",
|
|
248
|
-
exc_info=True,
|
|
249
|
-
)
|
|
250
|
-
yield None
|
|
251
|
-
except asyncio.TimeoutError:
|
|
252
|
-
raise Exception(
|
|
253
|
-
f"tool {tool.name} execution timeout after {run_context.tool_call_timeout} seconds.",
|
|
254
|
-
)
|
|
255
|
-
except StopAsyncIteration:
|
|
256
|
-
break
|
|
257
|
-
|
|
258
|
-
@classmethod
|
|
259
|
-
async def _execute_mcp(
|
|
260
|
-
cls,
|
|
261
|
-
tool: FunctionTool,
|
|
262
|
-
run_context: ContextWrapper[AstrAgentContext],
|
|
263
|
-
**tool_args,
|
|
264
|
-
):
|
|
265
|
-
res = await tool.call(run_context, **tool_args)
|
|
266
|
-
if not res:
|
|
267
|
-
return
|
|
268
|
-
yield res
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
|
|
272
|
-
async def on_agent_done(self, run_context, llm_response):
|
|
273
|
-
# 执行事件钩子
|
|
274
|
-
await call_event_hook(
|
|
275
|
-
run_context.context.event,
|
|
276
|
-
EventType.OnLLMResponseEvent,
|
|
277
|
-
llm_response,
|
|
278
|
-
)
|
|
279
|
-
|
|
280
|
-
async def on_tool_end(
|
|
281
|
-
self,
|
|
282
|
-
run_context: ContextWrapper[AstrAgentContext],
|
|
283
|
-
tool: FunctionTool[Any],
|
|
284
|
-
tool_args: dict | None,
|
|
285
|
-
tool_result: CallToolResult | None,
|
|
286
|
-
):
|
|
287
|
-
run_context.context.event.clear_result()
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
MAIN_AGENT_HOOKS = MainAgentHooks()
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
async def run_agent(
|
|
294
|
-
agent_runner: AgentRunner,
|
|
295
|
-
max_step: int = 30,
|
|
296
|
-
show_tool_use: bool = True,
|
|
297
|
-
) -> AsyncGenerator[MessageChain, None]:
|
|
298
|
-
step_idx = 0
|
|
299
|
-
astr_event = agent_runner.run_context.context.event
|
|
300
|
-
while step_idx < max_step:
|
|
301
|
-
step_idx += 1
|
|
302
|
-
try:
|
|
303
|
-
async for resp in agent_runner.step():
|
|
304
|
-
if astr_event.is_stopped():
|
|
305
|
-
return
|
|
306
|
-
if resp.type == "tool_call_result":
|
|
307
|
-
msg_chain = resp.data["chain"]
|
|
308
|
-
if msg_chain.type == "tool_direct_result":
|
|
309
|
-
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
|
|
310
|
-
resp.data["chain"].type = "tool_call_result"
|
|
311
|
-
await astr_event.send(resp.data["chain"])
|
|
312
|
-
continue
|
|
313
|
-
# 对于其他情况,暂时先不处理
|
|
314
|
-
continue
|
|
315
|
-
elif resp.type == "tool_call":
|
|
316
|
-
if agent_runner.streaming:
|
|
317
|
-
# 用来标记流式响应需要分节
|
|
318
|
-
yield MessageChain(chain=[], type="break")
|
|
319
|
-
if show_tool_use or astr_event.get_platform_name() == "webchat":
|
|
320
|
-
resp.data["chain"].type = "tool_call"
|
|
321
|
-
await astr_event.send(resp.data["chain"])
|
|
322
|
-
continue
|
|
323
|
-
|
|
324
|
-
if not agent_runner.streaming:
|
|
325
|
-
content_typ = (
|
|
326
|
-
ResultContentType.LLM_RESULT
|
|
327
|
-
if resp.type == "llm_result"
|
|
328
|
-
else ResultContentType.GENERAL_RESULT
|
|
329
|
-
)
|
|
330
|
-
astr_event.set_result(
|
|
331
|
-
MessageEventResult(
|
|
332
|
-
chain=resp.data["chain"].chain,
|
|
333
|
-
result_content_type=content_typ,
|
|
334
|
-
),
|
|
335
|
-
)
|
|
336
|
-
yield
|
|
337
|
-
astr_event.clear_result()
|
|
338
|
-
elif resp.type == "streaming_delta":
|
|
339
|
-
yield resp.data["chain"] # MessageChain
|
|
340
|
-
if agent_runner.done():
|
|
341
|
-
break
|
|
342
|
-
|
|
343
|
-
except Exception as e:
|
|
344
|
-
logger.error(traceback.format_exc())
|
|
345
|
-
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在控制台查看和分享错误详情。\n"
|
|
346
|
-
if agent_runner.streaming:
|
|
347
|
-
yield MessageChain().message(err_msg)
|
|
348
|
-
else:
|
|
349
|
-
astr_event.set_result(MessageEventResult().message(err_msg))
|
|
350
|
-
return
|
|
351
|
-
|
|
352
37
|
|
|
353
38
|
class LLMRequestSubStage(Stage):
|
|
354
39
|
async def initialize(self, ctx: PipelineContext) -> None:
|
|
@@ -363,11 +48,15 @@ class LLMRequestSubStage(Stage):
|
|
|
363
48
|
self.max_context_length - 1,
|
|
364
49
|
)
|
|
365
50
|
self.streaming_response: bool = settings["streaming_response"]
|
|
51
|
+
self.unsupported_streaming_strategy: str = settings[
|
|
52
|
+
"unsupported_streaming_strategy"
|
|
53
|
+
]
|
|
366
54
|
self.max_step: int = settings.get("max_agent_step", 30)
|
|
367
55
|
self.tool_call_timeout: int = settings.get("tool_call_timeout", 60)
|
|
368
56
|
if isinstance(self.max_step, bool): # workaround: #2622
|
|
369
57
|
self.max_step = 30
|
|
370
58
|
self.show_tool_use: bool = settings.get("show_tool_use_status", True)
|
|
59
|
+
self.show_reasoning = settings.get("display_reasoning_text", False)
|
|
371
60
|
|
|
372
61
|
for bwp in self.bot_wake_prefixs:
|
|
373
62
|
if self.provider_wake_prefix.startswith(bwp):
|
|
@@ -406,63 +95,12 @@ class LLMRequestSubStage(Stage):
|
|
|
406
95
|
raise RuntimeError("无法创建新的对话。")
|
|
407
96
|
return conversation
|
|
408
97
|
|
|
409
|
-
async def
|
|
98
|
+
async def _apply_kb_context(
|
|
410
99
|
self,
|
|
411
100
|
event: AstrMessageEvent,
|
|
412
|
-
|
|
413
|
-
)
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
|
|
417
|
-
logger.debug("未启用 LLM 能力,跳过处理。")
|
|
418
|
-
return
|
|
419
|
-
|
|
420
|
-
# 检查会话级别的LLM启停状态
|
|
421
|
-
if not SessionServiceManager.should_process_llm_request(event):
|
|
422
|
-
logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM,跳过处理。")
|
|
423
|
-
return
|
|
424
|
-
|
|
425
|
-
provider = self._select_provider(event)
|
|
426
|
-
if provider is None:
|
|
427
|
-
return
|
|
428
|
-
if not isinstance(provider, Provider):
|
|
429
|
-
logger.error(f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。")
|
|
430
|
-
return
|
|
431
|
-
|
|
432
|
-
if event.get_extra("provider_request"):
|
|
433
|
-
req = event.get_extra("provider_request")
|
|
434
|
-
assert isinstance(req, ProviderRequest), (
|
|
435
|
-
"provider_request 必须是 ProviderRequest 类型。"
|
|
436
|
-
)
|
|
437
|
-
|
|
438
|
-
if req.conversation:
|
|
439
|
-
req.contexts = json.loads(req.conversation.history)
|
|
440
|
-
|
|
441
|
-
else:
|
|
442
|
-
req = ProviderRequest(prompt="", image_urls=[])
|
|
443
|
-
if sel_model := event.get_extra("selected_model"):
|
|
444
|
-
req.model = sel_model
|
|
445
|
-
if self.provider_wake_prefix:
|
|
446
|
-
if not event.message_str.startswith(self.provider_wake_prefix):
|
|
447
|
-
return
|
|
448
|
-
req.prompt = event.message_str[len(self.provider_wake_prefix) :]
|
|
449
|
-
# func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。
|
|
450
|
-
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
|
451
|
-
for comp in event.message_obj.message:
|
|
452
|
-
if isinstance(comp, Image):
|
|
453
|
-
image_path = await comp.convert_to_file_path()
|
|
454
|
-
req.image_urls.append(image_path)
|
|
455
|
-
|
|
456
|
-
conversation = await self._get_session_conv(event)
|
|
457
|
-
req.conversation = conversation
|
|
458
|
-
req.contexts = json.loads(conversation.history)
|
|
459
|
-
|
|
460
|
-
event.set_extra("provider_request", req)
|
|
461
|
-
|
|
462
|
-
if not req.prompt and not req.image_urls:
|
|
463
|
-
return
|
|
464
|
-
|
|
465
|
-
# 应用知识库
|
|
101
|
+
req: ProviderRequest,
|
|
102
|
+
):
|
|
103
|
+
"""应用知识库上下文到请求中"""
|
|
466
104
|
try:
|
|
467
105
|
await inject_kb_context(
|
|
468
106
|
umo=event.unified_msg_origin,
|
|
@@ -472,43 +110,40 @@ class LLMRequestSubStage(Stage):
|
|
|
472
110
|
except Exception as e:
|
|
473
111
|
logger.error(f"调用知识库时遇到问题: {e}")
|
|
474
112
|
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
if (
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
(
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
if index is not None and index > 0:
|
|
501
|
-
req.contexts = req.contexts[index:]
|
|
502
|
-
|
|
503
|
-
# session_id
|
|
504
|
-
if not req.session_id:
|
|
505
|
-
req.session_id = event.unified_msg_origin
|
|
113
|
+
def _truncate_contexts(
|
|
114
|
+
self,
|
|
115
|
+
contexts: list[dict],
|
|
116
|
+
) -> list[dict]:
|
|
117
|
+
"""截断上下文列表,确保不超过最大长度"""
|
|
118
|
+
if self.max_context_length == -1:
|
|
119
|
+
return contexts
|
|
120
|
+
|
|
121
|
+
if len(contexts) // 2 <= self.max_context_length:
|
|
122
|
+
return contexts
|
|
123
|
+
|
|
124
|
+
truncated_contexts = contexts[
|
|
125
|
+
-(self.max_context_length - self.dequeue_context_length + 1) * 2 :
|
|
126
|
+
]
|
|
127
|
+
# 找到第一个role 为 user 的索引,确保上下文格式正确
|
|
128
|
+
index = next(
|
|
129
|
+
(
|
|
130
|
+
i
|
|
131
|
+
for i, item in enumerate(truncated_contexts)
|
|
132
|
+
if item.get("role") == "user"
|
|
133
|
+
),
|
|
134
|
+
None,
|
|
135
|
+
)
|
|
136
|
+
if index is not None and index > 0:
|
|
137
|
+
truncated_contexts = truncated_contexts[index:]
|
|
506
138
|
|
|
507
|
-
|
|
508
|
-
req.contexts = self.fix_messages(req.contexts)
|
|
139
|
+
return truncated_contexts
|
|
509
140
|
|
|
510
|
-
|
|
511
|
-
|
|
141
|
+
def _modalities_fix(
|
|
142
|
+
self,
|
|
143
|
+
provider: Provider,
|
|
144
|
+
req: ProviderRequest,
|
|
145
|
+
):
|
|
146
|
+
"""检查提供商的模态能力,清理请求中的不支持内容"""
|
|
512
147
|
if req.image_urls:
|
|
513
148
|
provider_cfg = provider.provider_config.get("modalities", ["image"])
|
|
514
149
|
if "image" not in provider_cfg:
|
|
@@ -522,7 +157,13 @@ class LLMRequestSubStage(Stage):
|
|
|
522
157
|
f"用户设置提供商 {provider} 不支持工具使用,清空工具列表。",
|
|
523
158
|
)
|
|
524
159
|
req.func_tool = None
|
|
525
|
-
|
|
160
|
+
|
|
161
|
+
def _plugin_tool_fix(
|
|
162
|
+
self,
|
|
163
|
+
event: AstrMessageEvent,
|
|
164
|
+
req: ProviderRequest,
|
|
165
|
+
):
|
|
166
|
+
"""根据事件中的插件设置,过滤请求中的工具列表"""
|
|
526
167
|
if event.plugins_name is not None and req.func_tool:
|
|
527
168
|
new_tool_set = ToolSet()
|
|
528
169
|
for tool in req.func_tool.tools:
|
|
@@ -536,80 +177,6 @@ class LLMRequestSubStage(Stage):
|
|
|
536
177
|
new_tool_set.add_tool(tool)
|
|
537
178
|
req.func_tool = new_tool_set
|
|
538
179
|
|
|
539
|
-
# 备份 req.contexts
|
|
540
|
-
backup_contexts = copy.deepcopy(req.contexts)
|
|
541
|
-
|
|
542
|
-
# run agent
|
|
543
|
-
agent_runner = AgentRunner()
|
|
544
|
-
logger.debug(
|
|
545
|
-
f"handle provider[id: {provider.provider_config['id']}] request: {req}",
|
|
546
|
-
)
|
|
547
|
-
astr_agent_ctx = AstrAgentContext(
|
|
548
|
-
provider=provider,
|
|
549
|
-
first_provider_request=req,
|
|
550
|
-
curr_provider_request=req,
|
|
551
|
-
streaming=self.streaming_response,
|
|
552
|
-
event=event,
|
|
553
|
-
)
|
|
554
|
-
await agent_runner.reset(
|
|
555
|
-
provider=provider,
|
|
556
|
-
request=req,
|
|
557
|
-
run_context=AgentContextWrapper(
|
|
558
|
-
context=astr_agent_ctx,
|
|
559
|
-
tool_call_timeout=self.tool_call_timeout,
|
|
560
|
-
),
|
|
561
|
-
tool_executor=FunctionToolExecutor(),
|
|
562
|
-
agent_hooks=MAIN_AGENT_HOOKS,
|
|
563
|
-
streaming=self.streaming_response,
|
|
564
|
-
)
|
|
565
|
-
|
|
566
|
-
if self.streaming_response:
|
|
567
|
-
# 流式响应
|
|
568
|
-
event.set_result(
|
|
569
|
-
MessageEventResult()
|
|
570
|
-
.set_result_content_type(ResultContentType.STREAMING_RESULT)
|
|
571
|
-
.set_async_stream(
|
|
572
|
-
run_agent(agent_runner, self.max_step, self.show_tool_use),
|
|
573
|
-
),
|
|
574
|
-
)
|
|
575
|
-
yield
|
|
576
|
-
if agent_runner.done():
|
|
577
|
-
if final_llm_resp := agent_runner.get_final_llm_resp():
|
|
578
|
-
if final_llm_resp.completion_text:
|
|
579
|
-
chain = (
|
|
580
|
-
MessageChain().message(final_llm_resp.completion_text).chain
|
|
581
|
-
)
|
|
582
|
-
elif final_llm_resp.result_chain:
|
|
583
|
-
chain = final_llm_resp.result_chain.chain
|
|
584
|
-
else:
|
|
585
|
-
chain = MessageChain().chain
|
|
586
|
-
event.set_result(
|
|
587
|
-
MessageEventResult(
|
|
588
|
-
chain=chain,
|
|
589
|
-
result_content_type=ResultContentType.STREAMING_FINISH,
|
|
590
|
-
),
|
|
591
|
-
)
|
|
592
|
-
else:
|
|
593
|
-
async for _ in run_agent(agent_runner, self.max_step, self.show_tool_use):
|
|
594
|
-
yield
|
|
595
|
-
|
|
596
|
-
# 恢复备份的 contexts
|
|
597
|
-
req.contexts = backup_contexts
|
|
598
|
-
|
|
599
|
-
await self._save_to_history(event, req, agent_runner.get_final_llm_resp())
|
|
600
|
-
|
|
601
|
-
# 异步处理 WebChat 特殊情况
|
|
602
|
-
if event.get_platform_name() == "webchat":
|
|
603
|
-
asyncio.create_task(self._handle_webchat(event, req, provider))
|
|
604
|
-
|
|
605
|
-
asyncio.create_task(
|
|
606
|
-
Metric.upload(
|
|
607
|
-
llm_tick=1,
|
|
608
|
-
model_name=agent_runner.provider.get_model(),
|
|
609
|
-
provider_type=agent_runner.provider.meta().type,
|
|
610
|
-
),
|
|
611
|
-
)
|
|
612
|
-
|
|
613
180
|
async def _handle_webchat(
|
|
614
181
|
self,
|
|
615
182
|
event: AstrMessageEvent,
|
|
@@ -657,9 +224,6 @@ class LLMRequestSubStage(Stage):
|
|
|
657
224
|
),
|
|
658
225
|
)
|
|
659
226
|
if llm_resp and llm_resp.completion_text:
|
|
660
|
-
logger.debug(
|
|
661
|
-
f"WebChat 对话标题生成响应: {llm_resp.completion_text.strip()}",
|
|
662
|
-
)
|
|
663
227
|
title = llm_resp.completion_text.strip()
|
|
664
228
|
if not title or "<None>" in title:
|
|
665
229
|
return
|
|
@@ -687,6 +251,9 @@ class LLMRequestSubStage(Stage):
|
|
|
687
251
|
logger.debug("LLM 响应为空,不保存记录。")
|
|
688
252
|
return
|
|
689
253
|
|
|
254
|
+
if req.contexts is None:
|
|
255
|
+
req.contexts = []
|
|
256
|
+
|
|
690
257
|
# 历史上下文
|
|
691
258
|
messages = copy.deepcopy(req.contexts)
|
|
692
259
|
# 这一轮对话请求的用户输入
|
|
@@ -706,7 +273,7 @@ class LLMRequestSubStage(Stage):
|
|
|
706
273
|
history=messages,
|
|
707
274
|
)
|
|
708
275
|
|
|
709
|
-
def
|
|
276
|
+
def _fix_messages(self, messages: list[dict]) -> list[dict]:
|
|
710
277
|
"""验证并且修复上下文"""
|
|
711
278
|
fixed_messages = []
|
|
712
279
|
for message in messages:
|
|
@@ -721,3 +288,184 @@ class LLMRequestSubStage(Stage):
|
|
|
721
288
|
else:
|
|
722
289
|
fixed_messages.append(message)
|
|
723
290
|
return fixed_messages
|
|
291
|
+
|
|
292
|
+
async def process(
|
|
293
|
+
self,
|
|
294
|
+
event: AstrMessageEvent,
|
|
295
|
+
_nested: bool = False,
|
|
296
|
+
) -> None | AsyncGenerator[None, None]:
|
|
297
|
+
req: ProviderRequest | None = None
|
|
298
|
+
|
|
299
|
+
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
|
|
300
|
+
logger.debug("未启用 LLM 能力,跳过处理。")
|
|
301
|
+
return
|
|
302
|
+
|
|
303
|
+
# 检查会话级别的LLM启停状态
|
|
304
|
+
if not SessionServiceManager.should_process_llm_request(event):
|
|
305
|
+
logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM,跳过处理。")
|
|
306
|
+
return
|
|
307
|
+
|
|
308
|
+
provider = self._select_provider(event)
|
|
309
|
+
if provider is None:
|
|
310
|
+
return
|
|
311
|
+
if not isinstance(provider, Provider):
|
|
312
|
+
logger.error(f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。")
|
|
313
|
+
return
|
|
314
|
+
|
|
315
|
+
streaming_response = self.streaming_response
|
|
316
|
+
if (enable_streaming := event.get_extra("enable_streaming")) is not None:
|
|
317
|
+
streaming_response = bool(enable_streaming)
|
|
318
|
+
|
|
319
|
+
logger.debug("ready to request llm provider")
|
|
320
|
+
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
|
|
321
|
+
logger.debug("acquired session lock for llm request")
|
|
322
|
+
if event.get_extra("provider_request"):
|
|
323
|
+
req = event.get_extra("provider_request")
|
|
324
|
+
assert isinstance(req, ProviderRequest), (
|
|
325
|
+
"provider_request 必须是 ProviderRequest 类型。"
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
if req.conversation:
|
|
329
|
+
req.contexts = json.loads(req.conversation.history)
|
|
330
|
+
|
|
331
|
+
else:
|
|
332
|
+
req = ProviderRequest()
|
|
333
|
+
req.prompt = ""
|
|
334
|
+
req.image_urls = []
|
|
335
|
+
if sel_model := event.get_extra("selected_model"):
|
|
336
|
+
req.model = sel_model
|
|
337
|
+
if self.provider_wake_prefix and not event.message_str.startswith(
|
|
338
|
+
self.provider_wake_prefix
|
|
339
|
+
):
|
|
340
|
+
return
|
|
341
|
+
|
|
342
|
+
req.prompt = event.message_str[len(self.provider_wake_prefix) :]
|
|
343
|
+
# func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。
|
|
344
|
+
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
|
345
|
+
for comp in event.message_obj.message:
|
|
346
|
+
if isinstance(comp, Image):
|
|
347
|
+
image_path = await comp.convert_to_file_path()
|
|
348
|
+
req.image_urls.append(image_path)
|
|
349
|
+
|
|
350
|
+
conversation = await self._get_session_conv(event)
|
|
351
|
+
req.conversation = conversation
|
|
352
|
+
req.contexts = json.loads(conversation.history)
|
|
353
|
+
|
|
354
|
+
event.set_extra("provider_request", req)
|
|
355
|
+
|
|
356
|
+
if not req.prompt and not req.image_urls:
|
|
357
|
+
return
|
|
358
|
+
|
|
359
|
+
# apply knowledge base context
|
|
360
|
+
await self._apply_kb_context(event, req)
|
|
361
|
+
|
|
362
|
+
# call event hook
|
|
363
|
+
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
|
|
364
|
+
return
|
|
365
|
+
|
|
366
|
+
# fix contexts json str
|
|
367
|
+
if isinstance(req.contexts, str):
|
|
368
|
+
req.contexts = json.loads(req.contexts)
|
|
369
|
+
|
|
370
|
+
# truncate contexts to fit max length
|
|
371
|
+
if req.contexts:
|
|
372
|
+
req.contexts = self._truncate_contexts(req.contexts)
|
|
373
|
+
self._fix_messages(req.contexts)
|
|
374
|
+
|
|
375
|
+
# session_id
|
|
376
|
+
if not req.session_id:
|
|
377
|
+
req.session_id = event.unified_msg_origin
|
|
378
|
+
|
|
379
|
+
# check provider modalities, if provider does not support image/tool_use, clear them in request.
|
|
380
|
+
self._modalities_fix(provider, req)
|
|
381
|
+
|
|
382
|
+
# filter tools, only keep tools from this pipeline's selected plugins
|
|
383
|
+
self._plugin_tool_fix(event, req)
|
|
384
|
+
|
|
385
|
+
stream_to_general = (
|
|
386
|
+
self.unsupported_streaming_strategy == "turn_off"
|
|
387
|
+
and not event.platform_meta.support_streaming_message
|
|
388
|
+
)
|
|
389
|
+
# 备份 req.contexts
|
|
390
|
+
backup_contexts = copy.deepcopy(req.contexts)
|
|
391
|
+
|
|
392
|
+
# run agent
|
|
393
|
+
agent_runner = AgentRunner()
|
|
394
|
+
logger.debug(
|
|
395
|
+
f"handle provider[id: {provider.provider_config['id']}] request: {req}",
|
|
396
|
+
)
|
|
397
|
+
astr_agent_ctx = AstrAgentContext(
|
|
398
|
+
context=self.ctx.plugin_manager.context,
|
|
399
|
+
event=event,
|
|
400
|
+
)
|
|
401
|
+
await agent_runner.reset(
|
|
402
|
+
provider=provider,
|
|
403
|
+
request=req,
|
|
404
|
+
run_context=AgentContextWrapper(
|
|
405
|
+
context=astr_agent_ctx,
|
|
406
|
+
tool_call_timeout=self.tool_call_timeout,
|
|
407
|
+
),
|
|
408
|
+
tool_executor=FunctionToolExecutor(),
|
|
409
|
+
agent_hooks=MAIN_AGENT_HOOKS,
|
|
410
|
+
streaming=streaming_response,
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
if streaming_response and not stream_to_general:
|
|
414
|
+
# 流式响应
|
|
415
|
+
event.set_result(
|
|
416
|
+
MessageEventResult()
|
|
417
|
+
.set_result_content_type(ResultContentType.STREAMING_RESULT)
|
|
418
|
+
.set_async_stream(
|
|
419
|
+
run_agent(
|
|
420
|
+
agent_runner,
|
|
421
|
+
self.max_step,
|
|
422
|
+
self.show_tool_use,
|
|
423
|
+
show_reasoning=self.show_reasoning,
|
|
424
|
+
),
|
|
425
|
+
),
|
|
426
|
+
)
|
|
427
|
+
yield
|
|
428
|
+
if agent_runner.done():
|
|
429
|
+
if final_llm_resp := agent_runner.get_final_llm_resp():
|
|
430
|
+
if final_llm_resp.completion_text:
|
|
431
|
+
chain = (
|
|
432
|
+
MessageChain()
|
|
433
|
+
.message(final_llm_resp.completion_text)
|
|
434
|
+
.chain
|
|
435
|
+
)
|
|
436
|
+
elif final_llm_resp.result_chain:
|
|
437
|
+
chain = final_llm_resp.result_chain.chain
|
|
438
|
+
else:
|
|
439
|
+
chain = MessageChain().chain
|
|
440
|
+
event.set_result(
|
|
441
|
+
MessageEventResult(
|
|
442
|
+
chain=chain,
|
|
443
|
+
result_content_type=ResultContentType.STREAMING_FINISH,
|
|
444
|
+
),
|
|
445
|
+
)
|
|
446
|
+
else:
|
|
447
|
+
async for _ in run_agent(
|
|
448
|
+
agent_runner,
|
|
449
|
+
self.max_step,
|
|
450
|
+
self.show_tool_use,
|
|
451
|
+
stream_to_general,
|
|
452
|
+
show_reasoning=self.show_reasoning,
|
|
453
|
+
):
|
|
454
|
+
yield
|
|
455
|
+
|
|
456
|
+
# 恢复备份的 contexts
|
|
457
|
+
req.contexts = backup_contexts
|
|
458
|
+
|
|
459
|
+
await self._save_to_history(event, req, agent_runner.get_final_llm_resp())
|
|
460
|
+
|
|
461
|
+
# 异步处理 WebChat 特殊情况
|
|
462
|
+
if event.get_platform_name() == "webchat":
|
|
463
|
+
asyncio.create_task(self._handle_webchat(event, req, provider))
|
|
464
|
+
|
|
465
|
+
asyncio.create_task(
|
|
466
|
+
Metric.upload(
|
|
467
|
+
llm_tick=1,
|
|
468
|
+
model_name=agent_runner.provider.get_model(),
|
|
469
|
+
provider_type=agent_runner.provider.meta().type,
|
|
470
|
+
),
|
|
471
|
+
)
|