headroom-ai 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- headroom/__init__.py +212 -0
- headroom/cache/__init__.py +76 -0
- headroom/cache/anthropic.py +517 -0
- headroom/cache/base.py +342 -0
- headroom/cache/compression_feedback.py +613 -0
- headroom/cache/compression_store.py +814 -0
- headroom/cache/dynamic_detector.py +1026 -0
- headroom/cache/google.py +884 -0
- headroom/cache/openai.py +584 -0
- headroom/cache/registry.py +175 -0
- headroom/cache/semantic.py +451 -0
- headroom/ccr/__init__.py +77 -0
- headroom/ccr/context_tracker.py +582 -0
- headroom/ccr/mcp_server.py +319 -0
- headroom/ccr/response_handler.py +772 -0
- headroom/ccr/tool_injection.py +415 -0
- headroom/cli.py +219 -0
- headroom/client.py +977 -0
- headroom/compression/__init__.py +42 -0
- headroom/compression/detector.py +424 -0
- headroom/compression/handlers/__init__.py +22 -0
- headroom/compression/handlers/base.py +219 -0
- headroom/compression/handlers/code_handler.py +506 -0
- headroom/compression/handlers/json_handler.py +418 -0
- headroom/compression/masks.py +345 -0
- headroom/compression/universal.py +465 -0
- headroom/config.py +474 -0
- headroom/exceptions.py +192 -0
- headroom/integrations/__init__.py +159 -0
- headroom/integrations/agno/__init__.py +53 -0
- headroom/integrations/agno/hooks.py +345 -0
- headroom/integrations/agno/model.py +625 -0
- headroom/integrations/agno/providers.py +154 -0
- headroom/integrations/langchain/__init__.py +106 -0
- headroom/integrations/langchain/agents.py +326 -0
- headroom/integrations/langchain/chat_model.py +1002 -0
- headroom/integrations/langchain/langsmith.py +324 -0
- headroom/integrations/langchain/memory.py +319 -0
- headroom/integrations/langchain/providers.py +200 -0
- headroom/integrations/langchain/retriever.py +371 -0
- headroom/integrations/langchain/streaming.py +341 -0
- headroom/integrations/mcp/__init__.py +37 -0
- headroom/integrations/mcp/server.py +533 -0
- headroom/memory/__init__.py +37 -0
- headroom/memory/extractor.py +390 -0
- headroom/memory/fast_store.py +621 -0
- headroom/memory/fast_wrapper.py +311 -0
- headroom/memory/inline_extractor.py +229 -0
- headroom/memory/store.py +434 -0
- headroom/memory/worker.py +260 -0
- headroom/memory/wrapper.py +321 -0
- headroom/models/__init__.py +39 -0
- headroom/models/registry.py +687 -0
- headroom/parser.py +293 -0
- headroom/pricing/__init__.py +51 -0
- headroom/pricing/anthropic_prices.py +81 -0
- headroom/pricing/litellm_pricing.py +113 -0
- headroom/pricing/openai_prices.py +91 -0
- headroom/pricing/registry.py +188 -0
- headroom/providers/__init__.py +61 -0
- headroom/providers/anthropic.py +621 -0
- headroom/providers/base.py +131 -0
- headroom/providers/cohere.py +362 -0
- headroom/providers/google.py +427 -0
- headroom/providers/litellm.py +297 -0
- headroom/providers/openai.py +566 -0
- headroom/providers/openai_compatible.py +521 -0
- headroom/proxy/__init__.py +19 -0
- headroom/proxy/server.py +2683 -0
- headroom/py.typed +0 -0
- headroom/relevance/__init__.py +124 -0
- headroom/relevance/base.py +106 -0
- headroom/relevance/bm25.py +255 -0
- headroom/relevance/embedding.py +255 -0
- headroom/relevance/hybrid.py +259 -0
- headroom/reporting/__init__.py +5 -0
- headroom/reporting/generator.py +549 -0
- headroom/storage/__init__.py +41 -0
- headroom/storage/base.py +125 -0
- headroom/storage/jsonl.py +220 -0
- headroom/storage/sqlite.py +289 -0
- headroom/telemetry/__init__.py +91 -0
- headroom/telemetry/collector.py +764 -0
- headroom/telemetry/models.py +880 -0
- headroom/telemetry/toin.py +1579 -0
- headroom/tokenizer.py +80 -0
- headroom/tokenizers/__init__.py +75 -0
- headroom/tokenizers/base.py +210 -0
- headroom/tokenizers/estimator.py +198 -0
- headroom/tokenizers/huggingface.py +317 -0
- headroom/tokenizers/mistral.py +245 -0
- headroom/tokenizers/registry.py +398 -0
- headroom/tokenizers/tiktoken_counter.py +248 -0
- headroom/transforms/__init__.py +106 -0
- headroom/transforms/base.py +57 -0
- headroom/transforms/cache_aligner.py +357 -0
- headroom/transforms/code_compressor.py +1313 -0
- headroom/transforms/content_detector.py +335 -0
- headroom/transforms/content_router.py +1158 -0
- headroom/transforms/llmlingua_compressor.py +638 -0
- headroom/transforms/log_compressor.py +529 -0
- headroom/transforms/pipeline.py +297 -0
- headroom/transforms/rolling_window.py +350 -0
- headroom/transforms/search_compressor.py +365 -0
- headroom/transforms/smart_crusher.py +2682 -0
- headroom/transforms/text_compressor.py +259 -0
- headroom/transforms/tool_crusher.py +338 -0
- headroom/utils.py +215 -0
- headroom_ai-0.2.13.dist-info/METADATA +315 -0
- headroom_ai-0.2.13.dist-info/RECORD +114 -0
- headroom_ai-0.2.13.dist-info/WHEEL +4 -0
- headroom_ai-0.2.13.dist-info/entry_points.txt +2 -0
- headroom_ai-0.2.13.dist-info/licenses/LICENSE +190 -0
- headroom_ai-0.2.13.dist-info/licenses/NOTICE +43 -0
|
@@ -0,0 +1,297 @@
|
|
|
1
|
+
"""Transform pipeline orchestration for Headroom SDK."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from typing import TYPE_CHECKING, Any
|
|
7
|
+
|
|
8
|
+
from ..config import (
|
|
9
|
+
CacheAlignerConfig,
|
|
10
|
+
DiffArtifact,
|
|
11
|
+
HeadroomConfig,
|
|
12
|
+
RollingWindowConfig,
|
|
13
|
+
ToolCrusherConfig,
|
|
14
|
+
TransformDiff,
|
|
15
|
+
TransformResult,
|
|
16
|
+
)
|
|
17
|
+
from ..tokenizer import Tokenizer
|
|
18
|
+
from ..utils import deep_copy_messages
|
|
19
|
+
from .base import Transform
|
|
20
|
+
from .cache_aligner import CacheAligner
|
|
21
|
+
from .rolling_window import RollingWindow
|
|
22
|
+
from .smart_crusher import SmartCrusher
|
|
23
|
+
from .tool_crusher import ToolCrusher
|
|
24
|
+
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from ..providers.base import Provider
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class TransformPipeline:
|
|
32
|
+
"""
|
|
33
|
+
Orchestrates multiple transforms in the correct order.
|
|
34
|
+
|
|
35
|
+
Transform order:
|
|
36
|
+
1. Cache Aligner - normalize prefix for cache hits
|
|
37
|
+
2. Tool Crusher - compress tool outputs
|
|
38
|
+
3. Rolling Window - enforce token limits
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
config: HeadroomConfig | None = None,
|
|
44
|
+
transforms: list[Transform] | None = None,
|
|
45
|
+
provider: Provider | None = None,
|
|
46
|
+
):
|
|
47
|
+
"""
|
|
48
|
+
Initialize pipeline.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
config: Headroom configuration.
|
|
52
|
+
transforms: Optional custom transform list (overrides config).
|
|
53
|
+
provider: Provider for model-specific behavior.
|
|
54
|
+
"""
|
|
55
|
+
self.config = config or HeadroomConfig()
|
|
56
|
+
self._provider = provider
|
|
57
|
+
|
|
58
|
+
if transforms is not None:
|
|
59
|
+
self.transforms = transforms
|
|
60
|
+
else:
|
|
61
|
+
self.transforms = self._build_default_transforms()
|
|
62
|
+
|
|
63
|
+
def _build_default_transforms(self) -> list[Transform]:
|
|
64
|
+
"""Build default transform pipeline from config."""
|
|
65
|
+
transforms: list[Transform] = []
|
|
66
|
+
|
|
67
|
+
# Order matters!
|
|
68
|
+
|
|
69
|
+
# 1. Cache Aligner (prefix stabilization)
|
|
70
|
+
if self.config.cache_aligner.enabled:
|
|
71
|
+
transforms.append(CacheAligner(self.config.cache_aligner))
|
|
72
|
+
|
|
73
|
+
# 2. Tool Output Compression
|
|
74
|
+
# SmartCrusher (statistical) takes precedence over ToolCrusher (fixed rules)
|
|
75
|
+
if self.config.smart_crusher.enabled:
|
|
76
|
+
# Use smart statistical crushing
|
|
77
|
+
from .smart_crusher import SmartCrusherConfig as SCConfig
|
|
78
|
+
|
|
79
|
+
smart_config = SCConfig(
|
|
80
|
+
enabled=True,
|
|
81
|
+
min_items_to_analyze=self.config.smart_crusher.min_items_to_analyze,
|
|
82
|
+
min_tokens_to_crush=self.config.smart_crusher.min_tokens_to_crush,
|
|
83
|
+
variance_threshold=self.config.smart_crusher.variance_threshold,
|
|
84
|
+
uniqueness_threshold=self.config.smart_crusher.uniqueness_threshold,
|
|
85
|
+
similarity_threshold=self.config.smart_crusher.similarity_threshold,
|
|
86
|
+
max_items_after_crush=self.config.smart_crusher.max_items_after_crush,
|
|
87
|
+
preserve_change_points=self.config.smart_crusher.preserve_change_points,
|
|
88
|
+
factor_out_constants=self.config.smart_crusher.factor_out_constants,
|
|
89
|
+
include_summaries=self.config.smart_crusher.include_summaries,
|
|
90
|
+
)
|
|
91
|
+
transforms.append(SmartCrusher(smart_config))
|
|
92
|
+
elif self.config.tool_crusher.enabled:
|
|
93
|
+
# Fallback to fixed-rule crushing
|
|
94
|
+
transforms.append(ToolCrusher(self.config.tool_crusher))
|
|
95
|
+
|
|
96
|
+
# 3. Rolling Window (enforce limits last)
|
|
97
|
+
if self.config.rolling_window.enabled:
|
|
98
|
+
transforms.append(RollingWindow(self.config.rolling_window))
|
|
99
|
+
|
|
100
|
+
return transforms
|
|
101
|
+
|
|
102
|
+
def _get_tokenizer(self, model: str) -> Tokenizer:
|
|
103
|
+
"""Get tokenizer for model using provider."""
|
|
104
|
+
if self._provider is None:
|
|
105
|
+
raise ValueError(
|
|
106
|
+
"Provider is required for token counting. "
|
|
107
|
+
"Pass a provider to TransformPipeline or HeadroomClient."
|
|
108
|
+
)
|
|
109
|
+
token_counter = self._provider.get_token_counter(model)
|
|
110
|
+
return Tokenizer(token_counter, model)
|
|
111
|
+
|
|
112
|
+
def apply(
|
|
113
|
+
self,
|
|
114
|
+
messages: list[dict[str, Any]],
|
|
115
|
+
model: str,
|
|
116
|
+
**kwargs: Any,
|
|
117
|
+
) -> TransformResult:
|
|
118
|
+
"""
|
|
119
|
+
Apply all transforms in sequence.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
messages: List of messages to transform.
|
|
123
|
+
model: Model name for token counting.
|
|
124
|
+
**kwargs: Additional arguments passed to transforms.
|
|
125
|
+
- model_limit: Context limit override.
|
|
126
|
+
- output_buffer: Output buffer override.
|
|
127
|
+
- tool_profiles: Per-tool compression profiles.
|
|
128
|
+
- request_id: Optional request ID for diff artifact.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
Combined TransformResult.
|
|
132
|
+
"""
|
|
133
|
+
tokenizer = self._get_tokenizer(model)
|
|
134
|
+
|
|
135
|
+
# Get model limit from kwargs (should be set by client)
|
|
136
|
+
model_limit = kwargs.get("model_limit")
|
|
137
|
+
if model_limit is None:
|
|
138
|
+
raise ValueError(
|
|
139
|
+
"model_limit is required. Provide it via kwargs or "
|
|
140
|
+
"configure model_context_limits in HeadroomClient."
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# Start with original tokens
|
|
144
|
+
tokens_before = tokenizer.count_messages(messages)
|
|
145
|
+
|
|
146
|
+
logger.debug(
|
|
147
|
+
"Pipeline starting: %d messages, %d tokens, model=%s",
|
|
148
|
+
len(messages),
|
|
149
|
+
tokens_before,
|
|
150
|
+
model,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# Track all transforms applied
|
|
154
|
+
all_transforms: list[str] = []
|
|
155
|
+
all_markers: list[str] = []
|
|
156
|
+
all_warnings: list[str] = []
|
|
157
|
+
|
|
158
|
+
# Track transform diffs if enabled
|
|
159
|
+
transform_diffs: list[TransformDiff] = []
|
|
160
|
+
generate_diff = self.config.generate_diff_artifact
|
|
161
|
+
|
|
162
|
+
current_messages = deep_copy_messages(messages)
|
|
163
|
+
|
|
164
|
+
for transform in self.transforms:
|
|
165
|
+
# Check if transform should run
|
|
166
|
+
if not transform.should_apply(current_messages, tokenizer, **kwargs):
|
|
167
|
+
continue
|
|
168
|
+
|
|
169
|
+
# Track tokens before this transform (for diff)
|
|
170
|
+
tokens_before_transform = tokenizer.count_messages(current_messages)
|
|
171
|
+
|
|
172
|
+
# Apply transform
|
|
173
|
+
result = transform.apply(current_messages, tokenizer, **kwargs)
|
|
174
|
+
|
|
175
|
+
# Update messages for next transform
|
|
176
|
+
current_messages = result.messages
|
|
177
|
+
|
|
178
|
+
# Track tokens after this transform (for diff)
|
|
179
|
+
tokens_after_transform = tokenizer.count_messages(current_messages)
|
|
180
|
+
|
|
181
|
+
# Accumulate results
|
|
182
|
+
all_transforms.extend(result.transforms_applied)
|
|
183
|
+
all_markers.extend(result.markers_inserted)
|
|
184
|
+
all_warnings.extend(result.warnings)
|
|
185
|
+
|
|
186
|
+
# Log transform results
|
|
187
|
+
if result.transforms_applied:
|
|
188
|
+
logger.info(
|
|
189
|
+
"Transform %s: %d -> %d tokens (saved %d)",
|
|
190
|
+
transform.name,
|
|
191
|
+
tokens_before_transform,
|
|
192
|
+
tokens_after_transform,
|
|
193
|
+
tokens_before_transform - tokens_after_transform,
|
|
194
|
+
)
|
|
195
|
+
else:
|
|
196
|
+
logger.debug("Transform %s: no changes", transform.name)
|
|
197
|
+
|
|
198
|
+
# Record diff if enabled
|
|
199
|
+
if generate_diff:
|
|
200
|
+
transform_diffs.append(
|
|
201
|
+
TransformDiff(
|
|
202
|
+
transform_name=transform.name,
|
|
203
|
+
tokens_before=tokens_before_transform,
|
|
204
|
+
tokens_after=tokens_after_transform,
|
|
205
|
+
tokens_saved=tokens_before_transform - tokens_after_transform,
|
|
206
|
+
details=", ".join(result.transforms_applied)
|
|
207
|
+
if result.transforms_applied
|
|
208
|
+
else "",
|
|
209
|
+
)
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
# Final token count
|
|
213
|
+
tokens_after = tokenizer.count_messages(current_messages)
|
|
214
|
+
|
|
215
|
+
# Log pipeline summary
|
|
216
|
+
total_saved = tokens_before - tokens_after
|
|
217
|
+
if total_saved > 0:
|
|
218
|
+
logger.info(
|
|
219
|
+
"Pipeline complete: %d -> %d tokens (saved %d, %.1f%% reduction)",
|
|
220
|
+
tokens_before,
|
|
221
|
+
tokens_after,
|
|
222
|
+
total_saved,
|
|
223
|
+
(total_saved / tokens_before * 100) if tokens_before > 0 else 0,
|
|
224
|
+
)
|
|
225
|
+
else:
|
|
226
|
+
logger.debug("Pipeline complete: no token savings")
|
|
227
|
+
|
|
228
|
+
# Build diff artifact if enabled
|
|
229
|
+
diff_artifact = None
|
|
230
|
+
if generate_diff:
|
|
231
|
+
diff_artifact = DiffArtifact(
|
|
232
|
+
request_id=kwargs.get("request_id", ""),
|
|
233
|
+
original_tokens=tokens_before,
|
|
234
|
+
optimized_tokens=tokens_after,
|
|
235
|
+
total_tokens_saved=tokens_before - tokens_after,
|
|
236
|
+
transforms=transform_diffs,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
return TransformResult(
|
|
240
|
+
messages=current_messages,
|
|
241
|
+
tokens_before=tokens_before,
|
|
242
|
+
tokens_after=tokens_after,
|
|
243
|
+
transforms_applied=all_transforms,
|
|
244
|
+
markers_inserted=all_markers,
|
|
245
|
+
warnings=all_warnings,
|
|
246
|
+
diff_artifact=diff_artifact,
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
def simulate(
|
|
250
|
+
self,
|
|
251
|
+
messages: list[dict[str, Any]],
|
|
252
|
+
model: str,
|
|
253
|
+
**kwargs: Any,
|
|
254
|
+
) -> TransformResult:
|
|
255
|
+
"""
|
|
256
|
+
Simulate transforms without modifying messages.
|
|
257
|
+
|
|
258
|
+
Same as apply() but returns what WOULD happen.
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
messages: List of messages.
|
|
262
|
+
model: Model name.
|
|
263
|
+
**kwargs: Additional arguments.
|
|
264
|
+
|
|
265
|
+
Returns:
|
|
266
|
+
TransformResult with simulated changes.
|
|
267
|
+
"""
|
|
268
|
+
# apply() already works on a copy, so this is safe
|
|
269
|
+
return self.apply(messages, model, **kwargs)
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def create_pipeline(
|
|
273
|
+
tool_crusher_config: ToolCrusherConfig | None = None,
|
|
274
|
+
cache_aligner_config: CacheAlignerConfig | None = None,
|
|
275
|
+
rolling_window_config: RollingWindowConfig | None = None,
|
|
276
|
+
) -> TransformPipeline:
|
|
277
|
+
"""
|
|
278
|
+
Create a pipeline with specific configurations.
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
tool_crusher_config: Tool crusher configuration.
|
|
282
|
+
cache_aligner_config: Cache aligner configuration.
|
|
283
|
+
rolling_window_config: Rolling window configuration.
|
|
284
|
+
|
|
285
|
+
Returns:
|
|
286
|
+
Configured TransformPipeline.
|
|
287
|
+
"""
|
|
288
|
+
config = HeadroomConfig()
|
|
289
|
+
|
|
290
|
+
if tool_crusher_config is not None:
|
|
291
|
+
config.tool_crusher = tool_crusher_config
|
|
292
|
+
if cache_aligner_config is not None:
|
|
293
|
+
config.cache_aligner = cache_aligner_config
|
|
294
|
+
if rolling_window_config is not None:
|
|
295
|
+
config.rolling_window = rolling_window_config
|
|
296
|
+
|
|
297
|
+
return TransformPipeline(config)
|
|
@@ -0,0 +1,350 @@
|
|
|
1
|
+
"""Rolling window transform for Headroom SDK."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from ..config import RollingWindowConfig, TransformResult
|
|
9
|
+
from ..parser import find_tool_units
|
|
10
|
+
from ..tokenizer import Tokenizer
|
|
11
|
+
from ..tokenizers import EstimatingTokenCounter
|
|
12
|
+
from ..utils import create_dropped_context_marker, deep_copy_messages
|
|
13
|
+
from .base import Transform
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class RollingWindow(Transform):
|
|
19
|
+
"""
|
|
20
|
+
Apply rolling window to keep messages within token budget.
|
|
21
|
+
|
|
22
|
+
Drop order (deterministic):
|
|
23
|
+
1. Oldest TOOL UNITS (assistant+tool_calls paired with tool responses)
|
|
24
|
+
2. Oldest assistant+user pairs
|
|
25
|
+
3. Oldest RAG blocks (if detectable)
|
|
26
|
+
|
|
27
|
+
CRITICAL: Tool calls and tool results are atomic DROP UNITS.
|
|
28
|
+
Never orphan a tool result.
|
|
29
|
+
|
|
30
|
+
Never drops:
|
|
31
|
+
- System prompt
|
|
32
|
+
- Stable instructions
|
|
33
|
+
- Last N conversational turns (configurable)
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
name = "rolling_window"
|
|
37
|
+
|
|
38
|
+
def __init__(self, config: RollingWindowConfig | None = None):
|
|
39
|
+
"""
|
|
40
|
+
Initialize rolling window.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
config: Configuration for window behavior.
|
|
44
|
+
"""
|
|
45
|
+
self.config = config or RollingWindowConfig()
|
|
46
|
+
|
|
47
|
+
def should_apply(
|
|
48
|
+
self,
|
|
49
|
+
messages: list[dict[str, Any]],
|
|
50
|
+
tokenizer: Tokenizer,
|
|
51
|
+
**kwargs: Any,
|
|
52
|
+
) -> bool:
|
|
53
|
+
"""Check if token cap is exceeded."""
|
|
54
|
+
if not self.config.enabled:
|
|
55
|
+
return False
|
|
56
|
+
|
|
57
|
+
model_limit = kwargs.get("model_limit", 128000)
|
|
58
|
+
output_buffer = kwargs.get("output_buffer", self.config.output_buffer_tokens)
|
|
59
|
+
|
|
60
|
+
current_tokens = tokenizer.count_messages(messages)
|
|
61
|
+
available = model_limit - output_buffer
|
|
62
|
+
|
|
63
|
+
return bool(current_tokens > available)
|
|
64
|
+
|
|
65
|
+
def apply(
|
|
66
|
+
self,
|
|
67
|
+
messages: list[dict[str, Any]],
|
|
68
|
+
tokenizer: Tokenizer,
|
|
69
|
+
**kwargs: Any,
|
|
70
|
+
) -> TransformResult:
|
|
71
|
+
"""
|
|
72
|
+
Apply rolling window to messages.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
messages: List of messages.
|
|
76
|
+
tokenizer: Tokenizer for counting.
|
|
77
|
+
**kwargs: Must include 'model_limit', optionally 'output_buffer'.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
TransformResult with windowed messages.
|
|
81
|
+
"""
|
|
82
|
+
model_limit = kwargs.get("model_limit", 128000)
|
|
83
|
+
output_buffer = kwargs.get("output_buffer", self.config.output_buffer_tokens)
|
|
84
|
+
available = model_limit - output_buffer
|
|
85
|
+
|
|
86
|
+
tokens_before = tokenizer.count_messages(messages)
|
|
87
|
+
result_messages = deep_copy_messages(messages)
|
|
88
|
+
transforms_applied: list[str] = []
|
|
89
|
+
markers_inserted: list[str] = []
|
|
90
|
+
warnings: list[str] = []
|
|
91
|
+
|
|
92
|
+
dropped_count = 0
|
|
93
|
+
tool_units_dropped = 0
|
|
94
|
+
|
|
95
|
+
# If already under budget, no changes needed
|
|
96
|
+
current_tokens = tokens_before
|
|
97
|
+
if current_tokens <= available:
|
|
98
|
+
return TransformResult(
|
|
99
|
+
messages=result_messages,
|
|
100
|
+
tokens_before=tokens_before,
|
|
101
|
+
tokens_after=tokens_before,
|
|
102
|
+
transforms_applied=[],
|
|
103
|
+
warnings=[],
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
# Identify protected indices
|
|
107
|
+
protected = self._get_protected_indices(result_messages)
|
|
108
|
+
|
|
109
|
+
# Identify tool units
|
|
110
|
+
tool_units = find_tool_units(result_messages)
|
|
111
|
+
|
|
112
|
+
# Create drop candidates with priorities
|
|
113
|
+
drop_candidates = self._build_drop_candidates(result_messages, protected, tool_units)
|
|
114
|
+
|
|
115
|
+
# Drop until under budget
|
|
116
|
+
indices_to_drop: set[int] = set()
|
|
117
|
+
|
|
118
|
+
for candidate in drop_candidates:
|
|
119
|
+
if current_tokens <= available:
|
|
120
|
+
break
|
|
121
|
+
|
|
122
|
+
# Get indices for this candidate
|
|
123
|
+
candidate_indices = candidate["indices"]
|
|
124
|
+
|
|
125
|
+
# Skip if any are protected
|
|
126
|
+
if any(idx in protected for idx in candidate_indices):
|
|
127
|
+
continue
|
|
128
|
+
|
|
129
|
+
# Skip if already dropped
|
|
130
|
+
if any(idx in indices_to_drop for idx in candidate_indices):
|
|
131
|
+
continue
|
|
132
|
+
|
|
133
|
+
# Calculate tokens saved
|
|
134
|
+
tokens_saved = sum(
|
|
135
|
+
tokenizer.count_message(result_messages[idx])
|
|
136
|
+
for idx in candidate_indices
|
|
137
|
+
if idx < len(result_messages)
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
indices_to_drop.update(candidate_indices)
|
|
141
|
+
current_tokens -= tokens_saved
|
|
142
|
+
dropped_count += 1
|
|
143
|
+
|
|
144
|
+
if candidate["type"] == "tool_unit":
|
|
145
|
+
tool_units_dropped += 1
|
|
146
|
+
|
|
147
|
+
# Remove dropped messages (in reverse order to preserve indices)
|
|
148
|
+
for idx in sorted(indices_to_drop, reverse=True):
|
|
149
|
+
if idx < len(result_messages):
|
|
150
|
+
del result_messages[idx]
|
|
151
|
+
|
|
152
|
+
# Insert marker if we dropped anything
|
|
153
|
+
if dropped_count > 0:
|
|
154
|
+
logger.info(
|
|
155
|
+
"RollingWindow: dropped %d units (%d tool units) to fit budget: %d -> %d tokens",
|
|
156
|
+
dropped_count,
|
|
157
|
+
tool_units_dropped,
|
|
158
|
+
tokens_before,
|
|
159
|
+
current_tokens,
|
|
160
|
+
)
|
|
161
|
+
marker = create_dropped_context_marker("token_cap", dropped_count)
|
|
162
|
+
markers_inserted.append(marker)
|
|
163
|
+
|
|
164
|
+
# Insert marker after system messages
|
|
165
|
+
insert_idx = 0
|
|
166
|
+
for i, msg in enumerate(result_messages):
|
|
167
|
+
if msg.get("role") != "system":
|
|
168
|
+
insert_idx = i
|
|
169
|
+
break
|
|
170
|
+
else:
|
|
171
|
+
insert_idx = len(result_messages)
|
|
172
|
+
|
|
173
|
+
result_messages.insert(
|
|
174
|
+
insert_idx,
|
|
175
|
+
{
|
|
176
|
+
"role": "user",
|
|
177
|
+
"content": marker,
|
|
178
|
+
},
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
transforms_applied.append(f"window_cap:{dropped_count}")
|
|
182
|
+
|
|
183
|
+
tokens_after = tokenizer.count_messages(result_messages)
|
|
184
|
+
|
|
185
|
+
result = TransformResult(
|
|
186
|
+
messages=result_messages,
|
|
187
|
+
tokens_before=tokens_before,
|
|
188
|
+
tokens_after=tokens_after,
|
|
189
|
+
transforms_applied=transforms_applied,
|
|
190
|
+
markers_inserted=markers_inserted,
|
|
191
|
+
warnings=warnings,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
return result
|
|
195
|
+
|
|
196
|
+
def _get_protected_indices(self, messages: list[dict[str, Any]]) -> set[int]:
|
|
197
|
+
"""Get indices that should never be dropped."""
|
|
198
|
+
protected: set[int] = set()
|
|
199
|
+
|
|
200
|
+
# Protect system messages
|
|
201
|
+
if self.config.keep_system:
|
|
202
|
+
for i, msg in enumerate(messages):
|
|
203
|
+
if msg.get("role") == "system":
|
|
204
|
+
protected.add(i)
|
|
205
|
+
|
|
206
|
+
# Protect last N turns
|
|
207
|
+
if self.config.keep_last_turns > 0:
|
|
208
|
+
# Count turns from end (user+assistant = 1 turn)
|
|
209
|
+
turns_seen = 0
|
|
210
|
+
i = len(messages) - 1
|
|
211
|
+
|
|
212
|
+
while i >= 0 and turns_seen < self.config.keep_last_turns:
|
|
213
|
+
msg = messages[i]
|
|
214
|
+
role = msg.get("role")
|
|
215
|
+
|
|
216
|
+
# Protect this message
|
|
217
|
+
protected.add(i)
|
|
218
|
+
|
|
219
|
+
# Count turns
|
|
220
|
+
if role == "user":
|
|
221
|
+
turns_seen += 1
|
|
222
|
+
|
|
223
|
+
i -= 1
|
|
224
|
+
|
|
225
|
+
# Also protect any tool responses that belong to protected assistant messages
|
|
226
|
+
for i in list(protected):
|
|
227
|
+
msg = messages[i]
|
|
228
|
+
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
|
229
|
+
# Find and protect corresponding tool responses
|
|
230
|
+
tool_call_ids = {tc.get("id") for tc in msg.get("tool_calls", [])}
|
|
231
|
+
for j, other_msg in enumerate(messages):
|
|
232
|
+
if other_msg.get("role") == "tool":
|
|
233
|
+
if other_msg.get("tool_call_id") in tool_call_ids:
|
|
234
|
+
protected.add(j)
|
|
235
|
+
|
|
236
|
+
return protected
|
|
237
|
+
|
|
238
|
+
def _build_drop_candidates(
|
|
239
|
+
self,
|
|
240
|
+
messages: list[dict[str, Any]],
|
|
241
|
+
protected: set[int],
|
|
242
|
+
tool_units: list[tuple[int, list[int]]],
|
|
243
|
+
) -> list[dict[str, Any]]:
|
|
244
|
+
"""
|
|
245
|
+
Build ordered list of drop candidates.
|
|
246
|
+
|
|
247
|
+
Returns candidates in drop priority order (first to drop first).
|
|
248
|
+
"""
|
|
249
|
+
candidates: list[dict[str, Any]] = []
|
|
250
|
+
|
|
251
|
+
# Track which indices are part of tool units
|
|
252
|
+
tool_unit_indices: set[int] = set()
|
|
253
|
+
for assistant_idx, response_indices in tool_units:
|
|
254
|
+
tool_unit_indices.add(assistant_idx)
|
|
255
|
+
tool_unit_indices.update(response_indices)
|
|
256
|
+
|
|
257
|
+
# Priority 1: Oldest tool units (all indices as atomic unit)
|
|
258
|
+
for assistant_idx, response_indices in tool_units:
|
|
259
|
+
if assistant_idx in protected:
|
|
260
|
+
continue
|
|
261
|
+
|
|
262
|
+
all_indices = [assistant_idx] + response_indices
|
|
263
|
+
candidates.append(
|
|
264
|
+
{
|
|
265
|
+
"type": "tool_unit",
|
|
266
|
+
"indices": all_indices,
|
|
267
|
+
"priority": 1,
|
|
268
|
+
"position": assistant_idx, # For sorting by age
|
|
269
|
+
}
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
# Priority 2: Oldest non-tool messages (user/assistant pairs)
|
|
273
|
+
i = 0
|
|
274
|
+
while i < len(messages):
|
|
275
|
+
msg = messages[i]
|
|
276
|
+
role = msg.get("role")
|
|
277
|
+
|
|
278
|
+
if i in protected or i in tool_unit_indices:
|
|
279
|
+
i += 1
|
|
280
|
+
continue
|
|
281
|
+
|
|
282
|
+
if role in ("user", "assistant"):
|
|
283
|
+
# Try to find a pair
|
|
284
|
+
if role == "user" and i + 1 < len(messages):
|
|
285
|
+
next_msg = messages[i + 1]
|
|
286
|
+
if next_msg.get("role") == "assistant" and i + 1 not in tool_unit_indices:
|
|
287
|
+
candidates.append(
|
|
288
|
+
{
|
|
289
|
+
"type": "turn",
|
|
290
|
+
"indices": [i, i + 1],
|
|
291
|
+
"priority": 2,
|
|
292
|
+
"position": i,
|
|
293
|
+
}
|
|
294
|
+
)
|
|
295
|
+
i += 2
|
|
296
|
+
continue
|
|
297
|
+
|
|
298
|
+
# Single message
|
|
299
|
+
candidates.append(
|
|
300
|
+
{
|
|
301
|
+
"type": "single",
|
|
302
|
+
"indices": [i],
|
|
303
|
+
"priority": 2,
|
|
304
|
+
"position": i,
|
|
305
|
+
}
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
i += 1
|
|
309
|
+
|
|
310
|
+
# Sort by priority, then by position (oldest first)
|
|
311
|
+
candidates.sort(key=lambda c: (c["priority"], c["position"]))
|
|
312
|
+
|
|
313
|
+
return candidates
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def apply_rolling_window(
|
|
317
|
+
messages: list[dict[str, Any]],
|
|
318
|
+
model_limit: int,
|
|
319
|
+
output_buffer: int = 4000,
|
|
320
|
+
keep_last_turns: int = 2,
|
|
321
|
+
config: RollingWindowConfig | None = None,
|
|
322
|
+
) -> tuple[list[dict[str, Any]], list[str]]:
|
|
323
|
+
"""
|
|
324
|
+
Convenience function to apply rolling window.
|
|
325
|
+
|
|
326
|
+
Args:
|
|
327
|
+
messages: List of messages.
|
|
328
|
+
model_limit: Model's context limit.
|
|
329
|
+
output_buffer: Tokens to reserve for output.
|
|
330
|
+
keep_last_turns: Number of recent turns to protect.
|
|
331
|
+
config: Optional configuration.
|
|
332
|
+
|
|
333
|
+
Returns:
|
|
334
|
+
Tuple of (windowed_messages, dropped_descriptions).
|
|
335
|
+
"""
|
|
336
|
+
cfg = config or RollingWindowConfig()
|
|
337
|
+
cfg.output_buffer_tokens = output_buffer
|
|
338
|
+
cfg.keep_last_turns = keep_last_turns
|
|
339
|
+
|
|
340
|
+
window = RollingWindow(cfg)
|
|
341
|
+
tokenizer = Tokenizer(EstimatingTokenCounter()) # type: ignore[arg-type]
|
|
342
|
+
|
|
343
|
+
result = window.apply(
|
|
344
|
+
messages,
|
|
345
|
+
tokenizer,
|
|
346
|
+
model_limit=model_limit,
|
|
347
|
+
output_buffer=output_buffer,
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
return result.messages, result.transforms_applied
|