AstrBot 4.10.5__py3-none-any.whl → 4.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/api/event/filter/__init__.py +4 -0
- astrbot/builtin_stars/builtin_commands/commands/tts.py +2 -2
- astrbot/cli/__init__.py +1 -1
- astrbot/core/agent/context/compressor.py +243 -0
- astrbot/core/agent/context/config.py +35 -0
- astrbot/core/agent/context/manager.py +120 -0
- astrbot/core/agent/context/token_counter.py +64 -0
- astrbot/core/agent/context/truncator.py +141 -0
- astrbot/core/agent/runners/tool_loop_agent_runner.py +48 -1
- astrbot/core/config/default.py +89 -28
- astrbot/core/conversation_mgr.py +4 -0
- astrbot/core/core_lifecycle.py +1 -0
- astrbot/core/db/__init__.py +1 -0
- astrbot/core/db/migration/migra_token_usage.py +61 -0
- astrbot/core/db/po.py +7 -0
- astrbot/core/db/sqlite.py +5 -1
- astrbot/core/pipeline/process_stage/method/agent_request.py +1 -1
- astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +70 -57
- astrbot/core/pipeline/result_decorate/stage.py +1 -1
- astrbot/core/pipeline/session_status_check/stage.py +1 -1
- astrbot/core/pipeline/waking_check/stage.py +1 -1
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +1 -1
- astrbot/core/provider/entities.py +5 -0
- astrbot/core/provider/manager.py +27 -12
- astrbot/core/provider/sources/openai_source.py +2 -1
- astrbot/core/star/context.py +14 -1
- astrbot/core/star/register/__init__.py +2 -0
- astrbot/core/star/register/star_handler.py +24 -0
- astrbot/core/star/session_llm_manager.py +38 -26
- astrbot/core/star/session_plugin_manager.py +23 -11
- astrbot/core/star/star_handler.py +1 -0
- astrbot/core/umop_config_router.py +9 -6
- astrbot/core/utils/migra_helper.py +8 -0
- astrbot/dashboard/routes/backup.py +1 -0
- {astrbot-4.10.5.dist-info → astrbot-4.11.0.dist-info}/METADATA +3 -1
- {astrbot-4.10.5.dist-info → astrbot-4.11.0.dist-info}/RECORD +39 -33
- {astrbot-4.10.5.dist-info → astrbot-4.11.0.dist-info}/WHEEL +0 -0
- {astrbot-4.10.5.dist-info → astrbot-4.11.0.dist-info}/entry_points.txt +0 -0
- {astrbot-4.10.5.dist-info → astrbot-4.11.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -25,6 +25,10 @@ from astrbot.core.provider.entities import (
|
|
|
25
25
|
)
|
|
26
26
|
from astrbot.core.provider.provider import Provider
|
|
27
27
|
|
|
28
|
+
from ..context.compressor import ContextCompressor
|
|
29
|
+
from ..context.config import ContextConfig
|
|
30
|
+
from ..context.manager import ContextManager
|
|
31
|
+
from ..context.token_counter import TokenCounter
|
|
28
32
|
from ..hooks import BaseAgentRunHooks
|
|
29
33
|
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
|
|
30
34
|
from ..response import AgentResponseData, AgentStats
|
|
@@ -47,10 +51,47 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|
|
47
51
|
run_context: ContextWrapper[TContext],
|
|
48
52
|
tool_executor: BaseFunctionToolExecutor[TContext],
|
|
49
53
|
agent_hooks: BaseAgentRunHooks[TContext],
|
|
54
|
+
streaming: bool = False,
|
|
55
|
+
# enforce max turns, will discard older turns when exceeded BEFORE compression
|
|
56
|
+
# -1 means no limit
|
|
57
|
+
enforce_max_turns: int = -1,
|
|
58
|
+
# llm compressor
|
|
59
|
+
llm_compress_instruction: str | None = None,
|
|
60
|
+
llm_compress_keep_recent: int = 0,
|
|
61
|
+
llm_compress_provider: Provider | None = None,
|
|
62
|
+
# truncate by turns compressor
|
|
63
|
+
truncate_turns: int = 1,
|
|
64
|
+
# customize
|
|
65
|
+
custom_token_counter: TokenCounter | None = None,
|
|
66
|
+
custom_compressor: ContextCompressor | None = None,
|
|
50
67
|
**kwargs: T.Any,
|
|
51
68
|
) -> None:
|
|
52
69
|
self.req = request
|
|
53
|
-
self.streaming =
|
|
70
|
+
self.streaming = streaming
|
|
71
|
+
self.enforce_max_turns = enforce_max_turns
|
|
72
|
+
self.llm_compress_instruction = llm_compress_instruction
|
|
73
|
+
self.llm_compress_keep_recent = llm_compress_keep_recent
|
|
74
|
+
self.llm_compress_provider = llm_compress_provider
|
|
75
|
+
self.truncate_turns = truncate_turns
|
|
76
|
+
self.custom_token_counter = custom_token_counter
|
|
77
|
+
self.custom_compressor = custom_compressor
|
|
78
|
+
# we will do compress when:
|
|
79
|
+
# 1. before requesting LLM
|
|
80
|
+
# TODO: 2. after LLM output a tool call
|
|
81
|
+
self.context_config = ContextConfig(
|
|
82
|
+
# <=0 will never do compress
|
|
83
|
+
max_context_tokens=provider.provider_config.get("max_context_tokens", 0),
|
|
84
|
+
# enforce max turns before compression
|
|
85
|
+
enforce_max_turns=self.enforce_max_turns,
|
|
86
|
+
truncate_turns=self.truncate_turns,
|
|
87
|
+
llm_compress_instruction=self.llm_compress_instruction,
|
|
88
|
+
llm_compress_keep_recent=self.llm_compress_keep_recent,
|
|
89
|
+
llm_compress_provider=self.llm_compress_provider,
|
|
90
|
+
custom_token_counter=self.custom_token_counter,
|
|
91
|
+
custom_compressor=self.custom_compressor,
|
|
92
|
+
)
|
|
93
|
+
self.context_manager = ContextManager(self.context_config)
|
|
94
|
+
|
|
54
95
|
self.provider = provider
|
|
55
96
|
self.final_llm_resp = None
|
|
56
97
|
self._state = AgentState.IDLE
|
|
@@ -110,6 +151,12 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|
|
110
151
|
self._transition_state(AgentState.RUNNING)
|
|
111
152
|
llm_resp_result = None
|
|
112
153
|
|
|
154
|
+
# do truncate and compress
|
|
155
|
+
token_usage = self.req.conversation.token_usage if self.req.conversation else 0
|
|
156
|
+
self.run_context.messages = await self.context_manager.process(
|
|
157
|
+
self.run_context.messages, trusted_token_usage=token_usage
|
|
158
|
+
)
|
|
159
|
+
|
|
113
160
|
async for llm_response in self._iter_llm_responses():
|
|
114
161
|
if llm_response.is_chunk:
|
|
115
162
|
# update ttft
|
astrbot/core/config/default.py
CHANGED
|
@@ -5,7 +5,7 @@ from typing import Any, TypedDict
|
|
|
5
5
|
|
|
6
6
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
|
7
7
|
|
|
8
|
-
VERSION = "4.
|
|
8
|
+
VERSION = "4.11.0"
|
|
9
9
|
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
|
10
10
|
|
|
11
11
|
WEBHOOK_SUPPORTED_PLATFORMS = [
|
|
@@ -83,6 +83,16 @@ DEFAULT_CONFIG = {
|
|
|
83
83
|
"default_personality": "default",
|
|
84
84
|
"persona_pool": ["*"],
|
|
85
85
|
"prompt_prefix": "{{prompt}}",
|
|
86
|
+
"context_limit_reached_strategy": "truncate_by_turns", # or llm_compress
|
|
87
|
+
"llm_compress_instruction": (
|
|
88
|
+
"Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n"
|
|
89
|
+
"1. Systematically cover all core topics discussed and the final conclusion/outcome for each; clearly highlight the latest primary focus.\n"
|
|
90
|
+
"2. If any tools were used, summarize tool usage (total call count) and extract the most valuable insights from tool outputs.\n"
|
|
91
|
+
"3. If there was an initial user goal, state it first and describe the current progress/status.\n"
|
|
92
|
+
"4. Write the summary in the user's language.\n"
|
|
93
|
+
),
|
|
94
|
+
"llm_compress_keep_recent": 4,
|
|
95
|
+
"llm_compress_provider_id": "",
|
|
86
96
|
"max_context_length": -1,
|
|
87
97
|
"dequeue_context_length": 1,
|
|
88
98
|
"streaming_response": False,
|
|
@@ -179,6 +189,7 @@ class ChatProviderTemplate(TypedDict):
|
|
|
179
189
|
model: str
|
|
180
190
|
modalities: list
|
|
181
191
|
custom_extra_body: dict[str, Any]
|
|
192
|
+
max_context_tokens: int
|
|
182
193
|
|
|
183
194
|
|
|
184
195
|
CHAT_PROVIDER_TEMPLATE = {
|
|
@@ -187,6 +198,7 @@ CHAT_PROVIDER_TEMPLATE = {
|
|
|
187
198
|
"model": "",
|
|
188
199
|
"modalities": [],
|
|
189
200
|
"custom_extra_body": {},
|
|
201
|
+
"max_context_tokens": 0,
|
|
190
202
|
}
|
|
191
203
|
|
|
192
204
|
"""
|
|
@@ -227,7 +239,7 @@ CONFIG_METADATA_2 = {
|
|
|
227
239
|
"callback_server_host": "0.0.0.0",
|
|
228
240
|
"port": 6196,
|
|
229
241
|
},
|
|
230
|
-
"OneBot v11": {
|
|
242
|
+
"OneBot v11 (QQ 个人号等)": {
|
|
231
243
|
"id": "default",
|
|
232
244
|
"type": "aiocqhttp",
|
|
233
245
|
"enable": False,
|
|
@@ -235,16 +247,6 @@ CONFIG_METADATA_2 = {
|
|
|
235
247
|
"ws_reverse_port": 6199,
|
|
236
248
|
"ws_reverse_token": "",
|
|
237
249
|
},
|
|
238
|
-
"WeChatPadPro": {
|
|
239
|
-
"id": "wechatpadpro",
|
|
240
|
-
"type": "wechatpadpro",
|
|
241
|
-
"enable": False,
|
|
242
|
-
"admin_key": "stay33",
|
|
243
|
-
"host": "这里填写你的局域网IP或者公网服务器IP",
|
|
244
|
-
"port": 8059,
|
|
245
|
-
"wpp_active_message_poll": False,
|
|
246
|
-
"wpp_active_message_poll_interval": 3,
|
|
247
|
-
},
|
|
248
250
|
"微信公众平台": {
|
|
249
251
|
"id": "weixin_official_account",
|
|
250
252
|
"type": "weixin_official_account",
|
|
@@ -374,6 +376,16 @@ CONFIG_METADATA_2 = {
|
|
|
374
376
|
"satori_heartbeat_interval": 10,
|
|
375
377
|
"satori_reconnect_delay": 5,
|
|
376
378
|
},
|
|
379
|
+
"WeChatPadPro": {
|
|
380
|
+
"id": "wechatpadpro",
|
|
381
|
+
"type": "wechatpadpro",
|
|
382
|
+
"enable": False,
|
|
383
|
+
"admin_key": "stay33",
|
|
384
|
+
"host": "这里填写你的局域网IP或者公网服务器IP",
|
|
385
|
+
"port": 8059,
|
|
386
|
+
"wpp_active_message_poll": False,
|
|
387
|
+
"wpp_active_message_poll_interval": 3,
|
|
388
|
+
},
|
|
377
389
|
# "WebChat": {
|
|
378
390
|
# "id": "webchat",
|
|
379
391
|
# "type": "webchat",
|
|
@@ -2033,6 +2045,11 @@ CONFIG_METADATA_2 = {
|
|
|
2033
2045
|
"type": "string",
|
|
2034
2046
|
"hint": "模型名称,如 gpt-4o-mini, deepseek-chat。",
|
|
2035
2047
|
},
|
|
2048
|
+
"max_context_tokens": {
|
|
2049
|
+
"description": "模型上下文窗口大小",
|
|
2050
|
+
"type": "int",
|
|
2051
|
+
"hint": "模型最大上下文 Token 大小。如果为 0,则会自动从模型元数据填充(如有),也可手动修改。",
|
|
2052
|
+
},
|
|
2036
2053
|
"dify_api_key": {
|
|
2037
2054
|
"description": "API Key",
|
|
2038
2055
|
"type": "string",
|
|
@@ -2540,6 +2557,66 @@ CONFIG_METADATA_3 = {
|
|
|
2540
2557
|
# "provider_settings.enable": True,
|
|
2541
2558
|
# },
|
|
2542
2559
|
# },
|
|
2560
|
+
"truncate_and_compress": {
|
|
2561
|
+
"description": "上下文管理策略",
|
|
2562
|
+
"type": "object",
|
|
2563
|
+
"items": {
|
|
2564
|
+
"provider_settings.max_context_length": {
|
|
2565
|
+
"description": "最多携带对话轮数",
|
|
2566
|
+
"type": "int",
|
|
2567
|
+
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制",
|
|
2568
|
+
"condition": {
|
|
2569
|
+
"provider_settings.agent_runner_type": "local",
|
|
2570
|
+
},
|
|
2571
|
+
},
|
|
2572
|
+
"provider_settings.dequeue_context_length": {
|
|
2573
|
+
"description": "丢弃对话轮数",
|
|
2574
|
+
"type": "int",
|
|
2575
|
+
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数",
|
|
2576
|
+
"condition": {
|
|
2577
|
+
"provider_settings.agent_runner_type": "local",
|
|
2578
|
+
},
|
|
2579
|
+
},
|
|
2580
|
+
"provider_settings.context_limit_reached_strategy": {
|
|
2581
|
+
"description": "超出模型上下文窗口时的处理方式",
|
|
2582
|
+
"type": "string",
|
|
2583
|
+
"options": ["truncate_by_turns", "llm_compress"],
|
|
2584
|
+
"labels": ["按对话轮数截断", "由 LLM 压缩上下文"],
|
|
2585
|
+
"condition": {
|
|
2586
|
+
"provider_settings.agent_runner_type": "local",
|
|
2587
|
+
},
|
|
2588
|
+
"hint": "",
|
|
2589
|
+
},
|
|
2590
|
+
"provider_settings.llm_compress_instruction": {
|
|
2591
|
+
"description": "上下文压缩提示词",
|
|
2592
|
+
"type": "text",
|
|
2593
|
+
"hint": "如果为空则使用默认提示词。",
|
|
2594
|
+
"condition": {
|
|
2595
|
+
"provider_settings.context_limit_reached_strategy": "llm_compress",
|
|
2596
|
+
"provider_settings.agent_runner_type": "local",
|
|
2597
|
+
},
|
|
2598
|
+
},
|
|
2599
|
+
"provider_settings.llm_compress_keep_recent": {
|
|
2600
|
+
"description": "压缩时保留最近对话轮数",
|
|
2601
|
+
"type": "int",
|
|
2602
|
+
"hint": "始终保留的最近 N 轮对话。",
|
|
2603
|
+
"condition": {
|
|
2604
|
+
"provider_settings.context_limit_reached_strategy": "llm_compress",
|
|
2605
|
+
"provider_settings.agent_runner_type": "local",
|
|
2606
|
+
},
|
|
2607
|
+
},
|
|
2608
|
+
"provider_settings.llm_compress_provider_id": {
|
|
2609
|
+
"description": "用于上下文压缩的模型提供商 ID",
|
|
2610
|
+
"type": "string",
|
|
2611
|
+
"_special": "select_provider",
|
|
2612
|
+
"hint": "留空时将降级为“按对话轮数截断”的策略。",
|
|
2613
|
+
"condition": {
|
|
2614
|
+
"provider_settings.context_limit_reached_strategy": "llm_compress",
|
|
2615
|
+
"provider_settings.agent_runner_type": "local",
|
|
2616
|
+
},
|
|
2617
|
+
},
|
|
2618
|
+
},
|
|
2619
|
+
},
|
|
2543
2620
|
"others": {
|
|
2544
2621
|
"description": "其他配置",
|
|
2545
2622
|
"type": "object",
|
|
@@ -2604,22 +2681,6 @@ CONFIG_METADATA_3 = {
|
|
|
2604
2681
|
"provider_settings.streaming_response": True,
|
|
2605
2682
|
},
|
|
2606
2683
|
},
|
|
2607
|
-
"provider_settings.max_context_length": {
|
|
2608
|
-
"description": "最多携带对话轮数",
|
|
2609
|
-
"type": "int",
|
|
2610
|
-
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制",
|
|
2611
|
-
"condition": {
|
|
2612
|
-
"provider_settings.agent_runner_type": "local",
|
|
2613
|
-
},
|
|
2614
|
-
},
|
|
2615
|
-
"provider_settings.dequeue_context_length": {
|
|
2616
|
-
"description": "丢弃对话轮数",
|
|
2617
|
-
"type": "int",
|
|
2618
|
-
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数",
|
|
2619
|
-
"condition": {
|
|
2620
|
-
"provider_settings.agent_runner_type": "local",
|
|
2621
|
-
},
|
|
2622
|
-
},
|
|
2623
2684
|
"provider_settings.wake_prefix": {
|
|
2624
2685
|
"description": "LLM 聊天额外唤醒前缀 ",
|
|
2625
2686
|
"type": "string",
|
astrbot/core/conversation_mgr.py
CHANGED
|
@@ -69,6 +69,7 @@ class ConversationManager:
|
|
|
69
69
|
persona_id=conv_v2.persona_id,
|
|
70
70
|
created_at=created_at,
|
|
71
71
|
updated_at=updated_at,
|
|
72
|
+
token_usage=conv_v2.token_usage,
|
|
72
73
|
)
|
|
73
74
|
|
|
74
75
|
async def new_conversation(
|
|
@@ -256,6 +257,7 @@ class ConversationManager:
|
|
|
256
257
|
history: list[dict] | None = None,
|
|
257
258
|
title: str | None = None,
|
|
258
259
|
persona_id: str | None = None,
|
|
260
|
+
token_usage: int | None = None,
|
|
259
261
|
) -> None:
|
|
260
262
|
"""更新会话的对话.
|
|
261
263
|
|
|
@@ -263,6 +265,7 @@ class ConversationManager:
|
|
|
263
265
|
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
|
264
266
|
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
|
265
267
|
history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段
|
|
268
|
+
token_usage (int | None): token 使用量。None 表示不更新
|
|
266
269
|
|
|
267
270
|
"""
|
|
268
271
|
if not conversation_id:
|
|
@@ -274,6 +277,7 @@ class ConversationManager:
|
|
|
274
277
|
title=title,
|
|
275
278
|
persona_id=persona_id,
|
|
276
279
|
content=history,
|
|
280
|
+
token_usage=token_usage,
|
|
277
281
|
)
|
|
278
282
|
|
|
279
283
|
async def update_conversation_title(
|
astrbot/core/core_lifecycle.py
CHANGED
astrbot/core/db/__init__.py
CHANGED
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
"""Migration script to add token_usage column to conversations table.
|
|
2
|
+
|
|
3
|
+
This migration adds the token_usage field to track token consumption for each conversation.
|
|
4
|
+
|
|
5
|
+
Changes:
|
|
6
|
+
- Adds token_usage column to conversations table (default: 0)
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from sqlalchemy import text
|
|
10
|
+
|
|
11
|
+
from astrbot.api import logger, sp
|
|
12
|
+
from astrbot.core.db import BaseDatabase
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
async def migrate_token_usage(db_helper: BaseDatabase):
|
|
16
|
+
"""Add token_usage column to conversations table.
|
|
17
|
+
|
|
18
|
+
This migration adds a new column to track token consumption in conversations.
|
|
19
|
+
"""
|
|
20
|
+
# 检查是否已经完成迁移
|
|
21
|
+
migration_done = await db_helper.get_preference(
|
|
22
|
+
"global", "global", "migration_done_token_usage_1"
|
|
23
|
+
)
|
|
24
|
+
if migration_done:
|
|
25
|
+
return
|
|
26
|
+
|
|
27
|
+
logger.info("开始执行数据库迁移(添加 conversations.token_usage 列)...")
|
|
28
|
+
|
|
29
|
+
# 这里只适配了 SQLite。因为截止至这一版本,AstrBot 仅支持 SQLite。
|
|
30
|
+
|
|
31
|
+
try:
|
|
32
|
+
async with db_helper.get_db() as session:
|
|
33
|
+
# 检查列是否已存在
|
|
34
|
+
result = await session.execute(text("PRAGMA table_info(conversations)"))
|
|
35
|
+
columns = result.fetchall()
|
|
36
|
+
column_names = [col[1] for col in columns]
|
|
37
|
+
|
|
38
|
+
if "token_usage" in column_names:
|
|
39
|
+
logger.info("token_usage 列已存在,跳过迁移")
|
|
40
|
+
await sp.put_async(
|
|
41
|
+
"global", "global", "migration_done_token_usage_1", True
|
|
42
|
+
)
|
|
43
|
+
return
|
|
44
|
+
|
|
45
|
+
# 添加 token_usage 列
|
|
46
|
+
await session.execute(
|
|
47
|
+
text(
|
|
48
|
+
"ALTER TABLE conversations ADD COLUMN token_usage INTEGER NOT NULL DEFAULT 0"
|
|
49
|
+
)
|
|
50
|
+
)
|
|
51
|
+
await session.commit()
|
|
52
|
+
|
|
53
|
+
logger.info("token_usage 列添加成功")
|
|
54
|
+
|
|
55
|
+
# 标记迁移完成
|
|
56
|
+
await sp.put_async("global", "global", "migration_done_token_usage_1", True)
|
|
57
|
+
logger.info("token_usage 迁移完成")
|
|
58
|
+
|
|
59
|
+
except Exception as e:
|
|
60
|
+
logger.error(f"迁移过程中发生错误: {e}", exc_info=True)
|
|
61
|
+
raise
|
astrbot/core/db/po.py
CHANGED
|
@@ -54,6 +54,11 @@ class ConversationV2(SQLModel, table=True):
|
|
|
54
54
|
)
|
|
55
55
|
title: str | None = Field(default=None, max_length=255)
|
|
56
56
|
persona_id: str | None = Field(default=None)
|
|
57
|
+
token_usage: int = Field(default=0, nullable=False)
|
|
58
|
+
"""content is a list of OpenAI-formated messages in list[dict] format.
|
|
59
|
+
token_usage is the total token value of the messages.
|
|
60
|
+
when 0, will use estimated token counter.
|
|
61
|
+
"""
|
|
57
62
|
|
|
58
63
|
__table_args__ = (
|
|
59
64
|
UniqueConstraint(
|
|
@@ -313,6 +318,8 @@ class Conversation:
|
|
|
313
318
|
persona_id: str | None = ""
|
|
314
319
|
created_at: int = 0
|
|
315
320
|
updated_at: int = 0
|
|
321
|
+
token_usage: int = 0
|
|
322
|
+
"""对话的总 token 数量。AstrBot 会保留最近一次 LLM 请求返回的总 token 数,方便统计。token_usage 可能为 0,表示未知。"""
|
|
316
323
|
|
|
317
324
|
|
|
318
325
|
class Personality(TypedDict):
|
astrbot/core/db/sqlite.py
CHANGED
|
@@ -241,7 +241,9 @@ class SQLiteDatabase(BaseDatabase):
|
|
|
241
241
|
session.add(new_conversation)
|
|
242
242
|
return new_conversation
|
|
243
243
|
|
|
244
|
-
async def update_conversation(
|
|
244
|
+
async def update_conversation(
|
|
245
|
+
self, cid, title=None, persona_id=None, content=None, token_usage=None
|
|
246
|
+
):
|
|
245
247
|
async with self.get_db() as session:
|
|
246
248
|
session: AsyncSession
|
|
247
249
|
async with session.begin():
|
|
@@ -255,6 +257,8 @@ class SQLiteDatabase(BaseDatabase):
|
|
|
255
257
|
values["persona_id"] = persona_id
|
|
256
258
|
if content is not None:
|
|
257
259
|
values["content"] = content
|
|
260
|
+
if token_usage is not None:
|
|
261
|
+
values["token_usage"] = token_usage
|
|
258
262
|
if not values:
|
|
259
263
|
return None
|
|
260
264
|
query = query.values(**values)
|
|
@@ -38,7 +38,7 @@ class AgentRequestSubStage(Stage):
|
|
|
38
38
|
)
|
|
39
39
|
return
|
|
40
40
|
|
|
41
|
-
if not SessionServiceManager.should_process_llm_request(event):
|
|
41
|
+
if not await SessionServiceManager.should_process_llm_request(event):
|
|
42
42
|
logger.debug(
|
|
43
43
|
f"The session {event.unified_msg_origin} has disabled AI capability, skipping processing."
|
|
44
44
|
)
|
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
"""本地 Agent 模式的 LLM 调用 Stage"""
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
-
import copy
|
|
5
4
|
import json
|
|
6
5
|
from collections.abc import AsyncGenerator
|
|
7
6
|
|
|
8
7
|
from astrbot.core import logger
|
|
9
8
|
from astrbot.core.agent.message import Message
|
|
9
|
+
from astrbot.core.agent.response import AgentStats
|
|
10
10
|
from astrbot.core.agent.tool import ToolSet
|
|
11
11
|
from astrbot.core.astr_agent_context import AstrAgentContext
|
|
12
12
|
from astrbot.core.conversation_mgr import Conversation
|
|
@@ -24,6 +24,7 @@ from astrbot.core.provider.entities import (
|
|
|
24
24
|
)
|
|
25
25
|
from astrbot.core.star.star_handler import EventType, star_map
|
|
26
26
|
from astrbot.core.utils.file_extract import extract_file_moonshotai
|
|
27
|
+
from astrbot.core.utils.llm_metadata import LLM_METADATAS
|
|
27
28
|
from astrbot.core.utils.metrics import Metric
|
|
28
29
|
from astrbot.core.utils.session_lock import session_lock_manager
|
|
29
30
|
|
|
@@ -41,11 +42,6 @@ class InternalAgentSubStage(Stage):
|
|
|
41
42
|
self.ctx = ctx
|
|
42
43
|
conf = ctx.astrbot_config
|
|
43
44
|
settings = conf["provider_settings"]
|
|
44
|
-
self.max_context_length = settings["max_context_length"] # int
|
|
45
|
-
self.dequeue_context_length: int = min(
|
|
46
|
-
max(1, settings["dequeue_context_length"]),
|
|
47
|
-
self.max_context_length - 1,
|
|
48
|
-
)
|
|
49
45
|
self.streaming_response: bool = settings["streaming_response"]
|
|
50
46
|
self.unsupported_streaming_strategy: str = settings[
|
|
51
47
|
"unsupported_streaming_strategy"
|
|
@@ -65,6 +61,25 @@ class InternalAgentSubStage(Stage):
|
|
|
65
61
|
"moonshotai_api_key", ""
|
|
66
62
|
)
|
|
67
63
|
|
|
64
|
+
# 上下文管理相关
|
|
65
|
+
self.context_limit_reached_strategy: str = settings.get(
|
|
66
|
+
"context_limit_reached_strategy", "truncate_by_turns"
|
|
67
|
+
)
|
|
68
|
+
self.llm_compress_instruction: str = settings.get(
|
|
69
|
+
"llm_compress_instruction", ""
|
|
70
|
+
)
|
|
71
|
+
self.llm_compress_keep_recent: int = settings.get("llm_compress_keep_recent", 4)
|
|
72
|
+
self.llm_compress_provider_id: str = settings.get(
|
|
73
|
+
"llm_compress_provider_id", ""
|
|
74
|
+
)
|
|
75
|
+
self.max_context_length = settings["max_context_length"] # int
|
|
76
|
+
self.dequeue_context_length: int = min(
|
|
77
|
+
max(1, settings["dequeue_context_length"]),
|
|
78
|
+
self.max_context_length - 1,
|
|
79
|
+
)
|
|
80
|
+
if self.dequeue_context_length <= 0:
|
|
81
|
+
self.dequeue_context_length = 1
|
|
82
|
+
|
|
68
83
|
self.conv_manager = ctx.plugin_manager.context.conversation_manager
|
|
69
84
|
|
|
70
85
|
def _select_provider(self, event: AstrMessageEvent):
|
|
@@ -167,34 +182,6 @@ class InternalAgentSubStage(Stage):
|
|
|
167
182
|
},
|
|
168
183
|
)
|
|
169
184
|
|
|
170
|
-
def _truncate_contexts(
|
|
171
|
-
self,
|
|
172
|
-
contexts: list[dict],
|
|
173
|
-
) -> list[dict]:
|
|
174
|
-
"""截断上下文列表,确保不超过最大长度"""
|
|
175
|
-
if self.max_context_length == -1:
|
|
176
|
-
return contexts
|
|
177
|
-
|
|
178
|
-
if len(contexts) // 2 <= self.max_context_length:
|
|
179
|
-
return contexts
|
|
180
|
-
|
|
181
|
-
truncated_contexts = contexts[
|
|
182
|
-
-(self.max_context_length - self.dequeue_context_length + 1) * 2 :
|
|
183
|
-
]
|
|
184
|
-
# 找到第一个role 为 user 的索引,确保上下文格式正确
|
|
185
|
-
index = next(
|
|
186
|
-
(
|
|
187
|
-
i
|
|
188
|
-
for i, item in enumerate(truncated_contexts)
|
|
189
|
-
if item.get("role") == "user"
|
|
190
|
-
),
|
|
191
|
-
None,
|
|
192
|
-
)
|
|
193
|
-
if index is not None and index > 0:
|
|
194
|
-
truncated_contexts = truncated_contexts[index:]
|
|
195
|
-
|
|
196
|
-
return truncated_contexts
|
|
197
|
-
|
|
198
185
|
def _modalities_fix(
|
|
199
186
|
self,
|
|
200
187
|
provider: Provider,
|
|
@@ -296,6 +283,7 @@ class InternalAgentSubStage(Stage):
|
|
|
296
283
|
req: ProviderRequest,
|
|
297
284
|
llm_response: LLMResponse | None,
|
|
298
285
|
all_messages: list[Message],
|
|
286
|
+
runner_stats: AgentStats | None,
|
|
299
287
|
):
|
|
300
288
|
if (
|
|
301
289
|
not req
|
|
@@ -322,27 +310,37 @@ class InternalAgentSubStage(Stage):
|
|
|
322
310
|
continue
|
|
323
311
|
message_to_save.append(message.model_dump())
|
|
324
312
|
|
|
313
|
+
# get token usage from agent runner stats
|
|
314
|
+
token_usage = None
|
|
315
|
+
if runner_stats:
|
|
316
|
+
token_usage = runner_stats.token_usage.total
|
|
317
|
+
|
|
325
318
|
await self.conv_manager.update_conversation(
|
|
326
319
|
event.unified_msg_origin,
|
|
327
320
|
req.conversation.cid,
|
|
328
321
|
history=message_to_save,
|
|
322
|
+
token_usage=token_usage,
|
|
329
323
|
)
|
|
330
324
|
|
|
331
|
-
def
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
325
|
+
def _get_compress_provider(self) -> Provider | None:
|
|
326
|
+
if not self.llm_compress_provider_id:
|
|
327
|
+
return None
|
|
328
|
+
if self.context_limit_reached_strategy != "llm_compress":
|
|
329
|
+
return None
|
|
330
|
+
provider = self.ctx.plugin_manager.context.get_provider_by_id(
|
|
331
|
+
self.llm_compress_provider_id,
|
|
332
|
+
)
|
|
333
|
+
if provider is None:
|
|
334
|
+
logger.warning(
|
|
335
|
+
f"未找到指定的上下文压缩模型 {self.llm_compress_provider_id},将跳过压缩。",
|
|
336
|
+
)
|
|
337
|
+
return None
|
|
338
|
+
if not isinstance(provider, Provider):
|
|
339
|
+
logger.warning(
|
|
340
|
+
f"指定的上下文压缩模型 {self.llm_compress_provider_id} 不是对话模型,将跳过压缩。"
|
|
341
|
+
)
|
|
342
|
+
return None
|
|
343
|
+
return provider
|
|
346
344
|
|
|
347
345
|
async def process(
|
|
348
346
|
self, event: AstrMessageEvent, provider_wake_prefix: str
|
|
@@ -364,6 +362,10 @@ class InternalAgentSubStage(Stage):
|
|
|
364
362
|
streaming_response = bool(enable_streaming)
|
|
365
363
|
|
|
366
364
|
logger.debug("ready to request llm provider")
|
|
365
|
+
|
|
366
|
+
# 通知等待调用 LLM(在获取锁之前)
|
|
367
|
+
await call_event_hook(event, EventType.OnWaitingLLMRequestEvent)
|
|
368
|
+
|
|
367
369
|
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
|
|
368
370
|
logger.debug("acquired session lock for llm request")
|
|
369
371
|
if event.get_extra("provider_request"):
|
|
@@ -422,9 +424,10 @@ class InternalAgentSubStage(Stage):
|
|
|
422
424
|
await self._apply_kb(event, req)
|
|
423
425
|
|
|
424
426
|
# truncate contexts to fit max length
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
427
|
+
# NOW moved to ContextManager inside ToolLoopAgentRunner
|
|
428
|
+
# if req.contexts:
|
|
429
|
+
# req.contexts = self._truncate_contexts(req.contexts)
|
|
430
|
+
# self._fix_messages(req.contexts)
|
|
428
431
|
|
|
429
432
|
# session_id
|
|
430
433
|
if not req.session_id:
|
|
@@ -440,8 +443,6 @@ class InternalAgentSubStage(Stage):
|
|
|
440
443
|
self.unsupported_streaming_strategy == "turn_off"
|
|
441
444
|
and not event.platform_meta.support_streaming_message
|
|
442
445
|
)
|
|
443
|
-
# 备份 req.contexts
|
|
444
|
-
backup_contexts = copy.deepcopy(req.contexts)
|
|
445
446
|
|
|
446
447
|
# run agent
|
|
447
448
|
agent_runner = AgentRunner()
|
|
@@ -452,6 +453,15 @@ class InternalAgentSubStage(Stage):
|
|
|
452
453
|
context=self.ctx.plugin_manager.context,
|
|
453
454
|
event=event,
|
|
454
455
|
)
|
|
456
|
+
|
|
457
|
+
# inject model context length limit
|
|
458
|
+
if provider.provider_config.get("max_context_tokens", 0) <= 0:
|
|
459
|
+
model = provider.get_model()
|
|
460
|
+
if model_info := LLM_METADATAS.get(model):
|
|
461
|
+
provider.provider_config["max_context_tokens"] = model_info[
|
|
462
|
+
"limit"
|
|
463
|
+
]["context"]
|
|
464
|
+
|
|
455
465
|
await agent_runner.reset(
|
|
456
466
|
provider=provider,
|
|
457
467
|
request=req,
|
|
@@ -462,6 +472,11 @@ class InternalAgentSubStage(Stage):
|
|
|
462
472
|
tool_executor=FunctionToolExecutor(),
|
|
463
473
|
agent_hooks=MAIN_AGENT_HOOKS,
|
|
464
474
|
streaming=streaming_response,
|
|
475
|
+
llm_compress_instruction=self.llm_compress_instruction,
|
|
476
|
+
llm_compress_keep_recent=self.llm_compress_keep_recent,
|
|
477
|
+
llm_compress_provider=self._get_compress_provider(),
|
|
478
|
+
truncate_turns=self.dequeue_context_length,
|
|
479
|
+
enforce_max_turns=self.max_context_length,
|
|
465
480
|
)
|
|
466
481
|
|
|
467
482
|
if streaming_response and not stream_to_general:
|
|
@@ -507,14 +522,12 @@ class InternalAgentSubStage(Stage):
|
|
|
507
522
|
):
|
|
508
523
|
yield
|
|
509
524
|
|
|
510
|
-
# 恢复备份的 contexts
|
|
511
|
-
req.contexts = backup_contexts
|
|
512
|
-
|
|
513
525
|
await self._save_to_history(
|
|
514
526
|
event,
|
|
515
527
|
req,
|
|
516
528
|
agent_runner.get_final_llm_resp(),
|
|
517
529
|
agent_runner.run_context.messages,
|
|
530
|
+
agent_runner.stats,
|
|
518
531
|
)
|
|
519
532
|
|
|
520
533
|
# 异步处理 WebChat 特殊情况
|
|
@@ -260,7 +260,7 @@ class ResultDecorateStage(Stage):
|
|
|
260
260
|
should_tts = (
|
|
261
261
|
bool(self.ctx.astrbot_config["provider_tts_settings"]["enable"])
|
|
262
262
|
and result.is_llm_result()
|
|
263
|
-
and SessionServiceManager.should_process_tts_request(event)
|
|
263
|
+
and await SessionServiceManager.should_process_tts_request(event)
|
|
264
264
|
and random.random() <= self.tts_trigger_probability
|
|
265
265
|
and tts_provider
|
|
266
266
|
)
|
|
@@ -21,7 +21,7 @@ class SessionStatusCheckStage(Stage):
|
|
|
21
21
|
event: AstrMessageEvent,
|
|
22
22
|
) -> None | AsyncGenerator[None, None]:
|
|
23
23
|
# 检查会话是否整体启用
|
|
24
|
-
if not SessionServiceManager.is_session_enabled(event.unified_msg_origin):
|
|
24
|
+
if not await SessionServiceManager.is_session_enabled(event.unified_msg_origin):
|
|
25
25
|
logger.debug(f"会话 {event.unified_msg_origin} 已被关闭,已终止事件传播。")
|
|
26
26
|
|
|
27
27
|
# workaround for #2309
|