router-maestro 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. router_maestro/__init__.py +3 -0
  2. router_maestro/__main__.py +6 -0
  3. router_maestro/auth/__init__.py +18 -0
  4. router_maestro/auth/github_oauth.py +181 -0
  5. router_maestro/auth/manager.py +136 -0
  6. router_maestro/auth/storage.py +91 -0
  7. router_maestro/cli/__init__.py +1 -0
  8. router_maestro/cli/auth.py +167 -0
  9. router_maestro/cli/client.py +322 -0
  10. router_maestro/cli/config.py +132 -0
  11. router_maestro/cli/context.py +146 -0
  12. router_maestro/cli/main.py +42 -0
  13. router_maestro/cli/model.py +288 -0
  14. router_maestro/cli/server.py +117 -0
  15. router_maestro/cli/stats.py +76 -0
  16. router_maestro/config/__init__.py +72 -0
  17. router_maestro/config/contexts.py +29 -0
  18. router_maestro/config/paths.py +50 -0
  19. router_maestro/config/priorities.py +93 -0
  20. router_maestro/config/providers.py +34 -0
  21. router_maestro/config/server.py +115 -0
  22. router_maestro/config/settings.py +76 -0
  23. router_maestro/providers/__init__.py +31 -0
  24. router_maestro/providers/anthropic.py +203 -0
  25. router_maestro/providers/base.py +123 -0
  26. router_maestro/providers/copilot.py +346 -0
  27. router_maestro/providers/openai.py +188 -0
  28. router_maestro/providers/openai_compat.py +175 -0
  29. router_maestro/routing/__init__.py +5 -0
  30. router_maestro/routing/router.py +526 -0
  31. router_maestro/server/__init__.py +5 -0
  32. router_maestro/server/app.py +87 -0
  33. router_maestro/server/middleware/__init__.py +11 -0
  34. router_maestro/server/middleware/auth.py +66 -0
  35. router_maestro/server/oauth_sessions.py +159 -0
  36. router_maestro/server/routes/__init__.py +8 -0
  37. router_maestro/server/routes/admin.py +358 -0
  38. router_maestro/server/routes/anthropic.py +228 -0
  39. router_maestro/server/routes/chat.py +142 -0
  40. router_maestro/server/routes/models.py +34 -0
  41. router_maestro/server/schemas/__init__.py +57 -0
  42. router_maestro/server/schemas/admin.py +87 -0
  43. router_maestro/server/schemas/anthropic.py +246 -0
  44. router_maestro/server/schemas/openai.py +107 -0
  45. router_maestro/server/translation.py +636 -0
  46. router_maestro/stats/__init__.py +14 -0
  47. router_maestro/stats/heatmap.py +154 -0
  48. router_maestro/stats/storage.py +228 -0
  49. router_maestro/stats/tracker.py +73 -0
  50. router_maestro/utils/__init__.py +16 -0
  51. router_maestro/utils/logging.py +81 -0
  52. router_maestro/utils/tokens.py +51 -0
  53. router_maestro-0.1.2.dist-info/METADATA +383 -0
  54. router_maestro-0.1.2.dist-info/RECORD +57 -0
  55. router_maestro-0.1.2.dist-info/WHEEL +4 -0
  56. router_maestro-0.1.2.dist-info/entry_points.txt +2 -0
  57. router_maestro-0.1.2.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,526 @@
