headroom-ai 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- headroom/__init__.py +212 -0
- headroom/cache/__init__.py +76 -0
- headroom/cache/anthropic.py +517 -0
- headroom/cache/base.py +342 -0
- headroom/cache/compression_feedback.py +613 -0
- headroom/cache/compression_store.py +814 -0
- headroom/cache/dynamic_detector.py +1026 -0
- headroom/cache/google.py +884 -0
- headroom/cache/openai.py +584 -0
- headroom/cache/registry.py +175 -0
- headroom/cache/semantic.py +451 -0
- headroom/ccr/__init__.py +77 -0
- headroom/ccr/context_tracker.py +582 -0
- headroom/ccr/mcp_server.py +319 -0
- headroom/ccr/response_handler.py +772 -0
- headroom/ccr/tool_injection.py +415 -0
- headroom/cli.py +219 -0
- headroom/client.py +977 -0
- headroom/compression/__init__.py +42 -0
- headroom/compression/detector.py +424 -0
- headroom/compression/handlers/__init__.py +22 -0
- headroom/compression/handlers/base.py +219 -0
- headroom/compression/handlers/code_handler.py +506 -0
- headroom/compression/handlers/json_handler.py +418 -0
- headroom/compression/masks.py +345 -0
- headroom/compression/universal.py +465 -0
- headroom/config.py +474 -0
- headroom/exceptions.py +192 -0
- headroom/integrations/__init__.py +159 -0
- headroom/integrations/agno/__init__.py +53 -0
- headroom/integrations/agno/hooks.py +345 -0
- headroom/integrations/agno/model.py +625 -0
- headroom/integrations/agno/providers.py +154 -0
- headroom/integrations/langchain/__init__.py +106 -0
- headroom/integrations/langchain/agents.py +326 -0
- headroom/integrations/langchain/chat_model.py +1002 -0
- headroom/integrations/langchain/langsmith.py +324 -0
- headroom/integrations/langchain/memory.py +319 -0
- headroom/integrations/langchain/providers.py +200 -0
- headroom/integrations/langchain/retriever.py +371 -0
- headroom/integrations/langchain/streaming.py +341 -0
- headroom/integrations/mcp/__init__.py +37 -0
- headroom/integrations/mcp/server.py +533 -0
- headroom/memory/__init__.py +37 -0
- headroom/memory/extractor.py +390 -0
- headroom/memory/fast_store.py +621 -0
- headroom/memory/fast_wrapper.py +311 -0
- headroom/memory/inline_extractor.py +229 -0
- headroom/memory/store.py +434 -0
- headroom/memory/worker.py +260 -0
- headroom/memory/wrapper.py +321 -0
- headroom/models/__init__.py +39 -0
- headroom/models/registry.py +687 -0
- headroom/parser.py +293 -0
- headroom/pricing/__init__.py +51 -0
- headroom/pricing/anthropic_prices.py +81 -0
- headroom/pricing/litellm_pricing.py +113 -0
- headroom/pricing/openai_prices.py +91 -0
- headroom/pricing/registry.py +188 -0
- headroom/providers/__init__.py +61 -0
- headroom/providers/anthropic.py +621 -0
- headroom/providers/base.py +131 -0
- headroom/providers/cohere.py +362 -0
- headroom/providers/google.py +427 -0
- headroom/providers/litellm.py +297 -0
- headroom/providers/openai.py +566 -0
- headroom/providers/openai_compatible.py +521 -0
- headroom/proxy/__init__.py +19 -0
- headroom/proxy/server.py +2683 -0
- headroom/py.typed +0 -0
- headroom/relevance/__init__.py +124 -0
- headroom/relevance/base.py +106 -0
- headroom/relevance/bm25.py +255 -0
- headroom/relevance/embedding.py +255 -0
- headroom/relevance/hybrid.py +259 -0
- headroom/reporting/__init__.py +5 -0
- headroom/reporting/generator.py +549 -0
- headroom/storage/__init__.py +41 -0
- headroom/storage/base.py +125 -0
- headroom/storage/jsonl.py +220 -0
- headroom/storage/sqlite.py +289 -0
- headroom/telemetry/__init__.py +91 -0
- headroom/telemetry/collector.py +764 -0
- headroom/telemetry/models.py +880 -0
- headroom/telemetry/toin.py +1579 -0
- headroom/tokenizer.py +80 -0
- headroom/tokenizers/__init__.py +75 -0
- headroom/tokenizers/base.py +210 -0
- headroom/tokenizers/estimator.py +198 -0
- headroom/tokenizers/huggingface.py +317 -0
- headroom/tokenizers/mistral.py +245 -0
- headroom/tokenizers/registry.py +398 -0
- headroom/tokenizers/tiktoken_counter.py +248 -0
- headroom/transforms/__init__.py +106 -0
- headroom/transforms/base.py +57 -0
- headroom/transforms/cache_aligner.py +357 -0
- headroom/transforms/code_compressor.py +1313 -0
- headroom/transforms/content_detector.py +335 -0
- headroom/transforms/content_router.py +1158 -0
- headroom/transforms/llmlingua_compressor.py +638 -0
- headroom/transforms/log_compressor.py +529 -0
- headroom/transforms/pipeline.py +297 -0
- headroom/transforms/rolling_window.py +350 -0
- headroom/transforms/search_compressor.py +365 -0
- headroom/transforms/smart_crusher.py +2682 -0
- headroom/transforms/text_compressor.py +259 -0
- headroom/transforms/tool_crusher.py +338 -0
- headroom/utils.py +215 -0
- headroom_ai-0.2.13.dist-info/METADATA +315 -0
- headroom_ai-0.2.13.dist-info/RECORD +114 -0
- headroom_ai-0.2.13.dist-info/WHEEL +4 -0
- headroom_ai-0.2.13.dist-info/entry_points.txt +2 -0
- headroom_ai-0.2.13.dist-info/licenses/LICENSE +190 -0
- headroom_ai-0.2.13.dist-info/licenses/NOTICE +43 -0
|
@@ -0,0 +1,566 @@
|
|
|
1
|
+
"""OpenAI provider implementation for Headroom SDK.
|
|
2
|
+
|
|
3
|
+
Token counting is accurate (uses tiktoken).
|
|
4
|
+
Cost estimates are APPROXIMATE - always verify against your actual billing.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
import logging
|
|
11
|
+
import os
|
|
12
|
+
import warnings
|
|
13
|
+
from datetime import date
|
|
14
|
+
from functools import lru_cache
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
from typing import Any, cast
|
|
17
|
+
|
|
18
|
+
from .base import Provider, TokenCounter
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
# Pricing metadata for transparency
|
|
23
|
+
_PRICING_LAST_UPDATED = date(2025, 1, 14)
|
|
24
|
+
_PRICING_STALE_DAYS = 60 # Warn if pricing data is older than this
|
|
25
|
+
|
|
26
|
+
# Warning tracking
|
|
27
|
+
_PRICING_WARNING_SHOWN = False
|
|
28
|
+
_UNKNOWN_MODEL_WARNINGS: set[str] = set()
|
|
29
|
+
|
|
30
|
+
try:
|
|
31
|
+
import tiktoken
|
|
32
|
+
|
|
33
|
+
TIKTOKEN_AVAILABLE = True
|
|
34
|
+
except ImportError:
|
|
35
|
+
TIKTOKEN_AVAILABLE = False
|
|
36
|
+
|
|
37
|
+
try:
|
|
38
|
+
import litellm
|
|
39
|
+
|
|
40
|
+
LITELLM_AVAILABLE = True
|
|
41
|
+
except ImportError:
|
|
42
|
+
LITELLM_AVAILABLE = False
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# OpenAI model to tiktoken encoding mappings
|
|
46
|
+
_MODEL_ENCODINGS: dict[str, str] = {
|
|
47
|
+
# GPT-4o and newer use o200k_base
|
|
48
|
+
"gpt-4o": "o200k_base",
|
|
49
|
+
"gpt-4o-mini": "o200k_base",
|
|
50
|
+
"gpt-4o-2024": "o200k_base",
|
|
51
|
+
"o1": "o200k_base",
|
|
52
|
+
"o1-preview": "o200k_base",
|
|
53
|
+
"o1-mini": "o200k_base",
|
|
54
|
+
"o3": "o200k_base",
|
|
55
|
+
"o3-mini": "o200k_base",
|
|
56
|
+
# GPT-4 and GPT-3.5 use cl100k_base
|
|
57
|
+
"gpt-4": "cl100k_base",
|
|
58
|
+
"gpt-4-turbo": "cl100k_base",
|
|
59
|
+
"gpt-3.5": "cl100k_base",
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
# OpenAI context window limits
|
|
63
|
+
_CONTEXT_LIMITS: dict[str, int] = {
|
|
64
|
+
# GPT-4o series
|
|
65
|
+
"gpt-4o": 128000,
|
|
66
|
+
"gpt-4o-mini": 128000,
|
|
67
|
+
"gpt-4o-2024-11-20": 128000,
|
|
68
|
+
"gpt-4o-2024-08-06": 128000,
|
|
69
|
+
"gpt-4o-2024-05-13": 128000,
|
|
70
|
+
# GPT-4 Turbo
|
|
71
|
+
"gpt-4-turbo": 128000,
|
|
72
|
+
"gpt-4-turbo-preview": 128000,
|
|
73
|
+
"gpt-4-1106-preview": 128000,
|
|
74
|
+
# GPT-4
|
|
75
|
+
"gpt-4": 8192,
|
|
76
|
+
"gpt-4-32k": 32768,
|
|
77
|
+
# GPT-3.5
|
|
78
|
+
"gpt-3.5-turbo": 16385,
|
|
79
|
+
"gpt-3.5-turbo-16k": 16385,
|
|
80
|
+
# o1/o3 reasoning models
|
|
81
|
+
"o1": 200000,
|
|
82
|
+
"o1-preview": 128000,
|
|
83
|
+
"o1-mini": 128000,
|
|
84
|
+
"o3": 200000,
|
|
85
|
+
"o3-mini": 200000,
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
# Fallback pricing - LiteLLM is preferred source
|
|
89
|
+
# OpenAI pricing per 1M tokens (input, output)
|
|
90
|
+
# NOTE: These are ESTIMATES. Always verify against actual OpenAI billing.
|
|
91
|
+
# Last updated: 2025-01-14
|
|
92
|
+
_PRICING: dict[str, tuple[float, float]] = {
|
|
93
|
+
"gpt-4o": (2.50, 10.00),
|
|
94
|
+
"gpt-4o-mini": (0.15, 0.60),
|
|
95
|
+
"gpt-4-turbo": (10.00, 30.00),
|
|
96
|
+
"gpt-4": (30.00, 60.00),
|
|
97
|
+
"gpt-3.5-turbo": (0.50, 1.50),
|
|
98
|
+
"o1": (15.00, 60.00),
|
|
99
|
+
"o1-preview": (15.00, 60.00),
|
|
100
|
+
"o1-mini": (3.00, 12.00),
|
|
101
|
+
"o3": (10.00, 40.00),
|
|
102
|
+
"o3-mini": (1.10, 4.40),
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
# Pattern-based defaults for unknown models
|
|
106
|
+
_PATTERN_DEFAULTS = {
|
|
107
|
+
"gpt-4o": {"context": 128000, "encoding": "o200k_base", "pricing": (2.50, 10.00)},
|
|
108
|
+
"gpt-4-turbo": {"context": 128000, "encoding": "cl100k_base", "pricing": (10.00, 30.00)},
|
|
109
|
+
"gpt-4": {"context": 8192, "encoding": "cl100k_base", "pricing": (30.00, 60.00)},
|
|
110
|
+
"gpt-3.5": {"context": 16385, "encoding": "cl100k_base", "pricing": (0.50, 1.50)},
|
|
111
|
+
"o1": {"context": 200000, "encoding": "o200k_base", "pricing": (15.00, 60.00)},
|
|
112
|
+
"o3": {"context": 200000, "encoding": "o200k_base", "pricing": (10.00, 40.00)},
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
# Default for completely unknown OpenAI models
|
|
116
|
+
_UNKNOWN_OPENAI_DEFAULT = {
|
|
117
|
+
"context": 128000,
|
|
118
|
+
"encoding": "o200k_base",
|
|
119
|
+
"pricing": (2.50, 10.00), # GPT-4o tier as reasonable default
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def _load_custom_model_config() -> dict[str, Any]:
|
|
124
|
+
"""Load custom model configuration from environment or config file.
|
|
125
|
+
|
|
126
|
+
Checks (in order):
|
|
127
|
+
1. HEADROOM_MODEL_LIMITS environment variable (JSON string or file path)
|
|
128
|
+
2. ~/.headroom/models.json config file
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
Dict with 'context_limits' and 'pricing' keys.
|
|
132
|
+
"""
|
|
133
|
+
config: dict[str, Any] = {"context_limits": {}, "pricing": {}, "encodings": {}}
|
|
134
|
+
|
|
135
|
+
# Check environment variable
|
|
136
|
+
env_config = os.environ.get("HEADROOM_MODEL_LIMITS", "")
|
|
137
|
+
if env_config:
|
|
138
|
+
try:
|
|
139
|
+
# Check if it's a file path
|
|
140
|
+
if os.path.isfile(env_config):
|
|
141
|
+
with open(env_config) as f:
|
|
142
|
+
loaded = json.load(f)
|
|
143
|
+
else:
|
|
144
|
+
# Try to parse as JSON string
|
|
145
|
+
loaded = json.loads(env_config)
|
|
146
|
+
|
|
147
|
+
openai_config = loaded.get("openai", loaded)
|
|
148
|
+
if "context_limits" in openai_config:
|
|
149
|
+
config["context_limits"].update(openai_config["context_limits"])
|
|
150
|
+
if "pricing" in openai_config:
|
|
151
|
+
config["pricing"].update(openai_config["pricing"])
|
|
152
|
+
if "encodings" in openai_config:
|
|
153
|
+
config["encodings"].update(openai_config["encodings"])
|
|
154
|
+
|
|
155
|
+
logger.debug("Loaded custom OpenAI model config from HEADROOM_MODEL_LIMITS")
|
|
156
|
+
except (json.JSONDecodeError, OSError) as e:
|
|
157
|
+
logger.warning(f"Failed to load HEADROOM_MODEL_LIMITS: {e}")
|
|
158
|
+
|
|
159
|
+
# Check config file
|
|
160
|
+
config_file = Path.home() / ".headroom" / "models.json"
|
|
161
|
+
if config_file.exists():
|
|
162
|
+
try:
|
|
163
|
+
with open(config_file) as f:
|
|
164
|
+
loaded = json.load(f)
|
|
165
|
+
|
|
166
|
+
openai_config = loaded.get("openai", {})
|
|
167
|
+
if "context_limits" in openai_config:
|
|
168
|
+
for model, limit in openai_config["context_limits"].items():
|
|
169
|
+
if model not in config["context_limits"]:
|
|
170
|
+
config["context_limits"][model] = limit
|
|
171
|
+
if "pricing" in openai_config:
|
|
172
|
+
for model, pricing in openai_config["pricing"].items():
|
|
173
|
+
if model not in config["pricing"]:
|
|
174
|
+
config["pricing"][model] = pricing
|
|
175
|
+
if "encodings" in openai_config:
|
|
176
|
+
for model, encoding in openai_config["encodings"].items():
|
|
177
|
+
if model not in config["encodings"]:
|
|
178
|
+
config["encodings"][model] = encoding
|
|
179
|
+
|
|
180
|
+
logger.debug(f"Loaded custom OpenAI model config from {config_file}")
|
|
181
|
+
except (json.JSONDecodeError, OSError) as e:
|
|
182
|
+
logger.warning(f"Failed to load {config_file}: {e}")
|
|
183
|
+
|
|
184
|
+
return config
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def _infer_model_family(model: str) -> str | None:
|
|
188
|
+
"""Infer the model family from model name for pattern-based defaults."""
|
|
189
|
+
model_lower = model.lower()
|
|
190
|
+
|
|
191
|
+
# Check in order of specificity
|
|
192
|
+
if model_lower.startswith("gpt-4o"):
|
|
193
|
+
return "gpt-4o"
|
|
194
|
+
elif model_lower.startswith("gpt-4-turbo"):
|
|
195
|
+
return "gpt-4-turbo"
|
|
196
|
+
elif model_lower.startswith("gpt-4"):
|
|
197
|
+
return "gpt-4"
|
|
198
|
+
elif model_lower.startswith("gpt-3.5"):
|
|
199
|
+
return "gpt-3.5"
|
|
200
|
+
elif model_lower.startswith("o1"):
|
|
201
|
+
return "o1"
|
|
202
|
+
elif model_lower.startswith("o3"):
|
|
203
|
+
return "o3"
|
|
204
|
+
|
|
205
|
+
return None
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def _check_pricing_staleness() -> str | None:
|
|
209
|
+
"""Check if pricing data is stale and return warning message if so."""
|
|
210
|
+
global _PRICING_WARNING_SHOWN
|
|
211
|
+
days_old = (date.today() - _PRICING_LAST_UPDATED).days
|
|
212
|
+
if days_old > _PRICING_STALE_DAYS and not _PRICING_WARNING_SHOWN:
|
|
213
|
+
_PRICING_WARNING_SHOWN = True
|
|
214
|
+
return (
|
|
215
|
+
f"OpenAI pricing data is {days_old} days old. "
|
|
216
|
+
"Cost estimates may be inaccurate. Verify against actual billing."
|
|
217
|
+
)
|
|
218
|
+
return None
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
@lru_cache(maxsize=8)
|
|
222
|
+
def _get_encoding(encoding_name: str) -> Any:
|
|
223
|
+
"""Get tiktoken encoding, cached."""
|
|
224
|
+
if not TIKTOKEN_AVAILABLE:
|
|
225
|
+
raise RuntimeError(
|
|
226
|
+
"tiktoken is required for OpenAI provider. Install with: pip install tiktoken"
|
|
227
|
+
)
|
|
228
|
+
return tiktoken.get_encoding(encoding_name)
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def _get_encoding_name_for_model(model: str, custom_encodings: dict[str, str] | None = None) -> str:
|
|
232
|
+
"""Get the encoding name for a model with fallback support."""
|
|
233
|
+
# Check custom encodings first
|
|
234
|
+
if custom_encodings and model in custom_encodings:
|
|
235
|
+
return custom_encodings[model]
|
|
236
|
+
|
|
237
|
+
# Direct match
|
|
238
|
+
if model in _MODEL_ENCODINGS:
|
|
239
|
+
return _MODEL_ENCODINGS[model]
|
|
240
|
+
|
|
241
|
+
# Prefix match for versioned models
|
|
242
|
+
for prefix, encoding in _MODEL_ENCODINGS.items():
|
|
243
|
+
if model.startswith(prefix):
|
|
244
|
+
return encoding
|
|
245
|
+
|
|
246
|
+
# Pattern-based inference
|
|
247
|
+
family = _infer_model_family(model)
|
|
248
|
+
if family and family in _PATTERN_DEFAULTS:
|
|
249
|
+
return cast(str, _PATTERN_DEFAULTS[family]["encoding"])
|
|
250
|
+
|
|
251
|
+
# Default for unknown models
|
|
252
|
+
return cast(str, _UNKNOWN_OPENAI_DEFAULT["encoding"])
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
class OpenAITokenCounter:
|
|
256
|
+
"""Token counter using tiktoken for OpenAI models."""
|
|
257
|
+
|
|
258
|
+
def __init__(self, model: str, custom_encodings: dict[str, str] | None = None):
|
|
259
|
+
"""
|
|
260
|
+
Initialize token counter for a model.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
model: OpenAI model name.
|
|
264
|
+
custom_encodings: Optional custom model -> encoding mappings.
|
|
265
|
+
|
|
266
|
+
Raises:
|
|
267
|
+
RuntimeError: If tiktoken is not installed.
|
|
268
|
+
"""
|
|
269
|
+
self.model = model
|
|
270
|
+
encoding_name = _get_encoding_name_for_model(model, custom_encodings)
|
|
271
|
+
self._encoding = _get_encoding(encoding_name)
|
|
272
|
+
|
|
273
|
+
def count_text(self, text: str) -> int:
|
|
274
|
+
"""Count tokens in text."""
|
|
275
|
+
if not text:
|
|
276
|
+
return 0
|
|
277
|
+
return len(self._encoding.encode(text))
|
|
278
|
+
|
|
279
|
+
def count_message(self, message: dict[str, Any]) -> int:
|
|
280
|
+
"""
|
|
281
|
+
Count tokens in a single message.
|
|
282
|
+
|
|
283
|
+
Accounts for ChatML format overhead.
|
|
284
|
+
"""
|
|
285
|
+
# Base overhead per message (role + delimiters)
|
|
286
|
+
tokens = 4
|
|
287
|
+
|
|
288
|
+
role = message.get("role", "")
|
|
289
|
+
tokens += self.count_text(role)
|
|
290
|
+
|
|
291
|
+
content = message.get("content")
|
|
292
|
+
if content:
|
|
293
|
+
if isinstance(content, str):
|
|
294
|
+
tokens += self.count_text(content)
|
|
295
|
+
elif isinstance(content, list):
|
|
296
|
+
for part in content:
|
|
297
|
+
if isinstance(part, dict):
|
|
298
|
+
if part.get("type") == "text":
|
|
299
|
+
tokens += self.count_text(part.get("text", ""))
|
|
300
|
+
elif part.get("type") == "image_url":
|
|
301
|
+
tokens += 85 # Low detail image estimate
|
|
302
|
+
elif isinstance(part, str):
|
|
303
|
+
tokens += self.count_text(part)
|
|
304
|
+
|
|
305
|
+
# Name field
|
|
306
|
+
name = message.get("name")
|
|
307
|
+
if name:
|
|
308
|
+
tokens += self.count_text(name) + 1
|
|
309
|
+
|
|
310
|
+
# Tool calls in assistant messages
|
|
311
|
+
tool_calls = message.get("tool_calls")
|
|
312
|
+
if tool_calls:
|
|
313
|
+
for tc in tool_calls:
|
|
314
|
+
func = tc.get("function", {})
|
|
315
|
+
tokens += self.count_text(func.get("name", ""))
|
|
316
|
+
tokens += self.count_text(func.get("arguments", ""))
|
|
317
|
+
tokens += self.count_text(tc.get("id", ""))
|
|
318
|
+
tokens += 10 # Structural overhead
|
|
319
|
+
|
|
320
|
+
# Tool call ID for tool responses
|
|
321
|
+
tool_call_id = message.get("tool_call_id")
|
|
322
|
+
if tool_call_id:
|
|
323
|
+
tokens += self.count_text(tool_call_id) + 2
|
|
324
|
+
|
|
325
|
+
return tokens
|
|
326
|
+
|
|
327
|
+
def count_messages(self, messages: list[dict[str, Any]]) -> int:
|
|
328
|
+
"""Count tokens in a list of messages."""
|
|
329
|
+
total = sum(self.count_message(msg) for msg in messages)
|
|
330
|
+
# Add priming tokens for assistant response
|
|
331
|
+
total += 3
|
|
332
|
+
return total
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
class OpenAIProvider(Provider):
|
|
336
|
+
"""Provider implementation for OpenAI models.
|
|
337
|
+
|
|
338
|
+
Custom Model Configuration:
|
|
339
|
+
You can configure custom models via environment variable or config file:
|
|
340
|
+
|
|
341
|
+
1. Environment variable (JSON string):
|
|
342
|
+
export HEADROOM_MODEL_LIMITS='{"openai": {"context_limits": {"my-model": 128000}}}'
|
|
343
|
+
|
|
344
|
+
2. Environment variable (file path):
|
|
345
|
+
export HEADROOM_MODEL_LIMITS=/path/to/models.json
|
|
346
|
+
|
|
347
|
+
3. Config file (~/.headroom/models.json):
|
|
348
|
+
{
|
|
349
|
+
"openai": {
|
|
350
|
+
"context_limits": {"my-model": 128000},
|
|
351
|
+
"pricing": {"my-model": [2.50, 10.00]}
|
|
352
|
+
}
|
|
353
|
+
}
|
|
354
|
+
"""
|
|
355
|
+
|
|
356
|
+
def __init__(self, context_limits: dict[str, int] | None = None):
|
|
357
|
+
"""Initialize OpenAI provider.
|
|
358
|
+
|
|
359
|
+
Args:
|
|
360
|
+
context_limits: Optional override for model context limits.
|
|
361
|
+
"""
|
|
362
|
+
# Build limits: defaults -> config file -> env var -> explicit
|
|
363
|
+
self._context_limits = {**_CONTEXT_LIMITS}
|
|
364
|
+
self._pricing = {**_PRICING}
|
|
365
|
+
self._encodings: dict[str, str] = {**_MODEL_ENCODINGS}
|
|
366
|
+
|
|
367
|
+
# Load from config file and env var
|
|
368
|
+
custom_config = _load_custom_model_config()
|
|
369
|
+
self._context_limits.update(custom_config["context_limits"])
|
|
370
|
+
self._encodings.update(custom_config["encodings"])
|
|
371
|
+
|
|
372
|
+
# Handle pricing (can be tuple or list from JSON)
|
|
373
|
+
for model, pricing in custom_config["pricing"].items():
|
|
374
|
+
if isinstance(pricing, (list, tuple)) and len(pricing) >= 2:
|
|
375
|
+
self._pricing[model] = (float(pricing[0]), float(pricing[1]))
|
|
376
|
+
|
|
377
|
+
# Explicit overrides take precedence
|
|
378
|
+
if context_limits:
|
|
379
|
+
self._context_limits.update(context_limits)
|
|
380
|
+
|
|
381
|
+
self._token_counters: dict[str, OpenAITokenCounter] = {}
|
|
382
|
+
|
|
383
|
+
@property
|
|
384
|
+
def name(self) -> str:
|
|
385
|
+
return "openai"
|
|
386
|
+
|
|
387
|
+
def supports_model(self, model: str) -> bool:
|
|
388
|
+
"""Check if model is a known OpenAI model."""
|
|
389
|
+
if model in self._context_limits:
|
|
390
|
+
return True
|
|
391
|
+
# Check prefix match
|
|
392
|
+
for prefix in self._context_limits:
|
|
393
|
+
if model.startswith(prefix):
|
|
394
|
+
return True
|
|
395
|
+
# Support any gpt-* or o1/o3 model
|
|
396
|
+
model_lower = model.lower()
|
|
397
|
+
return (
|
|
398
|
+
model_lower.startswith("gpt-")
|
|
399
|
+
or model_lower.startswith("o1")
|
|
400
|
+
or model_lower.startswith("o3")
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
def get_token_counter(self, model: str) -> TokenCounter:
|
|
404
|
+
"""Get token counter for an OpenAI model."""
|
|
405
|
+
if model not in self._token_counters:
|
|
406
|
+
self._token_counters[model] = OpenAITokenCounter(
|
|
407
|
+
model=model, custom_encodings=self._encodings
|
|
408
|
+
)
|
|
409
|
+
return self._token_counters[model]
|
|
410
|
+
|
|
411
|
+
def get_context_limit(self, model: str) -> int:
|
|
412
|
+
"""Get context limit for an OpenAI model.
|
|
413
|
+
|
|
414
|
+
Resolution order:
|
|
415
|
+
1. LiteLLM (if available, most up-to-date)
|
|
416
|
+
2. Explicit context_limits passed to constructor
|
|
417
|
+
3. HEADROOM_MODEL_LIMITS environment variable
|
|
418
|
+
4. ~/.headroom/models.json config file
|
|
419
|
+
5. Built-in _CONTEXT_LIMITS
|
|
420
|
+
6. Pattern-based inference (gpt-4o, gpt-4, etc.)
|
|
421
|
+
7. Default fallback (128K)
|
|
422
|
+
|
|
423
|
+
Never raises an exception - uses sensible defaults for unknown models.
|
|
424
|
+
"""
|
|
425
|
+
# Try LiteLLM first
|
|
426
|
+
if LITELLM_AVAILABLE:
|
|
427
|
+
try:
|
|
428
|
+
info = litellm.get_model_info(model)
|
|
429
|
+
if info and "max_input_tokens" in info:
|
|
430
|
+
max_tokens = info["max_input_tokens"]
|
|
431
|
+
if max_tokens is not None:
|
|
432
|
+
return int(max_tokens)
|
|
433
|
+
except Exception:
|
|
434
|
+
pass
|
|
435
|
+
|
|
436
|
+
# Fall back to hardcoded
|
|
437
|
+
return self._get_context_limit_manual(model)
|
|
438
|
+
|
|
439
|
+
def _get_context_limit_manual(self, model: str) -> int:
|
|
440
|
+
"""Get context limit using hardcoded values (fallback)."""
|
|
441
|
+
if model in self._context_limits:
|
|
442
|
+
return self._context_limits[model]
|
|
443
|
+
|
|
444
|
+
# Prefix match
|
|
445
|
+
for prefix, limit in self._context_limits.items():
|
|
446
|
+
if model.startswith(prefix):
|
|
447
|
+
return limit
|
|
448
|
+
|
|
449
|
+
# Pattern-based inference
|
|
450
|
+
family = _infer_model_family(model)
|
|
451
|
+
if family and family in _PATTERN_DEFAULTS:
|
|
452
|
+
limit = cast(int, _PATTERN_DEFAULTS[family]["context"])
|
|
453
|
+
self._warn_unknown_model(model, limit, f"inferred from '{family}' family")
|
|
454
|
+
self._context_limits[model] = limit
|
|
455
|
+
return limit
|
|
456
|
+
|
|
457
|
+
# Default for unknown OpenAI models
|
|
458
|
+
limit = cast(int, _UNKNOWN_OPENAI_DEFAULT["context"])
|
|
459
|
+
self._warn_unknown_model(model, limit, "using default limit")
|
|
460
|
+
self._context_limits[model] = limit
|
|
461
|
+
return limit
|
|
462
|
+
|
|
463
|
+
def _warn_unknown_model(self, model: str, limit: int, reason: str) -> None:
|
|
464
|
+
"""Warn about unknown model (once per model)."""
|
|
465
|
+
global _UNKNOWN_MODEL_WARNINGS
|
|
466
|
+
if model not in _UNKNOWN_MODEL_WARNINGS:
|
|
467
|
+
_UNKNOWN_MODEL_WARNINGS.add(model)
|
|
468
|
+
logger.warning(
|
|
469
|
+
f"Unknown OpenAI model '{model}': {reason} ({limit:,} tokens). "
|
|
470
|
+
f"To configure explicitly, set HEADROOM_MODEL_LIMITS env var or "
|
|
471
|
+
f"add to ~/.headroom/models.json"
|
|
472
|
+
)
|
|
473
|
+
|
|
474
|
+
def estimate_cost(
|
|
475
|
+
self,
|
|
476
|
+
input_tokens: int,
|
|
477
|
+
output_tokens: int,
|
|
478
|
+
model: str,
|
|
479
|
+
cached_tokens: int = 0,
|
|
480
|
+
) -> float | None:
|
|
481
|
+
"""Estimate cost for OpenAI API call.
|
|
482
|
+
|
|
483
|
+
⚠️ IMPORTANT: This is an ESTIMATE only.
|
|
484
|
+
- Pricing data may be outdated
|
|
485
|
+
- Cached token discount assumed at 50% (actual may vary)
|
|
486
|
+
- Always verify against your actual OpenAI billing
|
|
487
|
+
|
|
488
|
+
Args:
|
|
489
|
+
input_tokens: Number of input tokens.
|
|
490
|
+
output_tokens: Number of output tokens.
|
|
491
|
+
model: Model name.
|
|
492
|
+
cached_tokens: Number of cached tokens (estimated 50% discount).
|
|
493
|
+
|
|
494
|
+
Returns:
|
|
495
|
+
Estimated cost in USD, or None if pricing unknown.
|
|
496
|
+
"""
|
|
497
|
+
# Try LiteLLM first (most up-to-date pricing)
|
|
498
|
+
if LITELLM_AVAILABLE:
|
|
499
|
+
try:
|
|
500
|
+
# LiteLLM uses per-token pricing, returns total cost
|
|
501
|
+
cost = litellm.completion_cost(
|
|
502
|
+
model=model,
|
|
503
|
+
prompt_tokens=input_tokens,
|
|
504
|
+
completion_tokens=output_tokens,
|
|
505
|
+
)
|
|
506
|
+
if cost is not None and cost > 0:
|
|
507
|
+
return float(cost)
|
|
508
|
+
except Exception:
|
|
509
|
+
pass # Fall through to manual pricing
|
|
510
|
+
|
|
511
|
+
# Fall back to hardcoded pricing
|
|
512
|
+
return self._estimate_cost_manual(input_tokens, output_tokens, model, cached_tokens)
|
|
513
|
+
|
|
514
|
+
def _estimate_cost_manual(
|
|
515
|
+
self,
|
|
516
|
+
input_tokens: int,
|
|
517
|
+
output_tokens: int,
|
|
518
|
+
model: str,
|
|
519
|
+
cached_tokens: int = 0,
|
|
520
|
+
) -> float | None:
|
|
521
|
+
"""Estimate cost using hardcoded pricing (fallback)."""
|
|
522
|
+
# Check for stale pricing and warn once
|
|
523
|
+
staleness_warning = _check_pricing_staleness()
|
|
524
|
+
if staleness_warning:
|
|
525
|
+
warnings.warn(staleness_warning, UserWarning, stacklevel=2)
|
|
526
|
+
|
|
527
|
+
pricing = self._get_pricing(model)
|
|
528
|
+
if not pricing:
|
|
529
|
+
return None
|
|
530
|
+
|
|
531
|
+
input_price, output_price = pricing
|
|
532
|
+
|
|
533
|
+
# Calculate cost (cached tokens get estimated 50% discount)
|
|
534
|
+
# NOTE: Actual OpenAI cache discount may vary
|
|
535
|
+
regular_input = input_tokens - cached_tokens
|
|
536
|
+
cached_cost = (cached_tokens / 1_000_000) * input_price * 0.5
|
|
537
|
+
regular_cost = (regular_input / 1_000_000) * input_price
|
|
538
|
+
output_cost = (output_tokens / 1_000_000) * output_price
|
|
539
|
+
|
|
540
|
+
return cached_cost + regular_cost + output_cost
|
|
541
|
+
|
|
542
|
+
def _get_pricing(self, model: str) -> tuple[float, float] | None:
|
|
543
|
+
"""Get pricing for a model with fallback logic."""
|
|
544
|
+
# Direct match
|
|
545
|
+
if model in self._pricing:
|
|
546
|
+
return self._pricing[model]
|
|
547
|
+
|
|
548
|
+
# Prefix match
|
|
549
|
+
for model_prefix, pricing in self._pricing.items():
|
|
550
|
+
if model.startswith(model_prefix):
|
|
551
|
+
return pricing
|
|
552
|
+
|
|
553
|
+
# Pattern-based inference
|
|
554
|
+
family = _infer_model_family(model)
|
|
555
|
+
if family and family in _PATTERN_DEFAULTS:
|
|
556
|
+
return cast(tuple[float, float], _PATTERN_DEFAULTS[family]["pricing"])
|
|
557
|
+
|
|
558
|
+
# Default for unknown models
|
|
559
|
+
return cast(tuple[float, float], _UNKNOWN_OPENAI_DEFAULT["pricing"])
|
|
560
|
+
|
|
561
|
+
def get_output_buffer(self, model: str, default: int = 4000) -> int:
|
|
562
|
+
"""Get recommended output buffer."""
|
|
563
|
+
# Reasoning models produce longer outputs
|
|
564
|
+
if model.startswith("o1") or model.startswith("o3"):
|
|
565
|
+
return 8000
|
|
566
|
+
return default
|