indoxrouter 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- indoxRouter/__init__.py +0 -0
- indoxRouter/api_endpoints.py +336 -0
- indoxRouter/client.py +286 -0
- indoxRouter/client_package.py +138 -0
- indoxRouter/init_db.py +71 -0
- indoxRouter/main.py +711 -0
- indoxRouter/migrations/__init__.py +1 -0
- indoxRouter/migrations/env.py +98 -0
- indoxRouter/migrations/versions/__init__.py +1 -0
- indoxRouter/migrations/versions/initial_schema.py +84 -0
- indoxRouter/providers/__init__.py +108 -0
- indoxRouter/providers/ai21.py +268 -0
- indoxRouter/providers/base_provider.py +69 -0
- indoxRouter/providers/claude.py +177 -0
- indoxRouter/providers/cohere.py +171 -0
- indoxRouter/providers/databricks.py +166 -0
- indoxRouter/providers/deepseek.py +166 -0
- indoxRouter/providers/google.py +216 -0
- indoxRouter/providers/llama.py +164 -0
- indoxRouter/providers/meta.py +227 -0
- indoxRouter/providers/mistral.py +182 -0
- indoxRouter/providers/nvidia.py +164 -0
- indoxRouter/providers/openai.py +122 -0
- indoxrouter-0.1.0.dist-info/METADATA +179 -0
- indoxrouter-0.1.0.dist-info/RECORD +27 -0
- indoxrouter-0.1.0.dist-info/WHEEL +5 -0
- indoxrouter-0.1.0.dist-info/top_level.txt +1 -0
indoxRouter/__init__.py
ADDED
File without changes
|
@@ -0,0 +1,336 @@
|
|
1
|
+
"""
|
2
|
+
API endpoints for the IndoxRouter client package.
|
3
|
+
"""
|
4
|
+
|
5
|
+
import logging
|
6
|
+
from typing import Dict, Any, Optional, List
|
7
|
+
from datetime import datetime
|
8
|
+
|
9
|
+
from fastapi import APIRouter, Depends, HTTPException, Request
|
10
|
+
from pydantic import BaseModel, Field
|
11
|
+
|
12
|
+
from .utils.database import execute_query
|
13
|
+
from .utils.auth import get_current_user
|
14
|
+
from .models.database import User, RequestLog, Credit
|
15
|
+
from .utils.config import get_config
|
16
|
+
from .providers import get_provider
|
17
|
+
|
18
|
+
# Configure logging
|
19
|
+
logger = logging.getLogger(__name__)
|
20
|
+
|
21
|
+
# Create API router
|
22
|
+
router = APIRouter(prefix="/api/v1", tags=["Client API"])
|
23
|
+
|
24
|
+
|
25
|
+
# Model definitions
|
26
|
+
class GenerateRequest(BaseModel):
|
27
|
+
"""Request model for text generation."""
|
28
|
+
|
29
|
+
prompt: str = Field(..., description="The prompt to send to the model")
|
30
|
+
model: Optional[str] = Field(None, description="The model to use")
|
31
|
+
provider: Optional[str] = Field(None, description="The provider to use")
|
32
|
+
temperature: Optional[float] = Field(0.7, description="Temperature for sampling")
|
33
|
+
max_tokens: Optional[int] = Field(
|
34
|
+
1000, description="Maximum number of tokens to generate"
|
35
|
+
)
|
36
|
+
top_p: Optional[float] = Field(1.0, description="Top-p sampling parameter")
|
37
|
+
stop: Optional[List[str]] = Field(
|
38
|
+
None, description="Sequences where the API will stop generating"
|
39
|
+
)
|
40
|
+
stream: Optional[bool] = Field(False, description="Whether to stream the response")
|
41
|
+
|
42
|
+
|
43
|
+
class GenerateResponse(BaseModel):
|
44
|
+
"""Response model for text generation."""
|
45
|
+
|
46
|
+
id: str = Field(..., description="Unique identifier for the completion")
|
47
|
+
text: str = Field(..., description="The generated text")
|
48
|
+
provider: str = Field(..., description="The provider used")
|
49
|
+
model: str = Field(..., description="The model used")
|
50
|
+
usage: Dict[str, int] = Field(..., description="Token usage information")
|
51
|
+
created_at: datetime = Field(..., description="Timestamp of creation")
|
52
|
+
|
53
|
+
|
54
|
+
class ModelInfo(BaseModel):
|
55
|
+
"""Model for LLM model information."""
|
56
|
+
|
57
|
+
id: str = Field(..., description="Model identifier")
|
58
|
+
name: str = Field(..., description="Model name")
|
59
|
+
provider: str = Field(..., description="Provider name")
|
60
|
+
description: Optional[str] = Field(None, description="Model description")
|
61
|
+
max_tokens: Optional[int] = Field(None, description="Maximum tokens supported")
|
62
|
+
pricing: Optional[Dict[str, float]] = Field(None, description="Pricing information")
|
63
|
+
|
64
|
+
|
65
|
+
class ProviderInfo(BaseModel):
|
66
|
+
"""Model for LLM provider information."""
|
67
|
+
|
68
|
+
id: str = Field(..., description="Provider identifier")
|
69
|
+
name: str = Field(..., description="Provider name")
|
70
|
+
description: Optional[str] = Field(None, description="Provider description")
|
71
|
+
website: Optional[str] = Field(None, description="Provider website")
|
72
|
+
models: List[str] = Field(..., description="Available models")
|
73
|
+
|
74
|
+
|
75
|
+
class UserInfo(BaseModel):
|
76
|
+
"""Model for user information."""
|
77
|
+
|
78
|
+
id: int = Field(..., description="User ID")
|
79
|
+
email: str = Field(..., description="User email")
|
80
|
+
first_name: Optional[str] = Field(None, description="User first name")
|
81
|
+
last_name: Optional[str] = Field(None, description="User last name")
|
82
|
+
is_admin: bool = Field(False, description="Whether the user is an admin")
|
83
|
+
created_at: datetime = Field(..., description="Account creation timestamp")
|
84
|
+
|
85
|
+
|
86
|
+
class BalanceInfo(BaseModel):
|
87
|
+
"""Model for user balance information."""
|
88
|
+
|
89
|
+
credits: float = Field(..., description="Available credits")
|
90
|
+
usage: Dict[str, float] = Field(..., description="Usage statistics")
|
91
|
+
last_updated: datetime = Field(..., description="Last updated timestamp")
|
92
|
+
|
93
|
+
|
94
|
+
# API endpoints
|
95
|
+
@router.post("/generate", response_model=GenerateResponse)
|
96
|
+
async def generate(
|
97
|
+
request: GenerateRequest,
|
98
|
+
user_data: Dict[str, Any] = Depends(get_current_user),
|
99
|
+
req: Request = None,
|
100
|
+
):
|
101
|
+
"""Generate text using the specified model and provider."""
|
102
|
+
try:
|
103
|
+
# Set default model and provider if not specified
|
104
|
+
provider_name = request.provider or get_config().get(
|
105
|
+
"default_provider", "openai"
|
106
|
+
)
|
107
|
+
model_name = request.model or get_config().get("default_model", "gpt-4o-mini")
|
108
|
+
|
109
|
+
# Get the provider
|
110
|
+
provider = get_provider(provider_name)
|
111
|
+
if not provider:
|
112
|
+
raise HTTPException(
|
113
|
+
status_code=400, detail=f"Provider '{provider_name}' not found"
|
114
|
+
)
|
115
|
+
|
116
|
+
# Generate the completion
|
117
|
+
completion = await provider.generate(
|
118
|
+
model=model_name,
|
119
|
+
prompt=request.prompt,
|
120
|
+
max_tokens=request.max_tokens,
|
121
|
+
temperature=request.temperature,
|
122
|
+
top_p=request.top_p,
|
123
|
+
stop=request.stop,
|
124
|
+
stream=request.stream,
|
125
|
+
)
|
126
|
+
|
127
|
+
# Log the request
|
128
|
+
log_entry = RequestLog(
|
129
|
+
user_id=user_data["user_id"],
|
130
|
+
provider=provider_name,
|
131
|
+
model=model_name,
|
132
|
+
prompt_tokens=completion.get("usage", {}).get("prompt_tokens", 0),
|
133
|
+
completion_tokens=completion.get("usage", {}).get("completion_tokens", 0),
|
134
|
+
total_tokens=completion.get("usage", {}).get("total_tokens", 0),
|
135
|
+
created_at=datetime.utcnow(),
|
136
|
+
)
|
137
|
+
execute_query(lambda session: session.add(log_entry))
|
138
|
+
|
139
|
+
# Return the response
|
140
|
+
return {
|
141
|
+
"id": completion.get("id", ""),
|
142
|
+
"text": completion.get("text", ""),
|
143
|
+
"provider": provider_name,
|
144
|
+
"model": model_name,
|
145
|
+
"usage": completion.get("usage", {}),
|
146
|
+
"created_at": datetime.utcnow(),
|
147
|
+
}
|
148
|
+
except Exception as e:
|
149
|
+
logger.error(f"Error generating completion: {str(e)}")
|
150
|
+
raise HTTPException(status_code=500, detail=str(e))
|
151
|
+
|
152
|
+
|
153
|
+
@router.get("/models", response_model=List[ModelInfo])
|
154
|
+
async def list_models(
|
155
|
+
provider: Optional[str] = None,
|
156
|
+
user_data: Dict[str, Any] = Depends(get_current_user),
|
157
|
+
):
|
158
|
+
"""List available models, optionally filtered by provider."""
|
159
|
+
try:
|
160
|
+
# Get configuration
|
161
|
+
config = get_config()
|
162
|
+
models = []
|
163
|
+
|
164
|
+
# If provider is specified, only get models for that provider
|
165
|
+
if provider:
|
166
|
+
provider_obj = get_provider(provider)
|
167
|
+
if not provider_obj:
|
168
|
+
raise HTTPException(
|
169
|
+
status_code=400, detail=f"Provider '{provider}' not found"
|
170
|
+
)
|
171
|
+
provider_models = provider_obj.list_models()
|
172
|
+
for model in provider_models:
|
173
|
+
models.append(
|
174
|
+
{
|
175
|
+
"id": model.get("id", ""),
|
176
|
+
"name": model.get("name", ""),
|
177
|
+
"provider": provider,
|
178
|
+
"description": model.get("description", ""),
|
179
|
+
"max_tokens": model.get("max_tokens", 0),
|
180
|
+
"pricing": model.get("pricing", {}),
|
181
|
+
}
|
182
|
+
)
|
183
|
+
else:
|
184
|
+
# Get models for all providers
|
185
|
+
for provider_name in config.get("providers", {}).keys():
|
186
|
+
try:
|
187
|
+
provider_obj = get_provider(provider_name)
|
188
|
+
if provider_obj:
|
189
|
+
provider_models = provider_obj.list_models()
|
190
|
+
for model in provider_models:
|
191
|
+
models.append(
|
192
|
+
{
|
193
|
+
"id": model.get("id", ""),
|
194
|
+
"name": model.get("name", ""),
|
195
|
+
"provider": provider_name,
|
196
|
+
"description": model.get("description", ""),
|
197
|
+
"max_tokens": model.get("max_tokens", 0),
|
198
|
+
"pricing": model.get("pricing", {}),
|
199
|
+
}
|
200
|
+
)
|
201
|
+
except Exception as e:
|
202
|
+
logger.error(
|
203
|
+
f"Error getting models for provider {provider_name}: {str(e)}"
|
204
|
+
)
|
205
|
+
|
206
|
+
return models
|
207
|
+
except Exception as e:
|
208
|
+
logger.error(f"Error listing models: {str(e)}")
|
209
|
+
raise HTTPException(status_code=500, detail=str(e))
|
210
|
+
|
211
|
+
|
212
|
+
@router.get("/providers", response_model=List[ProviderInfo])
|
213
|
+
async def list_providers(
|
214
|
+
user_data: Dict[str, Any] = Depends(get_current_user),
|
215
|
+
):
|
216
|
+
"""List available providers."""
|
217
|
+
try:
|
218
|
+
# Get configuration
|
219
|
+
config = get_config()
|
220
|
+
providers = []
|
221
|
+
|
222
|
+
# Get all providers
|
223
|
+
for provider_name, provider_config in config.get("providers", {}).items():
|
224
|
+
try:
|
225
|
+
provider_obj = get_provider(provider_name)
|
226
|
+
if provider_obj:
|
227
|
+
provider_models = provider_obj.list_models()
|
228
|
+
providers.append(
|
229
|
+
{
|
230
|
+
"id": provider_name,
|
231
|
+
"name": provider_config.get("name", provider_name),
|
232
|
+
"description": provider_config.get("description", ""),
|
233
|
+
"website": provider_config.get("website", ""),
|
234
|
+
"models": [
|
235
|
+
model.get("id", "") for model in provider_models
|
236
|
+
],
|
237
|
+
}
|
238
|
+
)
|
239
|
+
except Exception as e:
|
240
|
+
logger.error(f"Error getting provider {provider_name}: {str(e)}")
|
241
|
+
|
242
|
+
return providers
|
243
|
+
except Exception as e:
|
244
|
+
logger.error(f"Error listing providers: {str(e)}")
|
245
|
+
raise HTTPException(status_code=500, detail=str(e))
|
246
|
+
|
247
|
+
|
248
|
+
@router.get("/user", response_model=UserInfo)
|
249
|
+
async def get_user(
|
250
|
+
user_data: Dict[str, Any] = Depends(get_current_user),
|
251
|
+
):
|
252
|
+
"""Get information about the authenticated user."""
|
253
|
+
try:
|
254
|
+
# Get user from database
|
255
|
+
user = None
|
256
|
+
|
257
|
+
def get_user_from_db(session):
|
258
|
+
nonlocal user
|
259
|
+
user = session.query(User).filter(User.id == user_data["user_id"]).first()
|
260
|
+
return user
|
261
|
+
|
262
|
+
execute_query(get_user_from_db)
|
263
|
+
|
264
|
+
if not user:
|
265
|
+
raise HTTPException(status_code=404, detail="User not found")
|
266
|
+
|
267
|
+
return {
|
268
|
+
"id": user.id,
|
269
|
+
"email": user.email,
|
270
|
+
"first_name": user.first_name,
|
271
|
+
"last_name": user.last_name,
|
272
|
+
"is_admin": user.is_admin,
|
273
|
+
"created_at": user.created_at,
|
274
|
+
}
|
275
|
+
except Exception as e:
|
276
|
+
logger.error(f"Error getting user: {str(e)}")
|
277
|
+
raise HTTPException(status_code=500, detail=str(e))
|
278
|
+
|
279
|
+
|
280
|
+
@router.get("/user/balance", response_model=BalanceInfo)
|
281
|
+
async def get_balance(
|
282
|
+
user_data: Dict[str, Any] = Depends(get_current_user),
|
283
|
+
):
|
284
|
+
"""Get the user's current balance."""
|
285
|
+
try:
|
286
|
+
# Get user's credit from database
|
287
|
+
credit = None
|
288
|
+
usage_data = {}
|
289
|
+
|
290
|
+
def get_credit_from_db(session):
|
291
|
+
nonlocal credit
|
292
|
+
credit = (
|
293
|
+
session.query(Credit)
|
294
|
+
.filter(Credit.user_id == user_data["user_id"])
|
295
|
+
.first()
|
296
|
+
)
|
297
|
+
return credit
|
298
|
+
|
299
|
+
def get_usage_from_db(session):
|
300
|
+
nonlocal usage_data
|
301
|
+
# Get total usage by provider
|
302
|
+
usage_by_provider = {}
|
303
|
+
logs = (
|
304
|
+
session.query(RequestLog)
|
305
|
+
.filter(RequestLog.user_id == user_data["user_id"])
|
306
|
+
.all()
|
307
|
+
)
|
308
|
+
for log in logs:
|
309
|
+
provider = log.provider
|
310
|
+
if provider not in usage_by_provider:
|
311
|
+
usage_by_provider[provider] = 0
|
312
|
+
usage_by_provider[provider] += log.total_tokens
|
313
|
+
|
314
|
+
usage_data = usage_by_provider
|
315
|
+
return usage_by_provider
|
316
|
+
|
317
|
+
execute_query(get_credit_from_db)
|
318
|
+
execute_query(get_usage_from_db)
|
319
|
+
|
320
|
+
if not credit:
|
321
|
+
# Create a new credit entry with default values
|
322
|
+
credit = Credit(
|
323
|
+
user_id=user_data["user_id"],
|
324
|
+
amount=0.0,
|
325
|
+
last_updated=datetime.utcnow(),
|
326
|
+
)
|
327
|
+
execute_query(lambda session: session.add(credit))
|
328
|
+
|
329
|
+
return {
|
330
|
+
"credits": credit.amount,
|
331
|
+
"usage": usage_data,
|
332
|
+
"last_updated": credit.last_updated,
|
333
|
+
}
|
334
|
+
except Exception as e:
|
335
|
+
logger.error(f"Error getting balance: {str(e)}")
|
336
|
+
raise HTTPException(status_code=500, detail=str(e))
|
indoxRouter/client.py
ADDED
@@ -0,0 +1,286 @@
|
|
1
|
+
from typing import Dict, Optional, Any, List
|
2
|
+
import os
|
3
|
+
import sys
|
4
|
+
import json
|
5
|
+
import requests
|
6
|
+
|
7
|
+
# Add the parent directory to the path to make imports work correctly
|
8
|
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
9
|
+
parent_dir = os.path.dirname(current_dir)
|
10
|
+
sys.path.append(parent_dir)
|
11
|
+
|
12
|
+
# Use absolute imports
|
13
|
+
from indoxRouter.utils.auth import AuthManager
|
14
|
+
from indoxRouter.providers.base_provider import BaseProvider
|
15
|
+
|
16
|
+
|
17
|
+
class Client:
|
18
|
+
"""
|
19
|
+
Client for making API requests to the IndoxRouter API.
|
20
|
+
"""
|
21
|
+
|
22
|
+
def __init__(self, api_key: str, base_url: str = None):
|
23
|
+
"""
|
24
|
+
Initialize the client.
|
25
|
+
|
26
|
+
Args:
|
27
|
+
api_key: API key for authentication
|
28
|
+
base_url: Base URL for the API (default: http://localhost:8000)
|
29
|
+
"""
|
30
|
+
self.api_key = api_key
|
31
|
+
self.base_url = base_url or "http://localhost:8000"
|
32
|
+
self.auth_manager = AuthManager()
|
33
|
+
|
34
|
+
# Verify the API key
|
35
|
+
self.user_data = self.auth_manager.verify_api_key(api_key)
|
36
|
+
if not self.user_data:
|
37
|
+
raise ValueError("Invalid API key")
|
38
|
+
|
39
|
+
def generate(
|
40
|
+
self,
|
41
|
+
provider: str,
|
42
|
+
model: str,
|
43
|
+
prompt: str,
|
44
|
+
temperature: float = 0.7,
|
45
|
+
max_tokens: int = 1000,
|
46
|
+
**kwargs,
|
47
|
+
) -> str:
|
48
|
+
"""
|
49
|
+
Generate a response from a model.
|
50
|
+
|
51
|
+
Args:
|
52
|
+
provider: Provider name
|
53
|
+
model: Model name
|
54
|
+
prompt: Prompt
|
55
|
+
temperature: Temperature
|
56
|
+
max_tokens: Maximum tokens
|
57
|
+
**kwargs: Additional parameters
|
58
|
+
|
59
|
+
Returns:
|
60
|
+
Model response
|
61
|
+
"""
|
62
|
+
url = f"{self.base_url}/v1/completions"
|
63
|
+
|
64
|
+
headers = {
|
65
|
+
"Content-Type": "application/json",
|
66
|
+
"Authorization": f"Bearer {self.api_key}",
|
67
|
+
}
|
68
|
+
|
69
|
+
data = {
|
70
|
+
"provider": provider,
|
71
|
+
"model": model,
|
72
|
+
"prompt": prompt,
|
73
|
+
"temperature": temperature,
|
74
|
+
"max_tokens": max_tokens,
|
75
|
+
**kwargs,
|
76
|
+
}
|
77
|
+
|
78
|
+
response = requests.post(url, headers=headers, json=data)
|
79
|
+
|
80
|
+
if response.status_code != 200:
|
81
|
+
error_message = (
|
82
|
+
response.json().get("error", {}).get("message", "Unknown error")
|
83
|
+
)
|
84
|
+
raise Exception(f"Error: {error_message}")
|
85
|
+
|
86
|
+
return response.json().get("choices", [{}])[0].get("text", "")
|
87
|
+
|
88
|
+
def list_models(self, provider: Optional[str] = None) -> List[Dict[str, Any]]:
|
89
|
+
"""
|
90
|
+
List available models.
|
91
|
+
|
92
|
+
Args:
|
93
|
+
provider: Provider name (optional)
|
94
|
+
|
95
|
+
Returns:
|
96
|
+
List of models
|
97
|
+
"""
|
98
|
+
url = f"{self.base_url}/v1/models"
|
99
|
+
|
100
|
+
if provider:
|
101
|
+
url += f"?provider={provider}"
|
102
|
+
|
103
|
+
headers = {
|
104
|
+
"Authorization": f"Bearer {self.api_key}",
|
105
|
+
}
|
106
|
+
|
107
|
+
response = requests.get(url, headers=headers)
|
108
|
+
|
109
|
+
if response.status_code != 200:
|
110
|
+
error_message = (
|
111
|
+
response.json().get("error", {}).get("message", "Unknown error")
|
112
|
+
)
|
113
|
+
raise Exception(f"Error: {error_message}")
|
114
|
+
|
115
|
+
return response.json().get("data", [])
|
116
|
+
|
117
|
+
def list_providers(self) -> List[str]:
|
118
|
+
"""
|
119
|
+
List available providers.
|
120
|
+
|
121
|
+
Returns:
|
122
|
+
List of providers
|
123
|
+
"""
|
124
|
+
url = f"{self.base_url}/v1/providers"
|
125
|
+
|
126
|
+
headers = {
|
127
|
+
"Authorization": f"Bearer {self.api_key}",
|
128
|
+
}
|
129
|
+
|
130
|
+
response = requests.get(url, headers=headers)
|
131
|
+
|
132
|
+
if response.status_code != 200:
|
133
|
+
error_message = (
|
134
|
+
response.json().get("error", {}).get("message", "Unknown error")
|
135
|
+
)
|
136
|
+
raise Exception(f"Error: {error_message}")
|
137
|
+
|
138
|
+
return response.json().get("data", [])
|
139
|
+
|
140
|
+
def _parse_model_name(self, model_name: str) -> tuple:
|
141
|
+
"""
|
142
|
+
Parse model name into provider and model parts
|
143
|
+
|
144
|
+
Args:
|
145
|
+
model_name: Full model name (e.g., 'openai/gpt-4')
|
146
|
+
|
147
|
+
Returns:
|
148
|
+
Tuple of (provider_name, model_part)
|
149
|
+
"""
|
150
|
+
if "/" not in model_name:
|
151
|
+
raise ValueError(
|
152
|
+
f"Invalid model name format: {model_name}. Expected format: 'provider/model'"
|
153
|
+
)
|
154
|
+
|
155
|
+
provider_name, model_part = model_name.split("/", 1)
|
156
|
+
return provider_name, model_part
|
157
|
+
|
158
|
+
def _load_provider_class(self, provider_name: str):
|
159
|
+
"""
|
160
|
+
Dynamically load provider class
|
161
|
+
|
162
|
+
Args:
|
163
|
+
provider_name: Name of the provider
|
164
|
+
|
165
|
+
Returns:
|
166
|
+
Provider class
|
167
|
+
"""
|
168
|
+
try:
|
169
|
+
# Import the provider module dynamically
|
170
|
+
module_path = f".providers.{provider_name}"
|
171
|
+
provider_module = __import__(
|
172
|
+
module_path, fromlist=["Provider"], globals=globals()
|
173
|
+
)
|
174
|
+
return provider_module.Provider
|
175
|
+
except (ImportError, AttributeError) as e:
|
176
|
+
raise ValueError(f"Provider not supported: {provider_name}") from e
|
177
|
+
|
178
|
+
def _get_provider(self, model_name: str) -> BaseProvider:
|
179
|
+
"""
|
180
|
+
Get provider instance with cached credentials
|
181
|
+
|
182
|
+
Args:
|
183
|
+
model_name: Full model name (e.g., 'openai/gpt-4')
|
184
|
+
|
185
|
+
Returns:
|
186
|
+
Provider instance
|
187
|
+
"""
|
188
|
+
if model_name in self.provider_cache:
|
189
|
+
return self.provider_cache[model_name]
|
190
|
+
|
191
|
+
provider_name, model_part = self._parse_model_name(model_name)
|
192
|
+
provider_class = self._load_provider_class(provider_name)
|
193
|
+
|
194
|
+
# Get provider API key from secure storage
|
195
|
+
provider_api_key = self._get_provider_credentials(provider_name)
|
196
|
+
|
197
|
+
instance = provider_class(api_key=provider_api_key, model_name=model_part)
|
198
|
+
self.provider_cache[model_name] = instance
|
199
|
+
return instance
|
200
|
+
|
201
|
+
def _get_provider_credentials(self, provider_name: str) -> str:
|
202
|
+
"""
|
203
|
+
Retrieve provider API key from secure storage
|
204
|
+
|
205
|
+
Args:
|
206
|
+
provider_name: Name of the provider
|
207
|
+
|
208
|
+
Returns:
|
209
|
+
Provider API key
|
210
|
+
"""
|
211
|
+
# Implement your secure credential storage (e.g., AWS Secrets Manager)
|
212
|
+
# Example using environment variables:
|
213
|
+
env_var = f"{provider_name.upper()}_API_KEY"
|
214
|
+
if env_var not in os.environ:
|
215
|
+
raise ValueError(
|
216
|
+
f"Missing API key for provider: {provider_name}. Set {env_var} environment variable."
|
217
|
+
)
|
218
|
+
|
219
|
+
return os.environ[env_var]
|
220
|
+
|
221
|
+
def generate(self, model_name: str, prompt: str, **kwargs) -> Dict[str, Any]:
|
222
|
+
"""
|
223
|
+
Generate completion with credit handling
|
224
|
+
|
225
|
+
Args:
|
226
|
+
model_name: Provider/model name (e.g., 'openai/gpt-4')
|
227
|
+
prompt: User input prompt
|
228
|
+
**kwargs: Generation parameters
|
229
|
+
|
230
|
+
Returns:
|
231
|
+
Dictionary with response and credit information
|
232
|
+
"""
|
233
|
+
provider = self._get_provider(model_name)
|
234
|
+
|
235
|
+
# Estimate max possible cost
|
236
|
+
max_tokens = kwargs.get("max_tokens", 2048)
|
237
|
+
estimated_cost = provider.estimate_cost(prompt, max_tokens)
|
238
|
+
|
239
|
+
# Check balance
|
240
|
+
if self.user_data["balance"] < estimated_cost:
|
241
|
+
raise ValueError(
|
242
|
+
f"Insufficient credits. Required: {estimated_cost:.6f}, Available: {self.user_data['balance']:.6f}"
|
243
|
+
)
|
244
|
+
|
245
|
+
# Make API call
|
246
|
+
response = provider.generate(prompt, **kwargs)
|
247
|
+
|
248
|
+
# Deduct actual cost
|
249
|
+
success = self.auth_manager.deduct_credits(
|
250
|
+
self.user_data["id"], response["cost"]
|
251
|
+
)
|
252
|
+
|
253
|
+
if not success:
|
254
|
+
raise RuntimeError("Credit deduction failed")
|
255
|
+
|
256
|
+
# Get updated user data
|
257
|
+
self.user_data = self.auth_manager.get_user_by_id(self.user_data["id"])
|
258
|
+
|
259
|
+
return {
|
260
|
+
"text": response["text"],
|
261
|
+
"cost": response["cost"],
|
262
|
+
"remaining_credits": self.user_data["balance"],
|
263
|
+
"model": model_name,
|
264
|
+
}
|
265
|
+
|
266
|
+
def get_balance(self) -> float:
|
267
|
+
"""
|
268
|
+
Get current user balance
|
269
|
+
|
270
|
+
Returns:
|
271
|
+
Current credit balance
|
272
|
+
"""
|
273
|
+
# Refresh user data to get the latest balance
|
274
|
+
self.user_data = self.auth_manager.get_user_by_id(self.user_data["id"])
|
275
|
+
return self.user_data["balance"]
|
276
|
+
|
277
|
+
def get_user_info(self) -> Dict[str, Any]:
|
278
|
+
"""
|
279
|
+
Get current user information
|
280
|
+
|
281
|
+
Returns:
|
282
|
+
User information dictionary
|
283
|
+
"""
|
284
|
+
# Refresh user data
|
285
|
+
self.user_data = self.auth_manager.get_user_by_id(self.user_data["id"])
|
286
|
+
return self.user_data
|