headroom-ai 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (114) hide show
  1. headroom/__init__.py +212 -0
  2. headroom/cache/__init__.py +76 -0
  3. headroom/cache/anthropic.py +517 -0
  4. headroom/cache/base.py +342 -0
  5. headroom/cache/compression_feedback.py +613 -0
  6. headroom/cache/compression_store.py +814 -0
  7. headroom/cache/dynamic_detector.py +1026 -0
  8. headroom/cache/google.py +884 -0
  9. headroom/cache/openai.py +584 -0
  10. headroom/cache/registry.py +175 -0
  11. headroom/cache/semantic.py +451 -0
  12. headroom/ccr/__init__.py +77 -0
  13. headroom/ccr/context_tracker.py +582 -0
  14. headroom/ccr/mcp_server.py +319 -0
  15. headroom/ccr/response_handler.py +772 -0
  16. headroom/ccr/tool_injection.py +415 -0
  17. headroom/cli.py +219 -0
  18. headroom/client.py +977 -0
  19. headroom/compression/__init__.py +42 -0
  20. headroom/compression/detector.py +424 -0
  21. headroom/compression/handlers/__init__.py +22 -0
  22. headroom/compression/handlers/base.py +219 -0
  23. headroom/compression/handlers/code_handler.py +506 -0
  24. headroom/compression/handlers/json_handler.py +418 -0
  25. headroom/compression/masks.py +345 -0
  26. headroom/compression/universal.py +465 -0
  27. headroom/config.py +474 -0
  28. headroom/exceptions.py +192 -0
  29. headroom/integrations/__init__.py +159 -0
  30. headroom/integrations/agno/__init__.py +53 -0
  31. headroom/integrations/agno/hooks.py +345 -0
  32. headroom/integrations/agno/model.py +625 -0
  33. headroom/integrations/agno/providers.py +154 -0
  34. headroom/integrations/langchain/__init__.py +106 -0
  35. headroom/integrations/langchain/agents.py +326 -0
  36. headroom/integrations/langchain/chat_model.py +1002 -0
  37. headroom/integrations/langchain/langsmith.py +324 -0
  38. headroom/integrations/langchain/memory.py +319 -0
  39. headroom/integrations/langchain/providers.py +200 -0
  40. headroom/integrations/langchain/retriever.py +371 -0
  41. headroom/integrations/langchain/streaming.py +341 -0
  42. headroom/integrations/mcp/__init__.py +37 -0
  43. headroom/integrations/mcp/server.py +533 -0
  44. headroom/memory/__init__.py +37 -0
  45. headroom/memory/extractor.py +390 -0
  46. headroom/memory/fast_store.py +621 -0
  47. headroom/memory/fast_wrapper.py +311 -0
  48. headroom/memory/inline_extractor.py +229 -0
  49. headroom/memory/store.py +434 -0
  50. headroom/memory/worker.py +260 -0
  51. headroom/memory/wrapper.py +321 -0
  52. headroom/models/__init__.py +39 -0
  53. headroom/models/registry.py +687 -0
  54. headroom/parser.py +293 -0
  55. headroom/pricing/__init__.py +51 -0
  56. headroom/pricing/anthropic_prices.py +81 -0
  57. headroom/pricing/litellm_pricing.py +113 -0
  58. headroom/pricing/openai_prices.py +91 -0
  59. headroom/pricing/registry.py +188 -0
  60. headroom/providers/__init__.py +61 -0
  61. headroom/providers/anthropic.py +621 -0
  62. headroom/providers/base.py +131 -0
  63. headroom/providers/cohere.py +362 -0
  64. headroom/providers/google.py +427 -0
  65. headroom/providers/litellm.py +297 -0
  66. headroom/providers/openai.py +566 -0
  67. headroom/providers/openai_compatible.py +521 -0
  68. headroom/proxy/__init__.py +19 -0
  69. headroom/proxy/server.py +2683 -0
  70. headroom/py.typed +0 -0
  71. headroom/relevance/__init__.py +124 -0
  72. headroom/relevance/base.py +106 -0
  73. headroom/relevance/bm25.py +255 -0
  74. headroom/relevance/embedding.py +255 -0
  75. headroom/relevance/hybrid.py +259 -0
  76. headroom/reporting/__init__.py +5 -0
  77. headroom/reporting/generator.py +549 -0
  78. headroom/storage/__init__.py +41 -0
  79. headroom/storage/base.py +125 -0
  80. headroom/storage/jsonl.py +220 -0
  81. headroom/storage/sqlite.py +289 -0
  82. headroom/telemetry/__init__.py +91 -0
  83. headroom/telemetry/collector.py +764 -0
  84. headroom/telemetry/models.py +880 -0
  85. headroom/telemetry/toin.py +1579 -0
  86. headroom/tokenizer.py +80 -0
  87. headroom/tokenizers/__init__.py +75 -0
  88. headroom/tokenizers/base.py +210 -0
  89. headroom/tokenizers/estimator.py +198 -0
  90. headroom/tokenizers/huggingface.py +317 -0
  91. headroom/tokenizers/mistral.py +245 -0
  92. headroom/tokenizers/registry.py +398 -0
  93. headroom/tokenizers/tiktoken_counter.py +248 -0
  94. headroom/transforms/__init__.py +106 -0
  95. headroom/transforms/base.py +57 -0
  96. headroom/transforms/cache_aligner.py +357 -0
  97. headroom/transforms/code_compressor.py +1313 -0
  98. headroom/transforms/content_detector.py +335 -0
  99. headroom/transforms/content_router.py +1158 -0
  100. headroom/transforms/llmlingua_compressor.py +638 -0
  101. headroom/transforms/log_compressor.py +529 -0
  102. headroom/transforms/pipeline.py +297 -0
  103. headroom/transforms/rolling_window.py +350 -0
  104. headroom/transforms/search_compressor.py +365 -0
  105. headroom/transforms/smart_crusher.py +2682 -0
  106. headroom/transforms/text_compressor.py +259 -0
  107. headroom/transforms/tool_crusher.py +338 -0
  108. headroom/utils.py +215 -0
  109. headroom_ai-0.2.13.dist-info/METADATA +315 -0
  110. headroom_ai-0.2.13.dist-info/RECORD +114 -0
  111. headroom_ai-0.2.13.dist-info/WHEEL +4 -0
  112. headroom_ai-0.2.13.dist-info/entry_points.txt +2 -0
  113. headroom_ai-0.2.13.dist-info/licenses/LICENSE +190 -0
  114. headroom_ai-0.2.13.dist-info/licenses/NOTICE +43 -0
