headroom-ai 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- headroom/__init__.py +212 -0
- headroom/cache/__init__.py +76 -0
- headroom/cache/anthropic.py +517 -0
- headroom/cache/base.py +342 -0
- headroom/cache/compression_feedback.py +613 -0
- headroom/cache/compression_store.py +814 -0
- headroom/cache/dynamic_detector.py +1026 -0
- headroom/cache/google.py +884 -0
- headroom/cache/openai.py +584 -0
- headroom/cache/registry.py +175 -0
- headroom/cache/semantic.py +451 -0
- headroom/ccr/__init__.py +77 -0
- headroom/ccr/context_tracker.py +582 -0
- headroom/ccr/mcp_server.py +319 -0
- headroom/ccr/response_handler.py +772 -0
- headroom/ccr/tool_injection.py +415 -0
- headroom/cli.py +219 -0
- headroom/client.py +977 -0
- headroom/compression/__init__.py +42 -0
- headroom/compression/detector.py +424 -0
- headroom/compression/handlers/__init__.py +22 -0
- headroom/compression/handlers/base.py +219 -0
- headroom/compression/handlers/code_handler.py +506 -0
- headroom/compression/handlers/json_handler.py +418 -0
- headroom/compression/masks.py +345 -0
- headroom/compression/universal.py +465 -0
- headroom/config.py +474 -0
- headroom/exceptions.py +192 -0
- headroom/integrations/__init__.py +159 -0
- headroom/integrations/agno/__init__.py +53 -0
- headroom/integrations/agno/hooks.py +345 -0
- headroom/integrations/agno/model.py +625 -0
- headroom/integrations/agno/providers.py +154 -0
- headroom/integrations/langchain/__init__.py +106 -0
- headroom/integrations/langchain/agents.py +326 -0
- headroom/integrations/langchain/chat_model.py +1002 -0
- headroom/integrations/langchain/langsmith.py +324 -0
- headroom/integrations/langchain/memory.py +319 -0
- headroom/integrations/langchain/providers.py +200 -0
- headroom/integrations/langchain/retriever.py +371 -0
- headroom/integrations/langchain/streaming.py +341 -0
- headroom/integrations/mcp/__init__.py +37 -0
- headroom/integrations/mcp/server.py +533 -0
- headroom/memory/__init__.py +37 -0
- headroom/memory/extractor.py +390 -0
- headroom/memory/fast_store.py +621 -0
- headroom/memory/fast_wrapper.py +311 -0
- headroom/memory/inline_extractor.py +229 -0
- headroom/memory/store.py +434 -0
- headroom/memory/worker.py +260 -0
- headroom/memory/wrapper.py +321 -0
- headroom/models/__init__.py +39 -0
- headroom/models/registry.py +687 -0
- headroom/parser.py +293 -0
- headroom/pricing/__init__.py +51 -0
- headroom/pricing/anthropic_prices.py +81 -0
- headroom/pricing/litellm_pricing.py +113 -0
- headroom/pricing/openai_prices.py +91 -0
- headroom/pricing/registry.py +188 -0
- headroom/providers/__init__.py +61 -0
- headroom/providers/anthropic.py +621 -0
- headroom/providers/base.py +131 -0
- headroom/providers/cohere.py +362 -0
- headroom/providers/google.py +427 -0
- headroom/providers/litellm.py +297 -0
- headroom/providers/openai.py +566 -0
- headroom/providers/openai_compatible.py +521 -0
- headroom/proxy/__init__.py +19 -0
- headroom/proxy/server.py +2683 -0
- headroom/py.typed +0 -0
- headroom/relevance/__init__.py +124 -0
- headroom/relevance/base.py +106 -0
- headroom/relevance/bm25.py +255 -0
- headroom/relevance/embedding.py +255 -0
- headroom/relevance/hybrid.py +259 -0
- headroom/reporting/__init__.py +5 -0
- headroom/reporting/generator.py +549 -0
- headroom/storage/__init__.py +41 -0
- headroom/storage/base.py +125 -0
- headroom/storage/jsonl.py +220 -0
- headroom/storage/sqlite.py +289 -0
- headroom/telemetry/__init__.py +91 -0
- headroom/telemetry/collector.py +764 -0
- headroom/telemetry/models.py +880 -0
- headroom/telemetry/toin.py +1579 -0
- headroom/tokenizer.py +80 -0
- headroom/tokenizers/__init__.py +75 -0
- headroom/tokenizers/base.py +210 -0
- headroom/tokenizers/estimator.py +198 -0
- headroom/tokenizers/huggingface.py +317 -0
- headroom/tokenizers/mistral.py +245 -0
- headroom/tokenizers/registry.py +398 -0
- headroom/tokenizers/tiktoken_counter.py +248 -0
- headroom/transforms/__init__.py +106 -0
- headroom/transforms/base.py +57 -0
- headroom/transforms/cache_aligner.py +357 -0
- headroom/transforms/code_compressor.py +1313 -0
- headroom/transforms/content_detector.py +335 -0
- headroom/transforms/content_router.py +1158 -0
- headroom/transforms/llmlingua_compressor.py +638 -0
- headroom/transforms/log_compressor.py +529 -0
- headroom/transforms/pipeline.py +297 -0
- headroom/transforms/rolling_window.py +350 -0
- headroom/transforms/search_compressor.py +365 -0
- headroom/transforms/smart_crusher.py +2682 -0
- headroom/transforms/text_compressor.py +259 -0
- headroom/transforms/tool_crusher.py +338 -0
- headroom/utils.py +215 -0
- headroom_ai-0.2.13.dist-info/METADATA +315 -0
- headroom_ai-0.2.13.dist-info/RECORD +114 -0
- headroom_ai-0.2.13.dist-info/WHEEL +4 -0
- headroom_ai-0.2.13.dist-info/entry_points.txt +2 -0
- headroom_ai-0.2.13.dist-info/licenses/LICENSE +190 -0
- headroom_ai-0.2.13.dist-info/licenses/NOTICE +43 -0
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
"""Transform modules for Headroom SDK."""
|
|
2
|
+
|
|
3
|
+
from .base import Transform
|
|
4
|
+
from .cache_aligner import CacheAligner
|
|
5
|
+
from .content_detector import ContentType, DetectionResult, detect_content_type
|
|
6
|
+
from .log_compressor import LogCompressionResult, LogCompressor, LogCompressorConfig
|
|
7
|
+
from .pipeline import TransformPipeline
|
|
8
|
+
from .rolling_window import RollingWindow
|
|
9
|
+
from .search_compressor import (
|
|
10
|
+
SearchCompressionResult,
|
|
11
|
+
SearchCompressor,
|
|
12
|
+
SearchCompressorConfig,
|
|
13
|
+
)
|
|
14
|
+
from .smart_crusher import SmartCrusher, SmartCrusherConfig
|
|
15
|
+
from .text_compressor import TextCompressionResult, TextCompressor, TextCompressorConfig
|
|
16
|
+
from .tool_crusher import ToolCrusher
|
|
17
|
+
|
|
18
|
+
# ML-based compression (optional dependency)
|
|
19
|
+
try:
|
|
20
|
+
from .llmlingua_compressor import ( # noqa: F401
|
|
21
|
+
LLMLinguaCompressor,
|
|
22
|
+
LLMLinguaConfig,
|
|
23
|
+
LLMLinguaResult,
|
|
24
|
+
compress_with_llmlingua,
|
|
25
|
+
is_llmlingua_model_loaded,
|
|
26
|
+
unload_llmlingua_model,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
_LLMLINGUA_AVAILABLE = True
|
|
30
|
+
except ImportError:
|
|
31
|
+
_LLMLINGUA_AVAILABLE = False
|
|
32
|
+
|
|
33
|
+
# AST-based code compression (optional dependency)
|
|
34
|
+
from .code_compressor import (
|
|
35
|
+
CodeAwareCompressor,
|
|
36
|
+
CodeCompressionResult,
|
|
37
|
+
CodeCompressorConfig,
|
|
38
|
+
CodeLanguage,
|
|
39
|
+
DocstringMode,
|
|
40
|
+
detect_language,
|
|
41
|
+
is_tree_sitter_available,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
# Content routing (always available, lazy-loads compressors)
|
|
45
|
+
from .content_router import (
|
|
46
|
+
CompressionStrategy,
|
|
47
|
+
ContentRouter,
|
|
48
|
+
ContentRouterConfig,
|
|
49
|
+
RouterCompressionResult,
|
|
50
|
+
generate_source_hint,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
__all__ = [
|
|
54
|
+
# Base
|
|
55
|
+
"Transform",
|
|
56
|
+
"TransformPipeline",
|
|
57
|
+
# JSON compression
|
|
58
|
+
"ToolCrusher",
|
|
59
|
+
"SmartCrusher",
|
|
60
|
+
"SmartCrusherConfig",
|
|
61
|
+
# Text compression (coding tasks)
|
|
62
|
+
"ContentType",
|
|
63
|
+
"DetectionResult",
|
|
64
|
+
"detect_content_type",
|
|
65
|
+
"SearchCompressor",
|
|
66
|
+
"SearchCompressorConfig",
|
|
67
|
+
"SearchCompressionResult",
|
|
68
|
+
"LogCompressor",
|
|
69
|
+
"LogCompressorConfig",
|
|
70
|
+
"LogCompressionResult",
|
|
71
|
+
"TextCompressor",
|
|
72
|
+
"TextCompressorConfig",
|
|
73
|
+
"TextCompressionResult",
|
|
74
|
+
# Code-aware compression (AST-based)
|
|
75
|
+
"CodeAwareCompressor",
|
|
76
|
+
"CodeCompressorConfig",
|
|
77
|
+
"CodeCompressionResult",
|
|
78
|
+
"CodeLanguage",
|
|
79
|
+
"DocstringMode",
|
|
80
|
+
"detect_language",
|
|
81
|
+
"is_tree_sitter_available",
|
|
82
|
+
# Content routing
|
|
83
|
+
"ContentRouter",
|
|
84
|
+
"ContentRouterConfig",
|
|
85
|
+
"RouterCompressionResult",
|
|
86
|
+
"CompressionStrategy",
|
|
87
|
+
"generate_source_hint",
|
|
88
|
+
# Other transforms
|
|
89
|
+
"CacheAligner",
|
|
90
|
+
"RollingWindow",
|
|
91
|
+
# ML-based compression (optional)
|
|
92
|
+
"_LLMLINGUA_AVAILABLE",
|
|
93
|
+
]
|
|
94
|
+
|
|
95
|
+
# Conditionally add LLMLingua exports
|
|
96
|
+
if _LLMLINGUA_AVAILABLE:
|
|
97
|
+
__all__.extend(
|
|
98
|
+
[
|
|
99
|
+
"LLMLinguaCompressor",
|
|
100
|
+
"LLMLinguaConfig",
|
|
101
|
+
"LLMLinguaResult",
|
|
102
|
+
"compress_with_llmlingua",
|
|
103
|
+
"is_llmlingua_model_loaded",
|
|
104
|
+
"unload_llmlingua_model",
|
|
105
|
+
]
|
|
106
|
+
)
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""Base transform interface for Headroom SDK."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from ..config import TransformResult
|
|
9
|
+
from ..tokenizer import Tokenizer
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Transform(ABC):
|
|
13
|
+
"""Abstract base class for message transforms."""
|
|
14
|
+
|
|
15
|
+
name: str = "base"
|
|
16
|
+
|
|
17
|
+
@abstractmethod
|
|
18
|
+
def apply(
|
|
19
|
+
self,
|
|
20
|
+
messages: list[dict[str, Any]],
|
|
21
|
+
tokenizer: Tokenizer,
|
|
22
|
+
**kwargs: Any,
|
|
23
|
+
) -> TransformResult:
|
|
24
|
+
"""
|
|
25
|
+
Apply the transform to messages.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
messages: List of message dicts to transform.
|
|
29
|
+
tokenizer: Tokenizer for token counting.
|
|
30
|
+
**kwargs: Additional transform-specific arguments.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
TransformResult with transformed messages and metadata.
|
|
34
|
+
"""
|
|
35
|
+
pass
|
|
36
|
+
|
|
37
|
+
def should_apply(
|
|
38
|
+
self,
|
|
39
|
+
messages: list[dict[str, Any]],
|
|
40
|
+
tokenizer: Tokenizer,
|
|
41
|
+
**kwargs: Any,
|
|
42
|
+
) -> bool:
|
|
43
|
+
"""
|
|
44
|
+
Check if this transform should be applied.
|
|
45
|
+
|
|
46
|
+
Default implementation always returns True.
|
|
47
|
+
Override in subclasses for conditional application.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
messages: List of message dicts.
|
|
51
|
+
tokenizer: Tokenizer for token counting.
|
|
52
|
+
**kwargs: Additional arguments.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
True if transform should be applied.
|
|
56
|
+
"""
|
|
57
|
+
return True
|
|
@@ -0,0 +1,357 @@
|
|
|
1
|
+
"""Cache alignment transform for Headroom SDK."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import re
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from ..config import CacheAlignerConfig, CachePrefixMetrics, TransformResult
|
|
10
|
+
from ..tokenizer import Tokenizer
|
|
11
|
+
from ..tokenizers import EstimatingTokenCounter
|
|
12
|
+
from ..utils import compute_short_hash, deep_copy_messages
|
|
13
|
+
from .base import Transform
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class CacheAligner(Transform):
|
|
19
|
+
"""
|
|
20
|
+
Align messages for optimal cache hits.
|
|
21
|
+
|
|
22
|
+
This transform:
|
|
23
|
+
1. Extracts dynamic content (dates) from system prompt
|
|
24
|
+
2. Normalizes whitespace for consistent hashing
|
|
25
|
+
3. Computes a stable prefix hash
|
|
26
|
+
|
|
27
|
+
The goal is to make the prefix byte-identical across requests
|
|
28
|
+
so that LLM provider caching can be effective.
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
name = "cache_aligner"
|
|
32
|
+
|
|
33
|
+
def __init__(self, config: CacheAlignerConfig | None = None):
|
|
34
|
+
"""
|
|
35
|
+
Initialize cache aligner.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
config: Configuration for alignment behavior.
|
|
39
|
+
"""
|
|
40
|
+
self.config = config or CacheAlignerConfig()
|
|
41
|
+
self._compiled_patterns: list[re.Pattern[str]] = []
|
|
42
|
+
self._compile_patterns()
|
|
43
|
+
# Track previous hash for cache hit detection
|
|
44
|
+
self._previous_prefix_hash: str | None = None
|
|
45
|
+
|
|
46
|
+
def _compile_patterns(self) -> None:
|
|
47
|
+
"""Compile regex patterns for efficiency."""
|
|
48
|
+
self._compiled_patterns = [re.compile(pattern) for pattern in self.config.date_patterns]
|
|
49
|
+
|
|
50
|
+
def should_apply(
|
|
51
|
+
self,
|
|
52
|
+
messages: list[dict[str, Any]],
|
|
53
|
+
tokenizer: Tokenizer,
|
|
54
|
+
**kwargs: Any,
|
|
55
|
+
) -> bool:
|
|
56
|
+
"""Check if alignment is needed."""
|
|
57
|
+
if not self.config.enabled:
|
|
58
|
+
return False
|
|
59
|
+
|
|
60
|
+
# Check if system prompt contains dynamic patterns
|
|
61
|
+
for msg in messages:
|
|
62
|
+
if msg.get("role") == "system":
|
|
63
|
+
content = msg.get("content", "")
|
|
64
|
+
if isinstance(content, str):
|
|
65
|
+
for pattern in self._compiled_patterns:
|
|
66
|
+
if pattern.search(content):
|
|
67
|
+
return True
|
|
68
|
+
|
|
69
|
+
return False
|
|
70
|
+
|
|
71
|
+
def apply(
|
|
72
|
+
self,
|
|
73
|
+
messages: list[dict[str, Any]],
|
|
74
|
+
tokenizer: Tokenizer,
|
|
75
|
+
**kwargs: Any,
|
|
76
|
+
) -> TransformResult:
|
|
77
|
+
"""
|
|
78
|
+
Apply cache alignment to messages.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
messages: List of messages.
|
|
82
|
+
tokenizer: Tokenizer for counting.
|
|
83
|
+
**kwargs: Additional arguments.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
TransformResult with aligned messages.
|
|
87
|
+
"""
|
|
88
|
+
tokens_before = tokenizer.count_messages(messages)
|
|
89
|
+
result_messages = deep_copy_messages(messages)
|
|
90
|
+
transforms_applied: list[str] = []
|
|
91
|
+
warnings: list[str] = []
|
|
92
|
+
|
|
93
|
+
extracted_dates: list[str] = []
|
|
94
|
+
|
|
95
|
+
# Process system messages
|
|
96
|
+
for msg in result_messages:
|
|
97
|
+
if msg.get("role") == "system":
|
|
98
|
+
content = msg.get("content", "")
|
|
99
|
+
if isinstance(content, str):
|
|
100
|
+
# Extract and remove date patterns
|
|
101
|
+
new_content, dates = self._extract_dates(content)
|
|
102
|
+
|
|
103
|
+
if dates:
|
|
104
|
+
extracted_dates.extend(dates)
|
|
105
|
+
msg["content"] = new_content
|
|
106
|
+
|
|
107
|
+
# Normalize whitespace if configured
|
|
108
|
+
if self.config.normalize_whitespace:
|
|
109
|
+
for msg in result_messages:
|
|
110
|
+
content = msg.get("content")
|
|
111
|
+
if isinstance(content, str):
|
|
112
|
+
msg["content"] = self._normalize_whitespace(content)
|
|
113
|
+
|
|
114
|
+
# Compute stable prefix content and hash BEFORE reinserting dates
|
|
115
|
+
# This ensures the hash is based on the static content only
|
|
116
|
+
stable_prefix_content = self._get_stable_prefix_content(result_messages)
|
|
117
|
+
stable_hash = compute_short_hash(stable_prefix_content)
|
|
118
|
+
|
|
119
|
+
# Compute cache metrics
|
|
120
|
+
prefix_bytes = len(stable_prefix_content.encode("utf-8"))
|
|
121
|
+
prefix_tokens_est = tokenizer.count_text(stable_prefix_content)
|
|
122
|
+
prefix_changed = (
|
|
123
|
+
self._previous_prefix_hash is not None and self._previous_prefix_hash != stable_hash
|
|
124
|
+
)
|
|
125
|
+
previous_hash = self._previous_prefix_hash
|
|
126
|
+
|
|
127
|
+
# Update tracking for next request
|
|
128
|
+
self._previous_prefix_hash = stable_hash
|
|
129
|
+
|
|
130
|
+
cache_metrics = CachePrefixMetrics(
|
|
131
|
+
stable_prefix_bytes=prefix_bytes,
|
|
132
|
+
stable_prefix_tokens_est=prefix_tokens_est,
|
|
133
|
+
stable_prefix_hash=stable_hash,
|
|
134
|
+
prefix_changed=prefix_changed,
|
|
135
|
+
previous_hash=previous_hash,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
# If we extracted dates, add them as dynamic context
|
|
139
|
+
if extracted_dates:
|
|
140
|
+
# Insert dates as a small user message or append to context
|
|
141
|
+
# Strategy: add as a context note after system messages
|
|
142
|
+
self._reinsert_dates(result_messages, extracted_dates)
|
|
143
|
+
transforms_applied.append("cache_align")
|
|
144
|
+
logger.debug(
|
|
145
|
+
"CacheAligner: extracted %d date patterns for cache alignment",
|
|
146
|
+
len(extracted_dates),
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
# Log cache hit/miss
|
|
150
|
+
if prefix_changed:
|
|
151
|
+
logger.debug(
|
|
152
|
+
"CacheAligner: prefix changed (likely cache miss), hash: %s -> %s",
|
|
153
|
+
previous_hash,
|
|
154
|
+
stable_hash,
|
|
155
|
+
)
|
|
156
|
+
else:
|
|
157
|
+
logger.debug("CacheAligner: prefix stable, hash: %s", stable_hash)
|
|
158
|
+
|
|
159
|
+
tokens_after = tokenizer.count_messages(result_messages)
|
|
160
|
+
|
|
161
|
+
result = TransformResult(
|
|
162
|
+
messages=result_messages,
|
|
163
|
+
tokens_before=tokens_before,
|
|
164
|
+
tokens_after=tokens_after,
|
|
165
|
+
transforms_applied=transforms_applied,
|
|
166
|
+
warnings=warnings,
|
|
167
|
+
cache_metrics=cache_metrics,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
# Store hash in flags for access by caller (backwards compatibility)
|
|
171
|
+
result.markers_inserted.append(f"stable_prefix_hash:{stable_hash}")
|
|
172
|
+
|
|
173
|
+
return result
|
|
174
|
+
|
|
175
|
+
def _extract_dates(self, content: str) -> tuple[str, list[str]]:
|
|
176
|
+
"""
|
|
177
|
+
Extract date patterns from content.
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
Tuple of (content_without_dates, list_of_extracted_dates).
|
|
181
|
+
"""
|
|
182
|
+
extracted: list[str] = []
|
|
183
|
+
result = content
|
|
184
|
+
|
|
185
|
+
for pattern in self._compiled_patterns:
|
|
186
|
+
matches = pattern.findall(result)
|
|
187
|
+
extracted.extend(matches)
|
|
188
|
+
result = pattern.sub("", result)
|
|
189
|
+
|
|
190
|
+
# Clean up any resulting empty lines
|
|
191
|
+
if extracted:
|
|
192
|
+
result = self._cleanup_empty_lines(result)
|
|
193
|
+
|
|
194
|
+
return result, extracted
|
|
195
|
+
|
|
196
|
+
def _normalize_whitespace(self, content: str) -> str:
|
|
197
|
+
"""Normalize whitespace for consistent hashing."""
|
|
198
|
+
# Normalize line endings
|
|
199
|
+
result = content.replace("\r\n", "\n").replace("\r", "\n")
|
|
200
|
+
|
|
201
|
+
# Trim trailing whitespace from lines
|
|
202
|
+
lines = result.split("\n")
|
|
203
|
+
lines = [line.rstrip() for line in lines]
|
|
204
|
+
|
|
205
|
+
# Collapse multiple blank lines if configured
|
|
206
|
+
if self.config.collapse_blank_lines:
|
|
207
|
+
new_lines: list[str] = []
|
|
208
|
+
prev_blank = False
|
|
209
|
+
for line in lines:
|
|
210
|
+
is_blank = not line.strip()
|
|
211
|
+
if is_blank and prev_blank:
|
|
212
|
+
continue
|
|
213
|
+
new_lines.append(line)
|
|
214
|
+
prev_blank = is_blank
|
|
215
|
+
lines = new_lines
|
|
216
|
+
|
|
217
|
+
return "\n".join(lines)
|
|
218
|
+
|
|
219
|
+
def _cleanup_empty_lines(self, content: str) -> str:
|
|
220
|
+
"""Remove empty lines that result from date extraction."""
|
|
221
|
+
lines = content.split("\n")
|
|
222
|
+
# Remove lines that are now empty after pattern removal
|
|
223
|
+
lines = [line for line in lines if line.strip() or line == ""]
|
|
224
|
+
|
|
225
|
+
# Collapse multiple consecutive empty lines
|
|
226
|
+
new_lines: list[str] = []
|
|
227
|
+
prev_empty = False
|
|
228
|
+
for line in lines:
|
|
229
|
+
is_empty = not line.strip()
|
|
230
|
+
if is_empty and prev_empty:
|
|
231
|
+
continue
|
|
232
|
+
new_lines.append(line)
|
|
233
|
+
prev_empty = is_empty
|
|
234
|
+
|
|
235
|
+
return "\n".join(new_lines).strip()
|
|
236
|
+
|
|
237
|
+
def _reinsert_dates(
|
|
238
|
+
self,
|
|
239
|
+
messages: list[dict[str, Any]],
|
|
240
|
+
dates: list[str],
|
|
241
|
+
) -> None:
|
|
242
|
+
"""
|
|
243
|
+
Reinsert extracted dates as dynamic context.
|
|
244
|
+
|
|
245
|
+
Strategy: Append to the end of system message with a clear separator.
|
|
246
|
+
The separator marks where static (cacheable) content ends and
|
|
247
|
+
dynamic content begins.
|
|
248
|
+
|
|
249
|
+
Note: The stable prefix hash is computed BEFORE this method is called,
|
|
250
|
+
so the hash is based on static content only.
|
|
251
|
+
"""
|
|
252
|
+
if not dates:
|
|
253
|
+
return
|
|
254
|
+
|
|
255
|
+
# Format dates as a simple note
|
|
256
|
+
date_note = ", ".join(dates)
|
|
257
|
+
separator = self.config.dynamic_tail_separator
|
|
258
|
+
|
|
259
|
+
# Find last system message and append dates
|
|
260
|
+
for msg in reversed(messages):
|
|
261
|
+
if msg.get("role") == "system":
|
|
262
|
+
content = msg.get("content", "")
|
|
263
|
+
if isinstance(content, str):
|
|
264
|
+
# Use separator to clearly mark dynamic content
|
|
265
|
+
msg["content"] = content.strip() + separator + date_note
|
|
266
|
+
break
|
|
267
|
+
|
|
268
|
+
def _get_stable_prefix_content(self, messages: list[dict[str, Any]]) -> str:
|
|
269
|
+
"""Get the stable prefix content (static portion of system messages).
|
|
270
|
+
|
|
271
|
+
Only includes content BEFORE the dynamic_tail_separator in each
|
|
272
|
+
system message. This ensures the content is stable across different
|
|
273
|
+
dates/dynamic content.
|
|
274
|
+
"""
|
|
275
|
+
prefix_parts: list[str] = []
|
|
276
|
+
separator = self.config.dynamic_tail_separator
|
|
277
|
+
|
|
278
|
+
for msg in messages:
|
|
279
|
+
if msg.get("role") == "system":
|
|
280
|
+
content = msg.get("content", "")
|
|
281
|
+
if isinstance(content, str):
|
|
282
|
+
# Only include content BEFORE the dynamic separator
|
|
283
|
+
if separator in content:
|
|
284
|
+
content = content.split(separator)[0]
|
|
285
|
+
prefix_parts.append(content.strip())
|
|
286
|
+
else:
|
|
287
|
+
# Stop at first non-system message
|
|
288
|
+
break
|
|
289
|
+
|
|
290
|
+
return "\n---\n".join(prefix_parts)
|
|
291
|
+
|
|
292
|
+
def _compute_stable_prefix_hash(self, messages: list[dict[str, Any]]) -> str:
|
|
293
|
+
"""Compute hash of the stable prefix portion.
|
|
294
|
+
|
|
295
|
+
Only includes content BEFORE the dynamic_tail_separator in each
|
|
296
|
+
system message. This ensures the hash is stable across different
|
|
297
|
+
dates/dynamic content.
|
|
298
|
+
"""
|
|
299
|
+
prefix_content = self._get_stable_prefix_content(messages)
|
|
300
|
+
return compute_short_hash(prefix_content)
|
|
301
|
+
|
|
302
|
+
def get_alignment_score(self, messages: list[dict[str, Any]]) -> float:
|
|
303
|
+
"""
|
|
304
|
+
Compute cache alignment score (0-100).
|
|
305
|
+
|
|
306
|
+
Higher score means better cache alignment potential.
|
|
307
|
+
"""
|
|
308
|
+
score = 100.0
|
|
309
|
+
|
|
310
|
+
for msg in messages:
|
|
311
|
+
if msg.get("role") == "system":
|
|
312
|
+
content = msg.get("content", "")
|
|
313
|
+
if isinstance(content, str):
|
|
314
|
+
# Penalize for each dynamic pattern found
|
|
315
|
+
for pattern in self._compiled_patterns:
|
|
316
|
+
matches = pattern.findall(content)
|
|
317
|
+
score -= len(matches) * 10
|
|
318
|
+
|
|
319
|
+
# Penalize for inconsistent whitespace
|
|
320
|
+
if "\r" in content:
|
|
321
|
+
score -= 5
|
|
322
|
+
if " " in content: # Double spaces
|
|
323
|
+
score -= 2
|
|
324
|
+
if "\n\n\n" in content: # Triple newlines
|
|
325
|
+
score -= 2
|
|
326
|
+
|
|
327
|
+
return max(0.0, min(100.0, score))
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def align_for_cache(
|
|
331
|
+
messages: list[dict[str, Any]],
|
|
332
|
+
config: CacheAlignerConfig | None = None,
|
|
333
|
+
) -> tuple[list[dict[str, Any]], str]:
|
|
334
|
+
"""
|
|
335
|
+
Convenience function to align messages for cache.
|
|
336
|
+
|
|
337
|
+
Args:
|
|
338
|
+
messages: List of messages.
|
|
339
|
+
config: Optional configuration.
|
|
340
|
+
|
|
341
|
+
Returns:
|
|
342
|
+
Tuple of (aligned_messages, stable_prefix_hash).
|
|
343
|
+
"""
|
|
344
|
+
cfg = config or CacheAlignerConfig()
|
|
345
|
+
aligner = CacheAligner(cfg)
|
|
346
|
+
tokenizer = Tokenizer(EstimatingTokenCounter()) # type: ignore[arg-type]
|
|
347
|
+
|
|
348
|
+
result = aligner.apply(messages, tokenizer)
|
|
349
|
+
|
|
350
|
+
# Extract hash from markers
|
|
351
|
+
stable_hash = ""
|
|
352
|
+
for marker in result.markers_inserted:
|
|
353
|
+
if marker.startswith("stable_prefix_hash:"):
|
|
354
|
+
stable_hash = marker.split(":", 1)[1]
|
|
355
|
+
break
|
|
356
|
+
|
|
357
|
+
return result.messages, stable_hash
|