AstrBot 4.5.0__py3-none-any.whl → 4.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/api/__init__.py +10 -11
- astrbot/api/event/__init__.py +5 -6
- astrbot/api/event/filter/__init__.py +37 -36
- astrbot/api/platform/__init__.py +7 -8
- astrbot/api/provider/__init__.py +7 -7
- astrbot/api/star/__init__.py +3 -4
- astrbot/api/util/__init__.py +2 -2
- astrbot/cli/__main__.py +5 -5
- astrbot/cli/commands/__init__.py +3 -3
- astrbot/cli/commands/cmd_conf.py +19 -16
- astrbot/cli/commands/cmd_init.py +3 -2
- astrbot/cli/commands/cmd_plug.py +8 -10
- astrbot/cli/commands/cmd_run.py +5 -6
- astrbot/cli/utils/__init__.py +6 -6
- astrbot/cli/utils/basic.py +14 -14
- astrbot/cli/utils/plugin.py +24 -15
- astrbot/cli/utils/version_comparator.py +10 -12
- astrbot/core/__init__.py +8 -6
- astrbot/core/agent/agent.py +3 -2
- astrbot/core/agent/handoff.py +6 -2
- astrbot/core/agent/hooks.py +9 -6
- astrbot/core/agent/mcp_client.py +50 -15
- astrbot/core/agent/message.py +168 -0
- astrbot/core/agent/response.py +2 -1
- astrbot/core/agent/run_context.py +2 -3
- astrbot/core/agent/runners/base.py +10 -13
- astrbot/core/agent/runners/tool_loop_agent_runner.py +52 -51
- astrbot/core/agent/tool.py +60 -41
- astrbot/core/agent/tool_executor.py +9 -3
- astrbot/core/astr_agent_context.py +3 -1
- astrbot/core/astrbot_config_mgr.py +29 -9
- astrbot/core/config/__init__.py +2 -2
- astrbot/core/config/astrbot_config.py +28 -26
- astrbot/core/config/default.py +44 -6
- astrbot/core/conversation_mgr.py +105 -36
- astrbot/core/core_lifecycle.py +68 -54
- astrbot/core/db/__init__.py +33 -18
- astrbot/core/db/migration/helper.py +18 -13
- astrbot/core/db/migration/migra_3_to_4.py +53 -34
- astrbot/core/db/migration/migra_45_to_46.py +1 -1
- astrbot/core/db/migration/shared_preferences_v3.py +2 -1
- astrbot/core/db/migration/sqlite_v3.py +26 -23
- astrbot/core/db/po.py +27 -18
- astrbot/core/db/sqlite.py +74 -45
- astrbot/core/db/vec_db/base.py +10 -14
- astrbot/core/db/vec_db/faiss_impl/document_storage.py +90 -77
- astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +9 -3
- astrbot/core/db/vec_db/faiss_impl/vec_db.py +36 -31
- astrbot/core/event_bus.py +8 -6
- astrbot/core/file_token_service.py +6 -5
- astrbot/core/initial_loader.py +7 -5
- astrbot/core/knowledge_base/chunking/__init__.py +1 -3
- astrbot/core/knowledge_base/chunking/base.py +1 -0
- astrbot/core/knowledge_base/chunking/fixed_size.py +2 -0
- astrbot/core/knowledge_base/chunking/recursive.py +16 -10
- astrbot/core/knowledge_base/kb_db_sqlite.py +50 -48
- astrbot/core/knowledge_base/kb_helper.py +30 -17
- astrbot/core/knowledge_base/kb_mgr.py +6 -7
- astrbot/core/knowledge_base/models.py +10 -4
- astrbot/core/knowledge_base/parsers/__init__.py +3 -5
- astrbot/core/knowledge_base/parsers/base.py +1 -0
- astrbot/core/knowledge_base/parsers/markitdown_parser.py +2 -1
- astrbot/core/knowledge_base/parsers/pdf_parser.py +2 -1
- astrbot/core/knowledge_base/parsers/text_parser.py +1 -0
- astrbot/core/knowledge_base/parsers/util.py +1 -1
- astrbot/core/knowledge_base/retrieval/__init__.py +6 -8
- astrbot/core/knowledge_base/retrieval/manager.py +17 -14
- astrbot/core/knowledge_base/retrieval/rank_fusion.py +7 -3
- astrbot/core/knowledge_base/retrieval/sparse_retriever.py +11 -5
- astrbot/core/log.py +21 -13
- astrbot/core/message/components.py +123 -217
- astrbot/core/message/message_event_result.py +24 -24
- astrbot/core/persona_mgr.py +20 -11
- astrbot/core/pipeline/__init__.py +7 -7
- astrbot/core/pipeline/content_safety_check/stage.py +13 -9
- astrbot/core/pipeline/content_safety_check/strategies/__init__.py +1 -2
- astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py +12 -13
- astrbot/core/pipeline/content_safety_check/strategies/keywords.py +1 -0
- astrbot/core/pipeline/content_safety_check/strategies/strategy.py +6 -6
- astrbot/core/pipeline/context.py +4 -1
- astrbot/core/pipeline/context_utils.py +77 -7
- astrbot/core/pipeline/preprocess_stage/stage.py +12 -9
- astrbot/core/pipeline/process_stage/method/llm_request.py +125 -72
- astrbot/core/pipeline/process_stage/method/star_request.py +19 -17
- astrbot/core/pipeline/process_stage/stage.py +13 -10
- astrbot/core/pipeline/process_stage/utils.py +6 -5
- astrbot/core/pipeline/rate_limit_check/stage.py +37 -36
- astrbot/core/pipeline/respond/stage.py +23 -20
- astrbot/core/pipeline/result_decorate/stage.py +31 -23
- astrbot/core/pipeline/scheduler.py +12 -8
- astrbot/core/pipeline/session_status_check/stage.py +12 -8
- astrbot/core/pipeline/stage.py +10 -4
- astrbot/core/pipeline/waking_check/stage.py +24 -18
- astrbot/core/pipeline/whitelist_check/stage.py +10 -7
- astrbot/core/platform/__init__.py +6 -6
- astrbot/core/platform/astr_message_event.py +76 -110
- astrbot/core/platform/astrbot_message.py +11 -13
- astrbot/core/platform/manager.py +16 -15
- astrbot/core/platform/message_session.py +5 -3
- astrbot/core/platform/platform.py +16 -24
- astrbot/core/platform/platform_metadata.py +4 -4
- astrbot/core/platform/register.py +8 -8
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +23 -15
- astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +51 -33
- astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +47 -29
- astrbot/core/platform/sources/dingtalk/dingtalk_event.py +7 -3
- astrbot/core/platform/sources/discord/client.py +9 -6
- astrbot/core/platform/sources/discord/components.py +18 -14
- astrbot/core/platform/sources/discord/discord_platform_adapter.py +45 -30
- astrbot/core/platform/sources/discord/discord_platform_event.py +38 -30
- astrbot/core/platform/sources/lark/lark_adapter.py +23 -17
- astrbot/core/platform/sources/lark/lark_event.py +21 -14
- astrbot/core/platform/sources/misskey/misskey_adapter.py +107 -67
- astrbot/core/platform/sources/misskey/misskey_api.py +153 -129
- astrbot/core/platform/sources/misskey/misskey_event.py +20 -15
- astrbot/core/platform/sources/misskey/misskey_utils.py +74 -62
- astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +63 -44
- astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +41 -26
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +36 -17
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_event.py +3 -1
- astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +12 -7
- astrbot/core/platform/sources/satori/satori_adapter.py +56 -38
- astrbot/core/platform/sources/satori/satori_event.py +34 -25
- astrbot/core/platform/sources/slack/client.py +11 -9
- astrbot/core/platform/sources/slack/slack_adapter.py +52 -36
- astrbot/core/platform/sources/slack/slack_event.py +34 -24
- astrbot/core/platform/sources/telegram/tg_adapter.py +38 -18
- astrbot/core/platform/sources/telegram/tg_event.py +32 -18
- astrbot/core/platform/sources/webchat/webchat_adapter.py +27 -17
- astrbot/core/platform/sources/webchat/webchat_event.py +14 -10
- astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +115 -120
- astrbot/core/platform/sources/wechatpadpro/wechatpadpro_message_event.py +9 -8
- astrbot/core/platform/sources/wechatpadpro/xml_data_parser.py +15 -16
- astrbot/core/platform/sources/wecom/wecom_adapter.py +35 -18
- astrbot/core/platform/sources/wecom/wecom_event.py +55 -48
- astrbot/core/platform/sources/wecom/wecom_kf.py +34 -44
- astrbot/core/platform/sources/wecom/wecom_kf_message.py +26 -10
- astrbot/core/platform/sources/wecom_ai_bot/WXBizJsonMsgCrypt.py +18 -10
- astrbot/core/platform/sources/wecom_ai_bot/__init__.py +3 -5
- astrbot/core/platform/sources/wecom_ai_bot/ierror.py +0 -1
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +61 -37
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_api.py +67 -28
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +8 -9
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_queue_mgr.py +18 -9
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +14 -12
- astrbot/core/platform/sources/wecom_ai_bot/wecomai_utils.py +22 -12
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +40 -26
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +47 -45
- astrbot/core/platform_message_history_mgr.py +5 -3
- astrbot/core/provider/__init__.py +2 -3
- astrbot/core/provider/entites.py +8 -8
- astrbot/core/provider/entities.py +61 -75
- astrbot/core/provider/func_tool_manager.py +59 -55
- astrbot/core/provider/manager.py +40 -22
- astrbot/core/provider/provider.py +72 -46
- astrbot/core/provider/register.py +7 -7
- astrbot/core/provider/sources/anthropic_source.py +48 -30
- astrbot/core/provider/sources/azure_tts_source.py +17 -13
- astrbot/core/provider/sources/coze_api_client.py +27 -17
- astrbot/core/provider/sources/coze_source.py +104 -87
- astrbot/core/provider/sources/dashscope_source.py +18 -11
- astrbot/core/provider/sources/dashscope_tts.py +36 -23
- astrbot/core/provider/sources/dify_source.py +25 -20
- astrbot/core/provider/sources/edge_tts_source.py +21 -17
- astrbot/core/provider/sources/fishaudio_tts_api_source.py +22 -14
- astrbot/core/provider/sources/gemini_embedding_source.py +12 -13
- astrbot/core/provider/sources/gemini_source.py +72 -58
- astrbot/core/provider/sources/gemini_tts_source.py +8 -6
- astrbot/core/provider/sources/gsv_selfhosted_source.py +17 -14
- astrbot/core/provider/sources/gsvi_tts_source.py +11 -7
- astrbot/core/provider/sources/minimax_tts_api_source.py +50 -40
- astrbot/core/provider/sources/openai_embedding_source.py +6 -8
- astrbot/core/provider/sources/openai_source.py +102 -69
- astrbot/core/provider/sources/openai_tts_api_source.py +14 -6
- astrbot/core/provider/sources/sensevoice_selfhosted_source.py +13 -11
- astrbot/core/provider/sources/vllm_rerank_source.py +10 -4
- astrbot/core/provider/sources/volcengine_tts.py +38 -31
- astrbot/core/provider/sources/whisper_api_source.py +14 -12
- astrbot/core/provider/sources/whisper_selfhosted_source.py +15 -11
- astrbot/core/provider/sources/xinference_rerank_source.py +116 -0
- astrbot/core/provider/sources/xinference_stt_provider.py +197 -0
- astrbot/core/star/__init__.py +16 -11
- astrbot/core/star/config.py +10 -15
- astrbot/core/star/context.py +109 -84
- astrbot/core/star/filter/__init__.py +4 -3
- astrbot/core/star/filter/command.py +30 -28
- astrbot/core/star/filter/command_group.py +27 -24
- astrbot/core/star/filter/custom_filter.py +6 -5
- astrbot/core/star/filter/event_message_type.py +4 -2
- astrbot/core/star/filter/permission.py +4 -2
- astrbot/core/star/filter/platform_adapter_type.py +4 -2
- astrbot/core/star/filter/regex.py +4 -2
- astrbot/core/star/register/__init__.py +19 -19
- astrbot/core/star/register/star.py +6 -2
- astrbot/core/star/register/star_handler.py +96 -73
- astrbot/core/star/session_llm_manager.py +48 -14
- astrbot/core/star/session_plugin_manager.py +29 -15
- astrbot/core/star/star.py +1 -2
- astrbot/core/star/star_handler.py +13 -8
- astrbot/core/star/star_manager.py +151 -59
- astrbot/core/star/star_tools.py +44 -37
- astrbot/core/star/updator.py +10 -10
- astrbot/core/umop_config_router.py +10 -4
- astrbot/core/updator.py +13 -5
- astrbot/core/utils/astrbot_path.py +3 -5
- astrbot/core/utils/dify_api_client.py +33 -15
- astrbot/core/utils/io.py +66 -42
- astrbot/core/utils/log_pipe.py +1 -1
- astrbot/core/utils/metrics.py +7 -7
- astrbot/core/utils/path_util.py +15 -16
- astrbot/core/utils/pip_installer.py +5 -5
- astrbot/core/utils/session_waiter.py +19 -20
- astrbot/core/utils/shared_preferences.py +45 -20
- astrbot/core/utils/t2i/__init__.py +4 -1
- astrbot/core/utils/t2i/network_strategy.py +35 -26
- astrbot/core/utils/t2i/renderer.py +11 -5
- astrbot/core/utils/t2i/template_manager.py +14 -15
- astrbot/core/utils/tencent_record_helper.py +19 -13
- astrbot/core/utils/version_comparator.py +10 -13
- astrbot/core/zip_updator.py +43 -40
- astrbot/dashboard/routes/__init__.py +18 -18
- astrbot/dashboard/routes/auth.py +10 -8
- astrbot/dashboard/routes/chat.py +30 -21
- astrbot/dashboard/routes/config.py +92 -75
- astrbot/dashboard/routes/conversation.py +46 -39
- astrbot/dashboard/routes/file.py +4 -2
- astrbot/dashboard/routes/knowledge_base.py +47 -40
- astrbot/dashboard/routes/log.py +9 -4
- astrbot/dashboard/routes/persona.py +19 -16
- astrbot/dashboard/routes/plugin.py +69 -55
- astrbot/dashboard/routes/route.py +3 -1
- astrbot/dashboard/routes/session_management.py +130 -116
- astrbot/dashboard/routes/stat.py +34 -34
- astrbot/dashboard/routes/t2i.py +15 -12
- astrbot/dashboard/routes/tools.py +47 -52
- astrbot/dashboard/routes/update.py +32 -28
- astrbot/dashboard/server.py +30 -26
- astrbot/dashboard/utils.py +8 -4
- {astrbot-4.5.0.dist-info → astrbot-4.5.2.dist-info}/METADATA +4 -2
- astrbot-4.5.2.dist-info/RECORD +261 -0
- astrbot-4.5.0.dist-info/RECORD +0 -258
- {astrbot-4.5.0.dist-info → astrbot-4.5.2.dist-info}/WHEEL +0 -0
- {astrbot-4.5.0.dist-info → astrbot-4.5.2.dist-info}/entry_points.txt +0 -0
- {astrbot-4.5.0.dist-info → astrbot-4.5.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,15 +1,24 @@
|
|
|
1
|
-
"""
|
|
2
|
-
本地 Agent 模式的 LLM 调用 Stage
|
|
3
|
-
"""
|
|
1
|
+
"""本地 Agent 模式的 LLM 调用 Stage"""
|
|
4
2
|
|
|
5
3
|
import asyncio
|
|
6
4
|
import copy
|
|
7
5
|
import json
|
|
8
6
|
import traceback
|
|
9
|
-
from datetime import timedelta
|
|
10
7
|
from collections.abc import AsyncGenerator
|
|
11
|
-
from
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
from mcp.types import CallToolResult
|
|
11
|
+
|
|
12
12
|
from astrbot.core import logger
|
|
13
|
+
from astrbot.core.agent.handoff import HandoffTool
|
|
14
|
+
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
|
15
|
+
from astrbot.core.agent.mcp_client import MCPTool
|
|
16
|
+
from astrbot.core.agent.run_context import ContextWrapper
|
|
17
|
+
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
|
|
18
|
+
from astrbot.core.agent.tool import FunctionTool, ToolSet
|
|
19
|
+
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
|
|
20
|
+
from astrbot.core.astr_agent_context import AstrAgentContext
|
|
21
|
+
from astrbot.core.conversation_mgr import Conversation
|
|
13
22
|
from astrbot.core.message.components import Image
|
|
14
23
|
from astrbot.core.message.message_event_result import (
|
|
15
24
|
MessageChain,
|
|
@@ -22,21 +31,14 @@ from astrbot.core.provider.entities import (
|
|
|
22
31
|
LLMResponse,
|
|
23
32
|
ProviderRequest,
|
|
24
33
|
)
|
|
25
|
-
from astrbot.core.
|
|
26
|
-
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
|
|
27
|
-
from astrbot.core.agent.run_context import ContextWrapper
|
|
28
|
-
from astrbot.core.agent.tool import ToolSet, FunctionTool
|
|
29
|
-
from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor
|
|
30
|
-
from astrbot.core.agent.handoff import HandoffTool
|
|
34
|
+
from astrbot.core.provider.register import llm_tools
|
|
31
35
|
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
|
32
|
-
from astrbot.core.star.star_handler import EventType
|
|
36
|
+
from astrbot.core.star.star_handler import EventType, star_map
|
|
33
37
|
from astrbot.core.utils.metrics import Metric
|
|
34
|
-
|
|
38
|
+
|
|
39
|
+
from ...context import PipelineContext, call_event_hook, call_local_llm_tool
|
|
35
40
|
from ..stage import Stage
|
|
36
41
|
from ..utils import inject_kb_context
|
|
37
|
-
from astrbot.core.provider.register import llm_tools
|
|
38
|
-
from astrbot.core.star.star_handler import star_map
|
|
39
|
-
from astrbot.core.astr_agent_context import AstrAgentContext
|
|
40
42
|
|
|
41
43
|
try:
|
|
42
44
|
import mcp
|
|
@@ -59,24 +61,23 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
|
|
59
61
|
|
|
60
62
|
Returns:
|
|
61
63
|
AsyncGenerator[None | mcp.types.CallToolResult, None]
|
|
64
|
+
|
|
62
65
|
"""
|
|
63
66
|
if isinstance(tool, HandoffTool):
|
|
64
67
|
async for r in cls._execute_handoff(tool, run_context, **tool_args):
|
|
65
68
|
yield r
|
|
66
69
|
return
|
|
67
70
|
|
|
68
|
-
|
|
69
|
-
async for r in cls.
|
|
71
|
+
elif isinstance(tool, MCPTool):
|
|
72
|
+
async for r in cls._execute_mcp(tool, run_context, **tool_args):
|
|
70
73
|
yield r
|
|
71
74
|
return
|
|
72
75
|
|
|
73
|
-
|
|
74
|
-
async for r in cls.
|
|
76
|
+
else:
|
|
77
|
+
async for r in cls._execute_local(tool, run_context, **tool_args):
|
|
75
78
|
yield r
|
|
76
79
|
return
|
|
77
80
|
|
|
78
|
-
raise Exception(f"Unknown function origin: {tool.origin}")
|
|
79
|
-
|
|
80
81
|
@classmethod
|
|
81
82
|
async def _execute_handoff(
|
|
82
83
|
cls,
|
|
@@ -113,18 +114,22 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
|
|
113
114
|
first_provider_request=run_context.context.first_provider_request,
|
|
114
115
|
curr_provider_request=request,
|
|
115
116
|
streaming=run_context.context.streaming,
|
|
117
|
+
event=run_context.context.event,
|
|
116
118
|
)
|
|
117
119
|
|
|
120
|
+
event = run_context.context.event
|
|
121
|
+
|
|
118
122
|
logger.debug(f"正在将任务委托给 Agent: {tool.agent.name}, input: {input_}")
|
|
119
|
-
await
|
|
120
|
-
MessageChain().message("✨ 正在将任务委托给 Agent: " + tool.agent.name)
|
|
123
|
+
await event.send(
|
|
124
|
+
MessageChain().message("✨ 正在将任务委托给 Agent: " + tool.agent.name),
|
|
121
125
|
)
|
|
122
126
|
|
|
123
127
|
await agent_runner.reset(
|
|
124
128
|
provider=run_context.context.provider,
|
|
125
129
|
request=request,
|
|
126
130
|
run_context=AgentContextWrapper(
|
|
127
|
-
context=astr_agent_ctx,
|
|
131
|
+
context=astr_agent_ctx,
|
|
132
|
+
tool_call_timeout=run_context.tool_call_timeout,
|
|
128
133
|
),
|
|
129
134
|
tool_executor=FunctionToolExecutor(),
|
|
130
135
|
agent_hooks=tool.agent.run_hooks or BaseAgentRunHooks[AstrAgentContext](),
|
|
@@ -146,7 +151,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
|
|
146
151
|
return
|
|
147
152
|
|
|
148
153
|
logger.debug(
|
|
149
|
-
f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}"
|
|
154
|
+
f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}",
|
|
150
155
|
)
|
|
151
156
|
|
|
152
157
|
result = (
|
|
@@ -174,25 +179,46 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
|
|
174
179
|
run_context: ContextWrapper[AstrAgentContext],
|
|
175
180
|
**tool_args,
|
|
176
181
|
):
|
|
177
|
-
|
|
182
|
+
event = run_context.context.event
|
|
183
|
+
if not event:
|
|
178
184
|
raise ValueError("Event must be provided for local function tools.")
|
|
179
185
|
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
186
|
+
is_override_call = False
|
|
187
|
+
for ty in type(tool).mro():
|
|
188
|
+
if "call" in ty.__dict__ and ty.__dict__["call"] is not FunctionTool.call:
|
|
189
|
+
logger.debug(f"Found call in: {ty}")
|
|
190
|
+
is_override_call = True
|
|
191
|
+
break
|
|
184
192
|
|
|
185
|
-
|
|
186
|
-
|
|
193
|
+
# 检查 tool 下有没有 run 方法
|
|
194
|
+
if not tool.handler and not hasattr(tool, "run") and not is_override_call:
|
|
195
|
+
raise ValueError("Tool must have a valid handler or override 'run' method.")
|
|
196
|
+
|
|
197
|
+
awaitable = None
|
|
198
|
+
method_name = ""
|
|
199
|
+
if tool.handler:
|
|
200
|
+
awaitable = tool.handler
|
|
201
|
+
method_name = "decorator_handler"
|
|
202
|
+
elif is_override_call:
|
|
203
|
+
awaitable = tool.call
|
|
204
|
+
method_name = "call"
|
|
205
|
+
elif hasattr(tool, "run"):
|
|
206
|
+
awaitable = getattr(tool, "run")
|
|
207
|
+
method_name = "run"
|
|
208
|
+
if awaitable is None:
|
|
209
|
+
raise ValueError("Tool must have a valid handler or override 'run' method.")
|
|
210
|
+
|
|
211
|
+
wrapper = call_local_llm_tool(
|
|
212
|
+
context=run_context,
|
|
187
213
|
handler=awaitable,
|
|
214
|
+
method_name=method_name,
|
|
188
215
|
**tool_args,
|
|
189
216
|
)
|
|
190
|
-
# async for resp in wrapper:
|
|
191
217
|
while True:
|
|
192
218
|
try:
|
|
193
219
|
resp = await asyncio.wait_for(
|
|
194
220
|
anext(wrapper),
|
|
195
|
-
timeout=run_context.
|
|
221
|
+
timeout=run_context.tool_call_timeout,
|
|
196
222
|
)
|
|
197
223
|
if resp is not None:
|
|
198
224
|
if isinstance(resp, mcp.types.CallToolResult):
|
|
@@ -207,10 +233,24 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
|
|
207
233
|
# NOTE: Tool 在这里直接请求发送消息给用户
|
|
208
234
|
# TODO: 是否需要判断 event.get_result() 是否为空?
|
|
209
235
|
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
|
|
236
|
+
if res := run_context.context.event.get_result():
|
|
237
|
+
if res.chain:
|
|
238
|
+
try:
|
|
239
|
+
await event.send(
|
|
240
|
+
MessageChain(
|
|
241
|
+
chain=res.chain,
|
|
242
|
+
type="tool_direct_result",
|
|
243
|
+
)
|
|
244
|
+
)
|
|
245
|
+
except Exception as e:
|
|
246
|
+
logger.error(
|
|
247
|
+
f"Tool 直接发送消息失败: {e}",
|
|
248
|
+
exc_info=True,
|
|
249
|
+
)
|
|
210
250
|
yield None
|
|
211
251
|
except asyncio.TimeoutError:
|
|
212
252
|
raise Exception(
|
|
213
|
-
f"tool {tool.name} execution timeout after {run_context.
|
|
253
|
+
f"tool {tool.name} execution timeout after {run_context.tool_call_timeout} seconds.",
|
|
214
254
|
)
|
|
215
255
|
except StopAsyncIteration:
|
|
216
256
|
break
|
|
@@ -222,19 +262,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
|
|
222
262
|
run_context: ContextWrapper[AstrAgentContext],
|
|
223
263
|
**tool_args,
|
|
224
264
|
):
|
|
225
|
-
|
|
226
|
-
raise ValueError("MCP client is not available for MCP function tools.")
|
|
227
|
-
|
|
228
|
-
session = tool.mcp_client.session
|
|
229
|
-
if not session:
|
|
230
|
-
raise ValueError("MCP session is not available for MCP function tools.")
|
|
231
|
-
res = await session.call_tool(
|
|
232
|
-
name=tool.name,
|
|
233
|
-
arguments=tool_args,
|
|
234
|
-
read_timeout_seconds=timedelta(
|
|
235
|
-
seconds=run_context.context.tool_call_timeout
|
|
236
|
-
),
|
|
237
|
-
)
|
|
265
|
+
res = await tool.call(run_context, **tool_args)
|
|
238
266
|
if not res:
|
|
239
267
|
return
|
|
240
268
|
yield res
|
|
@@ -244,18 +272,31 @@ class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
|
|
|
244
272
|
async def on_agent_done(self, run_context, llm_response):
|
|
245
273
|
# 执行事件钩子
|
|
246
274
|
await call_event_hook(
|
|
247
|
-
run_context.event,
|
|
275
|
+
run_context.context.event,
|
|
276
|
+
EventType.OnLLMResponseEvent,
|
|
277
|
+
llm_response,
|
|
248
278
|
)
|
|
249
279
|
|
|
280
|
+
async def on_tool_end(
|
|
281
|
+
self,
|
|
282
|
+
run_context: ContextWrapper[AstrAgentContext],
|
|
283
|
+
tool: FunctionTool[Any],
|
|
284
|
+
tool_args: dict | None,
|
|
285
|
+
tool_result: CallToolResult | None,
|
|
286
|
+
):
|
|
287
|
+
run_context.context.event.clear_result()
|
|
288
|
+
|
|
250
289
|
|
|
251
290
|
MAIN_AGENT_HOOKS = MainAgentHooks()
|
|
252
291
|
|
|
253
292
|
|
|
254
293
|
async def run_agent(
|
|
255
|
-
agent_runner: AgentRunner,
|
|
294
|
+
agent_runner: AgentRunner,
|
|
295
|
+
max_step: int = 30,
|
|
296
|
+
show_tool_use: bool = True,
|
|
256
297
|
) -> AsyncGenerator[MessageChain, None]:
|
|
257
298
|
step_idx = 0
|
|
258
|
-
astr_event = agent_runner.run_context.event
|
|
299
|
+
astr_event = agent_runner.run_context.context.event
|
|
259
300
|
while step_idx < max_step:
|
|
260
301
|
step_idx += 1
|
|
261
302
|
try:
|
|
@@ -290,19 +331,18 @@ async def run_agent(
|
|
|
290
331
|
MessageEventResult(
|
|
291
332
|
chain=resp.data["chain"].chain,
|
|
292
333
|
result_content_type=content_typ,
|
|
293
|
-
)
|
|
334
|
+
),
|
|
294
335
|
)
|
|
295
336
|
yield
|
|
296
337
|
astr_event.clear_result()
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
yield resp.data["chain"] # MessageChain
|
|
338
|
+
elif resp.type == "streaming_delta":
|
|
339
|
+
yield resp.data["chain"] # MessageChain
|
|
300
340
|
if agent_runner.done():
|
|
301
341
|
break
|
|
302
342
|
|
|
303
343
|
except Exception as e:
|
|
304
344
|
logger.error(traceback.format_exc())
|
|
305
|
-
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {
|
|
345
|
+
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在控制台查看和分享错误详情。\n"
|
|
306
346
|
if agent_runner.streaming:
|
|
307
347
|
yield MessageChain().message(err_msg)
|
|
308
348
|
else:
|
|
@@ -332,7 +372,7 @@ class LLMRequestSubStage(Stage):
|
|
|
332
372
|
for bwp in self.bot_wake_prefixs:
|
|
333
373
|
if self.provider_wake_prefix.startswith(bwp):
|
|
334
374
|
logger.info(
|
|
335
|
-
f"识别 LLM 聊天额外唤醒前缀 {self.provider_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。"
|
|
375
|
+
f"识别 LLM 聊天额外唤醒前缀 {self.provider_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。",
|
|
336
376
|
)
|
|
337
377
|
self.provider_wake_prefix = self.provider_wake_prefix[len(bwp) :]
|
|
338
378
|
|
|
@@ -367,7 +407,9 @@ class LLMRequestSubStage(Stage):
|
|
|
367
407
|
return conversation
|
|
368
408
|
|
|
369
409
|
async def process(
|
|
370
|
-
self,
|
|
410
|
+
self,
|
|
411
|
+
event: AstrMessageEvent,
|
|
412
|
+
_nested: bool = False,
|
|
371
413
|
) -> None | AsyncGenerator[None, None]:
|
|
372
414
|
req: ProviderRequest | None = None
|
|
373
415
|
|
|
@@ -423,7 +465,9 @@ class LLMRequestSubStage(Stage):
|
|
|
423
465
|
# 应用知识库
|
|
424
466
|
try:
|
|
425
467
|
await inject_kb_context(
|
|
426
|
-
umo=event.unified_msg_origin,
|
|
468
|
+
umo=event.unified_msg_origin,
|
|
469
|
+
p_ctx=self.ctx,
|
|
470
|
+
req=req,
|
|
427
471
|
)
|
|
428
472
|
except Exception as e:
|
|
429
473
|
logger.error(f"调用知识库时遇到问题: {e}")
|
|
@@ -475,7 +519,7 @@ class LLMRequestSubStage(Stage):
|
|
|
475
519
|
# 如果模型不支持工具使用,但请求中包含工具列表,则清空。
|
|
476
520
|
if "tool_use" not in provider_cfg:
|
|
477
521
|
logger.debug(
|
|
478
|
-
f"用户设置提供商 {provider} 不支持工具使用,清空工具列表。"
|
|
522
|
+
f"用户设置提供商 {provider} 不支持工具使用,清空工具列表。",
|
|
479
523
|
)
|
|
480
524
|
req.func_tool = None
|
|
481
525
|
# 插件可用性设置
|
|
@@ -498,19 +542,22 @@ class LLMRequestSubStage(Stage):
|
|
|
498
542
|
# run agent
|
|
499
543
|
agent_runner = AgentRunner()
|
|
500
544
|
logger.debug(
|
|
501
|
-
f"handle provider[id: {provider.provider_config['id']}] request: {req}"
|
|
545
|
+
f"handle provider[id: {provider.provider_config['id']}] request: {req}",
|
|
502
546
|
)
|
|
503
547
|
astr_agent_ctx = AstrAgentContext(
|
|
504
548
|
provider=provider,
|
|
505
549
|
first_provider_request=req,
|
|
506
550
|
curr_provider_request=req,
|
|
507
551
|
streaming=self.streaming_response,
|
|
508
|
-
|
|
552
|
+
event=event,
|
|
509
553
|
)
|
|
510
554
|
await agent_runner.reset(
|
|
511
555
|
provider=provider,
|
|
512
556
|
request=req,
|
|
513
|
-
run_context=AgentContextWrapper(
|
|
557
|
+
run_context=AgentContextWrapper(
|
|
558
|
+
context=astr_agent_ctx,
|
|
559
|
+
tool_call_timeout=self.tool_call_timeout,
|
|
560
|
+
),
|
|
514
561
|
tool_executor=FunctionToolExecutor(),
|
|
515
562
|
agent_hooks=MAIN_AGENT_HOOKS,
|
|
516
563
|
streaming=self.streaming_response,
|
|
@@ -522,8 +569,8 @@ class LLMRequestSubStage(Stage):
|
|
|
522
569
|
MessageEventResult()
|
|
523
570
|
.set_result_content_type(ResultContentType.STREAMING_RESULT)
|
|
524
571
|
.set_async_stream(
|
|
525
|
-
run_agent(agent_runner, self.max_step, self.show_tool_use)
|
|
526
|
-
)
|
|
572
|
+
run_agent(agent_runner, self.max_step, self.show_tool_use),
|
|
573
|
+
),
|
|
527
574
|
)
|
|
528
575
|
yield
|
|
529
576
|
if agent_runner.done():
|
|
@@ -540,7 +587,7 @@ class LLMRequestSubStage(Stage):
|
|
|
540
587
|
MessageEventResult(
|
|
541
588
|
chain=chain,
|
|
542
589
|
result_content_type=ResultContentType.STREAMING_FINISH,
|
|
543
|
-
)
|
|
590
|
+
),
|
|
544
591
|
)
|
|
545
592
|
else:
|
|
546
593
|
async for _ in run_agent(agent_runner, self.max_step, self.show_tool_use):
|
|
@@ -560,17 +607,21 @@ class LLMRequestSubStage(Stage):
|
|
|
560
607
|
llm_tick=1,
|
|
561
608
|
model_name=agent_runner.provider.get_model(),
|
|
562
609
|
provider_type=agent_runner.provider.meta().type,
|
|
563
|
-
)
|
|
610
|
+
),
|
|
564
611
|
)
|
|
565
612
|
|
|
566
613
|
async def _handle_webchat(
|
|
567
|
-
self,
|
|
614
|
+
self,
|
|
615
|
+
event: AstrMessageEvent,
|
|
616
|
+
req: ProviderRequest,
|
|
617
|
+
prov: Provider,
|
|
568
618
|
):
|
|
569
619
|
"""处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title"""
|
|
570
620
|
if not req.conversation:
|
|
571
621
|
return
|
|
572
622
|
conversation = await self.conv_manager.get_conversation(
|
|
573
|
-
event.unified_msg_origin,
|
|
623
|
+
event.unified_msg_origin,
|
|
624
|
+
req.conversation.cid,
|
|
574
625
|
)
|
|
575
626
|
if conversation and not req.conversation.title:
|
|
576
627
|
messages = json.loads(conversation.history)
|
|
@@ -607,7 +658,7 @@ class LLMRequestSubStage(Stage):
|
|
|
607
658
|
)
|
|
608
659
|
if llm_resp and llm_resp.completion_text:
|
|
609
660
|
logger.debug(
|
|
610
|
-
f"WebChat 对话标题生成响应: {llm_resp.completion_text.strip()}"
|
|
661
|
+
f"WebChat 对话标题生成响应: {llm_resp.completion_text.strip()}",
|
|
611
662
|
)
|
|
612
663
|
title = llm_resp.completion_text.strip()
|
|
613
664
|
if not title or "<None>" in title:
|
|
@@ -650,7 +701,9 @@ class LLMRequestSubStage(Stage):
|
|
|
650
701
|
messages.append({"role": "assistant", "content": llm_response.completion_text})
|
|
651
702
|
messages = list(filter(lambda item: "_no_save" not in item, messages))
|
|
652
703
|
await self.conv_manager.update_conversation(
|
|
653
|
-
event.unified_msg_origin,
|
|
704
|
+
event.unified_msg_origin,
|
|
705
|
+
req.conversation.cid,
|
|
706
|
+
history=messages,
|
|
654
707
|
)
|
|
655
708
|
|
|
656
709
|
def fix_messages(self, messages: list[dict]) -> list[dict]:
|
|
@@ -1,16 +1,17 @@
|
|
|
1
|
-
"""
|
|
2
|
-
|
|
3
|
-
|
|
1
|
+
"""本地 Agent 模式的 AstrBot 插件调用 Stage"""
|
|
2
|
+
|
|
3
|
+
import traceback
|
|
4
|
+
from collections.abc import AsyncGenerator
|
|
5
|
+
from typing import Any
|
|
4
6
|
|
|
5
|
-
from ...context import PipelineContext, call_handler
|
|
6
|
-
from ..stage import Stage
|
|
7
|
-
from typing import Dict, Any, List, AsyncGenerator, Union
|
|
8
|
-
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
|
9
|
-
from astrbot.core.message.message_event_result import MessageEventResult
|
|
10
7
|
from astrbot.core import logger
|
|
11
|
-
from astrbot.core.
|
|
8
|
+
from astrbot.core.message.message_event_result import MessageEventResult
|
|
9
|
+
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
|
12
10
|
from astrbot.core.star.star import star_map
|
|
13
|
-
import
|
|
11
|
+
from astrbot.core.star.star_handler import StarHandlerMetadata
|
|
12
|
+
|
|
13
|
+
from ...context import PipelineContext, call_handler
|
|
14
|
+
from ..stage import Stage
|
|
14
15
|
|
|
15
16
|
|
|
16
17
|
class StarRequestSubStage(Stage):
|
|
@@ -21,13 +22,14 @@ class StarRequestSubStage(Stage):
|
|
|
21
22
|
self.ctx = ctx
|
|
22
23
|
|
|
23
24
|
async def process(
|
|
24
|
-
self,
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
25
|
+
self,
|
|
26
|
+
event: AstrMessageEvent,
|
|
27
|
+
) -> None | AsyncGenerator[None, None]:
|
|
28
|
+
activated_handlers: list[StarHandlerMetadata] = event.get_extra(
|
|
29
|
+
"activated_handlers",
|
|
28
30
|
)
|
|
29
|
-
handlers_parsed_params:
|
|
30
|
-
"handlers_parsed_params"
|
|
31
|
+
handlers_parsed_params: dict[str, dict[str, Any]] = event.get_extra(
|
|
32
|
+
"handlers_parsed_params",
|
|
31
33
|
)
|
|
32
34
|
if not handlers_parsed_params:
|
|
33
35
|
handlers_parsed_params = {}
|
|
@@ -37,7 +39,7 @@ class StarRequestSubStage(Stage):
|
|
|
37
39
|
md = star_map.get(handler.handler_module_path)
|
|
38
40
|
if not md:
|
|
39
41
|
logger.warning(
|
|
40
|
-
f"Cannot find plugin for given handler module path: {handler.handler_module_path}"
|
|
42
|
+
f"Cannot find plugin for given handler module path: {handler.handler_module_path}",
|
|
41
43
|
)
|
|
42
44
|
continue
|
|
43
45
|
logger.debug(f"plugin -> {md.name} - {handler.handler_name}")
|
|
@@ -1,12 +1,14 @@
|
|
|
1
|
-
from
|
|
2
|
-
|
|
1
|
+
from collections.abc import AsyncGenerator
|
|
2
|
+
|
|
3
|
+
from astrbot.core import logger
|
|
4
|
+
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
|
5
|
+
from astrbot.core.provider.entities import ProviderRequest
|
|
6
|
+
from astrbot.core.star.star_handler import StarHandlerMetadata
|
|
7
|
+
|
|
3
8
|
from ..context import PipelineContext
|
|
9
|
+
from ..stage import Stage, register_stage
|
|
4
10
|
from .method.llm_request import LLMRequestSubStage
|
|
5
11
|
from .method.star_request import StarRequestSubStage
|
|
6
|
-
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
|
7
|
-
from astrbot.core.star.star_handler import StarHandlerMetadata
|
|
8
|
-
from astrbot.core.provider.entities import ProviderRequest
|
|
9
|
-
from astrbot.core import logger
|
|
10
12
|
|
|
11
13
|
|
|
12
14
|
@register_stage
|
|
@@ -22,11 +24,12 @@ class ProcessStage(Stage):
|
|
|
22
24
|
await self.star_request_sub_stage.initialize(ctx)
|
|
23
25
|
|
|
24
26
|
async def process(
|
|
25
|
-
self,
|
|
26
|
-
|
|
27
|
+
self,
|
|
28
|
+
event: AstrMessageEvent,
|
|
29
|
+
) -> None | AsyncGenerator[None, None]:
|
|
27
30
|
"""处理事件"""
|
|
28
|
-
activated_handlers:
|
|
29
|
-
"activated_handlers"
|
|
31
|
+
activated_handlers: list[StarHandlerMetadata] = event.get_extra(
|
|
32
|
+
"activated_handlers",
|
|
30
33
|
)
|
|
31
34
|
# 有插件 Handler 被激活
|
|
32
35
|
if activated_handlers:
|
|
@@ -1,6 +1,7 @@
|
|
|
1
|
-
from ..context import PipelineContext
|
|
2
|
-
from astrbot.core.provider.entities import ProviderRequest
|
|
3
1
|
from astrbot.api import logger, sp
|
|
2
|
+
from astrbot.core.provider.entities import ProviderRequest
|
|
3
|
+
|
|
4
|
+
from ..context import PipelineContext
|
|
4
5
|
|
|
5
6
|
|
|
6
7
|
async def inject_kb_context(
|
|
@@ -8,14 +9,14 @@ async def inject_kb_context(
|
|
|
8
9
|
p_ctx: PipelineContext,
|
|
9
10
|
req: ProviderRequest,
|
|
10
11
|
) -> None:
|
|
11
|
-
"""
|
|
12
|
+
"""Inject knowledge base context into the provider request
|
|
12
13
|
|
|
13
14
|
Args:
|
|
14
15
|
umo: Unique message object (session ID)
|
|
15
16
|
p_ctx: Pipeline context
|
|
16
17
|
req: Provider request
|
|
17
|
-
"""
|
|
18
18
|
|
|
19
|
+
"""
|
|
19
20
|
kb_mgr = p_ctx.plugin_manager.context.kb_manager
|
|
20
21
|
|
|
21
22
|
# 1. 优先读取会话级配置
|
|
@@ -45,7 +46,7 @@ async def inject_kb_context(
|
|
|
45
46
|
|
|
46
47
|
if invalid_kb_ids:
|
|
47
48
|
logger.warning(
|
|
48
|
-
f"[知识库] 会话 {umo} 配置的以下知识库无效: {invalid_kb_ids}"
|
|
49
|
+
f"[知识库] 会话 {umo} 配置的以下知识库无效: {invalid_kb_ids}",
|
|
49
50
|
)
|
|
50
51
|
|
|
51
52
|
if not kb_names:
|
|
@@ -1,18 +1,19 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
-
from datetime import datetime, timedelta
|
|
3
2
|
from collections import defaultdict, deque
|
|
4
|
-
from
|
|
5
|
-
from
|
|
6
|
-
|
|
7
|
-
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
|
3
|
+
from collections.abc import AsyncGenerator
|
|
4
|
+
from datetime import datetime, timedelta
|
|
5
|
+
|
|
8
6
|
from astrbot.core import logger
|
|
9
7
|
from astrbot.core.config.astrbot_config import RateLimitStrategy
|
|
8
|
+
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
|
9
|
+
|
|
10
|
+
from ..context import PipelineContext
|
|
11
|
+
from ..stage import Stage, register_stage
|
|
10
12
|
|
|
11
13
|
|
|
12
14
|
@register_stage
|
|
13
15
|
class RateLimitStage(Stage):
|
|
14
|
-
"""
|
|
15
|
-
检查是否需要限制消息发送的限流器。
|
|
16
|
+
"""检查是否需要限制消息发送的限流器。
|
|
16
17
|
|
|
17
18
|
使用 Fixed Window 算法。
|
|
18
19
|
如果触发限流,将 stall 流水线,直到下一个时间窗口来临时自动唤醒。
|
|
@@ -20,32 +21,30 @@ class RateLimitStage(Stage):
|
|
|
20
21
|
|
|
21
22
|
def __init__(self):
|
|
22
23
|
# 存储每个会话的请求时间队列
|
|
23
|
-
self.event_timestamps:
|
|
24
|
+
self.event_timestamps: defaultdict[str, deque[datetime]] = defaultdict(deque)
|
|
24
25
|
# 为每个会话设置一个锁,避免并发冲突
|
|
25
|
-
self.locks:
|
|
26
|
+
self.locks: defaultdict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
|
26
27
|
# 限流参数
|
|
27
28
|
self.rate_limit_count: int = 0
|
|
28
29
|
self.rate_limit_time: timedelta = timedelta(0)
|
|
29
30
|
|
|
30
31
|
async def initialize(self, ctx: PipelineContext) -> None:
|
|
31
|
-
"""
|
|
32
|
-
初始化限流器,根据配置设置限流参数。
|
|
33
|
-
"""
|
|
32
|
+
"""初始化限流器,根据配置设置限流参数。"""
|
|
34
33
|
self.rate_limit_count = ctx.astrbot_config["platform_settings"]["rate_limit"][
|
|
35
34
|
"count"
|
|
36
35
|
]
|
|
37
36
|
self.rate_limit_time = timedelta(
|
|
38
|
-
seconds=ctx.astrbot_config["platform_settings"]["rate_limit"]["time"]
|
|
37
|
+
seconds=ctx.astrbot_config["platform_settings"]["rate_limit"]["time"],
|
|
39
38
|
)
|
|
40
39
|
self.rl_strategy = ctx.astrbot_config["platform_settings"]["rate_limit"][
|
|
41
40
|
"strategy"
|
|
42
41
|
] # stall or discard
|
|
43
42
|
|
|
44
43
|
async def process(
|
|
45
|
-
self,
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
检查并处理限流逻辑。如果触发限流,流水线会 stall 并在窗口期后自动恢复。
|
|
44
|
+
self,
|
|
45
|
+
event: AstrMessageEvent,
|
|
46
|
+
) -> None | AsyncGenerator[None, None]:
|
|
47
|
+
"""检查并处理限流逻辑。如果触发限流,流水线会 stall 并在窗口期后自动恢复。
|
|
49
48
|
|
|
50
49
|
Args:
|
|
51
50
|
event (AstrMessageEvent): 当前消息事件。
|
|
@@ -53,6 +52,7 @@ class RateLimitStage(Stage):
|
|
|
53
52
|
|
|
54
53
|
Returns:
|
|
55
54
|
MessageEventResult: 继续或停止事件处理的结果。
|
|
55
|
+
|
|
56
56
|
"""
|
|
57
57
|
session_id = event.session_id
|
|
58
58
|
now = datetime.now()
|
|
@@ -66,32 +66,33 @@ class RateLimitStage(Stage):
|
|
|
66
66
|
if len(timestamps) < self.rate_limit_count:
|
|
67
67
|
timestamps.append(now)
|
|
68
68
|
break
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
return event.stop_event()
|
|
69
|
+
next_window_time = timestamps[0] + self.rate_limit_time
|
|
70
|
+
stall_duration = (next_window_time - now).total_seconds() + 0.3
|
|
71
|
+
|
|
72
|
+
match self.rl_strategy:
|
|
73
|
+
case RateLimitStrategy.STALL.value:
|
|
74
|
+
logger.info(
|
|
75
|
+
f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。",
|
|
76
|
+
)
|
|
77
|
+
await asyncio.sleep(stall_duration)
|
|
78
|
+
now = datetime.now()
|
|
79
|
+
case RateLimitStrategy.DISCARD.value:
|
|
80
|
+
logger.info(
|
|
81
|
+
f"会话 {session_id} 被限流。根据限流策略,此请求已被丢弃,直到限额于 {stall_duration:.2f} 秒后重置。",
|
|
82
|
+
)
|
|
83
|
+
return event.stop_event()
|
|
85
84
|
|
|
86
85
|
def _remove_expired_timestamps(
|
|
87
|
-
self,
|
|
86
|
+
self,
|
|
87
|
+
timestamps: deque[datetime],
|
|
88
|
+
now: datetime,
|
|
88
89
|
) -> None:
|
|
89
|
-
"""
|
|
90
|
-
移除时间窗口外的时间戳。
|
|
90
|
+
"""移除时间窗口外的时间戳。
|
|
91
91
|
|
|
92
92
|
Args:
|
|
93
93
|
timestamps (Deque[datetime]): 当前会话的时间戳队列。
|
|
94
94
|
now (datetime): 当前时间,用于计算过期时间。
|
|
95
|
+
|
|
95
96
|
"""
|
|
96
97
|
expiry_threshold: datetime = now - self.rate_limit_time
|
|
97
98
|
while timestamps and timestamps[0] < expiry_threshold:
|