router-maestro 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- router_maestro/__init__.py +3 -0
- router_maestro/__main__.py +6 -0
- router_maestro/auth/__init__.py +18 -0
- router_maestro/auth/github_oauth.py +181 -0
- router_maestro/auth/manager.py +136 -0
- router_maestro/auth/storage.py +91 -0
- router_maestro/cli/__init__.py +1 -0
- router_maestro/cli/auth.py +167 -0
- router_maestro/cli/client.py +322 -0
- router_maestro/cli/config.py +132 -0
- router_maestro/cli/context.py +146 -0
- router_maestro/cli/main.py +42 -0
- router_maestro/cli/model.py +288 -0
- router_maestro/cli/server.py +117 -0
- router_maestro/cli/stats.py +76 -0
- router_maestro/config/__init__.py +72 -0
- router_maestro/config/contexts.py +29 -0
- router_maestro/config/paths.py +50 -0
- router_maestro/config/priorities.py +93 -0
- router_maestro/config/providers.py +34 -0
- router_maestro/config/server.py +115 -0
- router_maestro/config/settings.py +76 -0
- router_maestro/providers/__init__.py +31 -0
- router_maestro/providers/anthropic.py +203 -0
- router_maestro/providers/base.py +123 -0
- router_maestro/providers/copilot.py +346 -0
- router_maestro/providers/openai.py +188 -0
- router_maestro/providers/openai_compat.py +175 -0
- router_maestro/routing/__init__.py +5 -0
- router_maestro/routing/router.py +526 -0
- router_maestro/server/__init__.py +5 -0
- router_maestro/server/app.py +87 -0
- router_maestro/server/middleware/__init__.py +11 -0
- router_maestro/server/middleware/auth.py +66 -0
- router_maestro/server/oauth_sessions.py +159 -0
- router_maestro/server/routes/__init__.py +8 -0
- router_maestro/server/routes/admin.py +358 -0
- router_maestro/server/routes/anthropic.py +228 -0
- router_maestro/server/routes/chat.py +142 -0
- router_maestro/server/routes/models.py +34 -0
- router_maestro/server/schemas/__init__.py +57 -0
- router_maestro/server/schemas/admin.py +87 -0
- router_maestro/server/schemas/anthropic.py +246 -0
- router_maestro/server/schemas/openai.py +107 -0
- router_maestro/server/translation.py +636 -0
- router_maestro/stats/__init__.py +14 -0
- router_maestro/stats/heatmap.py +154 -0
- router_maestro/stats/storage.py +228 -0
- router_maestro/stats/tracker.py +73 -0
- router_maestro/utils/__init__.py +16 -0
- router_maestro/utils/logging.py +81 -0
- router_maestro/utils/tokens.py +51 -0
- router_maestro-0.1.2.dist-info/METADATA +383 -0
- router_maestro-0.1.2.dist-info/RECORD +57 -0
- router_maestro-0.1.2.dist-info/WHEEL +4 -0
- router_maestro-0.1.2.dist-info/entry_points.txt +2 -0
- router_maestro-0.1.2.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
"""OAuth session management for remote OAuth flows."""
|
|
2
|
+
|
|
3
|
+
import secrets
|
|
4
|
+
import time
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from threading import Lock
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class OAuthSession:
|
|
11
|
+
"""An OAuth session for device flow authentication."""
|
|
12
|
+
|
|
13
|
+
session_id: str
|
|
14
|
+
provider: str
|
|
15
|
+
device_code: str
|
|
16
|
+
user_code: str
|
|
17
|
+
verification_uri: str
|
|
18
|
+
expires_at: float
|
|
19
|
+
interval: int
|
|
20
|
+
status: str = "pending" # pending, complete, expired, error
|
|
21
|
+
error: str | None = None
|
|
22
|
+
# Result data when complete
|
|
23
|
+
access_token: str | None = None
|
|
24
|
+
refresh_token: str | None = None
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class OAuthSessionManager:
|
|
28
|
+
"""Manages OAuth sessions for remote device flow authentication."""
|
|
29
|
+
|
|
30
|
+
def __init__(self, session_timeout: int = 900) -> None:
|
|
31
|
+
"""Initialize the session manager.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
session_timeout: Default session timeout in seconds (default: 15 minutes)
|
|
35
|
+
"""
|
|
36
|
+
self._sessions: dict[str, OAuthSession] = {}
|
|
37
|
+
self._lock = Lock()
|
|
38
|
+
self._session_timeout = session_timeout
|
|
39
|
+
|
|
40
|
+
def create_session(
|
|
41
|
+
self,
|
|
42
|
+
provider: str,
|
|
43
|
+
device_code: str,
|
|
44
|
+
user_code: str,
|
|
45
|
+
verification_uri: str,
|
|
46
|
+
expires_in: int,
|
|
47
|
+
interval: int = 5,
|
|
48
|
+
) -> OAuthSession:
|
|
49
|
+
"""Create a new OAuth session.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
provider: Provider name (e.g., 'github-copilot')
|
|
53
|
+
device_code: Device code from OAuth provider
|
|
54
|
+
user_code: User code to display
|
|
55
|
+
verification_uri: URL for user to visit
|
|
56
|
+
expires_in: Seconds until expiration
|
|
57
|
+
interval: Polling interval in seconds
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
The created OAuth session
|
|
61
|
+
"""
|
|
62
|
+
session_id = secrets.token_urlsafe(16)
|
|
63
|
+
session = OAuthSession(
|
|
64
|
+
session_id=session_id,
|
|
65
|
+
provider=provider,
|
|
66
|
+
device_code=device_code,
|
|
67
|
+
user_code=user_code,
|
|
68
|
+
verification_uri=verification_uri,
|
|
69
|
+
expires_at=time.time() + expires_in,
|
|
70
|
+
interval=interval,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
with self._lock:
|
|
74
|
+
self._sessions[session_id] = session
|
|
75
|
+
self._cleanup_expired()
|
|
76
|
+
|
|
77
|
+
return session
|
|
78
|
+
|
|
79
|
+
def get_session(self, session_id: str) -> OAuthSession | None:
|
|
80
|
+
"""Get a session by ID.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
session_id: Session ID to look up
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
The session if found and not expired, None otherwise
|
|
87
|
+
"""
|
|
88
|
+
with self._lock:
|
|
89
|
+
session = self._sessions.get(session_id)
|
|
90
|
+
if session is None:
|
|
91
|
+
return None
|
|
92
|
+
|
|
93
|
+
# Check if expired
|
|
94
|
+
if session.status == "pending" and time.time() > session.expires_at:
|
|
95
|
+
session.status = "expired"
|
|
96
|
+
|
|
97
|
+
return session
|
|
98
|
+
|
|
99
|
+
def update_session_status(
|
|
100
|
+
self,
|
|
101
|
+
session_id: str,
|
|
102
|
+
status: str,
|
|
103
|
+
error: str | None = None,
|
|
104
|
+
access_token: str | None = None,
|
|
105
|
+
refresh_token: str | None = None,
|
|
106
|
+
) -> bool:
|
|
107
|
+
"""Update session status.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
session_id: Session ID to update
|
|
111
|
+
status: New status
|
|
112
|
+
error: Error message if status is 'error'
|
|
113
|
+
access_token: Access token if status is 'complete'
|
|
114
|
+
refresh_token: Refresh token if status is 'complete'
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
True if session was updated, False if not found
|
|
118
|
+
"""
|
|
119
|
+
with self._lock:
|
|
120
|
+
session = self._sessions.get(session_id)
|
|
121
|
+
if session is None:
|
|
122
|
+
return False
|
|
123
|
+
|
|
124
|
+
session.status = status
|
|
125
|
+
session.error = error
|
|
126
|
+
session.access_token = access_token
|
|
127
|
+
session.refresh_token = refresh_token
|
|
128
|
+
return True
|
|
129
|
+
|
|
130
|
+
def remove_session(self, session_id: str) -> bool:
|
|
131
|
+
"""Remove a session.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
session_id: Session ID to remove
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
True if session was removed, False if not found
|
|
138
|
+
"""
|
|
139
|
+
with self._lock:
|
|
140
|
+
if session_id in self._sessions:
|
|
141
|
+
del self._sessions[session_id]
|
|
142
|
+
return True
|
|
143
|
+
return False
|
|
144
|
+
|
|
145
|
+
def _cleanup_expired(self) -> None:
|
|
146
|
+
"""Remove expired sessions. Must be called with lock held."""
|
|
147
|
+
now = time.time()
|
|
148
|
+
expired = [
|
|
149
|
+
sid
|
|
150
|
+
for sid, session in self._sessions.items()
|
|
151
|
+
if session.status in ("complete", "error", "expired")
|
|
152
|
+
or now > session.expires_at + 60 # Keep for 1 minute after expiry
|
|
153
|
+
]
|
|
154
|
+
for sid in expired:
|
|
155
|
+
del self._sessions[sid]
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
# Global session manager instance
|
|
159
|
+
oauth_sessions = OAuthSessionManager()
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
"""Server routes module."""
|
|
2
|
+
|
|
3
|
+
from router_maestro.server.routes.admin import router as admin_router
|
|
4
|
+
from router_maestro.server.routes.anthropic import router as anthropic_router
|
|
5
|
+
from router_maestro.server.routes.chat import router as chat_router
|
|
6
|
+
from router_maestro.server.routes.models import router as models_router
|
|
7
|
+
|
|
8
|
+
__all__ = ["admin_router", "anthropic_router", "chat_router", "models_router"]
|
|
@@ -0,0 +1,358 @@
|
|
|
1
|
+
"""Admin API routes for remote management."""
|
|
2
|
+
|
|
3
|
+
from typing import Annotated
|
|
4
|
+
|
|
5
|
+
import httpx
|
|
6
|
+
from fastapi import APIRouter, BackgroundTasks, HTTPException, Query
|
|
7
|
+
|
|
8
|
+
from router_maestro.auth import AuthManager, AuthType
|
|
9
|
+
from router_maestro.auth.github_oauth import (
|
|
10
|
+
GitHubOAuthError,
|
|
11
|
+
get_copilot_token,
|
|
12
|
+
poll_access_token,
|
|
13
|
+
request_device_code,
|
|
14
|
+
)
|
|
15
|
+
from router_maestro.auth.storage import OAuthCredential
|
|
16
|
+
from router_maestro.config import (
|
|
17
|
+
load_priorities_config,
|
|
18
|
+
save_priorities_config,
|
|
19
|
+
)
|
|
20
|
+
from router_maestro.routing import get_router, reset_router
|
|
21
|
+
from router_maestro.server.oauth_sessions import oauth_sessions
|
|
22
|
+
from router_maestro.server.schemas.admin import (
|
|
23
|
+
AuthListResponse,
|
|
24
|
+
AuthProviderInfo,
|
|
25
|
+
LoginRequest,
|
|
26
|
+
ModelInfo,
|
|
27
|
+
ModelsResponse,
|
|
28
|
+
OAuthInitResponse,
|
|
29
|
+
OAuthStatusResponse,
|
|
30
|
+
PrioritiesResponse,
|
|
31
|
+
PrioritiesUpdateRequest,
|
|
32
|
+
StatsResponse,
|
|
33
|
+
)
|
|
34
|
+
from router_maestro.stats import StatsStorage
|
|
35
|
+
|
|
36
|
+
router = APIRouter(prefix="/api/admin", tags=["admin"])
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
# ============================================================================
|
|
40
|
+
# Auth endpoints
|
|
41
|
+
# ============================================================================
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@router.get("/auth", response_model=AuthListResponse)
|
|
45
|
+
async def list_auth() -> AuthListResponse:
|
|
46
|
+
"""List all authenticated providers."""
|
|
47
|
+
manager = AuthManager()
|
|
48
|
+
providers = []
|
|
49
|
+
|
|
50
|
+
for provider_name in manager.list_authenticated():
|
|
51
|
+
cred = manager.get_credential(provider_name)
|
|
52
|
+
if cred:
|
|
53
|
+
auth_type = "oauth" if cred.type == AuthType.OAUTH else "api"
|
|
54
|
+
# For OAuth, check if token might be expired
|
|
55
|
+
status = "active"
|
|
56
|
+
if isinstance(cred, OAuthCredential) and cred.expires > 0:
|
|
57
|
+
import time
|
|
58
|
+
|
|
59
|
+
if cred.expires < time.time():
|
|
60
|
+
status = "expired"
|
|
61
|
+
|
|
62
|
+
providers.append(
|
|
63
|
+
AuthProviderInfo(
|
|
64
|
+
provider=provider_name,
|
|
65
|
+
auth_type=auth_type,
|
|
66
|
+
status=status,
|
|
67
|
+
)
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
return AuthListResponse(providers=providers)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@router.post("/auth/login")
|
|
74
|
+
async def login(
|
|
75
|
+
request: LoginRequest,
|
|
76
|
+
background_tasks: BackgroundTasks,
|
|
77
|
+
) -> OAuthInitResponse | dict:
|
|
78
|
+
"""Initiate login for a provider.
|
|
79
|
+
|
|
80
|
+
For OAuth providers (github-copilot): Returns session info for device flow polling.
|
|
81
|
+
For API key providers: Saves the key and returns success.
|
|
82
|
+
"""
|
|
83
|
+
manager = AuthManager()
|
|
84
|
+
|
|
85
|
+
if request.provider == "github-copilot":
|
|
86
|
+
# OAuth device flow
|
|
87
|
+
async with httpx.AsyncClient() as client:
|
|
88
|
+
try:
|
|
89
|
+
device_code = await request_device_code(client)
|
|
90
|
+
except httpx.HTTPError as e:
|
|
91
|
+
raise HTTPException(status_code=502, detail=f"Failed to get device code: {e}")
|
|
92
|
+
|
|
93
|
+
# Create session for polling
|
|
94
|
+
session = oauth_sessions.create_session(
|
|
95
|
+
provider=request.provider,
|
|
96
|
+
device_code=device_code.device_code,
|
|
97
|
+
user_code=device_code.user_code,
|
|
98
|
+
verification_uri=device_code.verification_uri,
|
|
99
|
+
expires_in=device_code.expires_in,
|
|
100
|
+
interval=device_code.interval,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# Start background task to poll for token
|
|
104
|
+
background_tasks.add_task(
|
|
105
|
+
_poll_oauth_completion,
|
|
106
|
+
session.session_id,
|
|
107
|
+
device_code.device_code,
|
|
108
|
+
device_code.interval,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
return OAuthInitResponse(
|
|
112
|
+
session_id=session.session_id,
|
|
113
|
+
user_code=device_code.user_code,
|
|
114
|
+
verification_uri=device_code.verification_uri,
|
|
115
|
+
expires_in=device_code.expires_in,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
elif request.api_key:
|
|
119
|
+
# API key auth
|
|
120
|
+
manager.login_api_key(request.provider, request.api_key)
|
|
121
|
+
# Reset router to pick up new authentication
|
|
122
|
+
reset_router()
|
|
123
|
+
return {"success": True, "provider": request.provider}
|
|
124
|
+
|
|
125
|
+
else:
|
|
126
|
+
raise HTTPException(
|
|
127
|
+
status_code=400,
|
|
128
|
+
detail=f"Provider '{request.provider}' requires an API key",
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
async def _poll_oauth_completion(
|
|
133
|
+
session_id: str,
|
|
134
|
+
device_code: str,
|
|
135
|
+
interval: int,
|
|
136
|
+
) -> None:
|
|
137
|
+
"""Background task to poll for OAuth completion and save credentials."""
|
|
138
|
+
manager = AuthManager()
|
|
139
|
+
|
|
140
|
+
async with httpx.AsyncClient() as client:
|
|
141
|
+
try:
|
|
142
|
+
# Poll for access token
|
|
143
|
+
access_token = await poll_access_token(
|
|
144
|
+
client,
|
|
145
|
+
device_code,
|
|
146
|
+
interval=interval,
|
|
147
|
+
timeout=900, # 15 minutes
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
# Get Copilot token
|
|
151
|
+
copilot_token = await get_copilot_token(client, access_token.access_token)
|
|
152
|
+
|
|
153
|
+
# Save credentials
|
|
154
|
+
manager.storage.set(
|
|
155
|
+
"github-copilot",
|
|
156
|
+
OAuthCredential(
|
|
157
|
+
refresh=access_token.access_token,
|
|
158
|
+
access=copilot_token.token,
|
|
159
|
+
expires=copilot_token.expires_at,
|
|
160
|
+
),
|
|
161
|
+
)
|
|
162
|
+
manager.save()
|
|
163
|
+
|
|
164
|
+
# Update session status
|
|
165
|
+
oauth_sessions.update_session_status(
|
|
166
|
+
session_id,
|
|
167
|
+
status="complete",
|
|
168
|
+
access_token=copilot_token.token,
|
|
169
|
+
refresh_token=access_token.access_token,
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
# Reset router to pick up new authentication
|
|
173
|
+
reset_router()
|
|
174
|
+
|
|
175
|
+
except GitHubOAuthError as e:
|
|
176
|
+
oauth_sessions.update_session_status(
|
|
177
|
+
session_id,
|
|
178
|
+
status="error",
|
|
179
|
+
error=str(e),
|
|
180
|
+
)
|
|
181
|
+
except Exception as e:
|
|
182
|
+
oauth_sessions.update_session_status(
|
|
183
|
+
session_id,
|
|
184
|
+
status="error",
|
|
185
|
+
error=f"Unexpected error: {e}",
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
@router.get("/auth/oauth/status/{session_id}", response_model=OAuthStatusResponse)
|
|
190
|
+
async def get_oauth_status(session_id: str) -> OAuthStatusResponse:
|
|
191
|
+
"""Get OAuth session status for polling."""
|
|
192
|
+
session = oauth_sessions.get_session(session_id)
|
|
193
|
+
if not session:
|
|
194
|
+
raise HTTPException(status_code=404, detail="Session not found")
|
|
195
|
+
|
|
196
|
+
return OAuthStatusResponse(
|
|
197
|
+
status=session.status,
|
|
198
|
+
error=session.error,
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
@router.delete("/auth/{provider}")
|
|
203
|
+
async def logout(provider: str) -> dict:
|
|
204
|
+
"""Log out from a provider."""
|
|
205
|
+
manager = AuthManager()
|
|
206
|
+
|
|
207
|
+
if manager.logout(provider):
|
|
208
|
+
# Reset router to reflect authentication change
|
|
209
|
+
reset_router()
|
|
210
|
+
return {"success": True, "provider": provider}
|
|
211
|
+
else:
|
|
212
|
+
raise HTTPException(status_code=404, detail=f"Not authenticated with {provider}")
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
# ============================================================================
|
|
216
|
+
# Model endpoints
|
|
217
|
+
# ============================================================================
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
@router.get("/models", response_model=ModelsResponse)
|
|
221
|
+
async def list_models() -> ModelsResponse:
|
|
222
|
+
"""List all available models from authenticated providers."""
|
|
223
|
+
router_instance = get_router()
|
|
224
|
+
|
|
225
|
+
try:
|
|
226
|
+
models = await router_instance.list_models()
|
|
227
|
+
except Exception as e:
|
|
228
|
+
raise HTTPException(status_code=500, detail=f"Failed to list models: {e}")
|
|
229
|
+
|
|
230
|
+
model_list = [
|
|
231
|
+
ModelInfo(
|
|
232
|
+
provider=model.provider,
|
|
233
|
+
id=model.id,
|
|
234
|
+
name=model.name,
|
|
235
|
+
)
|
|
236
|
+
for model in models
|
|
237
|
+
]
|
|
238
|
+
|
|
239
|
+
return ModelsResponse(models=model_list)
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
@router.post("/models/refresh")
|
|
243
|
+
async def refresh_models() -> dict:
|
|
244
|
+
"""Force refresh the models cache."""
|
|
245
|
+
router_instance = get_router()
|
|
246
|
+
router_instance.invalidate_cache()
|
|
247
|
+
# Trigger re-population
|
|
248
|
+
try:
|
|
249
|
+
models = await router_instance.list_models()
|
|
250
|
+
return {"success": True, "models_count": len(models)}
|
|
251
|
+
except Exception as e:
|
|
252
|
+
raise HTTPException(status_code=500, detail=f"Failed to refresh models: {e}")
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
# ============================================================================
|
|
256
|
+
# Priority endpoints
|
|
257
|
+
# ============================================================================
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
@router.get("/priorities", response_model=PrioritiesResponse)
|
|
261
|
+
async def get_priorities() -> PrioritiesResponse:
|
|
262
|
+
"""Get current priority configuration."""
|
|
263
|
+
config = load_priorities_config()
|
|
264
|
+
|
|
265
|
+
return PrioritiesResponse(
|
|
266
|
+
priorities=config.priorities,
|
|
267
|
+
fallback=config.fallback.model_dump(),
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
@router.put("/priorities", response_model=PrioritiesResponse)
|
|
272
|
+
async def update_priorities(request: PrioritiesUpdateRequest) -> PrioritiesResponse:
|
|
273
|
+
"""Update priority configuration."""
|
|
274
|
+
config = load_priorities_config()
|
|
275
|
+
|
|
276
|
+
# Update priorities
|
|
277
|
+
config.priorities = request.priorities
|
|
278
|
+
|
|
279
|
+
# Update fallback if provided
|
|
280
|
+
if request.fallback is not None:
|
|
281
|
+
from router_maestro.config import FallbackConfig
|
|
282
|
+
|
|
283
|
+
config.fallback = FallbackConfig.model_validate(request.fallback)
|
|
284
|
+
|
|
285
|
+
save_priorities_config(config)
|
|
286
|
+
|
|
287
|
+
return PrioritiesResponse(
|
|
288
|
+
priorities=config.priorities,
|
|
289
|
+
fallback=config.fallback.model_dump(),
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
# ============================================================================
|
|
294
|
+
# Stats endpoints
|
|
295
|
+
# ============================================================================
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
@router.get("/stats", response_model=StatsResponse)
|
|
299
|
+
async def get_stats(
|
|
300
|
+
days: Annotated[int, Query(ge=1, le=365)] = 7,
|
|
301
|
+
provider: str | None = None,
|
|
302
|
+
model: str | None = None,
|
|
303
|
+
) -> StatsResponse:
|
|
304
|
+
"""Get usage statistics."""
|
|
305
|
+
storage = StatsStorage()
|
|
306
|
+
|
|
307
|
+
# Get total stats
|
|
308
|
+
total = storage.get_total_usage(days=days)
|
|
309
|
+
|
|
310
|
+
# Get stats by model (which includes provider info)
|
|
311
|
+
by_model_raw = storage.get_usage_by_model(days=days)
|
|
312
|
+
|
|
313
|
+
# Aggregate by provider
|
|
314
|
+
by_provider: dict[str, dict] = {}
|
|
315
|
+
by_model: dict[str, dict] = {}
|
|
316
|
+
|
|
317
|
+
for record in by_model_raw:
|
|
318
|
+
provider_name = record["provider"]
|
|
319
|
+
model_name = record["model"]
|
|
320
|
+
model_key = f"{provider_name}/{model_name}"
|
|
321
|
+
|
|
322
|
+
# Filter if requested
|
|
323
|
+
if provider and provider_name != provider:
|
|
324
|
+
continue
|
|
325
|
+
if model and model_name != model:
|
|
326
|
+
continue
|
|
327
|
+
|
|
328
|
+
# Aggregate by provider
|
|
329
|
+
if provider_name not in by_provider:
|
|
330
|
+
by_provider[provider_name] = {
|
|
331
|
+
"total_tokens": 0,
|
|
332
|
+
"prompt_tokens": 0,
|
|
333
|
+
"completion_tokens": 0,
|
|
334
|
+
"request_count": 0,
|
|
335
|
+
}
|
|
336
|
+
|
|
337
|
+
by_provider[provider_name]["total_tokens"] += record.get("total_tokens", 0) or 0
|
|
338
|
+
by_provider[provider_name]["prompt_tokens"] += record.get("prompt_tokens", 0) or 0
|
|
339
|
+
by_provider[provider_name]["completion_tokens"] += record.get("completion_tokens", 0) or 0
|
|
340
|
+
by_provider[provider_name]["request_count"] += record.get("request_count", 0) or 0
|
|
341
|
+
|
|
342
|
+
# Store by model
|
|
343
|
+
by_model[model_key] = {
|
|
344
|
+
"total_tokens": record.get("total_tokens", 0) or 0,
|
|
345
|
+
"prompt_tokens": record.get("prompt_tokens", 0) or 0,
|
|
346
|
+
"completion_tokens": record.get("completion_tokens", 0) or 0,
|
|
347
|
+
"request_count": record.get("request_count", 0) or 0,
|
|
348
|
+
"avg_latency_ms": record.get("avg_latency_ms", 0) or 0,
|
|
349
|
+
}
|
|
350
|
+
|
|
351
|
+
return StatsResponse(
|
|
352
|
+
total_requests=total.get("request_count", 0) or 0,
|
|
353
|
+
total_tokens=total.get("total_tokens", 0) or 0,
|
|
354
|
+
prompt_tokens=total.get("prompt_tokens", 0) or 0,
|
|
355
|
+
completion_tokens=total.get("completion_tokens", 0) or 0,
|
|
356
|
+
by_provider=by_provider,
|
|
357
|
+
by_model=by_model,
|
|
358
|
+
)
|