maxllm-gate 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. llm_scheduler/__init__.py +8 -0
  2. llm_scheduler/api/__init__.py +6 -0
  3. llm_scheduler/api/dependencies.py +10 -0
  4. llm_scheduler/api/routes.py +275 -0
  5. llm_scheduler/api/schemas.py +135 -0
  6. llm_scheduler/config.py +117 -0
  7. llm_scheduler/core/__init__.py +8 -0
  8. llm_scheduler/core/dispatcher.py +225 -0
  9. llm_scheduler/core/queue_manager.py +251 -0
  10. llm_scheduler/core/scheduler.py +236 -0
  11. llm_scheduler/core/token_estimator.py +201 -0
  12. llm_scheduler/main.py +86 -0
  13. llm_scheduler/models/__init__.py +6 -0
  14. llm_scheduler/models/provider.py +103 -0
  15. llm_scheduler/models/request.py +101 -0
  16. llm_scheduler/observability/__init__.py +6 -0
  17. llm_scheduler/observability/logging.py +65 -0
  18. llm_scheduler/observability/metrics.py +92 -0
  19. llm_scheduler/rate_limiting/__init__.py +7 -0
  20. llm_scheduler/rate_limiting/key_manager.py +252 -0
  21. llm_scheduler/rate_limiting/token_bucket.py +152 -0
  22. llm_scheduler/rate_limiting/tracker.py +281 -0
  23. llm_scheduler/strategies/__init__.py +7 -0
  24. llm_scheduler/strategies/base.py +56 -0
  25. llm_scheduler/strategies/fallback.py +52 -0
  26. llm_scheduler/strategies/least_utilized.py +30 -0
  27. llm_scheduler/strategies/round_robin.py +29 -0
  28. llm_scheduler/strategies/token_aware.py +46 -0
  29. llm_scheduler/utils/__init__.py +6 -0
  30. llm_scheduler/utils/retry.py +136 -0
  31. llm_scheduler/utils/time_utils.py +115 -0
  32. maxllm/__init__.py +77 -0
  33. maxllm/client.py +598 -0
  34. maxllm/config.py +181 -0
  35. maxllm/rate_limiter.py +432 -0
  36. maxllm/redis_backend.py +495 -0
  37. maxllm/scheduler.py +559 -0
  38. maxllm/validation.py +183 -0
  39. maxllm_gate-0.2.0.dist-info/METADATA +771 -0
  40. maxllm_gate-0.2.0.dist-info/RECORD +43 -0
  41. maxllm_gate-0.2.0.dist-info/WHEEL +4 -0
  42. maxllm_gate-0.2.0.dist-info/entry_points.txt +2 -0
  43. maxllm_gate-0.2.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,8 @@
