@aws/ml-container-creator 0.10.0 → 0.12.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. package/LICENSE-THIRD-PARTY +9304 -0
  2. package/bin/cli.js +2 -0
  3. package/config/bootstrap-e2e-stack.json +341 -0
  4. package/config/bootstrap-stack.json +40 -3
  5. package/config/parameter-schema-v2.json +33 -22
  6. package/config/tune-catalog.json +1781 -0
  7. package/infra/ci-harness/buildspec.yml +1 -0
  8. package/infra/ci-harness/lambda/path-prover/brain.ts +306 -0
  9. package/infra/ci-harness/lambda/path-prover/write-results.ts +152 -0
  10. package/infra/ci-harness/lib/ci-harness-stack.ts +851 -7
  11. package/infra/ci-harness/state-machines/path-prover.asl.json +496 -0
  12. package/package.json +53 -67
  13. package/servers/base-image-picker/index.js +121 -121
  14. package/servers/e2e-status/index.js +297 -0
  15. package/servers/e2e-status/manifest.json +14 -0
  16. package/servers/e2e-status/package.json +15 -0
  17. package/servers/endpoint-picker/LICENSE +202 -0
  18. package/servers/endpoint-picker/index.js +536 -0
  19. package/servers/endpoint-picker/manifest.json +14 -0
  20. package/servers/endpoint-picker/package.json +18 -0
  21. package/servers/hyperpod-cluster-picker/index.js +125 -125
  22. package/servers/instance-sizer/index.js +166 -153
  23. package/servers/instance-sizer/lib/instance-ranker.js +120 -76
  24. package/servers/instance-sizer/lib/model-resolver.js +61 -61
  25. package/servers/instance-sizer/lib/quota-resolver.js +113 -113
  26. package/servers/instance-sizer/lib/vram-estimator.js +31 -31
  27. package/servers/lib/bedrock-client.js +38 -38
  28. package/servers/lib/catalogs/instances.json +27 -0
  29. package/servers/lib/catalogs/model-servers.json +201 -3
  30. package/servers/lib/custom-validators.js +13 -13
  31. package/servers/lib/dynamic-resolver.js +4 -4
  32. package/servers/marketplace-picker/index.js +342 -0
  33. package/servers/marketplace-picker/manifest.json +14 -0
  34. package/servers/marketplace-picker/package.json +18 -0
  35. package/servers/model-picker/index.js +382 -382
  36. package/servers/region-picker/index.js +56 -56
  37. package/servers/workload-picker/LICENSE +202 -0
  38. package/servers/workload-picker/catalogs/workload-profiles.json +67 -0
  39. package/servers/workload-picker/index.js +171 -0
  40. package/servers/workload-picker/manifest.json +16 -0
  41. package/servers/workload-picker/package.json +16 -0
  42. package/src/app.js +12 -3
  43. package/src/lib/bootstrap-command-handler.js +609 -15
  44. package/src/lib/bootstrap-config.js +36 -0
  45. package/src/lib/bootstrap-profile-manager.js +48 -41
  46. package/src/lib/ci-register-helpers.js +74 -0
  47. package/src/lib/config-loader.js +3 -0
  48. package/src/lib/config-manager.js +7 -0
  49. package/src/lib/config-validator.js +1 -1
  50. package/src/lib/cuda-resolver.js +17 -8
  51. package/src/lib/generated/cli-options.js +319 -314
  52. package/src/lib/generated/parameter-matrix.js +672 -661
  53. package/src/lib/generated/validation-rules.js +76 -72
  54. package/src/lib/path-prover-brain.js +664 -0
  55. package/src/lib/prompts/infrastructure-prompts.js +2 -2
  56. package/src/lib/prompts/model-prompts.js +6 -0
  57. package/src/lib/prompts/project-prompts.js +12 -0
  58. package/src/lib/secrets-prompt-runner.js +4 -0
  59. package/src/lib/template-manager.js +1 -1
  60. package/src/lib/template-variable-resolver.js +87 -1
  61. package/src/lib/tune-catalog-validator.js +37 -4
  62. package/templates/Dockerfile +9 -0
  63. package/templates/code/adapter_sidecar.py +444 -0
  64. package/templates/code/serve +6 -0
  65. package/templates/code/serve.d/vllm.ejs +1 -1
  66. package/templates/do/.benchmark_writer.py +1476 -0
  67. package/templates/do/.tune_helper.py +982 -57
  68. package/templates/do/__pycache__/.benchmark_writer.cpython-312.pyc +0 -0
  69. package/templates/do/adapter +154 -0
  70. package/templates/do/benchmark +639 -85
  71. package/templates/do/build +5 -0
  72. package/templates/do/clean.d/async-inference.ejs +5 -0
  73. package/templates/do/clean.d/batch-transform.ejs +5 -0
  74. package/templates/do/clean.d/hyperpod-eks.ejs +5 -0
  75. package/templates/do/clean.d/managed-inference.ejs +5 -0
  76. package/templates/do/config +115 -45
  77. package/templates/do/deploy.d/async-inference.ejs +30 -3
  78. package/templates/do/deploy.d/batch-transform.ejs +29 -3
  79. package/templates/do/deploy.d/hyperpod-eks.ejs +4 -0
  80. package/templates/do/deploy.d/managed-inference.ejs +216 -14
  81. package/templates/do/lib/endpoint-config.sh +1 -1
  82. package/templates/do/lib/profile.sh +44 -0
  83. package/templates/do/optimize +106 -37
  84. package/templates/do/push +5 -0
  85. package/templates/do/register +94 -0
  86. package/templates/do/stage +567 -0
  87. package/templates/do/submit +7 -0
  88. package/templates/do/test +14 -0
  89. package/templates/do/tune +382 -59
  90. package/templates/do/validate +44 -4
