AstrBot 4.5.6__py3-none-any.whl → 4.5.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. astrbot/api/all.py +2 -1
  2. astrbot/api/provider/__init__.py +2 -1
  3. astrbot/core/agent/message.py +1 -1
  4. astrbot/core/agent/run_context.py +7 -2
  5. astrbot/core/agent/runners/base.py +7 -0
  6. astrbot/core/agent/runners/tool_loop_agent_runner.py +51 -3
  7. astrbot/core/agent/tool.py +5 -6
  8. astrbot/core/astr_agent_context.py +13 -8
  9. astrbot/core/astr_agent_hooks.py +36 -0
  10. astrbot/core/astr_agent_run_util.py +80 -0
  11. astrbot/core/astr_agent_tool_exec.py +246 -0
  12. astrbot/core/config/default.py +53 -7
  13. astrbot/core/exceptions.py +9 -0
  14. astrbot/core/pipeline/context.py +1 -2
  15. astrbot/core/pipeline/context_utils.py +0 -65
  16. astrbot/core/pipeline/process_stage/method/llm_request.py +239 -491
  17. astrbot/core/pipeline/respond/stage.py +21 -20
  18. astrbot/core/platform/platform_metadata.py +3 -0
  19. astrbot/core/platform/register.py +2 -0
  20. astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +2 -0
  21. astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +16 -5
  22. astrbot/core/platform/sources/discord/discord_platform_adapter.py +4 -1
  23. astrbot/core/platform/sources/discord/discord_platform_event.py +16 -7
  24. astrbot/core/platform/sources/lark/lark_adapter.py +4 -1
  25. astrbot/core/platform/sources/misskey/misskey_adapter.py +4 -1
  26. astrbot/core/platform/sources/satori/satori_adapter.py +2 -2
  27. astrbot/core/platform/sources/slack/slack_adapter.py +2 -0
  28. astrbot/core/platform/sources/webchat/webchat_adapter.py +3 -0
  29. astrbot/core/platform/sources/webchat/webchat_event.py +8 -1
  30. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +4 -1
  31. astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +16 -0
  32. astrbot/core/platform/sources/wecom/wecom_adapter.py +2 -1
  33. astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +4 -1
  34. astrbot/core/provider/__init__.py +2 -2
  35. astrbot/core/provider/entities.py +40 -18
  36. astrbot/core/provider/func_tool_manager.py +15 -6
  37. astrbot/core/provider/manager.py +4 -1
  38. astrbot/core/provider/provider.py +7 -22
  39. astrbot/core/provider/register.py +2 -0
  40. astrbot/core/provider/sources/anthropic_source.py +0 -2
  41. astrbot/core/provider/sources/coze_source.py +0 -2
  42. astrbot/core/provider/sources/dashscope_source.py +1 -3
  43. astrbot/core/provider/sources/dify_source.py +0 -2
  44. astrbot/core/provider/sources/gemini_source.py +31 -3
  45. astrbot/core/provider/sources/groq_source.py +15 -0
  46. astrbot/core/provider/sources/openai_source.py +67 -21
  47. astrbot/core/provider/sources/zhipu_source.py +1 -6
  48. astrbot/core/star/context.py +197 -45
  49. astrbot/core/star/register/star_handler.py +30 -10
  50. astrbot/dashboard/routes/chat.py +5 -0
  51. {astrbot-4.5.6.dist-info → astrbot-4.5.8.dist-info}/METADATA +2 -2
  52. {astrbot-4.5.6.dist-info → astrbot-4.5.8.dist-info}/RECORD +55 -50
  53. {astrbot-4.5.6.dist-info → astrbot-4.5.8.dist-info}/WHEEL +0 -0
  54. {astrbot-4.5.6.dist-info → astrbot-4.5.8.dist-info}/entry_points.txt +0 -0
  55. {astrbot-4.5.6.dist-info → astrbot-4.5.8.dist-info}/licenses/LICENSE +0 -0
@@ -3,20 +3,10 @@
3
3
  import asyncio
4
4
  import copy
5
5
  import json
6
- import traceback
7
6
  from collections.abc import AsyncGenerator
8
- from typing import Any
9
-
10
- from mcp.types import CallToolResult
11
7
 
12
8
  from astrbot.core import logger
