router-maestro 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- router_maestro/__init__.py +3 -0
- router_maestro/__main__.py +6 -0
- router_maestro/auth/__init__.py +18 -0
- router_maestro/auth/github_oauth.py +181 -0
- router_maestro/auth/manager.py +136 -0
- router_maestro/auth/storage.py +91 -0
- router_maestro/cli/__init__.py +1 -0
- router_maestro/cli/auth.py +167 -0
- router_maestro/cli/client.py +322 -0
- router_maestro/cli/config.py +132 -0
- router_maestro/cli/context.py +146 -0
- router_maestro/cli/main.py +42 -0
- router_maestro/cli/model.py +288 -0
- router_maestro/cli/server.py +117 -0
- router_maestro/cli/stats.py +76 -0
- router_maestro/config/__init__.py +72 -0
- router_maestro/config/contexts.py +29 -0
- router_maestro/config/paths.py +50 -0
- router_maestro/config/priorities.py +93 -0
- router_maestro/config/providers.py +34 -0
- router_maestro/config/server.py +115 -0
- router_maestro/config/settings.py +76 -0
- router_maestro/providers/__init__.py +31 -0
- router_maestro/providers/anthropic.py +203 -0
- router_maestro/providers/base.py +123 -0
- router_maestro/providers/copilot.py +346 -0
- router_maestro/providers/openai.py +188 -0
- router_maestro/providers/openai_compat.py +175 -0
- router_maestro/routing/__init__.py +5 -0
- router_maestro/routing/router.py +526 -0
- router_maestro/server/__init__.py +5 -0
- router_maestro/server/app.py +87 -0
- router_maestro/server/middleware/__init__.py +11 -0
- router_maestro/server/middleware/auth.py +66 -0
- router_maestro/server/oauth_sessions.py +159 -0
- router_maestro/server/routes/__init__.py +8 -0
- router_maestro/server/routes/admin.py +358 -0
- router_maestro/server/routes/anthropic.py +228 -0
- router_maestro/server/routes/chat.py +142 -0
- router_maestro/server/routes/models.py +34 -0
- router_maestro/server/schemas/__init__.py +57 -0
- router_maestro/server/schemas/admin.py +87 -0
- router_maestro/server/schemas/anthropic.py +246 -0
- router_maestro/server/schemas/openai.py +107 -0
- router_maestro/server/translation.py +636 -0
- router_maestro/stats/__init__.py +14 -0
- router_maestro/stats/heatmap.py +154 -0
- router_maestro/stats/storage.py +228 -0
- router_maestro/stats/tracker.py +73 -0
- router_maestro/utils/__init__.py +16 -0
- router_maestro/utils/logging.py +81 -0
- router_maestro/utils/tokens.py +51 -0
- router_maestro-0.1.2.dist-info/METADATA +383 -0
- router_maestro-0.1.2.dist-info/RECORD +57 -0
- router_maestro-0.1.2.dist-info/WHEEL +4 -0
- router_maestro-0.1.2.dist-info/entry_points.txt +2 -0
- router_maestro-0.1.2.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,526 @@
|
|
|
1
|
+
"""Model router with priority-based selection and fallback."""
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
from collections.abc import AsyncIterator
|
|
5
|
+
|
|
6
|
+
from router_maestro.auth import ApiKeyCredential, AuthManager
|
|
7
|
+
from router_maestro.config import (
|
|
8
|
+
FallbackStrategy,
|
|
9
|
+
PrioritiesConfig,
|
|
10
|
+
load_priorities_config,
|
|
11
|
+
load_providers_config,
|
|
12
|
+
)
|
|
13
|
+
from router_maestro.providers import (
|
|
14
|
+
BaseProvider,
|
|
15
|
+
ChatRequest,
|
|
16
|
+
ChatResponse,
|
|
17
|
+
ChatStreamChunk,
|
|
18
|
+
CopilotProvider,
|
|
19
|
+
ModelInfo,
|
|
20
|
+
OpenAICompatibleProvider,
|
|
21
|
+
ProviderError,
|
|
22
|
+
)
|
|
23
|
+
from router_maestro.utils import get_logger
|
|
24
|
+
|
|
25
|
+
logger = get_logger("routing")
|
|
26
|
+
|
|
27
|
+
# Special model name that triggers auto-routing
|
|
28
|
+
AUTO_ROUTE_MODEL = "router-maestro"
|
|
29
|
+
|
|
30
|
+
# Cache TTL in seconds (5 minutes)
|
|
31
|
+
CACHE_TTL_SECONDS = 300
|
|
32
|
+
|
|
33
|
+
# Global singleton instance
|
|
34
|
+
_router_instance: "Router | None" = None
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def get_router() -> "Router":
|
|
38
|
+
"""Get the singleton Router instance.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
The global Router instance
|
|
42
|
+
"""
|
|
43
|
+
global _router_instance
|
|
44
|
+
if _router_instance is None:
|
|
45
|
+
_router_instance = Router()
|
|
46
|
+
logger.info("Created singleton Router instance")
|
|
47
|
+
return _router_instance
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def reset_router() -> None:
|
|
51
|
+
"""Reset the singleton Router instance.
|
|
52
|
+
|
|
53
|
+
Call this when authentication changes or to force reload.
|
|
54
|
+
"""
|
|
55
|
+
global _router_instance
|
|
56
|
+
if _router_instance is not None:
|
|
57
|
+
_router_instance.invalidate_cache()
|
|
58
|
+
_router_instance = None
|
|
59
|
+
logger.info("Reset singleton Router instance")
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class Router:
|
|
63
|
+
"""Router for model requests with priority and fallback support."""
|
|
64
|
+
|
|
65
|
+
def __init__(self) -> None:
|
|
66
|
+
self.providers: dict[str, BaseProvider] = {}
|
|
67
|
+
# Model cache: maps model_id -> (provider_name, ModelInfo)
|
|
68
|
+
self._models_cache: dict[str, tuple[str, ModelInfo]] = {}
|
|
69
|
+
self._cache_initialized: bool = False
|
|
70
|
+
self._cache_timestamp: float = 0.0
|
|
71
|
+
# Priorities config cache
|
|
72
|
+
self._priorities_config: PrioritiesConfig | None = None
|
|
73
|
+
self._priorities_config_timestamp: float = 0.0
|
|
74
|
+
# Providers config cache
|
|
75
|
+
self._providers_config_timestamp: float = 0.0
|
|
76
|
+
self._load_providers()
|
|
77
|
+
|
|
78
|
+
def _load_providers(self) -> None:
|
|
79
|
+
"""Load providers from configuration."""
|
|
80
|
+
custom_providers_config = load_providers_config()
|
|
81
|
+
auth_manager = AuthManager()
|
|
82
|
+
|
|
83
|
+
# Clear existing providers except keep copilot if already exists
|
|
84
|
+
old_copilot = self.providers.get("github-copilot")
|
|
85
|
+
self.providers.clear()
|
|
86
|
+
|
|
87
|
+
# Always add built-in GitHub Copilot provider (reuse existing instance if available)
|
|
88
|
+
if old_copilot is not None:
|
|
89
|
+
self.providers["github-copilot"] = old_copilot
|
|
90
|
+
else:
|
|
91
|
+
copilot = CopilotProvider()
|
|
92
|
+
self.providers["github-copilot"] = copilot
|
|
93
|
+
logger.debug("Loaded built-in provider: github-copilot")
|
|
94
|
+
|
|
95
|
+
# Load custom providers from providers.json
|
|
96
|
+
for provider_name, provider_config in custom_providers_config.providers.items():
|
|
97
|
+
# Get API key from auth storage
|
|
98
|
+
cred = auth_manager.get_credential(provider_name)
|
|
99
|
+
if isinstance(cred, ApiKeyCredential):
|
|
100
|
+
provider = OpenAICompatibleProvider(
|
|
101
|
+
name=provider_name,
|
|
102
|
+
base_url=provider_config.baseURL,
|
|
103
|
+
api_key=cred.key,
|
|
104
|
+
models={
|
|
105
|
+
model_id: model_config.name
|
|
106
|
+
for model_id, model_config in provider_config.models.items()
|
|
107
|
+
},
|
|
108
|
+
)
|
|
109
|
+
self.providers[provider_name] = provider
|
|
110
|
+
logger.debug("Loaded custom provider: %s", provider_name)
|
|
111
|
+
|
|
112
|
+
self._providers_config_timestamp = time.time()
|
|
113
|
+
logger.info("Loaded %d providers", len(self.providers))
|
|
114
|
+
|
|
115
|
+
def _get_priorities_config(self) -> PrioritiesConfig:
|
|
116
|
+
"""Get priorities config with caching."""
|
|
117
|
+
# Simple time-based cache (same TTL as models cache)
|
|
118
|
+
current_time = time.time()
|
|
119
|
+
if (
|
|
120
|
+
self._priorities_config is not None
|
|
121
|
+
and current_time - self._priorities_config_timestamp < CACHE_TTL_SECONDS
|
|
122
|
+
):
|
|
123
|
+
return self._priorities_config
|
|
124
|
+
|
|
125
|
+
self._priorities_config = load_priorities_config()
|
|
126
|
+
self._priorities_config_timestamp = current_time
|
|
127
|
+
return self._priorities_config
|
|
128
|
+
|
|
129
|
+
def _ensure_providers_fresh(self) -> None:
|
|
130
|
+
"""Ensure providers config is fresh, reload if expired."""
|
|
131
|
+
current_time = time.time()
|
|
132
|
+
if current_time - self._providers_config_timestamp >= CACHE_TTL_SECONDS:
|
|
133
|
+
logger.debug("Providers config expired, reloading")
|
|
134
|
+
self._load_providers()
|
|
135
|
+
# Also invalidate models cache since providers may have changed
|
|
136
|
+
self._models_cache.clear()
|
|
137
|
+
self._cache_initialized = False
|
|
138
|
+
|
|
139
|
+
def _parse_model_key(self, model_key: str) -> tuple[str, str]:
|
|
140
|
+
"""Parse a model key into provider and model.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
model_key: Model key in format 'provider/model'
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
Tuple of (provider_name, model_id)
|
|
147
|
+
"""
|
|
148
|
+
if "/" in model_key:
|
|
149
|
+
parts = model_key.split("/", 1)
|
|
150
|
+
return parts[0], parts[1]
|
|
151
|
+
return "", model_key
|
|
152
|
+
|
|
153
|
+
async def _ensure_models_cache(self) -> None:
|
|
154
|
+
"""Ensure the models cache is populated and not expired."""
|
|
155
|
+
# First ensure providers config is fresh
|
|
156
|
+
self._ensure_providers_fresh()
|
|
157
|
+
|
|
158
|
+
# Check if cache is still valid (initialized and not expired)
|
|
159
|
+
if self._cache_initialized:
|
|
160
|
+
age = time.time() - self._cache_timestamp
|
|
161
|
+
if age < CACHE_TTL_SECONDS:
|
|
162
|
+
return
|
|
163
|
+
logger.debug("Cache expired (age=%.1fs), refreshing", age)
|
|
164
|
+
self._models_cache.clear()
|
|
165
|
+
|
|
166
|
+
logger.debug("Initializing models cache")
|
|
167
|
+
for provider_name, provider in self.providers.items():
|
|
168
|
+
if provider.is_authenticated():
|
|
169
|
+
try:
|
|
170
|
+
await provider.ensure_token()
|
|
171
|
+
models = await provider.list_models()
|
|
172
|
+
for model in models:
|
|
173
|
+
# Store by model_id only (without provider prefix)
|
|
174
|
+
# If same model_id exists in multiple providers, first one wins
|
|
175
|
+
if model.id not in self._models_cache:
|
|
176
|
+
self._models_cache[model.id] = (provider_name, model)
|
|
177
|
+
# Also store with provider prefix for explicit lookups
|
|
178
|
+
full_key = f"{provider_name}/{model.id}"
|
|
179
|
+
self._models_cache[full_key] = (provider_name, model)
|
|
180
|
+
logger.debug("Cached %d models from %s", len(models), provider_name)
|
|
181
|
+
except ProviderError as e:
|
|
182
|
+
logger.warning("Failed to load models from %s: %s", provider_name, e)
|
|
183
|
+
continue
|
|
184
|
+
|
|
185
|
+
self._cache_initialized = True
|
|
186
|
+
self._cache_timestamp = time.time()
|
|
187
|
+
logger.info("Models cache initialized with %d entries", len(self._models_cache))
|
|
188
|
+
|
|
189
|
+
async def _resolve_provider(self, model_id: str) -> tuple[str, str, BaseProvider]:
|
|
190
|
+
"""Resolve model_id to provider.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
model_id: Model ID (can be 'router-maestro', 'provider/model', or just 'model')
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
Tuple of (provider_name, actual_model_id, provider)
|
|
197
|
+
|
|
198
|
+
Raises:
|
|
199
|
+
ProviderError: If model not found or no models available
|
|
200
|
+
"""
|
|
201
|
+
# Check for auto-routing
|
|
202
|
+
if model_id == AUTO_ROUTE_MODEL:
|
|
203
|
+
result = await self._get_auto_route_model()
|
|
204
|
+
if not result:
|
|
205
|
+
logger.error("No models available for auto-routing")
|
|
206
|
+
raise ProviderError("No models available for auto-routing", status_code=503)
|
|
207
|
+
return result
|
|
208
|
+
|
|
209
|
+
# Explicit model specified - find in cache
|
|
210
|
+
result = await self._find_model_in_cache(model_id)
|
|
211
|
+
if not result:
|
|
212
|
+
logger.warning("Model not found: %s", model_id)
|
|
213
|
+
raise ProviderError(
|
|
214
|
+
f"Model '{model_id}' not found in any provider",
|
|
215
|
+
status_code=404,
|
|
216
|
+
)
|
|
217
|
+
return result
|
|
218
|
+
|
|
219
|
+
def _create_request_with_model(
|
|
220
|
+
self, original_request: ChatRequest, model_id: str
|
|
221
|
+
) -> ChatRequest:
|
|
222
|
+
"""Create a new ChatRequest with a different model ID.
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
original_request: The original request
|
|
226
|
+
model_id: The new model ID to use
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
New ChatRequest with updated model
|
|
230
|
+
"""
|
|
231
|
+
return ChatRequest(
|
|
232
|
+
model=model_id,
|
|
233
|
+
messages=original_request.messages,
|
|
234
|
+
temperature=original_request.temperature,
|
|
235
|
+
max_tokens=original_request.max_tokens,
|
|
236
|
+
stream=original_request.stream,
|
|
237
|
+
tools=original_request.tools,
|
|
238
|
+
tool_choice=original_request.tool_choice,
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
async def _get_auto_route_model(self) -> tuple[str, str, BaseProvider] | None:
|
|
242
|
+
"""Get the highest priority available model for auto-routing.
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
Tuple of (provider_name, model_id, provider) or None if no model available
|
|
246
|
+
"""
|
|
247
|
+
await self._ensure_models_cache()
|
|
248
|
+
priorities_config = self._get_priorities_config()
|
|
249
|
+
|
|
250
|
+
# Try each priority in order
|
|
251
|
+
for priority_key in priorities_config.priorities:
|
|
252
|
+
provider_name, model_id = self._parse_model_key(priority_key)
|
|
253
|
+
if provider_name in self.providers:
|
|
254
|
+
provider = self.providers[provider_name]
|
|
255
|
+
if provider.is_authenticated():
|
|
256
|
+
# Verify model exists in cache
|
|
257
|
+
if priority_key in self._models_cache:
|
|
258
|
+
logger.debug("Auto-route selected: %s", priority_key)
|
|
259
|
+
return provider_name, model_id, provider
|
|
260
|
+
|
|
261
|
+
# Fallback: return first available model from any provider
|
|
262
|
+
for model_id, (provider_name, _) in self._models_cache.items():
|
|
263
|
+
if "/" not in model_id: # Skip full keys, only use simple model_ids
|
|
264
|
+
provider = self.providers.get(provider_name)
|
|
265
|
+
if provider and provider.is_authenticated():
|
|
266
|
+
logger.debug("Auto-route fallback: %s/%s", provider_name, model_id)
|
|
267
|
+
return provider_name, model_id, provider
|
|
268
|
+
|
|
269
|
+
return None
|
|
270
|
+
|
|
271
|
+
async def _find_model_in_cache(self, model_id: str) -> tuple[str, str, BaseProvider] | None:
|
|
272
|
+
"""Find a model in the cache.
|
|
273
|
+
|
|
274
|
+
Args:
|
|
275
|
+
model_id: Model ID (can be 'provider/model' or just 'model')
|
|
276
|
+
|
|
277
|
+
Returns:
|
|
278
|
+
Tuple of (provider_name, actual_model_id, provider) or None
|
|
279
|
+
"""
|
|
280
|
+
await self._ensure_models_cache()
|
|
281
|
+
|
|
282
|
+
# If model_id includes provider prefix (e.g., "github-copilot/gpt-4o")
|
|
283
|
+
if "/" in model_id:
|
|
284
|
+
provider_name, actual_model_id = self._parse_model_key(model_id)
|
|
285
|
+
if provider_name in self.providers:
|
|
286
|
+
provider = self.providers[provider_name]
|
|
287
|
+
if provider.is_authenticated():
|
|
288
|
+
# Check if the model exists for this provider
|
|
289
|
+
if model_id in self._models_cache:
|
|
290
|
+
return provider_name, actual_model_id, provider
|
|
291
|
+
return None
|
|
292
|
+
|
|
293
|
+
# Simple model_id (e.g., "gpt-4o") - look up in cache
|
|
294
|
+
if model_id in self._models_cache:
|
|
295
|
+
provider_name, _ = self._models_cache[model_id]
|
|
296
|
+
provider = self.providers.get(provider_name)
|
|
297
|
+
if provider and provider.is_authenticated():
|
|
298
|
+
return provider_name, model_id, provider
|
|
299
|
+
|
|
300
|
+
return None
|
|
301
|
+
|
|
302
|
+
def _get_fallback_candidates(
|
|
303
|
+
self,
|
|
304
|
+
current_provider: str,
|
|
305
|
+
current_model: str,
|
|
306
|
+
strategy: FallbackStrategy,
|
|
307
|
+
) -> list[tuple[str, str, BaseProvider]]:
|
|
308
|
+
"""Get ordered list of fallback candidates based on strategy.
|
|
309
|
+
|
|
310
|
+
Args:
|
|
311
|
+
current_provider: The provider that just failed
|
|
312
|
+
current_model: The model that was requested
|
|
313
|
+
strategy: The fallback strategy to use
|
|
314
|
+
|
|
315
|
+
Returns:
|
|
316
|
+
List of (provider_name, model_id, provider) tuples to try
|
|
317
|
+
"""
|
|
318
|
+
if strategy == FallbackStrategy.NONE:
|
|
319
|
+
return []
|
|
320
|
+
|
|
321
|
+
candidates: list[tuple[str, str, BaseProvider]] = []
|
|
322
|
+
current_key = f"{current_provider}/{current_model}"
|
|
323
|
+
|
|
324
|
+
if strategy == FallbackStrategy.PRIORITY:
|
|
325
|
+
# Follow the priorities list order, starting after current
|
|
326
|
+
priorities_config = self._get_priorities_config()
|
|
327
|
+
found_current = False
|
|
328
|
+
|
|
329
|
+
for priority_key in priorities_config.priorities:
|
|
330
|
+
if priority_key == current_key:
|
|
331
|
+
found_current = True
|
|
332
|
+
continue
|
|
333
|
+
|
|
334
|
+
if found_current:
|
|
335
|
+
provider_name, model_id = self._parse_model_key(priority_key)
|
|
336
|
+
if provider_name in self.providers:
|
|
337
|
+
provider = self.providers[provider_name]
|
|
338
|
+
if provider.is_authenticated():
|
|
339
|
+
if priority_key in self._models_cache:
|
|
340
|
+
candidates.append((provider_name, model_id, provider))
|
|
341
|
+
|
|
342
|
+
elif strategy == FallbackStrategy.SAME_MODEL:
|
|
343
|
+
# Only try other providers that have the same model
|
|
344
|
+
for other_name, other_provider in self.providers.items():
|
|
345
|
+
if other_name == current_provider:
|
|
346
|
+
continue
|
|
347
|
+
if not other_provider.is_authenticated():
|
|
348
|
+
continue
|
|
349
|
+
other_key = f"{other_name}/{current_model}"
|
|
350
|
+
if other_key in self._models_cache:
|
|
351
|
+
candidates.append((other_name, current_model, other_provider))
|
|
352
|
+
|
|
353
|
+
return candidates
|
|
354
|
+
|
|
355
|
+
async def _execute_with_fallback(
|
|
356
|
+
self,
|
|
357
|
+
request: ChatRequest,
|
|
358
|
+
provider_name: str,
|
|
359
|
+
actual_model_id: str,
|
|
360
|
+
provider: BaseProvider,
|
|
361
|
+
fallback: bool,
|
|
362
|
+
is_stream: bool,
|
|
363
|
+
) -> tuple[ChatResponse | AsyncIterator[ChatStreamChunk], str]:
|
|
364
|
+
"""Execute request with fallback support.
|
|
365
|
+
|
|
366
|
+
Args:
|
|
367
|
+
request: Original chat request
|
|
368
|
+
provider_name: Name of the primary provider
|
|
369
|
+
actual_model_id: The actual model ID to use
|
|
370
|
+
provider: The primary provider instance
|
|
371
|
+
fallback: Whether to try fallback providers on error
|
|
372
|
+
is_stream: Whether this is a streaming request
|
|
373
|
+
|
|
374
|
+
Returns:
|
|
375
|
+
Tuple of (response or stream, provider_name)
|
|
376
|
+
|
|
377
|
+
Raises:
|
|
378
|
+
ProviderError: If all providers fail
|
|
379
|
+
"""
|
|
380
|
+
actual_request = self._create_request_with_model(request, actual_model_id)
|
|
381
|
+
|
|
382
|
+
try:
|
|
383
|
+
await provider.ensure_token()
|
|
384
|
+
if is_stream:
|
|
385
|
+
stream = provider.chat_completion_stream(actual_request)
|
|
386
|
+
logger.info("Stream request routed to %s", provider_name)
|
|
387
|
+
return stream, provider_name
|
|
388
|
+
else:
|
|
389
|
+
response = await provider.chat_completion(actual_request)
|
|
390
|
+
logger.info("Request completed via %s", provider_name)
|
|
391
|
+
return response, provider_name
|
|
392
|
+
except ProviderError as e:
|
|
393
|
+
logger.warning("Provider %s failed: %s", provider_name, e)
|
|
394
|
+
if not fallback or not e.retryable:
|
|
395
|
+
raise
|
|
396
|
+
|
|
397
|
+
# Load fallback config
|
|
398
|
+
priorities_config = self._get_priorities_config()
|
|
399
|
+
fallback_config = priorities_config.fallback
|
|
400
|
+
|
|
401
|
+
if fallback_config.strategy == FallbackStrategy.NONE:
|
|
402
|
+
raise
|
|
403
|
+
|
|
404
|
+
# Get fallback candidates
|
|
405
|
+
candidates = self._get_fallback_candidates(
|
|
406
|
+
provider_name, actual_model_id, fallback_config.strategy
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
# Try fallback candidates up to maxRetries
|
|
410
|
+
for i, (other_name, other_model_id, other_provider) in enumerate(candidates):
|
|
411
|
+
if i >= fallback_config.maxRetries:
|
|
412
|
+
break
|
|
413
|
+
|
|
414
|
+
logger.info("Trying fallback: %s/%s", other_name, other_model_id)
|
|
415
|
+
fallback_request = self._create_request_with_model(request, other_model_id)
|
|
416
|
+
|
|
417
|
+
try:
|
|
418
|
+
await other_provider.ensure_token()
|
|
419
|
+
if is_stream:
|
|
420
|
+
stream = other_provider.chat_completion_stream(fallback_request)
|
|
421
|
+
logger.info("Stream fallback succeeded via %s", other_name)
|
|
422
|
+
return stream, other_name
|
|
423
|
+
else:
|
|
424
|
+
response = await other_provider.chat_completion(fallback_request)
|
|
425
|
+
logger.info("Fallback succeeded via %s", other_name)
|
|
426
|
+
return response, other_name
|
|
427
|
+
except ProviderError as fallback_error:
|
|
428
|
+
logger.warning("Fallback %s failed: %s", other_name, fallback_error)
|
|
429
|
+
continue
|
|
430
|
+
raise
|
|
431
|
+
|
|
432
|
+
async def chat_completion(
|
|
433
|
+
self,
|
|
434
|
+
request: ChatRequest,
|
|
435
|
+
fallback: bool = True,
|
|
436
|
+
) -> tuple[ChatResponse, str]:
|
|
437
|
+
"""Route a chat completion request.
|
|
438
|
+
|
|
439
|
+
Args:
|
|
440
|
+
request: Chat completion request
|
|
441
|
+
fallback: Whether to try fallback providers on error
|
|
442
|
+
|
|
443
|
+
Returns:
|
|
444
|
+
Tuple of (response, provider_name)
|
|
445
|
+
|
|
446
|
+
Raises:
|
|
447
|
+
ProviderError: If model not found or all providers fail
|
|
448
|
+
"""
|
|
449
|
+
provider_name, actual_model_id, provider = await self._resolve_provider(request.model)
|
|
450
|
+
logger.info("Routing request to %s/%s", provider_name, actual_model_id)
|
|
451
|
+
|
|
452
|
+
result, used_provider = await self._execute_with_fallback(
|
|
453
|
+
request, provider_name, actual_model_id, provider, fallback, is_stream=False
|
|
454
|
+
)
|
|
455
|
+
return result, used_provider # type: ignore
|
|
456
|
+
|
|
457
|
+
async def chat_completion_stream(
|
|
458
|
+
self,
|
|
459
|
+
request: ChatRequest,
|
|
460
|
+
fallback: bool = True,
|
|
461
|
+
) -> tuple[AsyncIterator[ChatStreamChunk], str]:
|
|
462
|
+
"""Route a streaming chat completion request.
|
|
463
|
+
|
|
464
|
+
Args:
|
|
465
|
+
request: Chat completion request
|
|
466
|
+
fallback: Whether to try fallback providers on error
|
|
467
|
+
|
|
468
|
+
Returns:
|
|
469
|
+
Tuple of (stream iterator, provider_name)
|
|
470
|
+
|
|
471
|
+
Raises:
|
|
472
|
+
ProviderError: If model not found or all providers fail
|
|
473
|
+
"""
|
|
474
|
+
provider_name, actual_model_id, provider = await self._resolve_provider(request.model)
|
|
475
|
+
logger.info("Routing stream request to %s/%s", provider_name, actual_model_id)
|
|
476
|
+
|
|
477
|
+
result, used_provider = await self._execute_with_fallback(
|
|
478
|
+
request, provider_name, actual_model_id, provider, fallback, is_stream=True
|
|
479
|
+
)
|
|
480
|
+
return result, used_provider # type: ignore
|
|
481
|
+
|
|
482
|
+
async def list_models(self) -> list[ModelInfo]:
|
|
483
|
+
"""List all available models from all authenticated providers.
|
|
484
|
+
|
|
485
|
+
Models are sorted by priority configuration.
|
|
486
|
+
|
|
487
|
+
Returns:
|
|
488
|
+
List of available models
|
|
489
|
+
"""
|
|
490
|
+
await self._ensure_models_cache()
|
|
491
|
+
priorities_config = self._get_priorities_config()
|
|
492
|
+
|
|
493
|
+
models: list[ModelInfo] = []
|
|
494
|
+
seen: set[str] = set()
|
|
495
|
+
|
|
496
|
+
# Collect all models with their full keys
|
|
497
|
+
all_models: dict[str, ModelInfo] = {}
|
|
498
|
+
for key, (_, model_info) in self._models_cache.items():
|
|
499
|
+
# Only include full keys (provider/model)
|
|
500
|
+
if "/" in key:
|
|
501
|
+
all_models[key] = model_info
|
|
502
|
+
|
|
503
|
+
# Add prioritized models first
|
|
504
|
+
for priority_key in priorities_config.priorities:
|
|
505
|
+
if priority_key in all_models and priority_key not in seen:
|
|
506
|
+
models.append(all_models[priority_key])
|
|
507
|
+
seen.add(priority_key)
|
|
508
|
+
|
|
509
|
+
# Add remaining models
|
|
510
|
+
for key, model in all_models.items():
|
|
511
|
+
if key not in seen:
|
|
512
|
+
models.append(model)
|
|
513
|
+
seen.add(key)
|
|
514
|
+
|
|
515
|
+
logger.debug("Listed %d models", len(models))
|
|
516
|
+
return models
|
|
517
|
+
|
|
518
|
+
def invalidate_cache(self) -> None:
|
|
519
|
+
"""Invalidate all caches to force refresh."""
|
|
520
|
+
self._models_cache.clear()
|
|
521
|
+
self._cache_initialized = False
|
|
522
|
+
self._cache_timestamp = 0.0
|
|
523
|
+
self._priorities_config = None
|
|
524
|
+
self._priorities_config_timestamp = 0.0
|
|
525
|
+
self._providers_config_timestamp = 0.0
|
|
526
|
+
logger.debug("All caches invalidated")
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
"""FastAPI application for router-maestro."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from contextlib import asynccontextmanager
|
|
5
|
+
|
|
6
|
+
from fastapi import Depends, FastAPI
|
|
7
|
+
from fastapi.middleware.cors import CORSMiddleware
|
|
8
|
+
|
|
9
|
+
from router_maestro import __version__
|
|
10
|
+
from router_maestro.routing import get_router
|
|
11
|
+
from router_maestro.server.middleware import verify_api_key
|
|
12
|
+
from router_maestro.server.routes import admin_router, anthropic_router, chat_router, models_router
|
|
13
|
+
from router_maestro.utils import get_logger, setup_logging
|
|
14
|
+
|
|
15
|
+
logger = get_logger("server")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@asynccontextmanager
|
|
19
|
+
async def lifespan(app: FastAPI):
|
|
20
|
+
"""Application lifespan handler."""
|
|
21
|
+
# Startup - initialize logging
|
|
22
|
+
log_level = os.environ.get("ROUTER_MAESTRO_LOG_LEVEL", "INFO")
|
|
23
|
+
setup_logging(level=log_level)
|
|
24
|
+
logger.info("Router-Maestro server starting up")
|
|
25
|
+
|
|
26
|
+
# Pre-warm model cache if any providers are authenticated
|
|
27
|
+
router = get_router()
|
|
28
|
+
authenticated_providers = [
|
|
29
|
+
name for name, provider in router.providers.items() if provider.is_authenticated()
|
|
30
|
+
]
|
|
31
|
+
if authenticated_providers:
|
|
32
|
+
logger.info(
|
|
33
|
+
"Pre-warming model cache for authenticated providers: %s", authenticated_providers
|
|
34
|
+
)
|
|
35
|
+
try:
|
|
36
|
+
models = await router.list_models()
|
|
37
|
+
logger.info("Model cache pre-warmed with %d models", len(models))
|
|
38
|
+
except Exception as e:
|
|
39
|
+
logger.warning("Failed to pre-warm model cache: %s", e)
|
|
40
|
+
|
|
41
|
+
yield
|
|
42
|
+
# Shutdown
|
|
43
|
+
logger.info("Router-Maestro server shutting down")
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def create_app() -> FastAPI:
|
|
47
|
+
"""Create the FastAPI application."""
|
|
48
|
+
app = FastAPI(
|
|
49
|
+
title="Router-Maestro",
|
|
50
|
+
description="Multi-model routing and load balancing with OpenAI-compatible API",
|
|
51
|
+
version=__version__,
|
|
52
|
+
lifespan=lifespan,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
# Add CORS middleware
|
|
56
|
+
app.add_middleware(
|
|
57
|
+
CORSMiddleware,
|
|
58
|
+
allow_origins=["*"],
|
|
59
|
+
allow_credentials=True,
|
|
60
|
+
allow_methods=["*"],
|
|
61
|
+
allow_headers=["*"],
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
# Include routers with API key verification
|
|
65
|
+
app.include_router(chat_router, dependencies=[Depends(verify_api_key)])
|
|
66
|
+
app.include_router(models_router, dependencies=[Depends(verify_api_key)])
|
|
67
|
+
app.include_router(anthropic_router, dependencies=[Depends(verify_api_key)])
|
|
68
|
+
app.include_router(admin_router, dependencies=[Depends(verify_api_key)])
|
|
69
|
+
|
|
70
|
+
@app.get("/")
|
|
71
|
+
async def root():
|
|
72
|
+
"""Root endpoint."""
|
|
73
|
+
return {
|
|
74
|
+
"name": "Router-Maestro",
|
|
75
|
+
"version": __version__,
|
|
76
|
+
"status": "running",
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
@app.get("/health")
|
|
80
|
+
async def health():
|
|
81
|
+
"""Health check endpoint."""
|
|
82
|
+
return {"status": "healthy"}
|
|
83
|
+
|
|
84
|
+
return app
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
app = create_app()
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
"""Authentication middleware for API key validation."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
from fastapi import Depends, HTTPException, Request, status
|
|
6
|
+
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
|
7
|
+
|
|
8
|
+
security = HTTPBearer(auto_error=False)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def get_server_api_key() -> str | None:
|
|
12
|
+
"""Get the server API key from environment variable."""
|
|
13
|
+
return os.environ.get("ROUTER_MAESTRO_API_KEY")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
async def verify_api_key(
|
|
17
|
+
request: Request,
|
|
18
|
+
credentials: HTTPAuthorizationCredentials | None = Depends(security),
|
|
19
|
+
) -> None:
|
|
20
|
+
"""Verify the API key from the Authorization header.
|
|
21
|
+
|
|
22
|
+
Accepts both:
|
|
23
|
+
- Authorization: Bearer <api_key>
|
|
24
|
+
- Authorization: <api_key>
|
|
25
|
+
"""
|
|
26
|
+
server_api_key = get_server_api_key()
|
|
27
|
+
|
|
28
|
+
if server_api_key is None:
|
|
29
|
+
# No API key configured, allow all requests
|
|
30
|
+
return
|
|
31
|
+
|
|
32
|
+
# Skip auth for health and root endpoints
|
|
33
|
+
if request.url.path in ("/", "/health", "/docs", "/openapi.json", "/redoc"):
|
|
34
|
+
return
|
|
35
|
+
|
|
36
|
+
# Get API key from header
|
|
37
|
+
api_key: str | None = None
|
|
38
|
+
|
|
39
|
+
if credentials:
|
|
40
|
+
api_key = credentials.credentials
|
|
41
|
+
else:
|
|
42
|
+
# Try to get from Authorization header directly (without Bearer prefix)
|
|
43
|
+
auth_header = request.headers.get("Authorization")
|
|
44
|
+
if auth_header:
|
|
45
|
+
if auth_header.startswith("Bearer "):
|
|
46
|
+
api_key = auth_header[7:]
|
|
47
|
+
else:
|
|
48
|
+
api_key = auth_header
|
|
49
|
+
|
|
50
|
+
# Also check x-api-key header for Anthropic API compatibility
|
|
51
|
+
if not api_key:
|
|
52
|
+
api_key = request.headers.get("x-api-key")
|
|
53
|
+
|
|
54
|
+
if not api_key:
|
|
55
|
+
raise HTTPException(
|
|
56
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
57
|
+
detail="Missing API key. Use 'Authorization: Bearer <api_key>' header.",
|
|
58
|
+
headers={"WWW-Authenticate": "Bearer"},
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
if api_key != server_api_key:
|
|
62
|
+
raise HTTPException(
|
|
63
|
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
64
|
+
detail="Invalid API key",
|
|
65
|
+
headers={"WWW-Authenticate": "Bearer"},
|
|
66
|
+
)
|