headroom-ai 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (114) hide show
  1. headroom/__init__.py +212 -0
  2. headroom/cache/__init__.py +76 -0
  3. headroom/cache/anthropic.py +517 -0
  4. headroom/cache/base.py +342 -0
  5. headroom/cache/compression_feedback.py +613 -0
  6. headroom/cache/compression_store.py +814 -0
  7. headroom/cache/dynamic_detector.py +1026 -0
  8. headroom/cache/google.py +884 -0
  9. headroom/cache/openai.py +584 -0
  10. headroom/cache/registry.py +175 -0
  11. headroom/cache/semantic.py +451 -0
  12. headroom/ccr/__init__.py +77 -0
  13. headroom/ccr/context_tracker.py +582 -0
  14. headroom/ccr/mcp_server.py +319 -0
  15. headroom/ccr/response_handler.py +772 -0
  16. headroom/ccr/tool_injection.py +415 -0
  17. headroom/cli.py +219 -0
  18. headroom/client.py +977 -0
  19. headroom/compression/__init__.py +42 -0
  20. headroom/compression/detector.py +424 -0
  21. headroom/compression/handlers/__init__.py +22 -0
  22. headroom/compression/handlers/base.py +219 -0
  23. headroom/compression/handlers/code_handler.py +506 -0
  24. headroom/compression/handlers/json_handler.py +418 -0
  25. headroom/compression/masks.py +345 -0
  26. headroom/compression/universal.py +465 -0
  27. headroom/config.py +474 -0
  28. headroom/exceptions.py +192 -0
  29. headroom/integrations/__init__.py +159 -0
  30. headroom/integrations/agno/__init__.py +53 -0
  31. headroom/integrations/agno/hooks.py +345 -0
  32. headroom/integrations/agno/model.py +625 -0
  33. headroom/integrations/agno/providers.py +154 -0
  34. headroom/integrations/langchain/__init__.py +106 -0
  35. headroom/integrations/langchain/agents.py +326 -0
  36. headroom/integrations/langchain/chat_model.py +1002 -0
  37. headroom/integrations/langchain/langsmith.py +324 -0
  38. headroom/integrations/langchain/memory.py +319 -0
  39. headroom/integrations/langchain/providers.py +200 -0
  40. headroom/integrations/langchain/retriever.py +371 -0
  41. headroom/integrations/langchain/streaming.py +341 -0
  42. headroom/integrations/mcp/__init__.py +37 -0
  43. headroom/integrations/mcp/server.py +533 -0
  44. headroom/memory/__init__.py +37 -0
  45. headroom/memory/extractor.py +390 -0
  46. headroom/memory/fast_store.py +621 -0
  47. headroom/memory/fast_wrapper.py +311 -0
  48. headroom/memory/inline_extractor.py +229 -0
  49. headroom/memory/store.py +434 -0
  50. headroom/memory/worker.py +260 -0
  51. headroom/memory/wrapper.py +321 -0
  52. headroom/models/__init__.py +39 -0
  53. headroom/models/registry.py +687 -0
  54. headroom/parser.py +293 -0
  55. headroom/pricing/__init__.py +51 -0
  56. headroom/pricing/anthropic_prices.py +81 -0
  57. headroom/pricing/litellm_pricing.py +113 -0
  58. headroom/pricing/openai_prices.py +91 -0
  59. headroom/pricing/registry.py +188 -0
  60. headroom/providers/__init__.py +61 -0
  61. headroom/providers/anthropic.py +621 -0
  62. headroom/providers/base.py +131 -0
  63. headroom/providers/cohere.py +362 -0
  64. headroom/providers/google.py +427 -0
  65. headroom/providers/litellm.py +297 -0
  66. headroom/providers/openai.py +566 -0
  67. headroom/providers/openai_compatible.py +521 -0
  68. headroom/proxy/__init__.py +19 -0
  69. headroom/proxy/server.py +2683 -0
  70. headroom/py.typed +0 -0
  71. headroom/relevance/__init__.py +124 -0
  72. headroom/relevance/base.py +106 -0
  73. headroom/relevance/bm25.py +255 -0
  74. headroom/relevance/embedding.py +255 -0
  75. headroom/relevance/hybrid.py +259 -0
  76. headroom/reporting/__init__.py +5 -0
  77. headroom/reporting/generator.py +549 -0
  78. headroom/storage/__init__.py +41 -0
  79. headroom/storage/base.py +125 -0
  80. headroom/storage/jsonl.py +220 -0
  81. headroom/storage/sqlite.py +289 -0
  82. headroom/telemetry/__init__.py +91 -0
  83. headroom/telemetry/collector.py +764 -0
  84. headroom/telemetry/models.py +880 -0
  85. headroom/telemetry/toin.py +1579 -0
  86. headroom/tokenizer.py +80 -0
  87. headroom/tokenizers/__init__.py +75 -0
  88. headroom/tokenizers/base.py +210 -0
  89. headroom/tokenizers/estimator.py +198 -0
  90. headroom/tokenizers/huggingface.py +317 -0
  91. headroom/tokenizers/mistral.py +245 -0
  92. headroom/tokenizers/registry.py +398 -0
  93. headroom/tokenizers/tiktoken_counter.py +248 -0
  94. headroom/transforms/__init__.py +106 -0
  95. headroom/transforms/base.py +57 -0
  96. headroom/transforms/cache_aligner.py +357 -0
  97. headroom/transforms/code_compressor.py +1313 -0
  98. headroom/transforms/content_detector.py +335 -0
  99. headroom/transforms/content_router.py +1158 -0
  100. headroom/transforms/llmlingua_compressor.py +638 -0
  101. headroom/transforms/log_compressor.py +529 -0
  102. headroom/transforms/pipeline.py +297 -0
  103. headroom/transforms/rolling_window.py +350 -0
  104. headroom/transforms/search_compressor.py +365 -0
  105. headroom/transforms/smart_crusher.py +2682 -0
  106. headroom/transforms/text_compressor.py +259 -0
  107. headroom/transforms/tool_crusher.py +338 -0
  108. headroom/utils.py +215 -0
  109. headroom_ai-0.2.13.dist-info/METADATA +315 -0
  110. headroom_ai-0.2.13.dist-info/RECORD +114 -0
  111. headroom_ai-0.2.13.dist-info/WHEEL +4 -0
  112. headroom_ai-0.2.13.dist-info/entry_points.txt +2 -0
  113. headroom_ai-0.2.13.dist-info/licenses/LICENSE +190 -0
  114. headroom_ai-0.2.13.dist-info/licenses/NOTICE +43 -0