13
- from astrbot.core.agent.handoff import HandoffTool
14
- from astrbot.core.agent.hooks import BaseAgentRunHooks
15
- from astrbot.core.agent.mcp_client import MCPTool
16
- from astrbot.core.agent.run_context import ContextWrapper
17
- from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
18
- from astrbot.core.agent.tool import FunctionTool, ToolSet
19
- from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
9
+ from astrbot.core.agent.tool import ToolSet
20
10
  from astrbot.core.astr_agent_context import AstrAgentContext
21
11
  from astrbot.core.conversation_mgr import Conversation
22
12
  from astrbot.core.message.components import Image
@@ -31,324 +21,19 @@ from astrbot.core.provider.entities import (
31
21
  LLMResponse,
32
22
  ProviderRequest,
33
23
  )
34
- from astrbot.core.provider.register import llm_tools
35
24
  from astrbot.core.star.session_llm_manager import SessionServiceManager
36
25
  from astrbot.core.star.star_handler import EventType, star_map
37
26
  from astrbot.core.utils.metrics import Metric
27
+ from astrbot.core.utils.session_lock import session_lock_manager
38
28
 
39
- from ...context import PipelineContext, call_event_hook, call_local_llm_tool
29
+ from ....astr_agent_context import AgentContextWrapper
30
+ from ....astr_agent_hooks import MAIN_AGENT_HOOKS
31
+ from ....astr_agent_run_util import AgentRunner, run_agent
32
+ from ....astr_agent_tool_exec import FunctionToolExecutor
33
+ from ...context import PipelineContext, call_event_hook
40
34
  from ..stage import Stage
41
35
  from ..utils import inject_kb_context
42
36
 
