headroom-ai 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (114) hide show
  1. headroom/__init__.py +212 -0
  2. headroom/cache/__init__.py +76 -0
  3. headroom/cache/anthropic.py +517 -0
  4. headroom/cache/base.py +342 -0
  5. headroom/cache/compression_feedback.py +613 -0
  6. headroom/cache/compression_store.py +814 -0
  7. headroom/cache/dynamic_detector.py +1026 -0
  8. headroom/cache/google.py +884 -0
  9. headroom/cache/openai.py +584 -0
  10. headroom/cache/registry.py +175 -0
  11. headroom/cache/semantic.py +451 -0
  12. headroom/ccr/__init__.py +77 -0
  13. headroom/ccr/context_tracker.py +582 -0
  14. headroom/ccr/mcp_server.py +319 -0
  15. headroom/ccr/response_handler.py +772 -0
  16. headroom/ccr/tool_injection.py +415 -0
  17. headroom/cli.py +219 -0
  18. headroom/client.py +977 -0
  19. headroom/compression/__init__.py +42 -0
  20. headroom/compression/detector.py +424 -0
  21. headroom/compression/handlers/__init__.py +22 -0
  22. headroom/compression/handlers/base.py +219 -0
  23. headroom/compression/handlers/code_handler.py +506 -0
  24. headroom/compression/handlers/json_handler.py +418 -0
  25. headroom/compression/masks.py +345 -0
  26. headroom/compression/universal.py +465 -0
  27. headroom/config.py +474 -0
  28. headroom/exceptions.py +192 -0
  29. headroom/integrations/__init__.py +159 -0
  30. headroom/integrations/agno/__init__.py +53 -0
  31. headroom/integrations/agno/hooks.py +345 -0
  32. headroom/integrations/agno/model.py +625 -0
  33. headroom/integrations/agno/providers.py +154 -0
  34. headroom/integrations/langchain/__init__.py +106 -0
  35. headroom/integrations/langchain/agents.py +326 -0
  36. headroom/integrations/langchain/chat_model.py +1002 -0
  37. headroom/integrations/langchain/langsmith.py +324 -0
  38. headroom/integrations/langchain/memory.py +319 -0
  39. headroom/integrations/langchain/providers.py +200 -0
  40. headroom/integrations/langchain/retriever.py +371 -0
  41. headroom/integrations/langchain/streaming.py +341 -0
  42. headroom/integrations/mcp/__init__.py +37 -0
  43. headroom/integrations/mcp/server.py +533 -0
  44. headroom/memory/__init__.py +37 -0
  45. headroom/memory/extractor.py +390 -0
  46. headroom/memory/fast_store.py +621 -0
  47. headroom/memory/fast_wrapper.py +311 -0
  48. headroom/memory/inline_extractor.py +229 -0
  49. headroom/memory/store.py +434 -0
  50. headroom/memory/worker.py +260 -0
  51. headroom/memory/wrapper.py +321 -0
  52. headroom/models/__init__.py +39 -0
  53. headroom/models/registry.py +687 -0
  54. headroom/parser.py +293 -0
  55. headroom/pricing/__init__.py +51 -0
  56. headroom/pricing/anthropic_prices.py +81 -0
  57. headroom/pricing/litellm_pricing.py +113 -0
  58. headroom/pricing/openai_prices.py +91 -0
  59. headroom/pricing/registry.py +188 -0
  60. headroom/providers/__init__.py +61 -0
  61. headroom/providers/anthropic.py +621 -0
  62. headroom/providers/base.py +131 -0
  63. headroom/providers/cohere.py +362 -0
  64. headroom/providers/google.py +427 -0
  65. headroom/providers/litellm.py +297 -0
  66. headroom/providers/openai.py +566 -0
  67. headroom/providers/openai_compatible.py +521 -0
  68. headroom/proxy/__init__.py +19 -0
  69. headroom/proxy/server.py +2683 -0
  70. headroom/py.typed +0 -0
  71. headroom/relevance/__init__.py +124 -0
  72. headroom/relevance/base.py +106 -0
  73. headroom/relevance/bm25.py +255 -0
  74. headroom/relevance/embedding.py +255 -0
  75. headroom/relevance/hybrid.py +259 -0
  76. headroom/reporting/__init__.py +5 -0
  77. headroom/reporting/generator.py +549 -0
  78. headroom/storage/__init__.py +41 -0
  79. headroom/storage/base.py +125 -0
  80. headroom/storage/jsonl.py +220 -0
  81. headroom/storage/sqlite.py +289 -0
  82. headroom/telemetry/__init__.py +91 -0
  83. headroom/telemetry/collector.py +764 -0
  84. headroom/telemetry/models.py +880 -0
  85. headroom/telemetry/toin.py +1579 -0
  86. headroom/tokenizer.py +80 -0
  87. headroom/tokenizers/__init__.py +75 -0
  88. headroom/tokenizers/base.py +210 -0
  89. headroom/tokenizers/estimator.py +198 -0
  90. headroom/tokenizers/huggingface.py +317 -0
  91. headroom/tokenizers/mistral.py +245 -0
  92. headroom/tokenizers/registry.py +398 -0
  93. headroom/tokenizers/tiktoken_counter.py +248 -0
  94. headroom/transforms/__init__.py +106 -0
  95. headroom/transforms/base.py +57 -0
  96. headroom/transforms/cache_aligner.py +357 -0
  97. headroom/transforms/code_compressor.py +1313 -0
  98. headroom/transforms/content_detector.py +335 -0
  99. headroom/transforms/content_router.py +1158 -0
  100. headroom/transforms/llmlingua_compressor.py +638 -0
  101. headroom/transforms/log_compressor.py +529 -0
  102. headroom/transforms/pipeline.py +297 -0
  103. headroom/transforms/rolling_window.py +350 -0
  104. headroom/transforms/search_compressor.py +365 -0
  105. headroom/transforms/smart_crusher.py +2682 -0
  106. headroom/transforms/text_compressor.py +259 -0
  107. headroom/transforms/tool_crusher.py +338 -0
  108. headroom/utils.py +215 -0
  109. headroom_ai-0.2.13.dist-info/METADATA +315 -0
  110. headroom_ai-0.2.13.dist-info/RECORD +114 -0
  111. headroom_ai-0.2.13.dist-info/WHEEL +4 -0
  112. headroom_ai-0.2.13.dist-info/entry_points.txt +2 -0
  113. headroom_ai-0.2.13.dist-info/licenses/LICENSE +190 -0
  114. headroom_ai-0.2.13.dist-info/licenses/NOTICE +43 -0
