@aws/ml-container-creator 0.9.1 → 0.10.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE-THIRD-PARTY +9304 -0
- package/bin/cli.js +2 -0
- package/config/bootstrap-e2e-stack.json +341 -0
- package/config/bootstrap-stack.json +40 -3
- package/config/parameter-schema-v2.json +2049 -0
- package/config/tune-catalog.json +1781 -0
- package/infra/ci-harness/buildspec.yml +1 -0
- package/infra/ci-harness/lambda/path-prover/brain.ts +306 -0
- package/infra/ci-harness/lambda/path-prover/write-results.ts +152 -0
- package/infra/ci-harness/lib/ci-harness-stack.ts +837 -7
- package/infra/ci-harness/state-machines/path-prover.asl.json +496 -0
- package/package.json +53 -68
- package/servers/base-image-picker/index.js +121 -121
- package/servers/e2e-status/index.js +297 -0
- package/servers/e2e-status/manifest.json +14 -0
- package/servers/e2e-status/package.json +15 -0
- package/servers/endpoint-picker/LICENSE +202 -0
- package/servers/endpoint-picker/index.js +536 -0
- package/servers/endpoint-picker/manifest.json +14 -0
- package/servers/endpoint-picker/package.json +18 -0
- package/servers/hyperpod-cluster-picker/index.js +125 -125
- package/servers/instance-sizer/index.js +138 -138
- package/servers/instance-sizer/lib/instance-ranker.js +76 -76
- package/servers/instance-sizer/lib/model-resolver.js +61 -61
- package/servers/instance-sizer/lib/quota-resolver.js +113 -113
- package/servers/instance-sizer/lib/vram-estimator.js +31 -31
- package/servers/lib/bedrock-client.js +38 -38
- package/servers/lib/catalogs/jumpstart-public.json +101 -16
- package/servers/lib/catalogs/model-servers.json +201 -3
- package/servers/lib/catalogs/models.json +182 -26
- package/servers/lib/custom-validators.js +13 -13
- package/servers/lib/dynamic-resolver.js +4 -4
- package/servers/marketplace-picker/index.js +342 -0
- package/servers/marketplace-picker/manifest.json +14 -0
- package/servers/marketplace-picker/package.json +18 -0
- package/servers/model-picker/index.js +382 -382
- package/servers/region-picker/index.js +56 -56
- package/servers/workload-picker/LICENSE +202 -0
- package/servers/workload-picker/catalogs/workload-profiles.json +67 -0
- package/servers/workload-picker/index.js +171 -0
- package/servers/workload-picker/manifest.json +16 -0
- package/servers/workload-picker/package.json +16 -0
- package/src/app.js +4 -390
- package/src/lib/bootstrap-command-handler.js +710 -1148
- package/src/lib/bootstrap-config.js +36 -0
- package/src/lib/bootstrap-profile-manager.js +641 -0
- package/src/lib/bootstrap-provisioners.js +421 -0
- package/src/lib/ci-register-helpers.js +74 -0
- package/src/lib/config-loader.js +408 -0
- package/src/lib/config-manager.js +66 -1685
- package/src/lib/config-mcp-client.js +118 -0
- package/src/lib/config-validator.js +634 -0
- package/src/lib/cuda-resolver.js +149 -0
- package/src/lib/e2e-catalog-validator.js +251 -3
- package/src/lib/e2e-ci-recorder.js +103 -0
- package/src/lib/generated/cli-options.js +315 -311
- package/src/lib/generated/parameter-matrix.js +671 -0
- package/src/lib/generated/validation-rules.js +71 -71
- package/src/lib/marketplace-flow.js +276 -0
- package/src/lib/mcp-query-runner.js +768 -0
- package/src/lib/parameter-schema-validator.js +62 -18
- package/src/lib/path-prover-brain.js +607 -0
- package/src/lib/prompt-runner.js +41 -1504
- package/src/lib/prompts/feature-prompts.js +172 -0
- package/src/lib/prompts/index.js +48 -0
- package/src/lib/prompts/infrastructure-prompts.js +690 -0
- package/src/lib/prompts/model-prompts.js +552 -0
- package/src/lib/prompts/project-prompts.js +82 -0
- package/src/lib/prompts.js +2 -1446
- package/src/lib/registry-command-handler.js +135 -3
- package/src/lib/secrets-prompt-runner.js +251 -0
- package/src/lib/template-variable-resolver.js +422 -0
- package/src/lib/tune-catalog-validator.js +37 -4
- package/templates/Dockerfile +9 -0
- package/templates/code/adapter_sidecar.py +444 -0
- package/templates/code/serve +6 -0
- package/templates/code/serve.d/vllm.ejs +1 -1
- package/templates/do/.benchmark_writer.py +1476 -0
- package/templates/do/.tune_helper.py +982 -57
- package/templates/do/__pycache__/.benchmark_writer.cpython-312.pyc +0 -0
- package/templates/do/adapter +149 -0
- package/templates/do/benchmark +639 -85
- package/templates/do/config +108 -5
- package/templates/do/deploy.d/managed-inference.ejs +192 -11
- package/templates/do/optimize +106 -37
- package/templates/do/register +89 -0
- package/templates/do/test +13 -0
- package/templates/do/tune +378 -59
- package/templates/do/validate +44 -4
- package/config/parameter-schema.json +0 -88
|
@@ -0,0 +1,444 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
Adapter Sidecar — SageMaker AI adapter contract implementation.
|
|
7
|
+
|
|
8
|
+
Lightweight aiohttp HTTP server that sits between SageMaker (port 8080) and the
|
|
9
|
+
model server (port 8081). Implements POST /adapters and DELETE /adapters by
|
|
10
|
+
translating them into the model server's native LoRA API, while proxying all
|
|
11
|
+
other traffic transparently.
|
|
12
|
+
|
|
13
|
+
Configuration (environment variables):
|
|
14
|
+
MODEL_SERVER_PORT - Internal model server port (default: 8081)
|
|
15
|
+
MODEL_SERVER_TYPE - Model server type: vllm or sglang (default: vllm)
|
|
16
|
+
SIDECAR_PORT - Port sidecar listens on (default: 8080)
|
|
17
|
+
MAX_LORAS - Maximum concurrent adapters (default: 64)
|
|
18
|
+
HEALTH_POLL_INTERVAL - Seconds between health polls (default: 2)
|
|
19
|
+
HEALTH_TIMEOUT - Seconds to wait for model server readiness (default: 600)
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
import asyncio
|
|
23
|
+
import os
|
|
24
|
+
import tarfile
|
|
25
|
+
import time
|
|
26
|
+
from datetime import datetime, timezone
|
|
27
|
+
|
|
28
|
+
from aiohttp import web, ClientSession, ClientTimeout
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
# ── Configuration ─────────────────────────────────────────────────────────────
|
|
32
|
+
|
|
33
|
+
MODEL_SERVER_PORT = int(os.environ.get('MODEL_SERVER_PORT', '8081'))
|
|
34
|
+
MODEL_SERVER_TYPE = os.environ.get('MODEL_SERVER_TYPE', 'vllm')
|
|
35
|
+
SIDECAR_PORT = int(os.environ.get('SIDECAR_PORT', '8080'))
|
|
36
|
+
MAX_LORAS = int(os.environ.get('MAX_LORAS', '64'))
|
|
37
|
+
HEALTH_POLL_INTERVAL = int(os.environ.get('HEALTH_POLL_INTERVAL', '2'))
|
|
38
|
+
HEALTH_TIMEOUT = int(os.environ.get('HEALTH_TIMEOUT', '600'))
|
|
39
|
+
|
|
40
|
+
MODEL_SERVER_BASE = f'http://localhost:{MODEL_SERVER_PORT}'
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
# ── Logging ───────────────────────────────────────────────────────────────────
|
|
44
|
+
|
|
45
|
+
def log(message, stream='stdout'):
|
|
46
|
+
"""Emit a log message with ISO 8601 timestamp and [adapter-sidecar] prefix."""
|
|
47
|
+
ts = datetime.now(timezone.utc).strftime('%Y-%m-%dT%H:%M:%SZ')
|
|
48
|
+
line = f'{ts} [adapter-sidecar] {message}'
|
|
49
|
+
if stream == 'stderr':
|
|
50
|
+
import sys
|
|
51
|
+
print(line, file=sys.stderr)
|
|
52
|
+
else:
|
|
53
|
+
print(line)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
# ── Artifact Resolution ───────────────────────────────────────────────────────
|
|
57
|
+
|
|
58
|
+
class ArtifactResolver:
|
|
59
|
+
"""Resolves adapter artifacts from a source path.
|
|
60
|
+
|
|
61
|
+
Handles three cases:
|
|
62
|
+
1. Path contains a single tar.gz file — extract in place, return directory
|
|
63
|
+
2. Path contains adapter_config.json — use directory directly
|
|
64
|
+
3. Path does not exist or is empty — raise FileNotFoundError
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
@staticmethod
|
|
68
|
+
def resolve(src):
|
|
69
|
+
"""Resolve the adapter artifact path.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
src: Filesystem path where SageMaker placed adapter artifacts.
|
|
73
|
+
|
|
74
|
+
Returns:
|
|
75
|
+
Resolved directory path containing adapter files.
|
|
76
|
+
|
|
77
|
+
Raises:
|
|
78
|
+
FileNotFoundError: If path does not exist or is empty.
|
|
79
|
+
RuntimeError: If tar.gz extraction fails.
|
|
80
|
+
"""
|
|
81
|
+
# Check if path exists
|
|
82
|
+
if not os.path.exists(src):
|
|
83
|
+
raise FileNotFoundError(f'Adapter artifact path does not exist: {src}')
|
|
84
|
+
|
|
85
|
+
# If src is a file (direct tar.gz path), extract it
|
|
86
|
+
if os.path.isfile(src) and src.endswith('.tar.gz'):
|
|
87
|
+
extract_dir = os.path.dirname(src)
|
|
88
|
+
ArtifactResolver._extract_tar_gz(src, extract_dir)
|
|
89
|
+
return extract_dir
|
|
90
|
+
|
|
91
|
+
# If src is a directory, check contents
|
|
92
|
+
if not os.path.isdir(src):
|
|
93
|
+
raise FileNotFoundError(f'Adapter artifact path is not a directory: {src}')
|
|
94
|
+
|
|
95
|
+
# Check if directory is empty
|
|
96
|
+
contents = os.listdir(src)
|
|
97
|
+
if not contents:
|
|
98
|
+
raise FileNotFoundError(f'Adapter artifact path is empty: {src}')
|
|
99
|
+
|
|
100
|
+
# Check if directory already contains adapter_config.json (extracted files)
|
|
101
|
+
if 'adapter_config.json' in contents:
|
|
102
|
+
return src
|
|
103
|
+
|
|
104
|
+
# Check if directory contains a single tar.gz file
|
|
105
|
+
tar_files = [f for f in contents if f.endswith('.tar.gz')]
|
|
106
|
+
if len(tar_files) == 1:
|
|
107
|
+
tar_path = os.path.join(src, tar_files[0])
|
|
108
|
+
ArtifactResolver._extract_tar_gz(tar_path, src)
|
|
109
|
+
return src
|
|
110
|
+
|
|
111
|
+
# If we get here, the path exists but has no recognizable adapter artifacts
|
|
112
|
+
# Check again after potential extraction if adapter_config.json appeared
|
|
113
|
+
if 'adapter_config.json' in os.listdir(src):
|
|
114
|
+
return src
|
|
115
|
+
|
|
116
|
+
raise FileNotFoundError(
|
|
117
|
+
f'Adapter artifact path does not contain adapter_config.json or a tar.gz archive: {src}'
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
@staticmethod
|
|
121
|
+
def _extract_tar_gz(tar_path, extract_dir):
|
|
122
|
+
"""Extract a tar.gz archive to the specified directory.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
tar_path: Path to the tar.gz file.
|
|
126
|
+
extract_dir: Directory to extract files into.
|
|
127
|
+
|
|
128
|
+
Raises:
|
|
129
|
+
RuntimeError: If extraction fails due to corruption or permission issues.
|
|
130
|
+
"""
|
|
131
|
+
try:
|
|
132
|
+
with tarfile.open(tar_path, 'r:gz') as tar:
|
|
133
|
+
# Use filter='data' on Python 3.12+ for security, fall back for older versions
|
|
134
|
+
if hasattr(tarfile, 'data_filter'):
|
|
135
|
+
tar.extractall(path=extract_dir, filter='data')
|
|
136
|
+
else:
|
|
137
|
+
tar.extractall(path=extract_dir)
|
|
138
|
+
except (tarfile.TarError, OSError, PermissionError) as e:
|
|
139
|
+
raise RuntimeError(f'Failed to extract tar.gz archive {tar_path}: {e}')
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
# ── Model Server Client (Strategy Pattern) ────────────────────────────────────
|
|
143
|
+
|
|
144
|
+
class ModelServerClient:
|
|
145
|
+
"""Strategy interface for model server native LoRA API translation.
|
|
146
|
+
|
|
147
|
+
Subclasses implement the specific HTTP calls for each model server type.
|
|
148
|
+
"""
|
|
149
|
+
|
|
150
|
+
def __init__(self, session, base_url):
|
|
151
|
+
self.session = session
|
|
152
|
+
self.base_url = base_url
|
|
153
|
+
|
|
154
|
+
async def load_adapter(self, name, path):
|
|
155
|
+
"""Load a LoRA adapter into the model server.
|
|
156
|
+
|
|
157
|
+
Args:
|
|
158
|
+
name: Adapter identifier.
|
|
159
|
+
path: Resolved filesystem path to adapter artifacts.
|
|
160
|
+
|
|
161
|
+
Returns:
|
|
162
|
+
dict with response data from the model server.
|
|
163
|
+
|
|
164
|
+
Raises:
|
|
165
|
+
RuntimeError: If the model server returns an error or is unreachable.
|
|
166
|
+
"""
|
|
167
|
+
raise NotImplementedError
|
|
168
|
+
|
|
169
|
+
async def unload_adapter(self, name):
|
|
170
|
+
"""Unload a LoRA adapter from the model server.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
name: Adapter identifier.
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
dict with response data from the model server.
|
|
177
|
+
|
|
178
|
+
Raises:
|
|
179
|
+
RuntimeError: If the model server returns an error or is unreachable.
|
|
180
|
+
"""
|
|
181
|
+
raise NotImplementedError
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
class VLLMClient(ModelServerClient):
|
|
185
|
+
"""vLLM-specific adapter API translation.
|
|
186
|
+
|
|
187
|
+
Load: POST /v1/load_lora_adapter {"lora_name": name, "lora_path": path}
|
|
188
|
+
Unload: POST /v1/unload_lora_adapter {"lora_name": name}
|
|
189
|
+
"""
|
|
190
|
+
|
|
191
|
+
async def load_adapter(self, name, path):
|
|
192
|
+
"""Load a LoRA adapter via vLLM's native API."""
|
|
193
|
+
url = f'{self.base_url}/v1/load_lora_adapter'
|
|
194
|
+
payload = {'lora_name': name, 'lora_path': path}
|
|
195
|
+
try:
|
|
196
|
+
async with self.session.post(url, json=payload) as resp:
|
|
197
|
+
body = await resp.text()
|
|
198
|
+
if resp.status == 200:
|
|
199
|
+
return {'status': 'success', 'response': body}
|
|
200
|
+
raise RuntimeError(f'vLLM load_lora_adapter failed (HTTP {resp.status}): {body}')
|
|
201
|
+
except RuntimeError:
|
|
202
|
+
raise
|
|
203
|
+
except Exception as e:
|
|
204
|
+
raise RuntimeError(f'Failed to connect to vLLM: {e}')
|
|
205
|
+
|
|
206
|
+
async def unload_adapter(self, name):
|
|
207
|
+
"""Unload a LoRA adapter via vLLM's native API."""
|
|
208
|
+
url = f'{self.base_url}/v1/unload_lora_adapter'
|
|
209
|
+
payload = {'lora_name': name}
|
|
210
|
+
try:
|
|
211
|
+
async with self.session.post(url, json=payload) as resp:
|
|
212
|
+
body = await resp.text()
|
|
213
|
+
if resp.status == 200:
|
|
214
|
+
return {'status': 'success', 'response': body}
|
|
215
|
+
raise RuntimeError(f'vLLM unload_lora_adapter failed (HTTP {resp.status}): {body}')
|
|
216
|
+
except RuntimeError:
|
|
217
|
+
raise
|
|
218
|
+
except Exception as e:
|
|
219
|
+
raise RuntimeError(f'Failed to connect to vLLM: {e}')
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
class SGLangClient(ModelServerClient):
|
|
223
|
+
"""SGLang-specific adapter API translation. (Deferred)
|
|
224
|
+
|
|
225
|
+
SGLang support is deferred to a follow-up. This placeholder raises
|
|
226
|
+
NotImplementedError for both load and unload operations.
|
|
227
|
+
"""
|
|
228
|
+
|
|
229
|
+
async def load_adapter(self, name, path):
|
|
230
|
+
"""Load a LoRA adapter via SGLang's native API. (Not yet implemented)"""
|
|
231
|
+
raise NotImplementedError('SGLang adapter loading is not yet implemented')
|
|
232
|
+
|
|
233
|
+
async def unload_adapter(self, name):
|
|
234
|
+
"""Unload a LoRA adapter via SGLang's native API. (Not yet implemented)"""
|
|
235
|
+
raise NotImplementedError('SGLang adapter unloading is not yet implemented')
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def create_model_server_client(session, base_url, server_type):
|
|
239
|
+
"""Factory function to create the appropriate ModelServerClient.
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
session: aiohttp.ClientSession for HTTP calls.
|
|
243
|
+
base_url: Model server base URL (e.g., http://localhost:8081).
|
|
244
|
+
server_type: Model server type ('vllm' or 'sglang').
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
ModelServerClient instance.
|
|
248
|
+
"""
|
|
249
|
+
if server_type == 'vllm':
|
|
250
|
+
return VLLMClient(session, base_url)
|
|
251
|
+
elif server_type == 'sglang':
|
|
252
|
+
return SGLangClient(session, base_url)
|
|
253
|
+
else:
|
|
254
|
+
raise ValueError(f'Unsupported model server type: {server_type}')
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
# ── State ─────────────────────────────────────────────────────────────────────
|
|
258
|
+
|
|
259
|
+
adapter_registry = {}
|
|
260
|
+
model_server_ready = False
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
# ── Health Polling (Readiness Gating) ─────────────────────────────────────────
|
|
264
|
+
|
|
265
|
+
async def poll_model_server_health(app):
|
|
266
|
+
"""Background task that polls the model server health endpoint.
|
|
267
|
+
|
|
268
|
+
Sets model_server_ready to True once the health endpoint returns 200.
|
|
269
|
+
After HEALTH_TIMEOUT seconds, logs a warning and sets ready to True
|
|
270
|
+
to avoid indefinite blocking.
|
|
271
|
+
"""
|
|
272
|
+
global model_server_ready
|
|
273
|
+
session = app['session']
|
|
274
|
+
start_time = time.monotonic()
|
|
275
|
+
|
|
276
|
+
log(f'Starting health polling — interval={HEALTH_POLL_INTERVAL}s, timeout={HEALTH_TIMEOUT}s')
|
|
277
|
+
|
|
278
|
+
while True:
|
|
279
|
+
elapsed = time.monotonic() - start_time
|
|
280
|
+
|
|
281
|
+
# Timeout: log warning and begin accepting requests
|
|
282
|
+
if elapsed >= HEALTH_TIMEOUT:
|
|
283
|
+
log(f'Health timeout reached ({HEALTH_TIMEOUT}s) — model server did not become ready. Accepting requests anyway.', stream='stderr')
|
|
284
|
+
model_server_ready = True
|
|
285
|
+
return
|
|
286
|
+
|
|
287
|
+
try:
|
|
288
|
+
async with session.get(f'{MODEL_SERVER_BASE}/health', timeout=ClientTimeout(total=5)) as resp:
|
|
289
|
+
if resp.status == 200:
|
|
290
|
+
model_server_ready = True
|
|
291
|
+
log('Model server is ready (health endpoint returned 200)')
|
|
292
|
+
return
|
|
293
|
+
except Exception:
|
|
294
|
+
pass
|
|
295
|
+
|
|
296
|
+
await asyncio.sleep(HEALTH_POLL_INTERVAL)
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
# ── Handlers ──────────────────────────────────────────────────────────────────
|
|
300
|
+
|
|
301
|
+
async def handle_ping(request):
|
|
302
|
+
"""GET /ping — readiness gating.
|
|
303
|
+
|
|
304
|
+
Returns 503 until model server health endpoint returns 200.
|
|
305
|
+
Once ready, proxies /ping to the model server.
|
|
306
|
+
"""
|
|
307
|
+
if not model_server_ready:
|
|
308
|
+
return web.Response(status=503, text='Service Unavailable')
|
|
309
|
+
|
|
310
|
+
# Proxy to model server health endpoint
|
|
311
|
+
session = request.app['session']
|
|
312
|
+
try:
|
|
313
|
+
async with session.get(f'{MODEL_SERVER_BASE}/health') as resp:
|
|
314
|
+
body = await resp.read()
|
|
315
|
+
return web.Response(status=resp.status, body=body,
|
|
316
|
+
headers={'Content-Type': resp.headers.get('Content-Type', 'text/plain')})
|
|
317
|
+
except Exception as e:
|
|
318
|
+
return web.Response(status=503, text=f'Model server unreachable: {e}')
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
async def handle_adapters_post(request):
|
|
322
|
+
"""POST /adapters — load a LoRA adapter."""
|
|
323
|
+
name = request.query.get('name')
|
|
324
|
+
src = request.query.get('src')
|
|
325
|
+
|
|
326
|
+
if not name:
|
|
327
|
+
return web.json_response({'status': 'error', 'error': 'Missing required query parameter: name'}, status=400)
|
|
328
|
+
if not src:
|
|
329
|
+
return web.json_response({'status': 'error', 'error': 'Missing required query parameter: src'}, status=400)
|
|
330
|
+
|
|
331
|
+
# Check MAX_LORAS limit
|
|
332
|
+
if len(adapter_registry) >= MAX_LORAS:
|
|
333
|
+
return web.json_response(
|
|
334
|
+
{'status': 'error', 'adapter': name, 'error': f'Maximum concurrent adapters ({MAX_LORAS}) reached'},
|
|
335
|
+
status=507
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
# Resolve adapter artifacts
|
|
339
|
+
try:
|
|
340
|
+
resolved_path = ArtifactResolver.resolve(src)
|
|
341
|
+
except FileNotFoundError as e:
|
|
342
|
+
return web.json_response({'status': 'error', 'adapter': name, 'error': str(e)}, status=404)
|
|
343
|
+
except RuntimeError as e:
|
|
344
|
+
return web.json_response({'status': 'error', 'adapter': name, 'error': str(e)}, status=500)
|
|
345
|
+
|
|
346
|
+
# Call model server native LoRA API
|
|
347
|
+
client = request.app['model_server_client']
|
|
348
|
+
try:
|
|
349
|
+
await client.load_adapter(name, resolved_path)
|
|
350
|
+
except RuntimeError as e:
|
|
351
|
+
log(f'Adapter load failed — name={name}, src={src}, error={e}', stream='stderr')
|
|
352
|
+
return web.json_response({'status': 'error', 'adapter': name, 'error': str(e)}, status=500)
|
|
353
|
+
|
|
354
|
+
# Register adapter and respond
|
|
355
|
+
adapter_registry[name] = resolved_path
|
|
356
|
+
log(f'Adapter loaded — name={name}, src={src}, resolved_path={resolved_path}')
|
|
357
|
+
return web.json_response({'status': 'loaded', 'adapter': name, 'path': resolved_path})
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
async def handle_adapters_delete(request):
|
|
361
|
+
"""DELETE /adapters — unload a LoRA adapter."""
|
|
362
|
+
name = request.query.get('name')
|
|
363
|
+
|
|
364
|
+
if not name:
|
|
365
|
+
return web.json_response({'status': 'error', 'error': 'Missing required query parameter: name'}, status=400)
|
|
366
|
+
|
|
367
|
+
# Call model server native LoRA API
|
|
368
|
+
client = request.app['model_server_client']
|
|
369
|
+
try:
|
|
370
|
+
await client.unload_adapter(name)
|
|
371
|
+
except RuntimeError as e:
|
|
372
|
+
log(f'Adapter unload failed — name={name}, error={e}', stream='stderr')
|
|
373
|
+
return web.json_response({'status': 'error', 'adapter': name, 'error': str(e)}, status=500)
|
|
374
|
+
|
|
375
|
+
# Remove from registry and respond
|
|
376
|
+
adapter_registry.pop(name, None)
|
|
377
|
+
log(f'Adapter unloaded — name={name}')
|
|
378
|
+
return web.json_response({'status': 'unloaded', 'adapter': name})
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
async def handle_proxy(request):
|
|
382
|
+
"""Proxy all non-/adapters requests to the model server transparently."""
|
|
383
|
+
session = request.app['session']
|
|
384
|
+
target_url = f'{MODEL_SERVER_BASE}{request.path_qs}'
|
|
385
|
+
|
|
386
|
+
try:
|
|
387
|
+
body = await request.read()
|
|
388
|
+
async with session.request(
|
|
389
|
+
method=request.method,
|
|
390
|
+
url=target_url,
|
|
391
|
+
headers={k: v for k, v in request.headers.items() if k.lower() != 'host'},
|
|
392
|
+
data=body if body else None
|
|
393
|
+
) as resp:
|
|
394
|
+
resp_body = await resp.read()
|
|
395
|
+
response_headers = {k: v for k, v in resp.headers.items()
|
|
396
|
+
if k.lower() not in ('transfer-encoding', 'content-encoding', 'content-length')}
|
|
397
|
+
return web.Response(status=resp.status, body=resp_body, headers=response_headers)
|
|
398
|
+
except Exception as e:
|
|
399
|
+
return web.json_response({'status': 'error', 'error': f'Model server unreachable: {e}'}, status=500)
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
# ── Application Setup ─────────────────────────────────────────────────────────
|
|
403
|
+
|
|
404
|
+
async def on_startup(app):
|
|
405
|
+
"""Create HTTP session, model server client, and start health polling background task."""
|
|
406
|
+
app['session'] = ClientSession()
|
|
407
|
+
app['model_server_client'] = create_model_server_client(app['session'], MODEL_SERVER_BASE, MODEL_SERVER_TYPE)
|
|
408
|
+
app['health_task'] = asyncio.create_task(poll_model_server_health(app))
|
|
409
|
+
log(f'Sidecar started — port={SIDECAR_PORT}, model_server_port={MODEL_SERVER_PORT}, '
|
|
410
|
+
f'model_server_type={MODEL_SERVER_TYPE}, max_loras={MAX_LORAS}')
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
async def on_cleanup(app):
|
|
414
|
+
"""Cleanup HTTP session and cancel background tasks."""
|
|
415
|
+
app['health_task'].cancel()
|
|
416
|
+
try:
|
|
417
|
+
await app['health_task']
|
|
418
|
+
except asyncio.CancelledError:
|
|
419
|
+
pass
|
|
420
|
+
await app['session'].close()
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
def create_app():
|
|
424
|
+
"""Create and configure the aiohttp application."""
|
|
425
|
+
app = web.Application()
|
|
426
|
+
|
|
427
|
+
# Register routes
|
|
428
|
+
app.router.add_get('/ping', handle_ping)
|
|
429
|
+
app.router.add_post('/adapters', handle_adapters_post)
|
|
430
|
+
app.router.add_delete('/adapters', handle_adapters_delete)
|
|
431
|
+
|
|
432
|
+
# Catch-all proxy for everything else
|
|
433
|
+
app.router.add_route('*', '/{path:.*}', handle_proxy)
|
|
434
|
+
|
|
435
|
+
# Lifecycle hooks
|
|
436
|
+
app.on_startup.append(on_startup)
|
|
437
|
+
app.on_cleanup.append(on_cleanup)
|
|
438
|
+
|
|
439
|
+
return app
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
if __name__ == '__main__':
|
|
443
|
+
app = create_app()
|
|
444
|
+
web.run_app(app, host='0.0.0.0', port=SIDECAR_PORT, print=None)
|
package/templates/code/serve
CHANGED
|
@@ -176,6 +176,12 @@ unset _MODEL_VAR _RESOLVED_MODEL
|
|
|
176
176
|
SERVER_ARGS=(--host 0.0.0.0 --port 8081)
|
|
177
177
|
<% } else { %>
|
|
178
178
|
SERVER_ARGS=(--host 0.0.0.0 --port 8080)
|
|
179
|
+
|
|
180
|
+
# ---------------------------------------------------------------------------
|
|
181
|
+
# Adapter Sidecar — DISABLED
|
|
182
|
+
# vLLM runs on port 8080 and handles /ping, /invocations, /v1/* natively.
|
|
183
|
+
# The /adapters route will be injected directly into vLLM's FastAPI app.
|
|
184
|
+
# ---------------------------------------------------------------------------
|
|
179
185
|
<% } %>
|
|
180
186
|
|
|
181
187
|
# --- Server-specific arg conversion and exec ---
|
|
@@ -10,7 +10,7 @@ PREFIX="VLLM_"
|
|
|
10
10
|
ARG_PREFIX="--"
|
|
11
11
|
|
|
12
12
|
# Internal variables set by the base image — not CLI args
|
|
13
|
-
EXCLUDE_VARS=("VLLM_USAGE_SOURCE" "VLLM_ENABLE_CUDA_COMPATIBILITY")
|
|
13
|
+
EXCLUDE_VARS=("VLLM_USAGE_SOURCE" "VLLM_ENABLE_CUDA_COMPATIBILITY" "VLLM_ALLOW_RUNTIME_LORA_UPDATING")
|
|
14
14
|
|
|
15
15
|
# Declare and populate array of matching environment variables
|
|
16
16
|
mapfile -t env_vars < <(env | grep "^${PREFIX}")
|