43
- try:
44
- import mcp
45
- except (ModuleNotFoundError, ImportError):
46
- logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
47
-
48
-
49
- AgentContextWrapper = ContextWrapper[AstrAgentContext]
50
- AgentRunner = ToolLoopAgentRunner[AstrAgentContext]
51
-
52
-
53
- class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
54
- @classmethod
55
- async def execute(cls, tool, run_context, **tool_args):
56
- """执行函数调用。
57
-
58
- Args:
59
- event (AstrMessageEvent): 事件对象, 当 origin 为 local 时必须提供。
60
- **kwargs: 函数调用的参数。
61
-
62
- Returns:
63
- AsyncGenerator[None | mcp.types.CallToolResult, None]
64
-
65
- """
66
- if isinstance(tool, HandoffTool):
67
- async for r in cls._execute_handoff(tool, run_context, **tool_args):
68
- yield r
69
- return
70
-
71
- elif isinstance(tool, MCPTool):
72
- async for r in cls._execute_mcp(tool, run_context, **tool_args):
73
- yield r
74
- return
75
-
76
- else:
77
- async for r in cls._execute_local(tool, run_context, **tool_args):
78
- yield r
79
- return
80
-
81
- @classmethod
82
- async def _execute_handoff(
83
- cls,
84
- tool: HandoffTool,
85
- run_context: ContextWrapper[AstrAgentContext],
86
- **tool_args,
87
- ):
88
- input_ = tool_args.get("input", "agent")
89
- agent_runner = AgentRunner()
90
-
91
- # make toolset for the agent
92
- tools = tool.agent.tools
93
- if tools:
94
- toolset = ToolSet()
95
- for t in tools:
96
- if isinstance(t, str):
97
- _t = llm_tools.get_func(t)
98
- if _t:
99
- toolset.add_tool(_t)
100
- elif isinstance(t, FunctionTool):
101
- toolset.add_tool(t)
102
- else:
103
- toolset = None
104
-
105
- request = ProviderRequest(
106
- prompt=input_,
107
- system_prompt=tool.description or "",
108
- image_urls=[], # 暂时不传递原始 agent 的上下文
109
- contexts=[], # 暂时不传递原始 agent 的上下文
110
- func_tool=toolset,
111
- )
112
- astr_agent_ctx = AstrAgentContext(
113
- provider=run_context.context.provider,
114
- first_provider_request=run_context.context.first_provider_request,
115
- curr_provider_request=request,
116
- streaming=run_context.context.streaming,
117
- event=run_context.context.event,
118
- )
119
-
120
- event = run_context.context.event
121
-
122
- logger.debug(f"正在将任务委托给 Agent: {tool.agent.name}, input: {input_}")
123
- await event.send(
124
- MessageChain().message("✨ 正在将任务委托给 Agent: " + tool.agent.name),
125
- )
126
-
127
- await agent_runner.reset(
128
- provider=run_context.context.provider,
129
- request=request,
130
- run_context=AgentContextWrapper(
131
- context=astr_agent_ctx,
132
- tool_call_timeout=run_context.tool_call_timeout,
133
- ),
134
- tool_executor=FunctionToolExecutor(),
135
- agent_hooks=tool.agent.run_hooks or BaseAgentRunHooks[AstrAgentContext](),
136
- streaming=run_context.context.streaming,
137
- )
138
-
139
- async for _ in run_agent(agent_runner, 15, True):
140
- pass
141
-
142
- if agent_runner.done():
143
- llm_response = agent_runner.get_final_llm_resp()
144
-
145
- if not llm_response:
146
- text_content = mcp.types.TextContent(
147
- type="text",
148
- text=f"error when deligate task to {tool.agent.name}",
149
- )
150
- yield mcp.types.CallToolResult(content=[text_content])
151
- return
152
-
153
- logger.debug(
154
- f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}",
155
- )
156
-
157
- result = (
158
- f"Agent {tool.agent.name} respond with: {llm_response.completion_text}\n\n"
159
- "Note: If the result is error or need user provide more information, please provide more information to the agent(you can ask user for more information first)."
160
- )
161
-
162
- text_content = mcp.types.TextContent(
163
- type="text",
164
- text=result,
165
- )
166
- yield mcp.types.CallToolResult(content=[text_content])
167
- else:
168
- text_content = mcp.types.TextContent(
169
- type="text",
170
- text=f"error when deligate task to {tool.agent.name}",
171
- )
172
- yield mcp.types.CallToolResult(content=[text_content])
173
- return
174
-
175
- @classmethod
176
- async def _execute_local(
177
- cls,
178
- tool: FunctionTool,
179
- run_context: ContextWrapper[AstrAgentContext],
180
- **tool_args,
181
- ):
182
- event = run_context.context.event
183
- if not event:
184
- raise ValueError("Event must be provided for local function tools.")
185
-
186
- is_override_call = False
187
- for ty in type(tool).mro():
188
- if "call" in ty.__dict__ and ty.__dict__["call"] is not FunctionTool.call:
189
- logger.debug(f"Found call in: {ty}")
190
- is_override_call = True
191
- break
192
-
193
- # 检查 tool 下有没有 run 方法
194
- if not tool.handler and not hasattr(tool, "run") and not is_override_call:
195
- raise ValueError("Tool must have a valid handler or override 'run' method.")
196
-
197
- awaitable = None
198
- method_name = ""
199
- if tool.handler:
200
- awaitable = tool.handler
201
- method_name = "decorator_handler"
202
- elif is_override_call:
203
- awaitable = tool.call
204
- method_name = "call"
205
- elif hasattr(tool, "run"):
206
- awaitable = getattr(tool, "run")
207
- method_name = "run"
208
- if awaitable is None:
209
- raise ValueError("Tool must have a valid handler or override 'run' method.")
210
-
211
- wrapper = call_local_llm_tool(
212
- context=run_context,
213
- handler=awaitable,
214
- method_name=method_name,
215
- **tool_args,
216
- )
217
- while True:
218
- try:
219
- resp = await asyncio.wait_for(
220
- anext(wrapper),
221
- timeout=run_context.tool_call_timeout,
222
- )
223
- if resp is not None:
224
- if isinstance(resp, mcp.types.CallToolResult):
225
- yield resp
226
- else:
227
- text_content = mcp.types.TextContent(
228
- type="text",
229
- text=str(resp),
230
- )
231
- yield mcp.types.CallToolResult(content=[text_content])
232
- else:
233
- # NOTE: Tool 在这里直接请求发送消息给用户
234
- # TODO: 是否需要判断 event.get_result() 是否为空?
235
- # 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
236
- if res := run_context.context.event.get_result():
237
- if res.chain:
238
- try:
239
- await event.send(
240
- MessageChain(
241
- chain=res.chain,
242
- type="tool_direct_result",
243
- )
244
- )
245
- except Exception as e:
246
- logger.error(
247
- f"Tool 直接发送消息失败: {e}",
248
- exc_info=True,
249
- )
250
- yield None
251
- except asyncio.TimeoutError:
252
- raise Exception(
253
- f"tool {tool.name} execution timeout after {run_context.tool_call_timeout} seconds.",
254
- )
255
- except StopAsyncIteration:
256
- break
257
-
258
- @classmethod
259
- async def _execute_mcp(
260
- cls,
261
- tool: FunctionTool,
262
- run_context: ContextWrapper[AstrAgentContext],
263
- **tool_args,
264
- ):
265
- res = await tool.call(run_context, **tool_args)
266
- if not res:
267
- return
268
- yield res
269
-
270
-
271
- class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
272
- async def on_agent_done(self, run_context, llm_response):
273
- # 执行事件钩子
274
- await call_event_hook(
275
- run_context.context.event,
276
- EventType.OnLLMResponseEvent,
277
- llm_response,
278
- )
279
-
280
- async def on_tool_end(
281
- self,
282
- run_context: ContextWrapper[AstrAgentContext],
283
- tool: FunctionTool[Any],
284
- tool_args: dict | None,
285
- tool_result: CallToolResult | None,
286
- ):
287
- run_context.context.event.clear_result()
288
-
289
-
290
- MAIN_AGENT_HOOKS = MainAgentHooks()
291
-
292
-
293
- async def run_agent(
294
- agent_runner: AgentRunner,
295
- max_step: int = 30,
296
- show_tool_use: bool = True,
297
- ) -> AsyncGenerator[MessageChain, None]:
298
- step_idx = 0
299
- astr_event = agent_runner.run_context.context.event
300
- while step_idx < max_step:
301
- step_idx += 1
302
- try:
303
- async for resp in agent_runner.step():
304
- if astr_event.is_stopped():
305
- return
306
- if resp.type == "tool_call_result":
307
- msg_chain = resp.data["chain"]
308
- if msg_chain.type == "tool_direct_result":
309
- # tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
310
- resp.data["chain"].type = "tool_call_result"
311
- await astr_event.send(resp.data["chain"])
312
- continue
313
- # 对于其他情况,暂时先不处理
314
- continue
315
- elif resp.type == "tool_call":
316
- if agent_runner.streaming:
317
- # 用来标记流式响应需要分节
318
- yield MessageChain(chain=[], type="break")
319
- if show_tool_use or astr_event.get_platform_name() == "webchat":
320
- resp.data["chain"].type = "tool_call"
321
- await astr_event.send(resp.data["chain"])
322
- continue
323
-
324
- if not agent_runner.streaming:
325
- content_typ = (
326
- ResultContentType.LLM_RESULT
327
- if resp.type == "llm_result"
328
- else ResultContentType.GENERAL_RESULT
329
- )
330
- astr_event.set_result(
331
- MessageEventResult(
332
- chain=resp.data["chain"].chain,
333
- result_content_type=content_typ,
334
- ),
335
- )
336
- yield
337
- astr_event.clear_result()
338
- elif resp.type == "streaming_delta":
339
- yield resp.data["chain"] # MessageChain
340
- if agent_runner.done():
341
- break
342
-
343
- except Exception as e:
344
- logger.error(traceback.format_exc())
345
- err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在控制台查看和分享错误详情。\n"
346
- if agent_runner.streaming:
347
- yield MessageChain().message(err_msg)
348
- else:
349
- astr_event.set_result(MessageEventResult().message(err_msg))
350
- return
351
-
352
37
 