@@ -48,6 +48,18 @@ const projectPrompts = [
48
48
  // Derive framework from deploymentConfig if not already set
49
49
  const framework = answers.framework || answers.deploymentConfig?.split('-')[0];
50
50
  return generateProjectName(framework);
51
+ },
52
+ validate: (input) => {
53
+ if (!input || input.length < 2) {
54
+ return 'Project name must be at least 2 characters.';
55
+ }
56
+ if (input.length > 63) {
57
+ return 'Project name must be 63 characters or fewer.';
58
+ }
59
+ if (!/^[a-z0-9][a-z0-9-]*[a-z0-9]$/.test(input)) {
60
+ return 'Project name must be lowercase alphanumeric with hyphens (e.g. "qwen3-0-6b-v1-test"). No uppercase, dots, or underscores.';
61
+ }
62
+ return true;
51
63
  }
52
64
  }
53
65
  ];
@@ -70,6 +70,10 @@ export default class SecretsPromptRunner {
70
70
  const modelSource = answers.modelSource;
71
71
  if (modelSource && modelSource !== 'huggingface') return false;
72
72
 
73
+ // Skip HF token when model name is an S3 URI (no HF download needed)
74
+ const modelName = answers.customModelName || answers.modelName;
75
+ if (modelName && modelName.startsWith('s3://')) return false;
76
+
73
77
  return true;
74
78
  }
75
79
 
