headroom-ai 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- headroom/__init__.py +212 -0
- headroom/cache/__init__.py +76 -0
- headroom/cache/anthropic.py +517 -0
- headroom/cache/base.py +342 -0
- headroom/cache/compression_feedback.py +613 -0
- headroom/cache/compression_store.py +814 -0
- headroom/cache/dynamic_detector.py +1026 -0
- headroom/cache/google.py +884 -0
- headroom/cache/openai.py +584 -0
- headroom/cache/registry.py +175 -0
- headroom/cache/semantic.py +451 -0
- headroom/ccr/__init__.py +77 -0
- headroom/ccr/context_tracker.py +582 -0
- headroom/ccr/mcp_server.py +319 -0
- headroom/ccr/response_handler.py +772 -0
- headroom/ccr/tool_injection.py +415 -0
- headroom/cli.py +219 -0
- headroom/client.py +977 -0
- headroom/compression/__init__.py +42 -0
- headroom/compression/detector.py +424 -0
- headroom/compression/handlers/__init__.py +22 -0
- headroom/compression/handlers/base.py +219 -0
- headroom/compression/handlers/code_handler.py +506 -0
- headroom/compression/handlers/json_handler.py +418 -0
- headroom/compression/masks.py +345 -0
- headroom/compression/universal.py +465 -0
- headroom/config.py +474 -0
- headroom/exceptions.py +192 -0
- headroom/integrations/__init__.py +159 -0
- headroom/integrations/agno/__init__.py +53 -0
- headroom/integrations/agno/hooks.py +345 -0
- headroom/integrations/agno/model.py +625 -0
- headroom/integrations/agno/providers.py +154 -0
- headroom/integrations/langchain/__init__.py +106 -0
- headroom/integrations/langchain/agents.py +326 -0
- headroom/integrations/langchain/chat_model.py +1002 -0
- headroom/integrations/langchain/langsmith.py +324 -0
- headroom/integrations/langchain/memory.py +319 -0
- headroom/integrations/langchain/providers.py +200 -0
- headroom/integrations/langchain/retriever.py +371 -0
- headroom/integrations/langchain/streaming.py +341 -0
- headroom/integrations/mcp/__init__.py +37 -0
- headroom/integrations/mcp/server.py +533 -0
- headroom/memory/__init__.py +37 -0
- headroom/memory/extractor.py +390 -0
- headroom/memory/fast_store.py +621 -0
- headroom/memory/fast_wrapper.py +311 -0
- headroom/memory/inline_extractor.py +229 -0
- headroom/memory/store.py +434 -0
- headroom/memory/worker.py +260 -0
- headroom/memory/wrapper.py +321 -0
- headroom/models/__init__.py +39 -0
- headroom/models/registry.py +687 -0
- headroom/parser.py +293 -0
- headroom/pricing/__init__.py +51 -0
- headroom/pricing/anthropic_prices.py +81 -0
- headroom/pricing/litellm_pricing.py +113 -0
- headroom/pricing/openai_prices.py +91 -0
- headroom/pricing/registry.py +188 -0
- headroom/providers/__init__.py +61 -0
- headroom/providers/anthropic.py +621 -0
- headroom/providers/base.py +131 -0
- headroom/providers/cohere.py +362 -0
- headroom/providers/google.py +427 -0
- headroom/providers/litellm.py +297 -0
- headroom/providers/openai.py +566 -0
- headroom/providers/openai_compatible.py +521 -0
- headroom/proxy/__init__.py +19 -0
- headroom/proxy/server.py +2683 -0
- headroom/py.typed +0 -0
- headroom/relevance/__init__.py +124 -0
- headroom/relevance/base.py +106 -0
- headroom/relevance/bm25.py +255 -0
- headroom/relevance/embedding.py +255 -0
- headroom/relevance/hybrid.py +259 -0
- headroom/reporting/__init__.py +5 -0
- headroom/reporting/generator.py +549 -0
- headroom/storage/__init__.py +41 -0
- headroom/storage/base.py +125 -0
- headroom/storage/jsonl.py +220 -0
- headroom/storage/sqlite.py +289 -0
- headroom/telemetry/__init__.py +91 -0
- headroom/telemetry/collector.py +764 -0
- headroom/telemetry/models.py +880 -0
- headroom/telemetry/toin.py +1579 -0
- headroom/tokenizer.py +80 -0
- headroom/tokenizers/__init__.py +75 -0
- headroom/tokenizers/base.py +210 -0
- headroom/tokenizers/estimator.py +198 -0
- headroom/tokenizers/huggingface.py +317 -0
- headroom/tokenizers/mistral.py +245 -0
- headroom/tokenizers/registry.py +398 -0
- headroom/tokenizers/tiktoken_counter.py +248 -0
- headroom/transforms/__init__.py +106 -0
- headroom/transforms/base.py +57 -0
- headroom/transforms/cache_aligner.py +357 -0
- headroom/transforms/code_compressor.py +1313 -0
- headroom/transforms/content_detector.py +335 -0
- headroom/transforms/content_router.py +1158 -0
- headroom/transforms/llmlingua_compressor.py +638 -0
- headroom/transforms/log_compressor.py +529 -0
- headroom/transforms/pipeline.py +297 -0
- headroom/transforms/rolling_window.py +350 -0
- headroom/transforms/search_compressor.py +365 -0
- headroom/transforms/smart_crusher.py +2682 -0
- headroom/transforms/text_compressor.py +259 -0
- headroom/transforms/tool_crusher.py +338 -0
- headroom/utils.py +215 -0
- headroom_ai-0.2.13.dist-info/METADATA +315 -0
- headroom_ai-0.2.13.dist-info/RECORD +114 -0
- headroom_ai-0.2.13.dist-info/WHEEL +4 -0
- headroom_ai-0.2.13.dist-info/entry_points.txt +2 -0
- headroom_ai-0.2.13.dist-info/licenses/LICENSE +190 -0
- headroom_ai-0.2.13.dist-info/licenses/NOTICE +43 -0
headroom/parser.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
1
|
+
"""Message parsing utilities for Headroom SDK."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import hashlib
|
|
6
|
+
import re
|
|
7
|
+
from typing import TYPE_CHECKING, Any
|
|
8
|
+
|
|
9
|
+
from .config import Block, WasteSignals
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from .tokenizer import Tokenizer
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# Patterns for detecting waste signals
|
|
16
|
+
HTML_TAG_PATTERN = re.compile(r"<[^>]+>")
|
|
17
|
+
HTML_COMMENT_PATTERN = re.compile(r"<!--[\s\S]*?-->")
|
|
18
|
+
BASE64_PATTERN = re.compile(r"[A-Za-z0-9+/]{50,}={0,2}")
|
|
19
|
+
WHITESPACE_PATTERN = re.compile(r"[ \t]{4,}|\n{3,}")
|
|
20
|
+
JSON_BLOCK_PATTERN = re.compile(r"\{[\s\S]{500,}\}")
|
|
21
|
+
|
|
22
|
+
# Patterns for RAG detection (best effort)
|
|
23
|
+
RAG_MARKERS = [
|
|
24
|
+
r"\[Document\s*\d+\]",
|
|
25
|
+
r"\[Source:\s*",
|
|
26
|
+
r"<context>",
|
|
27
|
+
r"<document>",
|
|
28
|
+
r"Retrieved from:",
|
|
29
|
+
r"From the knowledge base:",
|
|
30
|
+
]
|
|
31
|
+
RAG_PATTERN = re.compile("|".join(RAG_MARKERS), re.IGNORECASE)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def compute_hash(text: str) -> str:
|
|
35
|
+
"""Compute SHA256 hash of text, truncated to 16 chars."""
|
|
36
|
+
return hashlib.sha256(text.encode()).hexdigest()[:16]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def detect_waste_signals(text: str, tokenizer: Tokenizer) -> WasteSignals:
|
|
40
|
+
"""
|
|
41
|
+
Detect waste signals in text.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
text: The text to analyze.
|
|
45
|
+
tokenizer: Tokenizer for counting tokens.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
WasteSignals with detected waste.
|
|
49
|
+
"""
|
|
50
|
+
signals = WasteSignals()
|
|
51
|
+
|
|
52
|
+
if not text:
|
|
53
|
+
return signals
|
|
54
|
+
|
|
55
|
+
# HTML tags and comments
|
|
56
|
+
html_matches = HTML_TAG_PATTERN.findall(text) + HTML_COMMENT_PATTERN.findall(text)
|
|
57
|
+
if html_matches:
|
|
58
|
+
html_text = "".join(html_matches)
|
|
59
|
+
signals.html_noise_tokens = tokenizer.count_text(html_text)
|
|
60
|
+
|
|
61
|
+
# Base64 blobs
|
|
62
|
+
base64_matches = BASE64_PATTERN.findall(text)
|
|
63
|
+
if base64_matches:
|
|
64
|
+
base64_text = "".join(base64_matches)
|
|
65
|
+
signals.base64_tokens = tokenizer.count_text(base64_text)
|
|
66
|
+
|
|
67
|
+
# Excessive whitespace
|
|
68
|
+
ws_matches = WHITESPACE_PATTERN.findall(text)
|
|
69
|
+
if ws_matches:
|
|
70
|
+
# Count tokens that could be saved by normalizing
|
|
71
|
+
ws_text = "".join(ws_matches)
|
|
72
|
+
signals.whitespace_tokens = max(0, tokenizer.count_text(ws_text) - len(ws_matches))
|
|
73
|
+
|
|
74
|
+
# Large JSON blocks
|
|
75
|
+
json_matches = JSON_BLOCK_PATTERN.findall(text)
|
|
76
|
+
if json_matches:
|
|
77
|
+
for match in json_matches:
|
|
78
|
+
tokens = tokenizer.count_text(match)
|
|
79
|
+
if tokens > 500:
|
|
80
|
+
signals.json_bloat_tokens += tokens
|
|
81
|
+
|
|
82
|
+
return signals
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def is_rag_content(text: str) -> bool:
|
|
86
|
+
"""Check if text appears to be RAG-injected content."""
|
|
87
|
+
return RAG_PATTERN.search(text) is not None
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def parse_message_to_blocks(
|
|
91
|
+
message: dict[str, Any],
|
|
92
|
+
index: int,
|
|
93
|
+
tokenizer: Tokenizer,
|
|
94
|
+
) -> list[Block]:
|
|
95
|
+
"""
|
|
96
|
+
Parse a single message into Block objects.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
message: The message dict to parse.
|
|
100
|
+
index: Position in the message list.
|
|
101
|
+
tokenizer: Tokenizer for token counting.
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
List of Block objects (usually 1, but tool_calls may produce multiple).
|
|
105
|
+
"""
|
|
106
|
+
blocks: list[Block] = []
|
|
107
|
+
role = message.get("role", "unknown")
|
|
108
|
+
|
|
109
|
+
# Handle content
|
|
110
|
+
content = message.get("content")
|
|
111
|
+
if content:
|
|
112
|
+
if isinstance(content, str):
|
|
113
|
+
text = content
|
|
114
|
+
elif isinstance(content, list):
|
|
115
|
+
# Multi-modal - extract text parts
|
|
116
|
+
text_parts = []
|
|
117
|
+
for part in content:
|
|
118
|
+
if isinstance(part, dict) and part.get("type") == "text":
|
|
119
|
+
text_parts.append(part.get("text", ""))
|
|
120
|
+
elif isinstance(part, str):
|
|
121
|
+
text_parts.append(part)
|
|
122
|
+
text = "\n".join(text_parts)
|
|
123
|
+
else:
|
|
124
|
+
text = str(content)
|
|
125
|
+
|
|
126
|
+
# Determine block kind
|
|
127
|
+
if role == "system":
|
|
128
|
+
kind = "system"
|
|
129
|
+
elif role == "user":
|
|
130
|
+
# Check if this looks like RAG content
|
|
131
|
+
kind = "rag" if is_rag_content(text) else "user"
|
|
132
|
+
elif role == "assistant":
|
|
133
|
+
kind = "assistant"
|
|
134
|
+
elif role == "tool":
|
|
135
|
+
kind = "tool_result"
|
|
136
|
+
else:
|
|
137
|
+
kind = "unknown"
|
|
138
|
+
|
|
139
|
+
# Build flags
|
|
140
|
+
flags: dict[str, Any] = {}
|
|
141
|
+
if role == "tool":
|
|
142
|
+
flags["tool_call_id"] = message.get("tool_call_id")
|
|
143
|
+
|
|
144
|
+
# Detect waste
|
|
145
|
+
waste = detect_waste_signals(text, tokenizer)
|
|
146
|
+
if waste.total() > 0:
|
|
147
|
+
flags["waste_signals"] = waste.to_dict()
|
|
148
|
+
|
|
149
|
+
blocks.append(
|
|
150
|
+
Block(
|
|
151
|
+
kind=kind, # type: ignore[arg-type]
|
|
152
|
+
text=text,
|
|
153
|
+
tokens_est=tokenizer.count_text(text) + 4, # Add message overhead
|
|
154
|
+
content_hash=compute_hash(text),
|
|
155
|
+
source_index=index,
|
|
156
|
+
flags=flags,
|
|
157
|
+
)
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
# Handle tool calls (assistant messages with tool_calls)
|
|
161
|
+
tool_calls = message.get("tool_calls")
|
|
162
|
+
if tool_calls:
|
|
163
|
+
for tc in tool_calls:
|
|
164
|
+
func = tc.get("function", {})
|
|
165
|
+
tc_text = f"{func.get('name', 'unknown')}({func.get('arguments', '')})"
|
|
166
|
+
|
|
167
|
+
blocks.append(
|
|
168
|
+
Block(
|
|
169
|
+
kind="tool_call",
|
|
170
|
+
text=tc_text,
|
|
171
|
+
tokens_est=tokenizer.count_text(tc_text) + 10,
|
|
172
|
+
content_hash=compute_hash(tc_text),
|
|
173
|
+
source_index=index,
|
|
174
|
+
flags={
|
|
175
|
+
"tool_call_id": tc.get("id"),
|
|
176
|
+
"function_name": func.get("name"),
|
|
177
|
+
},
|
|
178
|
+
)
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
# If no content or tool_calls, create a minimal block
|
|
182
|
+
if not blocks:
|
|
183
|
+
blocks.append(
|
|
184
|
+
Block(
|
|
185
|
+
kind="unknown",
|
|
186
|
+
text="",
|
|
187
|
+
tokens_est=4,
|
|
188
|
+
content_hash=compute_hash(""),
|
|
189
|
+
source_index=index,
|
|
190
|
+
flags={},
|
|
191
|
+
)
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
return blocks
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def parse_messages(
|
|
198
|
+
messages: list[dict[str, Any]],
|
|
199
|
+
tokenizer: Tokenizer,
|
|
200
|
+
) -> tuple[list[Block], dict[str, int], WasteSignals]:
|
|
201
|
+
"""
|
|
202
|
+
Parse all messages into blocks with analysis.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
messages: List of message dicts.
|
|
206
|
+
tokenizer: Tokenizer instance for token counting.
|
|
207
|
+
|
|
208
|
+
Returns:
|
|
209
|
+
Tuple of (blocks, block_breakdown, total_waste_signals)
|
|
210
|
+
"""
|
|
211
|
+
all_blocks: list[Block] = []
|
|
212
|
+
total_waste = WasteSignals()
|
|
213
|
+
|
|
214
|
+
for i, msg in enumerate(messages):
|
|
215
|
+
blocks = parse_message_to_blocks(msg, i, tokenizer)
|
|
216
|
+
all_blocks.extend(blocks)
|
|
217
|
+
|
|
218
|
+
# Accumulate waste signals
|
|
219
|
+
for block in blocks:
|
|
220
|
+
if "waste_signals" in block.flags:
|
|
221
|
+
ws = block.flags["waste_signals"]
|
|
222
|
+
total_waste.json_bloat_tokens += ws.get("json_bloat", 0)
|
|
223
|
+
total_waste.html_noise_tokens += ws.get("html_noise", 0)
|
|
224
|
+
total_waste.base64_tokens += ws.get("base64", 0)
|
|
225
|
+
total_waste.whitespace_tokens += ws.get("whitespace", 0)
|
|
226
|
+
total_waste.dynamic_date_tokens += ws.get("dynamic_date", 0)
|
|
227
|
+
total_waste.repetition_tokens += ws.get("repetition", 0)
|
|
228
|
+
|
|
229
|
+
# Compute block breakdown
|
|
230
|
+
breakdown: dict[str, int] = {}
|
|
231
|
+
for block in all_blocks:
|
|
232
|
+
kind = block.kind
|
|
233
|
+
breakdown[kind] = breakdown.get(kind, 0) + block.tokens_est
|
|
234
|
+
|
|
235
|
+
return all_blocks, breakdown, total_waste
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def find_tool_units(messages: list[dict[str, Any]]) -> list[tuple[int, list[int]]]:
|
|
239
|
+
"""
|
|
240
|
+
Find tool call units (assistant with tool_calls + corresponding tool responses).
|
|
241
|
+
|
|
242
|
+
A tool unit is atomic - if the assistant message is dropped, all its
|
|
243
|
+
tool responses must also be dropped.
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
messages: List of message dicts.
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
List of (assistant_index, [tool_response_indices]) tuples.
|
|
250
|
+
"""
|
|
251
|
+
units: list[tuple[int, list[int]]] = []
|
|
252
|
+
|
|
253
|
+
# Build map of tool_call_id -> message index for tool responses
|
|
254
|
+
tool_response_map: dict[str, int] = {}
|
|
255
|
+
for i, msg in enumerate(messages):
|
|
256
|
+
if msg.get("role") == "tool":
|
|
257
|
+
tc_id = msg.get("tool_call_id")
|
|
258
|
+
if tc_id:
|
|
259
|
+
tool_response_map[tc_id] = i
|
|
260
|
+
|
|
261
|
+
# Find assistant messages with tool_calls
|
|
262
|
+
for i, msg in enumerate(messages):
|
|
263
|
+
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
|
264
|
+
tool_calls = msg["tool_calls"]
|
|
265
|
+
response_indices: list[int] = []
|
|
266
|
+
|
|
267
|
+
for tc in tool_calls:
|
|
268
|
+
tc_id = tc.get("id")
|
|
269
|
+
if tc_id and tc_id in tool_response_map:
|
|
270
|
+
response_indices.append(tool_response_map[tc_id])
|
|
271
|
+
|
|
272
|
+
if response_indices:
|
|
273
|
+
units.append((i, sorted(response_indices)))
|
|
274
|
+
|
|
275
|
+
return units
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def get_message_content_text(message: dict[str, Any]) -> str:
|
|
279
|
+
"""Extract text content from a message."""
|
|
280
|
+
content = message.get("content")
|
|
281
|
+
if content is None:
|
|
282
|
+
return ""
|
|
283
|
+
if isinstance(content, str):
|
|
284
|
+
return content
|
|
285
|
+
if isinstance(content, list):
|
|
286
|
+
parts = []
|
|
287
|
+
for part in content:
|
|
288
|
+
if isinstance(part, dict) and part.get("type") == "text":
|
|
289
|
+
parts.append(part.get("text", ""))
|
|
290
|
+
elif isinstance(part, str):
|
|
291
|
+
parts.append(part)
|
|
292
|
+
return "\n".join(parts)
|
|
293
|
+
return str(content)
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
"""Pricing module for LLM cost estimation.
|
|
2
|
+
|
|
3
|
+
This module provides pricing information and cost estimation utilities
|
|
4
|
+
for various LLM providers. Uses LiteLLM's community-maintained pricing
|
|
5
|
+
database for up-to-date costs across 100+ models.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
# Legacy imports for backwards compatibility
|
|
9
|
+
from .anthropic_prices import (
|
|
10
|
+
ANTHROPIC_PRICES,
|
|
11
|
+
get_anthropic_registry,
|
|
12
|
+
)
|
|
13
|
+
from .anthropic_prices import (
|
|
14
|
+
LAST_UPDATED as ANTHROPIC_LAST_UPDATED,
|
|
15
|
+
)
|
|
16
|
+
from .litellm_pricing import (
|
|
17
|
+
LiteLLMModelPricing,
|
|
18
|
+
estimate_cost,
|
|
19
|
+
get_litellm_model_cost,
|
|
20
|
+
get_model_pricing,
|
|
21
|
+
list_available_models,
|
|
22
|
+
)
|
|
23
|
+
from .openai_prices import (
|
|
24
|
+
LAST_UPDATED as OPENAI_LAST_UPDATED,
|
|
25
|
+
)
|
|
26
|
+
from .openai_prices import (
|
|
27
|
+
OPENAI_PRICES,
|
|
28
|
+
get_openai_registry,
|
|
29
|
+
)
|
|
30
|
+
from .registry import CostEstimate, ModelPricing, PricingRegistry
|
|
31
|
+
|
|
32
|
+
__all__ = [
|
|
33
|
+
# LiteLLM-based pricing (preferred)
|
|
34
|
+
"LiteLLMModelPricing",
|
|
35
|
+
"estimate_cost",
|
|
36
|
+
"get_litellm_model_cost",
|
|
37
|
+
"get_model_pricing",
|
|
38
|
+
"list_available_models",
|
|
39
|
+
# Core classes
|
|
40
|
+
"CostEstimate",
|
|
41
|
+
"ModelPricing",
|
|
42
|
+
"PricingRegistry",
|
|
43
|
+
# Legacy - OpenAI (deprecated, use LiteLLM instead)
|
|
44
|
+
"OPENAI_LAST_UPDATED",
|
|
45
|
+
"OPENAI_PRICES",
|
|
46
|
+
"get_openai_registry",
|
|
47
|
+
# Legacy - Anthropic (deprecated, use LiteLLM instead)
|
|
48
|
+
"ANTHROPIC_LAST_UPDATED",
|
|
49
|
+
"ANTHROPIC_PRICES",
|
|
50
|
+
"get_anthropic_registry",
|
|
51
|
+
]
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
"""Anthropic model pricing information."""
|
|
2
|
+
|
|
3
|
+
from datetime import date
|
|
4
|
+
|
|
5
|
+
from .registry import ModelPricing, PricingRegistry
|
|
6
|
+
|
|
7
|
+
# Last verified date for pricing information
|
|
8
|
+
LAST_UPDATED = date(2025, 1, 6)
|
|
9
|
+
|
|
10
|
+
# Official pricing page
|
|
11
|
+
SOURCE_URL = "https://www.anthropic.com/pricing"
|
|
12
|
+
|
|
13
|
+
# All prices are in USD per 1 million tokens
|
|
14
|
+
ANTHROPIC_PRICES: dict[str, ModelPricing] = {
|
|
15
|
+
"claude-3-5-sonnet-20241022": ModelPricing(
|
|
16
|
+
model="claude-3-5-sonnet-20241022",
|
|
17
|
+
provider="anthropic",
|
|
18
|
+
input_per_1m=3.00,
|
|
19
|
+
output_per_1m=15.00,
|
|
20
|
+
cached_input_per_1m=0.30,
|
|
21
|
+
batch_input_per_1m=1.50,
|
|
22
|
+
batch_output_per_1m=7.50,
|
|
23
|
+
context_window=200_000,
|
|
24
|
+
notes="Most intelligent Claude model, best for complex tasks",
|
|
25
|
+
),
|
|
26
|
+
"claude-3-5-sonnet-latest": ModelPricing(
|
|
27
|
+
model="claude-3-5-sonnet-latest",
|
|
28
|
+
provider="anthropic",
|
|
29
|
+
input_per_1m=3.00,
|
|
30
|
+
output_per_1m=15.00,
|
|
31
|
+
cached_input_per_1m=0.30,
|
|
32
|
+
batch_input_per_1m=1.50,
|
|
33
|
+
batch_output_per_1m=7.50,
|
|
34
|
+
context_window=200_000,
|
|
35
|
+
notes="Alias for claude-3-5-sonnet-20241022",
|
|
36
|
+
),
|
|
37
|
+
"claude-3-5-haiku-20241022": ModelPricing(
|
|
38
|
+
model="claude-3-5-haiku-20241022",
|
|
39
|
+
provider="anthropic",
|
|
40
|
+
input_per_1m=0.80,
|
|
41
|
+
output_per_1m=4.00,
|
|
42
|
+
cached_input_per_1m=0.08,
|
|
43
|
+
batch_input_per_1m=0.40,
|
|
44
|
+
batch_output_per_1m=2.00,
|
|
45
|
+
context_window=200_000,
|
|
46
|
+
notes="Fast and cost-effective for simple tasks",
|
|
47
|
+
),
|
|
48
|
+
"claude-3-opus-20240229": ModelPricing(
|
|
49
|
+
model="claude-3-opus-20240229",
|
|
50
|
+
provider="anthropic",
|
|
51
|
+
input_per_1m=15.00,
|
|
52
|
+
output_per_1m=75.00,
|
|
53
|
+
cached_input_per_1m=1.50,
|
|
54
|
+
batch_input_per_1m=7.50,
|
|
55
|
+
batch_output_per_1m=37.50,
|
|
56
|
+
context_window=200_000,
|
|
57
|
+
notes="Previous generation powerful model for complex tasks",
|
|
58
|
+
),
|
|
59
|
+
"claude-3-haiku-20240307": ModelPricing(
|
|
60
|
+
model="claude-3-haiku-20240307",
|
|
61
|
+
provider="anthropic",
|
|
62
|
+
input_per_1m=0.25,
|
|
63
|
+
output_per_1m=1.25,
|
|
64
|
+
cached_input_per_1m=0.03,
|
|
65
|
+
context_window=200_000,
|
|
66
|
+
notes="Previous generation fastest and most compact model",
|
|
67
|
+
),
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def get_anthropic_registry() -> PricingRegistry:
|
|
72
|
+
"""Create and return an Anthropic pricing registry.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
PricingRegistry configured with Anthropic model prices.
|
|
76
|
+
"""
|
|
77
|
+
return PricingRegistry(
|
|
78
|
+
last_updated=LAST_UPDATED,
|
|
79
|
+
source_url=SOURCE_URL,
|
|
80
|
+
prices=ANTHROPIC_PRICES.copy(),
|
|
81
|
+
)
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
"""LiteLLM-based pricing for model cost estimation.
|
|
2
|
+
|
|
3
|
+
Uses LiteLLM's community-maintained model cost database instead of
|
|
4
|
+
hardcoded values. This provides up-to-date pricing for 100+ models.
|
|
5
|
+
|
|
6
|
+
See: https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
import litellm
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class LiteLLMModelPricing:
|
|
19
|
+
"""Pricing information from LiteLLM's database.
|
|
20
|
+
|
|
21
|
+
All costs are in USD per 1 million tokens.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
model: str
|
|
25
|
+
input_cost_per_1m: float
|
|
26
|
+
output_cost_per_1m: float
|
|
27
|
+
max_tokens: int | None = None
|
|
28
|
+
max_input_tokens: int | None = None
|
|
29
|
+
max_output_tokens: int | None = None
|
|
30
|
+
supports_vision: bool = False
|
|
31
|
+
supports_function_calling: bool = False
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def get_litellm_model_cost() -> dict[str, Any]:
|
|
35
|
+
"""Get LiteLLM's full model cost dictionary.
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
Dictionary mapping model names to their pricing/capability info.
|
|
39
|
+
"""
|
|
40
|
+
return litellm.model_cost
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def get_model_pricing(model: str) -> LiteLLMModelPricing | None:
|
|
44
|
+
"""Get pricing for a model from LiteLLM's database.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
model: Model name (e.g., 'gpt-4o', 'claude-3-5-sonnet-20241022').
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
LiteLLMModelPricing if found, None otherwise.
|
|
51
|
+
"""
|
|
52
|
+
cost_data = litellm.model_cost
|
|
53
|
+
|
|
54
|
+
# Try exact match first
|
|
55
|
+
info = cost_data.get(model)
|
|
56
|
+
|
|
57
|
+
# Try common provider prefixes if not found
|
|
58
|
+
if info is None:
|
|
59
|
+
for prefix in ["openai/", "anthropic/", "google/", "mistral/", "deepseek/"]:
|
|
60
|
+
if f"{prefix}{model}" in cost_data:
|
|
61
|
+
info = cost_data[f"{prefix}{model}"]
|
|
62
|
+
break
|
|
63
|
+
|
|
64
|
+
if info is None:
|
|
65
|
+
return None
|
|
66
|
+
|
|
67
|
+
# LiteLLM stores cost per token, convert to per 1M
|
|
68
|
+
input_per_token = info.get("input_cost_per_token", 0) or 0
|
|
69
|
+
output_per_token = info.get("output_cost_per_token", 0) or 0
|
|
70
|
+
|
|
71
|
+
return LiteLLMModelPricing(
|
|
72
|
+
model=model,
|
|
73
|
+
input_cost_per_1m=input_per_token * 1_000_000,
|
|
74
|
+
output_cost_per_1m=output_per_token * 1_000_000,
|
|
75
|
+
max_tokens=info.get("max_tokens"),
|
|
76
|
+
max_input_tokens=info.get("max_input_tokens"),
|
|
77
|
+
max_output_tokens=info.get("max_output_tokens"),
|
|
78
|
+
supports_vision=info.get("supports_vision", False),
|
|
79
|
+
supports_function_calling=info.get("supports_function_calling", False),
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def estimate_cost(
|
|
84
|
+
model: str,
|
|
85
|
+
input_tokens: int = 0,
|
|
86
|
+
output_tokens: int = 0,
|
|
87
|
+
) -> float | None:
|
|
88
|
+
"""Estimate cost for a model using LiteLLM's pricing.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
model: Model name.
|
|
92
|
+
input_tokens: Number of input tokens.
|
|
93
|
+
output_tokens: Number of output tokens.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
Estimated cost in USD, or None if model not found.
|
|
97
|
+
"""
|
|
98
|
+
pricing = get_model_pricing(model)
|
|
99
|
+
if pricing is None:
|
|
100
|
+
return None
|
|
101
|
+
|
|
102
|
+
input_cost = (input_tokens / 1_000_000) * pricing.input_cost_per_1m
|
|
103
|
+
output_cost = (output_tokens / 1_000_000) * pricing.output_cost_per_1m
|
|
104
|
+
return input_cost + output_cost
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def list_available_models() -> list[str]:
|
|
108
|
+
"""List all models with pricing info in LiteLLM's database.
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
List of model names.
|
|
112
|
+
"""
|
|
113
|
+
return list(litellm.model_cost.keys())
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
"""OpenAI model pricing information."""
|
|
2
|
+
|
|
3
|
+
from datetime import date
|
|
4
|
+
|
|
5
|
+
from .registry import ModelPricing, PricingRegistry
|
|
6
|
+
|
|
7
|
+
# Last verified date for pricing information
|
|
8
|
+
LAST_UPDATED = date(2025, 1, 6)
|
|
9
|
+
|
|
10
|
+
# Official pricing page
|
|
11
|
+
SOURCE_URL = "https://openai.com/api/pricing/"
|
|
12
|
+
|
|
13
|
+
# All prices are in USD per 1 million tokens
|
|
14
|
+
OPENAI_PRICES: dict[str, ModelPricing] = {
|
|
15
|
+
"gpt-4o": ModelPricing(
|
|
16
|
+
model="gpt-4o",
|
|
17
|
+
provider="openai",
|
|
18
|
+
input_per_1m=2.50,
|
|
19
|
+
output_per_1m=10.00,
|
|
20
|
+
cached_input_per_1m=1.25,
|
|
21
|
+
context_window=128_000,
|
|
22
|
+
notes="Most capable GPT-4o model",
|
|
23
|
+
),
|
|
24
|
+
"gpt-4o-mini": ModelPricing(
|
|
25
|
+
model="gpt-4o-mini",
|
|
26
|
+
provider="openai",
|
|
27
|
+
input_per_1m=0.15,
|
|
28
|
+
output_per_1m=0.60,
|
|
29
|
+
cached_input_per_1m=0.075,
|
|
30
|
+
context_window=128_000,
|
|
31
|
+
notes="Affordable small model for fast, lightweight tasks",
|
|
32
|
+
),
|
|
33
|
+
"o1": ModelPricing(
|
|
34
|
+
model="o1",
|
|
35
|
+
provider="openai",
|
|
36
|
+
input_per_1m=15.00,
|
|
37
|
+
output_per_1m=60.00,
|
|
38
|
+
cached_input_per_1m=7.50,
|
|
39
|
+
context_window=200_000,
|
|
40
|
+
notes="Reasoning model for complex, multi-step tasks",
|
|
41
|
+
),
|
|
42
|
+
"o1-mini": ModelPricing(
|
|
43
|
+
model="o1-mini",
|
|
44
|
+
provider="openai",
|
|
45
|
+
input_per_1m=1.10,
|
|
46
|
+
output_per_1m=4.40,
|
|
47
|
+
cached_input_per_1m=0.55,
|
|
48
|
+
context_window=128_000,
|
|
49
|
+
notes="Smaller reasoning model, cost-effective for coding tasks",
|
|
50
|
+
),
|
|
51
|
+
"o3-mini": ModelPricing(
|
|
52
|
+
model="o3-mini",
|
|
53
|
+
provider="openai",
|
|
54
|
+
input_per_1m=1.10,
|
|
55
|
+
output_per_1m=4.40,
|
|
56
|
+
cached_input_per_1m=0.55,
|
|
57
|
+
context_window=200_000,
|
|
58
|
+
notes="Latest small reasoning model",
|
|
59
|
+
),
|
|
60
|
+
"gpt-4-turbo": ModelPricing(
|
|
61
|
+
model="gpt-4-turbo",
|
|
62
|
+
provider="openai",
|
|
63
|
+
input_per_1m=10.00,
|
|
64
|
+
output_per_1m=30.00,
|
|
65
|
+
cached_input_per_1m=5.00,
|
|
66
|
+
context_window=128_000,
|
|
67
|
+
notes="Previous generation GPT-4 Turbo model",
|
|
68
|
+
),
|
|
69
|
+
"gpt-3.5-turbo": ModelPricing(
|
|
70
|
+
model="gpt-3.5-turbo",
|
|
71
|
+
provider="openai",
|
|
72
|
+
input_per_1m=0.50,
|
|
73
|
+
output_per_1m=1.50,
|
|
74
|
+
cached_input_per_1m=0.25,
|
|
75
|
+
context_window=16_385,
|
|
76
|
+
notes="Fast, inexpensive model for simple tasks",
|
|
77
|
+
),
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def get_openai_registry() -> PricingRegistry:
|
|
82
|
+
"""Create and return an OpenAI pricing registry.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
PricingRegistry configured with OpenAI model prices.
|
|
86
|
+
"""
|
|
87
|
+
return PricingRegistry(
|
|
88
|
+
last_updated=LAST_UPDATED,
|
|
89
|
+
source_url=SOURCE_URL,
|
|
90
|
+
prices=OPENAI_PRICES.copy(),
|
|
91
|
+
)
|