353
38
  class LLMRequestSubStage(Stage):
354
39
  async def initialize(self, ctx: PipelineContext) -> None:
@@ -363,11 +48,15 @@ class LLMRequestSubStage(Stage):
363
48
  self.max_context_length - 1,
364
49
  )
365
50
  self.streaming_response: bool = settings["streaming_response"]
51
+ self.unsupported_streaming_strategy: str = settings[
52
+ "unsupported_streaming_strategy"
53
+ ]
366
54
  self.max_step: int = settings.get("max_agent_step", 30)
367
55
  self.tool_call_timeout: int = settings.get("tool_call_timeout", 60)
368
56
  if isinstance(self.max_step, bool): # workaround: #2622
369
57
  self.max_step = 30
370
58
  self.show_tool_use: bool = settings.get("show_tool_use_status", True)
59
+ self.show_reasoning = settings.get("display_reasoning_text", False)
371
60
 
372
61
  for bwp in self.bot_wake_prefixs:
373
62
  if self.provider_wake_prefix.startswith(bwp):
@@ -406,63 +95,12 @@ class LLMRequestSubStage(Stage):
406
95
  raise RuntimeError("无法创建新的对话。")
407
96
  return conversation
408
97
 
409
- async def process(
98
+ async def _apply_kb_context(
410
99
  self,
411
100
  event: AstrMessageEvent,
412
- _nested: bool = False,
413
- ) -> None | AsyncGenerator[None, None]:
414
- req: ProviderRequest | None = None
415
-
416
- if not self.ctx.astrbot_config["provider_settings"]["enable"]:
417
- logger.debug("未启用 LLM 能力,跳过处理。")
418
- return
419
-
420
- # 检查会话级别的LLM启停状态
421
- if not SessionServiceManager.should_process_llm_request(event):
422
- logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM,跳过处理。")
423
- return
424
-
425
- provider = self._select_provider(event)
426
- if provider is None:
427
- return
428
- if not isinstance(provider, Provider):
429
- logger.error(f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。")
430
- return
431
-
432
- if event.get_extra("provider_request"):
433
- req = event.get_extra("provider_request")
434
- assert isinstance(req, ProviderRequest), (
435
- "provider_request 必须是 ProviderRequest 类型。"
436
- )
437
-
438
- if req.conversation:
439
- req.contexts = json.loads(req.conversation.history)
440
-
441
- else:
442
- req = ProviderRequest(prompt="", image_urls=[])
443
- if sel_model := event.get_extra("selected_model"):
444
- req.model = sel_model
445
- if self.provider_wake_prefix:
446
- if not event.message_str.startswith(self.provider_wake_prefix):
447
- return
448
- req.prompt = event.message_str[len(self.provider_wake_prefix) :]
449
- # func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。
450
- # req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
451
- for comp in event.message_obj.message:
452
- if isinstance(comp, Image):
453
- image_path = await comp.convert_to_file_path()
454
- req.image_urls.append(image_path)
455
-
456
- conversation = await self._get_session_conv(event)
457
- req.conversation = conversation
458
- req.contexts = json.loads(conversation.history)
459
-
460
- event.set_extra("provider_request", req)
461
-
462
- if not req.prompt and not req.image_urls:
463
- return
464
-
465
- # 应用知识库
101
+ req: ProviderRequest,
102
+ ):
103
+ """应用知识库上下文到请求中"""
466
104
  try:
467
105
  await inject_kb_context(
468
106
  umo=event.unified_msg_origin,
@@ -472,43 +110,40 @@ class LLMRequestSubStage(Stage):
472
110
  except Exception as e:
473
111
  logger.error(f"调用知识库时遇到问题: {e}")
474
112
 
475
- # 执行请求 LLM 前事件钩子。
476
- if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
477
- return
478
-
479
- if isinstance(req.contexts, str):
480
- req.contexts = json.loads(req.contexts)
481
-
482
- # max context length
483
- if (
484
- self.max_context_length != -1 # -1 为不限制
485
- and len(req.contexts) // 2 > self.max_context_length
486
- ):
487
- logger.debug("上下文长度超过限制,将截断。")
488
- req.contexts = req.contexts[
489
- -(self.max_context_length - self.dequeue_context_length + 1) * 2 :
490
- ]
491
- # 找到第一个role 为 user 的索引,确保上下文格式正确
492
- index = next(
493
- (
494
- i
495
- for i, item in enumerate(req.contexts)
496
- if item.get("role") == "user"
497
- ),
498
- None,
499
- )
500
- if index is not None and index > 0:
501
- req.contexts = req.contexts[index:]
502
-
503
- # session_id
504
- if not req.session_id:
505
- req.session_id = event.unified_msg_origin
113
+ def _truncate_contexts(
114
+ self,
115
+ contexts: list[dict],
116
+ ) -> list[dict]:
117
+ """截断上下文列表,确保不超过最大长度"""
118
+ if self.max_context_length == -1:
119
+ return contexts
120
+
121
+ if len(contexts) // 2 <= self.max_context_length:
122
+ return contexts
123
+
124
+ truncated_contexts = contexts[
125
+ -(self.max_context_length - self.dequeue_context_length + 1) * 2 :
126
+ ]
127
+ # 找到第一个role user 的索引,确保上下文格式正确
128
+ index = next(
129
+ (
130
+ i
131
+ for i, item in enumerate(truncated_contexts)
132
+ if item.get("role") == "user"
133
+ ),
134
+ None,
135
+ )
136
+ if index is not None and index > 0:
137
+ truncated_contexts = truncated_contexts[index:]
506
138
 
507
- # fix messages
508
- req.contexts = self.fix_messages(req.contexts)
139
+ return truncated_contexts
509
140
 
510
- # check provider modalities
511
- # 如果提供商不支持图像/工具使用,但请求中包含图像/工具列表,则清空。图片转述等的检测和调用发生在这之前,因此这里可以这样处理。
141
+ def _modalities_fix(
142
+ self,
143
+ provider: Provider,
144
+ req: ProviderRequest,
145
+ ):
146
+ """检查提供商的模态能力,清理请求中的不支持内容"""
512
147
  if req.image_urls:
513
148
  provider_cfg = provider.provider_config.get("modalities", ["image"])
514
149
  if "image" not in provider_cfg:
@@ -522,7 +157,13 @@ class LLMRequestSubStage(Stage):
522
157
  f"用户设置提供商 {provider} 不支持工具使用,清空工具列表。",
523
158
  )
524
159
  req.func_tool = None
525
- # 插件可用性设置
160
+
161
+ def _plugin_tool_fix(
162
+ self,
163
+ event: AstrMessageEvent,
164
+ req: ProviderRequest,
165
+ ):
166
+ """根据事件中的插件设置,过滤请求中的工具列表"""
526
167
  if event.plugins_name is not None and req.func_tool:
527
168
  new_tool_set = ToolSet()
528
169
  for tool in req.func_tool.tools:
@@ -536,80 +177,6 @@ class LLMRequestSubStage(Stage):
536
177
  new_tool_set.add_tool(tool)
537
178
  req.func_tool = new_tool_set
538
179
 
539
- # 备份 req.contexts
540
- backup_contexts = copy.deepcopy(req.contexts)
541
-
542
- # run agent
543
- agent_runner = AgentRunner()
544
- logger.debug(
545
- f"handle provider[id: {provider.provider_config['id']}] request: {req}",
546
- )
547
- astr_agent_ctx = AstrAgentContext(
548
- provider=provider,
549
- first_provider_request=req,
550
- curr_provider_request=req,
551
- streaming=self.streaming_response,
552
- event=event,
553
- )
554
- await agent_runner.reset(
555
- provider=provider,
556
- request=req,
557
- run_context=AgentContextWrapper(
558
- context=astr_agent_ctx,
559
- tool_call_timeout=self.tool_call_timeout,
560
- ),
561
- tool_executor=FunctionToolExecutor(),
562
- agent_hooks=MAIN_AGENT_HOOKS,
563
- streaming=self.streaming_response,
564
- )
565
-
566
- if self.streaming_response:
567
- # 流式响应
568
- event.set_result(
569
- MessageEventResult()
570
- .set_result_content_type(ResultContentType.STREAMING_RESULT)
571
- .set_async_stream(
572
- run_agent(agent_runner, self.max_step, self.show_tool_use),
573
- ),
574
- )
575
- yield
576
- if agent_runner.done():
577
- if final_llm_resp := agent_runner.get_final_llm_resp():
578
- if final_llm_resp.completion_text:
579
- chain = (
580
- MessageChain().message(final_llm_resp.completion_text).chain
581
- )
582
- elif final_llm_resp.result_chain:
583
- chain = final_llm_resp.result_chain.chain
584
- else:
585
- chain = MessageChain().chain
586
- event.set_result(
587
- MessageEventResult(
588
- chain=chain,
589
- result_content_type=ResultContentType.STREAMING_FINISH,
590
- ),
591
- )
592
- else:
593
- async for _ in run_agent(agent_runner, self.max_step, self.show_tool_use):
594
- yield
595
-
596
- # 恢复备份的 contexts
597
- req.contexts = backup_contexts
598
-
599
- await self._save_to_history(event, req, agent_runner.get_final_llm_resp())
600
-
601
- # 异步处理 WebChat 特殊情况
602
- if event.get_platform_name() == "webchat":
603
- asyncio.create_task(self._handle_webchat(event, req, provider))
604
-
605
- asyncio.create_task(
606
- Metric.upload(
607
- llm_tick=1,
608
- model_name=agent_runner.provider.get_model(),
609
- provider_type=agent_runner.provider.meta().type,
610
- ),
611
- )
612
-
613
180
  async def _handle_webchat(
614
181
  self,
615
182
  event: AstrMessageEvent,
@@ -657,9 +224,6 @@ class LLMRequestSubStage(Stage):
657
224
  ),
658
225
  )
