lm-deluge 0.0.88__py3-none-any.whl → 0.0.90__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lm-deluge might be problematic. Click here for more details.
- lm_deluge/__init__.py +0 -24
- lm_deluge/api_requests/anthropic.py +25 -5
- lm_deluge/api_requests/base.py +37 -0
- lm_deluge/api_requests/bedrock.py +23 -2
- lm_deluge/api_requests/gemini.py +36 -10
- lm_deluge/api_requests/openai.py +31 -4
- lm_deluge/batches.py +15 -45
- lm_deluge/client.py +27 -1
- lm_deluge/models/__init__.py +2 -0
- lm_deluge/models/anthropic.py +12 -12
- lm_deluge/models/google.py +13 -0
- lm_deluge/models/minimax.py +9 -1
- lm_deluge/models/openrouter.py +48 -0
- lm_deluge/models/zai.py +50 -1
- lm_deluge/pipelines/gepa/docs/samples.py +19 -10
- lm_deluge/prompt.py +333 -68
- lm_deluge/server/__init__.py +24 -0
- lm_deluge/server/__main__.py +144 -0
- lm_deluge/server/adapters.py +369 -0
- lm_deluge/server/app.py +388 -0
- lm_deluge/server/auth.py +71 -0
- lm_deluge/server/model_policy.py +215 -0
- lm_deluge/server/models_anthropic.py +172 -0
- lm_deluge/server/models_openai.py +175 -0
- lm_deluge/skills/anthropic.py +0 -0
- lm_deluge/skills/compat.py +0 -0
- lm_deluge/tool/__init__.py +13 -1
- lm_deluge/tool/prefab/sandbox/__init__.py +19 -0
- lm_deluge/tool/prefab/sandbox/daytona_sandbox.py +483 -0
- lm_deluge/tool/prefab/sandbox/docker_sandbox.py +609 -0
- lm_deluge/tool/prefab/sandbox/fargate_sandbox.py +546 -0
- lm_deluge/tool/prefab/sandbox/modal_sandbox.py +469 -0
- lm_deluge/tool/prefab/sandbox/seatbelt_sandbox.py +827 -0
- lm_deluge/tool/prefab/skills.py +0 -0
- {lm_deluge-0.0.88.dist-info → lm_deluge-0.0.90.dist-info}/METADATA +4 -3
- {lm_deluge-0.0.88.dist-info → lm_deluge-0.0.90.dist-info}/RECORD +39 -24
- lm_deluge/mock_openai.py +0 -643
- lm_deluge/tool/prefab/sandbox.py +0 -1621
- {lm_deluge-0.0.88.dist-info → lm_deluge-0.0.90.dist-info}/WHEEL +0 -0
- {lm_deluge-0.0.88.dist-info → lm_deluge-0.0.90.dist-info}/licenses/LICENSE +0 -0
- {lm_deluge-0.0.88.dist-info → lm_deluge-0.0.90.dist-info}/top_level.txt +0 -0
lm_deluge/server/app.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
1
|
+
"""
|
|
2
|
+
FastAPI application for the LM-Deluge proxy server.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
import os
|
|
9
|
+
import traceback
|
|
10
|
+
from contextlib import asynccontextmanager
|
|
11
|
+
|
|
12
|
+
import aiohttp
|
|
13
|
+
from dotenv import load_dotenv
|
|
14
|
+
from fastapi import Depends, FastAPI, HTTPException, Request
|
|
15
|
+
from fastapi.responses import JSONResponse
|
|
16
|
+
|
|
17
|
+
from lm_deluge.models import APIModel, registry
|
|
18
|
+
from lm_deluge.prompt import CachePattern
|
|
19
|
+
from lm_deluge.request_context import RequestContext
|
|
20
|
+
from lm_deluge.tracker import StatusTracker
|
|
21
|
+
|
|
22
|
+
from .adapters import (
|
|
23
|
+
anthropic_request_to_conversation,
|
|
24
|
+
anthropic_request_to_sampling_params,
|
|
25
|
+
anthropic_tools_to_lm_deluge,
|
|
26
|
+
api_response_to_anthropic,
|
|
27
|
+
api_response_to_openai,
|
|
28
|
+
openai_request_to_conversation,
|
|
29
|
+
openai_request_to_sampling_params,
|
|
30
|
+
openai_tools_to_lm_deluge,
|
|
31
|
+
)
|
|
32
|
+
from .auth import verify_anthropic_auth, verify_openai_auth
|
|
33
|
+
from .model_policy import ModelRouter, ProxyModelPolicy
|
|
34
|
+
from .models_anthropic import (
|
|
35
|
+
AnthropicErrorDetail,
|
|
36
|
+
AnthropicErrorResponse,
|
|
37
|
+
AnthropicMessagesRequest,
|
|
38
|
+
AnthropicMessagesResponse,
|
|
39
|
+
)
|
|
40
|
+
from .models_openai import (
|
|
41
|
+
OpenAIChatCompletionsRequest,
|
|
42
|
+
OpenAIChatCompletionsResponse,
|
|
43
|
+
OpenAIErrorDetail,
|
|
44
|
+
OpenAIErrorResponse,
|
|
45
|
+
OpenAIModelInfo,
|
|
46
|
+
OpenAIModelsResponse,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
# Valid cache patterns
|
|
50
|
+
_VALID_CACHE_PATTERNS = {
|
|
51
|
+
"tools_only",
|
|
52
|
+
"system_and_tools",
|
|
53
|
+
"last_user_message",
|
|
54
|
+
"last_2_user_messages",
|
|
55
|
+
"last_3_user_messages",
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
_TRUTHY_VALUES = {"1", "true", "yes", "on"}
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def get_cache_pattern() -> CachePattern | None:
|
|
62
|
+
"""
|
|
63
|
+
Get cache pattern from DELUGE_CACHE_PATTERN environment variable.
|
|
64
|
+
|
|
65
|
+
Valid values:
|
|
66
|
+
- none / NONE / unset → no caching
|
|
67
|
+
- tools_only → cache tools definition
|
|
68
|
+
- system_and_tools → cache system prompt and tools
|
|
69
|
+
- last_user_message → cache last user message
|
|
70
|
+
- last_2_user_messages → cache last 2 user messages
|
|
71
|
+
- last_3_user_messages → cache last 3 user messages
|
|
72
|
+
"""
|
|
73
|
+
pattern = os.getenv("DELUGE_CACHE_PATTERN", "").lower().strip()
|
|
74
|
+
if not pattern or pattern == "none":
|
|
75
|
+
return None
|
|
76
|
+
if pattern in _VALID_CACHE_PATTERNS:
|
|
77
|
+
return pattern # type: ignore
|
|
78
|
+
return None
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _is_truthy_env(name: str) -> bool:
|
|
82
|
+
return os.getenv(name, "").strip().lower() in _TRUTHY_VALUES
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
# Global aiohttp session for connection reuse
|
|
86
|
+
_http_session: aiohttp.ClientSession | None = None
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@asynccontextmanager
|
|
90
|
+
async def lifespan(app: FastAPI):
|
|
91
|
+
"""Manage application lifespan - startup and shutdown."""
|
|
92
|
+
global _http_session
|
|
93
|
+
|
|
94
|
+
# Load .env file if present
|
|
95
|
+
load_dotenv()
|
|
96
|
+
|
|
97
|
+
# Create shared aiohttp session
|
|
98
|
+
connector = aiohttp.TCPConnector(
|
|
99
|
+
limit=100,
|
|
100
|
+
limit_per_host=20,
|
|
101
|
+
keepalive_timeout=30,
|
|
102
|
+
enable_cleanup_closed=True,
|
|
103
|
+
)
|
|
104
|
+
_http_session = aiohttp.ClientSession(connector=connector)
|
|
105
|
+
|
|
106
|
+
yield
|
|
107
|
+
|
|
108
|
+
# Cleanup
|
|
109
|
+
if _http_session:
|
|
110
|
+
await _http_session.close()
|
|
111
|
+
_http_session = None
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _is_model_available(api_model: APIModel) -> bool:
|
|
115
|
+
"""Check if model is available based on configured API keys."""
|
|
116
|
+
env_var = api_model.api_key_env_var
|
|
117
|
+
if not env_var:
|
|
118
|
+
if api_model.api_spec == "bedrock":
|
|
119
|
+
return bool(os.getenv("AWS_ACCESS_KEY_ID") or os.getenv("AWS_PROFILE"))
|
|
120
|
+
return False
|
|
121
|
+
return bool(os.getenv(env_var))
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def create_app(policy: ProxyModelPolicy | None = None) -> FastAPI:
|
|
125
|
+
"""Create and configure the FastAPI application."""
|
|
126
|
+
policy = policy or ProxyModelPolicy()
|
|
127
|
+
router = ModelRouter(policy, registry)
|
|
128
|
+
|
|
129
|
+
app = FastAPI(
|
|
130
|
+
title="LM-Deluge Proxy Server",
|
|
131
|
+
description="OpenAI and Anthropic compatible API proxy backed by lm-deluge",
|
|
132
|
+
version="0.1.0",
|
|
133
|
+
lifespan=lifespan,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
@app.middleware("http")
|
|
137
|
+
async def log_requests(request: Request, call_next):
|
|
138
|
+
if _is_truthy_env("DELUGE_PROXY_LOG_REQUESTS"):
|
|
139
|
+
body = await request.body()
|
|
140
|
+
body_text = body.decode("utf-8", errors="replace")
|
|
141
|
+
if body_text:
|
|
142
|
+
try:
|
|
143
|
+
body_text = json.dumps(json.loads(body_text), indent=2)
|
|
144
|
+
except Exception:
|
|
145
|
+
pass
|
|
146
|
+
print("DELUGE_PROXY_REQUEST")
|
|
147
|
+
print(f"{request.method} {request.url}")
|
|
148
|
+
print("Headers:")
|
|
149
|
+
print(dict(request.headers))
|
|
150
|
+
if body_text:
|
|
151
|
+
print("Body:")
|
|
152
|
+
print(body_text)
|
|
153
|
+
return await call_next(request)
|
|
154
|
+
|
|
155
|
+
# ========================================================================
|
|
156
|
+
# Health Check
|
|
157
|
+
# ========================================================================
|
|
158
|
+
|
|
159
|
+
@app.get("/health")
|
|
160
|
+
async def health_check():
|
|
161
|
+
"""Health check endpoint."""
|
|
162
|
+
return {"status": "ok"}
|
|
163
|
+
|
|
164
|
+
# ========================================================================
|
|
165
|
+
# OpenAI-Compatible Endpoints
|
|
166
|
+
# ========================================================================
|
|
167
|
+
|
|
168
|
+
@app.get("/v1/models", dependencies=[Depends(verify_openai_auth)])
|
|
169
|
+
async def list_models(all: bool = False) -> OpenAIModelsResponse:
|
|
170
|
+
"""
|
|
171
|
+
List available models (OpenAI-compatible).
|
|
172
|
+
|
|
173
|
+
By default, only returns models for which the required API key
|
|
174
|
+
is set in the environment. Use ?all=true to list all registered models.
|
|
175
|
+
"""
|
|
176
|
+
models = []
|
|
177
|
+
model_ids = router.list_model_ids(
|
|
178
|
+
only_available=not all,
|
|
179
|
+
is_available=lambda model_id: _is_model_available(registry[model_id]),
|
|
180
|
+
)
|
|
181
|
+
for model_id in model_ids:
|
|
182
|
+
models.append(OpenAIModelInfo(id=model_id, owned_by="lm-deluge"))
|
|
183
|
+
|
|
184
|
+
return OpenAIModelsResponse(data=models)
|
|
185
|
+
|
|
186
|
+
@app.post(
|
|
187
|
+
"/v1/chat/completions",
|
|
188
|
+
dependencies=[Depends(verify_openai_auth)],
|
|
189
|
+
response_model=None,
|
|
190
|
+
)
|
|
191
|
+
async def openai_chat_completions(
|
|
192
|
+
req: OpenAIChatCompletionsRequest,
|
|
193
|
+
) -> OpenAIChatCompletionsResponse | JSONResponse:
|
|
194
|
+
"""OpenAI-compatible chat completions endpoint."""
|
|
195
|
+
# Reject streaming
|
|
196
|
+
if req.stream:
|
|
197
|
+
raise HTTPException(
|
|
198
|
+
status_code=400,
|
|
199
|
+
detail="Streaming is not supported. Set stream=false.",
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
try:
|
|
203
|
+
# Get model from registry
|
|
204
|
+
try:
|
|
205
|
+
resolved_model = router.resolve(req.model)
|
|
206
|
+
api_model = APIModel.from_registry(resolved_model)
|
|
207
|
+
except ValueError as e:
|
|
208
|
+
raise HTTPException(status_code=400, detail=str(e))
|
|
209
|
+
|
|
210
|
+
# Convert request to lm-deluge types
|
|
211
|
+
conversation = openai_request_to_conversation(req)
|
|
212
|
+
sampling_params = openai_request_to_sampling_params(req)
|
|
213
|
+
|
|
214
|
+
# Convert tools if provided
|
|
215
|
+
tools = None
|
|
216
|
+
if req.tools:
|
|
217
|
+
tools = openai_tools_to_lm_deluge(req.tools)
|
|
218
|
+
|
|
219
|
+
# Apply cache pattern only for Anthropic-compatible models
|
|
220
|
+
cache = None
|
|
221
|
+
if api_model.api_spec in ("anthropic", "bedrock"):
|
|
222
|
+
cache = get_cache_pattern()
|
|
223
|
+
|
|
224
|
+
# Build RequestContext
|
|
225
|
+
# We need a minimal StatusTracker for execute_once to work
|
|
226
|
+
tracker = StatusTracker(
|
|
227
|
+
max_requests_per_minute=1000,
|
|
228
|
+
max_tokens_per_minute=1_000_000,
|
|
229
|
+
max_concurrent_requests=100,
|
|
230
|
+
use_progress_bar=False,
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
context = RequestContext(
|
|
234
|
+
task_id=0,
|
|
235
|
+
model_name=resolved_model,
|
|
236
|
+
prompt=conversation,
|
|
237
|
+
sampling_params=sampling_params,
|
|
238
|
+
tools=tools,
|
|
239
|
+
cache=cache,
|
|
240
|
+
status_tracker=tracker,
|
|
241
|
+
request_timeout=int(os.getenv("DELUGE_PROXY_TIMEOUT", "120")),
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
# Create and execute request
|
|
245
|
+
request_obj = api_model.make_request(context)
|
|
246
|
+
response = await request_obj.execute_once()
|
|
247
|
+
|
|
248
|
+
# Check for errors
|
|
249
|
+
if response.is_error:
|
|
250
|
+
return JSONResponse(
|
|
251
|
+
status_code=response.status_code or 500,
|
|
252
|
+
content=OpenAIErrorResponse(
|
|
253
|
+
error=OpenAIErrorDetail(
|
|
254
|
+
message=response.error_message or "Unknown error",
|
|
255
|
+
type="api_error",
|
|
256
|
+
code=str(response.status_code)
|
|
257
|
+
if response.status_code
|
|
258
|
+
else None,
|
|
259
|
+
)
|
|
260
|
+
).model_dump(),
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
# Convert to OpenAI format
|
|
264
|
+
return api_response_to_openai(response, resolved_model)
|
|
265
|
+
|
|
266
|
+
except HTTPException:
|
|
267
|
+
raise
|
|
268
|
+
except Exception as e:
|
|
269
|
+
traceback.print_exc()
|
|
270
|
+
return JSONResponse(
|
|
271
|
+
status_code=500,
|
|
272
|
+
content=OpenAIErrorResponse(
|
|
273
|
+
error=OpenAIErrorDetail(
|
|
274
|
+
message=str(e),
|
|
275
|
+
type="internal_error",
|
|
276
|
+
)
|
|
277
|
+
).model_dump(),
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
# ========================================================================
|
|
281
|
+
# Anthropic-Compatible Endpoints
|
|
282
|
+
# ========================================================================
|
|
283
|
+
|
|
284
|
+
# Support both /v1/messages and /messages for Anthropic SDK compatibility
|
|
285
|
+
# The Anthropic SDK constructs paths as {base_url}/v1/messages
|
|
286
|
+
@app.post(
|
|
287
|
+
"/v1/messages",
|
|
288
|
+
dependencies=[Depends(verify_anthropic_auth)],
|
|
289
|
+
response_model=None,
|
|
290
|
+
)
|
|
291
|
+
@app.post(
|
|
292
|
+
"/messages",
|
|
293
|
+
dependencies=[Depends(verify_anthropic_auth)],
|
|
294
|
+
response_model=None,
|
|
295
|
+
)
|
|
296
|
+
async def anthropic_messages(
|
|
297
|
+
request: Request,
|
|
298
|
+
req: AnthropicMessagesRequest,
|
|
299
|
+
) -> AnthropicMessagesResponse | JSONResponse:
|
|
300
|
+
"""Anthropic-compatible messages endpoint."""
|
|
301
|
+
# Reject streaming
|
|
302
|
+
if req.stream:
|
|
303
|
+
raise HTTPException(
|
|
304
|
+
status_code=400,
|
|
305
|
+
detail="Streaming is not supported. Set stream=false.",
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
try:
|
|
309
|
+
# Get model from registry
|
|
310
|
+
try:
|
|
311
|
+
resolved_model = router.resolve(req.model)
|
|
312
|
+
api_model = APIModel.from_registry(resolved_model)
|
|
313
|
+
except ValueError as e:
|
|
314
|
+
raise HTTPException(status_code=400, detail=str(e))
|
|
315
|
+
|
|
316
|
+
# Convert request to lm-deluge types
|
|
317
|
+
conversation = anthropic_request_to_conversation(req)
|
|
318
|
+
sampling_params = anthropic_request_to_sampling_params(req)
|
|
319
|
+
|
|
320
|
+
# Convert tools if provided
|
|
321
|
+
tools = None
|
|
322
|
+
if req.tools:
|
|
323
|
+
tools = anthropic_tools_to_lm_deluge(req.tools)
|
|
324
|
+
|
|
325
|
+
# Apply cache pattern only for Anthropic-compatible models
|
|
326
|
+
cache = None
|
|
327
|
+
if api_model.api_spec in ("anthropic", "bedrock"):
|
|
328
|
+
cache = get_cache_pattern()
|
|
329
|
+
|
|
330
|
+
# Build RequestContext
|
|
331
|
+
tracker = StatusTracker(
|
|
332
|
+
max_requests_per_minute=1000,
|
|
333
|
+
max_tokens_per_minute=1_000_000,
|
|
334
|
+
max_concurrent_requests=100,
|
|
335
|
+
use_progress_bar=False,
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
extra_headers = None
|
|
339
|
+
beta_header = request.headers.get("anthropic-beta")
|
|
340
|
+
if beta_header:
|
|
341
|
+
extra_headers = {"anthropic-beta": beta_header}
|
|
342
|
+
|
|
343
|
+
context = RequestContext(
|
|
344
|
+
task_id=0,
|
|
345
|
+
model_name=resolved_model,
|
|
346
|
+
prompt=conversation,
|
|
347
|
+
sampling_params=sampling_params,
|
|
348
|
+
tools=tools,
|
|
349
|
+
cache=cache,
|
|
350
|
+
status_tracker=tracker,
|
|
351
|
+
request_timeout=int(os.getenv("DELUGE_PROXY_TIMEOUT", "120")),
|
|
352
|
+
extra_headers=extra_headers,
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
# Create and execute request
|
|
356
|
+
request_obj = api_model.make_request(context)
|
|
357
|
+
response = await request_obj.execute_once()
|
|
358
|
+
|
|
359
|
+
# Check for errors
|
|
360
|
+
if response.is_error:
|
|
361
|
+
return JSONResponse(
|
|
362
|
+
status_code=response.status_code or 500,
|
|
363
|
+
content=AnthropicErrorResponse(
|
|
364
|
+
error=AnthropicErrorDetail(
|
|
365
|
+
type="api_error",
|
|
366
|
+
message=response.error_message or "Unknown error",
|
|
367
|
+
)
|
|
368
|
+
).model_dump(),
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
# Convert to Anthropic format
|
|
372
|
+
return api_response_to_anthropic(response, resolved_model)
|
|
373
|
+
|
|
374
|
+
except HTTPException:
|
|
375
|
+
raise
|
|
376
|
+
except Exception as e:
|
|
377
|
+
traceback.print_exc()
|
|
378
|
+
return JSONResponse(
|
|
379
|
+
status_code=500,
|
|
380
|
+
content=AnthropicErrorResponse(
|
|
381
|
+
error=AnthropicErrorDetail(
|
|
382
|
+
type="internal_error",
|
|
383
|
+
message=str(e),
|
|
384
|
+
)
|
|
385
|
+
).model_dump(),
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
return app
|
lm_deluge/server/auth.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Optional authentication for the proxy server.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import os
|
|
8
|
+
|
|
9
|
+
from fastapi import Header, HTTPException
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def get_proxy_api_key() -> str | None:
|
|
13
|
+
"""Get the configured proxy API key from environment."""
|
|
14
|
+
return os.getenv("DELUGE_PROXY_API_KEY")
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
async def verify_openai_auth(
|
|
18
|
+
authorization: str | None = Header(default=None),
|
|
19
|
+
) -> None:
|
|
20
|
+
"""
|
|
21
|
+
Verify OpenAI-style Bearer token authentication.
|
|
22
|
+
Only enforced if DELUGE_PROXY_API_KEY is set.
|
|
23
|
+
"""
|
|
24
|
+
expected_key = get_proxy_api_key()
|
|
25
|
+
if not expected_key:
|
|
26
|
+
# No auth configured, allow all requests
|
|
27
|
+
return
|
|
28
|
+
|
|
29
|
+
if not authorization:
|
|
30
|
+
raise HTTPException(
|
|
31
|
+
status_code=401,
|
|
32
|
+
detail="Missing Authorization header",
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
if not authorization.startswith("Bearer "):
|
|
36
|
+
raise HTTPException(
|
|
37
|
+
status_code=401,
|
|
38
|
+
detail="Invalid Authorization header format. Expected 'Bearer <token>'",
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
token = authorization.removeprefix("Bearer ").strip()
|
|
42
|
+
if token != expected_key:
|
|
43
|
+
raise HTTPException(
|
|
44
|
+
status_code=401,
|
|
45
|
+
detail="Invalid API key",
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
async def verify_anthropic_auth(
|
|
50
|
+
x_api_key: str | None = Header(default=None, alias="x-api-key"),
|
|
51
|
+
) -> None:
|
|
52
|
+
"""
|
|
53
|
+
Verify Anthropic-style x-api-key header authentication.
|
|
54
|
+
Only enforced if DELUGE_PROXY_API_KEY is set.
|
|
55
|
+
"""
|
|
56
|
+
expected_key = get_proxy_api_key()
|
|
57
|
+
if not expected_key:
|
|
58
|
+
# No auth configured, allow all requests
|
|
59
|
+
return
|
|
60
|
+
|
|
61
|
+
if not x_api_key:
|
|
62
|
+
raise HTTPException(
|
|
63
|
+
status_code=401,
|
|
64
|
+
detail="Missing x-api-key header",
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
if x_api_key != expected_key:
|
|
68
|
+
raise HTTPException(
|
|
69
|
+
status_code=401,
|
|
70
|
+
detail="Invalid API key",
|
|
71
|
+
)
|
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import random
|
|
4
|
+
from typing import Callable, Literal
|
|
5
|
+
|
|
6
|
+
import yaml
|
|
7
|
+
from pydantic import BaseModel, Field, model_validator
|
|
8
|
+
|
|
9
|
+
from lm_deluge.models import APIModel, registry
|
|
10
|
+
|
|
11
|
+
RouteStrategy = Literal["round_robin", "random", "weighted"]
|
|
12
|
+
PolicyMode = Literal["allow_user_pick", "force_default", "alias_only"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _dedupe_keep_order(values: list[str]) -> list[str]:
|
|
16
|
+
seen: set[str] = set()
|
|
17
|
+
output = []
|
|
18
|
+
for value in values:
|
|
19
|
+
if value in seen:
|
|
20
|
+
continue
|
|
21
|
+
seen.add(value)
|
|
22
|
+
output.append(value)
|
|
23
|
+
return output
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class RouteConfig(BaseModel):
|
|
27
|
+
models: list[str]
|
|
28
|
+
strategy: RouteStrategy = "round_robin"
|
|
29
|
+
weights: list[float] | None = None
|
|
30
|
+
|
|
31
|
+
@model_validator(mode="after")
|
|
32
|
+
def _validate_route(self) -> "RouteConfig":
|
|
33
|
+
if not self.models:
|
|
34
|
+
raise ValueError("route models must not be empty")
|
|
35
|
+
if self.strategy == "weighted":
|
|
36
|
+
if not self.weights:
|
|
37
|
+
raise ValueError("weighted strategy requires weights")
|
|
38
|
+
if len(self.weights) != len(self.models):
|
|
39
|
+
raise ValueError("weights must match models length")
|
|
40
|
+
if any(weight < 0 for weight in self.weights):
|
|
41
|
+
raise ValueError("weights must be non-negative")
|
|
42
|
+
if sum(self.weights) <= 0:
|
|
43
|
+
raise ValueError("weights must sum to a positive value")
|
|
44
|
+
return self
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class ProxyModelPolicy(BaseModel):
|
|
48
|
+
mode: PolicyMode = "allow_user_pick"
|
|
49
|
+
allowed_models: list[str] | None = None
|
|
50
|
+
default_model: str | None = None
|
|
51
|
+
routes: dict[str, RouteConfig] = Field(default_factory=dict)
|
|
52
|
+
expose_aliases: bool = False
|
|
53
|
+
|
|
54
|
+
def validate_against_registry(self, model_registry: dict[str, APIModel]) -> None:
|
|
55
|
+
registry_keys = set(model_registry.keys())
|
|
56
|
+
allowed_models = (
|
|
57
|
+
_dedupe_keep_order(self.allowed_models)
|
|
58
|
+
if self.allowed_models is not None
|
|
59
|
+
else None
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
if allowed_models is not None:
|
|
63
|
+
unknown = [model for model in allowed_models if model not in registry_keys]
|
|
64
|
+
if unknown:
|
|
65
|
+
raise ValueError(f"Unknown allowed models: {', '.join(unknown)}")
|
|
66
|
+
|
|
67
|
+
for alias, route in self.routes.items():
|
|
68
|
+
if alias in registry_keys:
|
|
69
|
+
raise ValueError(
|
|
70
|
+
f"Route alias '{alias}' conflicts with a registry model id"
|
|
71
|
+
)
|
|
72
|
+
for model_id in route.models:
|
|
73
|
+
if model_id not in registry_keys:
|
|
74
|
+
raise ValueError(
|
|
75
|
+
f"Route '{alias}' references unknown model '{model_id}'"
|
|
76
|
+
)
|
|
77
|
+
if allowed_models is not None and model_id not in allowed_models:
|
|
78
|
+
raise ValueError(
|
|
79
|
+
f"Route '{alias}' uses model '{model_id}' not in allowlist"
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
if self.mode == "force_default" and not self.default_model:
|
|
83
|
+
raise ValueError("force_default mode requires default_model")
|
|
84
|
+
|
|
85
|
+
if self.default_model:
|
|
86
|
+
if self.default_model not in self.routes:
|
|
87
|
+
if self.default_model not in registry_keys:
|
|
88
|
+
raise ValueError(f"Default model '{self.default_model}' is unknown")
|
|
89
|
+
if (
|
|
90
|
+
allowed_models is not None
|
|
91
|
+
and self.default_model not in allowed_models
|
|
92
|
+
):
|
|
93
|
+
raise ValueError(
|
|
94
|
+
f"Default model '{self.default_model}' not in allowlist"
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
if self.mode == "alias_only" and not self.routes:
|
|
98
|
+
raise ValueError("alias_only mode requires at least one route alias")
|
|
99
|
+
|
|
100
|
+
def allowed_raw_models(self, model_registry: dict[str, APIModel]) -> list[str]:
|
|
101
|
+
if self.allowed_models is None:
|
|
102
|
+
return list(model_registry.keys())
|
|
103
|
+
return _dedupe_keep_order(self.allowed_models)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class ModelRouter:
|
|
107
|
+
def __init__(
|
|
108
|
+
self,
|
|
109
|
+
policy: ProxyModelPolicy,
|
|
110
|
+
model_registry: dict[str, APIModel] | None = None,
|
|
111
|
+
*,
|
|
112
|
+
rng: random.Random | None = None,
|
|
113
|
+
) -> None:
|
|
114
|
+
self.policy = policy
|
|
115
|
+
self.model_registry = model_registry or registry
|
|
116
|
+
self.policy.validate_against_registry(self.model_registry)
|
|
117
|
+
self._rng = rng or random.Random()
|
|
118
|
+
self._round_robin_index: dict[str, int] = {}
|
|
119
|
+
|
|
120
|
+
def resolve(self, requested_model: str) -> str:
|
|
121
|
+
target = requested_model
|
|
122
|
+
if self.policy.mode == "force_default":
|
|
123
|
+
if not self.policy.default_model:
|
|
124
|
+
raise ValueError("No default model configured")
|
|
125
|
+
target = self.policy.default_model
|
|
126
|
+
|
|
127
|
+
if target in self.policy.routes:
|
|
128
|
+
return self._select_from_route(target)
|
|
129
|
+
|
|
130
|
+
if self.policy.mode == "alias_only":
|
|
131
|
+
raise ValueError(f"Model '{requested_model}' is not an exposed alias")
|
|
132
|
+
|
|
133
|
+
if target not in self.model_registry:
|
|
134
|
+
raise ValueError(f"Model '{target}' not found in registry")
|
|
135
|
+
if (
|
|
136
|
+
self.policy.allowed_models is not None
|
|
137
|
+
and target not in self.policy.allowed_models
|
|
138
|
+
):
|
|
139
|
+
raise ValueError(f"Model '{target}' is not allowed by proxy policy")
|
|
140
|
+
return target
|
|
141
|
+
|
|
142
|
+
def list_model_ids(
|
|
143
|
+
self,
|
|
144
|
+
*,
|
|
145
|
+
only_available: bool,
|
|
146
|
+
is_available: Callable[[str], bool],
|
|
147
|
+
) -> list[str]:
|
|
148
|
+
models: list[str] = []
|
|
149
|
+
|
|
150
|
+
if self.policy.mode != "alias_only":
|
|
151
|
+
raw_models = self.policy.allowed_raw_models(self.model_registry)
|
|
152
|
+
if only_available:
|
|
153
|
+
raw_models = [model for model in raw_models if is_available(model)]
|
|
154
|
+
models.extend(raw_models)
|
|
155
|
+
|
|
156
|
+
if self.policy.mode == "alias_only" or self.policy.expose_aliases:
|
|
157
|
+
aliases = list(self.policy.routes.keys())
|
|
158
|
+
if only_available:
|
|
159
|
+
aliases = [
|
|
160
|
+
alias
|
|
161
|
+
for alias in aliases
|
|
162
|
+
if any(
|
|
163
|
+
is_available(model)
|
|
164
|
+
for model in self.policy.routes[alias].models
|
|
165
|
+
)
|
|
166
|
+
]
|
|
167
|
+
models.extend(aliases)
|
|
168
|
+
|
|
169
|
+
return models
|
|
170
|
+
|
|
171
|
+
def _select_from_route(self, alias: str) -> str:
|
|
172
|
+
route = self.policy.routes[alias]
|
|
173
|
+
models = route.models
|
|
174
|
+
if len(models) == 1:
|
|
175
|
+
return models[0]
|
|
176
|
+
|
|
177
|
+
if route.strategy == "round_robin":
|
|
178
|
+
index = self._round_robin_index.get(alias, 0)
|
|
179
|
+
selected = models[index % len(models)]
|
|
180
|
+
self._round_robin_index[alias] = index + 1
|
|
181
|
+
return selected
|
|
182
|
+
|
|
183
|
+
if route.strategy == "weighted":
|
|
184
|
+
assert route.weights is not None
|
|
185
|
+
return self._rng.choices(models, weights=route.weights, k=1)[0]
|
|
186
|
+
|
|
187
|
+
return self._rng.choice(models)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def load_policy_data(path: str | None) -> dict:
|
|
191
|
+
if not path:
|
|
192
|
+
return {}
|
|
193
|
+
with open(path, "r", encoding="utf-8") as handle:
|
|
194
|
+
data = yaml.safe_load(handle) or {}
|
|
195
|
+
if not isinstance(data, dict):
|
|
196
|
+
return {}
|
|
197
|
+
if "model_policy" in data:
|
|
198
|
+
return data["model_policy"] or {}
|
|
199
|
+
proxy_block = data.get("proxy")
|
|
200
|
+
if isinstance(proxy_block, dict) and "model_policy" in proxy_block:
|
|
201
|
+
return proxy_block["model_policy"] or {}
|
|
202
|
+
return data
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def build_policy(
|
|
206
|
+
*,
|
|
207
|
+
path: str | None = None,
|
|
208
|
+
overrides: dict | None = None,
|
|
209
|
+
) -> ProxyModelPolicy:
|
|
210
|
+
data = load_policy_data(path)
|
|
211
|
+
if overrides:
|
|
212
|
+
data.update(overrides)
|
|
213
|
+
policy = ProxyModelPolicy(**data)
|
|
214
|
+
policy.validate_against_registry(registry)
|
|
215
|
+
return policy
|