headroom-ai 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (114) hide show
  1. headroom/__init__.py +212 -0
  2. headroom/cache/__init__.py +76 -0
  3. headroom/cache/anthropic.py +517 -0
  4. headroom/cache/base.py +342 -0
  5. headroom/cache/compression_feedback.py +613 -0
  6. headroom/cache/compression_store.py +814 -0
  7. headroom/cache/dynamic_detector.py +1026 -0
  8. headroom/cache/google.py +884 -0
  9. headroom/cache/openai.py +584 -0
  10. headroom/cache/registry.py +175 -0
  11. headroom/cache/semantic.py +451 -0
  12. headroom/ccr/__init__.py +77 -0
  13. headroom/ccr/context_tracker.py +582 -0
  14. headroom/ccr/mcp_server.py +319 -0
  15. headroom/ccr/response_handler.py +772 -0
  16. headroom/ccr/tool_injection.py +415 -0
  17. headroom/cli.py +219 -0
  18. headroom/client.py +977 -0
  19. headroom/compression/__init__.py +42 -0
  20. headroom/compression/detector.py +424 -0
  21. headroom/compression/handlers/__init__.py +22 -0
  22. headroom/compression/handlers/base.py +219 -0
  23. headroom/compression/handlers/code_handler.py +506 -0
  24. headroom/compression/handlers/json_handler.py +418 -0
  25. headroom/compression/masks.py +345 -0
  26. headroom/compression/universal.py +465 -0
  27. headroom/config.py +474 -0
  28. headroom/exceptions.py +192 -0
  29. headroom/integrations/__init__.py +159 -0
  30. headroom/integrations/agno/__init__.py +53 -0
  31. headroom/integrations/agno/hooks.py +345 -0
  32. headroom/integrations/agno/model.py +625 -0
  33. headroom/integrations/agno/providers.py +154 -0
  34. headroom/integrations/langchain/__init__.py +106 -0
  35. headroom/integrations/langchain/agents.py +326 -0
  36. headroom/integrations/langchain/chat_model.py +1002 -0
  37. headroom/integrations/langchain/langsmith.py +324 -0
  38. headroom/integrations/langchain/memory.py +319 -0
  39. headroom/integrations/langchain/providers.py +200 -0
  40. headroom/integrations/langchain/retriever.py +371 -0
  41. headroom/integrations/langchain/streaming.py +341 -0
  42. headroom/integrations/mcp/__init__.py +37 -0
  43. headroom/integrations/mcp/server.py +533 -0
  44. headroom/memory/__init__.py +37 -0
  45. headroom/memory/extractor.py +390 -0
  46. headroom/memory/fast_store.py +621 -0
  47. headroom/memory/fast_wrapper.py +311 -0
  48. headroom/memory/inline_extractor.py +229 -0
  49. headroom/memory/store.py +434 -0
  50. headroom/memory/worker.py +260 -0
  51. headroom/memory/wrapper.py +321 -0
  52. headroom/models/__init__.py +39 -0
  53. headroom/models/registry.py +687 -0
  54. headroom/parser.py +293 -0
  55. headroom/pricing/__init__.py +51 -0
  56. headroom/pricing/anthropic_prices.py +81 -0
  57. headroom/pricing/litellm_pricing.py +113 -0
  58. headroom/pricing/openai_prices.py +91 -0
  59. headroom/pricing/registry.py +188 -0
  60. headroom/providers/__init__.py +61 -0
  61. headroom/providers/anthropic.py +621 -0
  62. headroom/providers/base.py +131 -0
  63. headroom/providers/cohere.py +362 -0
  64. headroom/providers/google.py +427 -0
  65. headroom/providers/litellm.py +297 -0
  66. headroom/providers/openai.py +566 -0
  67. headroom/providers/openai_compatible.py +521 -0
  68. headroom/proxy/__init__.py +19 -0
  69. headroom/proxy/server.py +2683 -0
  70. headroom/py.typed +0 -0
  71. headroom/relevance/__init__.py +124 -0
  72. headroom/relevance/base.py +106 -0
  73. headroom/relevance/bm25.py +255 -0
  74. headroom/relevance/embedding.py +255 -0
  75. headroom/relevance/hybrid.py +259 -0
  76. headroom/reporting/__init__.py +5 -0
  77. headroom/reporting/generator.py +549 -0
  78. headroom/storage/__init__.py +41 -0
  79. headroom/storage/base.py +125 -0
  80. headroom/storage/jsonl.py +220 -0
  81. headroom/storage/sqlite.py +289 -0
  82. headroom/telemetry/__init__.py +91 -0
  83. headroom/telemetry/collector.py +764 -0
  84. headroom/telemetry/models.py +880 -0
  85. headroom/telemetry/toin.py +1579 -0
  86. headroom/tokenizer.py +80 -0
  87. headroom/tokenizers/__init__.py +75 -0
  88. headroom/tokenizers/base.py +210 -0
  89. headroom/tokenizers/estimator.py +198 -0
  90. headroom/tokenizers/huggingface.py +317 -0
  91. headroom/tokenizers/mistral.py +245 -0
  92. headroom/tokenizers/registry.py +398 -0
  93. headroom/tokenizers/tiktoken_counter.py +248 -0
  94. headroom/transforms/__init__.py +106 -0
  95. headroom/transforms/base.py +57 -0
  96. headroom/transforms/cache_aligner.py +357 -0
  97. headroom/transforms/code_compressor.py +1313 -0
  98. headroom/transforms/content_detector.py +335 -0
  99. headroom/transforms/content_router.py +1158 -0
  100. headroom/transforms/llmlingua_compressor.py +638 -0
  101. headroom/transforms/log_compressor.py +529 -0
  102. headroom/transforms/pipeline.py +297 -0
  103. headroom/transforms/rolling_window.py +350 -0
  104. headroom/transforms/search_compressor.py +365 -0
  105. headroom/transforms/smart_crusher.py +2682 -0
  106. headroom/transforms/text_compressor.py +259 -0
  107. headroom/transforms/tool_crusher.py +338 -0
  108. headroom/utils.py +215 -0
  109. headroom_ai-0.2.13.dist-info/METADATA +315 -0
  110. headroom_ai-0.2.13.dist-info/RECORD +114 -0
  111. headroom_ai-0.2.13.dist-info/WHEEL +4 -0
  112. headroom_ai-0.2.13.dist-info/entry_points.txt +2 -0
  113. headroom_ai-0.2.13.dist-info/licenses/LICENSE +190 -0
  114. headroom_ai-0.2.13.dist-info/licenses/NOTICE +43 -0