659
226
  if llm_resp and llm_resp.completion_text:
660
- logger.debug(
661
- f"WebChat 对话标题生成响应: {llm_resp.completion_text.strip()}",
662
- )
663
227
  title = llm_resp.completion_text.strip()
664
228
  if not title or "<None>" in title:
665
229
  return
@@ -687,6 +251,9 @@ class LLMRequestSubStage(Stage):
687
251
  logger.debug("LLM 响应为空,不保存记录。")
688
252
  return
689
253
 
254
+ if req.contexts is None:
255
+ req.contexts = []
256
+
690
257
  # 历史上下文
691
258
  messages = copy.deepcopy(req.contexts)
692
259
  # 这一轮对话请求的用户输入
@@ -706,7 +273,7 @@ class LLMRequestSubStage(Stage):
706
273
  history=messages,
707
274
  )
708
275
 
709
- def fix_messages(self, messages: list[dict]) -> list[dict]:
276
+ def _fix_messages(self, messages: list[dict]) -> list[dict]:
710
277
  """验证并且修复上下文"""
711
278
  fixed_messages = []
712
279
  for message in messages:
@@ -721,3 +288,184 @@ class LLMRequestSubStage(Stage):
721
288
  else:
722
289
  fixed_messages.append(message)
723
290
  return fixed_messages