@@ -0,0 +1,297 @@
1
+ """Transform pipeline orchestration for Headroom SDK."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from typing import TYPE_CHECKING, Any
7
+
8
+ from ..config import (
9
+ CacheAlignerConfig,
10
+ DiffArtifact,
11
+ HeadroomConfig,
12
+ RollingWindowConfig,
13
+ ToolCrusherConfig,
14
+ TransformDiff,
15
+ TransformResult,
16
+ )
17
+ from ..tokenizer import Tokenizer
18
+ from ..utils import deep_copy_messages
19
+ from .base import Transform
20
+ from .cache_aligner import CacheAligner
21
+ from .rolling_window import RollingWindow
22
+ from .smart_crusher import SmartCrusher
23
+ from .tool_crusher import ToolCrusher
24
+
25
+ if TYPE_CHECKING:
26
+ from ..providers.base import Provider
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class TransformPipeline:
32
+ """
33
+ Orchestrates multiple transforms in the correct order.
34
+
35
+ Transform order:
36
+ 1. Cache Aligner - normalize prefix for cache hits
37
+ 2. Tool Crusher - compress tool outputs
38
+ 3. Rolling Window - enforce token limits
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ config: HeadroomConfig | None = None,
44
+ transforms: list[Transform] | None = None,
45
+ provider: Provider | None = None,
46
+ ):
47
+ """
48
+ Initialize pipeline.
49
+
50
+ Args:
51
+ config: Headroom configuration.
52
+ transforms: Optional custom transform list (overrides config).
53
+ provider: Provider for model-specific behavior.
54
+ """
55
+ self.config = config or HeadroomConfig()
56
+ self._provider = provider
57
+
58
+ if transforms is not None:
59
+ self.transforms = transforms
60
+ else:
61
+ self.transforms = self._build_default_transforms()
62
+
63
+ def _build_default_transforms(self) -> list[Transform]:
64
+ """Build default transform pipeline from config."""
65
+ transforms: list[Transform] = []
66
+
67
+ # Order matters!
68
+
69
+ # 1. Cache Aligner (prefix stabilization)
70
+ if self.config.cache_aligner.enabled:
71
+ transforms.append(CacheAligner(self.config.cache_aligner))
72
+
73
+ # 2. Tool Output Compression
74
+ # SmartCrusher (statistical) takes precedence over ToolCrusher (fixed rules)
75
+ if self.config.smart_crusher.enabled:
76
+ # Use smart statistical crushing
77
+ from .smart_crusher import SmartCrusherConfig as SCConfig
78
+
79
+ smart_config = SCConfig(
80
+ enabled=True,
81
+ min_items_to_analyze=self.config.smart_crusher.min_items_to_analyze,
82
+ min_tokens_to_crush=self.config.smart_crusher.min_tokens_to_crush,
83
+ variance_threshold=self.config.smart_crusher.variance_threshold,
84
+ uniqueness_threshold=self.config.smart_crusher.uniqueness_threshold,
85
+ similarity_threshold=self.config.smart_crusher.similarity_threshold,
86
+ max_items_after_crush=self.config.smart_crusher.max_items_after_crush,
87
+ preserve_change_points=self.config.smart_crusher.preserve_change_points,
88
+ factor_out_constants=self.config.smart_crusher.factor_out_constants,
89
+ include_summaries=self.config.smart_crusher.include_summaries,
90
+ )
91
+ transforms.append(SmartCrusher(smart_config))
92
+ elif self.config.tool_crusher.enabled:
93
+ # Fallback to fixed-rule crushing
94
+ transforms.append(ToolCrusher(self.config.tool_crusher))
95
+
96
+ # 3. Rolling Window (enforce limits last)
97
+ if self.config.rolling_window.enabled:
98
+ transforms.append(RollingWindow(self.config.rolling_window))
99
+
100
+ return transforms
101
+
102
+ def _get_tokenizer(self, model: str) -> Tokenizer:
103
+ """Get tokenizer for model using provider."""
104
+ if self._provider is None:
105
+ raise ValueError(
106
+ "Provider is required for token counting. "
107
+ "Pass a provider to TransformPipeline or HeadroomClient."
108
+ )
109
+ token_counter = self._provider.get_token_counter(model)
110
+ return Tokenizer(token_counter, model)
111
+
112
+ def apply(
113
+ self,
114
+ messages: list[dict[str, Any]],
115
+ model: str,
116
+ **kwargs: Any,
117
+ ) -> TransformResult:
118
+ """
119
+ Apply all transforms in sequence.
120
+
121
+ Args:
122
+ messages: List of messages to transform.
123
+ model: Model name for token counting.
124
+ **kwargs: Additional arguments passed to transforms.
125
+ - model_limit: Context limit override.
126
+ - output_buffer: Output buffer override.
127
+ - tool_profiles: Per-tool compression profiles.
128
+ - request_id: Optional request ID for diff artifact.
129
+
130
+ Returns:
131
+ Combined TransformResult.
132
+ """
133
+ tokenizer = self._get_tokenizer(model)
134
+
135
+ # Get model limit from kwargs (should be set by client)
136
+ model_limit = kwargs.get("model_limit")
137
+ if model_limit is None:
138
+ raise ValueError(
139
+ "model_limit is required. Provide it via kwargs or "
140
+ "configure model_context_limits in HeadroomClient."
141
+ )
142
+
143
+ # Start with original tokens
144
+ tokens_before = tokenizer.count_messages(messages)
145
+
146
+ logger.debug(
147
+ "Pipeline starting: %d messages, %d tokens, model=%s",
148
+ len(messages),
149
+ tokens_before,
150
+ model,
151
+ )
152
+
153
+ # Track all transforms applied
154
+ all_transforms: list[str] = []
155
+ all_markers: list[str] = []
156
+ all_warnings: list[str] = []
157
+
158
+ # Track transform diffs if enabled
159
+ transform_diffs: list[TransformDiff] = []
160
+ generate_diff = self.config.generate_diff_artifact
161
+
162
+ current_messages = deep_copy_messages(messages)
163
+
164
+ for transform in self.transforms:
165
+ # Check if transform should run
166
+ if not transform.should_apply(current_messages, tokenizer, **kwargs):
167
+ continue
168
+
169
+ # Track tokens before this transform (for diff)
170
+ tokens_before_transform = tokenizer.count_messages(current_messages)
171
+
172
+ # Apply transform
173
+ result = transform.apply(current_messages, tokenizer, **kwargs)
174
+
175
+ # Update messages for next transform
176
+ current_messages = result.messages
177
+
178
+ # Track tokens after this transform (for diff)
179
+ tokens_after_transform = tokenizer.count_messages(current_messages)
180
+
181
+ # Accumulate results
182
+ all_transforms.extend(result.transforms_applied)
183
+ all_markers.extend(result.markers_inserted)
184
+ all_warnings.extend(result.warnings)
185
+
186
+ # Log transform results
187
+ if result.transforms_applied:
188
+ logger.info(
189
+ "Transform %s: %d -> %d tokens (saved %d)",
190
+ transform.name,
191
+ tokens_before_transform,
192
+ tokens_after_transform,
193
+ tokens_before_transform - tokens_after_transform,
194
+ )
195
+ else:
196
+ logger.debug("Transform %s: no changes", transform.name)
197
+
198
+ # Record diff if enabled
199
+ if generate_diff:
200
+ transform_diffs.append(
201
+ TransformDiff(
202
+ transform_name=transform.name,
203
+ tokens_before=tokens_before_transform,
204
+ tokens_after=tokens_after_transform,
205
+ tokens_saved=tokens_before_transform - tokens_after_transform,
206
+ details=", ".join(result.transforms_applied)
207
+ if result.transforms_applied
208
+ else "",
209
+ )
210
+ )
211
+
212
+ # Final token count
213
+ tokens_after = tokenizer.count_messages(current_messages)
214
+
215
+ # Log pipeline summary
216
+ total_saved = tokens_before - tokens_after
217
+ if total_saved > 0:
218
+ logger.info(
219
+ "Pipeline complete: %d -> %d tokens (saved %d, %.1f%% reduction)",
220
+ tokens_before,
221
+ tokens_after,
222
+ total_saved,
223
+ (total_saved / tokens_before * 100) if tokens_before > 0 else 0,
224
+ )
225
+ else:
226
+ logger.debug("Pipeline complete: no token savings")
227
+
228
+ # Build diff artifact if enabled
229
+ diff_artifact = None
230
+ if generate_diff:
231
+ diff_artifact = DiffArtifact(
232
+ request_id=kwargs.get("request_id", ""),
233
+ original_tokens=tokens_before,
234
+ optimized_tokens=tokens_after,
235
+ total_tokens_saved=tokens_before - tokens_after,
236
+ transforms=transform_diffs,
237
+ )
238
+
239
+ return TransformResult(
240
+ messages=current_messages,
241
+ tokens_before=tokens_before,
242
+ tokens_after=tokens_after,
243
+ transforms_applied=all_transforms,
244
+ markers_inserted=all_markers,
245
+ warnings=all_warnings,
246
+ diff_artifact=diff_artifact,
247
+ )
248
+
249
+ def simulate(
250
+ self,
251
+ messages: list[dict[str, Any]],
252
+ model: str,
253
+ **kwargs: Any,
254
+ ) -> TransformResult:
255
+ """
256
+ Simulate transforms without modifying messages.
257
+
258
+ Same as apply() but returns what WOULD happen.
259
+
260
+ Args:
261
+ messages: List of messages.
262
+ model: Model name.
263
+ **kwargs: Additional arguments.
264
+
265
+ Returns:
266
+ TransformResult with simulated changes.
267
+ """
268
+ # apply() already works on a copy, so this is safe
269
+ return self.apply(messages, model, **kwargs)
270
+
271
+
272
+ def create_pipeline(
273
+ tool_crusher_config: ToolCrusherConfig | None = None,
274
+ cache_aligner_config: CacheAlignerConfig | None = None,
275
+ rolling_window_config: RollingWindowConfig | None = None,
276
+ ) -> TransformPipeline:
277
+ """
278
+ Create a pipeline with specific configurations.
279
+
280
+ Args:
281
+ tool_crusher_config: Tool crusher configuration.
282
+ cache_aligner_config: Cache aligner configuration.
283
+ rolling_window_config: Rolling window configuration.
284
+
285
+ Returns:
286
+ Configured TransformPipeline.
287
+ """
288
+ config = HeadroomConfig()
289
+
290
+ if tool_crusher_config is not None:
291
+ config.tool_crusher = tool_crusher_config
292
+ if cache_aligner_config is not None:
293
+ config.cache_aligner = cache_aligner_config
294
+ if rolling_window_config is not None:
295
+ config.rolling_window = rolling_window_config
296
+
297
+ return TransformPipeline(config)
@@ -0,0 +1,350 @@
1
+ """Rolling window transform for Headroom SDK."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from typing import Any
7
+
8
+ from ..config import RollingWindowConfig, TransformResult
9
+ from ..parser import find_tool_units
10
+ from ..tokenizer import Tokenizer
11
+ from ..tokenizers import EstimatingTokenCounter
12
+ from ..utils import create_dropped_context_marker, deep_copy_messages
13
+ from .base import Transform
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class RollingWindow(Transform):
19
+ """
20
+ Apply rolling window to keep messages within token budget.
21
+
22
+ Drop order (deterministic):
23
+ 1. Oldest TOOL UNITS (assistant+tool_calls paired with tool responses)
24
+ 2. Oldest assistant+user pairs
25
+ 3. Oldest RAG blocks (if detectable)
26
+
27
+ CRITICAL: Tool calls and tool results are atomic DROP UNITS.
28
+ Never orphan a tool result.
29
+
30
+ Never drops:
31
+ - System prompt
32
+ - Stable instructions
33
+ - Last N conversational turns (configurable)
34
+ """
35
+
36
+ name = "rolling_window"
37
+
38
+ def __init__(self, config: RollingWindowConfig | None = None):
39
+ """
40
+ Initialize rolling window.
41
+
42
+ Args:
43
+ config: Configuration for window behavior.
44
+ """
45
+ self.config = config or RollingWindowConfig()
46
+
47
+ def should_apply(
48
+ self,
49
+ messages: list[dict[str, Any]],
50
+ tokenizer: Tokenizer,
51
+ **kwargs: Any,
52
+ ) -> bool:
53
+ """Check if token cap is exceeded."""
54
+ if not self.config.enabled:
55
+ return False
56
+
57
+ model_limit = kwargs.get("model_limit", 128000)
58
+ output_buffer = kwargs.get("output_buffer", self.config.output_buffer_tokens)
59
+
60
+ current_tokens = tokenizer.count_messages(messages)
61
+ available = model_limit - output_buffer
62
+
63
+ return bool(current_tokens > available)
64
+
65
+ def apply(
66
+ self,
67
+ messages: list[dict[str, Any]],
68
+ tokenizer: Tokenizer,
69
+ **kwargs: Any,
70
+ ) -> TransformResult:
71
+ """
72
+ Apply rolling window to messages.
73
+
74
+ Args:
75
+ messages: List of messages.
76
+ tokenizer: Tokenizer for counting.
77
+ **kwargs: Must include 'model_limit', optionally 'output_buffer'.
78
+
79
+ Returns:
80
+ TransformResult with windowed messages.
81
+ """
82
+ model_limit = kwargs.get("model_limit", 128000)
83
+ output_buffer = kwargs.get("output_buffer", self.config.output_buffer_tokens)
84
+ available = model_limit - output_buffer
85
+
86
+ tokens_before = tokenizer.count_messages(messages)
87
+ result_messages = deep_copy_messages(messages)
88
+ transforms_applied: list[str] = []
89
+ markers_inserted: list[str] = []
90
+ warnings: list[str] = []
91
+
92
+ dropped_count = 0
93
+ tool_units_dropped = 0
94
+
95
+ # If already under budget, no changes needed
96
+ current_tokens = tokens_before
97
+ if current_tokens <= available:
98
+ return TransformResult(
99
+ messages=result_messages,
100
+ tokens_before=tokens_before,
101
+ tokens_after=tokens_before,
102
+ transforms_applied=[],
103
+ warnings=[],
104
+ )
105
+
106
+ # Identify protected indices
107
+ protected = self._get_protected_indices(result_messages)
108
+
109
+ # Identify tool units
110
+ tool_units = find_tool_units(result_messages)
111
+
112
+ # Create drop candidates with priorities
113
+ drop_candidates = self._build_drop_candidates(result_messages, protected, tool_units)
114
+
115
+ # Drop until under budget
116
+ indices_to_drop: set[int] = set()
117
+
118
+ for candidate in drop_candidates:
119
+ if current_tokens <= available:
120
+ break
121
+
122
+ # Get indices for this candidate
123
+ candidate_indices = candidate["indices"]
124
+
125
+ # Skip if any are protected
126
+ if any(idx in protected for idx in candidate_indices):
127
+ continue
128
+
129
+ # Skip if already dropped
130
+ if any(idx in indices_to_drop for idx in candidate_indices):
131
+ continue
132
+
133
+ # Calculate tokens saved
134
+ tokens_saved = sum(
135
+ tokenizer.count_message(result_messages[idx])
136
+ for idx in candidate_indices
137
+ if idx < len(result_messages)
138
+ )
139
+
140
+ indices_to_drop.update(candidate_indices)
141
+ current_tokens -= tokens_saved
142
+ dropped_count += 1
143
+
144
+ if candidate["type"] == "tool_unit":
145
+ tool_units_dropped += 1
146
+
147
+ # Remove dropped messages (in reverse order to preserve indices)
148
+ for idx in sorted(indices_to_drop, reverse=True):
149
+ if idx < len(result_messages):
150
+ del result_messages[idx]
151
+
152
+ # Insert marker if we dropped anything
153
+ if dropped_count > 0:
154
+ logger.info(
155
+ "RollingWindow: dropped %d units (%d tool units) to fit budget: %d -> %d tokens",
156
+ dropped_count,
157
+ tool_units_dropped,
158
+ tokens_before,
159
+ current_tokens,
160
+ )
161
+ marker = create_dropped_context_marker("token_cap", dropped_count)
162
+ markers_inserted.append(marker)
163
+
164
+ # Insert marker after system messages
165
+ insert_idx = 0
166
+ for i, msg in enumerate(result_messages):
167
+ if msg.get("role") != "system":
168
+ insert_idx = i
169
+ break
170
+ else:
171
+ insert_idx = len(result_messages)
172
+
173
+ result_messages.insert(
174
+ insert_idx,
175
+ {
176
+ "role": "user",
177
+ "content": marker,
178
+ },
179
+ )
180
+
181
+ transforms_applied.append(f"window_cap:{dropped_count}")
182
+
183
+ tokens_after = tokenizer.count_messages(result_messages)
184
+
185
+ result = TransformResult(
186
+ messages=result_messages,
187
+ tokens_before=tokens_before,
188
+ tokens_after=tokens_after,
189
+ transforms_applied=transforms_applied,
190
+ markers_inserted=markers_inserted,
191
+ warnings=warnings,
192
+ )
193
+
194
+ return result
195
+
196
+ def _get_protected_indices(self, messages: list[dict[str, Any]]) -> set[int]:
197
+ """Get indices that should never be dropped."""
198
+ protected: set[int] = set()
199
+
200
+ # Protect system messages
201
+ if self.config.keep_system:
202
+ for i, msg in enumerate(messages):
203
+ if msg.get("role") == "system":
204
+ protected.add(i)
205
+
206
+ # Protect last N turns
207
+ if self.config.keep_last_turns > 0:
208
+ # Count turns from end (user+assistant = 1 turn)
209
+ turns_seen = 0
210
+ i = len(messages) - 1
211
+
212
+ while i >= 0 and turns_seen < self.config.keep_last_turns:
213
+ msg = messages[i]
214
+ role = msg.get("role")
215
+
216
+ # Protect this message
217
+ protected.add(i)
218
+
219
+ # Count turns
220
+ if role == "user":
221
+ turns_seen += 1
222
+
223
+ i -= 1
224
+
225
+ # Also protect any tool responses that belong to protected assistant messages
226
+ for i in list(protected):
227
+ msg = messages[i]
228
+ if msg.get("role") == "assistant" and msg.get("tool_calls"):
229
+ # Find and protect corresponding tool responses
230
+ tool_call_ids = {tc.get("id") for tc in msg.get("tool_calls", [])}
231
+ for j, other_msg in enumerate(messages):
232
+ if other_msg.get("role") == "tool":
233
+ if other_msg.get("tool_call_id") in tool_call_ids:
234
+ protected.add(j)
235
+
236
+ return protected
237
+
238
+ def _build_drop_candidates(
239
+ self,
240
+ messages: list[dict[str, Any]],
241
+ protected: set[int],
242
+ tool_units: list[tuple[int, list[int]]],
243
+ ) -> list[dict[str, Any]]:
244
+ """
245
+ Build ordered list of drop candidates.
246
+
247
+ Returns candidates in drop priority order (first to drop first).
248
+ """
249
+ candidates: list[dict[str, Any]] = []
250
+
251
+ # Track which indices are part of tool units
252
+ tool_unit_indices: set[int] = set()
253
+ for assistant_idx, response_indices in tool_units:
254
+ tool_unit_indices.add(assistant_idx)
255
+ tool_unit_indices.update(response_indices)
256
+
257
+ # Priority 1: Oldest tool units (all indices as atomic unit)
258
+ for assistant_idx, response_indices in tool_units:
259
+ if assistant_idx in protected:
260
+ continue
261
+
262
+ all_indices = [assistant_idx] + response_indices
263
+ candidates.append(
264
+ {
265
+ "type": "tool_unit",
266
+ "indices": all_indices,
267
+ "priority": 1,
268
+ "position": assistant_idx, # For sorting by age
269
+ }
270
+ )
271
+
272
+ # Priority 2: Oldest non-tool messages (user/assistant pairs)
273
+ i = 0
274
+ while i < len(messages):
275
+ msg = messages[i]
276
+ role = msg.get("role")
277
+
278
+ if i in protected or i in tool_unit_indices:
279
+ i += 1
280
+ continue
281
+
282
+ if role in ("user", "assistant"):
283
+ # Try to find a pair
284
+ if role == "user" and i + 1 < len(messages):
285
+ next_msg = messages[i + 1]
286
+ if next_msg.get("role") == "assistant" and i + 1 not in tool_unit_indices:
287
+ candidates.append(
288
+ {
289
+ "type": "turn",
290
+ "indices": [i, i + 1],
291
+ "priority": 2,
292
+ "position": i,
293
+ }
294
+ )
295
+ i += 2
296
+ continue
297
+
298
+ # Single message
299
+ candidates.append(
300
+ {
301
+ "type": "single",
302
+ "indices": [i],
303
+ "priority": 2,
304
+ "position": i,
305
+ }
306
+ )
307
+
308
+ i += 1
309
+
310
+ # Sort by priority, then by position (oldest first)
311
+ candidates.sort(key=lambda c: (c["priority"], c["position"]))
312
+
313
+ return candidates
314
+
315
+
316
+ def apply_rolling_window(
317
+ messages: list[dict[str, Any]],
318
+ model_limit: int,
319
+ output_buffer: int = 4000,
320
+ keep_last_turns: int = 2,
321
+ config: RollingWindowConfig | None = None,
322
+ ) -> tuple[list[dict[str, Any]], list[str]]:
323
+ """
324
+ Convenience function to apply rolling window.
325
+
326
+ Args:
327
+ messages: List of messages.
328
+ model_limit: Model's context limit.
329
+ output_buffer: Tokens to reserve for output.
330
+ keep_last_turns: Number of recent turns to protect.
331
+ config: Optional configuration.
332
+
333
+ Returns:
334
+ Tuple of (windowed_messages, dropped_descriptions).
335
+ """
336
+ cfg = config or RollingWindowConfig()
337
+ cfg.output_buffer_tokens = output_buffer
338
+ cfg.keep_last_turns = keep_last_turns
339
+
340
+ window = RollingWindow(cfg)
341
+ tokenizer = Tokenizer(EstimatingTokenCounter()) # type: ignore[arg-type]
342
+
343
+ result = window.apply(
344
+ messages,
345
+ tokenizer,
346
+ model_limit=model_limit,
347
+ output_buffer=output_buffer,
348
+ )
349
+
350
+ return result.messages, result.transforms_applied