1
+ """Model router with priority-based selection and fallback."""
2
+
3
+ import time
4
+ from collections.abc import AsyncIterator
5
+
6
+ from router_maestro.auth import ApiKeyCredential, AuthManager
7
+ from router_maestro.config import (
8
+ FallbackStrategy,
9
+ PrioritiesConfig,
10
+ load_priorities_config,
11
+ load_providers_config,
12
+ )
13
+ from router_maestro.providers import (
14
+ BaseProvider,
15
+ ChatRequest,
16
+ ChatResponse,
17
+ ChatStreamChunk,
18
+ CopilotProvider,
19
+ ModelInfo,
20
+ OpenAICompatibleProvider,
21
+ ProviderError,
22
+ )
23
+ from router_maestro.utils import get_logger
24
+
25
+ logger = get_logger("routing")
26
+
27
+ # Special model name that triggers auto-routing
28
+ AUTO_ROUTE_MODEL = "router-maestro"
29
+
30
+ # Cache TTL in seconds (5 minutes)
31
+ CACHE_TTL_SECONDS = 300
32
+
33
+ # Global singleton instance
34
+ _router_instance: "Router | None" = None
35
+
36
+
37
+ def get_router() -> "Router":
38
+ """Get the singleton Router instance.
39
+
40
+ Returns:
41
+ The global Router instance
42
+ """
43
+ global _router_instance
44
+ if _router_instance is None:
45
+ _router_instance = Router()
46
+ logger.info("Created singleton Router instance")
47
+ return _router_instance
48
+
49
+
50
+ def reset_router() -> None:
51
+ """Reset the singleton Router instance.
52
+
53
+ Call this when authentication changes or to force reload.
54
+ """
55
+ global _router_instance
56
+ if _router_instance is not None:
57
+ _router_instance.invalidate_cache()
58
+ _router_instance = None
59
+ logger.info("Reset singleton Router instance")
60
+
61
+
62
+ class Router:
63
+ """Router for model requests with priority and fallback support."""
64
+
65
+ def __init__(self) -> None:
66
+ self.providers: dict[str, BaseProvider] = {}
67
+ # Model cache: maps model_id -> (provider_name, ModelInfo)
68
+ self._models_cache: dict[str, tuple[str, ModelInfo]] = {}
69
+ self._cache_initialized: bool = False
70
+ self._cache_timestamp: float = 0.0
71
+ # Priorities config cache
72
+ self._priorities_config: PrioritiesConfig | None = None
73
+ self._priorities_config_timestamp: float = 0.0
74
+ # Providers config cache
75
+ self._providers_config_timestamp: float = 0.0
76
+ self._load_providers()
77
+
78
+ def _load_providers(self) -> None:
79
+ """Load providers from configuration."""
80
+ custom_providers_config = load_providers_config()
81
+ auth_manager = AuthManager()
82
+
83
+ # Clear existing providers except keep copilot if already exists
84
+ old_copilot = self.providers.get("github-copilot")
85
+ self.providers.clear()
86
+
87
+ # Always add built-in GitHub Copilot provider (reuse existing instance if available)
88
+ if old_copilot is not None:
89
+ self.providers["github-copilot"] = old_copilot
90
+ else:
91
+ copilot = CopilotProvider()
92
+ self.providers["github-copilot"] = copilot
93
+ logger.debug("Loaded built-in provider: github-copilot")
94
+
95
+ # Load custom providers from providers.json
96
+ for provider_name, provider_config in custom_providers_config.providers.items():
97
+ # Get API key from auth storage
98
+ cred = auth_manager.get_credential(provider_name)
99
+ if isinstance(cred, ApiKeyCredential):
100
+ provider = OpenAICompatibleProvider(
101
+ name=provider_name,
102
+ base_url=provider_config.baseURL,
103
+ api_key=cred.key,
104
+ models={
105
+ model_id: model_config.name
106
+ for model_id, model_config in provider_config.models.items()
107
+ },
108
+ )
109
+ self.providers[provider_name] = provider
110
+ logger.debug("Loaded custom provider: %s", provider_name)
111
+
112
+ self._providers_config_timestamp = time.time()
113
+ logger.info("Loaded %d providers", len(self.providers))
114
+
115
+ def _get_priorities_config(self) -> PrioritiesConfig:
116
+ """Get priorities config with caching."""
117
+ # Simple time-based cache (same TTL as models cache)
118
+ current_time = time.time()
119
+ if (
120
+ self._priorities_config is not None
121
+ and current_time - self._priorities_config_timestamp < CACHE_TTL_SECONDS
122
+ ):
123
+ return self._priorities_config
124
+
125
+ self._priorities_config = load_priorities_config()
126
+ self._priorities_config_timestamp = current_time
127
+ return self._priorities_config
128
+
129
+ def _ensure_providers_fresh(self) -> None:
130
+ """Ensure providers config is fresh, reload if expired."""
131
+ current_time = time.time()
132
+ if current_time - self._providers_config_timestamp >= CACHE_TTL_SECONDS:
133
+ logger.debug("Providers config expired, reloading")
134
+ self._load_providers()
135
+ # Also invalidate models cache since providers may have changed
136
+ self._models_cache.clear()
137
+ self._cache_initialized = False
138
+
139
+ def _parse_model_key(self, model_key: str) -> tuple[str, str]:
140
+ """Parse a model key into provider and model.
141
+
142
+ Args:
143
+ model_key: Model key in format 'provider/model'
144
+
145
+ Returns:
146
+ Tuple of (provider_name, model_id)
147
+ """
148
+ if "/" in model_key:
149
+ parts = model_key.split("/", 1)
150
+ return parts[0], parts[1]
151
+ return "", model_key
152
+
153
+ async def _ensure_models_cache(self) -> None:
154
+ """Ensure the models cache is populated and not expired."""
155
+ # First ensure providers config is fresh
156
+ self._ensure_providers_fresh()
157
+
158
+ # Check if cache is still valid (initialized and not expired)
159
+ if self._cache_initialized:
160
+ age = time.time() - self._cache_timestamp
161
+ if age < CACHE_TTL_SECONDS:
162
+ return
163
+ logger.debug("Cache expired (age=%.1fs), refreshing", age)
164
+ self._models_cache.clear()
165
+
166
+ logger.debug("Initializing models cache")
167
+ for provider_name, provider in self.providers.items():
168
+ if provider.is_authenticated():
169
+ try:
170
+ await provider.ensure_token()
171
+ models = await provider.list_models()
172
+ for model in models:
173
+ # Store by model_id only (without provider prefix)
174
+ # If same model_id exists in multiple providers, first one wins
175
+ if model.id not in self._models_cache:
176
+ self._models_cache[model.id] = (provider_name, model)
177
+ # Also store with provider prefix for explicit lookups
178
+ full_key = f"{provider_name}/{model.id}"
179
+ self._models_cache[full_key] = (provider_name, model)
180
+ logger.debug("Cached %d models from %s", len(models), provider_name)
181
+ except ProviderError as e:
182
+ logger.warning("Failed to load models from %s: %s", provider_name, e)
183
+ continue
184
+
185
+ self._cache_initialized = True
186
+ self._cache_timestamp = time.time()
187
+ logger.info("Models cache initialized with %d entries", len(self._models_cache))
188
+
189
+ async def _resolve_provider(self, model_id: str) -> tuple[str, str, BaseProvider]:
190
+ """Resolve model_id to provider.
191
+
192
+ Args:
193
+ model_id: Model ID (can be 'router-maestro', 'provider/model', or just 'model')
194
+
195
+ Returns:
196
+ Tuple of (provider_name, actual_model_id, provider)
197
+
198
+ Raises:
199
+ ProviderError: If model not found or no models available
200
+ """
201
+ # Check for auto-routing
202
+ if model_id == AUTO_ROUTE_MODEL:
203
+ result = await self._get_auto_route_model()
204
+ if not result:
205
+ logger.error("No models available for auto-routing")
206
+ raise ProviderError("No models available for auto-routing", status_code=503)
207
+ return result
208
+
209
+ # Explicit model specified - find in cache
210
+ result = await self._find_model_in_cache(model_id)
211
+ if not result:
212
+ logger.warning("Model not found: %s", model_id)
213
+ raise ProviderError(
214
+ f"Model '{model_id}' not found in any provider",
215
+ status_code=404,
216
+ )
217
+ return result
218
+
219
+ def _create_request_with_model(
220
+ self, original_request: ChatRequest, model_id: str
221
+ ) -> ChatRequest:
222
+ """Create a new ChatRequest with a different model ID.
223
+
224
+ Args:
225
+ original_request: The original request
226
+ model_id: The new model ID to use
227
+
228
+ Returns:
229
+ New ChatRequest with updated model
230
+ """
231
+ return ChatRequest(
232
+ model=model_id,
233
+ messages=original_request.messages,
234
+ temperature=original_request.temperature,
235
+ max_tokens=original_request.max_tokens,
236
+ stream=original_request.stream,
237
+ tools=original_request.tools,
238
+ tool_choice=original_request.tool_choice,
239
+ )
240
+
241
+ async def _get_auto_route_model(self) -> tuple[str, str, BaseProvider] | None:
242
+ """Get the highest priority available model for auto-routing.
243
+
244
+ Returns:
245
+ Tuple of (provider_name, model_id, provider) or None if no model available
246
+ """
247
+ await self._ensure_models_cache()
248
+ priorities_config = self._get_priorities_config()
249
+
250
+ # Try each priority in order
251
+ for priority_key in priorities_config.priorities:
252
+ provider_name, model_id = self._parse_model_key(priority_key)
253
+ if provider_name in self.providers:
254
+ provider = self.providers[provider_name]
255
+ if provider.is_authenticated():
256
+ # Verify model exists in cache
257
+ if priority_key in self._models_cache:
258
+ logger.debug("Auto-route selected: %s", priority_key)
259
+ return provider_name, model_id, provider
260
+
261
+ # Fallback: return first available model from any provider
262
+ for model_id, (provider_name, _) in self._models_cache.items():
263
+ if "/" not in model_id: # Skip full keys, only use simple model_ids
264
+ provider = self.providers.get(provider_name)
265
+ if provider and provider.is_authenticated():
266
+ logger.debug("Auto-route fallback: %s/%s", provider_name, model_id)
267
+ return provider_name, model_id, provider
268
+
269
+ return None
270
+
271
+ async def _find_model_in_cache(self, model_id: str) -> tuple[str, str, BaseProvider] | None:
272
+ """Find a model in the cache.
273
+
274
+ Args:
275
+ model_id: Model ID (can be 'provider/model' or just 'model')
276
+
277
+ Returns:
278
+ Tuple of (provider_name, actual_model_id, provider) or None
279
+ """
280
+ await self._ensure_models_cache()
281
+
282
+ # If model_id includes provider prefix (e.g., "github-copilot/gpt-4o")
283
+ if "/" in model_id:
284
+ provider_name, actual_model_id = self._parse_model_key(model_id)
285
+ if provider_name in self.providers:
286
+ provider = self.providers[provider_name]
287
+ if provider.is_authenticated():
288
+ # Check if the model exists for this provider
289
+ if model_id in self._models_cache:
290
+ return provider_name, actual_model_id, provider
291
+ return None
292
+
293
+ # Simple model_id (e.g., "gpt-4o") - look up in cache
294
+ if model_id in self._models_cache:
295
+ provider_name, _ = self._models_cache[model_id]
296
+ provider = self.providers.get(provider_name)
297
+ if provider and provider.is_authenticated():
298
+ return provider_name, model_id, provider
299
+
300
+ return None
301
+
302
+ def _get_fallback_candidates(
303
+ self,
304
+ current_provider: str,
305
+ current_model: str,
306
+ strategy: FallbackStrategy,
307
+ ) -> list[tuple[str, str, BaseProvider]]:
308
+ """Get ordered list of fallback candidates based on strategy.
309
+
310
+ Args:
311
+ current_provider: The provider that just failed
312
+ current_model: The model that was requested
313
+ strategy: The fallback strategy to use
314
+
315
+ Returns:
316
+ List of (provider_name, model_id, provider) tuples to try
317
+ """
318
+ if strategy == FallbackStrategy.NONE:
319
+ return []
320
+
321
+ candidates: list[tuple[str, str, BaseProvider]] = []
322
+ current_key = f"{current_provider}/{current_model}"
323
+
324
+ if strategy == FallbackStrategy.PRIORITY:
325
+ # Follow the priorities list order, starting after current
326
+ priorities_config = self._get_priorities_config()
327
+ found_current = False
328
+
329
+ for priority_key in priorities_config.priorities:
330
+ if priority_key == current_key:
331
+ found_current = True
332
+ continue
333
+
334
+ if found_current:
335
+ provider_name, model_id = self._parse_model_key(priority_key)
336
+ if provider_name in self.providers:
337
+ provider = self.providers[provider_name]
338
+ if provider.is_authenticated():
339
+ if priority_key in self._models_cache:
340
+ candidates.append((provider_name, model_id, provider))
341
+
342
+ elif strategy == FallbackStrategy.SAME_MODEL:
343
+ # Only try other providers that have the same model
344
+ for other_name, other_provider in self.providers.items():
345
+ if other_name == current_provider:
346
+ continue
347
+ if not other_provider.is_authenticated():
348
+ continue
349
+ other_key = f"{other_name}/{current_model}"
350
+ if other_key in self._models_cache:
351
+ candidates.append((other_name, current_model, other_provider))
352
+
353
+ return candidates
354
+
355
+ async def _execute_with_fallback(
356
+ self,
357
+ request: ChatRequest,
358
+ provider_name: str,
359
+ actual_model_id: str,
360
+ provider: BaseProvider,
361
+ fallback: bool,
362
+ is_stream: bool,
363
+ ) -> tuple[ChatResponse | AsyncIterator[ChatStreamChunk], str]:
364
+ """Execute request with fallback support.
365
+
366
+ Args:
367
+ request: Original chat request
368
+ provider_name: Name of the primary provider
369
+ actual_model_id: The actual model ID to use
370
+ provider: The primary provider instance
371
+ fallback: Whether to try fallback providers on error
372
+ is_stream: Whether this is a streaming request
373
+
374
+ Returns:
375
+ Tuple of (response or stream, provider_name)
376
+
377
+ Raises:
378
+ ProviderError: If all providers fail
379
+ """
380
+ actual_request = self._create_request_with_model(request, actual_model_id)
381
+
382
+ try:
383
+ await provider.ensure_token()
384
+ if is_stream:
385
+ stream = provider.chat_completion_stream(actual_request)
386
+ logger.info("Stream request routed to %s", provider_name)
387
+ return stream, provider_name
388
+ else:
389
+ response = await provider.chat_completion(actual_request)
390
+ logger.info("Request completed via %s", provider_name)
391
+ return response, provider_name
392
+ except ProviderError as e:
393
+ logger.warning("Provider %s failed: %s", provider_name, e)
394
+ if not fallback or not e.retryable:
395
+ raise
396
+
397
+ # Load fallback config
398
+ priorities_config = self._get_priorities_config()
399
+ fallback_config = priorities_config.fallback
400
+
401
+ if fallback_config.strategy == FallbackStrategy.NONE:
402
+ raise
403
+
404
+ # Get fallback candidates
405
+ candidates = self._get_fallback_candidates(
406
+ provider_name, actual_model_id, fallback_config.strategy
407
+ )
408
+
409
+ # Try fallback candidates up to maxRetries
410
+ for i, (other_name, other_model_id, other_provider) in enumerate(candidates):
411
+ if i >= fallback_config.maxRetries:
412
+ break
413
+
414
+ logger.info("Trying fallback: %s/%s", other_name, other_model_id)
415
+ fallback_request = self._create_request_with_model(request, other_model_id)
416
+
417
+ try:
418
+ await other_provider.ensure_token()
419
+ if is_stream:
420
+ stream = other_provider.chat_completion_stream(fallback_request)
421
+ logger.info("Stream fallback succeeded via %s", other_name)
422
+ return stream, other_name
423
+ else:
424
+ response = await other_provider.chat_completion(fallback_request)
425
+ logger.info("Fallback succeeded via %s", other_name)
426
+ return response, other_name
427
+ except ProviderError as fallback_error:
428
+ logger.warning("Fallback %s failed: %s", other_name, fallback_error)
429
+ continue
430
+ raise
431
+
432
+ async def chat_completion(
433
+ self,
434
+ request: ChatRequest,
435
+ fallback: bool = True,
436
+ ) -> tuple[ChatResponse, str]:
437
+ """Route a chat completion request.
438
+
439
+ Args:
440
+ request: Chat completion request
441
+ fallback: Whether to try fallback providers on error
442
+
443
+ Returns:
444
+ Tuple of (response, provider_name)
445
+
446
+ Raises:
447
+ ProviderError: If model not found or all providers fail
448
+ """
449
+ provider_name, actual_model_id, provider = await self._resolve_provider(request.model)
450
+ logger.info("Routing request to %s/%s", provider_name, actual_model_id)
451
+
452
+ result, used_provider = await self._execute_with_fallback(
453
+ request, provider_name, actual_model_id, provider, fallback, is_stream=False
454
+ )
455
+ return result, used_provider # type: ignore
456
+
457
+ async def chat_completion_stream(
458
+ self,
459
+ request: ChatRequest,
460
+ fallback: bool = True,
461
+ ) -> tuple[AsyncIterator[ChatStreamChunk], str]:
462
+ """Route a streaming chat completion request.
463
+
464
+ Args:
465
+ request: Chat completion request
466
+ fallback: Whether to try fallback providers on error
467
+
468
+ Returns:
469
+ Tuple of (stream iterator, provider_name)
470
+
471
+ Raises:
472
+ ProviderError: If model not found or all providers fail
473
+ """
474
+ provider_name, actual_model_id, provider = await self._resolve_provider(request.model)
475
+ logger.info("Routing stream request to %s/%s", provider_name, actual_model_id)
476
+
477
+ result, used_provider = await self._execute_with_fallback(
478
+ request, provider_name, actual_model_id, provider, fallback, is_stream=True
479
+ )
480
+ return result, used_provider # type: ignore
481
+
482
+ async def list_models(self) -> list[ModelInfo]:
483
+ """List all available models from all authenticated providers.
484
+
485
+ Models are sorted by priority configuration.
486
+
487
+ Returns:
488
+ List of available models
489
+ """
490
+ await self._ensure_models_cache()
491
+ priorities_config = self._get_priorities_config()
492
+
493
+ models: list[ModelInfo] = []
494
+ seen: set[str] = set()
495
+
496
+ # Collect all models with their full keys
497
+ all_models: dict[str, ModelInfo] = {}
498
+ for key, (_, model_info) in self._models_cache.items():
499
+ # Only include full keys (provider/model)
500
+ if "/" in key:
501
+ all_models[key] = model_info
502
+
503
+ # Add prioritized models first
504
+ for priority_key in priorities_config.priorities:
505
+ if priority_key in all_models and priority_key not in seen:
506
+ models.append(all_models[priority_key])
507
+ seen.add(priority_key)
508
+
509
+ # Add remaining models
510
+ for key, model in all_models.items():
511
+ if key not in seen:
512
+ models.append(model)
513
+ seen.add(key)
514
+
515
+ logger.debug("Listed %d models", len(models))
516
+ return models
517
+
518
+ def invalidate_cache(self) -> None:
519
+ """Invalidate all caches to force refresh."""
520
+ self._models_cache.clear()
521
+ self._cache_initialized = False
522
+ self._cache_timestamp = 0.0
523
+ self._priorities_config = None
524
+ self._priorities_config_timestamp = 0.0
525
+ self._providers_config_timestamp = 0.0
526
+ logger.debug("All caches invalidated")
@@ -0,0 +1,5 @@
1
+ """Server module for router-maestro."""
2
+
3
+ from router_maestro.server.app import app, create_app
4
+
5
+ __all__ = ["app", "create_app"]
@@ -0,0 +1,87 @@
1
+ """FastAPI application for router-maestro."""
2
+
3
+ import os
4
+ from contextlib import asynccontextmanager
5
+
6
+ from fastapi import Depends, FastAPI
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+
9
+ from router_maestro import __version__
10
+ from router_maestro.routing import get_router
11
+ from router_maestro.server.middleware import verify_api_key
12
+ from router_maestro.server.routes import admin_router, anthropic_router, chat_router, models_router
13
+ from router_maestro.utils import get_logger, setup_logging
14
+
15
+ logger = get_logger("server")
16
+
17
+
18
+ @asynccontextmanager
19
+ async def lifespan(app: FastAPI):
20
+ """Application lifespan handler."""
21
+ # Startup - initialize logging
22
+ log_level = os.environ.get("ROUTER_MAESTRO_LOG_LEVEL", "INFO")
23
+ setup_logging(level=log_level)
24
+ logger.info("Router-Maestro server starting up")
25
+
26
+ # Pre-warm model cache if any providers are authenticated
27
+ router = get_router()
28
+ authenticated_providers = [
29
+ name for name, provider in router.providers.items() if provider.is_authenticated()
30
+ ]
31
+ if authenticated_providers:
32
+ logger.info(
33
+ "Pre-warming model cache for authenticated providers: %s", authenticated_providers
34
+ )
35
+ try:
36
+ models = await router.list_models()
37
+ logger.info("Model cache pre-warmed with %d models", len(models))
38
+ except Exception as e:
39
+ logger.warning("Failed to pre-warm model cache: %s", e)
40
+
41
+ yield
42
+ # Shutdown
43
+ logger.info("Router-Maestro server shutting down")
44
+
45
+
46
+ def create_app() -> FastAPI:
47
+ """Create the FastAPI application."""
48
+ app = FastAPI(
49
+ title="Router-Maestro",
50
+ description="Multi-model routing and load balancing with OpenAI-compatible API",
51
+ version=__version__,
52
+ lifespan=lifespan,
53
+ )
54
+
55
+ # Add CORS middleware
56
+ app.add_middleware(
57
+ CORSMiddleware,
58
+ allow_origins=["*"],
59
+ allow_credentials=True,
60
+ allow_methods=["*"],
61
+ allow_headers=["*"],
62
+ )
63
+
64
+ # Include routers with API key verification
65
+ app.include_router(chat_router, dependencies=[Depends(verify_api_key)])
66
+ app.include_router(models_router, dependencies=[Depends(verify_api_key)])
67
+ app.include_router(anthropic_router, dependencies=[Depends(verify_api_key)])
68
+ app.include_router(admin_router, dependencies=[Depends(verify_api_key)])
69
+
70
+ @app.get("/")
71
+ async def root():
72
+ """Root endpoint."""
73
+ return {
74
+ "name": "Router-Maestro",
75
+ "version": __version__,
76
+ "status": "running",
77
+ }
78
+
79
+ @app.get("/health")
80
+ async def health():
81
+ """Health check endpoint."""
82
+ return {"status": "healthy"}
83
+
84
+ return app
85
+
86
+
87
+ app = create_app()
@@ -0,0 +1,11 @@
1
+ """Middleware module."""
2
+
3
+ from router_maestro.server.middleware.auth import (
4
+ get_server_api_key,
5
+ verify_api_key,
6
+ )
7
+
8
+ __all__ = [
9
+ "verify_api_key",
10
+ "get_server_api_key",
11
+ ]
@@ -0,0 +1,66 @@
1
+ """Authentication middleware for API key validation."""
2
+
3
+ import os
4
+
5
+ from fastapi import Depends, HTTPException, Request, status
6
+ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
7
+
8
+ security = HTTPBearer(auto_error=False)
9
+
10
+
11
+ def get_server_api_key() -> str | None:
12
+ """Get the server API key from environment variable."""
13
+ return os.environ.get("ROUTER_MAESTRO_API_KEY")
14
+
15
+
16
+ async def verify_api_key(
17
+ request: Request,
18
+ credentials: HTTPAuthorizationCredentials | None = Depends(security),
19
+ ) -> None:
20
+ """Verify the API key from the Authorization header.
21
+
22
+ Accepts both:
23
+ - Authorization: Bearer <api_key>
24
+ - Authorization: <api_key>
25
+ """
26
+ server_api_key = get_server_api_key()
27
+
28
+ if server_api_key is None:
29
+ # No API key configured, allow all requests
30
+ return
31
+
32
+ # Skip auth for health and root endpoints
33
+ if request.url.path in ("/", "/health", "/docs", "/openapi.json", "/redoc"):
34
+ return
35
+
36
+ # Get API key from header
37
+ api_key: str | None = None
38
+
39
+ if credentials:
40
+ api_key = credentials.credentials
41
+ else:
42
+ # Try to get from Authorization header directly (without Bearer prefix)
43
+ auth_header = request.headers.get("Authorization")
44
+ if auth_header:
45
+ if auth_header.startswith("Bearer "):
46
+ api_key = auth_header[7:]
47
+ else:
48
+ api_key = auth_header
49
+
50
+ # Also check x-api-key header for Anthropic API compatibility
51
+ if not api_key:
52
+ api_key = request.headers.get("x-api-key")
53
+
54
+ if not api_key:
55
+ raise HTTPException(
56
+ status_code=status.HTTP_401_UNAUTHORIZED,
57
+ detail="Missing API key. Use 'Authorization: Bearer <api_key>' header.",
58
+ headers={"WWW-Authenticate": "Bearer"},
59
+ )
60
+
61
+ if api_key != server_api_key:
62
+ raise HTTPException(
63
+ status_code=status.HTTP_401_UNAUTHORIZED,
64
+ detail="Invalid API key",
65
+ headers={"WWW-Authenticate": "Bearer"},
66
+ )