indoxrouter 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- indoxRouter/__init__.py +83 -0
- indoxRouter/client.py +564 -218
- indoxRouter/client_resourses/__init__.py +20 -0
- indoxRouter/client_resourses/base.py +67 -0
- indoxRouter/client_resourses/chat.py +144 -0
- indoxRouter/client_resourses/completion.py +138 -0
- indoxRouter/client_resourses/embedding.py +83 -0
- indoxRouter/client_resourses/image.py +116 -0
- indoxRouter/client_resourses/models.py +114 -0
- indoxRouter/config.py +151 -0
- indoxRouter/constants/__init__.py +81 -0
- indoxRouter/exceptions/__init__.py +70 -0
- indoxRouter/models/__init__.py +111 -0
- indoxRouter/providers/__init__.py +50 -50
- indoxRouter/providers/ai21labs.json +128 -0
- indoxRouter/providers/base_provider.py +62 -30
- indoxRouter/providers/claude.json +164 -0
- indoxRouter/providers/cohere.json +116 -0
- indoxRouter/providers/databricks.json +110 -0
- indoxRouter/providers/deepseek.json +110 -0
- indoxRouter/providers/google.json +128 -0
- indoxRouter/providers/meta.json +128 -0
- indoxRouter/providers/mistral.json +146 -0
- indoxRouter/providers/nvidia.json +110 -0
- indoxRouter/providers/openai.json +308 -0
- indoxRouter/providers/openai.py +471 -72
- indoxRouter/providers/qwen.json +110 -0
- indoxRouter/utils/__init__.py +240 -0
- indoxrouter-0.1.2.dist-info/LICENSE +21 -0
- indoxrouter-0.1.2.dist-info/METADATA +259 -0
- indoxrouter-0.1.2.dist-info/RECORD +33 -0
- indoxRouter/api_endpoints.py +0 -336
- indoxRouter/client_package.py +0 -138
- indoxRouter/init_db.py +0 -71
- indoxRouter/main.py +0 -711
- indoxRouter/migrations/__init__.py +0 -1
- indoxRouter/migrations/env.py +0 -98
- indoxRouter/migrations/versions/__init__.py +0 -1
- indoxRouter/migrations/versions/initial_schema.py +0 -84
- indoxRouter/providers/ai21.py +0 -268
- indoxRouter/providers/claude.py +0 -177
- indoxRouter/providers/cohere.py +0 -171
- indoxRouter/providers/databricks.py +0 -166
- indoxRouter/providers/deepseek.py +0 -166
- indoxRouter/providers/google.py +0 -216
- indoxRouter/providers/llama.py +0 -164
- indoxRouter/providers/meta.py +0 -227
- indoxRouter/providers/mistral.py +0 -182
- indoxRouter/providers/nvidia.py +0 -164
- indoxrouter-0.1.0.dist-info/METADATA +0 -179
- indoxrouter-0.1.0.dist-info/RECORD +0 -27
- {indoxrouter-0.1.0.dist-info → indoxrouter-0.1.2.dist-info}/WHEEL +0 -0
- {indoxrouter-0.1.0.dist-info → indoxrouter-0.1.2.dist-info}/top_level.txt +0 -0
indoxRouter/main.py
DELETED
@@ -1,711 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Main FastAPI application for IndoxRouter.
|
3
|
-
"""
|
4
|
-
|
5
|
-
import os
|
6
|
-
import json
|
7
|
-
import logging
|
8
|
-
from typing import Dict, Any, Optional, List
|
9
|
-
from datetime import datetime
|
10
|
-
|
11
|
-
from fastapi import FastAPI, Depends, HTTPException, Request, status
|
12
|
-
from fastapi.middleware.cors import CORSMiddleware
|
13
|
-
from fastapi.responses import JSONResponse
|
14
|
-
from fastapi.security import APIKeyHeader
|
15
|
-
from pydantic import BaseModel, Field
|
16
|
-
|
17
|
-
# Import utility modules
|
18
|
-
from .utils.database import execute_query, close_all_connections
|
19
|
-
from .utils.auth import (
|
20
|
-
verify_api_key,
|
21
|
-
authenticate_user,
|
22
|
-
generate_jwt_token,
|
23
|
-
verify_jwt_token,
|
24
|
-
AuthManager,
|
25
|
-
)
|
26
|
-
from .models.database import User, ApiKey, RequestLog, ProviderConfig, Credit
|
27
|
-
from .utils.config import get_config
|
28
|
-
from .providers import get_provider
|
29
|
-
|
30
|
-
# Import API endpoints router (will be created later)
|
31
|
-
# from .api_endpoints import router as api_router
|
32
|
-
|
33
|
-
# Configure logging
|
34
|
-
logging.basicConfig(
|
35
|
-
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
36
|
-
)
|
37
|
-
logger = logging.getLogger(__name__)
|
38
|
-
|
39
|
-
|
40
|
-
# Load configuration
|
41
|
-
def load_config():
|
42
|
-
"""Load configuration from config.json file."""
|
43
|
-
config_path = os.path.join(os.path.dirname(__file__), "config.json")
|
44
|
-
try:
|
45
|
-
with open(config_path, "r") as f:
|
46
|
-
return json.load(f)
|
47
|
-
except FileNotFoundError:
|
48
|
-
logger.error(f"Configuration file not found at {config_path}")
|
49
|
-
return {}
|
50
|
-
except json.JSONDecodeError:
|
51
|
-
logger.error(f"Invalid JSON in configuration file at {config_path}")
|
52
|
-
return {}
|
53
|
-
|
54
|
-
|
55
|
-
config = load_config()
|
56
|
-
|
57
|
-
# Create FastAPI app
|
58
|
-
app = FastAPI(
|
59
|
-
title="IndoxRouter API",
|
60
|
-
description="A unified API for multiple LLM providers",
|
61
|
-
version="0.1.0",
|
62
|
-
)
|
63
|
-
|
64
|
-
# Add CORS middleware
|
65
|
-
cors_origins = config.get("api", {}).get("cors_origins", ["*"])
|
66
|
-
app.add_middleware(
|
67
|
-
CORSMiddleware,
|
68
|
-
allow_origins=cors_origins,
|
69
|
-
allow_credentials=True,
|
70
|
-
allow_methods=["*"],
|
71
|
-
allow_headers=["*"],
|
72
|
-
)
|
73
|
-
|
74
|
-
# API key security
|
75
|
-
api_key_header = APIKeyHeader(name="X-API-Key")
|
76
|
-
|
77
|
-
# Create the auth manager
|
78
|
-
auth_manager = AuthManager()
|
79
|
-
|
80
|
-
|
81
|
-
# Define request and response models
|
82
|
-
class CompletionRequest(BaseModel):
|
83
|
-
provider: str = Field(..., description="The LLM provider to use")
|
84
|
-
model: str = Field(..., description="The model to use")
|
85
|
-
prompt: str = Field(..., description="The prompt to send to the model")
|
86
|
-
max_tokens: Optional[int] = Field(
|
87
|
-
None, description="Maximum number of tokens to generate"
|
88
|
-
)
|
89
|
-
temperature: Optional[float] = Field(0.7, description="Temperature for sampling")
|
90
|
-
top_p: Optional[float] = Field(1.0, description="Top-p sampling parameter")
|
91
|
-
stop: Optional[List[str]] = Field(
|
92
|
-
None, description="Sequences where the API will stop generating"
|
93
|
-
)
|
94
|
-
stream: Optional[bool] = Field(False, description="Whether to stream the response")
|
95
|
-
|
96
|
-
|
97
|
-
class CompletionResponse(BaseModel):
|
98
|
-
id: str = Field(..., description="Unique identifier for the completion")
|
99
|
-
provider: str = Field(..., description="The provider used")
|
100
|
-
model: str = Field(..., description="The model used")
|
101
|
-
text: str = Field(..., description="The generated text")
|
102
|
-
usage: Dict[str, int] = Field(..., description="Token usage information")
|
103
|
-
created_at: datetime = Field(..., description="Timestamp of creation")
|
104
|
-
|
105
|
-
|
106
|
-
class LoginRequest(BaseModel):
|
107
|
-
email: str = Field(..., description="User email")
|
108
|
-
password: str = Field(..., description="User password")
|
109
|
-
|
110
|
-
|
111
|
-
class LoginResponse(BaseModel):
|
112
|
-
access_token: str = Field(..., description="JWT access token")
|
113
|
-
refresh_token: str = Field(..., description="JWT refresh token")
|
114
|
-
token_type: str = Field("bearer", description="Token type")
|
115
|
-
expires_in: int = Field(..., description="Token expiry in seconds")
|
116
|
-
user_id: int = Field(..., description="User ID")
|
117
|
-
is_admin: bool = Field(False, description="Whether the user is an admin")
|
118
|
-
|
119
|
-
|
120
|
-
class ApiKeyRequest(BaseModel):
|
121
|
-
key_name: str = Field(..., description="Name for the API key")
|
122
|
-
expires_days: Optional[int] = Field(
|
123
|
-
None, description="Days until expiry (None for no expiry)"
|
124
|
-
)
|
125
|
-
|
126
|
-
|
127
|
-
class ApiKeyResponse(BaseModel):
|
128
|
-
key: str = Field(..., description="The generated API key")
|
129
|
-
key_id: int = Field(..., description="ID of the API key")
|
130
|
-
key_name: str = Field(..., description="Name of the API key")
|
131
|
-
expires_at: Optional[datetime] = Field(None, description="Expiry date")
|
132
|
-
|
133
|
-
|
134
|
-
# Authentication dependency
|
135
|
-
async def get_current_user(authorization: str = None):
|
136
|
-
"""
|
137
|
-
Get the current user from the API key.
|
138
|
-
|
139
|
-
Args:
|
140
|
-
authorization: Authorization header
|
141
|
-
|
142
|
-
Returns:
|
143
|
-
User data
|
144
|
-
"""
|
145
|
-
if not authorization:
|
146
|
-
raise HTTPException(status_code=401, detail="API key is required")
|
147
|
-
|
148
|
-
if not authorization.startswith("Bearer "):
|
149
|
-
raise HTTPException(status_code=401, detail="Invalid authorization header")
|
150
|
-
|
151
|
-
api_key = authorization.replace("Bearer ", "")
|
152
|
-
|
153
|
-
user_data = auth_manager.verify_api_key(api_key)
|
154
|
-
if not user_data:
|
155
|
-
raise HTTPException(status_code=401, detail="Invalid API key")
|
156
|
-
|
157
|
-
return user_data
|
158
|
-
|
159
|
-
|
160
|
-
# Routes
|
161
|
-
@app.get("/")
|
162
|
-
async def root():
|
163
|
-
"""Root endpoint."""
|
164
|
-
return {"message": "Welcome to IndoxRouter", "version": "1.0.0"}
|
165
|
-
|
166
|
-
|
167
|
-
@app.post("/v1/completions", response_model=CompletionResponse)
|
168
|
-
async def create_completion(
|
169
|
-
request: CompletionRequest,
|
170
|
-
user_data: Dict[str, Any] = Depends(get_current_user),
|
171
|
-
req: Request = None,
|
172
|
-
):
|
173
|
-
"""Create a completion using the specified provider."""
|
174
|
-
try:
|
175
|
-
# Get the provider
|
176
|
-
provider_instance = get_provider(request.provider)
|
177
|
-
if not provider_instance:
|
178
|
-
raise HTTPException(
|
179
|
-
status_code=400, detail=f"Provider '{request.provider}' not found"
|
180
|
-
)
|
181
|
-
|
182
|
-
# Generate the completion
|
183
|
-
start_time = time.time()
|
184
|
-
response = provider_instance.generate(
|
185
|
-
model=request.model,
|
186
|
-
prompt=request.prompt,
|
187
|
-
temperature=request.temperature,
|
188
|
-
max_tokens=request.max_tokens,
|
189
|
-
top_p=request.top_p,
|
190
|
-
stop=request.stop,
|
191
|
-
)
|
192
|
-
process_time = time.time() - start_time
|
193
|
-
|
194
|
-
# Log the request
|
195
|
-
from indoxRouter.utils.database import get_session
|
196
|
-
|
197
|
-
session = get_session()
|
198
|
-
try:
|
199
|
-
log = RequestLog(
|
200
|
-
user_id=user_data["id"],
|
201
|
-
api_key_id=user_data["api_key_id"],
|
202
|
-
provider=request.provider,
|
203
|
-
model=request.model,
|
204
|
-
prompt=request.prompt,
|
205
|
-
response=response,
|
206
|
-
tokens_input=len(request.prompt.split()),
|
207
|
-
tokens_output=len(response.split()),
|
208
|
-
latency_ms=int(process_time * 1000),
|
209
|
-
status_code=200,
|
210
|
-
ip_address=req.client.host if req else None,
|
211
|
-
user_agent=req.headers.get("User-Agent") if req else None,
|
212
|
-
)
|
213
|
-
session.add(log)
|
214
|
-
session.commit()
|
215
|
-
except Exception as e:
|
216
|
-
logger.error(f"Error logging request: {e}")
|
217
|
-
session.rollback()
|
218
|
-
finally:
|
219
|
-
session.close()
|
220
|
-
|
221
|
-
# Return the response
|
222
|
-
return {
|
223
|
-
"id": f"cmpl-{int(time.time())}",
|
224
|
-
"provider": request.provider,
|
225
|
-
"model": request.model,
|
226
|
-
"text": response,
|
227
|
-
"usage": {
|
228
|
-
"prompt_tokens": len(request.prompt.split()),
|
229
|
-
"completion_tokens": len(response.split()),
|
230
|
-
"total_tokens": len(request.prompt.split()) + len(response.split()),
|
231
|
-
},
|
232
|
-
"created_at": datetime.now(),
|
233
|
-
}
|
234
|
-
except Exception as e:
|
235
|
-
logger.error(f"Error creating completion: {e}")
|
236
|
-
raise HTTPException(
|
237
|
-
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
238
|
-
detail=f"Error creating completion: {str(e)}",
|
239
|
-
)
|
240
|
-
|
241
|
-
|
242
|
-
@app.post("/auth/login", response_model=LoginResponse)
|
243
|
-
async def login(request: LoginRequest):
|
244
|
-
"""Login endpoint to get JWT tokens."""
|
245
|
-
user = authenticate_user(request.email, request.password)
|
246
|
-
if not user:
|
247
|
-
raise HTTPException(
|
248
|
-
status_code=status.HTTP_401_UNAUTHORIZED,
|
249
|
-
detail="Invalid email or password",
|
250
|
-
headers={"WWW-Authenticate": "Bearer"},
|
251
|
-
)
|
252
|
-
|
253
|
-
access_token, refresh_token, expires_in, _ = generate_jwt_token(
|
254
|
-
user["id"], user["is_admin"]
|
255
|
-
)
|
256
|
-
|
257
|
-
return {
|
258
|
-
"access_token": access_token,
|
259
|
-
"refresh_token": refresh_token,
|
260
|
-
"token_type": "bearer",
|
261
|
-
"expires_in": expires_in,
|
262
|
-
"user_id": user["id"],
|
263
|
-
"is_admin": user["is_admin"],
|
264
|
-
}
|
265
|
-
|
266
|
-
|
267
|
-
@app.post("/auth/refresh", response_model=Dict[str, Any])
|
268
|
-
async def refresh_token(refresh_token: str):
|
269
|
-
"""Refresh an access token using a refresh token."""
|
270
|
-
result = verify_jwt_token(refresh_token)
|
271
|
-
if not result or result.get("type") != "refresh":
|
272
|
-
raise HTTPException(
|
273
|
-
status_code=status.HTTP_401_UNAUTHORIZED,
|
274
|
-
detail="Invalid refresh token",
|
275
|
-
headers={"WWW-Authenticate": "Bearer"},
|
276
|
-
)
|
277
|
-
|
278
|
-
# Get new access token
|
279
|
-
access_token, expires_in = result
|
280
|
-
|
281
|
-
return {
|
282
|
-
"access_token": access_token,
|
283
|
-
"token_type": "bearer",
|
284
|
-
"expires_in": expires_in,
|
285
|
-
}
|
286
|
-
|
287
|
-
|
288
|
-
@app.post("/api-keys", response_model=ApiKeyResponse)
|
289
|
-
async def create_api_key(
|
290
|
-
request: ApiKeyRequest, user_data: Dict[str, Any] = Depends(get_current_user)
|
291
|
-
):
|
292
|
-
"""Create a new API key for the current user."""
|
293
|
-
try:
|
294
|
-
api_key, key_id = auth_manager.generate_api_key(
|
295
|
-
user_id=user_data["id"],
|
296
|
-
key_name=request.key_name,
|
297
|
-
expires_days=request.expires_days,
|
298
|
-
)
|
299
|
-
|
300
|
-
# Get the key record
|
301
|
-
key_record = ApiKey.get_by_id(key_id)
|
302
|
-
|
303
|
-
return {
|
304
|
-
"key": api_key,
|
305
|
-
"key_id": key_id,
|
306
|
-
"key_name": key_record["key_name"],
|
307
|
-
"expires_at": key_record["expires_at"],
|
308
|
-
}
|
309
|
-
except Exception as e:
|
310
|
-
logger.error(f"Error creating API key: {e}")
|
311
|
-
raise HTTPException(
|
312
|
-
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
313
|
-
detail=f"Error creating API key: {str(e)}",
|
314
|
-
)
|
315
|
-
|
316
|
-
|
317
|
-
@app.get("/api-keys", response_model=List[Dict[str, Any]])
|
318
|
-
async def list_api_keys(user_data: Dict[str, Any] = Depends(get_current_user)):
|
319
|
-
"""List all API keys for the current user."""
|
320
|
-
try:
|
321
|
-
keys = ApiKey.list_by_user(user_data["id"])
|
322
|
-
return keys
|
323
|
-
except Exception as e:
|
324
|
-
logger.error(f"Error listing API keys: {e}")
|
325
|
-
raise HTTPException(
|
326
|
-
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
327
|
-
detail=f"Error listing API keys: {str(e)}",
|
328
|
-
)
|
329
|
-
|
330
|
-
|
331
|
-
@app.delete("/api-keys/{key_id}", response_model=Dict[str, bool])
|
332
|
-
async def delete_api_key(
|
333
|
-
key_id: int, user_data: Dict[str, Any] = Depends(get_current_user)
|
334
|
-
):
|
335
|
-
"""Delete an API key."""
|
336
|
-
try:
|
337
|
-
# Check if the key belongs to the user
|
338
|
-
key = ApiKey.get_by_id(key_id)
|
339
|
-
if not key or key["user_id"] != user_data["id"]:
|
340
|
-
raise HTTPException(
|
341
|
-
status_code=status.HTTP_404_NOT_FOUND, detail="API key not found"
|
342
|
-
)
|
343
|
-
|
344
|
-
# Delete the key
|
345
|
-
success = ApiKey.delete_key(key_id)
|
346
|
-
return {"success": success}
|
347
|
-
except HTTPException:
|
348
|
-
raise
|
349
|
-
except Exception as e:
|
350
|
-
logger.error(f"Error deleting API key: {e}")
|
351
|
-
raise HTTPException(
|
352
|
-
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
353
|
-
detail=f"Error deleting API key: {str(e)}",
|
354
|
-
)
|
355
|
-
|
356
|
-
|
357
|
-
@app.get("/health")
|
358
|
-
async def health_check():
|
359
|
-
"""Health check endpoint."""
|
360
|
-
return {"status": "ok", "timestamp": datetime.now().isoformat()}
|
361
|
-
|
362
|
-
|
363
|
-
# Startup and shutdown events
|
364
|
-
@app.on_event("startup")
|
365
|
-
async def startup_event():
|
366
|
-
"""Run on application startup."""
|
367
|
-
logger.info("Starting IndoxRouter application")
|
368
|
-
|
369
|
-
# TODO: Initialize any resources needed at startup
|
370
|
-
|
371
|
-
|
372
|
-
@app.on_event("shutdown")
|
373
|
-
async def shutdown_event():
|
374
|
-
"""Run on application shutdown."""
|
375
|
-
logger.info("Shutting down IndoxRouter application")
|
376
|
-
|
377
|
-
# Close database connections
|
378
|
-
close_all_connections()
|
379
|
-
|
380
|
-
|
381
|
-
# Exception handlers
|
382
|
-
@app.exception_handler(Exception)
|
383
|
-
async def global_exception_handler(request: Request, exc: Exception):
|
384
|
-
"""Global exception handler."""
|
385
|
-
logger.error(f"Unhandled exception: {exc}")
|
386
|
-
return JSONResponse(
|
387
|
-
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
388
|
-
content={"detail": "An unexpected error occurred"},
|
389
|
-
)
|
390
|
-
|
391
|
-
|
392
|
-
# New API endpoints for the client package
|
393
|
-
|
394
|
-
|
395
|
-
class GenerateRequest(BaseModel):
|
396
|
-
"""Request model for text generation."""
|
397
|
-
|
398
|
-
prompt: str = Field(..., description="The prompt to send to the model")
|
399
|
-
model: Optional[str] = Field(None, description="The model to use")
|
400
|
-
provider: Optional[str] = Field(None, description="The provider to use")
|
401
|
-
temperature: Optional[float] = Field(0.7, description="Temperature for sampling")
|
402
|
-
max_tokens: Optional[int] = Field(
|
403
|
-
1000, description="Maximum number of tokens to generate"
|
404
|
-
)
|
405
|
-
top_p: Optional[float] = Field(1.0, description="Top-p sampling parameter")
|
406
|
-
stop: Optional[List[str]] = Field(
|
407
|
-
None, description="Sequences where the API will stop generating"
|
408
|
-
)
|
409
|
-
stream: Optional[bool] = Field(False, description="Whether to stream the response")
|
410
|
-
|
411
|
-
|
412
|
-
class GenerateResponse(BaseModel):
|
413
|
-
"""Response model for text generation."""
|
414
|
-
|
415
|
-
id: str = Field(..., description="Unique identifier for the completion")
|
416
|
-
text: str = Field(..., description="The generated text")
|
417
|
-
provider: str = Field(..., description="The provider used")
|
418
|
-
model: str = Field(..., description="The model used")
|
419
|
-
usage: Dict[str, int] = Field(..., description="Token usage information")
|
420
|
-
created_at: datetime = Field(..., description="Timestamp of creation")
|
421
|
-
|
422
|
-
|
423
|
-
class ModelInfo(BaseModel):
|
424
|
-
"""Model for LLM model information."""
|
425
|
-
|
426
|
-
id: str = Field(..., description="Model identifier")
|
427
|
-
name: str = Field(..., description="Model name")
|
428
|
-
provider: str = Field(..., description="Provider name")
|
429
|
-
description: Optional[str] = Field(None, description="Model description")
|
430
|
-
max_tokens: Optional[int] = Field(None, description="Maximum tokens supported")
|
431
|
-
pricing: Optional[Dict[str, float]] = Field(None, description="Pricing information")
|
432
|
-
|
433
|
-
|
434
|
-
class ProviderInfo(BaseModel):
|
435
|
-
"""Model for LLM provider information."""
|
436
|
-
|
437
|
-
id: str = Field(..., description="Provider identifier")
|
438
|
-
name: str = Field(..., description="Provider name")
|
439
|
-
description: Optional[str] = Field(None, description="Provider description")
|
440
|
-
website: Optional[str] = Field(None, description="Provider website")
|
441
|
-
models: List[str] = Field(..., description="Available models")
|
442
|
-
|
443
|
-
|
444
|
-
class UserInfo(BaseModel):
|
445
|
-
"""Model for user information."""
|
446
|
-
|
447
|
-
id: int = Field(..., description="User ID")
|
448
|
-
email: str = Field(..., description="User email")
|
449
|
-
first_name: Optional[str] = Field(None, description="User first name")
|
450
|
-
last_name: Optional[str] = Field(None, description="User last name")
|
451
|
-
is_admin: bool = Field(False, description="Whether the user is an admin")
|
452
|
-
created_at: datetime = Field(..., description="Account creation timestamp")
|
453
|
-
|
454
|
-
|
455
|
-
class BalanceInfo(BaseModel):
|
456
|
-
"""Model for user balance information."""
|
457
|
-
|
458
|
-
credits: float = Field(..., description="Available credits")
|
459
|
-
usage: Dict[str, float] = Field(..., description="Usage statistics")
|
460
|
-
last_updated: datetime = Field(..., description="Last updated timestamp")
|
461
|
-
|
462
|
-
|
463
|
-
@app.post("/api/v1/generate", response_model=GenerateResponse)
|
464
|
-
async def generate(
|
465
|
-
request: GenerateRequest,
|
466
|
-
user_data: Dict[str, Any] = Depends(get_current_user),
|
467
|
-
req: Request = None,
|
468
|
-
):
|
469
|
-
"""Generate text using the specified model and provider."""
|
470
|
-
try:
|
471
|
-
# Set default model and provider if not specified
|
472
|
-
provider_name = request.provider or get_config().get(
|
473
|
-
"default_provider", "openai"
|
474
|
-
)
|
475
|
-
model_name = request.model or get_config().get("default_model", "gpt-4o-mini")
|
476
|
-
|
477
|
-
# Get the provider
|
478
|
-
provider = get_provider(provider_name)
|
479
|
-
if not provider:
|
480
|
-
raise HTTPException(
|
481
|
-
status_code=400, detail=f"Provider '{provider_name}' not found"
|
482
|
-
)
|
483
|
-
|
484
|
-
# Generate the completion
|
485
|
-
completion = await provider.generate(
|
486
|
-
model=model_name,
|
487
|
-
prompt=request.prompt,
|
488
|
-
max_tokens=request.max_tokens,
|
489
|
-
temperature=request.temperature,
|
490
|
-
top_p=request.top_p,
|
491
|
-
stop=request.stop,
|
492
|
-
stream=request.stream,
|
493
|
-
)
|
494
|
-
|
495
|
-
# Log the request
|
496
|
-
log_entry = RequestLog(
|
497
|
-
user_id=user_data["user_id"],
|
498
|
-
provider=provider_name,
|
499
|
-
model=model_name,
|
500
|
-
prompt_tokens=completion.get("usage", {}).get("prompt_tokens", 0),
|
501
|
-
completion_tokens=completion.get("usage", {}).get("completion_tokens", 0),
|
502
|
-
total_tokens=completion.get("usage", {}).get("total_tokens", 0),
|
503
|
-
created_at=datetime.utcnow(),
|
504
|
-
)
|
505
|
-
execute_query(lambda session: session.add(log_entry))
|
506
|
-
|
507
|
-
# Return the response
|
508
|
-
return {
|
509
|
-
"id": completion.get("id", ""),
|
510
|
-
"text": completion.get("text", ""),
|
511
|
-
"provider": provider_name,
|
512
|
-
"model": model_name,
|
513
|
-
"usage": completion.get("usage", {}),
|
514
|
-
"created_at": datetime.utcnow(),
|
515
|
-
}
|
516
|
-
except Exception as e:
|
517
|
-
logger.error(f"Error generating completion: {str(e)}")
|
518
|
-
raise HTTPException(status_code=500, detail=str(e))
|
519
|
-
|
520
|
-
|
521
|
-
@app.get("/api/v1/models", response_model=List[ModelInfo])
|
522
|
-
async def list_models(
|
523
|
-
provider: Optional[str] = None,
|
524
|
-
user_data: Dict[str, Any] = Depends(get_current_user),
|
525
|
-
):
|
526
|
-
"""List available models, optionally filtered by provider."""
|
527
|
-
try:
|
528
|
-
# Get configuration
|
529
|
-
config = get_config()
|
530
|
-
models = []
|
531
|
-
|
532
|
-
# If provider is specified, only get models for that provider
|
533
|
-
if provider:
|
534
|
-
provider_obj = get_provider(provider)
|
535
|
-
if not provider_obj:
|
536
|
-
raise HTTPException(
|
537
|
-
status_code=400, detail=f"Provider '{provider}' not found"
|
538
|
-
)
|
539
|
-
provider_models = provider_obj.list_models()
|
540
|
-
for model in provider_models:
|
541
|
-
models.append(
|
542
|
-
{
|
543
|
-
"id": model.get("id", ""),
|
544
|
-
"name": model.get("name", ""),
|
545
|
-
"provider": provider,
|
546
|
-
"description": model.get("description", ""),
|
547
|
-
"max_tokens": model.get("max_tokens", 0),
|
548
|
-
"pricing": model.get("pricing", {}),
|
549
|
-
}
|
550
|
-
)
|
551
|
-
else:
|
552
|
-
# Get models for all providers
|
553
|
-
for provider_name in config.get("providers", {}).keys():
|
554
|
-
try:
|
555
|
-
provider_obj = get_provider(provider_name)
|
556
|
-
if provider_obj:
|
557
|
-
provider_models = provider_obj.list_models()
|
558
|
-
for model in provider_models:
|
559
|
-
models.append(
|
560
|
-
{
|
561
|
-
"id": model.get("id", ""),
|
562
|
-
"name": model.get("name", ""),
|
563
|
-
"provider": provider_name,
|
564
|
-
"description": model.get("description", ""),
|
565
|
-
"max_tokens": model.get("max_tokens", 0),
|
566
|
-
"pricing": model.get("pricing", {}),
|
567
|
-
}
|
568
|
-
)
|
569
|
-
except Exception as e:
|
570
|
-
logger.error(
|
571
|
-
f"Error getting models for provider {provider_name}: {str(e)}"
|
572
|
-
)
|
573
|
-
|
574
|
-
return models
|
575
|
-
except Exception as e:
|
576
|
-
logger.error(f"Error listing models: {str(e)}")
|
577
|
-
raise HTTPException(status_code=500, detail=str(e))
|
578
|
-
|
579
|
-
|
580
|
-
@app.get("/api/v1/providers", response_model=List[ProviderInfo])
|
581
|
-
async def list_providers(
|
582
|
-
user_data: Dict[str, Any] = Depends(get_current_user),
|
583
|
-
):
|
584
|
-
"""List available providers."""
|
585
|
-
try:
|
586
|
-
# Get configuration
|
587
|
-
config = get_config()
|
588
|
-
providers = []
|
589
|
-
|
590
|
-
# Get all providers
|
591
|
-
for provider_name, provider_config in config.get("providers", {}).items():
|
592
|
-
try:
|
593
|
-
provider_obj = get_provider(provider_name)
|
594
|
-
if provider_obj:
|
595
|
-
provider_models = provider_obj.list_models()
|
596
|
-
providers.append(
|
597
|
-
{
|
598
|
-
"id": provider_name,
|
599
|
-
"name": provider_config.get("name", provider_name),
|
600
|
-
"description": provider_config.get("description", ""),
|
601
|
-
"website": provider_config.get("website", ""),
|
602
|
-
"models": [
|
603
|
-
model.get("id", "") for model in provider_models
|
604
|
-
],
|
605
|
-
}
|
606
|
-
)
|
607
|
-
except Exception as e:
|
608
|
-
logger.error(f"Error getting provider {provider_name}: {str(e)}")
|
609
|
-
|
610
|
-
return providers
|
611
|
-
except Exception as e:
|
612
|
-
logger.error(f"Error listing providers: {str(e)}")
|
613
|
-
raise HTTPException(status_code=500, detail=str(e))
|
614
|
-
|
615
|
-
|
616
|
-
@app.get("/api/v1/user", response_model=UserInfo)
|
617
|
-
async def get_user(
|
618
|
-
user_data: Dict[str, Any] = Depends(get_current_user),
|
619
|
-
):
|
620
|
-
"""Get information about the authenticated user."""
|
621
|
-
try:
|
622
|
-
# Get user from database
|
623
|
-
user = None
|
624
|
-
|
625
|
-
def get_user_from_db(session):
|
626
|
-
nonlocal user
|
627
|
-
user = session.query(User).filter(User.id == user_data["user_id"]).first()
|
628
|
-
return user
|
629
|
-
|
630
|
-
execute_query(get_user_from_db)
|
631
|
-
|
632
|
-
if not user:
|
633
|
-
raise HTTPException(status_code=404, detail="User not found")
|
634
|
-
|
635
|
-
return {
|
636
|
-
"id": user.id,
|
637
|
-
"email": user.email,
|
638
|
-
"first_name": user.first_name,
|
639
|
-
"last_name": user.last_name,
|
640
|
-
"is_admin": user.is_admin,
|
641
|
-
"created_at": user.created_at,
|
642
|
-
}
|
643
|
-
except Exception as e:
|
644
|
-
logger.error(f"Error getting user: {str(e)}")
|
645
|
-
raise HTTPException(status_code=500, detail=str(e))
|
646
|
-
|
647
|
-
|
648
|
-
@app.get("/api/v1/user/balance", response_model=BalanceInfo)
|
649
|
-
async def get_balance(
|
650
|
-
user_data: Dict[str, Any] = Depends(get_current_user),
|
651
|
-
):
|
652
|
-
"""Get the user's current balance."""
|
653
|
-
try:
|
654
|
-
# Get user's credit from database
|
655
|
-
credit = None
|
656
|
-
usage_data = {}
|
657
|
-
|
658
|
-
def get_credit_from_db(session):
|
659
|
-
nonlocal credit
|
660
|
-
credit = (
|
661
|
-
session.query(Credit)
|
662
|
-
.filter(Credit.user_id == user_data["user_id"])
|
663
|
-
.first()
|
664
|
-
)
|
665
|
-
return credit
|
666
|
-
|
667
|
-
def get_usage_from_db(session):
|
668
|
-
nonlocal usage_data
|
669
|
-
# Get total usage by provider
|
670
|
-
usage_by_provider = {}
|
671
|
-
logs = (
|
672
|
-
session.query(RequestLog)
|
673
|
-
.filter(RequestLog.user_id == user_data["user_id"])
|
674
|
-
.all()
|
675
|
-
)
|
676
|
-
for log in logs:
|
677
|
-
provider = log.provider
|
678
|
-
if provider not in usage_by_provider:
|
679
|
-
usage_by_provider[provider] = 0
|
680
|
-
usage_by_provider[provider] += log.total_tokens
|
681
|
-
|
682
|
-
usage_data = usage_by_provider
|
683
|
-
return usage_by_provider
|
684
|
-
|
685
|
-
execute_query(get_credit_from_db)
|
686
|
-
execute_query(get_usage_from_db)
|
687
|
-
|
688
|
-
if not credit:
|
689
|
-
# Create a new credit entry with default values
|
690
|
-
credit = Credit(
|
691
|
-
user_id=user_data["user_id"],
|
692
|
-
amount=0.0,
|
693
|
-
last_updated=datetime.utcnow(),
|
694
|
-
)
|
695
|
-
execute_query(lambda session: session.add(credit))
|
696
|
-
|
697
|
-
return {
|
698
|
-
"credits": credit.amount,
|
699
|
-
"usage": usage_data,
|
700
|
-
"last_updated": credit.last_updated,
|
701
|
-
}
|
702
|
-
except Exception as e:
|
703
|
-
logger.error(f"Error getting balance: {str(e)}")
|
704
|
-
raise HTTPException(status_code=500, detail=str(e))
|
705
|
-
|
706
|
-
|
707
|
-
# If this file is run directly, start the application with Uvicorn
|
708
|
-
if __name__ == "__main__":
|
709
|
-
import uvicorn
|
710
|
-
|
711
|
-
uvicorn.run(app, host="0.0.0.0", port=8000)
|