ccproxy-api 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. ccproxy/__init__.py +4 -0
  2. ccproxy/__main__.py +7 -0
  3. ccproxy/_version.py +21 -0
  4. ccproxy/adapters/__init__.py +11 -0
  5. ccproxy/adapters/base.py +80 -0
  6. ccproxy/adapters/openai/__init__.py +43 -0
  7. ccproxy/adapters/openai/adapter.py +915 -0
  8. ccproxy/adapters/openai/models.py +412 -0
  9. ccproxy/adapters/openai/streaming.py +449 -0
  10. ccproxy/api/__init__.py +28 -0
  11. ccproxy/api/app.py +225 -0
  12. ccproxy/api/dependencies.py +140 -0
  13. ccproxy/api/middleware/__init__.py +11 -0
  14. ccproxy/api/middleware/auth.py +0 -0
  15. ccproxy/api/middleware/cors.py +55 -0
  16. ccproxy/api/middleware/errors.py +703 -0
  17. ccproxy/api/middleware/headers.py +51 -0
  18. ccproxy/api/middleware/logging.py +175 -0
  19. ccproxy/api/middleware/request_id.py +69 -0
  20. ccproxy/api/middleware/server_header.py +62 -0
  21. ccproxy/api/responses.py +84 -0
  22. ccproxy/api/routes/__init__.py +16 -0
  23. ccproxy/api/routes/claude.py +181 -0
  24. ccproxy/api/routes/health.py +489 -0
  25. ccproxy/api/routes/metrics.py +1033 -0
  26. ccproxy/api/routes/proxy.py +238 -0
  27. ccproxy/auth/__init__.py +75 -0
  28. ccproxy/auth/bearer.py +68 -0
  29. ccproxy/auth/credentials_adapter.py +93 -0
  30. ccproxy/auth/dependencies.py +229 -0
  31. ccproxy/auth/exceptions.py +79 -0
  32. ccproxy/auth/manager.py +102 -0
  33. ccproxy/auth/models.py +118 -0
  34. ccproxy/auth/oauth/__init__.py +26 -0
  35. ccproxy/auth/oauth/models.py +49 -0
  36. ccproxy/auth/oauth/routes.py +396 -0
  37. ccproxy/auth/oauth/storage.py +0 -0
  38. ccproxy/auth/storage/__init__.py +12 -0
  39. ccproxy/auth/storage/base.py +57 -0
  40. ccproxy/auth/storage/json_file.py +159 -0
  41. ccproxy/auth/storage/keyring.py +192 -0
  42. ccproxy/claude_sdk/__init__.py +20 -0
  43. ccproxy/claude_sdk/client.py +169 -0
  44. ccproxy/claude_sdk/converter.py +331 -0
  45. ccproxy/claude_sdk/options.py +120 -0
  46. ccproxy/cli/__init__.py +14 -0
  47. ccproxy/cli/commands/__init__.py +8 -0
  48. ccproxy/cli/commands/auth.py +553 -0
  49. ccproxy/cli/commands/config/__init__.py +14 -0
  50. ccproxy/cli/commands/config/commands.py +766 -0
  51. ccproxy/cli/commands/config/schema_commands.py +119 -0
  52. ccproxy/cli/commands/serve.py +630 -0
  53. ccproxy/cli/docker/__init__.py +34 -0
  54. ccproxy/cli/docker/adapter_factory.py +157 -0
  55. ccproxy/cli/docker/params.py +278 -0
  56. ccproxy/cli/helpers.py +144 -0
  57. ccproxy/cli/main.py +193 -0
  58. ccproxy/cli/options/__init__.py +14 -0
  59. ccproxy/cli/options/claude_options.py +216 -0
  60. ccproxy/cli/options/core_options.py +40 -0
  61. ccproxy/cli/options/security_options.py +48 -0
  62. ccproxy/cli/options/server_options.py +117 -0
  63. ccproxy/config/__init__.py +40 -0
  64. ccproxy/config/auth.py +154 -0
  65. ccproxy/config/claude.py +124 -0
  66. ccproxy/config/cors.py +79 -0
  67. ccproxy/config/discovery.py +87 -0
  68. ccproxy/config/docker_settings.py +265 -0
  69. ccproxy/config/loader.py +108 -0
  70. ccproxy/config/observability.py +158 -0
  71. ccproxy/config/pricing.py +88 -0
  72. ccproxy/config/reverse_proxy.py +31 -0
  73. ccproxy/config/scheduler.py +89 -0
  74. ccproxy/config/security.py +14 -0
  75. ccproxy/config/server.py +81 -0
  76. ccproxy/config/settings.py +534 -0
  77. ccproxy/config/validators.py +231 -0
  78. ccproxy/core/__init__.py +274 -0
  79. ccproxy/core/async_utils.py +675 -0
  80. ccproxy/core/constants.py +97 -0
  81. ccproxy/core/errors.py +256 -0
  82. ccproxy/core/http.py +328 -0
  83. ccproxy/core/http_transformers.py +428 -0
  84. ccproxy/core/interfaces.py +247 -0
  85. ccproxy/core/logging.py +189 -0
  86. ccproxy/core/middleware.py +114 -0
  87. ccproxy/core/proxy.py +143 -0
  88. ccproxy/core/system.py +38 -0
  89. ccproxy/core/transformers.py +259 -0
  90. ccproxy/core/types.py +129 -0
  91. ccproxy/core/validators.py +288 -0
  92. ccproxy/docker/__init__.py +67 -0
  93. ccproxy/docker/adapter.py +588 -0
  94. ccproxy/docker/docker_path.py +207 -0
  95. ccproxy/docker/middleware.py +103 -0
  96. ccproxy/docker/models.py +228 -0
  97. ccproxy/docker/protocol.py +192 -0
  98. ccproxy/docker/stream_process.py +264 -0
  99. ccproxy/docker/validators.py +173 -0
  100. ccproxy/models/__init__.py +123 -0
  101. ccproxy/models/errors.py +42 -0
  102. ccproxy/models/messages.py +243 -0
  103. ccproxy/models/requests.py +85 -0
  104. ccproxy/models/responses.py +227 -0
  105. ccproxy/models/types.py +102 -0
  106. ccproxy/observability/__init__.py +51 -0
  107. ccproxy/observability/access_logger.py +400 -0
  108. ccproxy/observability/context.py +447 -0
  109. ccproxy/observability/metrics.py +539 -0
  110. ccproxy/observability/pushgateway.py +366 -0
  111. ccproxy/observability/sse_events.py +303 -0
  112. ccproxy/observability/stats_printer.py +755 -0
  113. ccproxy/observability/storage/__init__.py +1 -0
  114. ccproxy/observability/storage/duckdb_simple.py +665 -0
  115. ccproxy/observability/storage/models.py +55 -0
  116. ccproxy/pricing/__init__.py +19 -0
  117. ccproxy/pricing/cache.py +212 -0
  118. ccproxy/pricing/loader.py +267 -0
  119. ccproxy/pricing/models.py +106 -0
  120. ccproxy/pricing/updater.py +309 -0
  121. ccproxy/scheduler/__init__.py +39 -0
  122. ccproxy/scheduler/core.py +335 -0
  123. ccproxy/scheduler/exceptions.py +34 -0
  124. ccproxy/scheduler/manager.py +186 -0
  125. ccproxy/scheduler/registry.py +150 -0
  126. ccproxy/scheduler/tasks.py +484 -0
  127. ccproxy/services/__init__.py +10 -0
  128. ccproxy/services/claude_sdk_service.py +614 -0
  129. ccproxy/services/credentials/__init__.py +55 -0
  130. ccproxy/services/credentials/config.py +105 -0
  131. ccproxy/services/credentials/manager.py +562 -0
  132. ccproxy/services/credentials/oauth_client.py +482 -0
  133. ccproxy/services/proxy_service.py +1536 -0
  134. ccproxy/static/.keep +0 -0
  135. ccproxy/testing/__init__.py +34 -0
  136. ccproxy/testing/config.py +148 -0
  137. ccproxy/testing/content_generation.py +197 -0
  138. ccproxy/testing/mock_responses.py +262 -0
  139. ccproxy/testing/response_handlers.py +161 -0
  140. ccproxy/testing/scenarios.py +241 -0
  141. ccproxy/utils/__init__.py +6 -0
  142. ccproxy/utils/cost_calculator.py +210 -0
  143. ccproxy/utils/streaming_metrics.py +199 -0
  144. ccproxy_api-0.1.0.dist-info/METADATA +253 -0
  145. ccproxy_api-0.1.0.dist-info/RECORD +148 -0
  146. ccproxy_api-0.1.0.dist-info/WHEEL +4 -0
  147. ccproxy_api-0.1.0.dist-info/entry_points.txt +2 -0
  148. ccproxy_api-0.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,55 @@