@@ -0,0 +1,1002 @@
1
+ """LangChain integration for Headroom SDK.
2
+
3
+ This module provides seamless integration with LangChain, enabling automatic
4
+ context optimization for any LangChain chat model.
5
+
6
+ Key insight: LangChain callbacks CANNOT modify messages (by design - see
7
+ https://github.com/langchain-ai/langchain/issues/8725). Therefore, we wrap
8
+ the chat model itself to intercept and transform messages.
9
+
10
+ Components:
11
+ 1. HeadroomChatModel - Wraps any BaseChatModel to apply Headroom transforms
12
+ 2. HeadroomCallbackHandler - Tracks metrics and token usage (observability only)
13
+ 3. HeadroomRunnable - LCEL-compatible Runnable for chain composition
14
+ 4. optimize_messages() - Standalone function for manual optimization
15
+
16
+ Example:
17
+ from langchain_openai import ChatOpenAI
18
+ from headroom.integrations import HeadroomChatModel
19
+
20
+ # Wrap any LangChain chat model
21
+ llm = ChatOpenAI(model="gpt-4o")
22
+ optimized_llm = HeadroomChatModel(llm)
23
+
24
+ # Use normally - Headroom automatically optimizes context
25
+ response = optimized_llm.invoke("What is 2+2?")
26
+ """
27
+
28
+ from __future__ import annotations
29
+
30
+ import asyncio
31
+ import json
32
+ import logging
33
+ from collections.abc import AsyncIterator, Iterator, Sequence
34
+ from dataclasses import dataclass
35
+ from datetime import datetime
36
+ from typing import Any
37
+ from uuid import UUID, uuid4
38
+
39
+ # LangChain imports - these are optional dependencies
40
+ try:
41
+ from langchain_core.callbacks import BaseCallbackHandler
42
+ from langchain_core.language_models import BaseChatModel
43
+ from langchain_core.messages import (
44
+ AIMessage,
45
+ BaseMessage,
46
+ HumanMessage,
47
+ SystemMessage,
48
+ ToolMessage,
49
+ )
50
+ from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult # noqa: F401
51
+ from langchain_core.runnables import RunnableLambda
52
+ from pydantic import ConfigDict, Field, PrivateAttr
53
+
54
+ LANGCHAIN_AVAILABLE = True
55
+ except ImportError:
56
+ LANGCHAIN_AVAILABLE = False
57
+ BaseChatModel = object # type: ignore[misc,assignment]
58
+ BaseCallbackHandler = object # type: ignore[misc,assignment]
59
+ ConfigDict = lambda **kwargs: {} # type: ignore[assignment,misc] # noqa: E731
60
+ Field = lambda **kwargs: None # type: ignore[assignment] # noqa: E731
61
+ PrivateAttr = lambda **kwargs: None # type: ignore[assignment] # noqa: E731
62
+
63
+ from headroom import HeadroomConfig, HeadroomMode
64
+ from headroom.providers import OpenAIProvider
65
+ from headroom.transforms import TransformPipeline
66
+
67
+ from .providers import get_headroom_provider, get_model_name_from_langchain
68
+
69
+ logger = logging.getLogger(__name__)
70
+
71
+
72
+ def _check_langchain_available() -> None:
73
+ """Raise ImportError if LangChain is not installed."""
74
+ if not LANGCHAIN_AVAILABLE:
75
+ raise ImportError(
76
+ "LangChain is required for this integration. "
77
+ "Install with: pip install headroom[langchain] "
78
+ "or: pip install langchain-core langchain-openai"
79
+ )
80
+
81
+
82
+ def langchain_available() -> bool:
83
+ """Check if LangChain is installed."""
84
+ return LANGCHAIN_AVAILABLE
85
+
86
+
87
+ @dataclass
88
+ class OptimizationMetrics:
89
+ """Metrics from a single optimization pass."""
90
+
91
+ request_id: str
92
+ timestamp: datetime
93
+ tokens_before: int
94
+ tokens_after: int
95
+ tokens_saved: int
96
+ savings_percent: float
97
+ transforms_applied: list[str]
98
+ model: str
99
+
100
+
101
+ class HeadroomChatModel(BaseChatModel):
102
+ """LangChain chat model wrapper that applies Headroom optimizations.
103
+
104
+ Wraps any LangChain BaseChatModel and automatically optimizes the context
105
+ before each API call. This is the recommended way to use Headroom with
106
+ LangChain because:
107
+
108
+ 1. Callbacks cannot modify messages (LangChain design limitation)
109
+ 2. Wrapping ensures ALL calls go through optimization
110
+ 3. Works with streaming, tools, and all LangChain features
111
+
112
+ Example:
113
+ from langchain_openai import ChatOpenAI
114
+ from headroom.integrations import HeadroomChatModel
115
+
116
+ # Basic usage
117
+ llm = ChatOpenAI(model="gpt-4o")
118
+ optimized = HeadroomChatModel(llm)
119
+ response = optimized.invoke([HumanMessage("Hello!")])
120
+
121
+ # With custom config
122
+ from headroom import HeadroomConfig, HeadroomMode
123
+ config = HeadroomConfig(default_mode=HeadroomMode.OPTIMIZE)
124
+ optimized = HeadroomChatModel(llm, config=config)
125
+
126
+ # Access metrics
127
+ print(f"Saved {optimized.total_tokens_saved} tokens")
128
+
129
+ Attributes:
130
+ wrapped_model: The underlying LangChain chat model
131
+ headroom_client: HeadroomClient instance for optimization
132
+ metrics_history: List of OptimizationMetrics from recent calls
133
+ total_tokens_saved: Running total of tokens saved
134
+ """
135
+
136
+ # Pydantic model fields
137
+ wrapped_model: Any = Field(description="The wrapped LangChain chat model")
138
+ headroom_config: Any = Field(default=None, description="Headroom configuration")
139
+ mode: HeadroomMode = Field(default=HeadroomMode.OPTIMIZE, description="Headroom mode")
140
+ auto_detect_provider: bool = Field(
141
+ default=True,
142
+ description="Auto-detect provider from wrapped model (OpenAI, Anthropic, Google)",
143
+ )
144
+
145
+ # Private attributes (not serialized)
146
+ _metrics_history: list = PrivateAttr(default_factory=list)
147
+ _total_tokens_saved: int = PrivateAttr(default=0)
148
+ _pipeline: Any = PrivateAttr(default=None)
149
+ _provider: Any = PrivateAttr(default=None)
150
+
151
+ # Pydantic v2 config for LangChain compatibility
152
+ model_config = ConfigDict(arbitrary_types_allowed=True)
153
+
154
+ def __init__(
155
+ self,
156
+ wrapped_model: BaseChatModel,
157
+ config: HeadroomConfig | None = None,
158
+ mode: HeadroomMode = HeadroomMode.OPTIMIZE,
159
+ auto_detect_provider: bool = True,
160
+ **kwargs: Any,
161
+ ) -> None:
162
+ """Initialize HeadroomChatModel.
163
+
164
+ Args:
165
+ wrapped_model: Any LangChain BaseChatModel to wrap
166
+ config: HeadroomConfig for optimization settings
167
+ mode: HeadroomMode (AUDIT, OPTIMIZE, or SIMULATE)
168
+ auto_detect_provider: Auto-detect provider from wrapped model.
169
+ When True (default), automatically detects if the wrapped model
170
+ is OpenAI, Anthropic, Google, etc. and uses the appropriate
171
+ Headroom provider for accurate token counting.
172
+ **kwargs: Additional arguments passed to BaseChatModel
173
+ """
174
+ _check_langchain_available()
175
+
176
+ super().__init__( # type: ignore[call-arg]
177
+ wrapped_model=wrapped_model,
178
+ headroom_config=config or HeadroomConfig(),
179
+ mode=mode,
180
+ auto_detect_provider=auto_detect_provider,
181
+ **kwargs,
182
+ )
183
+ self._metrics_history = []
184
+ self._total_tokens_saved = 0
185
+ self._pipeline = None
186
+ self._provider = None
187
+
188
+ @property
189
+ def _llm_type(self) -> str:
190
+ """Return identifier for this LLM type."""
191
+ return f"headroom-{self.wrapped_model._llm_type}"
192
+
193
+ @property
194
+ def _identifying_params(self) -> dict[str, Any]:
195
+ """Return identifying parameters."""
196
+ return {
197
+ "wrapped_model": self.wrapped_model._identifying_params,
198
+ "headroom_mode": self.mode.value,
199
+ }
200
+
201
+ @property
202
+ def pipeline(self) -> TransformPipeline:
203
+ """Lazily initialize TransformPipeline.
204
+
205
+ When auto_detect_provider is True, automatically detects the provider
206
+ from the wrapped model's class path (e.g., ChatAnthropic -> AnthropicProvider).
207
+ """
208
+ if self._pipeline is None:
209
+ if self.auto_detect_provider:
210
+ self._provider = get_headroom_provider(self.wrapped_model)
211
+ logger.debug(f"Auto-detected provider: {self._provider.__class__.__name__}")
212
+ else:
213
+ self._provider = OpenAIProvider()
214
+ self._pipeline = TransformPipeline(
215
+ config=self.headroom_config,
216
+ provider=self._provider,
217
+ )
218
+ pipeline: TransformPipeline = self._pipeline
219
+ return pipeline
220
+
221
+ @property
222
+ def total_tokens_saved(self) -> int:
223
+ """Total tokens saved across all calls."""
224
+ return self._total_tokens_saved
225
+
226
+ @property
227
+ def metrics_history(self) -> list[OptimizationMetrics]:
228
+ """History of optimization metrics."""
229
+ return self._metrics_history.copy()
230
+
231
+ def _convert_messages_to_openai(self, messages: list[BaseMessage]) -> list[dict[str, Any]]:
232
+ """Convert LangChain messages to OpenAI format for Headroom."""
233
+ result = []
234
+ for msg in messages:
235
+ if isinstance(msg, SystemMessage):
236
+ result.append({"role": "system", "content": msg.content})
237
+ elif isinstance(msg, HumanMessage):
238
+ result.append({"role": "user", "content": msg.content})
239
+ elif isinstance(msg, AIMessage):
240
+ entry = {"role": "assistant", "content": msg.content}
241
+ if msg.tool_calls:
242
+ entry["tool_calls"] = [
243
+ {
244
+ "id": tc["id"],
245
+ "type": "function",
246
+ "function": {
247
+ "name": tc["name"],
248
+ "arguments": json.dumps(tc["args"]),
249
+ },
250
+ }
251
+ for tc in msg.tool_calls
252
+ ]
253
+ result.append(entry)
254
+ elif isinstance(msg, ToolMessage):
255
+ result.append(
256
+ {
257
+ "role": "tool",
258
+ "tool_call_id": msg.tool_call_id,
259
+ "content": msg.content,
260
+ }
261
+ )
262
+ else:
263
+ # Generic fallback
264
+ result.append(
265
+ {
266
+ "role": getattr(msg, "type", "user"),
267
+ "content": msg.content,
268
+ }
269
+ )
270
+ return result
271
+
272
+ def _convert_messages_from_openai(self, messages: list[dict[str, Any]]) -> list[BaseMessage]:
273
+ """Convert OpenAI format messages back to LangChain format."""
274
+ result: list[BaseMessage] = []
275
+ for msg in messages:
276
+ role = msg.get("role", "user")
277
+ content = msg.get("content", "")
278
+
279
+ if role == "system":
280
+ result.append(SystemMessage(content=content))
281
+ elif role == "user":
282
+ result.append(HumanMessage(content=content))
283
+ elif role == "assistant":
284
+ tool_calls = []
285
+ if "tool_calls" in msg:
286
+ for tc in msg["tool_calls"]:
287
+ tool_calls.append(
288
+ {
289
+ "id": tc["id"],
290
+ "name": tc["function"]["name"],
291
+ "args": json.loads(tc["function"]["arguments"]),
292
+ }
293
+ )
294
+ result.append(AIMessage(content=content, tool_calls=tool_calls))
295
+ elif role == "tool":
296
+ result.append(
297
+ ToolMessage(
298
+ content=content,
299
+ tool_call_id=msg.get("tool_call_id", ""),
300
+ )
301
+ )
302
+ return result
303
+
304
+ def _optimize_messages(
305
+ self, messages: list[BaseMessage]
306
+ ) -> tuple[list[BaseMessage], OptimizationMetrics]:
307
+ """Apply Headroom optimization to messages."""
308
+ request_id = str(uuid4())
309
+
310
+ # Convert to OpenAI format
311
+ openai_messages = self._convert_messages_to_openai(messages)
312
+
313
+ # Get model name from wrapped model
314
+ model = get_model_name_from_langchain(self.wrapped_model)
315
+
316
+ # Ensure pipeline is initialized (this also sets up provider)
317
+ _ = self.pipeline
318
+
319
+ # Get model context limit from provider
320
+ model_limit = self._provider.get_context_limit(model) if self._provider else 128000
321
+
322
+ # Ensure model is a string
323
+ model_str = str(model) if model else "gpt-4o"
324
+
325
+ # Apply Headroom transforms via pipeline
326
+ result = self.pipeline.apply(
327
+ messages=openai_messages,
328
+ model=model_str,
329
+ model_limit=model_limit,
330
+ )
331
+
332
+ # Create metrics
333
+ metrics = OptimizationMetrics(
334
+ request_id=request_id,
335
+ timestamp=datetime.now(),
336
+ tokens_before=result.tokens_before,
337
+ tokens_after=result.tokens_after,
338
+ tokens_saved=result.tokens_before - result.tokens_after,
339
+ savings_percent=(
340
+ (result.tokens_before - result.tokens_after) / result.tokens_before * 100
341
+ if result.tokens_before > 0
342
+ else 0
343
+ ),
344
+ transforms_applied=result.transforms_applied,
345
+ model=model_str,
346
+ )
347
+
348
+ # Track metrics
349
+ self._metrics_history.append(metrics)
350
+ self._total_tokens_saved += metrics.tokens_saved
351
+
352
+ # Keep only last 100 metrics
353
+ if len(self._metrics_history) > 100:
354
+ self._metrics_history = self._metrics_history[-100:]
355
+
356
+ # Convert back to LangChain format
357
+ optimized_messages = self._convert_messages_from_openai(result.messages)
358
+
359
+ return optimized_messages, metrics
360
+
361
+ def _generate(
362
+ self,
363
+ messages: list[BaseMessage],
364
+ stop: list[str] | None = None,
365
+ run_manager: Any = None,
366
+ **kwargs: Any,
367
+ ) -> ChatResult:
368
+ """Generate response with Headroom optimization.
369
+
370
+ This is the core method called by invoke(), batch(), etc.
371
+ """
372
+ # Optimize messages
373
+ optimized_messages, metrics = self._optimize_messages(messages)
374
+
375
+ logger.info(
376
+ f"Headroom optimized: {metrics.tokens_before} -> {metrics.tokens_after} tokens "
377
+ f"({metrics.savings_percent:.1f}% saved)"
378
+ )
379
+
380
+ # Call wrapped model with optimized messages
381
+ result: ChatResult = self.wrapped_model._generate(
382
+ optimized_messages,
383
+ stop=stop,
384
+ run_manager=run_manager,
385
+ **kwargs,
386
+ )
387
+
388
+ return result
389
+
390
+ def _stream(
391
+ self,
392
+ messages: list[BaseMessage],
393
+ stop: list[str] | None = None,
394
+ run_manager: Any = None,
395
+ **kwargs: Any,
396
+ ) -> Iterator[ChatGenerationChunk]:
397
+ """Stream response with Headroom optimization."""
398
+ # Optimize messages
399
+ optimized_messages, metrics = self._optimize_messages(messages)
400
+
401
+ logger.info(
402
+ f"Headroom optimized (streaming): {metrics.tokens_before} -> "
403
+ f"{metrics.tokens_after} tokens"
404
+ )
405
+
406
+ # Stream from wrapped model
407
+ yield from self.wrapped_model._stream(
408
+ optimized_messages,
409
+ stop=stop,
410
+ run_manager=run_manager,
411
+ **kwargs,
412
+ )
413
+
414
+ async def _agenerate(
415
+ self,
416
+ messages: list[BaseMessage],
417
+ stop: list[str] | None = None,
418
+ run_manager: Any = None,
419
+ **kwargs: Any,
420
+ ) -> ChatResult:
421
+ """Async generate response with Headroom optimization.
422
+
423
+ This enables `await model.ainvoke(messages)` to work correctly.
424
+ The optimization step runs in a thread executor since it's CPU-bound.
425
+ """
426
+ # Run optimization in executor (CPU-bound)
427
+ loop = asyncio.get_event_loop()
428
+ optimized_messages, metrics = await loop.run_in_executor(
429
+ None, self._optimize_messages, messages
430
+ )
431
+
432
+ logger.info(
433
+ f"Headroom optimized (async): {metrics.tokens_before} -> {metrics.tokens_after} tokens "
434
+ f"({metrics.savings_percent:.1f}% saved)"
435
+ )
436
+
437
+ # Call wrapped model's async generate
438
+ result: ChatResult = await self.wrapped_model._agenerate(
439
+ optimized_messages,
440
+ stop=stop,
441
+ run_manager=run_manager,
442
+ **kwargs,
443
+ )
444
+
445
+ return result
446
+
447
+ async def _astream(
448
+ self,
449
+ messages: list[BaseMessage],
450
+ stop: list[str] | None = None,
451
+ run_manager: Any = None,
452
+ **kwargs: Any,
453
+ ) -> AsyncIterator[ChatGenerationChunk]:
454
+ """Async stream response with Headroom optimization.
455
+
456
+ This enables `async for chunk in model.astream(messages)` to work correctly.
457
+ """
458
+ # Run optimization in executor (CPU-bound)
459
+ loop = asyncio.get_event_loop()
460
+ optimized_messages, metrics = await loop.run_in_executor(
461
+ None, self._optimize_messages, messages
462
+ )
463
+
464
+ logger.info(
465
+ f"Headroom optimized (async streaming): {metrics.tokens_before} -> "
466
+ f"{metrics.tokens_after} tokens"
467
+ )
468
+
469
+ # Async stream from wrapped model
470
+ async for chunk in self.wrapped_model._astream(
471
+ optimized_messages,
472
+ stop=stop,
473
+ run_manager=run_manager,
474
+ **kwargs,
475
+ ):
476
+ yield chunk
477
+
478
+ def bind_tools(self, tools: Sequence[Any], **kwargs: Any) -> HeadroomChatModel:
479
+ """Bind tools to the wrapped model."""
480
+ new_wrapped = self.wrapped_model.bind_tools(tools, **kwargs)
481
+ return HeadroomChatModel(
482
+ wrapped_model=new_wrapped,
483
+ config=self.headroom_config,
484
+ mode=self.mode,
485
+ auto_detect_provider=self.auto_detect_provider,
486
+ )
487
+
488
+ def get_savings_summary(self) -> dict[str, Any]:
489
+ """Get summary of token savings."""
490
+ if not self._metrics_history:
491
+ return {
492
+ "total_requests": 0,
493
+ "total_tokens_saved": 0,
494
+ "average_savings_percent": 0,
495
+ }
496
+
497
+ return {
498
+ "total_requests": len(self._metrics_history),
499
+ "total_tokens_saved": self._total_tokens_saved,
500
+ "average_savings_percent": sum(m.savings_percent for m in self._metrics_history)
501
+ / len(self._metrics_history),
502
+ "total_tokens_before": sum(m.tokens_before for m in self._metrics_history),
503
+ "total_tokens_after": sum(m.tokens_after for m in self._metrics_history),
504
+ }
505
+
506
+
507
+ class HeadroomCallbackHandler(BaseCallbackHandler):
508
+ """LangChain callback handler for Headroom metrics and observability.
509
+
510
+ NOTE: Callbacks CANNOT modify messages in LangChain (by design).
511
+ Use HeadroomChatModel for actual optimization. This handler is for:
512
+
513
+ 1. Tracking token usage across chains
514
+ 2. Logging optimization metrics
515
+ 3. Alerting on high token usage
516
+ 4. Integration with observability platforms
517
+
518
+ Example:
519
+ from langchain_openai import ChatOpenAI
520
+ from headroom.integrations import HeadroomCallbackHandler
521
+
522
+ handler = HeadroomCallbackHandler(
523
+ log_level="INFO",
524
+ token_alert_threshold=10000,
525
+ )
526
+
527
+ llm = ChatOpenAI(model="gpt-4o", callbacks=[handler])
528
+ response = llm.invoke("Hello!")
529
+
530
+ # Check metrics
531
+ print(f"Total tokens: {handler.total_tokens}")
532
+ print(f"Alerts: {handler.alerts}")
533
+ """
534
+
535
+ def __init__(
536
+ self,
537
+ log_level: str = "INFO",
538
+ token_alert_threshold: int | None = None,
539
+ cost_alert_threshold: float | None = None,
540
+ ):
541
+ """Initialize callback handler.
542
+
543
+ Args:
544
+ log_level: Logging level for metrics ("DEBUG", "INFO", "WARNING")
545
+ token_alert_threshold: Alert if request exceeds this many tokens
546
+ cost_alert_threshold: Alert if estimated cost exceeds this amount
547
+ """
548
+ _check_langchain_available()
549
+
550
+ self.log_level = log_level
551
+ self.token_alert_threshold = token_alert_threshold
552
+ self.cost_alert_threshold = cost_alert_threshold
553
+
554
+ # Metrics tracking
555
+ self._requests: list[dict[str, Any]] = []
556
+ self._total_tokens = 0
557
+ self._alerts: list[str] = []
558
+ self._current_request: dict[str, Any] | None = None
559
+
560
+ @property
561
+ def total_tokens(self) -> int:
562
+ """Total tokens used across all requests."""
563
+ return self._total_tokens
564
+
565
+ @property
566
+ def total_requests(self) -> int:
567
+ """Total number of requests tracked."""
568
+ return len(self._requests)
569
+
570
+ @property
571
+ def alerts(self) -> list[str]:
572
+ """List of alerts triggered."""
573
+ return self._alerts.copy()
574
+
575
+ @property
576
+ def requests(self) -> list[dict[str, Any]]:
577
+ """List of request metrics."""
578
+ return self._requests.copy()
579
+
580
+ def on_llm_start(
581
+ self,
582
+ serialized: dict[str, Any],
583
+ prompts: list[str],
584
+ **kwargs: Any,
585
+ ) -> None:
586
+ """Called when LLM starts processing."""
587
+ self._current_request = {
588
+ "start_time": datetime.now(),
589
+ "model": serialized.get("name", "unknown"),
590
+ "prompt_count": len(prompts),
591
+ "estimated_input_tokens": sum(len(p) // 4 for p in prompts), # Rough estimate
592
+ }
593
+
594
+ if self.log_level == "DEBUG":
595
+ logger.debug(f"LLM request started: {self._current_request}")
596
+
597
+ def on_chat_model_start(
598
+ self,
599
+ serialized: dict[str, Any],
600
+ messages: list[list[BaseMessage]],
601
+ **kwargs: Any,
602
+ ) -> None:
603
+ """Called when chat model starts processing."""
604
+ # Estimate tokens from messages
605
+ total_content = ""
606
+ for msg_list in messages:
607
+ for msg in msg_list:
608
+ content = msg.content if isinstance(msg.content, str) else str(msg.content)
609
+ total_content += content
610
+
611
+ estimated_tokens = len(total_content) // 4 # Rough estimate
612
+
613
+ self._current_request = {
614
+ "start_time": datetime.now(),
615
+ "model": serialized.get("name", serialized.get("id", ["unknown"])[-1]),
616
+ "message_count": sum(len(ml) for ml in messages),
617
+ "estimated_input_tokens": estimated_tokens,
618
+ }
619
+
620
+ # Check token alert
621
+ if self.token_alert_threshold and estimated_tokens > self.token_alert_threshold:
622
+ alert = (
623
+ f"Token alert: {estimated_tokens} tokens exceeds "
624
+ f"threshold {self.token_alert_threshold}"
625
+ )
626
+ self._alerts.append(alert)
627
+ logger.warning(alert)
628
+
629
+ if self.log_level in ("DEBUG", "INFO"):
630
+ logger.log(
631
+ logging.DEBUG if self.log_level == "DEBUG" else logging.INFO,
632
+ f"Chat model request: ~{estimated_tokens} input tokens",
633
+ )
634
+
635
+ def on_llm_end(self, response: Any, **kwargs: Any) -> None:
636
+ """Called when LLM finishes processing."""
637
+ if self._current_request is None:
638
+ return
639
+
640
+ # Extract token usage from response if available
641
+ token_usage = {}
642
+ if hasattr(response, "llm_output") and response.llm_output:
643
+ token_usage = response.llm_output.get("token_usage", {})
644
+
645
+ self._current_request["end_time"] = datetime.now()
646
+ self._current_request["duration_ms"] = (
647
+ self._current_request["end_time"] - self._current_request["start_time"]
648
+ ).total_seconds() * 1000
649
+
650
+ if token_usage:
651
+ self._current_request["input_tokens"] = token_usage.get("prompt_tokens", 0)
652
+ self._current_request["output_tokens"] = token_usage.get("completion_tokens", 0)
653
+ self._current_request["total_tokens"] = token_usage.get("total_tokens", 0)
654
+ self._total_tokens += self._current_request["total_tokens"]
655
+
656
+ self._requests.append(self._current_request)
657
+
658
+ # Keep only last 1000 requests
659
+ if len(self._requests) > 1000:
660
+ self._requests = self._requests[-1000:]
661
+
662
+ if self.log_level in ("DEBUG", "INFO"):
663
+ tokens_info = f"{self._current_request.get('total_tokens', 'unknown')} tokens"
664
+ duration = f"{self._current_request['duration_ms']:.0f}ms"
665
+ logger.log(
666
+ logging.DEBUG if self.log_level == "DEBUG" else logging.INFO,
667
+ f"LLM request completed: {tokens_info} in {duration}",
668
+ )
669
+
670
+ self._current_request = None
671
+
672
+ def on_llm_error(
673
+ self,
674
+ error: BaseException,
675
+ *,
676
+ run_id: UUID,
677
+ parent_run_id: UUID | None = None,
678
+ **kwargs: Any,
679
+ ) -> Any:
680
+ """Called when LLM encounters an error."""
681
+ if self._current_request:
682
+ self._current_request["error"] = str(error)
683
+ self._current_request["end_time"] = datetime.now()
684
+ self._requests.append(self._current_request)
685
+ self._current_request = None
686
+
687
+ logger.error(f"LLM error: {error}")
688
+
689
+ def get_summary(self) -> dict[str, Any]:
690
+ """Get summary of all tracked requests."""
691
+ if not self._requests:
692
+ return {
693
+ "total_requests": 0,
694
+ "total_tokens": 0,
695
+ "average_tokens": 0,
696
+ "average_duration_ms": 0,
697
+ "errors": 0,
698
+ "alerts": len(self._alerts),
699
+ }
700
+
701
+ successful = [r for r in self._requests if "error" not in r]
702
+ total_tokens = sum(r.get("total_tokens", 0) for r in successful)
703
+
704
+ return {
705
+ "total_requests": len(self._requests),
706
+ "successful_requests": len(successful),
707
+ "total_tokens": total_tokens,
708
+ "average_tokens": total_tokens / len(successful) if successful else 0,
709
+ "average_duration_ms": (
710
+ sum(r.get("duration_ms", 0) for r in successful) / len(successful)
711
+ if successful
712
+ else 0
713
+ ),
714
+ "errors": len(self._requests) - len(successful),
715
+ "alerts": len(self._alerts),
716
+ }
717
+
718
+ def reset(self) -> None:
719
+ """Reset all tracked metrics."""
720
+ self._requests = []
721
+ self._total_tokens = 0
722
+ self._alerts = []
723
+ self._current_request = None
724
+
725
+
726
+ class HeadroomRunnable:
727
+ """LCEL-compatible Runnable for Headroom optimization.
728
+
729
+ Use this to add Headroom optimization to any LangChain chain using LCEL.
730
+
731
+ Example:
732
+ from langchain_openai import ChatOpenAI
733
+ from langchain_core.prompts import ChatPromptTemplate
734
+ from headroom.integrations import HeadroomRunnable
735
+
736
+ prompt = ChatPromptTemplate.from_messages([
737
+ ("system", "You are a helpful assistant."),
738
+ ("user", "{input}"),
739
+ ])
740
+ llm = ChatOpenAI(model="gpt-4o")
741
+
742
+ # Add Headroom optimization to chain
743
+ chain = prompt | HeadroomRunnable() | llm
744
+ response = chain.invoke({"input": "Hello!"})
745
+ """
746
+
747
+ def __init__(
748
+ self,
749
+ config: HeadroomConfig | None = None,
750
+ mode: HeadroomMode = HeadroomMode.OPTIMIZE,
751
+ ):
752
+ """Initialize HeadroomRunnable.
753
+
754
+ Args:
755
+ config: HeadroomConfig for optimization settings
756
+ mode: HeadroomMode (AUDIT, OPTIMIZE, or SIMULATE)
757
+ """
758
+ _check_langchain_available()
759
+
760
+ self.config = config or HeadroomConfig()
761
+ self.mode = mode
762
+ self._pipeline: TransformPipeline | None = None
763
+ self._provider: OpenAIProvider | None = None
764
+ self._metrics_history: list[OptimizationMetrics] = []
765
+
766
+ @property
767
+ def pipeline(self) -> TransformPipeline:
768
+ """Lazily initialize TransformPipeline."""
769
+ if self._pipeline is None:
770
+ self._provider = OpenAIProvider()
771
+ self._pipeline = TransformPipeline(
772
+ config=self.config,
773
+ provider=self._provider,
774
+ )
775
+ return self._pipeline
776
+
777
+ def __or__(self, other: Any) -> Any:
778
+ """Support pipe operator for LCEL composition."""
779
+ from langchain_core.runnables import RunnableSequence
780
+
781
+ return RunnableSequence(first=self.as_runnable(), last=other)
782
+
783
+ def __ror__(self, other: Any) -> Any:
784
+ """Support reverse pipe operator."""
785
+ from langchain_core.runnables import RunnableSequence
786
+
787
+ return RunnableSequence(first=other, last=self.as_runnable())
788
+
789
+ def as_runnable(self) -> RunnableLambda:
790
+ """Convert to LangChain Runnable."""
791
+ return RunnableLambda(self._optimize)
792
+
793
+ def _optimize(self, input_data: Any) -> Any:
794
+ """Optimize input messages."""
795
+ # Handle different input types
796
+ if isinstance(input_data, list):
797
+ messages = input_data
798
+ elif hasattr(input_data, "messages"):
799
+ messages = input_data.messages
800
+ elif hasattr(input_data, "to_messages"):
801
+ messages = input_data.to_messages()
802
+ else:
803
+ # Can't optimize, pass through
804
+ return input_data
805
+
806
+ # Convert messages to OpenAI format
807
+ openai_messages = []
808
+ for msg in messages:
809
+ if isinstance(msg, SystemMessage):
810
+ openai_messages.append({"role": "system", "content": msg.content})
811
+ elif isinstance(msg, HumanMessage):
812
+ openai_messages.append({"role": "user", "content": msg.content})
813
+ elif isinstance(msg, AIMessage):
814
+ openai_messages.append({"role": "assistant", "content": msg.content})
815
+ elif isinstance(msg, ToolMessage):
816
+ openai_messages.append(
817
+ {
818
+ "role": "tool",
819
+ "tool_call_id": msg.tool_call_id,
820
+ "content": msg.content,
821
+ }
822
+ )
823
+ elif hasattr(msg, "type") and hasattr(msg, "content"):
824
+ openai_messages.append(
825
+ {
826
+ "role": msg.type,
827
+ "content": msg.content,
828
+ }
829
+ )
830
+
831
+ # Get model context limit
832
+ model = "gpt-4o" # Default model for estimation
833
+ model_limit = self._provider.get_context_limit(model) if self._provider else 128000
834
+
835
+ # Apply Headroom transforms via pipeline
836
+ result = self.pipeline.apply(
837
+ messages=openai_messages,
838
+ model=model,
839
+ model_limit=model_limit,
840
+ )
841
+
842
+ # Track metrics
843
+ metrics = OptimizationMetrics(
844
+ request_id=str(uuid4()),
845
+ timestamp=datetime.now(),
846
+ tokens_before=result.tokens_before,
847
+ tokens_after=result.tokens_after,
848
+ tokens_saved=result.tokens_before - result.tokens_after,
849
+ savings_percent=(
850
+ (result.tokens_before - result.tokens_after) / result.tokens_before * 100
851
+ if result.tokens_before > 0
852
+ else 0
853
+ ),
854
+ transforms_applied=result.transforms_applied,
855
+ model="gpt-4o",
856
+ )
857
+ self._metrics_history.append(metrics)
858
+
859
+ # Convert back to LangChain messages
860
+ output_messages: list[BaseMessage] = []
861
+ for msg in result.messages:
862
+ role = msg.get("role", "user")
863
+ content = msg.get("content", "")
864
+
865
+ if role == "system":
866
+ output_messages.append(SystemMessage(content=content))
867
+ elif role == "user":
868
+ output_messages.append(HumanMessage(content=content))
869
+ elif role == "assistant":
870
+ output_messages.append(AIMessage(content=content))
871
+ elif role == "tool":
872
+ output_messages.append(
873
+ ToolMessage(
874
+ content=content,
875
+ tool_call_id=msg.get("tool_call_id", ""),
876
+ )
877
+ )
878
+
879
+ return output_messages
880
+
881
+
882
+ def optimize_messages(
883
+ messages: list[BaseMessage],
884
+ config: HeadroomConfig | None = None,
885
+ mode: HeadroomMode = HeadroomMode.OPTIMIZE,
886
+ model: str = "gpt-4o",
887
+ ) -> tuple[list[BaseMessage], dict[str, Any]]:
888
+ """Standalone function to optimize LangChain messages.
889
+
890
+ Use this for manual optimization when you need fine-grained control.
891
+
892
+ Args:
893
+ messages: List of LangChain BaseMessage objects
894
+ config: HeadroomConfig for optimization settings
895
+ mode: HeadroomMode (AUDIT, OPTIMIZE, or SIMULATE)
896
+ model: Model name for token estimation
897
+
898
+ Returns:
899
+ Tuple of (optimized_messages, metrics_dict)
900
+
901
+ Example:
902
+ from langchain_core.messages import HumanMessage, SystemMessage
903
+ from headroom.integrations import optimize_messages
904
+
905
+ messages = [
906
+ SystemMessage(content="You are helpful."),
907
+ HumanMessage(content="What is 2+2?"),
908
+ ]
909
+
910
+ optimized, metrics = optimize_messages(messages)
911
+ print(f"Saved {metrics['tokens_saved']} tokens")
912
+ """
913
+ _check_langchain_available()
914
+
915
+ config = config or HeadroomConfig()
916
+ provider = OpenAIProvider()
917
+ pipeline = TransformPipeline(config=config, provider=provider)
918
+
919
+ # Convert to OpenAI format
920
+ openai_messages = []
921
+ for msg in messages:
922
+ if isinstance(msg, SystemMessage):
923
+ openai_messages.append({"role": "system", "content": msg.content})
924
+ elif isinstance(msg, HumanMessage):
925
+ openai_messages.append({"role": "user", "content": msg.content})
926
+ elif isinstance(msg, AIMessage):
927
+ entry = {"role": "assistant", "content": msg.content}
928
+ if hasattr(msg, "tool_calls") and msg.tool_calls:
929
+ entry["tool_calls"] = [
930
+ {
931
+ "id": tc["id"],
932
+ "type": "function",
933
+ "function": {
934
+ "name": tc["name"],
935
+ "arguments": json.dumps(tc["args"]),
936
+ },
937
+ }
938
+ for tc in msg.tool_calls
939
+ ]
940
+ openai_messages.append(entry)
941
+ elif isinstance(msg, ToolMessage):
942
+ openai_messages.append(
943
+ {
944
+ "role": "tool",
945
+ "tool_call_id": msg.tool_call_id,
946
+ "content": msg.content,
947
+ }
948
+ )
949
+
950
+ # Get model context limit
951
+ model_limit = provider.get_context_limit(model)
952
+
953
+ # Apply transforms via pipeline
954
+ result = pipeline.apply(
955
+ messages=openai_messages,
956
+ model=model,
957
+ model_limit=model_limit,
958
+ )
959
+
960
+ # Convert back
961
+ output_messages: list[BaseMessage] = []
962
+ for openai_msg in result.messages:
963
+ role = openai_msg.get("role", "user")
964
+ content = openai_msg.get("content", "")
965
+
966
+ if role == "system":
967
+ output_messages.append(SystemMessage(content=content))
968
+ elif role == "user":
969
+ output_messages.append(HumanMessage(content=content))
970
+ elif role == "assistant":
971
+ tool_calls = []
972
+ if "tool_calls" in openai_msg:
973
+ for tc in openai_msg["tool_calls"]:
974
+ tool_calls.append(
975
+ {
976
+ "id": tc["id"],
977
+ "name": tc["function"]["name"],
978
+ "args": json.loads(tc["function"]["arguments"]),
979
+ }
980
+ )
981
+ output_messages.append(AIMessage(content=content, tool_calls=tool_calls))
982
+ elif role == "tool":
983
+ output_messages.append(
984
+ ToolMessage(
985
+ content=content,
986
+ tool_call_id=openai_msg.get("tool_call_id", ""),
987
+ )
988
+ )
989
+
990
+ metrics = {
991
+ "tokens_before": result.tokens_before,
992
+ "tokens_after": result.tokens_after,
993
+ "tokens_saved": result.tokens_before - result.tokens_after,
994
+ "savings_percent": (
995
+ (result.tokens_before - result.tokens_after) / result.tokens_before * 100
996
+ if result.tokens_before > 0
997
+ else 0
998
+ ),
999
+ "transforms_applied": result.transforms_applied,
1000
+ }
1001
+
1002
+ return output_messages, metrics