ccproxy-api 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ccproxy/__init__.py +4 -0
- ccproxy/__main__.py +7 -0
- ccproxy/_version.py +21 -0
- ccproxy/adapters/__init__.py +11 -0
- ccproxy/adapters/base.py +80 -0
- ccproxy/adapters/openai/__init__.py +43 -0
- ccproxy/adapters/openai/adapter.py +915 -0
- ccproxy/adapters/openai/models.py +412 -0
- ccproxy/adapters/openai/streaming.py +449 -0
- ccproxy/api/__init__.py +28 -0
- ccproxy/api/app.py +225 -0
- ccproxy/api/dependencies.py +140 -0
- ccproxy/api/middleware/__init__.py +11 -0
- ccproxy/api/middleware/auth.py +0 -0
- ccproxy/api/middleware/cors.py +55 -0
- ccproxy/api/middleware/errors.py +703 -0
- ccproxy/api/middleware/headers.py +51 -0
- ccproxy/api/middleware/logging.py +175 -0
- ccproxy/api/middleware/request_id.py +69 -0
- ccproxy/api/middleware/server_header.py +62 -0
- ccproxy/api/responses.py +84 -0
- ccproxy/api/routes/__init__.py +16 -0
- ccproxy/api/routes/claude.py +181 -0
- ccproxy/api/routes/health.py +489 -0
- ccproxy/api/routes/metrics.py +1033 -0
- ccproxy/api/routes/proxy.py +238 -0
- ccproxy/auth/__init__.py +75 -0
- ccproxy/auth/bearer.py +68 -0
- ccproxy/auth/credentials_adapter.py +93 -0
- ccproxy/auth/dependencies.py +229 -0
- ccproxy/auth/exceptions.py +79 -0
- ccproxy/auth/manager.py +102 -0
- ccproxy/auth/models.py +118 -0
- ccproxy/auth/oauth/__init__.py +26 -0
- ccproxy/auth/oauth/models.py +49 -0
- ccproxy/auth/oauth/routes.py +396 -0
- ccproxy/auth/oauth/storage.py +0 -0
- ccproxy/auth/storage/__init__.py +12 -0
- ccproxy/auth/storage/base.py +57 -0
- ccproxy/auth/storage/json_file.py +159 -0
- ccproxy/auth/storage/keyring.py +192 -0
- ccproxy/claude_sdk/__init__.py +20 -0
- ccproxy/claude_sdk/client.py +169 -0
- ccproxy/claude_sdk/converter.py +331 -0
- ccproxy/claude_sdk/options.py +120 -0
- ccproxy/cli/__init__.py +14 -0
- ccproxy/cli/commands/__init__.py +8 -0
- ccproxy/cli/commands/auth.py +553 -0
- ccproxy/cli/commands/config/__init__.py +14 -0
- ccproxy/cli/commands/config/commands.py +766 -0
- ccproxy/cli/commands/config/schema_commands.py +119 -0
- ccproxy/cli/commands/serve.py +630 -0
- ccproxy/cli/docker/__init__.py +34 -0
- ccproxy/cli/docker/adapter_factory.py +157 -0
- ccproxy/cli/docker/params.py +278 -0
- ccproxy/cli/helpers.py +144 -0
- ccproxy/cli/main.py +193 -0
- ccproxy/cli/options/__init__.py +14 -0
- ccproxy/cli/options/claude_options.py +216 -0
- ccproxy/cli/options/core_options.py +40 -0
- ccproxy/cli/options/security_options.py +48 -0
- ccproxy/cli/options/server_options.py +117 -0
- ccproxy/config/__init__.py +40 -0
- ccproxy/config/auth.py +154 -0
- ccproxy/config/claude.py +124 -0
- ccproxy/config/cors.py +79 -0
- ccproxy/config/discovery.py +87 -0
- ccproxy/config/docker_settings.py +265 -0
- ccproxy/config/loader.py +108 -0
- ccproxy/config/observability.py +158 -0
- ccproxy/config/pricing.py +88 -0
- ccproxy/config/reverse_proxy.py +31 -0
- ccproxy/config/scheduler.py +89 -0
- ccproxy/config/security.py +14 -0
- ccproxy/config/server.py +81 -0
- ccproxy/config/settings.py +534 -0
- ccproxy/config/validators.py +231 -0
- ccproxy/core/__init__.py +274 -0
- ccproxy/core/async_utils.py +675 -0
- ccproxy/core/constants.py +97 -0
- ccproxy/core/errors.py +256 -0
- ccproxy/core/http.py +328 -0
- ccproxy/core/http_transformers.py +428 -0
- ccproxy/core/interfaces.py +247 -0
- ccproxy/core/logging.py +189 -0
- ccproxy/core/middleware.py +114 -0
- ccproxy/core/proxy.py +143 -0
- ccproxy/core/system.py +38 -0
- ccproxy/core/transformers.py +259 -0
- ccproxy/core/types.py +129 -0
- ccproxy/core/validators.py +288 -0
- ccproxy/docker/__init__.py +67 -0
- ccproxy/docker/adapter.py +588 -0
- ccproxy/docker/docker_path.py +207 -0
- ccproxy/docker/middleware.py +103 -0
- ccproxy/docker/models.py +228 -0
- ccproxy/docker/protocol.py +192 -0
- ccproxy/docker/stream_process.py +264 -0
- ccproxy/docker/validators.py +173 -0
- ccproxy/models/__init__.py +123 -0
- ccproxy/models/errors.py +42 -0
- ccproxy/models/messages.py +243 -0
- ccproxy/models/requests.py +85 -0
- ccproxy/models/responses.py +227 -0
- ccproxy/models/types.py +102 -0
- ccproxy/observability/__init__.py +51 -0
- ccproxy/observability/access_logger.py +400 -0
- ccproxy/observability/context.py +447 -0
- ccproxy/observability/metrics.py +539 -0
- ccproxy/observability/pushgateway.py +366 -0
- ccproxy/observability/sse_events.py +303 -0
- ccproxy/observability/stats_printer.py +755 -0
- ccproxy/observability/storage/__init__.py +1 -0
- ccproxy/observability/storage/duckdb_simple.py +665 -0
- ccproxy/observability/storage/models.py +55 -0
- ccproxy/pricing/__init__.py +19 -0
- ccproxy/pricing/cache.py +212 -0
- ccproxy/pricing/loader.py +267 -0
- ccproxy/pricing/models.py +106 -0
- ccproxy/pricing/updater.py +309 -0
- ccproxy/scheduler/__init__.py +39 -0
- ccproxy/scheduler/core.py +335 -0
- ccproxy/scheduler/exceptions.py +34 -0
- ccproxy/scheduler/manager.py +186 -0
- ccproxy/scheduler/registry.py +150 -0
- ccproxy/scheduler/tasks.py +484 -0
- ccproxy/services/__init__.py +10 -0
- ccproxy/services/claude_sdk_service.py +614 -0
- ccproxy/services/credentials/__init__.py +55 -0
- ccproxy/services/credentials/config.py +105 -0
- ccproxy/services/credentials/manager.py +562 -0
- ccproxy/services/credentials/oauth_client.py +482 -0
- ccproxy/services/proxy_service.py +1536 -0
- ccproxy/static/.keep +0 -0
- ccproxy/testing/__init__.py +34 -0
- ccproxy/testing/config.py +148 -0
- ccproxy/testing/content_generation.py +197 -0
- ccproxy/testing/mock_responses.py +262 -0
- ccproxy/testing/response_handlers.py +161 -0
- ccproxy/testing/scenarios.py +241 -0
- ccproxy/utils/__init__.py +6 -0
- ccproxy/utils/cost_calculator.py +210 -0
- ccproxy/utils/streaming_metrics.py +199 -0
- ccproxy_api-0.1.0.dist-info/METADATA +253 -0
- ccproxy_api-0.1.0.dist-info/RECORD +148 -0
- ccproxy_api-0.1.0.dist-info/WHEEL +4 -0
- ccproxy_api-0.1.0.dist-info/entry_points.txt +2 -0
- ccproxy_api-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
"""
|
|
2
|
+
SQLModel schema definitions for observability storage.
|
|
3
|
+
|
|
4
|
+
This module provides the centralized schema definitions for access logs and metrics
|
|
5
|
+
using SQLModel to ensure type safety and eliminate column name repetition.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from datetime import datetime
|
|
9
|
+
from typing import Optional
|
|
10
|
+
|
|
11
|
+
from sqlmodel import Field, SQLModel
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class AccessLog(SQLModel, table=True):
|
|
15
|
+
"""Access log model for storing request/response data."""
|
|
16
|
+
|
|
17
|
+
__tablename__ = "access_logs"
|
|
18
|
+
|
|
19
|
+
# Core request identification
|
|
20
|
+
request_id: str = Field(primary_key=True)
|
|
21
|
+
timestamp: datetime = Field(default_factory=datetime.now, index=True)
|
|
22
|
+
|
|
23
|
+
# Request details
|
|
24
|
+
method: str
|
|
25
|
+
endpoint: str
|
|
26
|
+
path: str
|
|
27
|
+
query: str = Field(default="")
|
|
28
|
+
client_ip: str
|
|
29
|
+
user_agent: str
|
|
30
|
+
|
|
31
|
+
# Service and model info
|
|
32
|
+
service_type: str
|
|
33
|
+
model: str
|
|
34
|
+
streaming: bool = Field(default=False)
|
|
35
|
+
|
|
36
|
+
# Response details
|
|
37
|
+
status_code: int
|
|
38
|
+
duration_ms: float
|
|
39
|
+
duration_seconds: float
|
|
40
|
+
|
|
41
|
+
# Token and cost tracking
|
|
42
|
+
tokens_input: int = Field(default=0)
|
|
43
|
+
tokens_output: int = Field(default=0)
|
|
44
|
+
cache_read_tokens: int = Field(default=0)
|
|
45
|
+
cache_write_tokens: int = Field(default=0)
|
|
46
|
+
cost_usd: float = Field(default=0.0)
|
|
47
|
+
cost_sdk_usd: float = Field(default=0.0)
|
|
48
|
+
|
|
49
|
+
class Config:
|
|
50
|
+
"""SQLModel configuration."""
|
|
51
|
+
|
|
52
|
+
# Enable automatic conversion from dict
|
|
53
|
+
from_attributes = True
|
|
54
|
+
# Use enum values
|
|
55
|
+
use_enum_values = True
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""Dynamic pricing system for Claude models.
|
|
2
|
+
|
|
3
|
+
This module provides dynamic pricing capabilities by downloading and caching
|
|
4
|
+
pricing information from external sources like LiteLLM.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from .cache import PricingCache
|
|
8
|
+
from .loader import PricingLoader
|
|
9
|
+
from .models import ModelPricing, PricingData
|
|
10
|
+
from .updater import PricingUpdater
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"PricingCache",
|
|
15
|
+
"PricingLoader",
|
|
16
|
+
"PricingUpdater",
|
|
17
|
+
"ModelPricing",
|
|
18
|
+
"PricingData",
|
|
19
|
+
]
|
ccproxy/pricing/cache.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
"""Pricing cache management for dynamic model pricing."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import time
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import httpx
|
|
9
|
+
from structlog import get_logger
|
|
10
|
+
|
|
11
|
+
from ccproxy.config.pricing import PricingSettings
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
logger = get_logger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class PricingCache:
|
|
18
|
+
"""Manages caching of model pricing data from external sources."""
|
|
19
|
+
|
|
20
|
+
def __init__(self, settings: PricingSettings) -> None:
|
|
21
|
+
"""Initialize pricing cache.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
settings: Pricing configuration settings
|
|
25
|
+
"""
|
|
26
|
+
self.settings = settings
|
|
27
|
+
self.cache_dir = settings.cache_dir
|
|
28
|
+
self.cache_file = self.cache_dir / "model_pricing.json"
|
|
29
|
+
|
|
30
|
+
# Ensure cache directory exists
|
|
31
|
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
32
|
+
|
|
33
|
+
def is_cache_valid(self) -> bool:
|
|
34
|
+
"""Check if cached pricing data is still valid.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
True if cache exists and is not expired
|
|
38
|
+
"""
|
|
39
|
+
if not self.cache_file.exists():
|
|
40
|
+
return False
|
|
41
|
+
|
|
42
|
+
try:
|
|
43
|
+
stat = self.cache_file.stat()
|
|
44
|
+
age_seconds = time.time() - stat.st_mtime
|
|
45
|
+
age_hours = age_seconds / 3600
|
|
46
|
+
|
|
47
|
+
is_valid = age_hours < self.settings.cache_ttl_hours
|
|
48
|
+
return is_valid
|
|
49
|
+
|
|
50
|
+
except OSError as e:
|
|
51
|
+
logger.warning("cache_stats_check_failed", error=str(e))
|
|
52
|
+
return False
|
|
53
|
+
|
|
54
|
+
def load_cached_data(self) -> dict[str, Any] | None:
|
|
55
|
+
"""Load pricing data from cache.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
Cached pricing data or None if cache is invalid/corrupted
|
|
59
|
+
"""
|
|
60
|
+
if not self.is_cache_valid():
|
|
61
|
+
return None
|
|
62
|
+
|
|
63
|
+
try:
|
|
64
|
+
with self.cache_file.open(encoding="utf-8") as f:
|
|
65
|
+
data = json.load(f)
|
|
66
|
+
|
|
67
|
+
return data # type: ignore[no-any-return]
|
|
68
|
+
|
|
69
|
+
except (OSError, json.JSONDecodeError) as e:
|
|
70
|
+
logger.warning("cache_load_failed", error=str(e))
|
|
71
|
+
return None
|
|
72
|
+
|
|
73
|
+
async def download_pricing_data(
|
|
74
|
+
self, timeout: int | None = None
|
|
75
|
+
) -> dict[str, Any] | None:
|
|
76
|
+
"""Download fresh pricing data from source URL.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
timeout: Request timeout in seconds (uses settings default if None)
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
Downloaded pricing data or None if download failed
|
|
83
|
+
"""
|
|
84
|
+
if timeout is None:
|
|
85
|
+
timeout = self.settings.download_timeout
|
|
86
|
+
|
|
87
|
+
try:
|
|
88
|
+
logger.info("pricing_download_start", url=self.settings.source_url)
|
|
89
|
+
|
|
90
|
+
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
91
|
+
response = await client.get(self.settings.source_url)
|
|
92
|
+
response.raise_for_status()
|
|
93
|
+
|
|
94
|
+
data = response.json()
|
|
95
|
+
logger.info("pricing_download_completed", model_count=len(data))
|
|
96
|
+
return data # type: ignore[no-any-return]
|
|
97
|
+
|
|
98
|
+
except (httpx.HTTPError, json.JSONDecodeError) as e:
|
|
99
|
+
logger.error("pricing_download_failed", error=str(e))
|
|
100
|
+
return None
|
|
101
|
+
|
|
102
|
+
def save_to_cache(self, data: dict[str, Any]) -> bool:
|
|
103
|
+
"""Save pricing data to cache.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
data: Pricing data to cache
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
True if successfully saved, False otherwise
|
|
110
|
+
"""
|
|
111
|
+
try:
|
|
112
|
+
# Write to temporary file first, then atomic rename
|
|
113
|
+
temp_file = self.cache_file.with_suffix(".tmp")
|
|
114
|
+
|
|
115
|
+
with temp_file.open("w", encoding="utf-8") as f:
|
|
116
|
+
json.dump(data, f, indent=2)
|
|
117
|
+
|
|
118
|
+
# Atomic rename
|
|
119
|
+
temp_file.replace(self.cache_file)
|
|
120
|
+
|
|
121
|
+
return True
|
|
122
|
+
|
|
123
|
+
except OSError as e:
|
|
124
|
+
logger.error("cache_save_failed", error=str(e))
|
|
125
|
+
return False
|
|
126
|
+
|
|
127
|
+
async def get_pricing_data(
|
|
128
|
+
self, force_refresh: bool = False
|
|
129
|
+
) -> dict[str, Any] | None:
|
|
130
|
+
"""Get pricing data, from cache if valid or by downloading fresh data.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
force_refresh: Force download even if cache is valid
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
Pricing data or None if both cache and download fail
|
|
137
|
+
"""
|
|
138
|
+
# Try cache first unless forced refresh
|
|
139
|
+
if not force_refresh:
|
|
140
|
+
cached_data = self.load_cached_data()
|
|
141
|
+
if cached_data is not None:
|
|
142
|
+
return cached_data
|
|
143
|
+
|
|
144
|
+
# Download fresh data
|
|
145
|
+
fresh_data = await self.download_pricing_data()
|
|
146
|
+
if fresh_data is not None:
|
|
147
|
+
# Save to cache for next time
|
|
148
|
+
self.save_to_cache(fresh_data)
|
|
149
|
+
return fresh_data
|
|
150
|
+
|
|
151
|
+
# If download failed, try to use stale cache as fallback
|
|
152
|
+
if not force_refresh:
|
|
153
|
+
logger.warning("pricing_download_failed_using_stale_cache")
|
|
154
|
+
try:
|
|
155
|
+
with self.cache_file.open(encoding="utf-8") as f:
|
|
156
|
+
stale_data = json.load(f)
|
|
157
|
+
logger.warning("stale_cache_used")
|
|
158
|
+
return stale_data # type: ignore[no-any-return]
|
|
159
|
+
except (OSError, json.JSONDecodeError):
|
|
160
|
+
pass
|
|
161
|
+
|
|
162
|
+
logger.error("pricing_data_unavailable")
|
|
163
|
+
return None
|
|
164
|
+
|
|
165
|
+
def clear_cache(self) -> bool:
|
|
166
|
+
"""Clear cached pricing data.
|
|
167
|
+
|
|
168
|
+
Returns:
|
|
169
|
+
True if cache was cleared successfully
|
|
170
|
+
"""
|
|
171
|
+
try:
|
|
172
|
+
if self.cache_file.exists():
|
|
173
|
+
self.cache_file.unlink()
|
|
174
|
+
return True
|
|
175
|
+
except OSError as e:
|
|
176
|
+
logger.error("cache_clear_failed", error=str(e))
|
|
177
|
+
return False
|
|
178
|
+
|
|
179
|
+
def get_cache_info(self) -> dict[str, Any]:
|
|
180
|
+
"""Get information about cache status.
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
Dictionary with cache information
|
|
184
|
+
"""
|
|
185
|
+
info = {
|
|
186
|
+
"cache_file": str(self.cache_file),
|
|
187
|
+
"cache_dir": str(self.cache_dir),
|
|
188
|
+
"source_url": self.settings.source_url,
|
|
189
|
+
"ttl_hours": self.settings.cache_ttl_hours,
|
|
190
|
+
"exists": self.cache_file.exists(),
|
|
191
|
+
"valid": False,
|
|
192
|
+
"age_hours": None,
|
|
193
|
+
"size_bytes": None,
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
if self.cache_file.exists():
|
|
197
|
+
try:
|
|
198
|
+
stat = self.cache_file.stat()
|
|
199
|
+
age_seconds = time.time() - stat.st_mtime
|
|
200
|
+
age_hours = age_seconds / 3600
|
|
201
|
+
|
|
202
|
+
info.update(
|
|
203
|
+
{
|
|
204
|
+
"valid": age_hours < self.settings.cache_ttl_hours,
|
|
205
|
+
"age_hours": age_hours,
|
|
206
|
+
"size_bytes": stat.st_size,
|
|
207
|
+
}
|
|
208
|
+
)
|
|
209
|
+
except OSError:
|
|
210
|
+
pass
|
|
211
|
+
|
|
212
|
+
return info
|
|
@@ -0,0 +1,267 @@
|
|
|
1
|
+
"""Pricing data loader and format converter for LiteLLM pricing data."""
|
|
2
|
+
|
|
3
|
+
from decimal import Decimal
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from pydantic import ValidationError
|
|
7
|
+
from structlog import get_logger
|
|
8
|
+
|
|
9
|
+
from .models import PricingData
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
logger = get_logger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class PricingLoader:
|
|
16
|
+
"""Loads and converts pricing data from LiteLLM format to internal format."""
|
|
17
|
+
|
|
18
|
+
# Claude model name mappings for different versions
|
|
19
|
+
CLAUDE_MODEL_MAPPINGS = {
|
|
20
|
+
# Map versioned models to their canonical names
|
|
21
|
+
"claude-3-5-sonnet-latest": "claude-3-5-sonnet-20241022",
|
|
22
|
+
"claude-3-5-sonnet-20240620": "claude-3-5-sonnet-20240620",
|
|
23
|
+
"claude-3-5-sonnet-20241022": "claude-3-5-sonnet-20241022",
|
|
24
|
+
"claude-3-5-haiku-latest": "claude-3-5-haiku-20241022",
|
|
25
|
+
"claude-3-5-haiku-20241022": "claude-3-5-haiku-20241022",
|
|
26
|
+
"claude-3-opus": "claude-3-opus-20240229",
|
|
27
|
+
"claude-3-opus-20240229": "claude-3-opus-20240229",
|
|
28
|
+
"claude-3-sonnet": "claude-3-sonnet-20240229",
|
|
29
|
+
"claude-3-sonnet-20240229": "claude-3-sonnet-20240229",
|
|
30
|
+
"claude-3-haiku": "claude-3-haiku-20240307",
|
|
31
|
+
"claude-3-haiku-20240307": "claude-3-haiku-20240307",
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
@staticmethod
|
|
35
|
+
def extract_claude_models(
|
|
36
|
+
litellm_data: dict[str, Any], verbose: bool = True
|
|
37
|
+
) -> dict[str, Any]:
|
|
38
|
+
"""Extract Claude model entries from LiteLLM data.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
litellm_data: Raw LiteLLM pricing data
|
|
42
|
+
verbose: Whether to log individual model discoveries
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
Dictionary with only Claude models
|
|
46
|
+
"""
|
|
47
|
+
claude_models = {}
|
|
48
|
+
|
|
49
|
+
for model_name, model_data in litellm_data.items():
|
|
50
|
+
# Check if this is a Claude model
|
|
51
|
+
if (
|
|
52
|
+
isinstance(model_data, dict)
|
|
53
|
+
and model_data.get("litellm_provider") == "anthropic"
|
|
54
|
+
and "claude" in model_name.lower()
|
|
55
|
+
):
|
|
56
|
+
claude_models[model_name] = model_data
|
|
57
|
+
if verbose:
|
|
58
|
+
logger.debug("claude_model_found", model_name=model_name)
|
|
59
|
+
|
|
60
|
+
if verbose:
|
|
61
|
+
logger.info(
|
|
62
|
+
"claude_models_extracted",
|
|
63
|
+
model_count=len(claude_models),
|
|
64
|
+
source="LiteLLM",
|
|
65
|
+
)
|
|
66
|
+
return claude_models
|
|
67
|
+
|
|
68
|
+
@staticmethod
|
|
69
|
+
def convert_to_internal_format(
|
|
70
|
+
claude_models: dict[str, Any], verbose: bool = True
|
|
71
|
+
) -> dict[str, dict[str, Decimal]]:
|
|
72
|
+
"""Convert LiteLLM pricing format to internal format.
|
|
73
|
+
|
|
74
|
+
LiteLLM format uses cost per token, we use cost per 1M tokens as Decimal.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
claude_models: Claude models in LiteLLM format
|
|
78
|
+
verbose: Whether to log individual model conversions
|
|
79
|
+
|
|
80
|
+
Returns:
|
|
81
|
+
Dictionary in internal pricing format
|
|
82
|
+
"""
|
|
83
|
+
internal_format = {}
|
|
84
|
+
|
|
85
|
+
for model_name, model_data in claude_models.items():
|
|
86
|
+
try:
|
|
87
|
+
# Extract pricing fields
|
|
88
|
+
input_cost_per_token = model_data.get("input_cost_per_token")
|
|
89
|
+
output_cost_per_token = model_data.get("output_cost_per_token")
|
|
90
|
+
cache_creation_cost = model_data.get("cache_creation_input_token_cost")
|
|
91
|
+
cache_read_cost = model_data.get("cache_read_input_token_cost")
|
|
92
|
+
|
|
93
|
+
# Skip models without pricing info
|
|
94
|
+
if input_cost_per_token is None or output_cost_per_token is None:
|
|
95
|
+
if verbose:
|
|
96
|
+
logger.warning("model_pricing_missing", model_name=model_name)
|
|
97
|
+
continue
|
|
98
|
+
|
|
99
|
+
# Convert to per-1M-token pricing (multiply by 1,000,000)
|
|
100
|
+
pricing = {
|
|
101
|
+
"input": Decimal(str(input_cost_per_token * 1_000_000)),
|
|
102
|
+
"output": Decimal(str(output_cost_per_token * 1_000_000)),
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
# Add cache pricing if available
|
|
106
|
+
if cache_creation_cost is not None:
|
|
107
|
+
pricing["cache_write"] = Decimal(
|
|
108
|
+
str(cache_creation_cost * 1_000_000)
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
if cache_read_cost is not None:
|
|
112
|
+
pricing["cache_read"] = Decimal(str(cache_read_cost * 1_000_000))
|
|
113
|
+
|
|
114
|
+
# Map to canonical model name if needed
|
|
115
|
+
canonical_name = PricingLoader.CLAUDE_MODEL_MAPPINGS.get(
|
|
116
|
+
model_name, model_name
|
|
117
|
+
)
|
|
118
|
+
internal_format[canonical_name] = pricing
|
|
119
|
+
|
|
120
|
+
if verbose:
|
|
121
|
+
logger.debug(
|
|
122
|
+
"model_pricing_converted",
|
|
123
|
+
original_name=model_name,
|
|
124
|
+
canonical_name=canonical_name,
|
|
125
|
+
input_cost=str(pricing["input"]),
|
|
126
|
+
output_cost=str(pricing["output"]),
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
except (ValueError, TypeError) as e:
|
|
130
|
+
if verbose:
|
|
131
|
+
logger.error(
|
|
132
|
+
"pricing_conversion_failed", model_name=model_name, error=str(e)
|
|
133
|
+
)
|
|
134
|
+
continue
|
|
135
|
+
|
|
136
|
+
if verbose:
|
|
137
|
+
logger.info("models_converted", model_count=len(internal_format))
|
|
138
|
+
return internal_format
|
|
139
|
+
|
|
140
|
+
@staticmethod
|
|
141
|
+
def load_pricing_from_data(
|
|
142
|
+
litellm_data: dict[str, Any], verbose: bool = True
|
|
143
|
+
) -> PricingData | None:
|
|
144
|
+
"""Load and convert pricing data from LiteLLM format.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
litellm_data: Raw LiteLLM pricing data
|
|
148
|
+
verbose: Whether to enable verbose logging
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
Validated pricing data as PricingData model, or None if invalid
|
|
152
|
+
"""
|
|
153
|
+
try:
|
|
154
|
+
# Extract Claude models
|
|
155
|
+
claude_models = PricingLoader.extract_claude_models(
|
|
156
|
+
litellm_data, verbose=verbose
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
if not claude_models:
|
|
160
|
+
if verbose:
|
|
161
|
+
logger.warning("claude_models_not_found", source="LiteLLM")
|
|
162
|
+
return None
|
|
163
|
+
|
|
164
|
+
# Convert to internal format
|
|
165
|
+
internal_pricing = PricingLoader.convert_to_internal_format(
|
|
166
|
+
claude_models, verbose=verbose
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
if not internal_pricing:
|
|
170
|
+
if verbose:
|
|
171
|
+
logger.warning("pricing_data_invalid")
|
|
172
|
+
return None
|
|
173
|
+
|
|
174
|
+
# Validate and create PricingData model
|
|
175
|
+
pricing_data = PricingData.from_dict(internal_pricing)
|
|
176
|
+
|
|
177
|
+
if verbose:
|
|
178
|
+
logger.info("pricing_data_loaded", model_count=len(pricing_data))
|
|
179
|
+
|
|
180
|
+
return pricing_data
|
|
181
|
+
|
|
182
|
+
except ValidationError as e:
|
|
183
|
+
if verbose:
|
|
184
|
+
logger.error("pricing_validation_failed", error=str(e))
|
|
185
|
+
return None
|
|
186
|
+
except Exception as e:
|
|
187
|
+
if verbose:
|
|
188
|
+
logger.error("pricing_load_failed", source="LiteLLM", error=str(e))
|
|
189
|
+
return None
|
|
190
|
+
|
|
191
|
+
@staticmethod
|
|
192
|
+
def validate_pricing_data(
|
|
193
|
+
pricing_data: Any, verbose: bool = True
|
|
194
|
+
) -> PricingData | None:
|
|
195
|
+
"""Validate pricing data using Pydantic models.
|
|
196
|
+
|
|
197
|
+
Args:
|
|
198
|
+
pricing_data: Pricing data to validate (dict or PricingData)
|
|
199
|
+
verbose: Whether to enable verbose logging
|
|
200
|
+
|
|
201
|
+
Returns:
|
|
202
|
+
Valid PricingData model or None if validation fails
|
|
203
|
+
"""
|
|
204
|
+
try:
|
|
205
|
+
# If already a PricingData instance, return it
|
|
206
|
+
if isinstance(pricing_data, PricingData):
|
|
207
|
+
if verbose:
|
|
208
|
+
logger.debug(
|
|
209
|
+
"pricing_already_validated", model_count=len(pricing_data)
|
|
210
|
+
)
|
|
211
|
+
return pricing_data
|
|
212
|
+
|
|
213
|
+
# If it's a dict, try to create PricingData from it
|
|
214
|
+
if isinstance(pricing_data, dict):
|
|
215
|
+
if not pricing_data:
|
|
216
|
+
if verbose:
|
|
217
|
+
logger.warning("pricing_data_empty")
|
|
218
|
+
return None
|
|
219
|
+
|
|
220
|
+
# Try to create PricingData model
|
|
221
|
+
validated_data = PricingData.from_dict(pricing_data)
|
|
222
|
+
|
|
223
|
+
if verbose:
|
|
224
|
+
logger.debug(
|
|
225
|
+
"pricing_data_validated", model_count=len(validated_data)
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
return validated_data
|
|
229
|
+
|
|
230
|
+
# Invalid type
|
|
231
|
+
if verbose:
|
|
232
|
+
logger.error(
|
|
233
|
+
"pricing_data_invalid_type",
|
|
234
|
+
actual_type=type(pricing_data).__name__,
|
|
235
|
+
expected_types=["dict", "PricingData"],
|
|
236
|
+
)
|
|
237
|
+
return None
|
|
238
|
+
|
|
239
|
+
except ValidationError as e:
|
|
240
|
+
if verbose:
|
|
241
|
+
logger.error("pricing_validation_failed", error=str(e))
|
|
242
|
+
return None
|
|
243
|
+
except Exception as e:
|
|
244
|
+
if verbose:
|
|
245
|
+
logger.error("pricing_validation_unexpected_error", error=str(e))
|
|
246
|
+
return None
|
|
247
|
+
|
|
248
|
+
@staticmethod
|
|
249
|
+
def get_model_aliases() -> dict[str, str]:
|
|
250
|
+
"""Get mapping of model aliases to canonical names.
|
|
251
|
+
|
|
252
|
+
Returns:
|
|
253
|
+
Dictionary mapping aliases to canonical model names
|
|
254
|
+
"""
|
|
255
|
+
return PricingLoader.CLAUDE_MODEL_MAPPINGS.copy()
|
|
256
|
+
|
|
257
|
+
@staticmethod
|
|
258
|
+
def get_canonical_model_name(model_name: str) -> str:
|
|
259
|
+
"""Get canonical model name for a given model name.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
model_name: Model name (possibly an alias)
|
|
263
|
+
|
|
264
|
+
Returns:
|
|
265
|
+
Canonical model name
|
|
266
|
+
"""
|
|
267
|
+
return PricingLoader.CLAUDE_MODEL_MAPPINGS.get(model_name, model_name)
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
"""Pydantic models for pricing data validation and type safety."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Iterator
|
|
4
|
+
from decimal import Decimal
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from pydantic import BaseModel, Field, RootModel, field_validator
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ModelPricing(BaseModel):
|
|
11
|
+
"""Pricing information for a single Claude model.
|
|
12
|
+
|
|
13
|
+
All costs are in USD per 1 million tokens.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
input: Decimal = Field(..., ge=0, description="Input token cost per 1M tokens")
|
|
17
|
+
output: Decimal = Field(..., ge=0, description="Output token cost per 1M tokens")
|
|
18
|
+
cache_read: Decimal = Field(
|
|
19
|
+
default=Decimal("0"), ge=0, description="Cache read cost per 1M tokens"
|
|
20
|
+
)
|
|
21
|
+
cache_write: Decimal = Field(
|
|
22
|
+
default=Decimal("0"), ge=0, description="Cache write cost per 1M tokens"
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
@field_validator("*", mode="before")
|
|
26
|
+
@classmethod
|
|
27
|
+
def convert_to_decimal(cls, v: Any) -> Decimal:
|
|
28
|
+
"""Convert numeric values to Decimal for precision."""
|
|
29
|
+
if isinstance(v, int | float | str):
|
|
30
|
+
return Decimal(str(v))
|
|
31
|
+
if isinstance(v, Decimal):
|
|
32
|
+
return v
|
|
33
|
+
raise TypeError(f"Cannot convert {type(v)} to Decimal")
|
|
34
|
+
|
|
35
|
+
class Config:
|
|
36
|
+
"""Pydantic configuration."""
|
|
37
|
+
|
|
38
|
+
arbitrary_types_allowed = True
|
|
39
|
+
json_encoders = {Decimal: lambda v: float(v)}
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class PricingData(RootModel[dict[str, ModelPricing]]):
|
|
43
|
+
"""Complete pricing data for all Claude models.
|
|
44
|
+
|
|
45
|
+
This is a wrapper around a dictionary of model name to ModelPricing
|
|
46
|
+
that provides dict-like access while maintaining type safety.
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __iter__(self) -> Iterator[str]: # type: ignore[override]
|
|
50
|
+
"""Iterate over model names."""
|
|
51
|
+
return iter(self.root)
|
|
52
|
+
|
|
53
|
+
def __getitem__(self, model_name: str) -> ModelPricing:
|
|
54
|
+
"""Get pricing for a specific model."""
|
|
55
|
+
return self.root[model_name]
|
|
56
|
+
|
|
57
|
+
def __contains__(self, model_name: str) -> bool:
|
|
58
|
+
"""Check if model exists in pricing data."""
|
|
59
|
+
return model_name in self.root
|
|
60
|
+
|
|
61
|
+
def __len__(self) -> int:
|
|
62
|
+
"""Get number of models in pricing data."""
|
|
63
|
+
return len(self.root)
|
|
64
|
+
|
|
65
|
+
def items(self) -> Iterator[tuple[str, ModelPricing]]:
|
|
66
|
+
"""Get model name and pricing pairs."""
|
|
67
|
+
return iter(self.root.items())
|
|
68
|
+
|
|
69
|
+
def keys(self) -> Iterator[str]:
|
|
70
|
+
"""Get model names."""
|
|
71
|
+
return iter(self.root.keys())
|
|
72
|
+
|
|
73
|
+
def values(self) -> Iterator[ModelPricing]:
|
|
74
|
+
"""Get pricing objects."""
|
|
75
|
+
return iter(self.root.values())
|
|
76
|
+
|
|
77
|
+
def get(
|
|
78
|
+
self, model_name: str, default: ModelPricing | None = None
|
|
79
|
+
) -> ModelPricing | None:
|
|
80
|
+
"""Get pricing for a model with optional default."""
|
|
81
|
+
return self.root.get(model_name, default)
|
|
82
|
+
|
|
83
|
+
def model_names(self) -> list[str]:
|
|
84
|
+
"""Get list of all model names."""
|
|
85
|
+
return list(self.root.keys())
|
|
86
|
+
|
|
87
|
+
def to_dict(self) -> dict[str, dict[str, Decimal]]:
|
|
88
|
+
"""Convert to legacy dict format for backward compatibility."""
|
|
89
|
+
return {
|
|
90
|
+
model_name: {
|
|
91
|
+
"input": pricing.input,
|
|
92
|
+
"output": pricing.output,
|
|
93
|
+
"cache_read": pricing.cache_read,
|
|
94
|
+
"cache_write": pricing.cache_write,
|
|
95
|
+
}
|
|
96
|
+
for model_name, pricing in self.root.items()
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
@classmethod
|
|
100
|
+
def from_dict(cls, data: dict[str, dict[str, Any]]) -> "PricingData":
|
|
101
|
+
"""Create PricingData from legacy dict format."""
|
|
102
|
+
models = {
|
|
103
|
+
model_name: ModelPricing(**pricing_dict)
|
|
104
|
+
for model_name, pricing_dict in data.items()
|
|
105
|
+
}
|
|
106
|
+
return cls(root=models)
|