291
+
292
+ async def process(
293
+ self,
294
+ event: AstrMessageEvent,
295
+ _nested: bool = False,
296
+ ) -> None | AsyncGenerator[None, None]:
297
+ req: ProviderRequest | None = None
298
+
299
+ if not self.ctx.astrbot_config["provider_settings"]["enable"]:
300
+ logger.debug("未启用 LLM 能力,跳过处理。")
301
+ return
302
+
303
+ # 检查会话级别的LLM启停状态
304
+ if not SessionServiceManager.should_process_llm_request(event):
305
+ logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM,跳过处理。")
306
+ return
307
+
308
+ provider = self._select_provider(event)
309
+ if provider is None:
310
+ return
311
+ if not isinstance(provider, Provider):
312
+ logger.error(f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。")
313
+ return
314
+
315
+ streaming_response = self.streaming_response
316
+ if (enable_streaming := event.get_extra("enable_streaming")) is not None:
317
+ streaming_response = bool(enable_streaming)
318
+
319
+ logger.debug("ready to request llm provider")
320
+ async with session_lock_manager.acquire_lock(event.unified_msg_origin):
321
+ logger.debug("acquired session lock for llm request")
322
+ if event.get_extra("provider_request"):
323
+ req = event.get_extra("provider_request")
324
+ assert isinstance(req, ProviderRequest), (
325
+ "provider_request 必须是 ProviderRequest 类型。"
326
+ )
327
+
328
+ if req.conversation:
329
+ req.contexts = json.loads(req.conversation.history)
330
+
331
+ else:
332
+ req = ProviderRequest()
333
+ req.prompt = ""
334
+ req.image_urls = []
335
+ if sel_model := event.get_extra("selected_model"):
336
+ req.model = sel_model
337
+ if self.provider_wake_prefix and not event.message_str.startswith(
338
+ self.provider_wake_prefix
339
+ ):
340
+ return
341
+
342
+ req.prompt = event.message_str[len(self.provider_wake_prefix) :]
343
+ # func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。
344
+ # req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
345
+ for comp in event.message_obj.message:
346
+ if isinstance(comp, Image):
347
+ image_path = await comp.convert_to_file_path()
348
+ req.image_urls.append(image_path)
349
+
350
+ conversation = await self._get_session_conv(event)
351
+ req.conversation = conversation
352
+ req.contexts = json.loads(conversation.history)
353
+
354
+ event.set_extra("provider_request", req)
355
+
356
+ if not req.prompt and not req.image_urls:
357
+ return
358
+
359
+ # apply knowledge base context
360
+ await self._apply_kb_context(event, req)
361
+
362
+ # call event hook
363
+ if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
364
+ return
365
+
366
+ # fix contexts json str
367
+ if isinstance(req.contexts, str):
368
+ req.contexts = json.loads(req.contexts)
369
+
370
+ # truncate contexts to fit max length
371
+ if req.contexts:
372
+ req.contexts = self._truncate_contexts(req.contexts)
373
+ self._fix_messages(req.contexts)
374
+
375
+ # session_id
376
+ if not req.session_id:
377
+ req.session_id = event.unified_msg_origin
378
+
379
+ # check provider modalities, if provider does not support image/tool_use, clear them in request.
380
+ self._modalities_fix(provider, req)
381
+
382
+ # filter tools, only keep tools from this pipeline's selected plugins
383
+ self._plugin_tool_fix(event, req)
384
+
385
+ stream_to_general = (
386
+ self.unsupported_streaming_strategy == "turn_off"
387
+ and not event.platform_meta.support_streaming_message
388
+ )
389
+ # 备份 req.contexts
390
+ backup_contexts = copy.deepcopy(req.contexts)
391
+
392
+ # run agent
393
+ agent_runner = AgentRunner()
394
+ logger.debug(
395
+ f"handle provider[id: {provider.provider_config['id']}] request: {req}",
396
+ )
397
+ astr_agent_ctx = AstrAgentContext(
398
+ context=self.ctx.plugin_manager.context,
399
+ event=event,
400
+ )
401
+ await agent_runner.reset(
402
+ provider=provider,
403
+ request=req,
404
+ run_context=AgentContextWrapper(
405
+ context=astr_agent_ctx,
406
+ tool_call_timeout=self.tool_call_timeout,
407
+ ),
408
+ tool_executor=FunctionToolExecutor(),
409
+ agent_hooks=MAIN_AGENT_HOOKS,
410
+ streaming=streaming_response,
411
+ )
412
+
413
+ if streaming_response and not stream_to_general:
414
+ # 流式响应
415
+ event.set_result(
416
+ MessageEventResult()
417
+ .set_result_content_type(ResultContentType.STREAMING_RESULT)
418
+ .set_async_stream(
419
+ run_agent(
420
+ agent_runner,
421
+ self.max_step,
422
+ self.show_tool_use,
423
+ show_reasoning=self.show_reasoning,
424
+ ),
425
+ ),
426
+ )
427
+ yield
428
+ if agent_runner.done():
429
+ if final_llm_resp := agent_runner.get_final_llm_resp():
430
+ if final_llm_resp.completion_text:
431
+ chain = (
432
+ MessageChain()
433
+ .message(final_llm_resp.completion_text)
434
+ .chain
435
+ )
436
+ elif final_llm_resp.result_chain:
437
+ chain = final_llm_resp.result_chain.chain
438
+ else:
439
+ chain = MessageChain().chain
440
+ event.set_result(
441
+ MessageEventResult(
442
+ chain=chain,
443
+ result_content_type=ResultContentType.STREAMING_FINISH,
444
+ ),
445
+ )
446
+ else:
447
+ async for _ in run_agent(
448
+ agent_runner,
449
+ self.max_step,
450
+ self.show_tool_use,
451
+ stream_to_general,
452
+ show_reasoning=self.show_reasoning,
453
+ ):
454
+ yield
455
+
456
+ # 恢复备份的 contexts
457
+ req.contexts = backup_contexts
458
+
459
+ await self._save_to_history(event, req, agent_runner.get_final_llm_resp())
460
+
461
+ # 异步处理 WebChat 特殊情况
462
+ if event.get_platform_name() == "webchat":
463
+ asyncio.create_task(self._handle_webchat(event, req, provider))
464
+
465
+ asyncio.create_task(
466
+ Metric.upload(
467
+ llm_tick=1,
468
+ model_name=agent_runner.provider.get_model(),
469
+ provider_type=agent_runner.provider.meta().type,
470
+ ),
471
+ )