AstrBot 4.10.5__py3-none-any.whl → 4.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- astrbot/api/event/filter/__init__.py +4 -0
- astrbot/builtin_stars/builtin_commands/commands/tts.py +2 -2
- astrbot/cli/__init__.py +1 -1
- astrbot/core/agent/context/compressor.py +243 -0
- astrbot/core/agent/context/config.py +35 -0
- astrbot/core/agent/context/manager.py +120 -0
- astrbot/core/agent/context/token_counter.py +64 -0
- astrbot/core/agent/context/truncator.py +141 -0
- astrbot/core/agent/runners/tool_loop_agent_runner.py +48 -1
- astrbot/core/config/default.py +89 -28
- astrbot/core/conversation_mgr.py +4 -0
- astrbot/core/core_lifecycle.py +1 -0
- astrbot/core/db/__init__.py +1 -0
- astrbot/core/db/migration/migra_token_usage.py +61 -0
- astrbot/core/db/po.py +7 -0
- astrbot/core/db/sqlite.py +5 -1
- astrbot/core/pipeline/process_stage/method/agent_request.py +1 -1
- astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +70 -57
- astrbot/core/pipeline/result_decorate/stage.py +1 -1
- astrbot/core/pipeline/session_status_check/stage.py +1 -1
- astrbot/core/pipeline/waking_check/stage.py +1 -1
- astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +1 -1
- astrbot/core/provider/entities.py +5 -0
- astrbot/core/provider/manager.py +27 -12
- astrbot/core/provider/sources/openai_source.py +2 -1
- astrbot/core/star/context.py +14 -1
- astrbot/core/star/register/__init__.py +2 -0
- astrbot/core/star/register/star_handler.py +24 -0
- astrbot/core/star/session_llm_manager.py +38 -26
- astrbot/core/star/session_plugin_manager.py +23 -11
- astrbot/core/star/star_handler.py +1 -0
- astrbot/core/umop_config_router.py +9 -6
- astrbot/core/utils/migra_helper.py +8 -0
- astrbot/dashboard/routes/backup.py +1 -0
- {astrbot-4.10.5.dist-info → astrbot-4.11.0.dist-info}/METADATA +3 -1
- {astrbot-4.10.5.dist-info → astrbot-4.11.0.dist-info}/RECORD +39 -33
- {astrbot-4.10.5.dist-info → astrbot-4.11.0.dist-info}/WHEEL +0 -0
- {astrbot-4.10.5.dist-info → astrbot-4.11.0.dist-info}/entry_points.txt +0 -0
- {astrbot-4.10.5.dist-info → astrbot-4.11.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -21,6 +21,9 @@ from astrbot.core.star.register import (
|
|
|
21
21
|
from astrbot.core.star.register import register_on_llm_request as on_llm_request
|
|
22
22
|
from astrbot.core.star.register import register_on_llm_response as on_llm_response
|
|
23
23
|
from astrbot.core.star.register import register_on_platform_loaded as on_platform_loaded
|
|
24
|
+
from astrbot.core.star.register import (
|
|
25
|
+
register_on_waiting_llm_request as on_waiting_llm_request,
|
|
26
|
+
)
|
|
24
27
|
from astrbot.core.star.register import register_permission_type as permission_type
|
|
25
28
|
from astrbot.core.star.register import (
|
|
26
29
|
register_platform_adapter_type as platform_adapter_type,
|
|
@@ -46,6 +49,7 @@ __all__ = [
|
|
|
46
49
|
"on_llm_request",
|
|
47
50
|
"on_llm_response",
|
|
48
51
|
"on_platform_loaded",
|
|
52
|
+
"on_waiting_llm_request",
|
|
49
53
|
"permission_type",
|
|
50
54
|
"platform_adapter_type",
|
|
51
55
|
"regex",
|
|
@@ -14,13 +14,13 @@ class TTSCommand:
|
|
|
14
14
|
async def tts(self, event: AstrMessageEvent):
|
|
15
15
|
"""开关文本转语音(会话级别)"""
|
|
16
16
|
umo = event.unified_msg_origin
|
|
17
|
-
ses_tts = SessionServiceManager.is_tts_enabled_for_session(umo)
|
|
17
|
+
ses_tts = await SessionServiceManager.is_tts_enabled_for_session(umo)
|
|
18
18
|
cfg = self.context.get_config(umo=umo)
|
|
19
19
|
tts_enable = cfg["provider_tts_settings"]["enable"]
|
|
20
20
|
|
|
21
21
|
# 切换状态
|
|
22
22
|
new_status = not ses_tts
|
|
23
|
-
SessionServiceManager.set_tts_status_for_session(umo, new_status)
|
|
23
|
+
await SessionServiceManager.set_tts_status_for_session(umo, new_status)
|
|
24
24
|
|
|
25
25
|
status_text = "已开启" if new_status else "已关闭"
|
|
26
26
|
|
astrbot/cli/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "4.
|
|
1
|
+
__version__ = "4.11.0"
|
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Protocol, runtime_checkable
|
|
2
|
+
|
|
3
|
+
from ..message import Message
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from astrbot import logger
|
|
7
|
+
else:
|
|
8
|
+
try:
|
|
9
|
+
from astrbot import logger
|
|
10
|
+
except ImportError:
|
|
11
|
+
import logging
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger("astrbot")
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from astrbot.core.provider.provider import Provider
|
|
17
|
+
|
|
18
|
+
from ..context.truncator import ContextTruncator
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@runtime_checkable
|
|
22
|
+
class ContextCompressor(Protocol):
|
|
23
|
+
"""
|
|
24
|
+
Protocol for context compressors.
|
|
25
|
+
Provides an interface for compressing message lists.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def should_compress(
|
|
29
|
+
self, messages: list[Message], current_tokens: int, max_tokens: int
|
|
30
|
+
) -> bool:
|
|
31
|
+
"""Check if compression is needed.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
messages: The message list to evaluate.
|
|
35
|
+
current_tokens: The current token count.
|
|
36
|
+
max_tokens: The maximum allowed tokens for the model.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
True if compression is needed, False otherwise.
|
|
40
|
+
"""
|
|
41
|
+
...
|
|
42
|
+
|
|
43
|
+
async def __call__(self, messages: list[Message]) -> list[Message]:
|
|
44
|
+
"""Compress the message list.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
messages: The original message list.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
The compressed message list.
|
|
51
|
+
"""
|
|
52
|
+
...
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class TruncateByTurnsCompressor:
|
|
56
|
+
"""Truncate by turns compressor implementation.
|
|
57
|
+
Truncates the message list by removing older turns.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
def __init__(self, truncate_turns: int = 1, compression_threshold: float = 0.82):
|
|
61
|
+
"""Initialize the truncate by turns compressor.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
truncate_turns: The number of turns to remove when truncating (default: 1).
|
|
65
|
+
compression_threshold: The compression trigger threshold (default: 0.82).
|
|
66
|
+
"""
|
|
67
|
+
self.truncate_turns = truncate_turns
|
|
68
|
+
self.compression_threshold = compression_threshold
|
|
69
|
+
|
|
70
|
+
def should_compress(
|
|
71
|
+
self, messages: list[Message], current_tokens: int, max_tokens: int
|
|
72
|
+
) -> bool:
|
|
73
|
+
"""Check if compression is needed.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
messages: The message list to evaluate.
|
|
77
|
+
current_tokens: The current token count.
|
|
78
|
+
max_tokens: The maximum allowed tokens.
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
True if compression is needed, False otherwise.
|
|
82
|
+
"""
|
|
83
|
+
if max_tokens <= 0 or current_tokens <= 0:
|
|
84
|
+
return False
|
|
85
|
+
usage_rate = current_tokens / max_tokens
|
|
86
|
+
return usage_rate > self.compression_threshold
|
|
87
|
+
|
|
88
|
+
async def __call__(self, messages: list[Message]) -> list[Message]:
|
|
89
|
+
truncator = ContextTruncator()
|
|
90
|
+
truncated_messages = truncator.truncate_by_dropping_oldest_turns(
|
|
91
|
+
messages,
|
|
92
|
+
drop_turns=self.truncate_turns,
|
|
93
|
+
)
|
|
94
|
+
return truncated_messages
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def split_history(
|
|
98
|
+
messages: list[Message], keep_recent: int
|
|
99
|
+
) -> tuple[list[Message], list[Message], list[Message]]:
|
|
100
|
+
"""Split the message list into system messages, messages to summarize, and recent messages.
|
|
101
|
+
|
|
102
|
+
Ensures that the split point is between complete user-assistant pairs to maintain conversation flow.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
messages: The original message list.
|
|
106
|
+
keep_recent: The number of latest messages to keep.
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
tuple: (system_messages, messages_to_summarize, recent_messages)
|
|
110
|
+
"""
|
|
111
|
+
# keep the system messages
|
|
112
|
+
first_non_system = 0
|
|
113
|
+
for i, msg in enumerate(messages):
|
|
114
|
+
if msg.role != "system":
|
|
115
|
+
first_non_system = i
|
|
116
|
+
break
|
|
117
|
+
|
|
118
|
+
system_messages = messages[:first_non_system]
|
|
119
|
+
non_system_messages = messages[first_non_system:]
|
|
120
|
+
|
|
121
|
+
if len(non_system_messages) <= keep_recent:
|
|
122
|
+
return system_messages, [], non_system_messages
|
|
123
|
+
|
|
124
|
+
# Find the split point, ensuring recent_messages starts with a user message
|
|
125
|
+
# This maintains complete conversation turns
|
|
126
|
+
split_index = len(non_system_messages) - keep_recent
|
|
127
|
+
|
|
128
|
+
# Search backward from split_index to find the first user message
|
|
129
|
+
# This ensures recent_messages starts with a user message (complete turn)
|
|
130
|
+
while split_index > 0 and non_system_messages[split_index].role != "user":
|
|
131
|
+
# TODO: +=1 or -=1 ? calculate by tokens
|
|
132
|
+
split_index -= 1
|
|
133
|
+
|
|
134
|
+
# If we couldn't find a user message, keep all messages as recent
|
|
135
|
+
if split_index == 0:
|
|
136
|
+
return system_messages, [], non_system_messages
|
|
137
|
+
|
|
138
|
+
messages_to_summarize = non_system_messages[:split_index]
|
|
139
|
+
recent_messages = non_system_messages[split_index:]
|
|
140
|
+
|
|
141
|
+
return system_messages, messages_to_summarize, recent_messages
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class LLMSummaryCompressor:
|
|
145
|
+
"""LLM-based summary compressor.
|
|
146
|
+
Uses LLM to summarize the old conversation history, keeping the latest messages.
|
|
147
|
+
"""
|
|
148
|
+
|
|
149
|
+
def __init__(
|
|
150
|
+
self,
|
|
151
|
+
provider: "Provider",
|
|
152
|
+
keep_recent: int = 4,
|
|
153
|
+
instruction_text: str | None = None,
|
|
154
|
+
compression_threshold: float = 0.82,
|
|
155
|
+
):
|
|
156
|
+
"""Initialize the LLM summary compressor.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
provider: The LLM provider instance.
|
|
160
|
+
keep_recent: The number of latest messages to keep (default: 4).
|
|
161
|
+
instruction_text: Custom instruction for summary generation.
|
|
162
|
+
compression_threshold: The compression trigger threshold (default: 0.82).
|
|
163
|
+
"""
|
|
164
|
+
self.provider = provider
|
|
165
|
+
self.keep_recent = keep_recent
|
|
166
|
+
self.compression_threshold = compression_threshold
|
|
167
|
+
|
|
168
|
+
self.instruction_text = instruction_text or (
|
|
169
|
+
"Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n"
|
|
170
|
+
"1. Systematically cover all core topics discussed and the final conclusion/outcome for each; clearly highlight the latest primary focus.\n"
|
|
171
|
+
"2. If any tools were used, summarize tool usage (total call count) and extract the most valuable insights from tool outputs.\n"
|
|
172
|
+
"3. If there was an initial user goal, state it first and describe the current progress/status.\n"
|
|
173
|
+
"4. Write the summary in the user's language.\n"
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
def should_compress(
|
|
177
|
+
self, messages: list[Message], current_tokens: int, max_tokens: int
|
|
178
|
+
) -> bool:
|
|
179
|
+
"""Check if compression is needed.
|
|
180
|
+
|
|
181
|
+
Args:
|
|
182
|
+
messages: The message list to evaluate.
|
|
183
|
+
current_tokens: The current token count.
|
|
184
|
+
max_tokens: The maximum allowed tokens.
|
|
185
|
+
|
|
186
|
+
Returns:
|
|
187
|
+
True if compression is needed, False otherwise.
|
|
188
|
+
"""
|
|
189
|
+
if max_tokens <= 0 or current_tokens <= 0:
|
|
190
|
+
return False
|
|
191
|
+
usage_rate = current_tokens / max_tokens
|
|
192
|
+
return usage_rate > self.compression_threshold
|
|
193
|
+
|
|
194
|
+
async def __call__(self, messages: list[Message]) -> list[Message]:
|
|
195
|
+
"""Use LLM to generate a summary of the conversation history.
|
|
196
|
+
|
|
197
|
+
Process:
|
|
198
|
+
1. Divide messages: keep the system message and the latest N messages.
|
|
199
|
+
2. Send the old messages + the instruction message to the LLM.
|
|
200
|
+
3. Reconstruct the message list: [system message, summary message, latest messages].
|
|
201
|
+
"""
|
|
202
|
+
if len(messages) <= self.keep_recent + 1:
|
|
203
|
+
return messages
|
|
204
|
+
|
|
205
|
+
system_messages, messages_to_summarize, recent_messages = split_history(
|
|
206
|
+
messages, self.keep_recent
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
if not messages_to_summarize:
|
|
210
|
+
return messages
|
|
211
|
+
|
|
212
|
+
# build payload
|
|
213
|
+
instruction_message = Message(role="user", content=self.instruction_text)
|
|
214
|
+
llm_payload = messages_to_summarize + [instruction_message]
|
|
215
|
+
|
|
216
|
+
# generate summary
|
|
217
|
+
try:
|
|
218
|
+
response = await self.provider.text_chat(contexts=llm_payload)
|
|
219
|
+
summary_content = response.completion_text
|
|
220
|
+
except Exception as e:
|
|
221
|
+
logger.error(f"Failed to generate summary: {e}")
|
|
222
|
+
return messages
|
|
223
|
+
|
|
224
|
+
# build result
|
|
225
|
+
result = []
|
|
226
|
+
result.extend(system_messages)
|
|
227
|
+
|
|
228
|
+
result.append(
|
|
229
|
+
Message(
|
|
230
|
+
role="user",
|
|
231
|
+
content=f"Our previous history conversation summary: {summary_content}",
|
|
232
|
+
)
|
|
233
|
+
)
|
|
234
|
+
result.append(
|
|
235
|
+
Message(
|
|
236
|
+
role="assistant",
|
|
237
|
+
content="Acknowledged the summary of our previous conversation history.",
|
|
238
|
+
)
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
result.extend(recent_messages)
|
|
242
|
+
|
|
243
|
+
return result
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import TYPE_CHECKING
|
|
3
|
+
|
|
4
|
+
from .compressor import ContextCompressor
|
|
5
|
+
from .token_counter import TokenCounter
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from astrbot.core.provider.provider import Provider
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class ContextConfig:
|
|
13
|
+
"""Context configuration class."""
|
|
14
|
+
|
|
15
|
+
max_context_tokens: int = 0
|
|
16
|
+
"""Maximum number of context tokens. <= 0 means no limit."""
|
|
17
|
+
enforce_max_turns: int = -1 # -1 means no limit
|
|
18
|
+
"""Maximum number of conversation turns to keep. -1 means no limit. Executed before compression."""
|
|
19
|
+
truncate_turns: int = 1
|
|
20
|
+
"""Number of conversation turns to discard at once when truncation is triggered.
|
|
21
|
+
Two processes will use this value:
|
|
22
|
+
|
|
23
|
+
1. Enforce max turns truncation.
|
|
24
|
+
2. Truncation by turns compression strategy.
|
|
25
|
+
"""
|
|
26
|
+
llm_compress_instruction: str | None = None
|
|
27
|
+
"""Instruction prompt for LLM-based compression."""
|
|
28
|
+
llm_compress_keep_recent: int = 0
|
|
29
|
+
"""Number of recent messages to keep during LLM-based compression."""
|
|
30
|
+
llm_compress_provider: "Provider | None" = None
|
|
31
|
+
"""LLM provider used for compression tasks. If None, truncation strategy is used."""
|
|
32
|
+
custom_token_counter: TokenCounter | None = None
|
|
33
|
+
"""Custom token counting method. If None, the default method is used."""
|
|
34
|
+
custom_compressor: ContextCompressor | None = None
|
|
35
|
+
"""Custom context compression method. If None, the default method is used."""
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
from astrbot import logger
|
|
2
|
+
|
|
3
|
+
from ..message import Message
|
|
4
|
+
from .compressor import LLMSummaryCompressor, TruncateByTurnsCompressor
|
|
5
|
+
from .config import ContextConfig
|
|
6
|
+
from .token_counter import EstimateTokenCounter
|
|
7
|
+
from .truncator import ContextTruncator
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ContextManager:
|
|
11
|
+
"""Context compression manager."""
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
config: ContextConfig,
|
|
16
|
+
):
|
|
17
|
+
"""Initialize the context manager.
|
|
18
|
+
|
|
19
|
+
There are two strategies to handle context limit reached:
|
|
20
|
+
1. Truncate by turns: remove older messages by turns.
|
|
21
|
+
2. LLM-based compression: use LLM to summarize old messages.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
config: The context configuration.
|
|
25
|
+
"""
|
|
26
|
+
self.config = config
|
|
27
|
+
|
|
28
|
+
self.token_counter = config.custom_token_counter or EstimateTokenCounter()
|
|
29
|
+
self.truncator = ContextTruncator()
|
|
30
|
+
|
|
31
|
+
if config.custom_compressor:
|
|
32
|
+
self.compressor = config.custom_compressor
|
|
33
|
+
elif config.llm_compress_provider:
|
|
34
|
+
self.compressor = LLMSummaryCompressor(
|
|
35
|
+
provider=config.llm_compress_provider,
|
|
36
|
+
keep_recent=config.llm_compress_keep_recent,
|
|
37
|
+
instruction_text=config.llm_compress_instruction,
|
|
38
|
+
)
|
|
39
|
+
else:
|
|
40
|
+
self.compressor = TruncateByTurnsCompressor(
|
|
41
|
+
truncate_turns=config.truncate_turns
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
async def process(
|
|
45
|
+
self, messages: list[Message], trusted_token_usage: int = 0
|
|
46
|
+
) -> list[Message]:
|
|
47
|
+
"""Process the messages.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
messages: The original message list.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
The processed message list.
|
|
54
|
+
"""
|
|
55
|
+
try:
|
|
56
|
+
result = messages
|
|
57
|
+
|
|
58
|
+
# 1. 基于轮次的截断 (Enforce max turns)
|
|
59
|
+
if self.config.enforce_max_turns != -1:
|
|
60
|
+
result = self.truncator.truncate_by_turns(
|
|
61
|
+
result,
|
|
62
|
+
keep_most_recent_turns=self.config.enforce_max_turns,
|
|
63
|
+
drop_turns=self.config.truncate_turns,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
# 2. 基于 token 的压缩
|
|
67
|
+
if self.config.max_context_tokens > 0:
|
|
68
|
+
total_tokens = self.token_counter.count_tokens(
|
|
69
|
+
result, trusted_token_usage
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
if self.compressor.should_compress(
|
|
73
|
+
result, total_tokens, self.config.max_context_tokens
|
|
74
|
+
):
|
|
75
|
+
result = await self._run_compression(result, total_tokens)
|
|
76
|
+
|
|
77
|
+
return result
|
|
78
|
+
except Exception as e:
|
|
79
|
+
logger.error(f"Error during context processing: {e}", exc_info=True)
|
|
80
|
+
return messages
|
|
81
|
+
|
|
82
|
+
async def _run_compression(
|
|
83
|
+
self, messages: list[Message], prev_tokens: int
|
|
84
|
+
) -> list[Message]:
|
|
85
|
+
"""
|
|
86
|
+
Compress/truncate the messages.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
messages: The original message list.
|
|
90
|
+
prev_tokens: The token count before compression.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
The compressed/truncated message list.
|
|
94
|
+
"""
|
|
95
|
+
logger.debug("Compress triggered, starting compression...")
|
|
96
|
+
|
|
97
|
+
messages = await self.compressor(messages)
|
|
98
|
+
|
|
99
|
+
# double check
|
|
100
|
+
tokens_after_summary = self.token_counter.count_tokens(messages)
|
|
101
|
+
|
|
102
|
+
# calculate compress rate
|
|
103
|
+
compress_rate = (tokens_after_summary / self.config.max_context_tokens) * 100
|
|
104
|
+
logger.info(
|
|
105
|
+
f"Compress completed."
|
|
106
|
+
f" {prev_tokens} -> {tokens_after_summary} tokens,"
|
|
107
|
+
f" compression rate: {compress_rate:.2f}%.",
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
# last check
|
|
111
|
+
if self.compressor.should_compress(
|
|
112
|
+
messages, tokens_after_summary, self.config.max_context_tokens
|
|
113
|
+
):
|
|
114
|
+
logger.info(
|
|
115
|
+
"Context still exceeds max tokens after compression, applying halving truncation..."
|
|
116
|
+
)
|
|
117
|
+
# still need compress, truncate by half
|
|
118
|
+
messages = self.truncator.truncate_by_halving(messages)
|
|
119
|
+
|
|
120
|
+
return messages
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import Protocol, runtime_checkable
|
|
3
|
+
|
|
4
|
+
from ..message import Message, TextPart
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@runtime_checkable
|
|
8
|
+
class TokenCounter(Protocol):
|
|
9
|
+
"""
|
|
10
|
+
Protocol for token counters.
|
|
11
|
+
Provides an interface for counting tokens in message lists.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def count_tokens(
|
|
15
|
+
self, messages: list[Message], trusted_token_usage: int = 0
|
|
16
|
+
) -> int:
|
|
17
|
+
"""Count the total tokens in the message list.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
messages: The message list.
|
|
21
|
+
trusted_token_usage: The total token usage that LLM API returned.
|
|
22
|
+
For some cases, this value is more accurate.
|
|
23
|
+
But some API does not return it, so the value defaults to 0.
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
The total token count.
|
|
27
|
+
"""
|
|
28
|
+
...
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class EstimateTokenCounter:
|
|
32
|
+
"""Estimate token counter implementation.
|
|
33
|
+
Provides a simple estimation of token count based on character types.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def count_tokens(
|
|
37
|
+
self, messages: list[Message], trusted_token_usage: int = 0
|
|
38
|
+
) -> int:
|
|
39
|
+
if trusted_token_usage > 0:
|
|
40
|
+
return trusted_token_usage
|
|
41
|
+
|
|
42
|
+
total = 0
|
|
43
|
+
for msg in messages:
|
|
44
|
+
content = msg.content
|
|
45
|
+
if isinstance(content, str):
|
|
46
|
+
total += self._estimate_tokens(content)
|
|
47
|
+
elif isinstance(content, list):
|
|
48
|
+
# 处理多模态内容
|
|
49
|
+
for part in content:
|
|
50
|
+
if isinstance(part, TextPart):
|
|
51
|
+
total += self._estimate_tokens(part.text)
|
|
52
|
+
|
|
53
|
+
# 处理 Tool Calls
|
|
54
|
+
if msg.tool_calls:
|
|
55
|
+
for tc in msg.tool_calls:
|
|
56
|
+
tc_str = json.dumps(tc if isinstance(tc, dict) else tc.model_dump())
|
|
57
|
+
total += self._estimate_tokens(tc_str)
|
|
58
|
+
|
|
59
|
+
return total
|
|
60
|
+
|
|
61
|
+
def _estimate_tokens(self, text: str) -> int:
|
|
62
|
+
chinese_count = len([c for c in text if "\u4e00" <= c <= "\u9fff"])
|
|
63
|
+
other_count = len(text) - chinese_count
|
|
64
|
+
return int(chinese_count * 0.6 + other_count * 0.3)
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
from ..message import Message
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class ContextTruncator:
|
|
5
|
+
"""Context truncator."""
|
|
6
|
+
|
|
7
|
+
def fix_messages(self, messages: list[Message]) -> list[Message]:
|
|
8
|
+
fixed_messages = []
|
|
9
|
+
for message in messages:
|
|
10
|
+
if message.role == "tool":
|
|
11
|
+
# tool block 前面必须要有 user 和 assistant block
|
|
12
|
+
if len(fixed_messages) < 2:
|
|
13
|
+
# 这种情况可能是上下文被截断导致的
|
|
14
|
+
# 我们直接将之前的上下文都清空
|
|
15
|
+
fixed_messages = []
|
|
16
|
+
else:
|
|
17
|
+
fixed_messages.append(message)
|
|
18
|
+
else:
|
|
19
|
+
fixed_messages.append(message)
|
|
20
|
+
return fixed_messages
|
|
21
|
+
|
|
22
|
+
def truncate_by_turns(
|
|
23
|
+
self,
|
|
24
|
+
messages: list[Message],
|
|
25
|
+
keep_most_recent_turns: int,
|
|
26
|
+
drop_turns: int = 1,
|
|
27
|
+
) -> list[Message]:
|
|
28
|
+
"""截断上下文列表,确保不超过最大长度。
|
|
29
|
+
一个 turn 包含一个 user 消息和一个 assistant 消息。
|
|
30
|
+
这个方法会保证截断后的上下文列表符合 OpenAI 的上下文格式。
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
messages: 上下文列表
|
|
34
|
+
keep_most_recent_turns: 保留最近的对话轮数
|
|
35
|
+
drop_turns: 一次性丢弃的对话轮数
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
截断后的上下文列表
|
|
39
|
+
"""
|
|
40
|
+
if keep_most_recent_turns == -1:
|
|
41
|
+
return messages
|
|
42
|
+
|
|
43
|
+
first_non_system = 0
|
|
44
|
+
for i, msg in enumerate(messages):
|
|
45
|
+
if msg.role != "system":
|
|
46
|
+
first_non_system = i
|
|
47
|
+
break
|
|
48
|
+
|
|
49
|
+
system_messages = messages[:first_non_system]
|
|
50
|
+
non_system_messages = messages[first_non_system:]
|
|
51
|
+
|
|
52
|
+
if len(non_system_messages) // 2 <= keep_most_recent_turns:
|
|
53
|
+
return messages
|
|
54
|
+
|
|
55
|
+
num_to_keep = keep_most_recent_turns - drop_turns + 1
|
|
56
|
+
if num_to_keep <= 0:
|
|
57
|
+
truncated_contexts = []
|
|
58
|
+
else:
|
|
59
|
+
truncated_contexts = non_system_messages[-num_to_keep * 2 :]
|
|
60
|
+
|
|
61
|
+
# 找到第一个 role 为 user 的索引,确保上下文格式正确
|
|
62
|
+
index = next(
|
|
63
|
+
(i for i, item in enumerate(truncated_contexts) if item.role == "user"),
|
|
64
|
+
None,
|
|
65
|
+
)
|
|
66
|
+
if index is not None and index > 0:
|
|
67
|
+
truncated_contexts = truncated_contexts[index:]
|
|
68
|
+
|
|
69
|
+
result = system_messages + truncated_contexts
|
|
70
|
+
|
|
71
|
+
return self.fix_messages(result)
|
|
72
|
+
|
|
73
|
+
def truncate_by_dropping_oldest_turns(
|
|
74
|
+
self,
|
|
75
|
+
messages: list[Message],
|
|
76
|
+
drop_turns: int = 1,
|
|
77
|
+
) -> list[Message]:
|
|
78
|
+
"""丢弃最旧的 N 个对话轮次。"""
|
|
79
|
+
if drop_turns <= 0:
|
|
80
|
+
return messages
|
|
81
|
+
|
|
82
|
+
first_non_system = 0
|
|
83
|
+
for i, msg in enumerate(messages):
|
|
84
|
+
if msg.role != "system":
|
|
85
|
+
first_non_system = i
|
|
86
|
+
break
|
|
87
|
+
|
|
88
|
+
system_messages = messages[:first_non_system]
|
|
89
|
+
non_system_messages = messages[first_non_system:]
|
|
90
|
+
|
|
91
|
+
if len(non_system_messages) // 2 <= drop_turns:
|
|
92
|
+
truncated_non_system = []
|
|
93
|
+
else:
|
|
94
|
+
truncated_non_system = non_system_messages[drop_turns * 2 :]
|
|
95
|
+
|
|
96
|
+
index = next(
|
|
97
|
+
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
|
|
98
|
+
None,
|
|
99
|
+
)
|
|
100
|
+
if index is not None:
|
|
101
|
+
truncated_non_system = truncated_non_system[index:]
|
|
102
|
+
elif truncated_non_system:
|
|
103
|
+
truncated_non_system = []
|
|
104
|
+
|
|
105
|
+
result = system_messages + truncated_non_system
|
|
106
|
+
|
|
107
|
+
return self.fix_messages(result)
|
|
108
|
+
|
|
109
|
+
def truncate_by_halving(
|
|
110
|
+
self,
|
|
111
|
+
messages: list[Message],
|
|
112
|
+
) -> list[Message]:
|
|
113
|
+
"""对半砍策略,删除 50% 的消息"""
|
|
114
|
+
if len(messages) <= 2:
|
|
115
|
+
return messages
|
|
116
|
+
|
|
117
|
+
first_non_system = 0
|
|
118
|
+
for i, msg in enumerate(messages):
|
|
119
|
+
if msg.role != "system":
|
|
120
|
+
first_non_system = i
|
|
121
|
+
break
|
|
122
|
+
|
|
123
|
+
system_messages = messages[:first_non_system]
|
|
124
|
+
non_system_messages = messages[first_non_system:]
|
|
125
|
+
|
|
126
|
+
messages_to_delete = len(non_system_messages) // 2
|
|
127
|
+
if messages_to_delete == 0:
|
|
128
|
+
return messages
|
|
129
|
+
|
|
130
|
+
truncated_non_system = non_system_messages[messages_to_delete:]
|
|
131
|
+
|
|
132
|
+
index = next(
|
|
133
|
+
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
|
|
134
|
+
None,
|
|
135
|
+
)
|
|
136
|
+
if index is not None:
|
|
137
|
+
truncated_non_system = truncated_non_system[index:]
|
|
138
|
+
|
|
139
|
+
result = system_messages + truncated_non_system
|
|
140
|
+
|
|
141
|
+
return self.fix_messages(result)
|