headroom-ai 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- headroom/__init__.py +212 -0
- headroom/cache/__init__.py +76 -0
- headroom/cache/anthropic.py +517 -0
- headroom/cache/base.py +342 -0
- headroom/cache/compression_feedback.py +613 -0
- headroom/cache/compression_store.py +814 -0
- headroom/cache/dynamic_detector.py +1026 -0
- headroom/cache/google.py +884 -0
- headroom/cache/openai.py +584 -0
- headroom/cache/registry.py +175 -0
- headroom/cache/semantic.py +451 -0
- headroom/ccr/__init__.py +77 -0
- headroom/ccr/context_tracker.py +582 -0
- headroom/ccr/mcp_server.py +319 -0
- headroom/ccr/response_handler.py +772 -0
- headroom/ccr/tool_injection.py +415 -0
- headroom/cli.py +219 -0
- headroom/client.py +977 -0
- headroom/compression/__init__.py +42 -0
- headroom/compression/detector.py +424 -0
- headroom/compression/handlers/__init__.py +22 -0
- headroom/compression/handlers/base.py +219 -0
- headroom/compression/handlers/code_handler.py +506 -0
- headroom/compression/handlers/json_handler.py +418 -0
- headroom/compression/masks.py +345 -0
- headroom/compression/universal.py +465 -0
- headroom/config.py +474 -0
- headroom/exceptions.py +192 -0
- headroom/integrations/__init__.py +159 -0
- headroom/integrations/agno/__init__.py +53 -0
- headroom/integrations/agno/hooks.py +345 -0
- headroom/integrations/agno/model.py +625 -0
- headroom/integrations/agno/providers.py +154 -0
- headroom/integrations/langchain/__init__.py +106 -0
- headroom/integrations/langchain/agents.py +326 -0
- headroom/integrations/langchain/chat_model.py +1002 -0
- headroom/integrations/langchain/langsmith.py +324 -0
- headroom/integrations/langchain/memory.py +319 -0
- headroom/integrations/langchain/providers.py +200 -0
- headroom/integrations/langchain/retriever.py +371 -0
- headroom/integrations/langchain/streaming.py +341 -0
- headroom/integrations/mcp/__init__.py +37 -0
- headroom/integrations/mcp/server.py +533 -0
- headroom/memory/__init__.py +37 -0
- headroom/memory/extractor.py +390 -0
- headroom/memory/fast_store.py +621 -0
- headroom/memory/fast_wrapper.py +311 -0
- headroom/memory/inline_extractor.py +229 -0
- headroom/memory/store.py +434 -0
- headroom/memory/worker.py +260 -0
- headroom/memory/wrapper.py +321 -0
- headroom/models/__init__.py +39 -0
- headroom/models/registry.py +687 -0
- headroom/parser.py +293 -0
- headroom/pricing/__init__.py +51 -0
- headroom/pricing/anthropic_prices.py +81 -0
- headroom/pricing/litellm_pricing.py +113 -0
- headroom/pricing/openai_prices.py +91 -0
- headroom/pricing/registry.py +188 -0
- headroom/providers/__init__.py +61 -0
- headroom/providers/anthropic.py +621 -0
- headroom/providers/base.py +131 -0
- headroom/providers/cohere.py +362 -0
- headroom/providers/google.py +427 -0
- headroom/providers/litellm.py +297 -0
- headroom/providers/openai.py +566 -0
- headroom/providers/openai_compatible.py +521 -0
- headroom/proxy/__init__.py +19 -0
- headroom/proxy/server.py +2683 -0
- headroom/py.typed +0 -0
- headroom/relevance/__init__.py +124 -0
- headroom/relevance/base.py +106 -0
- headroom/relevance/bm25.py +255 -0
- headroom/relevance/embedding.py +255 -0
- headroom/relevance/hybrid.py +259 -0
- headroom/reporting/__init__.py +5 -0
- headroom/reporting/generator.py +549 -0
- headroom/storage/__init__.py +41 -0
- headroom/storage/base.py +125 -0
- headroom/storage/jsonl.py +220 -0
- headroom/storage/sqlite.py +289 -0
- headroom/telemetry/__init__.py +91 -0
- headroom/telemetry/collector.py +764 -0
- headroom/telemetry/models.py +880 -0
- headroom/telemetry/toin.py +1579 -0
- headroom/tokenizer.py +80 -0
- headroom/tokenizers/__init__.py +75 -0
- headroom/tokenizers/base.py +210 -0
- headroom/tokenizers/estimator.py +198 -0
- headroom/tokenizers/huggingface.py +317 -0
- headroom/tokenizers/mistral.py +245 -0
- headroom/tokenizers/registry.py +398 -0
- headroom/tokenizers/tiktoken_counter.py +248 -0
- headroom/transforms/__init__.py +106 -0
- headroom/transforms/base.py +57 -0
- headroom/transforms/cache_aligner.py +357 -0
- headroom/transforms/code_compressor.py +1313 -0
- headroom/transforms/content_detector.py +335 -0
- headroom/transforms/content_router.py +1158 -0
- headroom/transforms/llmlingua_compressor.py +638 -0
- headroom/transforms/log_compressor.py +529 -0
- headroom/transforms/pipeline.py +297 -0
- headroom/transforms/rolling_window.py +350 -0
- headroom/transforms/search_compressor.py +365 -0
- headroom/transforms/smart_crusher.py +2682 -0
- headroom/transforms/text_compressor.py +259 -0
- headroom/transforms/tool_crusher.py +338 -0
- headroom/utils.py +215 -0
- headroom_ai-0.2.13.dist-info/METADATA +315 -0
- headroom_ai-0.2.13.dist-info/RECORD +114 -0
- headroom_ai-0.2.13.dist-info/WHEEL +4 -0
- headroom_ai-0.2.13.dist-info/entry_points.txt +2 -0
- headroom_ai-0.2.13.dist-info/licenses/LICENSE +190 -0
- headroom_ai-0.2.13.dist-info/licenses/NOTICE +43 -0
|
@@ -0,0 +1,324 @@
|
|
|
1
|
+
"""LangSmith integration for Headroom compression metrics.
|
|
2
|
+
|
|
3
|
+
This module provides HeadroomLangSmithCallbackHandler, a LangChain callback
|
|
4
|
+
handler that adds Headroom compression metrics to LangSmith traces.
|
|
5
|
+
|
|
6
|
+
When used with HeadroomChatModel, it automatically captures:
|
|
7
|
+
- Tokens before/after optimization
|
|
8
|
+
- Savings percentage
|
|
9
|
+
- Transforms applied
|
|
10
|
+
- Per-request compression details
|
|
11
|
+
|
|
12
|
+
Example:
|
|
13
|
+
import os
|
|
14
|
+
from langchain_openai import ChatOpenAI
|
|
15
|
+
from headroom.integrations import (
|
|
16
|
+
HeadroomChatModel,
|
|
17
|
+
HeadroomLangSmithCallbackHandler,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
# Enable LangSmith tracing
|
|
21
|
+
os.environ["LANGCHAIN_TRACING_V2"] = "true"
|
|
22
|
+
os.environ["LANGCHAIN_API_KEY"] = "..."
|
|
23
|
+
|
|
24
|
+
# Create handler
|
|
25
|
+
handler = HeadroomLangSmithCallbackHandler()
|
|
26
|
+
|
|
27
|
+
# Use with HeadroomChatModel
|
|
28
|
+
llm = HeadroomChatModel(
|
|
29
|
+
ChatOpenAI(model="gpt-4o"),
|
|
30
|
+
callbacks=[handler],
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
# Traces will include headroom.* metadata
|
|
34
|
+
response = llm.invoke("Hello!")
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
from __future__ import annotations
|
|
38
|
+
|
|
39
|
+
import logging
|
|
40
|
+
import os
|
|
41
|
+
from dataclasses import dataclass, field
|
|
42
|
+
from datetime import datetime
|
|
43
|
+
from typing import Any
|
|
44
|
+
from uuid import UUID
|
|
45
|
+
|
|
46
|
+
# LangChain imports - these are optional dependencies
|
|
47
|
+
try:
|
|
48
|
+
from langchain_core.callbacks import BaseCallbackHandler
|
|
49
|
+
from langchain_core.messages import BaseMessage
|
|
50
|
+
from langchain_core.outputs import LLMResult
|
|
51
|
+
|
|
52
|
+
LANGCHAIN_AVAILABLE = True
|
|
53
|
+
except ImportError:
|
|
54
|
+
LANGCHAIN_AVAILABLE = False
|
|
55
|
+
BaseCallbackHandler = object # type: ignore[misc,assignment]
|
|
56
|
+
LLMResult = object # type: ignore[misc,assignment]
|
|
57
|
+
|
|
58
|
+
# LangSmith imports - optional
|
|
59
|
+
try:
|
|
60
|
+
from langsmith import Client as LangSmithClient
|
|
61
|
+
|
|
62
|
+
LANGSMITH_AVAILABLE = True
|
|
63
|
+
except ImportError:
|
|
64
|
+
LANGSMITH_AVAILABLE = False
|
|
65
|
+
LangSmithClient = None # type: ignore[misc,assignment]
|
|
66
|
+
|
|
67
|
+
logger = logging.getLogger(__name__)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _check_langchain_available() -> None:
|
|
71
|
+
"""Raise ImportError if LangChain is not installed."""
|
|
72
|
+
if not LANGCHAIN_AVAILABLE:
|
|
73
|
+
raise ImportError(
|
|
74
|
+
"LangChain is required for this integration. "
|
|
75
|
+
"Install with: pip install headroom[langchain] "
|
|
76
|
+
"or: pip install langchain-core"
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@dataclass
|
|
81
|
+
class PendingMetrics:
|
|
82
|
+
"""Metrics pending attachment to a LangSmith run."""
|
|
83
|
+
|
|
84
|
+
tokens_before: int
|
|
85
|
+
tokens_after: int
|
|
86
|
+
tokens_saved: int
|
|
87
|
+
savings_percent: float
|
|
88
|
+
transforms_applied: list[str]
|
|
89
|
+
timestamp: datetime = field(default_factory=datetime.now)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class HeadroomLangSmithCallbackHandler(BaseCallbackHandler):
|
|
93
|
+
"""Callback handler that adds Headroom metrics to LangSmith traces.
|
|
94
|
+
|
|
95
|
+
Integrates with LangSmith to provide visibility into context
|
|
96
|
+
optimization within traces. Metrics appear as metadata with
|
|
97
|
+
the `headroom.` prefix.
|
|
98
|
+
|
|
99
|
+
Works automatically when:
|
|
100
|
+
1. LANGCHAIN_TRACING_V2=true is set
|
|
101
|
+
2. Used as a callback with HeadroomChatModel
|
|
102
|
+
3. LangSmith API key is configured
|
|
103
|
+
|
|
104
|
+
Example:
|
|
105
|
+
from headroom.integrations import (
|
|
106
|
+
HeadroomChatModel,
|
|
107
|
+
HeadroomLangSmithCallbackHandler,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
handler = HeadroomLangSmithCallbackHandler()
|
|
111
|
+
llm = HeadroomChatModel(
|
|
112
|
+
ChatOpenAI(model="gpt-4o"),
|
|
113
|
+
callbacks=[handler],
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
response = llm.invoke("Hello!")
|
|
117
|
+
# LangSmith trace now includes:
|
|
118
|
+
# - headroom.tokens_before
|
|
119
|
+
# - headroom.tokens_after
|
|
120
|
+
# - headroom.tokens_saved
|
|
121
|
+
# - headroom.savings_percent
|
|
122
|
+
# - headroom.transforms_applied
|
|
123
|
+
|
|
124
|
+
Attributes:
|
|
125
|
+
langsmith_client: LangSmith client for updating runs.
|
|
126
|
+
pending_metrics: Metrics waiting to be attached to runs.
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
def __init__(
|
|
130
|
+
self,
|
|
131
|
+
langsmith_client: Any = None,
|
|
132
|
+
auto_update_runs: bool = True,
|
|
133
|
+
):
|
|
134
|
+
"""Initialize HeadroomLangSmithCallbackHandler.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
langsmith_client: LangSmith client instance. Auto-creates
|
|
138
|
+
one if not provided and LangSmith is available.
|
|
139
|
+
auto_update_runs: If True, automatically updates LangSmith
|
|
140
|
+
runs with Headroom metadata. Default True.
|
|
141
|
+
"""
|
|
142
|
+
_check_langchain_available()
|
|
143
|
+
|
|
144
|
+
self._client = langsmith_client
|
|
145
|
+
self._auto_update = auto_update_runs
|
|
146
|
+
self._pending_metrics: dict[str, PendingMetrics] = {}
|
|
147
|
+
self._run_metrics: dict[str, dict[str, Any]] = {}
|
|
148
|
+
|
|
149
|
+
# Initialize LangSmith client if available and not provided
|
|
150
|
+
if self._client is None and LANGSMITH_AVAILABLE and auto_update_runs:
|
|
151
|
+
try:
|
|
152
|
+
if os.environ.get("LANGCHAIN_API_KEY"):
|
|
153
|
+
self._client = LangSmithClient()
|
|
154
|
+
except Exception as e:
|
|
155
|
+
logger.debug(f"Could not initialize LangSmith client: {e}")
|
|
156
|
+
|
|
157
|
+
def set_headroom_metrics(
|
|
158
|
+
self,
|
|
159
|
+
run_id: str | UUID,
|
|
160
|
+
tokens_before: int,
|
|
161
|
+
tokens_after: int,
|
|
162
|
+
transforms_applied: list[str] | None = None,
|
|
163
|
+
) -> None:
|
|
164
|
+
"""Set Headroom metrics for a run.
|
|
165
|
+
|
|
166
|
+
Call this from HeadroomChatModel after optimization to attach
|
|
167
|
+
metrics to the current run.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
run_id: The LangSmith run ID.
|
|
171
|
+
tokens_before: Token count before optimization.
|
|
172
|
+
tokens_after: Token count after optimization.
|
|
173
|
+
transforms_applied: List of transforms that were applied.
|
|
174
|
+
"""
|
|
175
|
+
run_id_str = str(run_id)
|
|
176
|
+
tokens_saved = tokens_before - tokens_after
|
|
177
|
+
savings_percent = (tokens_saved / tokens_before * 100) if tokens_before > 0 else 0.0
|
|
178
|
+
|
|
179
|
+
metrics = PendingMetrics(
|
|
180
|
+
tokens_before=tokens_before,
|
|
181
|
+
tokens_after=tokens_after,
|
|
182
|
+
tokens_saved=tokens_saved,
|
|
183
|
+
savings_percent=savings_percent,
|
|
184
|
+
transforms_applied=transforms_applied or [],
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
self._pending_metrics[run_id_str] = metrics
|
|
188
|
+
|
|
189
|
+
logger.debug(
|
|
190
|
+
f"Headroom metrics set for run {run_id_str}: "
|
|
191
|
+
f"{tokens_before} -> {tokens_after} tokens ({savings_percent:.1f}% saved)"
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
def on_chat_model_start(
|
|
195
|
+
self,
|
|
196
|
+
serialized: dict[str, Any],
|
|
197
|
+
messages: list[list[BaseMessage]],
|
|
198
|
+
*,
|
|
199
|
+
run_id: UUID,
|
|
200
|
+
**kwargs: Any,
|
|
201
|
+
) -> None:
|
|
202
|
+
"""Called when chat model starts.
|
|
203
|
+
|
|
204
|
+
Records the run ID for later metric attachment.
|
|
205
|
+
"""
|
|
206
|
+
run_id_str = str(run_id)
|
|
207
|
+
# Initialize empty metrics for this run
|
|
208
|
+
self._run_metrics[run_id_str] = {}
|
|
209
|
+
|
|
210
|
+
def on_llm_end(
|
|
211
|
+
self,
|
|
212
|
+
response: LLMResult,
|
|
213
|
+
*,
|
|
214
|
+
run_id: UUID,
|
|
215
|
+
**kwargs: Any,
|
|
216
|
+
) -> None:
|
|
217
|
+
"""Called when LLM completes.
|
|
218
|
+
|
|
219
|
+
Attaches pending Headroom metrics to the LangSmith run.
|
|
220
|
+
"""
|
|
221
|
+
run_id_str = str(run_id)
|
|
222
|
+
|
|
223
|
+
# Check for pending metrics
|
|
224
|
+
if run_id_str in self._pending_metrics:
|
|
225
|
+
metrics = self._pending_metrics.pop(run_id_str)
|
|
226
|
+
self._attach_metrics_to_run(run_id_str, metrics)
|
|
227
|
+
|
|
228
|
+
def _attach_metrics_to_run(self, run_id: str, metrics: PendingMetrics) -> None:
|
|
229
|
+
"""Attach Headroom metrics to a LangSmith run.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
run_id: The run ID.
|
|
233
|
+
metrics: Metrics to attach.
|
|
234
|
+
"""
|
|
235
|
+
metadata = {
|
|
236
|
+
"headroom.tokens_before": metrics.tokens_before,
|
|
237
|
+
"headroom.tokens_after": metrics.tokens_after,
|
|
238
|
+
"headroom.tokens_saved": metrics.tokens_saved,
|
|
239
|
+
"headroom.savings_percent": round(metrics.savings_percent, 2),
|
|
240
|
+
"headroom.transforms_applied": metrics.transforms_applied,
|
|
241
|
+
"headroom.optimization_timestamp": metrics.timestamp.isoformat(),
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
# Store in run metrics
|
|
245
|
+
self._run_metrics[run_id] = metadata
|
|
246
|
+
|
|
247
|
+
# Update LangSmith run if client available
|
|
248
|
+
if self._client and self._auto_update:
|
|
249
|
+
try:
|
|
250
|
+
self._client.update_run(
|
|
251
|
+
run_id=run_id,
|
|
252
|
+
extra={"metadata": metadata},
|
|
253
|
+
)
|
|
254
|
+
logger.debug(f"Updated LangSmith run {run_id} with Headroom metrics")
|
|
255
|
+
except Exception as e:
|
|
256
|
+
logger.debug(f"Could not update LangSmith run: {e}")
|
|
257
|
+
|
|
258
|
+
def get_run_metrics(self, run_id: str | UUID) -> dict[str, Any]:
|
|
259
|
+
"""Get Headroom metrics for a specific run.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
run_id: The run ID.
|
|
263
|
+
|
|
264
|
+
Returns:
|
|
265
|
+
Dictionary of headroom.* metrics for the run.
|
|
266
|
+
"""
|
|
267
|
+
return self._run_metrics.get(str(run_id), {})
|
|
268
|
+
|
|
269
|
+
def get_all_metrics(self) -> dict[str, dict[str, Any]]:
|
|
270
|
+
"""Get all recorded run metrics.
|
|
271
|
+
|
|
272
|
+
Returns:
|
|
273
|
+
Dictionary mapping run IDs to their metrics.
|
|
274
|
+
"""
|
|
275
|
+
return self._run_metrics.copy()
|
|
276
|
+
|
|
277
|
+
def get_summary(self) -> dict[str, Any]:
|
|
278
|
+
"""Get summary statistics across all runs.
|
|
279
|
+
|
|
280
|
+
Returns:
|
|
281
|
+
Summary with total runs, tokens saved, etc.
|
|
282
|
+
"""
|
|
283
|
+
if not self._run_metrics:
|
|
284
|
+
return {
|
|
285
|
+
"total_runs": 0,
|
|
286
|
+
"total_tokens_saved": 0,
|
|
287
|
+
"average_savings_percent": 0,
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
total_saved = sum(m.get("headroom.tokens_saved", 0) for m in self._run_metrics.values())
|
|
291
|
+
savings_percents = [
|
|
292
|
+
m.get("headroom.savings_percent", 0) for m in self._run_metrics.values()
|
|
293
|
+
]
|
|
294
|
+
|
|
295
|
+
return {
|
|
296
|
+
"total_runs": len(self._run_metrics),
|
|
297
|
+
"total_tokens_saved": total_saved,
|
|
298
|
+
"average_savings_percent": (
|
|
299
|
+
sum(savings_percents) / len(savings_percents) if savings_percents else 0
|
|
300
|
+
),
|
|
301
|
+
}
|
|
302
|
+
|
|
303
|
+
def reset(self) -> None:
|
|
304
|
+
"""Clear all recorded metrics."""
|
|
305
|
+
self._pending_metrics.clear()
|
|
306
|
+
self._run_metrics.clear()
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def is_langsmith_available() -> bool:
|
|
310
|
+
"""Check if LangSmith is available and configured.
|
|
311
|
+
|
|
312
|
+
Returns:
|
|
313
|
+
True if LangSmith is installed and API key is set.
|
|
314
|
+
"""
|
|
315
|
+
return LANGSMITH_AVAILABLE and bool(os.environ.get("LANGCHAIN_API_KEY"))
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def is_langsmith_tracing_enabled() -> bool:
|
|
319
|
+
"""Check if LangSmith tracing is enabled.
|
|
320
|
+
|
|
321
|
+
Returns:
|
|
322
|
+
True if LANGCHAIN_TRACING_V2 is set to "true".
|
|
323
|
+
"""
|
|
324
|
+
return os.environ.get("LANGCHAIN_TRACING_V2", "").lower() == "true"
|
|
@@ -0,0 +1,319 @@
|
|
|
1
|
+
"""Memory integration for LangChain with automatic compression.
|
|
2
|
+
|
|
3
|
+
This module provides HeadroomChatMessageHistory, a wrapper for any LangChain
|
|
4
|
+
chat message history that automatically compresses conversation history
|
|
5
|
+
when it exceeds a token threshold.
|
|
6
|
+
|
|
7
|
+
Example:
|
|
8
|
+
from langchain.memory import ConversationBufferMemory
|
|
9
|
+
from langchain_community.chat_message_histories import ChatMessageHistory
|
|
10
|
+
from headroom.integrations import HeadroomChatMessageHistory
|
|
11
|
+
|
|
12
|
+
# Wrap any chat message history
|
|
13
|
+
base_history = ChatMessageHistory()
|
|
14
|
+
compressed_history = HeadroomChatMessageHistory(base_history)
|
|
15
|
+
|
|
16
|
+
# Use with ConversationBufferMemory (zero code changes to chain)
|
|
17
|
+
memory = ConversationBufferMemory(chat_memory=compressed_history)
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from __future__ import annotations
|
|
21
|
+
|
|
22
|
+
import logging
|
|
23
|
+
from typing import TYPE_CHECKING, Any
|
|
24
|
+
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from headroom.providers.base import Provider
|
|
27
|
+
|
|
28
|
+
# LangChain imports - these are optional dependencies
|
|
29
|
+
try:
|
|
30
|
+
from langchain_core.chat_history import BaseChatMessageHistory
|
|
31
|
+
from langchain_core.messages import (
|
|
32
|
+
AIMessage,
|
|
33
|
+
BaseMessage,
|
|
34
|
+
HumanMessage,
|
|
35
|
+
SystemMessage,
|
|
36
|
+
ToolMessage,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
LANGCHAIN_AVAILABLE = True
|
|
40
|
+
except ImportError:
|
|
41
|
+
LANGCHAIN_AVAILABLE = False
|
|
42
|
+
BaseChatMessageHistory = object # type: ignore[misc,assignment]
|
|
43
|
+
|
|
44
|
+
from headroom import HeadroomConfig
|
|
45
|
+
from headroom.config import RollingWindowConfig
|
|
46
|
+
from headroom.providers import OpenAIProvider
|
|
47
|
+
from headroom.transforms import TransformPipeline
|
|
48
|
+
|
|
49
|
+
logger = logging.getLogger(__name__)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _check_langchain_available() -> None:
|
|
53
|
+
"""Raise ImportError if LangChain is not installed."""
|
|
54
|
+
if not LANGCHAIN_AVAILABLE:
|
|
55
|
+
raise ImportError(
|
|
56
|
+
"LangChain is required for this integration. "
|
|
57
|
+
"Install with: pip install headroom[langchain] "
|
|
58
|
+
"or: pip install langchain-core"
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class HeadroomChatMessageHistory(BaseChatMessageHistory):
|
|
63
|
+
"""Wraps any LangChain chat message history with automatic compression.
|
|
64
|
+
|
|
65
|
+
When conversation history exceeds the token threshold, automatically
|
|
66
|
+
applies RollingWindow compression to keep recent turns while fitting
|
|
67
|
+
within the limit.
|
|
68
|
+
|
|
69
|
+
This works with ANY memory type because it wraps at the storage layer:
|
|
70
|
+
- ConversationBufferMemory
|
|
71
|
+
- ConversationSummaryMemory
|
|
72
|
+
- ConversationBufferWindowMemory
|
|
73
|
+
- Redis, PostgreSQL, or any custom history
|
|
74
|
+
|
|
75
|
+
Example:
|
|
76
|
+
from langchain.memory import ConversationBufferMemory
|
|
77
|
+
from langchain_community.chat_message_histories import ChatMessageHistory
|
|
78
|
+
from headroom.integrations import HeadroomChatMessageHistory
|
|
79
|
+
|
|
80
|
+
# Wrap base history
|
|
81
|
+
base = ChatMessageHistory()
|
|
82
|
+
compressed = HeadroomChatMessageHistory(
|
|
83
|
+
base,
|
|
84
|
+
compress_threshold_tokens=4000,
|
|
85
|
+
keep_recent_turns=5,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
# Use with any memory class
|
|
89
|
+
memory = ConversationBufferMemory(chat_memory=compressed)
|
|
90
|
+
|
|
91
|
+
# Messages are compressed automatically when accessed
|
|
92
|
+
chain = ConversationChain(llm=llm, memory=memory)
|
|
93
|
+
chain.invoke({"input": "Hello!"})
|
|
94
|
+
|
|
95
|
+
Attributes:
|
|
96
|
+
base_history: The underlying chat message history
|
|
97
|
+
compress_threshold_tokens: Token count that triggers compression
|
|
98
|
+
keep_recent_turns: Minimum recent turns to always preserve
|
|
99
|
+
model: Model name for token counting (default: "gpt-4o")
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
def __init__(
|
|
103
|
+
self,
|
|
104
|
+
base_history: BaseChatMessageHistory,
|
|
105
|
+
compress_threshold_tokens: int = 4000,
|
|
106
|
+
keep_recent_turns: int = 5,
|
|
107
|
+
model: str = "gpt-4o",
|
|
108
|
+
provider: Provider | None = None,
|
|
109
|
+
):
|
|
110
|
+
"""Initialize HeadroomChatMessageHistory.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
base_history: Any LangChain BaseChatMessageHistory to wrap
|
|
114
|
+
compress_threshold_tokens: Apply compression when history exceeds
|
|
115
|
+
this many tokens. Default 4000.
|
|
116
|
+
keep_recent_turns: Minimum number of recent user/assistant turns
|
|
117
|
+
to always preserve during compression. Default 5.
|
|
118
|
+
model: Model name for token counting. Default "gpt-4o".
|
|
119
|
+
provider: Headroom provider for token counting. Auto-uses
|
|
120
|
+
OpenAIProvider if not specified.
|
|
121
|
+
"""
|
|
122
|
+
_check_langchain_available()
|
|
123
|
+
|
|
124
|
+
self._base = base_history
|
|
125
|
+
self._threshold = compress_threshold_tokens
|
|
126
|
+
self._keep_recent_turns = keep_recent_turns
|
|
127
|
+
self._model = model
|
|
128
|
+
self._provider: Provider = provider or OpenAIProvider()
|
|
129
|
+
|
|
130
|
+
# Track compression stats
|
|
131
|
+
self._compression_count = 0
|
|
132
|
+
self._total_tokens_saved = 0
|
|
133
|
+
|
|
134
|
+
@property
|
|
135
|
+
def messages(self) -> list[BaseMessage]: # type: ignore[override]
|
|
136
|
+
"""Get messages, applying compression if over threshold.
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
List of messages, potentially compressed to fit within threshold.
|
|
140
|
+
"""
|
|
141
|
+
raw_messages = self._base.messages
|
|
142
|
+
|
|
143
|
+
if not raw_messages:
|
|
144
|
+
return []
|
|
145
|
+
|
|
146
|
+
# Count tokens
|
|
147
|
+
token_count = self._count_tokens(raw_messages)
|
|
148
|
+
|
|
149
|
+
if token_count <= self._threshold:
|
|
150
|
+
return list(raw_messages)
|
|
151
|
+
|
|
152
|
+
# Apply compression
|
|
153
|
+
compressed = self._apply_rolling_window(raw_messages)
|
|
154
|
+
tokens_after = self._count_tokens(compressed)
|
|
155
|
+
|
|
156
|
+
self._compression_count += 1
|
|
157
|
+
self._total_tokens_saved += token_count - tokens_after
|
|
158
|
+
|
|
159
|
+
logger.info(
|
|
160
|
+
f"HeadroomChatMessageHistory compressed: {token_count} -> {tokens_after} tokens "
|
|
161
|
+
f"({len(raw_messages)} -> {len(compressed)} messages)"
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
return compressed
|
|
165
|
+
|
|
166
|
+
def add_message(self, message: BaseMessage) -> None:
|
|
167
|
+
"""Add a message to the underlying history.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
message: The message to add.
|
|
171
|
+
"""
|
|
172
|
+
self._base.add_message(message)
|
|
173
|
+
|
|
174
|
+
def add_user_message(self, message: HumanMessage | str) -> None:
|
|
175
|
+
"""Add a user message to the history.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
message: The user message (string or HumanMessage).
|
|
179
|
+
"""
|
|
180
|
+
self._base.add_user_message(message)
|
|
181
|
+
|
|
182
|
+
def add_ai_message(self, message: AIMessage | str) -> None:
|
|
183
|
+
"""Add an AI message to the history.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
message: The AI message (string or AIMessage).
|
|
187
|
+
"""
|
|
188
|
+
self._base.add_ai_message(message)
|
|
189
|
+
|
|
190
|
+
def clear(self) -> None:
|
|
191
|
+
"""Clear all messages from history."""
|
|
192
|
+
self._base.clear()
|
|
193
|
+
|
|
194
|
+
def _count_tokens(self, messages: list[BaseMessage]) -> int:
|
|
195
|
+
"""Count tokens in messages using provider's tokenizer.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
messages: List of messages to count.
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
Total token count.
|
|
202
|
+
"""
|
|
203
|
+
token_counter = self._provider.get_token_counter(self._model)
|
|
204
|
+
total = 0
|
|
205
|
+
for msg in messages:
|
|
206
|
+
content = msg.content if isinstance(msg.content, str) else str(msg.content)
|
|
207
|
+
total += token_counter.count_text(content)
|
|
208
|
+
return total
|
|
209
|
+
|
|
210
|
+
def _apply_rolling_window(self, messages: list[BaseMessage]) -> list[BaseMessage]:
|
|
211
|
+
"""Apply RollingWindow compression to messages.
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
messages: Messages to compress.
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
Compressed messages fitting within threshold.
|
|
218
|
+
"""
|
|
219
|
+
# Convert to OpenAI format for Headroom transforms
|
|
220
|
+
openai_messages = self._convert_to_openai(messages)
|
|
221
|
+
|
|
222
|
+
# Use TransformPipeline which handles tokenizer setup
|
|
223
|
+
config = HeadroomConfig(
|
|
224
|
+
rolling_window=RollingWindowConfig(keep_last_turns=self._keep_recent_turns),
|
|
225
|
+
)
|
|
226
|
+
pipeline = TransformPipeline(config=config, provider=self._provider)
|
|
227
|
+
|
|
228
|
+
# Apply compression via pipeline
|
|
229
|
+
result = pipeline.apply(
|
|
230
|
+
messages=openai_messages,
|
|
231
|
+
model=self._model,
|
|
232
|
+
model_limit=self._threshold,
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
# Convert back to LangChain format
|
|
236
|
+
return self._convert_from_openai(result.messages)
|
|
237
|
+
|
|
238
|
+
def _convert_to_openai(self, messages: list[BaseMessage]) -> list[dict[str, Any]]:
|
|
239
|
+
"""Convert LangChain messages to OpenAI format.
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
messages: LangChain messages.
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
OpenAI format messages.
|
|
246
|
+
"""
|
|
247
|
+
result = []
|
|
248
|
+
for msg in messages:
|
|
249
|
+
content = msg.content if isinstance(msg.content, str) else str(msg.content)
|
|
250
|
+
|
|
251
|
+
if isinstance(msg, SystemMessage):
|
|
252
|
+
result.append({"role": "system", "content": content})
|
|
253
|
+
elif isinstance(msg, HumanMessage):
|
|
254
|
+
result.append({"role": "user", "content": content})
|
|
255
|
+
elif isinstance(msg, AIMessage):
|
|
256
|
+
entry: dict[str, Any] = {"role": "assistant", "content": content}
|
|
257
|
+
if hasattr(msg, "tool_calls") and msg.tool_calls:
|
|
258
|
+
entry["tool_calls"] = msg.tool_calls
|
|
259
|
+
result.append(entry)
|
|
260
|
+
elif isinstance(msg, ToolMessage):
|
|
261
|
+
result.append(
|
|
262
|
+
{
|
|
263
|
+
"role": "tool",
|
|
264
|
+
"tool_call_id": getattr(msg, "tool_call_id", ""),
|
|
265
|
+
"content": content,
|
|
266
|
+
}
|
|
267
|
+
)
|
|
268
|
+
else:
|
|
269
|
+
# Generic fallback
|
|
270
|
+
result.append(
|
|
271
|
+
{
|
|
272
|
+
"role": getattr(msg, "type", "user"),
|
|
273
|
+
"content": content,
|
|
274
|
+
}
|
|
275
|
+
)
|
|
276
|
+
return result
|
|
277
|
+
|
|
278
|
+
def _convert_from_openai(self, messages: list[dict[str, Any]]) -> list[BaseMessage]:
|
|
279
|
+
"""Convert OpenAI format back to LangChain messages.
|
|
280
|
+
|
|
281
|
+
Args:
|
|
282
|
+
messages: OpenAI format messages.
|
|
283
|
+
|
|
284
|
+
Returns:
|
|
285
|
+
LangChain messages.
|
|
286
|
+
"""
|
|
287
|
+
result: list[BaseMessage] = []
|
|
288
|
+
for msg in messages:
|
|
289
|
+
role = msg.get("role", "user")
|
|
290
|
+
content = msg.get("content", "")
|
|
291
|
+
|
|
292
|
+
if role == "system":
|
|
293
|
+
result.append(SystemMessage(content=content))
|
|
294
|
+
elif role == "user":
|
|
295
|
+
result.append(HumanMessage(content=content))
|
|
296
|
+
elif role == "assistant":
|
|
297
|
+
tool_calls = msg.get("tool_calls", [])
|
|
298
|
+
result.append(AIMessage(content=content, tool_calls=tool_calls))
|
|
299
|
+
elif role == "tool":
|
|
300
|
+
result.append(
|
|
301
|
+
ToolMessage(
|
|
302
|
+
content=content,
|
|
303
|
+
tool_call_id=msg.get("tool_call_id", ""),
|
|
304
|
+
)
|
|
305
|
+
)
|
|
306
|
+
return result
|
|
307
|
+
|
|
308
|
+
def get_compression_stats(self) -> dict[str, Any]:
|
|
309
|
+
"""Get statistics about compression operations.
|
|
310
|
+
|
|
311
|
+
Returns:
|
|
312
|
+
Dictionary with compression_count, total_tokens_saved.
|
|
313
|
+
"""
|
|
314
|
+
return {
|
|
315
|
+
"compression_count": self._compression_count,
|
|
316
|
+
"total_tokens_saved": self._total_tokens_saved,
|
|
317
|
+
"threshold_tokens": self._threshold,
|
|
318
|
+
"keep_recent_turns": self._keep_recent_turns,
|
|
319
|
+
}
|