@@ -146,7 +146,7 @@ export default class TemplateManager {
146
146
 
147
147
  // Validate instance type format (ml.*.*) - only for realtime-inference
148
148
  if (this.answers.instanceType && this.answers.instanceType !== 'custom') {
149
- const instancePattern = /^ml\.[a-z0-9]+\.(nano|micro|small|medium|large|xlarge|[0-9]+xlarge)$/;
149
+ const instancePattern = /^ml\.[a-z0-9-]+\.(nano|micro|small|medium|large|xlarge|[0-9]+xlarge)$/;
150
150
  if (!instancePattern.test(this.answers.instanceType)) {
151
151
  throw new Error(`⚠️ Invalid instance type format: ${this.answers.instanceType}. Expected format: ml.{family}.{size} (e.g., ml.m5.large, ml.g5.xlarge)`);
152
152
  }
@@ -4,7 +4,7 @@
4
4
  import fs from 'fs';
5
5
  import path from 'path';
6
6
  import { fileURLToPath } from 'url';
7
- import { isTuneSupported } from './tune-catalog-validator.js';
7
+ import { isTuneSupported, lookupModel } from './tune-catalog-validator.js';
8
8
 
9
9
  const __filename = fileURLToPath(import.meta.url);
10
10
  const __dirname = path.dirname(__filename);
@@ -383,6 +383,68 @@ export async function _ensureTemplateVariables(answers, registryConfigManager =
383
383
  }
384
384
  }
385
385
 
386
+ // Auto-resolve tensor parallel degree from instance catalog GPU count.
387
+ // Only applies when:
388
+ // 1. The engine supports tensor parallelism (vLLM, SGLang, TensorRT-LLM, LMI)
389
+ // 2. The instance has multiple GPUs (gpus > 1)
390
+ // 3. The user has NOT explicitly set the TP env var via --server-env or --model-env
391
+ // This ensures multi-GPU instances default to full TP utilization without requiring
392
+ // the user to manually specify TENSOR_PARALLEL_SIZE.
393
+ // Requirements: FTP-1 (extension) — task 6.2
394
+ const _TP_ENGINE_MAP = {
395
+ 'vllm': 'VLLM_TENSOR_PARALLEL_SIZE',
396
+ 'vllm-omni': 'VLLM_OMNI_TENSOR_PARALLEL_SIZE',
397
+ 'sglang': 'SGLANG_TENSOR_PARALLEL_SIZE',
398
+ 'tensorrt-llm': 'TRTLLM_TENSOR_PARALLEL_SIZE',
399
+ 'lmi': 'OPTION_TENSOR_PARALLEL_DEGREE'
400
+ };
401
+
402
+ const tpEngine = answers.backend || answers.modelServer;
403
+ const tpEnvKey = tpEngine ? _TP_ENGINE_MAP[tpEngine] : null;
404
+
405
+ if (tpEnvKey && answers.instanceType) {
406
+ // Check if user explicitly set the TP value via --server-env (un-prefixed key)
407
+ const userServerEnvVars = answers.serverEnvVars || {};
408
+ const userExplicitlySetTP = (
409
+ userServerEnvVars['TENSOR_PARALLEL_SIZE'] !== undefined ||
410
+ userServerEnvVars['TENSOR_PARALLEL_DEGREE'] !== undefined ||
411
+ userServerEnvVars[tpEnvKey] !== undefined
412
+ );
413
+
414
+ if (!userExplicitlySetTP) {
415
+ // Look up GPU count from instance catalog
416
+ let instanceGpuCount = null;
417
+ if (answers.gpuCount) {
418
+ instanceGpuCount = answers.gpuCount;
419
+ } else if (answers.icGpuCount) {
420
+ instanceGpuCount = answers.icGpuCount;
421
+ } else {
422
+ try {
423
+ const catalogPath = path.resolve(__dirname, '..', '..', 'servers', 'lib', 'catalogs', 'instances.json');
424
+ const catalogData = JSON.parse(fs.readFileSync(catalogPath, 'utf-8'));
425
+ const instanceInfo = catalogData?.catalog?.[answers.instanceType];
426
+ if (instanceInfo?.gpus && instanceInfo.gpus > 0) {
427
+ instanceGpuCount = instanceInfo.gpus;
428
+ }
429
+ } catch {
430
+ // Silently continue
431
+ }
432
+ }
433
+
434
+ // Auto-set TP to GPU count when instance has multiple GPUs
435
+ if (instanceGpuCount && instanceGpuCount > 1) {
436
+ if (!answers.envVars) {
437
+ answers.envVars = {};
438
+ }
439
+ answers.envVars[tpEnvKey] = String(instanceGpuCount);
440
+ answers.tensorParallelSize = instanceGpuCount;
441
+ answers._tpAutoResolved = true;
442
+ answers._tpAutoResolvedFrom = answers.instanceType;
443
+ console.log(` ℹ️ TP degree: ${instanceGpuCount} (auto-detected from ${answers.instanceType})`);
444
+ }
445
+ }
446
+ }
447
+
386
448
  // Determine tune support based on model presence in the tune catalog.
387
449
  // Used by the do/config template to write TUNE_SUPPORTED=true|false.
388
450
  if (answers.tuneSupported === undefined) {
@@ -395,4 +457,28 @@ export async function _ensureTemplateVariables(answers, registryConfigManager =
395
457
  answers.tuneSupported = false;
396
458
  }
397
459
  }
460
+
461
+ // Resolve tuneModelId from the catalog — static lookup, no network calls.
462
+ // Maps the HuggingFace model ID to the Hub content name (catalog key).
463
+ if (answers.tuneModelId === undefined) {
464
+ if (answers.tuneSupported && answers.modelName) {
465
+ try {
466
+ const tuneCatalogPath = path.resolve(__dirname, '..', '..', 'config', 'tune-catalog.json');
467
+ const tuneCatalog = JSON.parse(fs.readFileSync(tuneCatalogPath, 'utf-8'));
468
+ const entry = lookupModel(answers.modelName, tuneCatalog);
469
+ if (entry) {
470
+ const hubContentName = Object.entries(tuneCatalog.models)
471
+ .find(([, v]) => v === entry)?.[0];
472
+ if (hubContentName) {
473
+ answers.tuneModelId = hubContentName;
474
+ }
475
+ }
476
+ } catch {
477
+ // Silently continue — tuneModelId will be set to null below
478
+ }
479
+ }
480
+ if (!answers.tuneModelId) {
481
+ answers.tuneModelId = null;
482
+ }
483
+ }
398
484
  }
@@ -13,7 +13,8 @@
13
13
 