1
+ """LLM Rate Limit Scheduler - Intelligent scheduling layer on top of LiteLLM."""
2
+
3
+ __version__ = "0.1.0"
4
+
5
+ from llm_scheduler.config import settings
6
+ from llm_scheduler.core.scheduler import Scheduler
7
+
8
+ __all__ = ["settings", "Scheduler", "__version__"]
@@ -0,0 +1,6 @@
1
+ """API module initialization."""
2
+
3
+ from llm_scheduler.api.routes import router
4
+ from llm_scheduler.api.schemas import ChatRequest, ChatResponse
5
+
6
+ __all__ = ["router", "ChatRequest", "ChatResponse"]
@@ -0,0 +1,10 @@
1
+ """FastAPI dependencies."""
2
+
3
+ from fastapi import Request
4
+
5
+ from llm_scheduler.core.scheduler import Scheduler
6
+
7
+
8
+ def get_scheduler(request: Request) -> Scheduler:
9
+ """Get scheduler instance from app state."""
10
+ return request.app.state.scheduler
@@ -0,0 +1,275 @@
1
+ """FastAPI route definitions."""
2
+
3
+ from fastapi import APIRouter, Depends, HTTPException
4
+ from fastapi.responses import StreamingResponse
5
+ from prometheus_client import generate_latest, CONTENT_TYPE_LATEST
6
+
7
+ from llm_scheduler.api.schemas import (
8
+ ChatRequest,
9
+ ChatResponse,
10
+ BatchRequest,
11
+ BatchResponse,
12
+ HealthResponse,
13
+ StatusResponse,
14
+ CapacityResponse,
15
+ KeyStatus,
16
+ )
17
+ from llm_scheduler.api.dependencies import get_scheduler
18
+ from llm_scheduler.core.scheduler import Scheduler, SchedulerError
19
+ from llm_scheduler.observability.logging import get_logger
20
+
21
+
22
+ router = APIRouter()
23
+ logger = get_logger()
24
+
25
+
26
+ @router.post("/chat", response_model=ChatResponse, tags=["LLM"])
27
+ async def chat(
28
+ request: ChatRequest,
29
+ scheduler: Scheduler = Depends(get_scheduler),
30
+ ) -> ChatResponse:
31
+ """
32
+ Send a chat completion request.
33
+
34
+ The request is queued and scheduled based on:
35
+ - Priority (high > medium > low)
36
+ - Available API key capacity
37
+ - Rate limits (TPM/RPM)
38
+
39
+ If all keys are at capacity, the request is deferred until
40
+ capacity becomes available.
41
+ """
42
+ try:
43
+ messages = [m.model_dump() for m in request.messages]
44
+
45
+ result = await scheduler.schedule(
46
+ model=request.model,
47
+ messages=messages,
48
+ priority=request.priority,
49
+ max_tokens=request.max_tokens,
50
+ temperature=request.temperature,
51
+ )
52
+
53
+ # Extract content from response
54
+ content = ""
55
+ if "choices" in result and result["choices"]:
56
+ choice = result["choices"][0]
57
+ if "message" in choice:
58
+ content = choice["message"].get("content", "")
59
+ elif "text" in choice:
60
+ content = choice["text"]
61
+ elif "content" in result:
62
+ content = result["content"]
63
+
64
+ return ChatResponse(
65
+ id=result.get("id", ""),
66
+ model=result.get("model", request.model),
67
+ content=content,
68
+ usage=result.get("usage"),
69
+ finish_reason=result.get("choices", [{}])[0].get("finish_reason"),
70
+ )
71
+
72
+ except SchedulerError as e:
73
+ logger.warning("Scheduler error", error=str(e))
74
+ raise HTTPException(status_code=503, detail=str(e))
75
+ except Exception as e:
76
+ logger.error("Chat request failed", error=str(e))
77
+ raise HTTPException(status_code=500, detail=str(e))
78
+
79
+
80
+ @router.post("/chat/stream", tags=["LLM"])
81
+ async def chat_stream(
82
+ request: ChatRequest,
83
+ scheduler: Scheduler = Depends(get_scheduler),
84
+ ):
85
+ """
86
+ Send a streaming chat completion request.
87
+
88
+ Returns a Server-Sent Events (SSE) stream of response chunks.
89
+ """
90
+ async def generate():
91
+ try:
92
+ messages = [m.model_dump() for m in request.messages]
93
+
94
+ # For streaming, we need direct access to dispatcher
95
+ # This is a simplified version
96
+ result = await scheduler.schedule(
97
+ model=request.model,
98
+ messages=messages,
99
+ priority=request.priority,
100
+ max_tokens=request.max_tokens,
101
+ temperature=request.temperature,
102
+ )
103
+
104
+ content = result.get("content", "")
105
+ # Simulate streaming for collected content
106
+ for chunk in [content[i:i+10] for i in range(0, len(content), 10)]:
107
+ yield f"data: {chunk}\n\n"
108
+
109
+ yield "data: [DONE]\n\n"
110
+
111
+ except Exception as e:
112
+ yield f"data: [ERROR] {str(e)}\n\n"
113
+
114
+ return StreamingResponse(
115
+ generate(),
116
+ media_type="text/event-stream",
117
+ )
118
+
119
+
120
+ @router.post("/batch", response_model=BatchResponse, tags=["LLM"])
121
+ async def batch(
122
+ request: BatchRequest,
123
+ scheduler: Scheduler = Depends(get_scheduler),
124
+ ) -> BatchResponse:
125
+ """
126
+ Process multiple chat requests in parallel.
127
+
128
+ Returns results for all requests, with errors inline.
129
+ """
130
+ results = []
131
+ successful = 0
132
+ failed = 0
133
+
134
+ batch_requests = [
135
+ {
136
+ "model": req.model,
137
+ "messages": [m.model_dump() for m in req.messages],
138
+ "priority": req.priority,
139
+ "max_tokens": req.max_tokens,
140
+ "temperature": req.temperature,
141
+ }
142
+ for req in request.requests
143
+ ]
144
+
145
+ raw_results = await scheduler.schedule_batch(batch_requests)
146
+
147
+ for raw in raw_results:
148
+ if isinstance(raw, Exception):
149
+ failed += 1
150
+ results.append({"error": str(raw)})
151
+ else:
152
+ successful += 1
153
+ content = ""
154
+ if "choices" in raw and raw["choices"]:
155
+ content = raw["choices"][0].get("message", {}).get("content", "")
156
+
157
+ results.append(ChatResponse(
158
+ id=raw.get("id", ""),
159
+ model=raw.get("model", ""),
160
+ content=content,
161
+ usage=raw.get("usage"),
162
+ finish_reason=raw.get("choices", [{}])[0].get("finish_reason"),
163
+ ))
164
+
165
+ return BatchResponse(
166
+ results=results,
167
+ total=len(request.requests),
168
+ successful=successful,
169
+ failed=failed,
170
+ )
171
+
172
+
173
+ @router.get("/health", response_model=HealthResponse, tags=["System"])
174
+ async def health(
175
+ scheduler: Scheduler = Depends(get_scheduler),
176
+ ) -> HealthResponse:
177
+ """
178
+ Health check endpoint.
179
+
180
+ Returns overall system health status.
181
+ """
182
+ status_data = scheduler.get_status()
183
+ queue_size = status_data["queue"]["queue_size"]
184
+
185
+ # Count healthy keys
186
+ keys_data = status_data["keys"].get("keys", {})
187
+ healthy_keys = sum(1 for k in keys_data.values() if k.get("is_healthy", False))
188
+
189
+ # Determine status
190
+ if not status_data["running"]:
191
+ status = "unhealthy"
192
+ elif healthy_keys == 0:
193
+ status = "unhealthy"
194
+ elif queue_size > status_data["queue"]["max_size"] * 0.9:
195
+ status = "degraded"
196
+ else:
197
+ status = "healthy"
198
+
199
+ return HealthResponse(
200
+ status=status,
201
+ scheduler_running=status_data["running"],
202
+ queue_size=queue_size,
203
+ keys_available=healthy_keys,
204
+ )
205
+
206
+
207
+ @router.get("/status", response_model=StatusResponse, tags=["System"])
208
+ async def status(
209
+ scheduler: Scheduler = Depends(get_scheduler),
210
+ ) -> StatusResponse:
211
+ """
212
+ Get detailed scheduler status.
213
+
214
+ Returns queue statistics, key states, and configuration.
215
+ """
216
+ return scheduler.get_status()
217
+
218
+
219
+ @router.get("/capacity", response_model=CapacityResponse, tags=["System"])
220
+ async def capacity(
221
+ scheduler: Scheduler = Depends(get_scheduler),
222
+ ) -> CapacityResponse:
223
+ """
224
+ Get current capacity across all API keys.
225
+
226
+ Useful for monitoring and understanding rate limit state.
227
+ """
228
+ status_data = scheduler.get_status()
229
+ keys_data = status_data["keys"]
230
+
231
+ total = keys_data.get("total_capacity", {})
232
+ available = keys_data.get("available_capacity", {})
233
+
234
+ key_statuses = [
235
+ KeyStatus(**key_info)
236
+ for key_info in keys_data.get("keys", {}).values()
237
+ ]
238
+
239
+ return CapacityResponse(
240
+ total_tpm=total.get("tpm", 0),
241
+ available_tpm=available.get("tpm", 0),
242
+ total_rpm=total.get("rpm", 0),
243
+ available_rpm=available.get("rpm", 0),
244
+ keys=key_statuses,
245
+ )
246
+
247
+
248
+ @router.get("/metrics", tags=["System"])
249
+ async def metrics():
250
+ """
251
+ Prometheus metrics endpoint.
252
+
253
+ Returns metrics in Prometheus exposition format.
254
+ """
255
+ from starlette.responses import Response
256
+ return Response(
257
+ content=generate_latest(),
258
+ media_type=CONTENT_TYPE_LATEST,
259
+ )
260
+
261
+
262
+ @router.get("/", tags=["System"])
263
+ async def root():
264
+ """Root endpoint with API information."""
265
+ return {
266
+ "name": "LLM Rate Limit Scheduler",
267
+ "version": "0.1.0",
268
+ "description": (
269
+ "An intelligent scheduling and rate-limit-aware control layer "
270
+ "on top of LiteLLM that maximizes throughput and prevents 429 errors."
271
+ ),
272
+ "docs_url": "/docs",
273
+ "health_url": "/health",
274
+ "metrics_url": "/metrics",
275
+ }
@@ -0,0 +1,135 @@
1
+ """Pydantic schemas for API requests and responses."""
2
+
3
+ from typing import Any, Literal
4
+ from pydantic import BaseModel, Field
5
+
6
+
7
+ class Message(BaseModel):
8
+ """Chat message."""
9
+
10
+ role: Literal["system", "user", "assistant"] = Field(
11
+ description="Role of the message sender"
12
+ )
13
+ content: str = Field(description="Message content")
14
+
15
+
16
+ class ChatRequest(BaseModel):
17
+ """Request body for /chat endpoint."""
18
+
19
+ model: str = Field(
20
+ description="Model name (e.g., 'mixtral', 'gpt-4o-mini', 'llama-3.1-70b')"
21
+ )
22
+ messages: list[Message] = Field(
23
+ description="List of chat messages"
24
+ )
25
+ priority: Literal["high", "medium", "low"] = Field(
26
+ default="medium",
27
+ description="Request priority for queue ordering"
28
+ )
29
+ max_tokens: int | None = Field(
30
+ default=None,
31
+ description="Maximum tokens in response"
32
+ )
33
+ temperature: float = Field(
34
+ default=0.7,
35
+ ge=0.0,
36
+ le=2.0,
37
+ description="Sampling temperature"
38
+ )
39
+
40
+ model_config = {
41
+ "json_schema_extra": {
42
+ "examples": [
43
+ {
44
+ "model": "mixtral-8x7b-32768",
45
+ "messages": [
46
+ {"role": "system", "content": "You are a helpful assistant."},
47
+ {"role": "user", "content": "Hello, how are you?"}
48
+ ],
49
+ "priority": "medium",
50
+ "max_tokens": 1024,
51
+ "temperature": 0.7
52
+ }
53
+ ]
54
+ }
55
+ }
56
+
57
+
58
+ class ChatResponse(BaseModel):
59
+ """Response body for /chat endpoint."""
60
+
61
+ id: str = Field(description="Response ID")
62
+ model: str = Field(description="Model used")
63
+ content: str = Field(description="Generated content")
64
+ usage: dict[str, int] | None = Field(
65
+ default=None,
66
+ description="Token usage statistics"
67
+ )
68
+ finish_reason: str | None = Field(
69
+ default=None,
70
+ description="Reason for completion"
71
+ )
72
+
73
+
74
+ class BatchRequest(BaseModel):
75
+ """Request body for /batch endpoint."""
76
+
77
+ requests: list[ChatRequest] = Field(
78
+ description="List of chat requests to process"
79
+ )
80
+
81
+
82
+ class BatchResponse(BaseModel):
83
+ """Response body for /batch endpoint."""
84
+
85
+ results: list[ChatResponse | dict[str, str]] = Field(
86
+ description="Results for each request (response or error)"
87
+ )
88
+ total: int = Field(description="Total requests")
89
+ successful: int = Field(description="Successful requests")
90
+ failed: int = Field(description="Failed requests")
91
+
92
+
93
+ class HealthResponse(BaseModel):
94
+ """Response body for /health endpoint."""
95
+
96
+ status: Literal["healthy", "degraded", "unhealthy"] = Field(
97
+ description="Overall health status"
98
+ )
99
+ scheduler_running: bool = Field(description="Whether scheduler is running")
100
+ queue_size: int = Field(description="Current queue size")
101
+ keys_available: int = Field(description="Number of healthy API keys")
102
+
103
+
104
+ class StatusResponse(BaseModel):
105
+ """Response body for /status endpoint."""
106
+
107
+ running: bool
108
+ queue: dict[str, Any]
109
+ keys: dict[str, Any]
110
+ strategy: str
111
+
112
+
113
+ class KeyStatus(BaseModel):
114
+ """Status of a single API key."""
115
+
116
+ key_id: str
117
+ provider: str
118
+ tpm_available: int
119
+ tpm_capacity: int
120
+ rpm_available: int
121
+ rpm_capacity: int
122
+ utilization: float
123
+ total_requests: int
124
+ total_tokens_used: int
125
+ is_healthy: bool
126
+
127
+
128
+ class CapacityResponse(BaseModel):
129
+ """Response body for /capacity endpoint."""
130
+
131
+ total_tpm: int
132
+ available_tpm: int
133
+ total_rpm: int
134
+ available_rpm: int
135
+ keys: list[KeyStatus]
@@ -0,0 +1,117 @@
1
+ """Configuration management for LLM Rate Limit Scheduler."""
2
+
3
+ import json
4
+
5
+ from pydantic import Field, field_validator
6
+ from pydantic_settings import BaseSettings, SettingsConfigDict
7
+
8
+
9
+ class APIKeyConfig:
10
+ """Configuration for a single API key."""
11
+
12
+ def __init__(
13
+ self,
14
+ key_id: str,
15
+ api_key: str,
16
+ provider: str,
17
+ models: list[str],
18
+ tpm_limit: int,
19
+ rpm_limit: int,
20
+ ):
21
+ self.key_id = key_id
22
+ self.api_key = api_key
23
+ self.provider = provider
24
+ self.models = models
25
+ self.tpm_limit = tpm_limit
26
+ self.rpm_limit = rpm_limit
27
+
28
+ def __repr__(self) -> str:
29
+ return f"APIKeyConfig(key_id={self.key_id}, provider={self.provider})"
30
+
31
+
32
+ class Settings(BaseSettings):
33
+ """Application settings loaded from environment variables."""
34
+
35
+ model_config = SettingsConfigDict(
36
+ env_file=".env",
37
+ env_file_encoding="utf-8",
38
+ extra="ignore",
39
+ )
40
+
41
+ # Server
42
+ host: str = Field(default="0.0.0.0")
43
+ port: int = Field(default=8000)
44
+ debug: bool = Field(default=False)
45
+ log_level: str = Field(default="INFO")
46
+
47
+ # API Keys (JSON string)
48
+ api_keys_config: str = Field(default="{}")
49
+
50
+ # Scheduling
51
+ default_strategy: str = Field(default="least_utilized")
52
+
53
+ # Token estimation
54
+ default_max_tokens: int = Field(default=1024)
55
+ token_estimation_buffer: float = Field(default=1.1)
56
+
57
+ # Retry
58
+ max_retries: int = Field(default=3)
59
+ retry_base_delay: float = Field(default=1.0)
60
+ retry_max_delay: float = Field(default=60.0)
61
+
62
+ # Queue
63
+ max_queue_size: int = Field(default=10000)
64
+ default_priority: str = Field(default="medium")
65
+
66
+ # Redis (optional)
67
+ redis_url: str | None = Field(default=None)
68
+ use_redis_queue: bool = Field(default=False)
69
+
70
+ @field_validator("default_strategy")
71
+ @classmethod
72
+ def validate_strategy(cls, v: str) -> str:
73
+ valid = {"least_utilized", "round_robin", "token_aware"}
74
+ if v not in valid:
75
+ raise ValueError(f"Strategy must be one of {valid}")
76
+ return v
77
+
78
+ @field_validator("default_priority")
79
+ @classmethod
80
+ def validate_priority(cls, v: str) -> str:
81
+ valid = {"high", "medium", "low"}
82
+ if v not in valid:
83
+ raise ValueError(f"Priority must be one of {valid}")
84
+ return v
85
+
86
+ def get_api_keys(self) -> dict[str, APIKeyConfig]:
87
+ """Parse API keys configuration into typed objects."""
88
+ try:
89
+ raw_config = json.loads(self.api_keys_config)
90
+ except json.JSONDecodeError:
91
+ return {}
92
+
93
+ result = {}
94
+ for key_id, config in raw_config.items():
95
+ result[key_id] = APIKeyConfig(
96
+ key_id=key_id,
97
+ api_key=config.get("api_key", ""),
98
+ provider=config.get("provider", ""),
99
+ models=config.get("models", []),
100
+ tpm_limit=config.get("tpm_limit", 10000),
101
+ rpm_limit=config.get("rpm_limit", 60),
102
+ )
103
+ return result
104
+
105
+ def get_keys_for_model(self, model: str) -> list[APIKeyConfig]:
106
+ """Get all API keys that support a given model."""
107
+ keys = self.get_api_keys()
108
+ return [k for k in keys.values() if model in k.models]
109
+
110
+ def get_keys_for_provider(self, provider: str) -> list[APIKeyConfig]:
111
+ """Get all API keys for a given provider."""
112
+ keys = self.get_api_keys()
113
+ return [k for k in keys.values() if k.provider == provider]
114
+
115
+
116
+ # Global settings instance
117
+ settings = Settings()
@@ -0,0 +1,8 @@
1
+ """Core module initialization."""
2
+
3
+ from llm_scheduler.core.scheduler import Scheduler
4
+ from llm_scheduler.core.queue_manager import QueueManager, QueuedRequest
5
+ from llm_scheduler.core.token_estimator import TokenEstimator
6
+ from llm_scheduler.core.dispatcher import Dispatcher
7
+
8
+ __all__ = ["Scheduler", "QueueManager", "QueuedRequest", "TokenEstimator", "Dispatcher"]