@@ -0,0 +1,131 @@
1
+ """Base provider protocol for Headroom SDK.
2
+
3
+ Providers are responsible for:
4
+ - Token counting (model-specific)
5
+ - Model context limits
6
+ - Cost estimation (optional)
7
+
8
+ This module defines the protocols that all providers must implement.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from abc import ABC, abstractmethod
14
+ from typing import Any, Protocol, runtime_checkable
15
+
16
+
17
+ @runtime_checkable
18
+ class TokenCounter(Protocol):
19
+ """Protocol for token counting implementations."""
20
+
21
+ def count_text(self, text: str) -> int:
22
+ """Count tokens in a text string."""
23
+ ...
24
+
25
+ def count_message(self, message: dict[str, Any]) -> int:
26
+ """Count tokens in a single message dict."""
27
+ ...
28
+
29
+ def count_messages(self, messages: list[dict[str, Any]]) -> int:
30
+ """Count tokens in a list of messages."""
31
+ ...
32
+
33
+
34
+ class Provider(ABC):
35
+ """
36
+ Abstract base class for LLM providers.
37
+
38
+ Providers encapsulate all model-specific behavior:
39
+ - Token counting
40
+ - Context window limits
41
+ - Cost estimation
42
+
43
+ Implementations must be explicit - no silent fallbacks.
44
+ """
45
+
46
+ @property
47
+ @abstractmethod
48
+ def name(self) -> str:
49
+ """Provider name (e.g., 'openai', 'anthropic')."""
50
+ ...
51
+
52
+ @abstractmethod
53
+ def get_token_counter(self, model: str) -> TokenCounter:
54
+ """
55
+ Get a token counter for a specific model.
56
+
57
+ Args:
58
+ model: The model name.
59
+
60
+ Returns:
61
+ TokenCounter instance for the model.
62
+
63
+ Raises:
64
+ ValueError: If model is not supported by this provider.
65
+ """
66
+ ...
67
+
68
+ @abstractmethod
69
+ def get_context_limit(self, model: str) -> int:
70
+ """
71
+ Get the context window limit for a model.
72
+
73
+ Args:
74
+ model: The model name.
75
+
76
+ Returns:
77
+ Maximum context tokens for the model.
78
+
79
+ Raises:
80
+ ValueError: If model is not recognized.
81
+ """
82
+ ...
83
+
84
+ @abstractmethod
85
+ def supports_model(self, model: str) -> bool:
86
+ """
87
+ Check if this provider supports a given model.
88
+
89
+ Args:
90
+ model: The model name.
91
+
92
+ Returns:
93
+ True if the model is supported.
94
+ """
95
+ ...
96
+
97
+ def estimate_cost(
98
+ self,
99
+ input_tokens: int,
100
+ output_tokens: int,
101
+ model: str,
102
+ cached_tokens: int = 0,
103
+ ) -> float | None:
104
+ """
105
+ Estimate API cost in USD.
106
+
107
+ Args:
108
+ input_tokens: Number of input tokens.
109
+ output_tokens: Number of output tokens.
110
+ model: Model name.
111
+ cached_tokens: Number of cached input tokens.
112
+
113
+ Returns:
114
+ Estimated cost in USD, or None if cost estimation not available.
115
+ """
116
+ return None
117
+
118
+ def get_output_buffer(self, model: str, default: int = 4000) -> int:
119
+ """
120
+ Get recommended output buffer for a model.
121
+
122
+ Some models (like reasoning models) produce longer outputs.
123
+
124
+ Args:
125
+ model: The model name.
126
+ default: Default buffer if no model-specific recommendation.
127
+
128
+ Returns:
129
+ Recommended output token buffer.
130
+ """
131
+ return default
@@ -0,0 +1,362 @@
1
+ """Cohere provider for Headroom SDK.
2
+
3
+ Token counting uses Cohere's official tokenize API when a client
4
+ is provided. This gives accurate counts for all content types.
5
+
6
+ Usage:
7
+ import cohere
8
+ from headroom import CohereProvider
9
+
10
+ client = cohere.ClientV2() # Uses CO_API_KEY env var
11
+ provider = CohereProvider(client=client) # Accurate counting via API
12
+
13
+ # Or without client (uses estimation - less accurate)
14
+ provider = CohereProvider() # Warning: approximate counting
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import logging
20
+ import warnings
21
+ from datetime import date
22
+ from typing import Any
23
+
24
+ from headroom.tokenizers import EstimatingTokenCounter
25
+
26
+ from .base import Provider, TokenCounter
27
+
28
+ try:
29
+ import litellm
30
+
31
+ LITELLM_AVAILABLE = True
32
+ except ImportError:
33
+ LITELLM_AVAILABLE = False
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+ # Warning flags
38
+ _FALLBACK_WARNING_SHOWN = False
39
+
40
+ # Pricing metadata
41
+ _PRICING_LAST_UPDATED = date(2025, 1, 6)
42
+
43
+ # Cohere model context limits
44
+ _CONTEXT_LIMITS: dict[str, int] = {
45
+ # Command A (latest, 2025)
46
+ "command-a-03-2025": 256000,
47
+ "command-a": 256000,
48
+ # Command R+ (2024)
49
+ "command-r-plus-08-2024": 128000,
50
+ "command-r-plus": 128000,
51
+ # Command R (2024)
52
+ "command-r-08-2024": 128000,
53
+ "command-r": 128000,
54
+ # Command (legacy)
55
+ "command": 4096,
56
+ "command-light": 4096,
57
+ "command-nightly": 128000,
58
+ # Embed models
59
+ "embed-english-v3.0": 512,
60
+ "embed-multilingual-v3.0": 512,
61
+ "embed-english-light-v3.0": 512,
62
+ "embed-multilingual-light-v3.0": 512,
63
+ }
64
+
65
+ # Fallback pricing - LiteLLM is preferred source
66
+ # Pricing per 1M tokens (input, output)
67
+ _PRICING: dict[str, tuple[float, float]] = {
68
+ "command-a-03-2025": (2.50, 10.00),
69
+ "command-a": (2.50, 10.00),
70
+ "command-r-plus-08-2024": (2.50, 10.00),
71
+ "command-r-plus": (2.50, 10.00),
72
+ "command-r-08-2024": (0.15, 0.60),
73
+ "command-r": (0.15, 0.60),
74
+ "command": (1.00, 2.00),
75
+ "command-light": (0.30, 0.60),
76
+ }
77
+
78
+
79
+ class CohereTokenCounter:
80
+ """Token counter for Cohere models.
81
+
82
+ When a Cohere client is provided, uses the official tokenize API
83
+ for accurate counting. Falls back to estimation when no client
84
+ is available.
85
+
86
+ Usage:
87
+ import cohere
88
+ client = cohere.ClientV2()
89
+
90
+ # With API (accurate)
91
+ counter = CohereTokenCounter("command-r-plus", client=client)
92
+
93
+ # Without API (estimation)
94
+ counter = CohereTokenCounter("command-r-plus")
95
+ """
96
+
97
+ def __init__(self, model: str, client: Any = None):
98
+ """Initialize Cohere token counter.
99
+
100
+ Args:
101
+ model: Cohere model name.
102
+ client: Optional cohere.ClientV2 for API-based counting.
103
+ """
104
+ global _FALLBACK_WARNING_SHOWN
105
+
106
+ self.model = model
107
+ self._client = client
108
+ self._use_api = client is not None
109
+
110
+ # Cohere uses ~4 chars per token
111
+ self._estimator = EstimatingTokenCounter(chars_per_token=4.0)
112
+
113
+ if not self._use_api and not _FALLBACK_WARNING_SHOWN:
114
+ warnings.warn(
115
+ "CohereProvider: No client provided, using estimation. "
116
+ "For accurate counting, pass a Cohere client: "
117
+ "CohereProvider(client=cohere.ClientV2())",
118
+ UserWarning,
119
+ stacklevel=4,
120
+ )
121
+ _FALLBACK_WARNING_SHOWN = True
122
+
123
+ def count_text(self, text: str) -> int:
124
+ """Count tokens in text.
125
+
126
+ Uses tokenize API if client available, otherwise estimates.
127
+ """
128
+ if not text:
129
+ return 0
130
+
131
+ if self._use_api:
132
+ try:
133
+ response = self._client.tokenize(
134
+ text=text,
135
+ model=self.model,
136
+ )
137
+ return len(response.tokens)
138
+ except Exception as e:
139
+ logger.debug(f"Cohere tokenize API failed: {e}, using estimation")
140
+
141
+ return self._estimator.count_text(text)
142
+
143
+ def count_message(self, message: dict[str, Any]) -> int:
144
+ """Count tokens in a message."""
145
+ content = self._extract_content(message)
146
+ tokens = self.count_text(content)
147
+ tokens += 4 # Message overhead (role tokens, etc.)
148
+ return tokens
149
+
150
+ def count_messages(self, messages: list[dict[str, Any]]) -> int:
151
+ """Count tokens in messages."""
152
+ if not messages:
153
+ return 0
154
+
155
+ # For API-based counting, concatenate all content
156
+ if self._use_api:
157
+ try:
158
+ all_content = []
159
+ for msg in messages:
160
+ content = self._extract_content(msg)
161
+ role = msg.get("role", "user")
162
+ all_content.append(f"{role}: {content}")
163
+
164
+ full_text = "\n".join(all_content)
165
+ response = self._client.tokenize(
166
+ text=full_text,
167
+ model=self.model,
168
+ )
169
+ return len(response.tokens)
170
+ except Exception as e:
171
+ logger.debug(f"Cohere tokenize API failed: {e}, using estimation")
172
+
173
+ # Fallback to estimation
174
+ total = sum(self.count_message(msg) for msg in messages)
175
+ total += 3 # Priming tokens
176
+ return total
177
+
178
+ def _extract_content(self, message: dict[str, Any]) -> str:
179
+ """Extract text content from message."""
180
+ content = message.get("content", "")
181
+ if isinstance(content, str):
182
+ return content
183
+ elif isinstance(content, list):
184
+ parts = []
185
+ for part in content:
186
+ if isinstance(part, dict) and part.get("type") == "text":
187
+ parts.append(part.get("text", ""))
188
+ elif isinstance(part, str):
189
+ parts.append(part)
190
+ return "\n".join(parts)
191
+ return str(content)
192
+
193
+
194
+ class CohereProvider(Provider):
195
+ """Provider for Cohere Command models.
196
+
197
+ Supports Command R, Command R+, and Command A model families.
198
+
199
+ Example:
200
+ import cohere
201
+ client = cohere.ClientV2()
202
+
203
+ # With client (accurate token counting via API)
204
+ provider = CohereProvider(client=client)
205
+
206
+ # Without client (estimation-based counting)
207
+ provider = CohereProvider()
208
+
209
+ # Token counting
210
+ counter = provider.get_token_counter("command-r-plus")
211
+ tokens = counter.count_text("Hello, world!")
212
+
213
+ # Context limits
214
+ limit = provider.get_context_limit("command-a") # 256K tokens
215
+
216
+ # Cost estimation
217
+ cost = provider.estimate_cost(
218
+ input_tokens=100000,
219
+ output_tokens=10000,
220
+ model="command-r-plus",
221
+ )
222
+ """
223
+
224
+ def __init__(self, client: Any = None):
225
+ """Initialize Cohere provider.
226
+
227
+ Args:
228
+ client: Optional cohere.ClientV2 for API-based token counting.
229
+ If provided, uses tokenize API for accurate counts.
230
+ """
231
+ self._client = client
232
+
233
+ @property
234
+ def name(self) -> str:
235
+ return "cohere"
236
+
237
+ def supports_model(self, model: str) -> bool:
238
+ """Check if model is a known Cohere model."""
239
+ model_lower = model.lower()
240
+ if model_lower in _CONTEXT_LIMITS:
241
+ return True
242
+ # Check prefix match
243
+ for prefix in ["command-a", "command-r", "command", "embed-"]:
244
+ if model_lower.startswith(prefix):
245
+ return True
246
+ return False
247
+
248
+ def get_token_counter(self, model: str) -> TokenCounter:
249
+ """Get token counter for a Cohere model.
250
+
251
+ Uses tokenize API if client was provided, otherwise estimates.
252
+ """
253
+ if not self.supports_model(model):
254
+ raise ValueError(
255
+ f"Model '{model}' is not recognized as a Cohere model. "
256
+ f"Supported models: {list(_CONTEXT_LIMITS.keys())}"
257
+ )
258
+ return CohereTokenCounter(model, client=self._client)
259
+
260
+ def get_context_limit(self, model: str) -> int:
261
+ """Get context limit for a Cohere model.
262
+
263
+ Tries LiteLLM first (with and without 'cohere/' prefix),
264
+ then falls back to built-in limits.
265
+ """
266
+ # Try LiteLLM first
267
+ if LITELLM_AVAILABLE:
268
+ for model_variant in [f"cohere/{model}", model]:
269
+ try:
270
+ info = litellm.get_model_info(model_variant)
271
+ if info and "max_input_tokens" in info:
272
+ result = info["max_input_tokens"]
273
+ if result is not None:
274
+ return result
275
+ if info and "max_tokens" in info:
276
+ result = info["max_tokens"]
277
+ if result is not None:
278
+ return result
279
+ except Exception:
280
+ pass
281
+
282
+ # Fallback to built-in limits
283
+ model_lower = model.lower()
284
+
285
+ # Direct match
286
+ if model_lower in _CONTEXT_LIMITS:
287
+ return _CONTEXT_LIMITS[model_lower]
288
+
289
+ # Prefix match
290
+ for prefix, limit in [
291
+ ("command-a", 256000),
292
+ ("command-r-plus", 128000),
293
+ ("command-r", 128000),
294
+ ("command", 4096),
295
+ ("embed-", 512),
296
+ ]:
297
+ if model_lower.startswith(prefix):
298
+ return limit
299
+
300
+ raise ValueError(
301
+ f"Unknown context limit for model '{model}'. "
302
+ f"Known models: {list(_CONTEXT_LIMITS.keys())}"
303
+ )
304
+
305
+ def estimate_cost(
306
+ self,
307
+ input_tokens: int,
308
+ output_tokens: int,
309
+ model: str,
310
+ cached_tokens: int = 0,
311
+ ) -> float | None:
312
+ """Estimate cost for Cohere API call.
313
+
314
+ Tries LiteLLM first (with and without 'cohere/' prefix),
315
+ then falls back to built-in pricing.
316
+
317
+ Args:
318
+ input_tokens: Number of input tokens.
319
+ output_tokens: Number of output tokens.
320
+ model: Model name.
321
+ cached_tokens: Not used by Cohere.
322
+
323
+ Returns:
324
+ Estimated cost in USD, or None if pricing unknown.
325
+ """
326
+ # Try LiteLLM first
327
+ if LITELLM_AVAILABLE:
328
+ for model_variant in [f"cohere/{model}", model]:
329
+ try:
330
+ cost = litellm.completion_cost(
331
+ model=model_variant,
332
+ prompt="",
333
+ completion="",
334
+ prompt_tokens=input_tokens,
335
+ completion_tokens=output_tokens,
336
+ )
337
+ if cost is not None:
338
+ return cost
339
+ except Exception:
340
+ pass
341
+
342
+ # Fallback to built-in pricing
343
+ model_lower = model.lower()
344
+
345
+ # Find pricing
346
+ input_price, output_price = None, None
347
+ for model_prefix, (inp, outp) in _PRICING.items():
348
+ if model_lower.startswith(model_prefix):
349
+ input_price, output_price = inp, outp
350
+ break
351
+
352
+ if input_price is None:
353
+ return None
354
+
355
+ input_cost = (input_tokens / 1_000_000) * input_price
356
+ output_cost = (output_tokens / 1_000_000) * (output_price or 0)
357
+
358
+ return input_cost + output_cost
359
+
360
+ def get_output_buffer(self, model: str, default: int = 4000) -> int:
361
+ """Get recommended output buffer."""
362
+ return default