14
14
  /**
15
15
  * Look up a model entry in the catalog by model ID.
16
- * @param {string} modelId - The model ID to look up
16
+ * Tries: direct key match, huggingFaceId field match, then normalized/suffix matching.
17
+ * @param {string} modelId - The model ID to look up (Hub content name or HuggingFace ID)
17
18
  * @param {Object} catalog - The tune catalog object with a `models` map
18
19
  * @returns {Object|null} The catalog entry for the model, or null if not found
19
20
  */
@@ -21,10 +22,42 @@ export function lookupModel(modelId, catalog) {
21
22
  if (!catalog || !catalog.models) {
22
23
  return null;
23
24
  }
24
- if (!Object.hasOwn(catalog.models, modelId)) {
25
- return null;
25
+
26
+ // Direct key match (Hub content name)
27
+ if (Object.hasOwn(catalog.models, modelId)) {
28
+ return catalog.models[modelId] || null;
29
+ }
30
+
31
+ // Match by huggingFaceId field (e.g., "Qwen/Qwen3-0.6B")
32
+ for (const [, entry] of Object.entries(catalog.models)) {
33
+ if (entry.huggingFaceId === modelId) {
34
+ return entry;
35
+ }
36
+ }
37
+
38
+ // Normalized match: strip org prefix, lowercase, replace dots/spaces with hyphens
39
+ const normalized = modelId.split('/').pop().toLowerCase().replace(/[.\s]+/g, '-');
40
+ if (normalized && Object.hasOwn(catalog.models, normalized)) {
41
+ return catalog.models[normalized] || null;
26
42
  }
27
- return catalog.models[modelId] || null;
43
+
44
+ // Try without trailing suffixes like -instruct, -chat, -hf, -base
45
+ const base = normalized ? normalized.replace(/-(instruct|chat|hf|base)$/i, '') : '';
46
+ if (base && base !== normalized && Object.hasOwn(catalog.models, base)) {
47
+ return catalog.models[base] || null;
48
+ }
49
+
50
+ // Suffix match: catalog keys may have prefixes (e.g., "huggingface-reasoning-")
51
+ // Match if a catalog key ends with the normalized name (must be non-trivial match)
52
+ if (normalized && normalized.length >= 4) {
53
+ for (const [key, entry] of Object.entries(catalog.models)) {
54
+ if (key.endsWith(normalized) || (base && base.length >= 4 && key.endsWith(base))) {
55
+ return entry || null;
56
+ }
57
+ }
58
+ }
59
+
60
+ return null;
28
61
  }
29
62
 
30
63
  /**
@@ -243,6 +243,7 @@ ENV <%= key %>=<%= value %>
243
243
  ENV VLLM_ENABLE_LORA=true
244
244
  ENV VLLM_MAX_LORAS=<%= maxLoras %>
245
245
  ENV VLLM_MAX_LORA_RANK=<%= maxLoraRank %>
246
+ ENV VLLM_ALLOW_RUNTIME_LORA_UPDATING=true
246
247
  <% } %>
247
248
  <% if (enableLora && modelServer === 'sglang') { %>
248
249
  # LoRA adapter serving configuration
@@ -307,9 +308,17 @@ COPY code/serving.properties /opt/ml/model/serving.properties
307
308
  # LMI/DJL containers use their own entrypoint
308
309
  # The container will automatically start DJL Serving with the configuration
309
310
  <% } else { %>
311
+ <% if (enableLora && (modelServer === 'vllm' || modelServer === 'sglang')) { %>
312
+ # Install aiohttp for the adapter sidecar
313
+ RUN pip install --no-cache-dir aiohttp
314
+
315
+ <% } %>
310
316
  COPY code/cuda_compat.sh /usr/bin/cuda_compat.sh
311
317
  COPY code/cw_log_forwarder.py /usr/bin/cw_log_forwarder.py
312
318
  COPY code/serve /usr/bin/serve
319
+ <% if (enableLora && (modelServer === 'vllm' || modelServer === 'sglang')) { %>
320
+ COPY code/adapter_sidecar.py /usr/bin/adapter_sidecar.py
321
+ <% } %>
313
322
  RUN chmod 777 /usr/bin/serve /usr/bin/cuda_compat.sh
314
323
 
315
324
  <% if (comments && comments.troubleshooting) { %>
@@ -0,0 +1,444 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ """
6
+ Adapter Sidecar — SageMaker AI adapter contract implementation.
7
+
8
+ Lightweight aiohttp HTTP server that sits between SageMaker (port 8080) and the
9
+ model server (port 8081). Implements POST /adapters and DELETE /adapters by
10
+ translating them into the model server's native LoRA API, while proxying all
11
+ other traffic transparently.
12
+
13
+ Configuration (environment variables):
14
+ MODEL_SERVER_PORT - Internal model server port (default: 8081)
15
+ MODEL_SERVER_TYPE - Model server type: vllm or sglang (default: vllm)
16
+ SIDECAR_PORT - Port sidecar listens on (default: 8080)
17
+ MAX_LORAS - Maximum concurrent adapters (default: 64)
18
+ HEALTH_POLL_INTERVAL - Seconds between health polls (default: 2)
19
+ HEALTH_TIMEOUT - Seconds to wait for model server readiness (default: 600)
20
+ """
21
+
22
+ import asyncio
23
+ import os
24
+ import tarfile
25
+ import time
26
+ from datetime import datetime, timezone
27
+
28
+ from aiohttp import web, ClientSession, ClientTimeout
29
+
30
+
31
+ # ── Configuration ─────────────────────────────────────────────────────────────
32
+
33
+ MODEL_SERVER_PORT = int(os.environ.get('MODEL_SERVER_PORT', '8081'))
34
+ MODEL_SERVER_TYPE = os.environ.get('MODEL_SERVER_TYPE', 'vllm')
35
+ SIDECAR_PORT = int(os.environ.get('SIDECAR_PORT', '8080'))
36
+ MAX_LORAS = int(os.environ.get('MAX_LORAS', '64'))
37
+ HEALTH_POLL_INTERVAL = int(os.environ.get('HEALTH_POLL_INTERVAL', '2'))
38
+ HEALTH_TIMEOUT = int(os.environ.get('HEALTH_TIMEOUT', '600'))
39
+
40
+ MODEL_SERVER_BASE = f'http://localhost:{MODEL_SERVER_PORT}'
41
+
42
+
43
+ # ── Logging ───────────────────────────────────────────────────────────────────
44
+
45
+ def log(message, stream='stdout'):
46
+ """Emit a log message with ISO 8601 timestamp and [adapter-sidecar] prefix."""
47
+ ts = datetime.now(timezone.utc).strftime('%Y-%m-%dT%H:%M:%SZ')
48
+ line = f'{ts} [adapter-sidecar] {message}'
49
+ if stream == 'stderr':
50
+ import sys
51
+ print(line, file=sys.stderr)
52
+ else:
53
+ print(line)
54
+
55
+
56
+ # ── Artifact Resolution ───────────────────────────────────────────────────────
57
+
58
+ class ArtifactResolver:
59
+ """Resolves adapter artifacts from a source path.
60
+
61
+ Handles three cases:
62
+ 1. Path contains a single tar.gz file — extract in place, return directory
63
+ 2. Path contains adapter_config.json — use directory directly
64
+ 3. Path does not exist or is empty — raise FileNotFoundError
65
+ """
66
+
67
+ @staticmethod
68
+ def resolve(src):
69
+ """Resolve the adapter artifact path.
70
+
71
+ Args:
72
+ src: Filesystem path where SageMaker placed adapter artifacts.
73
+
74
+ Returns:
75
+ Resolved directory path containing adapter files.
76
+
77
+ Raises:
78
+ FileNotFoundError: If path does not exist or is empty.
79
+ RuntimeError: If tar.gz extraction fails.
80
+ """
81
+ # Check if path exists
82
+ if not os.path.exists(src):
83
+ raise FileNotFoundError(f'Adapter artifact path does not exist: {src}')
84
+
85
+ # If src is a file (direct tar.gz path), extract it
86
+ if os.path.isfile(src) and src.endswith('.tar.gz'):
87
+ extract_dir = os.path.dirname(src)
88
+ ArtifactResolver._extract_tar_gz(src, extract_dir)
89
+ return extract_dir
90
+
91
+ # If src is a directory, check contents
92
+ if not os.path.isdir(src):
93
+ raise FileNotFoundError(f'Adapter artifact path is not a directory: {src}')
94
+
95
+ # Check if directory is empty
96
+ contents = os.listdir(src)
97
+ if not contents:
98
+ raise FileNotFoundError(f'Adapter artifact path is empty: {src}')
99
+
100
+ # Check if directory already contains adapter_config.json (extracted files)
101
+ if 'adapter_config.json' in contents:
102
+ return src
103
+
104
+ # Check if directory contains a single tar.gz file
105
+ tar_files = [f for f in contents if f.endswith('.tar.gz')]
106
+ if len(tar_files) == 1:
107
+ tar_path = os.path.join(src, tar_files[0])
108
+ ArtifactResolver._extract_tar_gz(tar_path, src)
109
+ return src
110
+
111
+ # If we get here, the path exists but has no recognizable adapter artifacts
112
+ # Check again after potential extraction if adapter_config.json appeared
113
+ if 'adapter_config.json' in os.listdir(src):
114
+ return src
115
+
116
+ raise FileNotFoundError(
117
+ f'Adapter artifact path does not contain adapter_config.json or a tar.gz archive: {src}'
118
+ )
119
+
120
+ @staticmethod
121
+ def _extract_tar_gz(tar_path, extract_dir):
122
+ """Extract a tar.gz archive to the specified directory.
123
+
124
+ Args:
125
+ tar_path: Path to the tar.gz file.
126
+ extract_dir: Directory to extract files into.
127
+
128
+ Raises:
129
+ RuntimeError: If extraction fails due to corruption or permission issues.
130
+ """
131
+ try:
132
+ with tarfile.open(tar_path, 'r:gz') as tar:
133
+ # Use filter='data' on Python 3.12+ for security, fall back for older versions
134
+ if hasattr(tarfile, 'data_filter'):
135
+ tar.extractall(path=extract_dir, filter='data')
136
+ else:
137
+ tar.extractall(path=extract_dir)
138
+ except (tarfile.TarError, OSError, PermissionError) as e:
139
+ raise RuntimeError(f'Failed to extract tar.gz archive {tar_path}: {e}')
140
+
141
+
142
+ # ── Model Server Client (Strategy Pattern) ────────────────────────────────────
143
+
144
+ class ModelServerClient:
145
+ """Strategy interface for model server native LoRA API translation.
146
+
147
+ Subclasses implement the specific HTTP calls for each model server type.
148
+ """
149
+
150
+ def __init__(self, session, base_url):
151
+ self.session = session
152
+ self.base_url = base_url
153
+
154
+ async def load_adapter(self, name, path):
155
+ """Load a LoRA adapter into the model server.
156
+
157
+ Args:
158
+ name: Adapter identifier.
159
+ path: Resolved filesystem path to adapter artifacts.
160
+
161
+ Returns:
162
+ dict with response data from the model server.
163
+
164
+ Raises:
165
+ RuntimeError: If the model server returns an error or is unreachable.
166
+ """
167
+ raise NotImplementedError
168
+
169
+ async def unload_adapter(self, name):
170
+ """Unload a LoRA adapter from the model server.
171
+
172
+ Args:
173
+ name: Adapter identifier.
174
+
175
+ Returns:
176
+ dict with response data from the model server.
177
+
178
+ Raises:
179
+ RuntimeError: If the model server returns an error or is unreachable.
180
+ """
181
+ raise NotImplementedError
182
+
183
+
184
+ class VLLMClient(ModelServerClient):
185
+ """vLLM-specific adapter API translation.
186
+
187
+ Load: POST /v1/load_lora_adapter {"lora_name": name, "lora_path": path}
188
+ Unload: POST /v1/unload_lora_adapter {"lora_name": name}
189
+ """
190
+
191
+ async def load_adapter(self, name, path):
192
+ """Load a LoRA adapter via vLLM's native API."""
193
+ url = f'{self.base_url}/v1/load_lora_adapter'
194
+ payload = {'lora_name': name, 'lora_path': path}
195
+ try:
196
+ async with self.session.post(url, json=payload) as resp:
197
+ body = await resp.text()
198
+ if resp.status == 200:
199
+ return {'status': 'success', 'response': body}
200
+ raise RuntimeError(f'vLLM load_lora_adapter failed (HTTP {resp.status}): {body}')
201
+ except RuntimeError:
202
+ raise
203
+ except Exception as e:
204
+ raise RuntimeError(f'Failed to connect to vLLM: {e}')
205
+
206
+ async def unload_adapter(self, name):
207
+ """Unload a LoRA adapter via vLLM's native API."""
208
+ url = f'{self.base_url}/v1/unload_lora_adapter'
209
+ payload = {'lora_name': name}
210
+ try:
211
+ async with self.session.post(url, json=payload) as resp:
212
+ body = await resp.text()
213
+ if resp.status == 200:
214
+ return {'status': 'success', 'response': body}
215
+ raise RuntimeError(f'vLLM unload_lora_adapter failed (HTTP {resp.status}): {body}')
216
+ except RuntimeError:
217
+ raise
218
+ except Exception as e:
219
+ raise RuntimeError(f'Failed to connect to vLLM: {e}')
220
+
221
+
222
+ class SGLangClient(ModelServerClient):
223
+ """SGLang-specific adapter API translation. (Deferred)
224
+
225
+ SGLang support is deferred to a follow-up. This placeholder raises
226
+ NotImplementedError for both load and unload operations.
227
+ """
228
+
229
+ async def load_adapter(self, name, path):
230
+ """Load a LoRA adapter via SGLang's native API. (Not yet implemented)"""
231
+ raise NotImplementedError('SGLang adapter loading is not yet implemented')
232
+
233
+ async def unload_adapter(self, name):
234
+ """Unload a LoRA adapter via SGLang's native API. (Not yet implemented)"""
235
+ raise NotImplementedError('SGLang adapter unloading is not yet implemented')
236
+
237
+
238
+ def create_model_server_client(session, base_url, server_type):
239
+ """Factory function to create the appropriate ModelServerClient.
240
+
241
+ Args:
242
+ session: aiohttp.ClientSession for HTTP calls.
243
+ base_url: Model server base URL (e.g., http://localhost:8081).
244
+ server_type: Model server type ('vllm' or 'sglang').
245
+
246
+ Returns:
247
+ ModelServerClient instance.
248
+ """
249
+ if server_type == 'vllm':
250
+ return VLLMClient(session, base_url)
251
+ elif server_type == 'sglang':
252
+ return SGLangClient(session, base_url)
253
+ else:
254
+ raise ValueError(f'Unsupported model server type: {server_type}')
255
+
256
+
257
+ # ── State ─────────────────────────────────────────────────────────────────────
258
+
259
+ adapter_registry = {}
260
+ model_server_ready = False
261
+
262
+
263
+ # ── Health Polling (Readiness Gating) ─────────────────────────────────────────
264
+
265
+ async def poll_model_server_health(app):
266
+ """Background task that polls the model server health endpoint.
267
+
268
+ Sets model_server_ready to True once the health endpoint returns 200.
269
+ After HEALTH_TIMEOUT seconds, logs a warning and sets ready to True
270
+ to avoid indefinite blocking.
271
+ """
272
+ global model_server_ready
273
+ session = app['session']
274
+ start_time = time.monotonic()
275
+
276
+ log(f'Starting health polling — interval={HEALTH_POLL_INTERVAL}s, timeout={HEALTH_TIMEOUT}s')
277
+
278
+ while True:
279
+ elapsed = time.monotonic() - start_time
280
+
281
+ # Timeout: log warning and begin accepting requests
282
+ if elapsed >= HEALTH_TIMEOUT:
283
+ log(f'Health timeout reached ({HEALTH_TIMEOUT}s) — model server did not become ready. Accepting requests anyway.', stream='stderr')
284
+ model_server_ready = True
285
+ return
286
+
287
+ try:
288
+ async with session.get(f'{MODEL_SERVER_BASE}/health', timeout=ClientTimeout(total=5)) as resp:
289
+ if resp.status == 200:
290
+ model_server_ready = True
291
+ log('Model server is ready (health endpoint returned 200)')
292
+ return
293
+ except Exception:
294
+ pass
295
+
296
+ await asyncio.sleep(HEALTH_POLL_INTERVAL)
297
+
298
+
299
+ # ── Handlers ──────────────────────────────────────────────────────────────────
300
+
301
+ async def handle_ping(request):
302
+ """GET /ping — readiness gating.
303
+
304
+ Returns 503 until model server health endpoint returns 200.
305
+ Once ready, proxies /ping to the model server.
306
+ """
307
+ if not model_server_ready:
308
+ return web.Response(status=503, text='Service Unavailable')
309
+
310
+ # Proxy to model server health endpoint
311
+ session = request.app['session']
312
+ try:
313
+ async with session.get(f'{MODEL_SERVER_BASE}/health') as resp:
314
+ body = await resp.read()
315
+ return web.Response(status=resp.status, body=body,
316
+ headers={'Content-Type': resp.headers.get('Content-Type', 'text/plain')})
317
+ except Exception as e:
318
+ return web.Response(status=503, text=f'Model server unreachable: {e}')
319
+
320
+
321
+ async def handle_adapters_post(request):
322
+ """POST /adapters — load a LoRA adapter."""
323
+ name = request.query.get('name')
324
+ src = request.query.get('src')
325
+
326
+ if not name:
327
+ return web.json_response({'status': 'error', 'error': 'Missing required query parameter: name'}, status=400)
328
+ if not src:
329
+ return web.json_response({'status': 'error', 'error': 'Missing required query parameter: src'}, status=400)
330
+
331
+ # Check MAX_LORAS limit
332
+ if len(adapter_registry) >= MAX_LORAS:
333
+ return web.json_response(
334
+ {'status': 'error', 'adapter': name, 'error': f'Maximum concurrent adapters ({MAX_LORAS}) reached'},
335
+ status=507
336
+ )
337
+
338
+ # Resolve adapter artifacts
339
+ try:
340
+ resolved_path = ArtifactResolver.resolve(src)
341
+ except FileNotFoundError as e:
342
+ return web.json_response({'status': 'error', 'adapter': name, 'error': str(e)}, status=404)
343
+ except RuntimeError as e:
344
+ return web.json_response({'status': 'error', 'adapter': name, 'error': str(e)}, status=500)
345
+
346
+ # Call model server native LoRA API
347
+ client = request.app['model_server_client']
348
+ try:
349
+ await client.load_adapter(name, resolved_path)
350
+ except RuntimeError as e:
351
+ log(f'Adapter load failed — name={name}, src={src}, error={e}', stream='stderr')
352
+ return web.json_response({'status': 'error', 'adapter': name, 'error': str(e)}, status=500)
353
+
354
+ # Register adapter and respond
355
+ adapter_registry[name] = resolved_path
356
+ log(f'Adapter loaded — name={name}, src={src}, resolved_path={resolved_path}')
357
+ return web.json_response({'status': 'loaded', 'adapter': name, 'path': resolved_path})
358
+
359
+
360
+ async def handle_adapters_delete(request):
361
+ """DELETE /adapters — unload a LoRA adapter."""
362
+ name = request.query.get('name')
363
+
364
+ if not name:
365
+ return web.json_response({'status': 'error', 'error': 'Missing required query parameter: name'}, status=400)
366
+
367
+ # Call model server native LoRA API
368
+ client = request.app['model_server_client']
369
+ try:
370
+ await client.unload_adapter(name)
371
+ except RuntimeError as e:
372
+ log(f'Adapter unload failed — name={name}, error={e}', stream='stderr')
373
+ return web.json_response({'status': 'error', 'adapter': name, 'error': str(e)}, status=500)
374
+
375
+ # Remove from registry and respond
376
+ adapter_registry.pop(name, None)
377
+ log(f'Adapter unloaded — name={name}')
378
+ return web.json_response({'status': 'unloaded', 'adapter': name})
379
+
380
+
381
+ async def handle_proxy(request):
382
+ """Proxy all non-/adapters requests to the model server transparently."""
383
+ session = request.app['session']
384
+ target_url = f'{MODEL_SERVER_BASE}{request.path_qs}'
385
+
386
+ try:
387
+ body = await request.read()
388
+ async with session.request(
389
+ method=request.method,
390
+ url=target_url,
391
+ headers={k: v for k, v in request.headers.items() if k.lower() != 'host'},
392
+ data=body if body else None
393
+ ) as resp:
394
+ resp_body = await resp.read()
395
+ response_headers = {k: v for k, v in resp.headers.items()
396
+ if k.lower() not in ('transfer-encoding', 'content-encoding', 'content-length')}
397
+ return web.Response(status=resp.status, body=resp_body, headers=response_headers)
398
+ except Exception as e:
399
+ return web.json_response({'status': 'error', 'error': f'Model server unreachable: {e}'}, status=500)
400
+
401
+
402
+ # ── Application Setup ─────────────────────────────────────────────────────────
403
+
404
+ async def on_startup(app):
405
+ """Create HTTP session, model server client, and start health polling background task."""
406
+ app['session'] = ClientSession()
407
+ app['model_server_client'] = create_model_server_client(app['session'], MODEL_SERVER_BASE, MODEL_SERVER_TYPE)
408
+ app['health_task'] = asyncio.create_task(poll_model_server_health(app))
409
+ log(f'Sidecar started — port={SIDECAR_PORT}, model_server_port={MODEL_SERVER_PORT}, '
410
+ f'model_server_type={MODEL_SERVER_TYPE}, max_loras={MAX_LORAS}')
411
+
412
+
413
+ async def on_cleanup(app):
414
+ """Cleanup HTTP session and cancel background tasks."""
415
+ app['health_task'].cancel()
416
+ try:
417
+ await app['health_task']
418
+ except asyncio.CancelledError:
419
+ pass
420
+ await app['session'].close()
421
+
422
+
423
+ def create_app():
424
+ """Create and configure the aiohttp application."""
425
+ app = web.Application()
426
+
427
+ # Register routes
428
+ app.router.add_get('/ping', handle_ping)
429
+ app.router.add_post('/adapters', handle_adapters_post)
430
+ app.router.add_delete('/adapters', handle_adapters_delete)
431
+
432
+ # Catch-all proxy for everything else
433
+ app.router.add_route('*', '/{path:.*}', handle_proxy)
434
+
435
+ # Lifecycle hooks
436
+ app.on_startup.append(on_startup)
437
+ app.on_cleanup.append(on_cleanup)
438
+
439
+ return app
440
+
441
+
442
+ if __name__ == '__main__':
443
+ app = create_app()
444
+ web.run_app(app, host='0.0.0.0', port=SIDECAR_PORT, print=None)