headroom-ai 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- headroom/__init__.py +212 -0
- headroom/cache/__init__.py +76 -0
- headroom/cache/anthropic.py +517 -0
- headroom/cache/base.py +342 -0
- headroom/cache/compression_feedback.py +613 -0
- headroom/cache/compression_store.py +814 -0
- headroom/cache/dynamic_detector.py +1026 -0
- headroom/cache/google.py +884 -0
- headroom/cache/openai.py +584 -0
- headroom/cache/registry.py +175 -0
- headroom/cache/semantic.py +451 -0
- headroom/ccr/__init__.py +77 -0
- headroom/ccr/context_tracker.py +582 -0
- headroom/ccr/mcp_server.py +319 -0
- headroom/ccr/response_handler.py +772 -0
- headroom/ccr/tool_injection.py +415 -0
- headroom/cli.py +219 -0
- headroom/client.py +977 -0
- headroom/compression/__init__.py +42 -0
- headroom/compression/detector.py +424 -0
- headroom/compression/handlers/__init__.py +22 -0
- headroom/compression/handlers/base.py +219 -0
- headroom/compression/handlers/code_handler.py +506 -0
- headroom/compression/handlers/json_handler.py +418 -0
- headroom/compression/masks.py +345 -0
- headroom/compression/universal.py +465 -0
- headroom/config.py +474 -0
- headroom/exceptions.py +192 -0
- headroom/integrations/__init__.py +159 -0
- headroom/integrations/agno/__init__.py +53 -0
- headroom/integrations/agno/hooks.py +345 -0
- headroom/integrations/agno/model.py +625 -0
- headroom/integrations/agno/providers.py +154 -0
- headroom/integrations/langchain/__init__.py +106 -0
- headroom/integrations/langchain/agents.py +326 -0
- headroom/integrations/langchain/chat_model.py +1002 -0
- headroom/integrations/langchain/langsmith.py +324 -0
- headroom/integrations/langchain/memory.py +319 -0
- headroom/integrations/langchain/providers.py +200 -0
- headroom/integrations/langchain/retriever.py +371 -0
- headroom/integrations/langchain/streaming.py +341 -0
- headroom/integrations/mcp/__init__.py +37 -0
- headroom/integrations/mcp/server.py +533 -0
- headroom/memory/__init__.py +37 -0
- headroom/memory/extractor.py +390 -0
- headroom/memory/fast_store.py +621 -0
- headroom/memory/fast_wrapper.py +311 -0
- headroom/memory/inline_extractor.py +229 -0
- headroom/memory/store.py +434 -0
- headroom/memory/worker.py +260 -0
- headroom/memory/wrapper.py +321 -0
- headroom/models/__init__.py +39 -0
- headroom/models/registry.py +687 -0
- headroom/parser.py +293 -0
- headroom/pricing/__init__.py +51 -0
- headroom/pricing/anthropic_prices.py +81 -0
- headroom/pricing/litellm_pricing.py +113 -0
- headroom/pricing/openai_prices.py +91 -0
- headroom/pricing/registry.py +188 -0
- headroom/providers/__init__.py +61 -0
- headroom/providers/anthropic.py +621 -0
- headroom/providers/base.py +131 -0
- headroom/providers/cohere.py +362 -0
- headroom/providers/google.py +427 -0
- headroom/providers/litellm.py +297 -0
- headroom/providers/openai.py +566 -0
- headroom/providers/openai_compatible.py +521 -0
- headroom/proxy/__init__.py +19 -0
- headroom/proxy/server.py +2683 -0
- headroom/py.typed +0 -0
- headroom/relevance/__init__.py +124 -0
- headroom/relevance/base.py +106 -0
- headroom/relevance/bm25.py +255 -0
- headroom/relevance/embedding.py +255 -0
- headroom/relevance/hybrid.py +259 -0
- headroom/reporting/__init__.py +5 -0
- headroom/reporting/generator.py +549 -0
- headroom/storage/__init__.py +41 -0
- headroom/storage/base.py +125 -0
- headroom/storage/jsonl.py +220 -0
- headroom/storage/sqlite.py +289 -0
- headroom/telemetry/__init__.py +91 -0
- headroom/telemetry/collector.py +764 -0
- headroom/telemetry/models.py +880 -0
- headroom/telemetry/toin.py +1579 -0
- headroom/tokenizer.py +80 -0
- headroom/tokenizers/__init__.py +75 -0
- headroom/tokenizers/base.py +210 -0
- headroom/tokenizers/estimator.py +198 -0
- headroom/tokenizers/huggingface.py +317 -0
- headroom/tokenizers/mistral.py +245 -0
- headroom/tokenizers/registry.py +398 -0
- headroom/tokenizers/tiktoken_counter.py +248 -0
- headroom/transforms/__init__.py +106 -0
- headroom/transforms/base.py +57 -0
- headroom/transforms/cache_aligner.py +357 -0
- headroom/transforms/code_compressor.py +1313 -0
- headroom/transforms/content_detector.py +335 -0
- headroom/transforms/content_router.py +1158 -0
- headroom/transforms/llmlingua_compressor.py +638 -0
- headroom/transforms/log_compressor.py +529 -0
- headroom/transforms/pipeline.py +297 -0
- headroom/transforms/rolling_window.py +350 -0
- headroom/transforms/search_compressor.py +365 -0
- headroom/transforms/smart_crusher.py +2682 -0
- headroom/transforms/text_compressor.py +259 -0
- headroom/transforms/tool_crusher.py +338 -0
- headroom/utils.py +215 -0
- headroom_ai-0.2.13.dist-info/METADATA +315 -0
- headroom_ai-0.2.13.dist-info/RECORD +114 -0
- headroom_ai-0.2.13.dist-info/WHEEL +4 -0
- headroom_ai-0.2.13.dist-info/entry_points.txt +2 -0
- headroom_ai-0.2.13.dist-info/licenses/LICENSE +190 -0
- headroom_ai-0.2.13.dist-info/licenses/NOTICE +43 -0
|
@@ -0,0 +1,1002 @@
|
|
|
1
|
+
"""LangChain integration for Headroom SDK.
|
|
2
|
+
|
|
3
|
+
This module provides seamless integration with LangChain, enabling automatic
|
|
4
|
+
context optimization for any LangChain chat model.
|
|
5
|
+
|
|
6
|
+
Key insight: LangChain callbacks CANNOT modify messages (by design - see
|
|
7
|
+
https://github.com/langchain-ai/langchain/issues/8725). Therefore, we wrap
|
|
8
|
+
the chat model itself to intercept and transform messages.
|
|
9
|
+
|
|
10
|
+
Components:
|
|
11
|
+
1. HeadroomChatModel - Wraps any BaseChatModel to apply Headroom transforms
|
|
12
|
+
2. HeadroomCallbackHandler - Tracks metrics and token usage (observability only)
|
|
13
|
+
3. HeadroomRunnable - LCEL-compatible Runnable for chain composition
|
|
14
|
+
4. optimize_messages() - Standalone function for manual optimization
|
|
15
|
+
|
|
16
|
+
Example:
|
|
17
|
+
from langchain_openai import ChatOpenAI
|
|
18
|
+
from headroom.integrations import HeadroomChatModel
|
|
19
|
+
|
|
20
|
+
# Wrap any LangChain chat model
|
|
21
|
+
llm = ChatOpenAI(model="gpt-4o")
|
|
22
|
+
optimized_llm = HeadroomChatModel(llm)
|
|
23
|
+
|
|
24
|
+
# Use normally - Headroom automatically optimizes context
|
|
25
|
+
response = optimized_llm.invoke("What is 2+2?")
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
from __future__ import annotations
|
|
29
|
+
|
|
30
|
+
import asyncio
|
|
31
|
+
import json
|
|
32
|
+
import logging
|
|
33
|
+
from collections.abc import AsyncIterator, Iterator, Sequence
|
|
34
|
+
from dataclasses import dataclass
|
|
35
|
+
from datetime import datetime
|
|
36
|
+
from typing import Any
|
|
37
|
+
from uuid import UUID, uuid4
|
|
38
|
+
|
|
39
|
+
# LangChain imports - these are optional dependencies
|
|
40
|
+
try:
|
|
41
|
+
from langchain_core.callbacks import BaseCallbackHandler
|
|
42
|
+
from langchain_core.language_models import BaseChatModel
|
|
43
|
+
from langchain_core.messages import (
|
|
44
|
+
AIMessage,
|
|
45
|
+
BaseMessage,
|
|
46
|
+
HumanMessage,
|
|
47
|
+
SystemMessage,
|
|
48
|
+
ToolMessage,
|
|
49
|
+
)
|
|
50
|
+
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult # noqa: F401
|
|
51
|
+
from langchain_core.runnables import RunnableLambda
|
|
52
|
+
from pydantic import ConfigDict, Field, PrivateAttr
|
|
53
|
+
|
|
54
|
+
LANGCHAIN_AVAILABLE = True
|
|
55
|
+
except ImportError:
|
|
56
|
+
LANGCHAIN_AVAILABLE = False
|
|
57
|
+
BaseChatModel = object # type: ignore[misc,assignment]
|
|
58
|
+
BaseCallbackHandler = object # type: ignore[misc,assignment]
|
|
59
|
+
ConfigDict = lambda **kwargs: {} # type: ignore[assignment,misc] # noqa: E731
|
|
60
|
+
Field = lambda **kwargs: None # type: ignore[assignment] # noqa: E731
|
|
61
|
+
PrivateAttr = lambda **kwargs: None # type: ignore[assignment] # noqa: E731
|
|
62
|
+
|
|
63
|
+
from headroom import HeadroomConfig, HeadroomMode
|
|
64
|
+
from headroom.providers import OpenAIProvider
|
|
65
|
+
from headroom.transforms import TransformPipeline
|
|
66
|
+
|
|
67
|
+
from .providers import get_headroom_provider, get_model_name_from_langchain
|
|
68
|
+
|
|
69
|
+
logger = logging.getLogger(__name__)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _check_langchain_available() -> None:
|
|
73
|
+
"""Raise ImportError if LangChain is not installed."""
|
|
74
|
+
if not LANGCHAIN_AVAILABLE:
|
|
75
|
+
raise ImportError(
|
|
76
|
+
"LangChain is required for this integration. "
|
|
77
|
+
"Install with: pip install headroom[langchain] "
|
|
78
|
+
"or: pip install langchain-core langchain-openai"
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def langchain_available() -> bool:
|
|
83
|
+
"""Check if LangChain is installed."""
|
|
84
|
+
return LANGCHAIN_AVAILABLE
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@dataclass
|
|
88
|
+
class OptimizationMetrics:
|
|
89
|
+
"""Metrics from a single optimization pass."""
|
|
90
|
+
|
|
91
|
+
request_id: str
|
|
92
|
+
timestamp: datetime
|
|
93
|
+
tokens_before: int
|
|
94
|
+
tokens_after: int
|
|
95
|
+
tokens_saved: int
|
|
96
|
+
savings_percent: float
|
|
97
|
+
transforms_applied: list[str]
|
|
98
|
+
model: str
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class HeadroomChatModel(BaseChatModel):
|
|
102
|
+
"""LangChain chat model wrapper that applies Headroom optimizations.
|
|
103
|
+
|
|
104
|
+
Wraps any LangChain BaseChatModel and automatically optimizes the context
|
|
105
|
+
before each API call. This is the recommended way to use Headroom with
|
|
106
|
+
LangChain because:
|
|
107
|
+
|
|
108
|
+
1. Callbacks cannot modify messages (LangChain design limitation)
|
|
109
|
+
2. Wrapping ensures ALL calls go through optimization
|
|
110
|
+
3. Works with streaming, tools, and all LangChain features
|
|
111
|
+
|
|
112
|
+
Example:
|
|
113
|
+
from langchain_openai import ChatOpenAI
|
|
114
|
+
from headroom.integrations import HeadroomChatModel
|
|
115
|
+
|
|
116
|
+
# Basic usage
|
|
117
|
+
llm = ChatOpenAI(model="gpt-4o")
|
|
118
|
+
optimized = HeadroomChatModel(llm)
|
|
119
|
+
response = optimized.invoke([HumanMessage("Hello!")])
|
|
120
|
+
|
|
121
|
+
# With custom config
|
|
122
|
+
from headroom import HeadroomConfig, HeadroomMode
|
|
123
|
+
config = HeadroomConfig(default_mode=HeadroomMode.OPTIMIZE)
|
|
124
|
+
optimized = HeadroomChatModel(llm, config=config)
|
|
125
|
+
|
|
126
|
+
# Access metrics
|
|
127
|
+
print(f"Saved {optimized.total_tokens_saved} tokens")
|
|
128
|
+
|
|
129
|
+
Attributes:
|
|
130
|
+
wrapped_model: The underlying LangChain chat model
|
|
131
|
+
headroom_client: HeadroomClient instance for optimization
|
|
132
|
+
metrics_history: List of OptimizationMetrics from recent calls
|
|
133
|
+
total_tokens_saved: Running total of tokens saved
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
# Pydantic model fields
|
|
137
|
+
wrapped_model: Any = Field(description="The wrapped LangChain chat model")
|
|
138
|
+
headroom_config: Any = Field(default=None, description="Headroom configuration")
|
|
139
|
+
mode: HeadroomMode = Field(default=HeadroomMode.OPTIMIZE, description="Headroom mode")
|
|
140
|
+
auto_detect_provider: bool = Field(
|
|
141
|
+
default=True,
|
|
142
|
+
description="Auto-detect provider from wrapped model (OpenAI, Anthropic, Google)",
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# Private attributes (not serialized)
|
|
146
|
+
_metrics_history: list = PrivateAttr(default_factory=list)
|
|
147
|
+
_total_tokens_saved: int = PrivateAttr(default=0)
|
|
148
|
+
_pipeline: Any = PrivateAttr(default=None)
|
|
149
|
+
_provider: Any = PrivateAttr(default=None)
|
|
150
|
+
|
|
151
|
+
# Pydantic v2 config for LangChain compatibility
|
|
152
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
153
|
+
|
|
154
|
+
def __init__(
|
|
155
|
+
self,
|
|
156
|
+
wrapped_model: BaseChatModel,
|
|
157
|
+
config: HeadroomConfig | None = None,
|
|
158
|
+
mode: HeadroomMode = HeadroomMode.OPTIMIZE,
|
|
159
|
+
auto_detect_provider: bool = True,
|
|
160
|
+
**kwargs: Any,
|
|
161
|
+
) -> None:
|
|
162
|
+
"""Initialize HeadroomChatModel.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
wrapped_model: Any LangChain BaseChatModel to wrap
|
|
166
|
+
config: HeadroomConfig for optimization settings
|
|
167
|
+
mode: HeadroomMode (AUDIT, OPTIMIZE, or SIMULATE)
|
|
168
|
+
auto_detect_provider: Auto-detect provider from wrapped model.
|
|
169
|
+
When True (default), automatically detects if the wrapped model
|
|
170
|
+
is OpenAI, Anthropic, Google, etc. and uses the appropriate
|
|
171
|
+
Headroom provider for accurate token counting.
|
|
172
|
+
**kwargs: Additional arguments passed to BaseChatModel
|
|
173
|
+
"""
|
|
174
|
+
_check_langchain_available()
|
|
175
|
+
|
|
176
|
+
super().__init__( # type: ignore[call-arg]
|
|
177
|
+
wrapped_model=wrapped_model,
|
|
178
|
+
headroom_config=config or HeadroomConfig(),
|
|
179
|
+
mode=mode,
|
|
180
|
+
auto_detect_provider=auto_detect_provider,
|
|
181
|
+
**kwargs,
|
|
182
|
+
)
|
|
183
|
+
self._metrics_history = []
|
|
184
|
+
self._total_tokens_saved = 0
|
|
185
|
+
self._pipeline = None
|
|
186
|
+
self._provider = None
|
|
187
|
+
|
|
188
|
+
@property
|
|
189
|
+
def _llm_type(self) -> str:
|
|
190
|
+
"""Return identifier for this LLM type."""
|
|
191
|
+
return f"headroom-{self.wrapped_model._llm_type}"
|
|
192
|
+
|
|
193
|
+
@property
|
|
194
|
+
def _identifying_params(self) -> dict[str, Any]:
|
|
195
|
+
"""Return identifying parameters."""
|
|
196
|
+
return {
|
|
197
|
+
"wrapped_model": self.wrapped_model._identifying_params,
|
|
198
|
+
"headroom_mode": self.mode.value,
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
@property
|
|
202
|
+
def pipeline(self) -> TransformPipeline:
|
|
203
|
+
"""Lazily initialize TransformPipeline.
|
|
204
|
+
|
|
205
|
+
When auto_detect_provider is True, automatically detects the provider
|
|
206
|
+
from the wrapped model's class path (e.g., ChatAnthropic -> AnthropicProvider).
|
|
207
|
+
"""
|
|
208
|
+
if self._pipeline is None:
|
|
209
|
+
if self.auto_detect_provider:
|
|
210
|
+
self._provider = get_headroom_provider(self.wrapped_model)
|
|
211
|
+
logger.debug(f"Auto-detected provider: {self._provider.__class__.__name__}")
|
|
212
|
+
else:
|
|
213
|
+
self._provider = OpenAIProvider()
|
|
214
|
+
self._pipeline = TransformPipeline(
|
|
215
|
+
config=self.headroom_config,
|
|
216
|
+
provider=self._provider,
|
|
217
|
+
)
|
|
218
|
+
pipeline: TransformPipeline = self._pipeline
|
|
219
|
+
return pipeline
|
|
220
|
+
|
|
221
|
+
@property
|
|
222
|
+
def total_tokens_saved(self) -> int:
|
|
223
|
+
"""Total tokens saved across all calls."""
|
|
224
|
+
return self._total_tokens_saved
|
|
225
|
+
|
|
226
|
+
@property
|
|
227
|
+
def metrics_history(self) -> list[OptimizationMetrics]:
|
|
228
|
+
"""History of optimization metrics."""
|
|
229
|
+
return self._metrics_history.copy()
|
|
230
|
+
|
|
231
|
+
def _convert_messages_to_openai(self, messages: list[BaseMessage]) -> list[dict[str, Any]]:
|
|
232
|
+
"""Convert LangChain messages to OpenAI format for Headroom."""
|
|
233
|
+
result = []
|
|
234
|
+
for msg in messages:
|
|
235
|
+
if isinstance(msg, SystemMessage):
|
|
236
|
+
result.append({"role": "system", "content": msg.content})
|
|
237
|
+
elif isinstance(msg, HumanMessage):
|
|
238
|
+
result.append({"role": "user", "content": msg.content})
|
|
239
|
+
elif isinstance(msg, AIMessage):
|
|
240
|
+
entry = {"role": "assistant", "content": msg.content}
|
|
241
|
+
if msg.tool_calls:
|
|
242
|
+
entry["tool_calls"] = [
|
|
243
|
+
{
|
|
244
|
+
"id": tc["id"],
|
|
245
|
+
"type": "function",
|
|
246
|
+
"function": {
|
|
247
|
+
"name": tc["name"],
|
|
248
|
+
"arguments": json.dumps(tc["args"]),
|
|
249
|
+
},
|
|
250
|
+
}
|
|
251
|
+
for tc in msg.tool_calls
|
|
252
|
+
]
|
|
253
|
+
result.append(entry)
|
|
254
|
+
elif isinstance(msg, ToolMessage):
|
|
255
|
+
result.append(
|
|
256
|
+
{
|
|
257
|
+
"role": "tool",
|
|
258
|
+
"tool_call_id": msg.tool_call_id,
|
|
259
|
+
"content": msg.content,
|
|
260
|
+
}
|
|
261
|
+
)
|
|
262
|
+
else:
|
|
263
|
+
# Generic fallback
|
|
264
|
+
result.append(
|
|
265
|
+
{
|
|
266
|
+
"role": getattr(msg, "type", "user"),
|
|
267
|
+
"content": msg.content,
|
|
268
|
+
}
|
|
269
|
+
)
|
|
270
|
+
return result
|
|
271
|
+
|
|
272
|
+
def _convert_messages_from_openai(self, messages: list[dict[str, Any]]) -> list[BaseMessage]:
|
|
273
|
+
"""Convert OpenAI format messages back to LangChain format."""
|
|
274
|
+
result: list[BaseMessage] = []
|
|
275
|
+
for msg in messages:
|
|
276
|
+
role = msg.get("role", "user")
|
|
277
|
+
content = msg.get("content", "")
|
|
278
|
+
|
|
279
|
+
if role == "system":
|
|
280
|
+
result.append(SystemMessage(content=content))
|
|
281
|
+
elif role == "user":
|
|
282
|
+
result.append(HumanMessage(content=content))
|
|
283
|
+
elif role == "assistant":
|
|
284
|
+
tool_calls = []
|
|
285
|
+
if "tool_calls" in msg:
|
|
286
|
+
for tc in msg["tool_calls"]:
|
|
287
|
+
tool_calls.append(
|
|
288
|
+
{
|
|
289
|
+
"id": tc["id"],
|
|
290
|
+
"name": tc["function"]["name"],
|
|
291
|
+
"args": json.loads(tc["function"]["arguments"]),
|
|
292
|
+
}
|
|
293
|
+
)
|
|
294
|
+
result.append(AIMessage(content=content, tool_calls=tool_calls))
|
|
295
|
+
elif role == "tool":
|
|
296
|
+
result.append(
|
|
297
|
+
ToolMessage(
|
|
298
|
+
content=content,
|
|
299
|
+
tool_call_id=msg.get("tool_call_id", ""),
|
|
300
|
+
)
|
|
301
|
+
)
|
|
302
|
+
return result
|
|
303
|
+
|
|
304
|
+
def _optimize_messages(
|
|
305
|
+
self, messages: list[BaseMessage]
|
|
306
|
+
) -> tuple[list[BaseMessage], OptimizationMetrics]:
|
|
307
|
+
"""Apply Headroom optimization to messages."""
|
|
308
|
+
request_id = str(uuid4())
|
|
309
|
+
|
|
310
|
+
# Convert to OpenAI format
|
|
311
|
+
openai_messages = self._convert_messages_to_openai(messages)
|
|
312
|
+
|
|
313
|
+
# Get model name from wrapped model
|
|
314
|
+
model = get_model_name_from_langchain(self.wrapped_model)
|
|
315
|
+
|
|
316
|
+
# Ensure pipeline is initialized (this also sets up provider)
|
|
317
|
+
_ = self.pipeline
|
|
318
|
+
|
|
319
|
+
# Get model context limit from provider
|
|
320
|
+
model_limit = self._provider.get_context_limit(model) if self._provider else 128000
|
|
321
|
+
|
|
322
|
+
# Ensure model is a string
|
|
323
|
+
model_str = str(model) if model else "gpt-4o"
|
|
324
|
+
|
|
325
|
+
# Apply Headroom transforms via pipeline
|
|
326
|
+
result = self.pipeline.apply(
|
|
327
|
+
messages=openai_messages,
|
|
328
|
+
model=model_str,
|
|
329
|
+
model_limit=model_limit,
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
# Create metrics
|
|
333
|
+
metrics = OptimizationMetrics(
|
|
334
|
+
request_id=request_id,
|
|
335
|
+
timestamp=datetime.now(),
|
|
336
|
+
tokens_before=result.tokens_before,
|
|
337
|
+
tokens_after=result.tokens_after,
|
|
338
|
+
tokens_saved=result.tokens_before - result.tokens_after,
|
|
339
|
+
savings_percent=(
|
|
340
|
+
(result.tokens_before - result.tokens_after) / result.tokens_before * 100
|
|
341
|
+
if result.tokens_before > 0
|
|
342
|
+
else 0
|
|
343
|
+
),
|
|
344
|
+
transforms_applied=result.transforms_applied,
|
|
345
|
+
model=model_str,
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
# Track metrics
|
|
349
|
+
self._metrics_history.append(metrics)
|
|
350
|
+
self._total_tokens_saved += metrics.tokens_saved
|
|
351
|
+
|
|
352
|
+
# Keep only last 100 metrics
|
|
353
|
+
if len(self._metrics_history) > 100:
|
|
354
|
+
self._metrics_history = self._metrics_history[-100:]
|
|
355
|
+
|
|
356
|
+
# Convert back to LangChain format
|
|
357
|
+
optimized_messages = self._convert_messages_from_openai(result.messages)
|
|
358
|
+
|
|
359
|
+
return optimized_messages, metrics
|
|
360
|
+
|
|
361
|
+
def _generate(
|
|
362
|
+
self,
|
|
363
|
+
messages: list[BaseMessage],
|
|
364
|
+
stop: list[str] | None = None,
|
|
365
|
+
run_manager: Any = None,
|
|
366
|
+
**kwargs: Any,
|
|
367
|
+
) -> ChatResult:
|
|
368
|
+
"""Generate response with Headroom optimization.
|
|
369
|
+
|
|
370
|
+
This is the core method called by invoke(), batch(), etc.
|
|
371
|
+
"""
|
|
372
|
+
# Optimize messages
|
|
373
|
+
optimized_messages, metrics = self._optimize_messages(messages)
|
|
374
|
+
|
|
375
|
+
logger.info(
|
|
376
|
+
f"Headroom optimized: {metrics.tokens_before} -> {metrics.tokens_after} tokens "
|
|
377
|
+
f"({metrics.savings_percent:.1f}% saved)"
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
# Call wrapped model with optimized messages
|
|
381
|
+
result: ChatResult = self.wrapped_model._generate(
|
|
382
|
+
optimized_messages,
|
|
383
|
+
stop=stop,
|
|
384
|
+
run_manager=run_manager,
|
|
385
|
+
**kwargs,
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
return result
|
|
389
|
+
|
|
390
|
+
def _stream(
|
|
391
|
+
self,
|
|
392
|
+
messages: list[BaseMessage],
|
|
393
|
+
stop: list[str] | None = None,
|
|
394
|
+
run_manager: Any = None,
|
|
395
|
+
**kwargs: Any,
|
|
396
|
+
) -> Iterator[ChatGenerationChunk]:
|
|
397
|
+
"""Stream response with Headroom optimization."""
|
|
398
|
+
# Optimize messages
|
|
399
|
+
optimized_messages, metrics = self._optimize_messages(messages)
|
|
400
|
+
|
|
401
|
+
logger.info(
|
|
402
|
+
f"Headroom optimized (streaming): {metrics.tokens_before} -> "
|
|
403
|
+
f"{metrics.tokens_after} tokens"
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
# Stream from wrapped model
|
|
407
|
+
yield from self.wrapped_model._stream(
|
|
408
|
+
optimized_messages,
|
|
409
|
+
stop=stop,
|
|
410
|
+
run_manager=run_manager,
|
|
411
|
+
**kwargs,
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
async def _agenerate(
|
|
415
|
+
self,
|
|
416
|
+
messages: list[BaseMessage],
|
|
417
|
+
stop: list[str] | None = None,
|
|
418
|
+
run_manager: Any = None,
|
|
419
|
+
**kwargs: Any,
|
|
420
|
+
) -> ChatResult:
|
|
421
|
+
"""Async generate response with Headroom optimization.
|
|
422
|
+
|
|
423
|
+
This enables `await model.ainvoke(messages)` to work correctly.
|
|
424
|
+
The optimization step runs in a thread executor since it's CPU-bound.
|
|
425
|
+
"""
|
|
426
|
+
# Run optimization in executor (CPU-bound)
|
|
427
|
+
loop = asyncio.get_event_loop()
|
|
428
|
+
optimized_messages, metrics = await loop.run_in_executor(
|
|
429
|
+
None, self._optimize_messages, messages
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
logger.info(
|
|
433
|
+
f"Headroom optimized (async): {metrics.tokens_before} -> {metrics.tokens_after} tokens "
|
|
434
|
+
f"({metrics.savings_percent:.1f}% saved)"
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
# Call wrapped model's async generate
|
|
438
|
+
result: ChatResult = await self.wrapped_model._agenerate(
|
|
439
|
+
optimized_messages,
|
|
440
|
+
stop=stop,
|
|
441
|
+
run_manager=run_manager,
|
|
442
|
+
**kwargs,
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
return result
|
|
446
|
+
|
|
447
|
+
async def _astream(
|
|
448
|
+
self,
|
|
449
|
+
messages: list[BaseMessage],
|
|
450
|
+
stop: list[str] | None = None,
|
|
451
|
+
run_manager: Any = None,
|
|
452
|
+
**kwargs: Any,
|
|
453
|
+
) -> AsyncIterator[ChatGenerationChunk]:
|
|
454
|
+
"""Async stream response with Headroom optimization.
|
|
455
|
+
|
|
456
|
+
This enables `async for chunk in model.astream(messages)` to work correctly.
|
|
457
|
+
"""
|
|
458
|
+
# Run optimization in executor (CPU-bound)
|
|
459
|
+
loop = asyncio.get_event_loop()
|
|
460
|
+
optimized_messages, metrics = await loop.run_in_executor(
|
|
461
|
+
None, self._optimize_messages, messages
|
|
462
|
+
)
|
|
463
|
+
|
|
464
|
+
logger.info(
|
|
465
|
+
f"Headroom optimized (async streaming): {metrics.tokens_before} -> "
|
|
466
|
+
f"{metrics.tokens_after} tokens"
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
# Async stream from wrapped model
|
|
470
|
+
async for chunk in self.wrapped_model._astream(
|
|
471
|
+
optimized_messages,
|
|
472
|
+
stop=stop,
|
|
473
|
+
run_manager=run_manager,
|
|
474
|
+
**kwargs,
|
|
475
|
+
):
|
|
476
|
+
yield chunk
|
|
477
|
+
|
|
478
|
+
def bind_tools(self, tools: Sequence[Any], **kwargs: Any) -> HeadroomChatModel:
|
|
479
|
+
"""Bind tools to the wrapped model."""
|
|
480
|
+
new_wrapped = self.wrapped_model.bind_tools(tools, **kwargs)
|
|
481
|
+
return HeadroomChatModel(
|
|
482
|
+
wrapped_model=new_wrapped,
|
|
483
|
+
config=self.headroom_config,
|
|
484
|
+
mode=self.mode,
|
|
485
|
+
auto_detect_provider=self.auto_detect_provider,
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
def get_savings_summary(self) -> dict[str, Any]:
|
|
489
|
+
"""Get summary of token savings."""
|
|
490
|
+
if not self._metrics_history:
|
|
491
|
+
return {
|
|
492
|
+
"total_requests": 0,
|
|
493
|
+
"total_tokens_saved": 0,
|
|
494
|
+
"average_savings_percent": 0,
|
|
495
|
+
}
|
|
496
|
+
|
|
497
|
+
return {
|
|
498
|
+
"total_requests": len(self._metrics_history),
|
|
499
|
+
"total_tokens_saved": self._total_tokens_saved,
|
|
500
|
+
"average_savings_percent": sum(m.savings_percent for m in self._metrics_history)
|
|
501
|
+
/ len(self._metrics_history),
|
|
502
|
+
"total_tokens_before": sum(m.tokens_before for m in self._metrics_history),
|
|
503
|
+
"total_tokens_after": sum(m.tokens_after for m in self._metrics_history),
|
|
504
|
+
}
|
|
505
|
+
|
|
506
|
+
|
|
507
|
+
class HeadroomCallbackHandler(BaseCallbackHandler):
|
|
508
|
+
"""LangChain callback handler for Headroom metrics and observability.
|
|
509
|
+
|
|
510
|
+
NOTE: Callbacks CANNOT modify messages in LangChain (by design).
|
|
511
|
+
Use HeadroomChatModel for actual optimization. This handler is for:
|
|
512
|
+
|
|
513
|
+
1. Tracking token usage across chains
|
|
514
|
+
2. Logging optimization metrics
|
|
515
|
+
3. Alerting on high token usage
|
|
516
|
+
4. Integration with observability platforms
|
|
517
|
+
|
|
518
|
+
Example:
|
|
519
|
+
from langchain_openai import ChatOpenAI
|
|
520
|
+
from headroom.integrations import HeadroomCallbackHandler
|
|
521
|
+
|
|
522
|
+
handler = HeadroomCallbackHandler(
|
|
523
|
+
log_level="INFO",
|
|
524
|
+
token_alert_threshold=10000,
|
|
525
|
+
)
|
|
526
|
+
|
|
527
|
+
llm = ChatOpenAI(model="gpt-4o", callbacks=[handler])
|
|
528
|
+
response = llm.invoke("Hello!")
|
|
529
|
+
|
|
530
|
+
# Check metrics
|
|
531
|
+
print(f"Total tokens: {handler.total_tokens}")
|
|
532
|
+
print(f"Alerts: {handler.alerts}")
|
|
533
|
+
"""
|
|
534
|
+
|
|
535
|
+
def __init__(
|
|
536
|
+
self,
|
|
537
|
+
log_level: str = "INFO",
|
|
538
|
+
token_alert_threshold: int | None = None,
|
|
539
|
+
cost_alert_threshold: float | None = None,
|
|
540
|
+
):
|
|
541
|
+
"""Initialize callback handler.
|
|
542
|
+
|
|
543
|
+
Args:
|
|
544
|
+
log_level: Logging level for metrics ("DEBUG", "INFO", "WARNING")
|
|
545
|
+
token_alert_threshold: Alert if request exceeds this many tokens
|
|
546
|
+
cost_alert_threshold: Alert if estimated cost exceeds this amount
|
|
547
|
+
"""
|
|
548
|
+
_check_langchain_available()
|
|
549
|
+
|
|
550
|
+
self.log_level = log_level
|
|
551
|
+
self.token_alert_threshold = token_alert_threshold
|
|
552
|
+
self.cost_alert_threshold = cost_alert_threshold
|
|
553
|
+
|
|
554
|
+
# Metrics tracking
|
|
555
|
+
self._requests: list[dict[str, Any]] = []
|
|
556
|
+
self._total_tokens = 0
|
|
557
|
+
self._alerts: list[str] = []
|
|
558
|
+
self._current_request: dict[str, Any] | None = None
|
|
559
|
+
|
|
560
|
+
@property
|
|
561
|
+
def total_tokens(self) -> int:
|
|
562
|
+
"""Total tokens used across all requests."""
|
|
563
|
+
return self._total_tokens
|
|
564
|
+
|
|
565
|
+
@property
|
|
566
|
+
def total_requests(self) -> int:
|
|
567
|
+
"""Total number of requests tracked."""
|
|
568
|
+
return len(self._requests)
|
|
569
|
+
|
|
570
|
+
@property
|
|
571
|
+
def alerts(self) -> list[str]:
|
|
572
|
+
"""List of alerts triggered."""
|
|
573
|
+
return self._alerts.copy()
|
|
574
|
+
|
|
575
|
+
@property
|
|
576
|
+
def requests(self) -> list[dict[str, Any]]:
|
|
577
|
+
"""List of request metrics."""
|
|
578
|
+
return self._requests.copy()
|
|
579
|
+
|
|
580
|
+
def on_llm_start(
|
|
581
|
+
self,
|
|
582
|
+
serialized: dict[str, Any],
|
|
583
|
+
prompts: list[str],
|
|
584
|
+
**kwargs: Any,
|
|
585
|
+
) -> None:
|
|
586
|
+
"""Called when LLM starts processing."""
|
|
587
|
+
self._current_request = {
|
|
588
|
+
"start_time": datetime.now(),
|
|
589
|
+
"model": serialized.get("name", "unknown"),
|
|
590
|
+
"prompt_count": len(prompts),
|
|
591
|
+
"estimated_input_tokens": sum(len(p) // 4 for p in prompts), # Rough estimate
|
|
592
|
+
}
|
|
593
|
+
|
|
594
|
+
if self.log_level == "DEBUG":
|
|
595
|
+
logger.debug(f"LLM request started: {self._current_request}")
|
|
596
|
+
|
|
597
|
+
def on_chat_model_start(
|
|
598
|
+
self,
|
|
599
|
+
serialized: dict[str, Any],
|
|
600
|
+
messages: list[list[BaseMessage]],
|
|
601
|
+
**kwargs: Any,
|
|
602
|
+
) -> None:
|
|
603
|
+
"""Called when chat model starts processing."""
|
|
604
|
+
# Estimate tokens from messages
|
|
605
|
+
total_content = ""
|
|
606
|
+
for msg_list in messages:
|
|
607
|
+
for msg in msg_list:
|
|
608
|
+
content = msg.content if isinstance(msg.content, str) else str(msg.content)
|
|
609
|
+
total_content += content
|
|
610
|
+
|
|
611
|
+
estimated_tokens = len(total_content) // 4 # Rough estimate
|
|
612
|
+
|
|
613
|
+
self._current_request = {
|
|
614
|
+
"start_time": datetime.now(),
|
|
615
|
+
"model": serialized.get("name", serialized.get("id", ["unknown"])[-1]),
|
|
616
|
+
"message_count": sum(len(ml) for ml in messages),
|
|
617
|
+
"estimated_input_tokens": estimated_tokens,
|
|
618
|
+
}
|
|
619
|
+
|
|
620
|
+
# Check token alert
|
|
621
|
+
if self.token_alert_threshold and estimated_tokens > self.token_alert_threshold:
|
|
622
|
+
alert = (
|
|
623
|
+
f"Token alert: {estimated_tokens} tokens exceeds "
|
|
624
|
+
f"threshold {self.token_alert_threshold}"
|
|
625
|
+
)
|
|
626
|
+
self._alerts.append(alert)
|
|
627
|
+
logger.warning(alert)
|
|
628
|
+
|
|
629
|
+
if self.log_level in ("DEBUG", "INFO"):
|
|
630
|
+
logger.log(
|
|
631
|
+
logging.DEBUG if self.log_level == "DEBUG" else logging.INFO,
|
|
632
|
+
f"Chat model request: ~{estimated_tokens} input tokens",
|
|
633
|
+
)
|
|
634
|
+
|
|
635
|
+
def on_llm_end(self, response: Any, **kwargs: Any) -> None:
|
|
636
|
+
"""Called when LLM finishes processing."""
|
|
637
|
+
if self._current_request is None:
|
|
638
|
+
return
|
|
639
|
+
|
|
640
|
+
# Extract token usage from response if available
|
|
641
|
+
token_usage = {}
|
|
642
|
+
if hasattr(response, "llm_output") and response.llm_output:
|
|
643
|
+
token_usage = response.llm_output.get("token_usage", {})
|
|
644
|
+
|
|
645
|
+
self._current_request["end_time"] = datetime.now()
|
|
646
|
+
self._current_request["duration_ms"] = (
|
|
647
|
+
self._current_request["end_time"] - self._current_request["start_time"]
|
|
648
|
+
).total_seconds() * 1000
|
|
649
|
+
|
|
650
|
+
if token_usage:
|
|
651
|
+
self._current_request["input_tokens"] = token_usage.get("prompt_tokens", 0)
|
|
652
|
+
self._current_request["output_tokens"] = token_usage.get("completion_tokens", 0)
|
|
653
|
+
self._current_request["total_tokens"] = token_usage.get("total_tokens", 0)
|
|
654
|
+
self._total_tokens += self._current_request["total_tokens"]
|
|
655
|
+
|
|
656
|
+
self._requests.append(self._current_request)
|
|
657
|
+
|
|
658
|
+
# Keep only last 1000 requests
|
|
659
|
+
if len(self._requests) > 1000:
|
|
660
|
+
self._requests = self._requests[-1000:]
|
|
661
|
+
|
|
662
|
+
if self.log_level in ("DEBUG", "INFO"):
|
|
663
|
+
tokens_info = f"{self._current_request.get('total_tokens', 'unknown')} tokens"
|
|
664
|
+
duration = f"{self._current_request['duration_ms']:.0f}ms"
|
|
665
|
+
logger.log(
|
|
666
|
+
logging.DEBUG if self.log_level == "DEBUG" else logging.INFO,
|
|
667
|
+
f"LLM request completed: {tokens_info} in {duration}",
|
|
668
|
+
)
|
|
669
|
+
|
|
670
|
+
self._current_request = None
|
|
671
|
+
|
|
672
|
+
def on_llm_error(
|
|
673
|
+
self,
|
|
674
|
+
error: BaseException,
|
|
675
|
+
*,
|
|
676
|
+
run_id: UUID,
|
|
677
|
+
parent_run_id: UUID | None = None,
|
|
678
|
+
**kwargs: Any,
|
|
679
|
+
) -> Any:
|
|
680
|
+
"""Called when LLM encounters an error."""
|
|
681
|
+
if self._current_request:
|
|
682
|
+
self._current_request["error"] = str(error)
|
|
683
|
+
self._current_request["end_time"] = datetime.now()
|
|
684
|
+
self._requests.append(self._current_request)
|
|
685
|
+
self._current_request = None
|
|
686
|
+
|
|
687
|
+
logger.error(f"LLM error: {error}")
|
|
688
|
+
|
|
689
|
+
def get_summary(self) -> dict[str, Any]:
|
|
690
|
+
"""Get summary of all tracked requests."""
|
|
691
|
+
if not self._requests:
|
|
692
|
+
return {
|
|
693
|
+
"total_requests": 0,
|
|
694
|
+
"total_tokens": 0,
|
|
695
|
+
"average_tokens": 0,
|
|
696
|
+
"average_duration_ms": 0,
|
|
697
|
+
"errors": 0,
|
|
698
|
+
"alerts": len(self._alerts),
|
|
699
|
+
}
|
|
700
|
+
|
|
701
|
+
successful = [r for r in self._requests if "error" not in r]
|
|
702
|
+
total_tokens = sum(r.get("total_tokens", 0) for r in successful)
|
|
703
|
+
|
|
704
|
+
return {
|
|
705
|
+
"total_requests": len(self._requests),
|
|
706
|
+
"successful_requests": len(successful),
|
|
707
|
+
"total_tokens": total_tokens,
|
|
708
|
+
"average_tokens": total_tokens / len(successful) if successful else 0,
|
|
709
|
+
"average_duration_ms": (
|
|
710
|
+
sum(r.get("duration_ms", 0) for r in successful) / len(successful)
|
|
711
|
+
if successful
|
|
712
|
+
else 0
|
|
713
|
+
),
|
|
714
|
+
"errors": len(self._requests) - len(successful),
|
|
715
|
+
"alerts": len(self._alerts),
|
|
716
|
+
}
|
|
717
|
+
|
|
718
|
+
def reset(self) -> None:
|
|
719
|
+
"""Reset all tracked metrics."""
|
|
720
|
+
self._requests = []
|
|
721
|
+
self._total_tokens = 0
|
|
722
|
+
self._alerts = []
|
|
723
|
+
self._current_request = None
|
|
724
|
+
|
|
725
|
+
|
|
726
|
+
class HeadroomRunnable:
|
|
727
|
+
"""LCEL-compatible Runnable for Headroom optimization.
|
|
728
|
+
|
|
729
|
+
Use this to add Headroom optimization to any LangChain chain using LCEL.
|
|
730
|
+
|
|
731
|
+
Example:
|
|
732
|
+
from langchain_openai import ChatOpenAI
|
|
733
|
+
from langchain_core.prompts import ChatPromptTemplate
|
|
734
|
+
from headroom.integrations import HeadroomRunnable
|
|
735
|
+
|
|
736
|
+
prompt = ChatPromptTemplate.from_messages([
|
|
737
|
+
("system", "You are a helpful assistant."),
|
|
738
|
+
("user", "{input}"),
|
|
739
|
+
])
|
|
740
|
+
llm = ChatOpenAI(model="gpt-4o")
|
|
741
|
+
|
|
742
|
+
# Add Headroom optimization to chain
|
|
743
|
+
chain = prompt | HeadroomRunnable() | llm
|
|
744
|
+
response = chain.invoke({"input": "Hello!"})
|
|
745
|
+
"""
|
|
746
|
+
|
|
747
|
+
def __init__(
|
|
748
|
+
self,
|
|
749
|
+
config: HeadroomConfig | None = None,
|
|
750
|
+
mode: HeadroomMode = HeadroomMode.OPTIMIZE,
|
|
751
|
+
):
|
|
752
|
+
"""Initialize HeadroomRunnable.
|
|
753
|
+
|
|
754
|
+
Args:
|
|
755
|
+
config: HeadroomConfig for optimization settings
|
|
756
|
+
mode: HeadroomMode (AUDIT, OPTIMIZE, or SIMULATE)
|
|
757
|
+
"""
|
|
758
|
+
_check_langchain_available()
|
|
759
|
+
|
|
760
|
+
self.config = config or HeadroomConfig()
|
|
761
|
+
self.mode = mode
|
|
762
|
+
self._pipeline: TransformPipeline | None = None
|
|
763
|
+
self._provider: OpenAIProvider | None = None
|
|
764
|
+
self._metrics_history: list[OptimizationMetrics] = []
|
|
765
|
+
|
|
766
|
+
@property
|
|
767
|
+
def pipeline(self) -> TransformPipeline:
|
|
768
|
+
"""Lazily initialize TransformPipeline."""
|
|
769
|
+
if self._pipeline is None:
|
|
770
|
+
self._provider = OpenAIProvider()
|
|
771
|
+
self._pipeline = TransformPipeline(
|
|
772
|
+
config=self.config,
|
|
773
|
+
provider=self._provider,
|
|
774
|
+
)
|
|
775
|
+
return self._pipeline
|
|
776
|
+
|
|
777
|
+
def __or__(self, other: Any) -> Any:
|
|
778
|
+
"""Support pipe operator for LCEL composition."""
|
|
779
|
+
from langchain_core.runnables import RunnableSequence
|
|
780
|
+
|
|
781
|
+
return RunnableSequence(first=self.as_runnable(), last=other)
|
|
782
|
+
|
|
783
|
+
def __ror__(self, other: Any) -> Any:
|
|
784
|
+
"""Support reverse pipe operator."""
|
|
785
|
+
from langchain_core.runnables import RunnableSequence
|
|
786
|
+
|
|
787
|
+
return RunnableSequence(first=other, last=self.as_runnable())
|
|
788
|
+
|
|
789
|
+
def as_runnable(self) -> RunnableLambda:
|
|
790
|
+
"""Convert to LangChain Runnable."""
|
|
791
|
+
return RunnableLambda(self._optimize)
|
|
792
|
+
|
|
793
|
+
def _optimize(self, input_data: Any) -> Any:
|
|
794
|
+
"""Optimize input messages."""
|
|
795
|
+
# Handle different input types
|
|
796
|
+
if isinstance(input_data, list):
|
|
797
|
+
messages = input_data
|
|
798
|
+
elif hasattr(input_data, "messages"):
|
|
799
|
+
messages = input_data.messages
|
|
800
|
+
elif hasattr(input_data, "to_messages"):
|
|
801
|
+
messages = input_data.to_messages()
|
|
802
|
+
else:
|
|
803
|
+
# Can't optimize, pass through
|
|
804
|
+
return input_data
|
|
805
|
+
|
|
806
|
+
# Convert messages to OpenAI format
|
|
807
|
+
openai_messages = []
|
|
808
|
+
for msg in messages:
|
|
809
|
+
if isinstance(msg, SystemMessage):
|
|
810
|
+
openai_messages.append({"role": "system", "content": msg.content})
|
|
811
|
+
elif isinstance(msg, HumanMessage):
|
|
812
|
+
openai_messages.append({"role": "user", "content": msg.content})
|
|
813
|
+
elif isinstance(msg, AIMessage):
|
|
814
|
+
openai_messages.append({"role": "assistant", "content": msg.content})
|
|
815
|
+
elif isinstance(msg, ToolMessage):
|
|
816
|
+
openai_messages.append(
|
|
817
|
+
{
|
|
818
|
+
"role": "tool",
|
|
819
|
+
"tool_call_id": msg.tool_call_id,
|
|
820
|
+
"content": msg.content,
|
|
821
|
+
}
|
|
822
|
+
)
|
|
823
|
+
elif hasattr(msg, "type") and hasattr(msg, "content"):
|
|
824
|
+
openai_messages.append(
|
|
825
|
+
{
|
|
826
|
+
"role": msg.type,
|
|
827
|
+
"content": msg.content,
|
|
828
|
+
}
|
|
829
|
+
)
|
|
830
|
+
|
|
831
|
+
# Get model context limit
|
|
832
|
+
model = "gpt-4o" # Default model for estimation
|
|
833
|
+
model_limit = self._provider.get_context_limit(model) if self._provider else 128000
|
|
834
|
+
|
|
835
|
+
# Apply Headroom transforms via pipeline
|
|
836
|
+
result = self.pipeline.apply(
|
|
837
|
+
messages=openai_messages,
|
|
838
|
+
model=model,
|
|
839
|
+
model_limit=model_limit,
|
|
840
|
+
)
|
|
841
|
+
|
|
842
|
+
# Track metrics
|
|
843
|
+
metrics = OptimizationMetrics(
|
|
844
|
+
request_id=str(uuid4()),
|
|
845
|
+
timestamp=datetime.now(),
|
|
846
|
+
tokens_before=result.tokens_before,
|
|
847
|
+
tokens_after=result.tokens_after,
|
|
848
|
+
tokens_saved=result.tokens_before - result.tokens_after,
|
|
849
|
+
savings_percent=(
|
|
850
|
+
(result.tokens_before - result.tokens_after) / result.tokens_before * 100
|
|
851
|
+
if result.tokens_before > 0
|
|
852
|
+
else 0
|
|
853
|
+
),
|
|
854
|
+
transforms_applied=result.transforms_applied,
|
|
855
|
+
model="gpt-4o",
|
|
856
|
+
)
|
|
857
|
+
self._metrics_history.append(metrics)
|
|
858
|
+
|
|
859
|
+
# Convert back to LangChain messages
|
|
860
|
+
output_messages: list[BaseMessage] = []
|
|
861
|
+
for msg in result.messages:
|
|
862
|
+
role = msg.get("role", "user")
|
|
863
|
+
content = msg.get("content", "")
|
|
864
|
+
|
|
865
|
+
if role == "system":
|
|
866
|
+
output_messages.append(SystemMessage(content=content))
|
|
867
|
+
elif role == "user":
|
|
868
|
+
output_messages.append(HumanMessage(content=content))
|
|
869
|
+
elif role == "assistant":
|
|
870
|
+
output_messages.append(AIMessage(content=content))
|
|
871
|
+
elif role == "tool":
|
|
872
|
+
output_messages.append(
|
|
873
|
+
ToolMessage(
|
|
874
|
+
content=content,
|
|
875
|
+
tool_call_id=msg.get("tool_call_id", ""),
|
|
876
|
+
)
|
|
877
|
+
)
|
|
878
|
+
|
|
879
|
+
return output_messages
|
|
880
|
+
|
|
881
|
+
|
|
882
|
+
def optimize_messages(
|
|
883
|
+
messages: list[BaseMessage],
|
|
884
|
+
config: HeadroomConfig | None = None,
|
|
885
|
+
mode: HeadroomMode = HeadroomMode.OPTIMIZE,
|
|
886
|
+
model: str = "gpt-4o",
|
|
887
|
+
) -> tuple[list[BaseMessage], dict[str, Any]]:
|
|
888
|
+
"""Standalone function to optimize LangChain messages.
|
|
889
|
+
|
|
890
|
+
Use this for manual optimization when you need fine-grained control.
|
|
891
|
+
|
|
892
|
+
Args:
|
|
893
|
+
messages: List of LangChain BaseMessage objects
|
|
894
|
+
config: HeadroomConfig for optimization settings
|
|
895
|
+
mode: HeadroomMode (AUDIT, OPTIMIZE, or SIMULATE)
|
|
896
|
+
model: Model name for token estimation
|
|
897
|
+
|
|
898
|
+
Returns:
|
|
899
|
+
Tuple of (optimized_messages, metrics_dict)
|
|
900
|
+
|
|
901
|
+
Example:
|
|
902
|
+
from langchain_core.messages import HumanMessage, SystemMessage
|
|
903
|
+
from headroom.integrations import optimize_messages
|
|
904
|
+
|
|
905
|
+
messages = [
|
|
906
|
+
SystemMessage(content="You are helpful."),
|
|
907
|
+
HumanMessage(content="What is 2+2?"),
|
|
908
|
+
]
|
|
909
|
+
|
|
910
|
+
optimized, metrics = optimize_messages(messages)
|
|
911
|
+
print(f"Saved {metrics['tokens_saved']} tokens")
|
|
912
|
+
"""
|
|
913
|
+
_check_langchain_available()
|
|
914
|
+
|
|
915
|
+
config = config or HeadroomConfig()
|
|
916
|
+
provider = OpenAIProvider()
|
|
917
|
+
pipeline = TransformPipeline(config=config, provider=provider)
|
|
918
|
+
|
|
919
|
+
# Convert to OpenAI format
|
|
920
|
+
openai_messages = []
|
|
921
|
+
for msg in messages:
|
|
922
|
+
if isinstance(msg, SystemMessage):
|
|
923
|
+
openai_messages.append({"role": "system", "content": msg.content})
|
|
924
|
+
elif isinstance(msg, HumanMessage):
|
|
925
|
+
openai_messages.append({"role": "user", "content": msg.content})
|
|
926
|
+
elif isinstance(msg, AIMessage):
|
|
927
|
+
entry = {"role": "assistant", "content": msg.content}
|
|
928
|
+
if hasattr(msg, "tool_calls") and msg.tool_calls:
|
|
929
|
+
entry["tool_calls"] = [
|
|
930
|
+
{
|
|
931
|
+
"id": tc["id"],
|
|
932
|
+
"type": "function",
|
|
933
|
+
"function": {
|
|
934
|
+
"name": tc["name"],
|
|
935
|
+
"arguments": json.dumps(tc["args"]),
|
|
936
|
+
},
|
|
937
|
+
}
|
|
938
|
+
for tc in msg.tool_calls
|
|
939
|
+
]
|
|
940
|
+
openai_messages.append(entry)
|
|
941
|
+
elif isinstance(msg, ToolMessage):
|
|
942
|
+
openai_messages.append(
|
|
943
|
+
{
|
|
944
|
+
"role": "tool",
|
|
945
|
+
"tool_call_id": msg.tool_call_id,
|
|
946
|
+
"content": msg.content,
|
|
947
|
+
}
|
|
948
|
+
)
|
|
949
|
+
|
|
950
|
+
# Get model context limit
|
|
951
|
+
model_limit = provider.get_context_limit(model)
|
|
952
|
+
|
|
953
|
+
# Apply transforms via pipeline
|
|
954
|
+
result = pipeline.apply(
|
|
955
|
+
messages=openai_messages,
|
|
956
|
+
model=model,
|
|
957
|
+
model_limit=model_limit,
|
|
958
|
+
)
|
|
959
|
+
|
|
960
|
+
# Convert back
|
|
961
|
+
output_messages: list[BaseMessage] = []
|
|
962
|
+
for openai_msg in result.messages:
|
|
963
|
+
role = openai_msg.get("role", "user")
|
|
964
|
+
content = openai_msg.get("content", "")
|
|
965
|
+
|
|
966
|
+
if role == "system":
|
|
967
|
+
output_messages.append(SystemMessage(content=content))
|
|
968
|
+
elif role == "user":
|
|
969
|
+
output_messages.append(HumanMessage(content=content))
|
|
970
|
+
elif role == "assistant":
|
|
971
|
+
tool_calls = []
|
|
972
|
+
if "tool_calls" in openai_msg:
|
|
973
|
+
for tc in openai_msg["tool_calls"]:
|
|
974
|
+
tool_calls.append(
|
|
975
|
+
{
|
|
976
|
+
"id": tc["id"],
|
|
977
|
+
"name": tc["function"]["name"],
|
|
978
|
+
"args": json.loads(tc["function"]["arguments"]),
|
|
979
|
+
}
|
|
980
|
+
)
|
|
981
|
+
output_messages.append(AIMessage(content=content, tool_calls=tool_calls))
|
|
982
|
+
elif role == "tool":
|
|
983
|
+
output_messages.append(
|
|
984
|
+
ToolMessage(
|
|
985
|
+
content=content,
|
|
986
|
+
tool_call_id=openai_msg.get("tool_call_id", ""),
|
|
987
|
+
)
|
|
988
|
+
)
|
|
989
|
+
|
|
990
|
+
metrics = {
|
|
991
|
+
"tokens_before": result.tokens_before,
|
|
992
|
+
"tokens_after": result.tokens_after,
|
|
993
|
+
"tokens_saved": result.tokens_before - result.tokens_after,
|
|
994
|
+
"savings_percent": (
|
|
995
|
+
(result.tokens_before - result.tokens_after) / result.tokens_before * 100
|
|
996
|
+
if result.tokens_before > 0
|
|
997
|
+
else 0
|
|
998
|
+
),
|
|
999
|
+
"transforms_applied": result.transforms_applied,
|
|
1000
|
+
}
|
|
1001
|
+
|
|
1002
|
+
return output_messages, metrics
|