1
+ """
2
+ SQLModel schema definitions for observability storage.
3
+
4
+ This module provides the centralized schema definitions for access logs and metrics
5
+ using SQLModel to ensure type safety and eliminate column name repetition.
6
+ """
7
+
8
+ from datetime import datetime
9
+ from typing import Optional
10
+
11
+ from sqlmodel import Field, SQLModel
12
+
13
+
14
+ class AccessLog(SQLModel, table=True):
15
+ """Access log model for storing request/response data."""
16
+
17
+ __tablename__ = "access_logs"
18
+
19
+ # Core request identification
20
+ request_id: str = Field(primary_key=True)
21
+ timestamp: datetime = Field(default_factory=datetime.now, index=True)
22
+
23
+ # Request details
24
+ method: str
25
+ endpoint: str
26
+ path: str
27
+ query: str = Field(default="")
28
+ client_ip: str
29
+ user_agent: str
30
+
31
+ # Service and model info
32
+ service_type: str
33
+ model: str
34
+ streaming: bool = Field(default=False)
35
+
36
+ # Response details
37
+ status_code: int
38
+ duration_ms: float
39
+ duration_seconds: float
40
+
41
+ # Token and cost tracking
42
+ tokens_input: int = Field(default=0)
43
+ tokens_output: int = Field(default=0)
44
+ cache_read_tokens: int = Field(default=0)
45
+ cache_write_tokens: int = Field(default=0)
46
+ cost_usd: float = Field(default=0.0)
47
+ cost_sdk_usd: float = Field(default=0.0)
48
+
49
+ class Config:
50
+ """SQLModel configuration."""
51
+
52
+ # Enable automatic conversion from dict
53
+ from_attributes = True
54
+ # Use enum values
55
+ use_enum_values = True
@@ -0,0 +1,19 @@
1
+ """Dynamic pricing system for Claude models.
2
+
3
+ This module provides dynamic pricing capabilities by downloading and caching
4
+ pricing information from external sources like LiteLLM.
5
+ """
6
+
7
+ from .cache import PricingCache
8
+ from .loader import PricingLoader
9
+ from .models import ModelPricing, PricingData
10
+ from .updater import PricingUpdater
11
+
12
+
13
+ __all__ = [
14
+ "PricingCache",
15
+ "PricingLoader",
16
+ "PricingUpdater",
17
+ "ModelPricing",
18
+ "PricingData",
19
+ ]
@@ -0,0 +1,212 @@
1
+ """Pricing cache management for dynamic model pricing."""
2
+
3
+ import json
4
+ import time
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import httpx
9
+ from structlog import get_logger
10
+
11
+ from ccproxy.config.pricing import PricingSettings
12
+
13
+
14
+ logger = get_logger(__name__)
15
+
16
+
17
+ class PricingCache:
18
+ """Manages caching of model pricing data from external sources."""
19
+
20
+ def __init__(self, settings: PricingSettings) -> None:
21
+ """Initialize pricing cache.
22
+
23
+ Args:
24
+ settings: Pricing configuration settings
25
+ """
26
+ self.settings = settings
27
+ self.cache_dir = settings.cache_dir
28
+ self.cache_file = self.cache_dir / "model_pricing.json"
29
+
30
+ # Ensure cache directory exists
31
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
32
+
33
+ def is_cache_valid(self) -> bool:
34
+ """Check if cached pricing data is still valid.
35
+
36
+ Returns:
37
+ True if cache exists and is not expired
38
+ """
39
+ if not self.cache_file.exists():
40
+ return False
41
+
42
+ try:
43
+ stat = self.cache_file.stat()
44
+ age_seconds = time.time() - stat.st_mtime
45
+ age_hours = age_seconds / 3600
46
+
47
+ is_valid = age_hours < self.settings.cache_ttl_hours
48
+ return is_valid
49
+
50
+ except OSError as e:
51
+ logger.warning("cache_stats_check_failed", error=str(e))
52
+ return False
53
+
54
+ def load_cached_data(self) -> dict[str, Any] | None:
55
+ """Load pricing data from cache.
56
+
57
+ Returns:
58
+ Cached pricing data or None if cache is invalid/corrupted
59
+ """
60
+ if not self.is_cache_valid():
61
+ return None
62
+
63
+ try:
64
+ with self.cache_file.open(encoding="utf-8") as f:
65
+ data = json.load(f)
66
+
67
+ return data # type: ignore[no-any-return]
68
+
69
+ except (OSError, json.JSONDecodeError) as e:
70
+ logger.warning("cache_load_failed", error=str(e))
71
+ return None
72
+
73
+ async def download_pricing_data(
74
+ self, timeout: int | None = None
75
+ ) -> dict[str, Any] | None:
76
+ """Download fresh pricing data from source URL.
77
+
78
+ Args:
79
+ timeout: Request timeout in seconds (uses settings default if None)
80
+
81
+ Returns:
82
+ Downloaded pricing data or None if download failed
83
+ """
84
+ if timeout is None:
85
+ timeout = self.settings.download_timeout
86
+
87
+ try:
88
+ logger.info("pricing_download_start", url=self.settings.source_url)
89
+
90
+ async with httpx.AsyncClient(timeout=timeout) as client:
91
+ response = await client.get(self.settings.source_url)
92
+ response.raise_for_status()
93
+
94
+ data = response.json()
95
+ logger.info("pricing_download_completed", model_count=len(data))
96
+ return data # type: ignore[no-any-return]
97
+
98
+ except (httpx.HTTPError, json.JSONDecodeError) as e:
99
+ logger.error("pricing_download_failed", error=str(e))
100
+ return None
101
+
102
+ def save_to_cache(self, data: dict[str, Any]) -> bool:
103
+ """Save pricing data to cache.
104
+
105
+ Args:
106
+ data: Pricing data to cache
107
+
108
+ Returns:
109
+ True if successfully saved, False otherwise
110
+ """
111
+ try:
112
+ # Write to temporary file first, then atomic rename
113
+ temp_file = self.cache_file.with_suffix(".tmp")
114
+
115
+ with temp_file.open("w", encoding="utf-8") as f:
116
+ json.dump(data, f, indent=2)
117
+
118
+ # Atomic rename
119
+ temp_file.replace(self.cache_file)
120
+
121
+ return True
122
+
123
+ except OSError as e:
124
+ logger.error("cache_save_failed", error=str(e))
125
+ return False
126
+
127
+ async def get_pricing_data(
128
+ self, force_refresh: bool = False
129
+ ) -> dict[str, Any] | None:
130
+ """Get pricing data, from cache if valid or by downloading fresh data.
131
+
132
+ Args:
133
+ force_refresh: Force download even if cache is valid
134
+
135
+ Returns:
136
+ Pricing data or None if both cache and download fail
137
+ """
138
+ # Try cache first unless forced refresh
139
+ if not force_refresh:
140
+ cached_data = self.load_cached_data()
141
+ if cached_data is not None:
142
+ return cached_data
143
+
144
+ # Download fresh data
145
+ fresh_data = await self.download_pricing_data()
146
+ if fresh_data is not None:
147
+ # Save to cache for next time
148
+ self.save_to_cache(fresh_data)
149
+ return fresh_data
150
+
151
+ # If download failed, try to use stale cache as fallback
152
+ if not force_refresh:
153
+ logger.warning("pricing_download_failed_using_stale_cache")
154
+ try:
155
+ with self.cache_file.open(encoding="utf-8") as f:
156
+ stale_data = json.load(f)
157
+ logger.warning("stale_cache_used")
158
+ return stale_data # type: ignore[no-any-return]
159
+ except (OSError, json.JSONDecodeError):
160
+ pass
161
+
162
+ logger.error("pricing_data_unavailable")
163
+ return None
164
+
165
+ def clear_cache(self) -> bool:
166
+ """Clear cached pricing data.
167
+
168
+ Returns:
169
+ True if cache was cleared successfully
170
+ """
171
+ try:
172
+ if self.cache_file.exists():
173
+ self.cache_file.unlink()
174
+ return True
175
+ except OSError as e:
176
+ logger.error("cache_clear_failed", error=str(e))
177
+ return False
178
+
179
+ def get_cache_info(self) -> dict[str, Any]:
180
+ """Get information about cache status.
181
+
182
+ Returns:
183
+ Dictionary with cache information
184
+ """
185
+ info = {
186
+ "cache_file": str(self.cache_file),
187
+ "cache_dir": str(self.cache_dir),
188
+ "source_url": self.settings.source_url,
189
+ "ttl_hours": self.settings.cache_ttl_hours,
190
+ "exists": self.cache_file.exists(),
191
+ "valid": False,
192
+ "age_hours": None,
193
+ "size_bytes": None,
194
+ }
195
+
196
+ if self.cache_file.exists():
197
+ try:
198
+ stat = self.cache_file.stat()
199
+ age_seconds = time.time() - stat.st_mtime
200
+ age_hours = age_seconds / 3600
201
+
202
+ info.update(
203
+ {
204
+ "valid": age_hours < self.settings.cache_ttl_hours,
205
+ "age_hours": age_hours,
206
+ "size_bytes": stat.st_size,
207
+ }
208
+ )
209
+ except OSError:
210
+ pass
211
+
212
+ return info
@@ -0,0 +1,267 @@
1
+ """Pricing data loader and format converter for LiteLLM pricing data."""
2
+
3
+ from decimal import Decimal
4
+ from typing import Any
5
+
6
+ from pydantic import ValidationError
7
+ from structlog import get_logger
8
+
9
+ from .models import PricingData
10
+
11
+
12
+ logger = get_logger(__name__)
13
+
14
+
15
+ class PricingLoader:
16
+ """Loads and converts pricing data from LiteLLM format to internal format."""
17
+
18
+ # Claude model name mappings for different versions
19
+ CLAUDE_MODEL_MAPPINGS = {
20
+ # Map versioned models to their canonical names
21
+ "claude-3-5-sonnet-latest": "claude-3-5-sonnet-20241022",
22
+ "claude-3-5-sonnet-20240620": "claude-3-5-sonnet-20240620",
23
+ "claude-3-5-sonnet-20241022": "claude-3-5-sonnet-20241022",
24
+ "claude-3-5-haiku-latest": "claude-3-5-haiku-20241022",
25
+ "claude-3-5-haiku-20241022": "claude-3-5-haiku-20241022",
26
+ "claude-3-opus": "claude-3-opus-20240229",
27
+ "claude-3-opus-20240229": "claude-3-opus-20240229",
28
+ "claude-3-sonnet": "claude-3-sonnet-20240229",
29
+ "claude-3-sonnet-20240229": "claude-3-sonnet-20240229",
30
+ "claude-3-haiku": "claude-3-haiku-20240307",
31
+ "claude-3-haiku-20240307": "claude-3-haiku-20240307",
32
+ }
33
+
34
+ @staticmethod
35
+ def extract_claude_models(
36
+ litellm_data: dict[str, Any], verbose: bool = True
37
+ ) -> dict[str, Any]:
38
+ """Extract Claude model entries from LiteLLM data.
39
+
40
+ Args:
41
+ litellm_data: Raw LiteLLM pricing data
42
+ verbose: Whether to log individual model discoveries
43
+
44
+ Returns:
45
+ Dictionary with only Claude models
46
+ """
47
+ claude_models = {}
48
+
49
+ for model_name, model_data in litellm_data.items():
50
+ # Check if this is a Claude model
51
+ if (
52
+ isinstance(model_data, dict)
53
+ and model_data.get("litellm_provider") == "anthropic"
54
+ and "claude" in model_name.lower()
55
+ ):
56
+ claude_models[model_name] = model_data
57
+ if verbose:
58
+ logger.debug("claude_model_found", model_name=model_name)
59
+
60
+ if verbose:
61
+ logger.info(
62
+ "claude_models_extracted",
63
+ model_count=len(claude_models),
64
+ source="LiteLLM",
65
+ )
66
+ return claude_models
67
+
68
+ @staticmethod
69
+ def convert_to_internal_format(
70
+ claude_models: dict[str, Any], verbose: bool = True
71
+ ) -> dict[str, dict[str, Decimal]]:
72
+ """Convert LiteLLM pricing format to internal format.
73
+
74
+ LiteLLM format uses cost per token, we use cost per 1M tokens as Decimal.
75
+
76
+ Args:
77
+ claude_models: Claude models in LiteLLM format
78
+ verbose: Whether to log individual model conversions
79
+
80
+ Returns:
81
+ Dictionary in internal pricing format
82
+ """
83
+ internal_format = {}
84
+
85
+ for model_name, model_data in claude_models.items():
86
+ try:
87
+ # Extract pricing fields
88
+ input_cost_per_token = model_data.get("input_cost_per_token")
89
+ output_cost_per_token = model_data.get("output_cost_per_token")
90
+ cache_creation_cost = model_data.get("cache_creation_input_token_cost")
91
+ cache_read_cost = model_data.get("cache_read_input_token_cost")
92
+
93
+ # Skip models without pricing info
94
+ if input_cost_per_token is None or output_cost_per_token is None:
95
+ if verbose:
96
+ logger.warning("model_pricing_missing", model_name=model_name)
97
+ continue
98
+
99
+ # Convert to per-1M-token pricing (multiply by 1,000,000)
100
+ pricing = {
101
+ "input": Decimal(str(input_cost_per_token * 1_000_000)),
102
+ "output": Decimal(str(output_cost_per_token * 1_000_000)),
103
+ }
104
+
105
+ # Add cache pricing if available
106
+ if cache_creation_cost is not None:
107
+ pricing["cache_write"] = Decimal(
108
+ str(cache_creation_cost * 1_000_000)
109
+ )
110
+
111
+ if cache_read_cost is not None:
112
+ pricing["cache_read"] = Decimal(str(cache_read_cost * 1_000_000))
113
+
114
+ # Map to canonical model name if needed
115
+ canonical_name = PricingLoader.CLAUDE_MODEL_MAPPINGS.get(
116
+ model_name, model_name
117
+ )
118
+ internal_format[canonical_name] = pricing
119
+
120
+ if verbose:
121
+ logger.debug(
122
+ "model_pricing_converted",
123
+ original_name=model_name,
124
+ canonical_name=canonical_name,
125
+ input_cost=str(pricing["input"]),
126
+ output_cost=str(pricing["output"]),
127
+ )
128
+
129
+ except (ValueError, TypeError) as e:
130
+ if verbose:
131
+ logger.error(
132
+ "pricing_conversion_failed", model_name=model_name, error=str(e)
133
+ )
134
+ continue
135
+
136
+ if verbose:
137
+ logger.info("models_converted", model_count=len(internal_format))
138
+ return internal_format
139
+
140
+ @staticmethod
141
+ def load_pricing_from_data(
142
+ litellm_data: dict[str, Any], verbose: bool = True
143
+ ) -> PricingData | None:
144
+ """Load and convert pricing data from LiteLLM format.
145
+
146
+ Args:
147
+ litellm_data: Raw LiteLLM pricing data
148
+ verbose: Whether to enable verbose logging
149
+
150
+ Returns:
151
+ Validated pricing data as PricingData model, or None if invalid
152
+ """
153
+ try:
154
+ # Extract Claude models
155
+ claude_models = PricingLoader.extract_claude_models(
156
+ litellm_data, verbose=verbose
157
+ )
158
+
159
+ if not claude_models:
160
+ if verbose:
161
+ logger.warning("claude_models_not_found", source="LiteLLM")
162
+ return None
163
+
164
+ # Convert to internal format
165
+ internal_pricing = PricingLoader.convert_to_internal_format(
166
+ claude_models, verbose=verbose
167
+ )
168
+
169
+ if not internal_pricing:
170
+ if verbose:
171
+ logger.warning("pricing_data_invalid")
172
+ return None
173
+
174
+ # Validate and create PricingData model
175
+ pricing_data = PricingData.from_dict(internal_pricing)
176
+
177
+ if verbose:
178
+ logger.info("pricing_data_loaded", model_count=len(pricing_data))
179
+
180
+ return pricing_data
181
+
182
+ except ValidationError as e:
183
+ if verbose:
184
+ logger.error("pricing_validation_failed", error=str(e))
185
+ return None
186
+ except Exception as e:
187
+ if verbose:
188
+ logger.error("pricing_load_failed", source="LiteLLM", error=str(e))
189
+ return None
190
+
191
+ @staticmethod
192
+ def validate_pricing_data(
193
+ pricing_data: Any, verbose: bool = True
194
+ ) -> PricingData | None:
195
+ """Validate pricing data using Pydantic models.
196
+
197
+ Args:
198
+ pricing_data: Pricing data to validate (dict or PricingData)
199
+ verbose: Whether to enable verbose logging
200
+
201
+ Returns:
202
+ Valid PricingData model or None if validation fails
203
+ """
204
+ try:
205
+ # If already a PricingData instance, return it
206
+ if isinstance(pricing_data, PricingData):
207
+ if verbose:
208
+ logger.debug(
209
+ "pricing_already_validated", model_count=len(pricing_data)
210
+ )
211
+ return pricing_data
212
+
213
+ # If it's a dict, try to create PricingData from it
214
+ if isinstance(pricing_data, dict):
215
+ if not pricing_data:
216
+ if verbose:
217
+ logger.warning("pricing_data_empty")
218
+ return None
219
+
220
+ # Try to create PricingData model
221
+ validated_data = PricingData.from_dict(pricing_data)
222
+
223
+ if verbose:
224
+ logger.debug(
225
+ "pricing_data_validated", model_count=len(validated_data)
226
+ )
227
+
228
+ return validated_data
229
+
230
+ # Invalid type
231
+ if verbose:
232
+ logger.error(
233
+ "pricing_data_invalid_type",
234
+ actual_type=type(pricing_data).__name__,
235
+ expected_types=["dict", "PricingData"],
236
+ )
237
+ return None
238
+
239
+ except ValidationError as e:
240
+ if verbose:
241
+ logger.error("pricing_validation_failed", error=str(e))
242
+ return None
243
+ except Exception as e:
244
+ if verbose:
245
+ logger.error("pricing_validation_unexpected_error", error=str(e))
246
+ return None
247
+
248
+ @staticmethod
249
+ def get_model_aliases() -> dict[str, str]:
250
+ """Get mapping of model aliases to canonical names.
251
+
252
+ Returns:
253
+ Dictionary mapping aliases to canonical model names
254
+ """
255
+ return PricingLoader.CLAUDE_MODEL_MAPPINGS.copy()
256
+
257
+ @staticmethod
258
+ def get_canonical_model_name(model_name: str) -> str:
259
+ """Get canonical model name for a given model name.
260
+
261
+ Args:
262
+ model_name: Model name (possibly an alias)
263
+
264
+ Returns:
265
+ Canonical model name
266
+ """
267
+ return PricingLoader.CLAUDE_MODEL_MAPPINGS.get(model_name, model_name)
@@ -0,0 +1,106 @@
1
+ """Pydantic models for pricing data validation and type safety."""
2
+
3
+ from collections.abc import Iterator
4
+ from decimal import Decimal
5
+ from typing import Any
6
+
7
+ from pydantic import BaseModel, Field, RootModel, field_validator
8
+
9
+
10
+ class ModelPricing(BaseModel):
11
+ """Pricing information for a single Claude model.
12
+
13
+ All costs are in USD per 1 million tokens.
14
+ """
15
+
16
+ input: Decimal = Field(..., ge=0, description="Input token cost per 1M tokens")
17
+ output: Decimal = Field(..., ge=0, description="Output token cost per 1M tokens")
18
+ cache_read: Decimal = Field(
19
+ default=Decimal("0"), ge=0, description="Cache read cost per 1M tokens"
20
+ )
21
+ cache_write: Decimal = Field(
22
+ default=Decimal("0"), ge=0, description="Cache write cost per 1M tokens"
23
+ )
24
+
25
+ @field_validator("*", mode="before")
26
+ @classmethod
27
+ def convert_to_decimal(cls, v: Any) -> Decimal:
28
+ """Convert numeric values to Decimal for precision."""
29
+ if isinstance(v, int | float | str):
30
+ return Decimal(str(v))
31
+ if isinstance(v, Decimal):
32
+ return v
33
+ raise TypeError(f"Cannot convert {type(v)} to Decimal")
34
+
35
+ class Config:
36
+ """Pydantic configuration."""
37
+
38
+ arbitrary_types_allowed = True
39
+ json_encoders = {Decimal: lambda v: float(v)}
40
+
41
+
42
+ class PricingData(RootModel[dict[str, ModelPricing]]):
43
+ """Complete pricing data for all Claude models.
44
+
45
+ This is a wrapper around a dictionary of model name to ModelPricing
46
+ that provides dict-like access while maintaining type safety.
47
+ """
48
+
49
+ def __iter__(self) -> Iterator[str]: # type: ignore[override]
50
+ """Iterate over model names."""
51
+ return iter(self.root)
52
+
53
+ def __getitem__(self, model_name: str) -> ModelPricing:
54
+ """Get pricing for a specific model."""
55
+ return self.root[model_name]
56
+
57
+ def __contains__(self, model_name: str) -> bool:
58
+ """Check if model exists in pricing data."""
59
+ return model_name in self.root
60
+
61
+ def __len__(self) -> int:
62
+ """Get number of models in pricing data."""
63
+ return len(self.root)
64
+
65
+ def items(self) -> Iterator[tuple[str, ModelPricing]]:
66
+ """Get model name and pricing pairs."""
67
+ return iter(self.root.items())
68
+
69
+ def keys(self) -> Iterator[str]:
70
+ """Get model names."""
71
+ return iter(self.root.keys())
72
+
73
+ def values(self) -> Iterator[ModelPricing]:
74
+ """Get pricing objects."""
75
+ return iter(self.root.values())
76
+
77
+ def get(
78
+ self, model_name: str, default: ModelPricing | None = None
79
+ ) -> ModelPricing | None:
80
+ """Get pricing for a model with optional default."""
81
+ return self.root.get(model_name, default)
82
+
83
+ def model_names(self) -> list[str]:
84
+ """Get list of all model names."""
85
+ return list(self.root.keys())
86
+
87
+ def to_dict(self) -> dict[str, dict[str, Decimal]]:
88
+ """Convert to legacy dict format for backward compatibility."""
89
+ return {
90
+ model_name: {
91
+ "input": pricing.input,
92
+ "output": pricing.output,
93
+ "cache_read": pricing.cache_read,
94
+ "cache_write": pricing.cache_write,
95
+ }
96
+ for model_name, pricing in self.root.items()
97
+ }
98
+
99
+ @classmethod
100
+ def from_dict(cls, data: dict[str, dict[str, Any]]) -> "PricingData":
101
+ """Create PricingData from legacy dict format."""
102
+ models = {
103
+ model_name: ModelPricing(**pricing_dict)
104
+ for model_name, pricing_dict in data.items()
105
+ }
106
+ return cls(root=models)