lm-deluge 0.0.67__py3-none-any.whl → 0.0.90__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of lm-deluge might be problematic. Click here for more details.

Files changed (108) hide show
  1. lm_deluge/__init__.py +1 -2
  2. lm_deluge/api_requests/anthropic.py +117 -22
  3. lm_deluge/api_requests/base.py +84 -11
  4. lm_deluge/api_requests/bedrock.py +30 -6
  5. lm_deluge/api_requests/chat_reasoning.py +4 -0
  6. lm_deluge/api_requests/gemini.py +166 -20
  7. lm_deluge/api_requests/openai.py +145 -25
  8. lm_deluge/batches.py +15 -45
  9. lm_deluge/client.py +309 -50
  10. lm_deluge/config.py +15 -3
  11. lm_deluge/models/__init__.py +14 -1
  12. lm_deluge/models/anthropic.py +29 -14
  13. lm_deluge/models/arcee.py +16 -0
  14. lm_deluge/models/deepseek.py +36 -4
  15. lm_deluge/models/google.py +42 -0
  16. lm_deluge/models/grok.py +24 -0
  17. lm_deluge/models/kimi.py +36 -0
  18. lm_deluge/models/minimax.py +18 -0
  19. lm_deluge/models/openai.py +100 -0
  20. lm_deluge/models/openrouter.py +133 -7
  21. lm_deluge/models/together.py +11 -0
  22. lm_deluge/models/zai.py +50 -0
  23. lm_deluge/pipelines/gepa/__init__.py +95 -0
  24. lm_deluge/pipelines/gepa/core.py +354 -0
  25. lm_deluge/pipelines/gepa/docs/samples.py +705 -0
  26. lm_deluge/pipelines/gepa/examples/01_synthetic_keywords.py +140 -0
  27. lm_deluge/pipelines/gepa/examples/02_gsm8k_math.py +261 -0
  28. lm_deluge/pipelines/gepa/examples/03_hotpotqa_multihop.py +300 -0
  29. lm_deluge/pipelines/gepa/examples/04_batch_classification.py +271 -0
  30. lm_deluge/pipelines/gepa/examples/simple_qa.py +129 -0
  31. lm_deluge/pipelines/gepa/optimizer.py +435 -0
  32. lm_deluge/pipelines/gepa/proposer.py +235 -0
  33. lm_deluge/pipelines/gepa/util.py +165 -0
  34. lm_deluge/{llm_tools → pipelines}/score.py +2 -2
  35. lm_deluge/{llm_tools → pipelines}/translate.py +5 -3
  36. lm_deluge/prompt.py +537 -88
  37. lm_deluge/request_context.py +7 -2
  38. lm_deluge/server/__init__.py +24 -0
  39. lm_deluge/server/__main__.py +144 -0
  40. lm_deluge/server/adapters.py +369 -0
  41. lm_deluge/server/app.py +388 -0
  42. lm_deluge/server/auth.py +71 -0
  43. lm_deluge/server/model_policy.py +215 -0
  44. lm_deluge/server/models_anthropic.py +172 -0
  45. lm_deluge/server/models_openai.py +175 -0
  46. lm_deluge/tool/__init__.py +1130 -0
  47. lm_deluge/tool/builtin/anthropic/__init__.py +300 -0
  48. lm_deluge/tool/builtin/anthropic/bash.py +0 -0
  49. lm_deluge/tool/builtin/anthropic/computer_use.py +0 -0
  50. lm_deluge/tool/builtin/gemini.py +59 -0
  51. lm_deluge/tool/builtin/openai.py +74 -0
  52. lm_deluge/tool/cua/__init__.py +173 -0
  53. lm_deluge/tool/cua/actions.py +148 -0
  54. lm_deluge/tool/cua/base.py +27 -0
  55. lm_deluge/tool/cua/batch.py +215 -0
  56. lm_deluge/tool/cua/converters.py +466 -0
  57. lm_deluge/tool/cua/kernel.py +702 -0
  58. lm_deluge/tool/cua/trycua.py +989 -0
  59. lm_deluge/tool/prefab/__init__.py +45 -0
  60. lm_deluge/tool/prefab/batch_tool.py +156 -0
  61. lm_deluge/tool/prefab/docs.py +1119 -0
  62. lm_deluge/tool/prefab/email.py +294 -0
  63. lm_deluge/tool/prefab/filesystem.py +1711 -0
  64. lm_deluge/tool/prefab/full_text_search/__init__.py +285 -0
  65. lm_deluge/tool/prefab/full_text_search/tantivy_index.py +396 -0
  66. lm_deluge/tool/prefab/memory.py +458 -0
  67. lm_deluge/tool/prefab/otc/__init__.py +165 -0
  68. lm_deluge/tool/prefab/otc/executor.py +281 -0
  69. lm_deluge/tool/prefab/otc/parse.py +188 -0
  70. lm_deluge/tool/prefab/random.py +212 -0
  71. lm_deluge/tool/prefab/rlm/__init__.py +296 -0
  72. lm_deluge/tool/prefab/rlm/executor.py +349 -0
  73. lm_deluge/tool/prefab/rlm/parse.py +144 -0
  74. lm_deluge/tool/prefab/sandbox/__init__.py +19 -0
  75. lm_deluge/tool/prefab/sandbox/daytona_sandbox.py +483 -0
  76. lm_deluge/tool/prefab/sandbox/docker_sandbox.py +609 -0
  77. lm_deluge/tool/prefab/sandbox/fargate_sandbox.py +546 -0
  78. lm_deluge/tool/prefab/sandbox/modal_sandbox.py +469 -0
  79. lm_deluge/tool/prefab/sandbox/seatbelt_sandbox.py +827 -0
  80. lm_deluge/tool/prefab/sheets.py +385 -0
  81. lm_deluge/tool/prefab/skills.py +0 -0
  82. lm_deluge/tool/prefab/subagents.py +233 -0
  83. lm_deluge/tool/prefab/todos.py +342 -0
  84. lm_deluge/tool/prefab/tool_search.py +169 -0
  85. lm_deluge/tool/prefab/web_search.py +199 -0
  86. lm_deluge/tracker.py +16 -13
  87. lm_deluge/util/schema.py +412 -0
  88. lm_deluge/warnings.py +8 -0
  89. {lm_deluge-0.0.67.dist-info → lm_deluge-0.0.90.dist-info}/METADATA +23 -9
  90. lm_deluge-0.0.90.dist-info/RECORD +132 -0
  91. lm_deluge/built_in_tools/anthropic/__init__.py +0 -128
  92. lm_deluge/built_in_tools/openai.py +0 -28
  93. lm_deluge/presets/cerebras.py +0 -17
  94. lm_deluge/presets/meta.py +0 -13
  95. lm_deluge/tool.py +0 -849
  96. lm_deluge-0.0.67.dist-info/RECORD +0 -72
  97. lm_deluge/{llm_tools → pipelines}/__init__.py +1 -1
  98. /lm_deluge/{llm_tools → pipelines}/classify.py +0 -0
  99. /lm_deluge/{llm_tools → pipelines}/extract.py +0 -0
  100. /lm_deluge/{llm_tools → pipelines}/locate.py +0 -0
  101. /lm_deluge/{llm_tools → pipelines}/ocr.py +0 -0
  102. /lm_deluge/{built_in_tools/anthropic/bash.py → skills/anthropic.py} +0 -0
  103. /lm_deluge/{built_in_tools/anthropic/computer_use.py → skills/compat.py} +0 -0
  104. /lm_deluge/{built_in_tools → tool/builtin}/anthropic/editor.py +0 -0
  105. /lm_deluge/{built_in_tools → tool/builtin}/base.py +0 -0
  106. {lm_deluge-0.0.67.dist-info → lm_deluge-0.0.90.dist-info}/WHEEL +0 -0
  107. {lm_deluge-0.0.67.dist-info → lm_deluge-0.0.90.dist-info}/licenses/LICENSE +0 -0
  108. {lm_deluge-0.0.67.dist-info → lm_deluge-0.0.90.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,388 @@
1
+ """
2
+ FastAPI application for the LM-Deluge proxy server.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import json
8
+ import os
9
+ import traceback
10
+ from contextlib import asynccontextmanager
11
+
12
+ import aiohttp
13
+ from dotenv import load_dotenv
14
+ from fastapi import Depends, FastAPI, HTTPException, Request
15
+ from fastapi.responses import JSONResponse
16
+
17
+ from lm_deluge.models import APIModel, registry
18
+ from lm_deluge.prompt import CachePattern
19
+ from lm_deluge.request_context import RequestContext
20
+ from lm_deluge.tracker import StatusTracker
21
+
22
+ from .adapters import (
23
+ anthropic_request_to_conversation,
24
+ anthropic_request_to_sampling_params,
25
+ anthropic_tools_to_lm_deluge,
26
+ api_response_to_anthropic,
27
+ api_response_to_openai,
28
+ openai_request_to_conversation,
29
+ openai_request_to_sampling_params,
30
+ openai_tools_to_lm_deluge,
31
+ )
32
+ from .auth import verify_anthropic_auth, verify_openai_auth
33
+ from .model_policy import ModelRouter, ProxyModelPolicy
34
+ from .models_anthropic import (
35
+ AnthropicErrorDetail,
36
+ AnthropicErrorResponse,
37
+ AnthropicMessagesRequest,
38
+ AnthropicMessagesResponse,
39
+ )
40
+ from .models_openai import (
41
+ OpenAIChatCompletionsRequest,
42
+ OpenAIChatCompletionsResponse,
43
+ OpenAIErrorDetail,
44
+ OpenAIErrorResponse,
45
+ OpenAIModelInfo,
46
+ OpenAIModelsResponse,
47
+ )
48
+
49
+ # Valid cache patterns
50
+ _VALID_CACHE_PATTERNS = {
51
+ "tools_only",
52
+ "system_and_tools",
53
+ "last_user_message",
54
+ "last_2_user_messages",
55
+ "last_3_user_messages",
56
+ }
57
+
58
+ _TRUTHY_VALUES = {"1", "true", "yes", "on"}
59
+
60
+
61
+ def get_cache_pattern() -> CachePattern | None:
62
+ """
63
+ Get cache pattern from DELUGE_CACHE_PATTERN environment variable.
64
+
65
+ Valid values:
66
+ - none / NONE / unset → no caching
67
+ - tools_only → cache tools definition
68
+ - system_and_tools → cache system prompt and tools
69
+ - last_user_message → cache last user message
70
+ - last_2_user_messages → cache last 2 user messages
71
+ - last_3_user_messages → cache last 3 user messages
72
+ """
73
+ pattern = os.getenv("DELUGE_CACHE_PATTERN", "").lower().strip()
74
+ if not pattern or pattern == "none":
75
+ return None
76
+ if pattern in _VALID_CACHE_PATTERNS:
77
+ return pattern # type: ignore
78
+ return None
79
+
80
+
81
+ def _is_truthy_env(name: str) -> bool:
82
+ return os.getenv(name, "").strip().lower() in _TRUTHY_VALUES
83
+
84
+
85
+ # Global aiohttp session for connection reuse
86
+ _http_session: aiohttp.ClientSession | None = None
87
+
88
+
89
+ @asynccontextmanager
90
+ async def lifespan(app: FastAPI):
91
+ """Manage application lifespan - startup and shutdown."""
92
+ global _http_session
93
+
94
+ # Load .env file if present
95
+ load_dotenv()
96
+
97
+ # Create shared aiohttp session
98
+ connector = aiohttp.TCPConnector(
99
+ limit=100,
100
+ limit_per_host=20,
101
+ keepalive_timeout=30,
102
+ enable_cleanup_closed=True,
103
+ )
104
+ _http_session = aiohttp.ClientSession(connector=connector)
105
+
106
+ yield
107
+
108
+ # Cleanup
109
+ if _http_session:
110
+ await _http_session.close()
111
+ _http_session = None
112
+
113
+
114
+ def _is_model_available(api_model: APIModel) -> bool:
115
+ """Check if model is available based on configured API keys."""
116
+ env_var = api_model.api_key_env_var
117
+ if not env_var:
118
+ if api_model.api_spec == "bedrock":
119
+ return bool(os.getenv("AWS_ACCESS_KEY_ID") or os.getenv("AWS_PROFILE"))
120
+ return False
121
+ return bool(os.getenv(env_var))
122
+
123
+
124
+ def create_app(policy: ProxyModelPolicy | None = None) -> FastAPI:
125
+ """Create and configure the FastAPI application."""
126
+ policy = policy or ProxyModelPolicy()
127
+ router = ModelRouter(policy, registry)
128
+
129
+ app = FastAPI(
130
+ title="LM-Deluge Proxy Server",
131
+ description="OpenAI and Anthropic compatible API proxy backed by lm-deluge",
132
+ version="0.1.0",
133
+ lifespan=lifespan,
134
+ )
135
+
136
+ @app.middleware("http")
137
+ async def log_requests(request: Request, call_next):
138
+ if _is_truthy_env("DELUGE_PROXY_LOG_REQUESTS"):
139
+ body = await request.body()
140
+ body_text = body.decode("utf-8", errors="replace")
141
+ if body_text:
142
+ try:
143
+ body_text = json.dumps(json.loads(body_text), indent=2)
144
+ except Exception:
145
+ pass
146
+ print("DELUGE_PROXY_REQUEST")
147
+ print(f"{request.method} {request.url}")
148
+ print("Headers:")
149
+ print(dict(request.headers))
150
+ if body_text:
151
+ print("Body:")
152
+ print(body_text)
153
+ return await call_next(request)
154
+
155
+ # ========================================================================
156
+ # Health Check
157
+ # ========================================================================
158
+
159
+ @app.get("/health")
160
+ async def health_check():
161
+ """Health check endpoint."""
162
+ return {"status": "ok"}
163
+
164
+ # ========================================================================
165
+ # OpenAI-Compatible Endpoints
166
+ # ========================================================================
167
+
168
+ @app.get("/v1/models", dependencies=[Depends(verify_openai_auth)])
169
+ async def list_models(all: bool = False) -> OpenAIModelsResponse:
170
+ """
171
+ List available models (OpenAI-compatible).
172
+
173
+ By default, only returns models for which the required API key
174
+ is set in the environment. Use ?all=true to list all registered models.
175
+ """
176
+ models = []
177
+ model_ids = router.list_model_ids(
178
+ only_available=not all,
179
+ is_available=lambda model_id: _is_model_available(registry[model_id]),
180
+ )
181
+ for model_id in model_ids:
182
+ models.append(OpenAIModelInfo(id=model_id, owned_by="lm-deluge"))
183
+
184
+ return OpenAIModelsResponse(data=models)
185
+
186
+ @app.post(
187
+ "/v1/chat/completions",
188
+ dependencies=[Depends(verify_openai_auth)],
189
+ response_model=None,
190
+ )
191
+ async def openai_chat_completions(
192
+ req: OpenAIChatCompletionsRequest,
193
+ ) -> OpenAIChatCompletionsResponse | JSONResponse:
194
+ """OpenAI-compatible chat completions endpoint."""
195
+ # Reject streaming
196
+ if req.stream:
197
+ raise HTTPException(
198
+ status_code=400,
199
+ detail="Streaming is not supported. Set stream=false.",
200
+ )
201
+
202
+ try:
203
+ # Get model from registry
204
+ try:
205
+ resolved_model = router.resolve(req.model)
206
+ api_model = APIModel.from_registry(resolved_model)
207
+ except ValueError as e:
208
+ raise HTTPException(status_code=400, detail=str(e))
209
+
210
+ # Convert request to lm-deluge types
211
+ conversation = openai_request_to_conversation(req)
212
+ sampling_params = openai_request_to_sampling_params(req)
213
+
214
+ # Convert tools if provided
215
+ tools = None
216
+ if req.tools:
217
+ tools = openai_tools_to_lm_deluge(req.tools)
218
+
219
+ # Apply cache pattern only for Anthropic-compatible models
220
+ cache = None
221
+ if api_model.api_spec in ("anthropic", "bedrock"):
222
+ cache = get_cache_pattern()
223
+
224
+ # Build RequestContext
225
+ # We need a minimal StatusTracker for execute_once to work
226
+ tracker = StatusTracker(
227
+ max_requests_per_minute=1000,
228
+ max_tokens_per_minute=1_000_000,
229
+ max_concurrent_requests=100,
230
+ use_progress_bar=False,
231
+ )
232
+
233
+ context = RequestContext(
234
+ task_id=0,
235
+ model_name=resolved_model,
236
+ prompt=conversation,
237
+ sampling_params=sampling_params,
238
+ tools=tools,
239
+ cache=cache,
240
+ status_tracker=tracker,
241
+ request_timeout=int(os.getenv("DELUGE_PROXY_TIMEOUT", "120")),
242
+ )
243
+
244
+ # Create and execute request
245
+ request_obj = api_model.make_request(context)
246
+ response = await request_obj.execute_once()
247
+
248
+ # Check for errors
249
+ if response.is_error:
250
+ return JSONResponse(
251
+ status_code=response.status_code or 500,
252
+ content=OpenAIErrorResponse(
253
+ error=OpenAIErrorDetail(
254
+ message=response.error_message or "Unknown error",
255
+ type="api_error",
256
+ code=str(response.status_code)
257
+ if response.status_code
258
+ else None,
259
+ )
260
+ ).model_dump(),
261
+ )
262
+
263
+ # Convert to OpenAI format
264
+ return api_response_to_openai(response, resolved_model)
265
+
266
+ except HTTPException:
267
+ raise
268
+ except Exception as e:
269
+ traceback.print_exc()
270
+ return JSONResponse(
271
+ status_code=500,
272
+ content=OpenAIErrorResponse(
273
+ error=OpenAIErrorDetail(
274
+ message=str(e),
275
+ type="internal_error",
276
+ )
277
+ ).model_dump(),
278
+ )
279
+
280
+ # ========================================================================
281
+ # Anthropic-Compatible Endpoints
282
+ # ========================================================================
283
+
284
+ # Support both /v1/messages and /messages for Anthropic SDK compatibility
285
+ # The Anthropic SDK constructs paths as {base_url}/v1/messages
286
+ @app.post(
287
+ "/v1/messages",
288
+ dependencies=[Depends(verify_anthropic_auth)],
289
+ response_model=None,
290
+ )
291
+ @app.post(
292
+ "/messages",
293
+ dependencies=[Depends(verify_anthropic_auth)],
294
+ response_model=None,
295
+ )
296
+ async def anthropic_messages(
297
+ request: Request,
298
+ req: AnthropicMessagesRequest,
299
+ ) -> AnthropicMessagesResponse | JSONResponse:
300
+ """Anthropic-compatible messages endpoint."""
301
+ # Reject streaming
302
+ if req.stream:
303
+ raise HTTPException(
304
+ status_code=400,
305
+ detail="Streaming is not supported. Set stream=false.",
306
+ )
307
+
308
+ try:
309
+ # Get model from registry
310
+ try:
311
+ resolved_model = router.resolve(req.model)
312
+ api_model = APIModel.from_registry(resolved_model)
313
+ except ValueError as e:
314
+ raise HTTPException(status_code=400, detail=str(e))
315
+
316
+ # Convert request to lm-deluge types
317
+ conversation = anthropic_request_to_conversation(req)
318
+ sampling_params = anthropic_request_to_sampling_params(req)
319
+
320
+ # Convert tools if provided
321
+ tools = None
322
+ if req.tools:
323
+ tools = anthropic_tools_to_lm_deluge(req.tools)
324
+
325
+ # Apply cache pattern only for Anthropic-compatible models
326
+ cache = None
327
+ if api_model.api_spec in ("anthropic", "bedrock"):
328
+ cache = get_cache_pattern()
329
+
330
+ # Build RequestContext
331
+ tracker = StatusTracker(
332
+ max_requests_per_minute=1000,
333
+ max_tokens_per_minute=1_000_000,
334
+ max_concurrent_requests=100,
335
+ use_progress_bar=False,
336
+ )
337
+
338
+ extra_headers = None
339
+ beta_header = request.headers.get("anthropic-beta")
340
+ if beta_header:
341
+ extra_headers = {"anthropic-beta": beta_header}
342
+
343
+ context = RequestContext(
344
+ task_id=0,
345
+ model_name=resolved_model,
346
+ prompt=conversation,
347
+ sampling_params=sampling_params,
348
+ tools=tools,
349
+ cache=cache,
350
+ status_tracker=tracker,
351
+ request_timeout=int(os.getenv("DELUGE_PROXY_TIMEOUT", "120")),
352
+ extra_headers=extra_headers,
353
+ )
354
+
355
+ # Create and execute request
356
+ request_obj = api_model.make_request(context)
357
+ response = await request_obj.execute_once()
358
+
359
+ # Check for errors
360
+ if response.is_error:
361
+ return JSONResponse(
362
+ status_code=response.status_code or 500,
363
+ content=AnthropicErrorResponse(
364
+ error=AnthropicErrorDetail(
365
+ type="api_error",
366
+ message=response.error_message or "Unknown error",
367
+ )
368
+ ).model_dump(),
369
+ )
370
+
371
+ # Convert to Anthropic format
372
+ return api_response_to_anthropic(response, resolved_model)
373
+
374
+ except HTTPException:
375
+ raise
376
+ except Exception as e:
377
+ traceback.print_exc()
378
+ return JSONResponse(
379
+ status_code=500,
380
+ content=AnthropicErrorResponse(
381
+ error=AnthropicErrorDetail(
382
+ type="internal_error",
383
+ message=str(e),
384
+ )
385
+ ).model_dump(),
386
+ )
387
+
388
+ return app
@@ -0,0 +1,71 @@
1
+ """
2
+ Optional authentication for the proxy server.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import os
8
+
9
+ from fastapi import Header, HTTPException
10
+
11
+
12
+ def get_proxy_api_key() -> str | None:
13
+ """Get the configured proxy API key from environment."""
14
+ return os.getenv("DELUGE_PROXY_API_KEY")
15
+
16
+
17
+ async def verify_openai_auth(
18
+ authorization: str | None = Header(default=None),
19
+ ) -> None:
20
+ """
21
+ Verify OpenAI-style Bearer token authentication.
22
+ Only enforced if DELUGE_PROXY_API_KEY is set.
23
+ """
24
+ expected_key = get_proxy_api_key()
25
+ if not expected_key:
26
+ # No auth configured, allow all requests
27
+ return
28
+
29
+ if not authorization:
30
+ raise HTTPException(
31
+ status_code=401,
32
+ detail="Missing Authorization header",
33
+ )
34
+
35
+ if not authorization.startswith("Bearer "):
36
+ raise HTTPException(
37
+ status_code=401,
38
+ detail="Invalid Authorization header format. Expected 'Bearer <token>'",
39
+ )
40
+
41
+ token = authorization.removeprefix("Bearer ").strip()
42
+ if token != expected_key:
43
+ raise HTTPException(
44
+ status_code=401,
45
+ detail="Invalid API key",
46
+ )
47
+
48
+
49
+ async def verify_anthropic_auth(
50
+ x_api_key: str | None = Header(default=None, alias="x-api-key"),
51
+ ) -> None:
52
+ """
53
+ Verify Anthropic-style x-api-key header authentication.
54
+ Only enforced if DELUGE_PROXY_API_KEY is set.
55
+ """
56
+ expected_key = get_proxy_api_key()
57
+ if not expected_key:
58
+ # No auth configured, allow all requests
59
+ return
60
+
61
+ if not x_api_key:
62
+ raise HTTPException(
63
+ status_code=401,
64
+ detail="Missing x-api-key header",
65
+ )
66
+
67
+ if x_api_key != expected_key:
68
+ raise HTTPException(
69
+ status_code=401,
70
+ detail="Invalid API key",
71
+ )
@@ -0,0 +1,215 @@
1
+ from __future__ import annotations
2
+
3
+ import random
4
+ from typing import Callable, Literal
5
+
6
+ import yaml
7
+ from pydantic import BaseModel, Field, model_validator
8
+
9
+ from lm_deluge.models import APIModel, registry
10
+
11
+ RouteStrategy = Literal["round_robin", "random", "weighted"]
12
+ PolicyMode = Literal["allow_user_pick", "force_default", "alias_only"]
13
+
14
+
15
+ def _dedupe_keep_order(values: list[str]) -> list[str]:
16
+ seen: set[str] = set()
17
+ output = []
18
+ for value in values:
19
+ if value in seen:
20
+ continue
21
+ seen.add(value)
22
+ output.append(value)
23
+ return output
24
+
25
+
26
+ class RouteConfig(BaseModel):
27
+ models: list[str]
28
+ strategy: RouteStrategy = "round_robin"
29
+ weights: list[float] | None = None
30
+
31
+ @model_validator(mode="after")
32
+ def _validate_route(self) -> "RouteConfig":
33
+ if not self.models:
34
+ raise ValueError("route models must not be empty")
35
+ if self.strategy == "weighted":
36
+ if not self.weights:
37
+ raise ValueError("weighted strategy requires weights")
38
+ if len(self.weights) != len(self.models):
39
+ raise ValueError("weights must match models length")
40
+ if any(weight < 0 for weight in self.weights):
41
+ raise ValueError("weights must be non-negative")
42
+ if sum(self.weights) <= 0:
43
+ raise ValueError("weights must sum to a positive value")
44
+ return self
45
+
46
+
47
+ class ProxyModelPolicy(BaseModel):
48
+ mode: PolicyMode = "allow_user_pick"
49
+ allowed_models: list[str] | None = None
50
+ default_model: str | None = None
51
+ routes: dict[str, RouteConfig] = Field(default_factory=dict)
52
+ expose_aliases: bool = False
53
+
54
+ def validate_against_registry(self, model_registry: dict[str, APIModel]) -> None:
55
+ registry_keys = set(model_registry.keys())
56
+ allowed_models = (
57
+ _dedupe_keep_order(self.allowed_models)
58
+ if self.allowed_models is not None
59
+ else None
60
+ )
61
+
62
+ if allowed_models is not None:
63
+ unknown = [model for model in allowed_models if model not in registry_keys]
64
+ if unknown:
65
+ raise ValueError(f"Unknown allowed models: {', '.join(unknown)}")
66
+
67
+ for alias, route in self.routes.items():
68
+ if alias in registry_keys:
69
+ raise ValueError(
70
+ f"Route alias '{alias}' conflicts with a registry model id"
71
+ )
72
+ for model_id in route.models:
73
+ if model_id not in registry_keys:
74
+ raise ValueError(
75
+ f"Route '{alias}' references unknown model '{model_id}'"
76
+ )
77
+ if allowed_models is not None and model_id not in allowed_models:
78
+ raise ValueError(
79
+ f"Route '{alias}' uses model '{model_id}' not in allowlist"
80
+ )
81
+
82
+ if self.mode == "force_default" and not self.default_model:
83
+ raise ValueError("force_default mode requires default_model")
84
+
85
+ if self.default_model:
86
+ if self.default_model not in self.routes:
87
+ if self.default_model not in registry_keys:
88
+ raise ValueError(f"Default model '{self.default_model}' is unknown")
89
+ if (
90
+ allowed_models is not None
91
+ and self.default_model not in allowed_models
92
+ ):
93
+ raise ValueError(
94
+ f"Default model '{self.default_model}' not in allowlist"
95
+ )
96
+
97
+ if self.mode == "alias_only" and not self.routes:
98
+ raise ValueError("alias_only mode requires at least one route alias")
99
+
100
+ def allowed_raw_models(self, model_registry: dict[str, APIModel]) -> list[str]:
101
+ if self.allowed_models is None:
102
+ return list(model_registry.keys())
103
+ return _dedupe_keep_order(self.allowed_models)
104
+
105
+
106
+ class ModelRouter:
107
+ def __init__(
108
+ self,
109
+ policy: ProxyModelPolicy,
110
+ model_registry: dict[str, APIModel] | None = None,
111
+ *,
112
+ rng: random.Random | None = None,
113
+ ) -> None:
114
+ self.policy = policy
115
+ self.model_registry = model_registry or registry
116
+ self.policy.validate_against_registry(self.model_registry)
117
+ self._rng = rng or random.Random()
118
+ self._round_robin_index: dict[str, int] = {}
119
+
120
+ def resolve(self, requested_model: str) -> str:
121
+ target = requested_model
122
+ if self.policy.mode == "force_default":
123
+ if not self.policy.default_model:
124
+ raise ValueError("No default model configured")
125
+ target = self.policy.default_model
126
+
127
+ if target in self.policy.routes:
128
+ return self._select_from_route(target)
129
+
130
+ if self.policy.mode == "alias_only":
131
+ raise ValueError(f"Model '{requested_model}' is not an exposed alias")
132
+
133
+ if target not in self.model_registry:
134
+ raise ValueError(f"Model '{target}' not found in registry")
135
+ if (
136
+ self.policy.allowed_models is not None
137
+ and target not in self.policy.allowed_models
138
+ ):
139
+ raise ValueError(f"Model '{target}' is not allowed by proxy policy")
140
+ return target
141
+
142
+ def list_model_ids(
143
+ self,
144
+ *,
145
+ only_available: bool,
146
+ is_available: Callable[[str], bool],
147
+ ) -> list[str]:
148
+ models: list[str] = []
149
+
150
+ if self.policy.mode != "alias_only":
151
+ raw_models = self.policy.allowed_raw_models(self.model_registry)
152
+ if only_available:
153
+ raw_models = [model for model in raw_models if is_available(model)]
154
+ models.extend(raw_models)
155
+
156
+ if self.policy.mode == "alias_only" or self.policy.expose_aliases:
157
+ aliases = list(self.policy.routes.keys())
158
+ if only_available:
159
+ aliases = [
160
+ alias
161
+ for alias in aliases
162
+ if any(
163
+ is_available(model)
164
+ for model in self.policy.routes[alias].models
165
+ )
166
+ ]
167
+ models.extend(aliases)
168
+
169
+ return models
170
+
171
+ def _select_from_route(self, alias: str) -> str:
172
+ route = self.policy.routes[alias]
173
+ models = route.models
174
+ if len(models) == 1:
175
+ return models[0]
176
+
177
+ if route.strategy == "round_robin":
178
+ index = self._round_robin_index.get(alias, 0)
179
+ selected = models[index % len(models)]
180
+ self._round_robin_index[alias] = index + 1
181
+ return selected
182
+
183
+ if route.strategy == "weighted":
184
+ assert route.weights is not None
185
+ return self._rng.choices(models, weights=route.weights, k=1)[0]
186
+
187
+ return self._rng.choice(models)
188
+
189
+
190
+ def load_policy_data(path: str | None) -> dict:
191
+ if not path:
192
+ return {}
193
+ with open(path, "r", encoding="utf-8") as handle:
194
+ data = yaml.safe_load(handle) or {}
195
+ if not isinstance(data, dict):
196
+ return {}
197
+ if "model_policy" in data:
198
+ return data["model_policy"] or {}
199
+ proxy_block = data.get("proxy")
200
+ if isinstance(proxy_block, dict) and "model_policy" in proxy_block:
201
+ return proxy_block["model_policy"] or {}
202
+ return data
203
+
204
+
205
+ def build_policy(
206
+ *,
207
+ path: str | None = None,
208
+ overrides: dict | None = None,
209
+ ) -> ProxyModelPolicy:
210
+ data = load_policy_data(path)
211
+ if overrides:
212
+ data.update(overrides)
213
+ policy = ProxyModelPolicy(**data)
214
+ policy.validate_against_registry(registry)
215
+ return policy