maxllm-gate 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_scheduler/__init__.py +8 -0
- llm_scheduler/api/__init__.py +6 -0
- llm_scheduler/api/dependencies.py +10 -0
- llm_scheduler/api/routes.py +275 -0
- llm_scheduler/api/schemas.py +135 -0
- llm_scheduler/config.py +117 -0
- llm_scheduler/core/__init__.py +8 -0
- llm_scheduler/core/dispatcher.py +225 -0
- llm_scheduler/core/queue_manager.py +251 -0
- llm_scheduler/core/scheduler.py +236 -0
- llm_scheduler/core/token_estimator.py +201 -0
- llm_scheduler/main.py +86 -0
- llm_scheduler/models/__init__.py +6 -0
- llm_scheduler/models/provider.py +103 -0
- llm_scheduler/models/request.py +101 -0
- llm_scheduler/observability/__init__.py +6 -0
- llm_scheduler/observability/logging.py +65 -0
- llm_scheduler/observability/metrics.py +92 -0
- llm_scheduler/rate_limiting/__init__.py +7 -0
- llm_scheduler/rate_limiting/key_manager.py +252 -0
- llm_scheduler/rate_limiting/token_bucket.py +152 -0
- llm_scheduler/rate_limiting/tracker.py +281 -0
- llm_scheduler/strategies/__init__.py +7 -0
- llm_scheduler/strategies/base.py +56 -0
- llm_scheduler/strategies/fallback.py +52 -0
- llm_scheduler/strategies/least_utilized.py +30 -0
- llm_scheduler/strategies/round_robin.py +29 -0
- llm_scheduler/strategies/token_aware.py +46 -0
- llm_scheduler/utils/__init__.py +6 -0
- llm_scheduler/utils/retry.py +136 -0
- llm_scheduler/utils/time_utils.py +115 -0
- maxllm/__init__.py +77 -0
- maxllm/client.py +598 -0
- maxllm/config.py +181 -0
- maxllm/rate_limiter.py +432 -0
- maxllm/redis_backend.py +495 -0
- maxllm/scheduler.py +559 -0
- maxllm/validation.py +183 -0
- maxllm_gate-0.2.0.dist-info/METADATA +771 -0
- maxllm_gate-0.2.0.dist-info/RECORD +43 -0
- maxllm_gate-0.2.0.dist-info/WHEEL +4 -0
- maxllm_gate-0.2.0.dist-info/entry_points.txt +2 -0
- maxllm_gate-0.2.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
"""FastAPI route definitions."""
|
|
2
|
+
|
|
3
|
+
from fastapi import APIRouter, Depends, HTTPException
|
|
4
|
+
from fastapi.responses import StreamingResponse
|
|
5
|
+
from prometheus_client import generate_latest, CONTENT_TYPE_LATEST
|
|
6
|
+
|
|
7
|
+
from llm_scheduler.api.schemas import (
|
|
8
|
+
ChatRequest,
|
|
9
|
+
ChatResponse,
|
|
10
|
+
BatchRequest,
|
|
11
|
+
BatchResponse,
|
|
12
|
+
HealthResponse,
|
|
13
|
+
StatusResponse,
|
|
14
|
+
CapacityResponse,
|
|
15
|
+
KeyStatus,
|
|
16
|
+
)
|
|
17
|
+
from llm_scheduler.api.dependencies import get_scheduler
|
|
18
|
+
from llm_scheduler.core.scheduler import Scheduler, SchedulerError
|
|
19
|
+
from llm_scheduler.observability.logging import get_logger
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
router = APIRouter()
|
|
23
|
+
logger = get_logger()
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@router.post("/chat", response_model=ChatResponse, tags=["LLM"])
|
|
27
|
+
async def chat(
|
|
28
|
+
request: ChatRequest,
|
|
29
|
+
scheduler: Scheduler = Depends(get_scheduler),
|
|
30
|
+
) -> ChatResponse:
|
|
31
|
+
"""
|
|
32
|
+
Send a chat completion request.
|
|
33
|
+
|
|
34
|
+
The request is queued and scheduled based on:
|
|
35
|
+
- Priority (high > medium > low)
|
|
36
|
+
- Available API key capacity
|
|
37
|
+
- Rate limits (TPM/RPM)
|
|
38
|
+
|
|
39
|
+
If all keys are at capacity, the request is deferred until
|
|
40
|
+
capacity becomes available.
|
|
41
|
+
"""
|
|
42
|
+
try:
|
|
43
|
+
messages = [m.model_dump() for m in request.messages]
|
|
44
|
+
|
|
45
|
+
result = await scheduler.schedule(
|
|
46
|
+
model=request.model,
|
|
47
|
+
messages=messages,
|
|
48
|
+
priority=request.priority,
|
|
49
|
+
max_tokens=request.max_tokens,
|
|
50
|
+
temperature=request.temperature,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
# Extract content from response
|
|
54
|
+
content = ""
|
|
55
|
+
if "choices" in result and result["choices"]:
|
|
56
|
+
choice = result["choices"][0]
|
|
57
|
+
if "message" in choice:
|
|
58
|
+
content = choice["message"].get("content", "")
|
|
59
|
+
elif "text" in choice:
|
|
60
|
+
content = choice["text"]
|
|
61
|
+
elif "content" in result:
|
|
62
|
+
content = result["content"]
|
|
63
|
+
|
|
64
|
+
return ChatResponse(
|
|
65
|
+
id=result.get("id", ""),
|
|
66
|
+
model=result.get("model", request.model),
|
|
67
|
+
content=content,
|
|
68
|
+
usage=result.get("usage"),
|
|
69
|
+
finish_reason=result.get("choices", [{}])[0].get("finish_reason"),
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
except SchedulerError as e:
|
|
73
|
+
logger.warning("Scheduler error", error=str(e))
|
|
74
|
+
raise HTTPException(status_code=503, detail=str(e))
|
|
75
|
+
except Exception as e:
|
|
76
|
+
logger.error("Chat request failed", error=str(e))
|
|
77
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@router.post("/chat/stream", tags=["LLM"])
|
|
81
|
+
async def chat_stream(
|
|
82
|
+
request: ChatRequest,
|
|
83
|
+
scheduler: Scheduler = Depends(get_scheduler),
|
|
84
|
+
):
|
|
85
|
+
"""
|
|
86
|
+
Send a streaming chat completion request.
|
|
87
|
+
|
|
88
|
+
Returns a Server-Sent Events (SSE) stream of response chunks.
|
|
89
|
+
"""
|
|
90
|
+
async def generate():
|
|
91
|
+
try:
|
|
92
|
+
messages = [m.model_dump() for m in request.messages]
|
|
93
|
+
|
|
94
|
+
# For streaming, we need direct access to dispatcher
|
|
95
|
+
# This is a simplified version
|
|
96
|
+
result = await scheduler.schedule(
|
|
97
|
+
model=request.model,
|
|
98
|
+
messages=messages,
|
|
99
|
+
priority=request.priority,
|
|
100
|
+
max_tokens=request.max_tokens,
|
|
101
|
+
temperature=request.temperature,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
content = result.get("content", "")
|
|
105
|
+
# Simulate streaming for collected content
|
|
106
|
+
for chunk in [content[i:i+10] for i in range(0, len(content), 10)]:
|
|
107
|
+
yield f"data: {chunk}\n\n"
|
|
108
|
+
|
|
109
|
+
yield "data: [DONE]\n\n"
|
|
110
|
+
|
|
111
|
+
except Exception as e:
|
|
112
|
+
yield f"data: [ERROR] {str(e)}\n\n"
|
|
113
|
+
|
|
114
|
+
return StreamingResponse(
|
|
115
|
+
generate(),
|
|
116
|
+
media_type="text/event-stream",
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@router.post("/batch", response_model=BatchResponse, tags=["LLM"])
|
|
121
|
+
async def batch(
|
|
122
|
+
request: BatchRequest,
|
|
123
|
+
scheduler: Scheduler = Depends(get_scheduler),
|
|
124
|
+
) -> BatchResponse:
|
|
125
|
+
"""
|
|
126
|
+
Process multiple chat requests in parallel.
|
|
127
|
+
|
|
128
|
+
Returns results for all requests, with errors inline.
|
|
129
|
+
"""
|
|
130
|
+
results = []
|
|
131
|
+
successful = 0
|
|
132
|
+
failed = 0
|
|
133
|
+
|
|
134
|
+
batch_requests = [
|
|
135
|
+
{
|
|
136
|
+
"model": req.model,
|
|
137
|
+
"messages": [m.model_dump() for m in req.messages],
|
|
138
|
+
"priority": req.priority,
|
|
139
|
+
"max_tokens": req.max_tokens,
|
|
140
|
+
"temperature": req.temperature,
|
|
141
|
+
}
|
|
142
|
+
for req in request.requests
|
|
143
|
+
]
|
|
144
|
+
|
|
145
|
+
raw_results = await scheduler.schedule_batch(batch_requests)
|
|
146
|
+
|
|
147
|
+
for raw in raw_results:
|
|
148
|
+
if isinstance(raw, Exception):
|
|
149
|
+
failed += 1
|
|
150
|
+
results.append({"error": str(raw)})
|
|
151
|
+
else:
|
|
152
|
+
successful += 1
|
|
153
|
+
content = ""
|
|
154
|
+
if "choices" in raw and raw["choices"]:
|
|
155
|
+
content = raw["choices"][0].get("message", {}).get("content", "")
|
|
156
|
+
|
|
157
|
+
results.append(ChatResponse(
|
|
158
|
+
id=raw.get("id", ""),
|
|
159
|
+
model=raw.get("model", ""),
|
|
160
|
+
content=content,
|
|
161
|
+
usage=raw.get("usage"),
|
|
162
|
+
finish_reason=raw.get("choices", [{}])[0].get("finish_reason"),
|
|
163
|
+
))
|
|
164
|
+
|
|
165
|
+
return BatchResponse(
|
|
166
|
+
results=results,
|
|
167
|
+
total=len(request.requests),
|
|
168
|
+
successful=successful,
|
|
169
|
+
failed=failed,
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
@router.get("/health", response_model=HealthResponse, tags=["System"])
|
|
174
|
+
async def health(
|
|
175
|
+
scheduler: Scheduler = Depends(get_scheduler),
|
|
176
|
+
) -> HealthResponse:
|
|
177
|
+
"""
|
|
178
|
+
Health check endpoint.
|
|
179
|
+
|
|
180
|
+
Returns overall system health status.
|
|
181
|
+
"""
|
|
182
|
+
status_data = scheduler.get_status()
|
|
183
|
+
queue_size = status_data["queue"]["queue_size"]
|
|
184
|
+
|
|
185
|
+
# Count healthy keys
|
|
186
|
+
keys_data = status_data["keys"].get("keys", {})
|
|
187
|
+
healthy_keys = sum(1 for k in keys_data.values() if k.get("is_healthy", False))
|
|
188
|
+
|
|
189
|
+
# Determine status
|
|
190
|
+
if not status_data["running"]:
|
|
191
|
+
status = "unhealthy"
|
|
192
|
+
elif healthy_keys == 0:
|
|
193
|
+
status = "unhealthy"
|
|
194
|
+
elif queue_size > status_data["queue"]["max_size"] * 0.9:
|
|
195
|
+
status = "degraded"
|
|
196
|
+
else:
|
|
197
|
+
status = "healthy"
|
|
198
|
+
|
|
199
|
+
return HealthResponse(
|
|
200
|
+
status=status,
|
|
201
|
+
scheduler_running=status_data["running"],
|
|
202
|
+
queue_size=queue_size,
|
|
203
|
+
keys_available=healthy_keys,
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
@router.get("/status", response_model=StatusResponse, tags=["System"])
|
|
208
|
+
async def status(
|
|
209
|
+
scheduler: Scheduler = Depends(get_scheduler),
|
|
210
|
+
) -> StatusResponse:
|
|
211
|
+
"""
|
|
212
|
+
Get detailed scheduler status.
|
|
213
|
+
|
|
214
|
+
Returns queue statistics, key states, and configuration.
|
|
215
|
+
"""
|
|
216
|
+
return scheduler.get_status()
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
@router.get("/capacity", response_model=CapacityResponse, tags=["System"])
|
|
220
|
+
async def capacity(
|
|
221
|
+
scheduler: Scheduler = Depends(get_scheduler),
|
|
222
|
+
) -> CapacityResponse:
|
|
223
|
+
"""
|
|
224
|
+
Get current capacity across all API keys.
|
|
225
|
+
|
|
226
|
+
Useful for monitoring and understanding rate limit state.
|
|
227
|
+
"""
|
|
228
|
+
status_data = scheduler.get_status()
|
|
229
|
+
keys_data = status_data["keys"]
|
|
230
|
+
|
|
231
|
+
total = keys_data.get("total_capacity", {})
|
|
232
|
+
available = keys_data.get("available_capacity", {})
|
|
233
|
+
|
|
234
|
+
key_statuses = [
|
|
235
|
+
KeyStatus(**key_info)
|
|
236
|
+
for key_info in keys_data.get("keys", {}).values()
|
|
237
|
+
]
|
|
238
|
+
|
|
239
|
+
return CapacityResponse(
|
|
240
|
+
total_tpm=total.get("tpm", 0),
|
|
241
|
+
available_tpm=available.get("tpm", 0),
|
|
242
|
+
total_rpm=total.get("rpm", 0),
|
|
243
|
+
available_rpm=available.get("rpm", 0),
|
|
244
|
+
keys=key_statuses,
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
@router.get("/metrics", tags=["System"])
|
|
249
|
+
async def metrics():
|
|
250
|
+
"""
|
|
251
|
+
Prometheus metrics endpoint.
|
|
252
|
+
|
|
253
|
+
Returns metrics in Prometheus exposition format.
|
|
254
|
+
"""
|
|
255
|
+
from starlette.responses import Response
|
|
256
|
+
return Response(
|
|
257
|
+
content=generate_latest(),
|
|
258
|
+
media_type=CONTENT_TYPE_LATEST,
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
@router.get("/", tags=["System"])
|
|
263
|
+
async def root():
|
|
264
|
+
"""Root endpoint with API information."""
|
|
265
|
+
return {
|
|
266
|
+
"name": "LLM Rate Limit Scheduler",
|
|
267
|
+
"version": "0.1.0",
|
|
268
|
+
"description": (
|
|
269
|
+
"An intelligent scheduling and rate-limit-aware control layer "
|
|
270
|
+
"on top of LiteLLM that maximizes throughput and prevents 429 errors."
|
|
271
|
+
),
|
|
272
|
+
"docs_url": "/docs",
|
|
273
|
+
"health_url": "/health",
|
|
274
|
+
"metrics_url": "/metrics",
|
|
275
|
+
}
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
"""Pydantic schemas for API requests and responses."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Literal
|
|
4
|
+
from pydantic import BaseModel, Field
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class Message(BaseModel):
|
|
8
|
+
"""Chat message."""
|
|
9
|
+
|
|
10
|
+
role: Literal["system", "user", "assistant"] = Field(
|
|
11
|
+
description="Role of the message sender"
|
|
12
|
+
)
|
|
13
|
+
content: str = Field(description="Message content")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ChatRequest(BaseModel):
|
|
17
|
+
"""Request body for /chat endpoint."""
|
|
18
|
+
|
|
19
|
+
model: str = Field(
|
|
20
|
+
description="Model name (e.g., 'mixtral', 'gpt-4o-mini', 'llama-3.1-70b')"
|
|
21
|
+
)
|
|
22
|
+
messages: list[Message] = Field(
|
|
23
|
+
description="List of chat messages"
|
|
24
|
+
)
|
|
25
|
+
priority: Literal["high", "medium", "low"] = Field(
|
|
26
|
+
default="medium",
|
|
27
|
+
description="Request priority for queue ordering"
|
|
28
|
+
)
|
|
29
|
+
max_tokens: int | None = Field(
|
|
30
|
+
default=None,
|
|
31
|
+
description="Maximum tokens in response"
|
|
32
|
+
)
|
|
33
|
+
temperature: float = Field(
|
|
34
|
+
default=0.7,
|
|
35
|
+
ge=0.0,
|
|
36
|
+
le=2.0,
|
|
37
|
+
description="Sampling temperature"
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
model_config = {
|
|
41
|
+
"json_schema_extra": {
|
|
42
|
+
"examples": [
|
|
43
|
+
{
|
|
44
|
+
"model": "mixtral-8x7b-32768",
|
|
45
|
+
"messages": [
|
|
46
|
+
{"role": "system", "content": "You are a helpful assistant."},
|
|
47
|
+
{"role": "user", "content": "Hello, how are you?"}
|
|
48
|
+
],
|
|
49
|
+
"priority": "medium",
|
|
50
|
+
"max_tokens": 1024,
|
|
51
|
+
"temperature": 0.7
|
|
52
|
+
}
|
|
53
|
+
]
|
|
54
|
+
}
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class ChatResponse(BaseModel):
|
|
59
|
+
"""Response body for /chat endpoint."""
|
|
60
|
+
|
|
61
|
+
id: str = Field(description="Response ID")
|
|
62
|
+
model: str = Field(description="Model used")
|
|
63
|
+
content: str = Field(description="Generated content")
|
|
64
|
+
usage: dict[str, int] | None = Field(
|
|
65
|
+
default=None,
|
|
66
|
+
description="Token usage statistics"
|
|
67
|
+
)
|
|
68
|
+
finish_reason: str | None = Field(
|
|
69
|
+
default=None,
|
|
70
|
+
description="Reason for completion"
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class BatchRequest(BaseModel):
|
|
75
|
+
"""Request body for /batch endpoint."""
|
|
76
|
+
|
|
77
|
+
requests: list[ChatRequest] = Field(
|
|
78
|
+
description="List of chat requests to process"
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class BatchResponse(BaseModel):
|
|
83
|
+
"""Response body for /batch endpoint."""
|
|
84
|
+
|
|
85
|
+
results: list[ChatResponse | dict[str, str]] = Field(
|
|
86
|
+
description="Results for each request (response or error)"
|
|
87
|
+
)
|
|
88
|
+
total: int = Field(description="Total requests")
|
|
89
|
+
successful: int = Field(description="Successful requests")
|
|
90
|
+
failed: int = Field(description="Failed requests")
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class HealthResponse(BaseModel):
|
|
94
|
+
"""Response body for /health endpoint."""
|
|
95
|
+
|
|
96
|
+
status: Literal["healthy", "degraded", "unhealthy"] = Field(
|
|
97
|
+
description="Overall health status"
|
|
98
|
+
)
|
|
99
|
+
scheduler_running: bool = Field(description="Whether scheduler is running")
|
|
100
|
+
queue_size: int = Field(description="Current queue size")
|
|
101
|
+
keys_available: int = Field(description="Number of healthy API keys")
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class StatusResponse(BaseModel):
|
|
105
|
+
"""Response body for /status endpoint."""
|
|
106
|
+
|
|
107
|
+
running: bool
|
|
108
|
+
queue: dict[str, Any]
|
|
109
|
+
keys: dict[str, Any]
|
|
110
|
+
strategy: str
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class KeyStatus(BaseModel):
|
|
114
|
+
"""Status of a single API key."""
|
|
115
|
+
|
|
116
|
+
key_id: str
|
|
117
|
+
provider: str
|
|
118
|
+
tpm_available: int
|
|
119
|
+
tpm_capacity: int
|
|
120
|
+
rpm_available: int
|
|
121
|
+
rpm_capacity: int
|
|
122
|
+
utilization: float
|
|
123
|
+
total_requests: int
|
|
124
|
+
total_tokens_used: int
|
|
125
|
+
is_healthy: bool
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class CapacityResponse(BaseModel):
|
|
129
|
+
"""Response body for /capacity endpoint."""
|
|
130
|
+
|
|
131
|
+
total_tpm: int
|
|
132
|
+
available_tpm: int
|
|
133
|
+
total_rpm: int
|
|
134
|
+
available_rpm: int
|
|
135
|
+
keys: list[KeyStatus]
|
llm_scheduler/config.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
"""Configuration management for LLM Rate Limit Scheduler."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
|
|
5
|
+
from pydantic import Field, field_validator
|
|
6
|
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class APIKeyConfig:
|
|
10
|
+
"""Configuration for a single API key."""
|
|
11
|
+
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
key_id: str,
|
|
15
|
+
api_key: str,
|
|
16
|
+
provider: str,
|
|
17
|
+
models: list[str],
|
|
18
|
+
tpm_limit: int,
|
|
19
|
+
rpm_limit: int,
|
|
20
|
+
):
|
|
21
|
+
self.key_id = key_id
|
|
22
|
+
self.api_key = api_key
|
|
23
|
+
self.provider = provider
|
|
24
|
+
self.models = models
|
|
25
|
+
self.tpm_limit = tpm_limit
|
|
26
|
+
self.rpm_limit = rpm_limit
|
|
27
|
+
|
|
28
|
+
def __repr__(self) -> str:
|
|
29
|
+
return f"APIKeyConfig(key_id={self.key_id}, provider={self.provider})"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class Settings(BaseSettings):
|
|
33
|
+
"""Application settings loaded from environment variables."""
|
|
34
|
+
|
|
35
|
+
model_config = SettingsConfigDict(
|
|
36
|
+
env_file=".env",
|
|
37
|
+
env_file_encoding="utf-8",
|
|
38
|
+
extra="ignore",
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
# Server
|
|
42
|
+
host: str = Field(default="0.0.0.0")
|
|
43
|
+
port: int = Field(default=8000)
|
|
44
|
+
debug: bool = Field(default=False)
|
|
45
|
+
log_level: str = Field(default="INFO")
|
|
46
|
+
|
|
47
|
+
# API Keys (JSON string)
|
|
48
|
+
api_keys_config: str = Field(default="{}")
|
|
49
|
+
|
|
50
|
+
# Scheduling
|
|
51
|
+
default_strategy: str = Field(default="least_utilized")
|
|
52
|
+
|
|
53
|
+
# Token estimation
|
|
54
|
+
default_max_tokens: int = Field(default=1024)
|
|
55
|
+
token_estimation_buffer: float = Field(default=1.1)
|
|
56
|
+
|
|
57
|
+
# Retry
|
|
58
|
+
max_retries: int = Field(default=3)
|
|
59
|
+
retry_base_delay: float = Field(default=1.0)
|
|
60
|
+
retry_max_delay: float = Field(default=60.0)
|
|
61
|
+
|
|
62
|
+
# Queue
|
|
63
|
+
max_queue_size: int = Field(default=10000)
|
|
64
|
+
default_priority: str = Field(default="medium")
|
|
65
|
+
|
|
66
|
+
# Redis (optional)
|
|
67
|
+
redis_url: str | None = Field(default=None)
|
|
68
|
+
use_redis_queue: bool = Field(default=False)
|
|
69
|
+
|
|
70
|
+
@field_validator("default_strategy")
|
|
71
|
+
@classmethod
|
|
72
|
+
def validate_strategy(cls, v: str) -> str:
|
|
73
|
+
valid = {"least_utilized", "round_robin", "token_aware"}
|
|
74
|
+
if v not in valid:
|
|
75
|
+
raise ValueError(f"Strategy must be one of {valid}")
|
|
76
|
+
return v
|
|
77
|
+
|
|
78
|
+
@field_validator("default_priority")
|
|
79
|
+
@classmethod
|
|
80
|
+
def validate_priority(cls, v: str) -> str:
|
|
81
|
+
valid = {"high", "medium", "low"}
|
|
82
|
+
if v not in valid:
|
|
83
|
+
raise ValueError(f"Priority must be one of {valid}")
|
|
84
|
+
return v
|
|
85
|
+
|
|
86
|
+
def get_api_keys(self) -> dict[str, APIKeyConfig]:
|
|
87
|
+
"""Parse API keys configuration into typed objects."""
|
|
88
|
+
try:
|
|
89
|
+
raw_config = json.loads(self.api_keys_config)
|
|
90
|
+
except json.JSONDecodeError:
|
|
91
|
+
return {}
|
|
92
|
+
|
|
93
|
+
result = {}
|
|
94
|
+
for key_id, config in raw_config.items():
|
|
95
|
+
result[key_id] = APIKeyConfig(
|
|
96
|
+
key_id=key_id,
|
|
97
|
+
api_key=config.get("api_key", ""),
|
|
98
|
+
provider=config.get("provider", ""),
|
|
99
|
+
models=config.get("models", []),
|
|
100
|
+
tpm_limit=config.get("tpm_limit", 10000),
|
|
101
|
+
rpm_limit=config.get("rpm_limit", 60),
|
|
102
|
+
)
|
|
103
|
+
return result
|
|
104
|
+
|
|
105
|
+
def get_keys_for_model(self, model: str) -> list[APIKeyConfig]:
|
|
106
|
+
"""Get all API keys that support a given model."""
|
|
107
|
+
keys = self.get_api_keys()
|
|
108
|
+
return [k for k in keys.values() if model in k.models]
|
|
109
|
+
|
|
110
|
+
def get_keys_for_provider(self, provider: str) -> list[APIKeyConfig]:
|
|
111
|
+
"""Get all API keys for a given provider."""
|
|
112
|
+
keys = self.get_api_keys()
|
|
113
|
+
return [k for k in keys.values() if k.provider == provider]
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
# Global settings instance
|
|
117
|
+
settings = Settings()
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
"""Core module initialization."""
|
|
2
|
+
|
|
3
|
+
from llm_scheduler.core.scheduler import Scheduler
|
|
4
|
+
from llm_scheduler.core.queue_manager import QueueManager, QueuedRequest
|
|
5
|
+
from llm_scheduler.core.token_estimator import TokenEstimator
|
|
6
|
+
from llm_scheduler.core.dispatcher import Dispatcher
|
|
7
|
+
|
|
8
|
+
__all__ = ["Scheduler", "QueueManager", "QueuedRequest", "TokenEstimator", "Dispatcher"]
|