headroom/parser.py ADDED
@@ -0,0 +1,293 @@
1
+ """Message parsing utilities for Headroom SDK."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import hashlib
6
+ import re
7
+ from typing import TYPE_CHECKING, Any
8
+
9
+ from .config import Block, WasteSignals
10
+
11
+ if TYPE_CHECKING:
12
+ from .tokenizer import Tokenizer
13
+
14
+
15
+ # Patterns for detecting waste signals
16
+ HTML_TAG_PATTERN = re.compile(r"<[^>]+>")
17
+ HTML_COMMENT_PATTERN = re.compile(r"<!--[\s\S]*?-->")
18
+ BASE64_PATTERN = re.compile(r"[A-Za-z0-9+/]{50,}={0,2}")
19
+ WHITESPACE_PATTERN = re.compile(r"[ \t]{4,}|\n{3,}")
20
+ JSON_BLOCK_PATTERN = re.compile(r"\{[\s\S]{500,}\}")
21
+
22
+ # Patterns for RAG detection (best effort)
23
+ RAG_MARKERS = [
24
+ r"\[Document\s*\d+\]",
25
+ r"\[Source:\s*",
26
+ r"<context>",
27
+ r"<document>",
28
+ r"Retrieved from:",
29
+ r"From the knowledge base:",
30
+ ]
31
+ RAG_PATTERN = re.compile("|".join(RAG_MARKERS), re.IGNORECASE)
32
+
33
+
34
+ def compute_hash(text: str) -> str:
35
+ """Compute SHA256 hash of text, truncated to 16 chars."""
36
+ return hashlib.sha256(text.encode()).hexdigest()[:16]
37
+
38
+
39
+ def detect_waste_signals(text: str, tokenizer: Tokenizer) -> WasteSignals:
40
+ """
41
+ Detect waste signals in text.
42
+
43
+ Args:
44
+ text: The text to analyze.
45
+ tokenizer: Tokenizer for counting tokens.
46
+
47
+ Returns:
48
+ WasteSignals with detected waste.
49
+ """
50
+ signals = WasteSignals()
51
+
52
+ if not text:
53
+ return signals
54
+
55
+ # HTML tags and comments
56
+ html_matches = HTML_TAG_PATTERN.findall(text) + HTML_COMMENT_PATTERN.findall(text)
57
+ if html_matches:
58
+ html_text = "".join(html_matches)
59
+ signals.html_noise_tokens = tokenizer.count_text(html_text)
60
+
61
+ # Base64 blobs
62
+ base64_matches = BASE64_PATTERN.findall(text)
63
+ if base64_matches:
64
+ base64_text = "".join(base64_matches)
65
+ signals.base64_tokens = tokenizer.count_text(base64_text)
66
+
67
+ # Excessive whitespace
68
+ ws_matches = WHITESPACE_PATTERN.findall(text)
69
+ if ws_matches:
70
+ # Count tokens that could be saved by normalizing
71
+ ws_text = "".join(ws_matches)
72
+ signals.whitespace_tokens = max(0, tokenizer.count_text(ws_text) - len(ws_matches))
73
+
74
+ # Large JSON blocks
75
+ json_matches = JSON_BLOCK_PATTERN.findall(text)
76
+ if json_matches:
77
+ for match in json_matches:
78
+ tokens = tokenizer.count_text(match)
79
+ if tokens > 500:
80
+ signals.json_bloat_tokens += tokens
81
+
82
+ return signals
83
+
84
+
85
+ def is_rag_content(text: str) -> bool:
86
+ """Check if text appears to be RAG-injected content."""
87
+ return RAG_PATTERN.search(text) is not None
88
+
89
+
90
+ def parse_message_to_blocks(
91
+ message: dict[str, Any],
92
+ index: int,
93
+ tokenizer: Tokenizer,
94
+ ) -> list[Block]:
95
+ """
96
+ Parse a single message into Block objects.
97
+
98
+ Args:
99
+ message: The message dict to parse.
100
+ index: Position in the message list.
101
+ tokenizer: Tokenizer for token counting.
102
+
103
+ Returns:
104
+ List of Block objects (usually 1, but tool_calls may produce multiple).
105
+ """
106
+ blocks: list[Block] = []
107
+ role = message.get("role", "unknown")
108
+
109
+ # Handle content
110
+ content = message.get("content")
111
+ if content:
112
+ if isinstance(content, str):
113
+ text = content
114
+ elif isinstance(content, list):
115
+ # Multi-modal - extract text parts
116
+ text_parts = []
117
+ for part in content:
118
+ if isinstance(part, dict) and part.get("type") == "text":
119
+ text_parts.append(part.get("text", ""))
120
+ elif isinstance(part, str):
121
+ text_parts.append(part)
122
+ text = "\n".join(text_parts)
123
+ else:
124
+ text = str(content)
125
+
126
+ # Determine block kind
127
+ if role == "system":
128
+ kind = "system"
129
+ elif role == "user":
130
+ # Check if this looks like RAG content
131
+ kind = "rag" if is_rag_content(text) else "user"
132
+ elif role == "assistant":
133
+ kind = "assistant"
134
+ elif role == "tool":
135
+ kind = "tool_result"
136
+ else:
137
+ kind = "unknown"
138
+
139
+ # Build flags
140
+ flags: dict[str, Any] = {}
141
+ if role == "tool":
142
+ flags["tool_call_id"] = message.get("tool_call_id")
143
+
144
+ # Detect waste
145
+ waste = detect_waste_signals(text, tokenizer)
146
+ if waste.total() > 0:
147
+ flags["waste_signals"] = waste.to_dict()
148
+
149
+ blocks.append(
150
+ Block(
151
+ kind=kind, # type: ignore[arg-type]
152
+ text=text,
153
+ tokens_est=tokenizer.count_text(text) + 4, # Add message overhead
154
+ content_hash=compute_hash(text),
155
+ source_index=index,
156
+ flags=flags,
157
+ )
158
+ )
159
+
160
+ # Handle tool calls (assistant messages with tool_calls)
161
+ tool_calls = message.get("tool_calls")
162
+ if tool_calls:
163
+ for tc in tool_calls:
164
+ func = tc.get("function", {})
165
+ tc_text = f"{func.get('name', 'unknown')}({func.get('arguments', '')})"
166
+
167
+ blocks.append(
168
+ Block(
169
+ kind="tool_call",
170
+ text=tc_text,
171
+ tokens_est=tokenizer.count_text(tc_text) + 10,
172
+ content_hash=compute_hash(tc_text),
173
+ source_index=index,
174
+ flags={
175
+ "tool_call_id": tc.get("id"),
176
+ "function_name": func.get("name"),
177
+ },
178
+ )
179
+ )
180
+
181
+ # If no content or tool_calls, create a minimal block
182
+ if not blocks:
183
+ blocks.append(
184
+ Block(
185
+ kind="unknown",
186
+ text="",
187
+ tokens_est=4,
188
+ content_hash=compute_hash(""),
189
+ source_index=index,
190
+ flags={},
191
+ )
192
+ )
193
+
194
+ return blocks
195
+
196
+
197
+ def parse_messages(
198
+ messages: list[dict[str, Any]],
199
+ tokenizer: Tokenizer,
200
+ ) -> tuple[list[Block], dict[str, int], WasteSignals]:
201
+ """
202
+ Parse all messages into blocks with analysis.
203
+
204
+ Args:
205
+ messages: List of message dicts.
206
+ tokenizer: Tokenizer instance for token counting.
207
+
208
+ Returns:
209
+ Tuple of (blocks, block_breakdown, total_waste_signals)
210
+ """
211
+ all_blocks: list[Block] = []
212
+ total_waste = WasteSignals()
213
+
214
+ for i, msg in enumerate(messages):
215
+ blocks = parse_message_to_blocks(msg, i, tokenizer)
216
+ all_blocks.extend(blocks)
217
+
218
+ # Accumulate waste signals
219
+ for block in blocks:
220
+ if "waste_signals" in block.flags:
221
+ ws = block.flags["waste_signals"]
222
+ total_waste.json_bloat_tokens += ws.get("json_bloat", 0)
223
+ total_waste.html_noise_tokens += ws.get("html_noise", 0)
224
+ total_waste.base64_tokens += ws.get("base64", 0)
225
+ total_waste.whitespace_tokens += ws.get("whitespace", 0)
226
+ total_waste.dynamic_date_tokens += ws.get("dynamic_date", 0)
227
+ total_waste.repetition_tokens += ws.get("repetition", 0)
228
+
229
+ # Compute block breakdown
230
+ breakdown: dict[str, int] = {}
231
+ for block in all_blocks:
232
+ kind = block.kind
233
+ breakdown[kind] = breakdown.get(kind, 0) + block.tokens_est
234
+
235
+ return all_blocks, breakdown, total_waste
236
+
237
+
238
+ def find_tool_units(messages: list[dict[str, Any]]) -> list[tuple[int, list[int]]]:
239
+ """
240
+ Find tool call units (assistant with tool_calls + corresponding tool responses).
241
+
242
+ A tool unit is atomic - if the assistant message is dropped, all its
243
+ tool responses must also be dropped.
244
+
245
+ Args:
246
+ messages: List of message dicts.
247
+
248
+ Returns:
249
+ List of (assistant_index, [tool_response_indices]) tuples.
250
+ """
251
+ units: list[tuple[int, list[int]]] = []
252
+
253
+ # Build map of tool_call_id -> message index for tool responses
254
+ tool_response_map: dict[str, int] = {}
255
+ for i, msg in enumerate(messages):
256
+ if msg.get("role") == "tool":
257
+ tc_id = msg.get("tool_call_id")
258
+ if tc_id:
259
+ tool_response_map[tc_id] = i
260
+
261
+ # Find assistant messages with tool_calls
262
+ for i, msg in enumerate(messages):
263
+ if msg.get("role") == "assistant" and msg.get("tool_calls"):
264
+ tool_calls = msg["tool_calls"]
265
+ response_indices: list[int] = []
266
+
267
+ for tc in tool_calls:
268
+ tc_id = tc.get("id")
269
+ if tc_id and tc_id in tool_response_map:
270
+ response_indices.append(tool_response_map[tc_id])
271
+
272
+ if response_indices:
273
+ units.append((i, sorted(response_indices)))
274
+
275
+ return units
276
+
277
+
278
+ def get_message_content_text(message: dict[str, Any]) -> str:
279
+ """Extract text content from a message."""
280
+ content = message.get("content")
281
+ if content is None:
282
+ return ""
283
+ if isinstance(content, str):
284
+ return content
285
+ if isinstance(content, list):
286
+ parts = []
287
+ for part in content:
288
+ if isinstance(part, dict) and part.get("type") == "text":
289
+ parts.append(part.get("text", ""))
290
+ elif isinstance(part, str):
291
+ parts.append(part)
292
+ return "\n".join(parts)
293
+ return str(content)
@@ -0,0 +1,51 @@
1
+ """Pricing module for LLM cost estimation.
2
+
3
+ This module provides pricing information and cost estimation utilities
4
+ for various LLM providers. Uses LiteLLM's community-maintained pricing
5
+ database for up-to-date costs across 100+ models.
6
+ """
7
+
8
+ # Legacy imports for backwards compatibility
9
+ from .anthropic_prices import (
10
+ ANTHROPIC_PRICES,
11
+ get_anthropic_registry,
12
+ )
13
+ from .anthropic_prices import (
14
+ LAST_UPDATED as ANTHROPIC_LAST_UPDATED,
15
+ )
16
+ from .litellm_pricing import (
17
+ LiteLLMModelPricing,
18
+ estimate_cost,
19
+ get_litellm_model_cost,
20
+ get_model_pricing,
21
+ list_available_models,
22
+ )
23
+ from .openai_prices import (
24
+ LAST_UPDATED as OPENAI_LAST_UPDATED,
25
+ )
26
+ from .openai_prices import (
27
+ OPENAI_PRICES,
28
+ get_openai_registry,
29
+ )
30
+ from .registry import CostEstimate, ModelPricing, PricingRegistry
31
+
32
+ __all__ = [
33
+ # LiteLLM-based pricing (preferred)
34
+ "LiteLLMModelPricing",
35
+ "estimate_cost",
36
+ "get_litellm_model_cost",
37
+ "get_model_pricing",
38
+ "list_available_models",
39
+ # Core classes
40
+ "CostEstimate",
41
+ "ModelPricing",
42
+ "PricingRegistry",
43
+ # Legacy - OpenAI (deprecated, use LiteLLM instead)
44
+ "OPENAI_LAST_UPDATED",
45
+ "OPENAI_PRICES",
46
+ "get_openai_registry",
47
+ # Legacy - Anthropic (deprecated, use LiteLLM instead)
48
+ "ANTHROPIC_LAST_UPDATED",
49
+ "ANTHROPIC_PRICES",
50
+ "get_anthropic_registry",
51
+ ]
@@ -0,0 +1,81 @@
1
+ """Anthropic model pricing information."""
2
+
3
+ from datetime import date
4
+
5
+ from .registry import ModelPricing, PricingRegistry
6
+
7
+ # Last verified date for pricing information
8
+ LAST_UPDATED = date(2025, 1, 6)
9
+
10
+ # Official pricing page
11
+ SOURCE_URL = "https://www.anthropic.com/pricing"
12
+
13
+ # All prices are in USD per 1 million tokens
14
+ ANTHROPIC_PRICES: dict[str, ModelPricing] = {
15
+ "claude-3-5-sonnet-20241022": ModelPricing(
16
+ model="claude-3-5-sonnet-20241022",
17
+ provider="anthropic",
18
+ input_per_1m=3.00,
19
+ output_per_1m=15.00,
20
+ cached_input_per_1m=0.30,
21
+ batch_input_per_1m=1.50,
22
+ batch_output_per_1m=7.50,
23
+ context_window=200_000,
24
+ notes="Most intelligent Claude model, best for complex tasks",
25
+ ),
26
+ "claude-3-5-sonnet-latest": ModelPricing(
27
+ model="claude-3-5-sonnet-latest",
28
+ provider="anthropic",
29
+ input_per_1m=3.00,
30
+ output_per_1m=15.00,
31
+ cached_input_per_1m=0.30,
32
+ batch_input_per_1m=1.50,
33
+ batch_output_per_1m=7.50,
34
+ context_window=200_000,
35
+ notes="Alias for claude-3-5-sonnet-20241022",
36
+ ),
37
+ "claude-3-5-haiku-20241022": ModelPricing(
38
+ model="claude-3-5-haiku-20241022",
39
+ provider="anthropic",
40
+ input_per_1m=0.80,
41
+ output_per_1m=4.00,
42
+ cached_input_per_1m=0.08,
43
+ batch_input_per_1m=0.40,
44
+ batch_output_per_1m=2.00,
45
+ context_window=200_000,
46
+ notes="Fast and cost-effective for simple tasks",
47
+ ),
48
+ "claude-3-opus-20240229": ModelPricing(
49
+ model="claude-3-opus-20240229",
50
+ provider="anthropic",
51
+ input_per_1m=15.00,
52
+ output_per_1m=75.00,
53
+ cached_input_per_1m=1.50,
54
+ batch_input_per_1m=7.50,
55
+ batch_output_per_1m=37.50,
56
+ context_window=200_000,
57
+ notes="Previous generation powerful model for complex tasks",
58
+ ),
59
+ "claude-3-haiku-20240307": ModelPricing(
60
+ model="claude-3-haiku-20240307",
61
+ provider="anthropic",
62
+ input_per_1m=0.25,
63
+ output_per_1m=1.25,
64
+ cached_input_per_1m=0.03,
65
+ context_window=200_000,
66
+ notes="Previous generation fastest and most compact model",
67
+ ),
68
+ }
69
+
70
+
71
+ def get_anthropic_registry() -> PricingRegistry:
72
+ """Create and return an Anthropic pricing registry.
73
+
74
+ Returns:
75
+ PricingRegistry configured with Anthropic model prices.
76
+ """
77
+ return PricingRegistry(
78
+ last_updated=LAST_UPDATED,
79
+ source_url=SOURCE_URL,
80
+ prices=ANTHROPIC_PRICES.copy(),
81
+ )
@@ -0,0 +1,113 @@
1
+ """LiteLLM-based pricing for model cost estimation.
2
+
3
+ Uses LiteLLM's community-maintained model cost database instead of
4
+ hardcoded values. This provides up-to-date pricing for 100+ models.
5
+
6
+ See: https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from dataclasses import dataclass
12
+ from typing import Any
13
+
14
+ import litellm
15
+
16
+
17
+ @dataclass
18
+ class LiteLLMModelPricing:
19
+ """Pricing information from LiteLLM's database.
20
+
21
+ All costs are in USD per 1 million tokens.
22
+ """
23
+
24
+ model: str
25
+ input_cost_per_1m: float
26
+ output_cost_per_1m: float
27
+ max_tokens: int | None = None
28
+ max_input_tokens: int | None = None
29
+ max_output_tokens: int | None = None
30
+ supports_vision: bool = False
31
+ supports_function_calling: bool = False
32
+
33
+
34
+ def get_litellm_model_cost() -> dict[str, Any]:
35
+ """Get LiteLLM's full model cost dictionary.
36
+
37
+ Returns:
38
+ Dictionary mapping model names to their pricing/capability info.
39
+ """
40
+ return litellm.model_cost
41
+
42
+
43
+ def get_model_pricing(model: str) -> LiteLLMModelPricing | None:
44
+ """Get pricing for a model from LiteLLM's database.
45
+
46
+ Args:
47
+ model: Model name (e.g., 'gpt-4o', 'claude-3-5-sonnet-20241022').
48
+
49
+ Returns:
50
+ LiteLLMModelPricing if found, None otherwise.
51
+ """
52
+ cost_data = litellm.model_cost
53
+
54
+ # Try exact match first
55
+ info = cost_data.get(model)
56
+
57
+ # Try common provider prefixes if not found
58
+ if info is None:
59
+ for prefix in ["openai/", "anthropic/", "google/", "mistral/", "deepseek/"]:
60
+ if f"{prefix}{model}" in cost_data:
61
+ info = cost_data[f"{prefix}{model}"]
62
+ break
63
+
64
+ if info is None:
65
+ return None
66
+
67
+ # LiteLLM stores cost per token, convert to per 1M
68
+ input_per_token = info.get("input_cost_per_token", 0) or 0
69
+ output_per_token = info.get("output_cost_per_token", 0) or 0
70
+
71
+ return LiteLLMModelPricing(
72
+ model=model,
73
+ input_cost_per_1m=input_per_token * 1_000_000,
74
+ output_cost_per_1m=output_per_token * 1_000_000,
75
+ max_tokens=info.get("max_tokens"),
76
+ max_input_tokens=info.get("max_input_tokens"),
77
+ max_output_tokens=info.get("max_output_tokens"),
78
+ supports_vision=info.get("supports_vision", False),
79
+ supports_function_calling=info.get("supports_function_calling", False),
80
+ )
81
+
82
+
83
+ def estimate_cost(
84
+ model: str,
85
+ input_tokens: int = 0,
86
+ output_tokens: int = 0,
87
+ ) -> float | None:
88
+ """Estimate cost for a model using LiteLLM's pricing.
89
+
90
+ Args:
91
+ model: Model name.
92
+ input_tokens: Number of input tokens.
93
+ output_tokens: Number of output tokens.
94
+
95
+ Returns:
96
+ Estimated cost in USD, or None if model not found.
97
+ """
98
+ pricing = get_model_pricing(model)
99
+ if pricing is None:
100
+ return None
101
+
102
+ input_cost = (input_tokens / 1_000_000) * pricing.input_cost_per_1m
103
+ output_cost = (output_tokens / 1_000_000) * pricing.output_cost_per_1m
104
+ return input_cost + output_cost
105
+
106
+
107
+ def list_available_models() -> list[str]:
108
+ """List all models with pricing info in LiteLLM's database.
109
+
110
+ Returns:
111
+ List of model names.
112
+ """
113
+ return list(litellm.model_cost.keys())
@@ -0,0 +1,91 @@
1
+ """OpenAI model pricing information."""
2
+
3
+ from datetime import date
4
+
5
+ from .registry import ModelPricing, PricingRegistry
6
+
7
+ # Last verified date for pricing information
8
+ LAST_UPDATED = date(2025, 1, 6)
9
+
10
+ # Official pricing page
11
+ SOURCE_URL = "https://openai.com/api/pricing/"
12
+
13
+ # All prices are in USD per 1 million tokens
14
+ OPENAI_PRICES: dict[str, ModelPricing] = {
15
+ "gpt-4o": ModelPricing(
16
+ model="gpt-4o",
17
+ provider="openai",
18
+ input_per_1m=2.50,
19
+ output_per_1m=10.00,
20
+ cached_input_per_1m=1.25,
21
+ context_window=128_000,
22
+ notes="Most capable GPT-4o model",
23
+ ),
24
+ "gpt-4o-mini": ModelPricing(
25
+ model="gpt-4o-mini",
26
+ provider="openai",
27
+ input_per_1m=0.15,
28
+ output_per_1m=0.60,
29
+ cached_input_per_1m=0.075,
30
+ context_window=128_000,
31
+ notes="Affordable small model for fast, lightweight tasks",
32
+ ),
33
+ "o1": ModelPricing(
34
+ model="o1",
35
+ provider="openai",
36
+ input_per_1m=15.00,
37
+ output_per_1m=60.00,
38
+ cached_input_per_1m=7.50,
39
+ context_window=200_000,
40
+ notes="Reasoning model for complex, multi-step tasks",
41
+ ),
42
+ "o1-mini": ModelPricing(
43
+ model="o1-mini",
44
+ provider="openai",
45
+ input_per_1m=1.10,
46
+ output_per_1m=4.40,
47
+ cached_input_per_1m=0.55,
48
+ context_window=128_000,
49
+ notes="Smaller reasoning model, cost-effective for coding tasks",
50
+ ),
51
+ "o3-mini": ModelPricing(
52
+ model="o3-mini",
53
+ provider="openai",
54
+ input_per_1m=1.10,
55
+ output_per_1m=4.40,
56
+ cached_input_per_1m=0.55,
57
+ context_window=200_000,
58
+ notes="Latest small reasoning model",
59
+ ),
60
+ "gpt-4-turbo": ModelPricing(
61
+ model="gpt-4-turbo",
62
+ provider="openai",
63
+ input_per_1m=10.00,
64
+ output_per_1m=30.00,
65
+ cached_input_per_1m=5.00,
66
+ context_window=128_000,
67
+ notes="Previous generation GPT-4 Turbo model",
68
+ ),
69
+ "gpt-3.5-turbo": ModelPricing(
70
+ model="gpt-3.5-turbo",
71
+ provider="openai",
72
+ input_per_1m=0.50,
73
+ output_per_1m=1.50,
74
+ cached_input_per_1m=0.25,
75
+ context_window=16_385,
76
+ notes="Fast, inexpensive model for simple tasks",
77
+ ),
78
+ }
79
+
80
+
81
+ def get_openai_registry() -> PricingRegistry:
82
+ """Create and return an OpenAI pricing registry.
83
+
84
+ Returns:
85
+ PricingRegistry configured with OpenAI model prices.
86
+ """
87
+ return PricingRegistry(
88
+ last_updated=LAST_UPDATED,
89
+ source_url=SOURCE_URL,
90
+ prices=OPENAI_PRICES.copy(),
91
+ )