@aws/ml-container-creator 0.9.0 → 0.10.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. package/bin/cli.js +31 -137
  2. package/config/parameter-schema-v2.json +2065 -0
  3. package/package.json +6 -3
  4. package/servers/lib/catalogs/jumpstart-public.json +101 -16
  5. package/servers/lib/catalogs/models.json +182 -26
  6. package/src/app.js +6 -389
  7. package/src/lib/bootstrap-command-handler.js +75 -1078
  8. package/src/lib/bootstrap-profile-manager.js +634 -0
  9. package/src/lib/bootstrap-provisioners.js +421 -0
  10. package/src/lib/config-loader.js +405 -0
  11. package/src/lib/config-manager.js +59 -1668
  12. package/src/lib/config-mcp-client.js +118 -0
  13. package/src/lib/config-validator.js +634 -0
  14. package/src/lib/cuda-resolver.js +140 -0
  15. package/src/lib/e2e-catalog-validator.js +251 -3
  16. package/src/lib/e2e-ci-recorder.js +103 -0
  17. package/src/lib/generated/cli-options.js +471 -0
  18. package/src/lib/generated/parameter-matrix.js +671 -0
  19. package/src/lib/generated/validation-rules.js +202 -0
  20. package/src/lib/marketplace-flow.js +276 -0
  21. package/src/lib/mcp-query-runner.js +768 -0
  22. package/src/lib/parameter-schema-validator.js +62 -18
  23. package/src/lib/prompt-runner.js +41 -1504
  24. package/src/lib/prompts/feature-prompts.js +172 -0
  25. package/src/lib/prompts/index.js +48 -0
  26. package/src/lib/prompts/infrastructure-prompts.js +690 -0
  27. package/src/lib/prompts/model-prompts.js +552 -0
  28. package/src/lib/prompts/project-prompts.js +70 -0
  29. package/src/lib/prompts.js +2 -1446
  30. package/src/lib/registry-command-handler.js +135 -3
  31. package/src/lib/secrets-prompt-runner.js +251 -0
  32. package/src/lib/template-variable-resolver.js +398 -0
  33. package/templates/code/serve +5 -134
  34. package/templates/code/serve.d/lmi.ejs +19 -0
  35. package/templates/code/serve.d/sglang.ejs +47 -0
  36. package/templates/code/serve.d/tensorrt-llm.ejs +53 -0
  37. package/templates/code/serve.d/vllm.ejs +48 -0
  38. package/templates/do/clean +1 -1387
  39. package/templates/do/clean.d/async-inference.ejs +508 -0
  40. package/templates/do/clean.d/batch-transform.ejs +512 -0
  41. package/templates/do/clean.d/hyperpod-eks.ejs +481 -0
  42. package/templates/do/clean.d/managed-inference.ejs +1043 -0
  43. package/templates/do/deploy +1 -1766
  44. package/templates/do/deploy.d/async-inference.ejs +501 -0
  45. package/templates/do/deploy.d/batch-transform.ejs +529 -0
  46. package/templates/do/deploy.d/hyperpod-eks.ejs +339 -0
  47. package/templates/do/deploy.d/managed-inference.ejs +726 -0
  48. package/config/parameter-schema.json +0 -88
@@ -0,0 +1,398 @@
1
+ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ // SPDX-License-Identifier: Apache-2.0
3
+
4
+ import fs from 'fs';
5
+ import path from 'path';
6
+ import { fileURLToPath } from 'url';
7
+ import { isTuneSupported } from './tune-catalog-validator.js';
8
+
9
+ const __filename = fileURLToPath(import.meta.url);
10
+ const __dirname = path.dirname(__filename);
11
+
12
+ /**
13
+ * Finds model configuration by exact match or glob-pattern match.
14
+ *
15
+ * @param {string} modelName - Model ID to look up
16
+ * @param {object} registryConfigManager - Registry configuration manager
17
+ * @returns {object|null} Model configuration or null
18
+ */
19
+ function _findModelConfig(modelName, registryConfigManager) {
20
+ if (!registryConfigManager?.modelRegistry) return null;
21
+
22
+ // Exact match first
23
+ const exact = registryConfigManager.modelRegistry[modelName];
24
+ if (exact) return exact;
25
+
26
+ // Pattern matching with glob-style wildcards
27
+ for (const [pattern, config] of Object.entries(registryConfigManager.modelRegistry)) {
28
+ if (pattern.includes('*')) {
29
+ const regex = new RegExp(`^${pattern.replace(/\*/g, '.*')}$`);
30
+ if (regex.test(modelName)) {
31
+ return config;
32
+ }
33
+ }
34
+ }
35
+
36
+ return null;
37
+ }
38
+
39
+ /**
40
+ * Merges environment variables from all catalog sources with correct precedence.
41
+ * Precedence (lowest → highest):
42
+ * 1. catalog defaults (Image_Entry defaults.envVars)
43
+ * 2. framework profile (Image_Entry profiles[selectedProfile].envVars)
44
+ * 3. model entry (model catalog entry envVars)
45
+ * 4. model profile (model catalog entry profiles[selectedProfile].envVars)
46
+ * 5. CLI overrides (existing answers.envVars from user CLI input)
47
+ *
48
+ * @param {object} answers - Configuration answers
49
+ * @param {object|null} registryConfigManager - Registry configuration manager
50
+ */
51
+ export async function _mergeEnvVarsWithPrecedence(answers, registryConfigManager) {
52
+ if (!registryConfigManager) return;
53
+
54
+ // Capture CLI-provided env vars before merging (highest precedence)
55
+ const cliEnvVars = { ...answers.envVars };
56
+
57
+ // Resolve the framework config for the selected framework + version
58
+ const frameworkName = answers.framework || answers.deploymentConfig;
59
+ const frameworkVersion = answers.frameworkVersion;
60
+ let frameworkConfig = null;
61
+
62
+ if (frameworkName && registryConfigManager.frameworkRegistry) {
63
+ const frameworkVersions = registryConfigManager.frameworkRegistry[frameworkName];
64
+ if (frameworkVersions) {
65
+ if (frameworkVersion && frameworkVersions[frameworkVersion]) {
66
+ frameworkConfig = frameworkVersions[frameworkVersion];
67
+ } else {
68
+ // Fall back to latest version for Triton and other non-versioned lookups
69
+ const versions = Object.keys(frameworkVersions).sort((a, b) =>
70
+ b.localeCompare(a, undefined, { numeric: true })
71
+ );
72
+ if (versions.length > 0) {
73
+ frameworkConfig = frameworkVersions[versions[0]];
74
+ }
75
+ }
76
+ }
77
+ }
78
+
79
+ // Resolve the model config (exact match or pattern match)
80
+ let modelConfig = null;
81
+ if (answers.modelName && registryConfigManager.modelRegistry) {
82
+ modelConfig = _findModelConfig(answers.modelName, registryConfigManager);
83
+ }
84
+
85
+ // Layer 1: catalog defaults (Image_Entry defaults.envVars)
86
+ const catalogDefaults = frameworkConfig?.envVars || {};
87
+
88
+ // Layer 2: framework profile envVars
89
+ let frameworkProfileEnvVars = {};
90
+ if (answers.frameworkProfile && frameworkConfig?.profiles) {
91
+ const profile = frameworkConfig.profiles[answers.frameworkProfile];
92
+ if (profile?.envVars) {
93
+ frameworkProfileEnvVars = profile.envVars;
94
+ }
95
+ }
96
+
97
+ // Layer 3: model entry envVars
98
+ const modelEntryEnvVars = modelConfig?.envVars || {};
99
+
100
+ // Layer 4: model profile envVars
101
+ let modelProfileEnvVars = {};
102
+ if (answers.modelProfile && modelConfig?.profiles) {
103
+ const profile = modelConfig.profiles[answers.modelProfile];
104
+ if (profile?.envVars) {
105
+ modelProfileEnvVars = profile.envVars;
106
+ }
107
+ }
108
+
109
+ // Layer 5: CLI overrides (captured above)
110
+
111
+ // Merge in precedence order: each layer overrides the previous
112
+ answers.envVars = {
113
+ ...catalogDefaults,
114
+ ...frameworkProfileEnvVars,
115
+ ...modelEntryEnvVars,
116
+ ...modelProfileEnvVars,
117
+ ...cliEnvVars
118
+ };
119
+ }
120
+
121
+ /**
122
+ * Validates environment variables using the registry system.
123
+ * Displays errors and warnings to the user.
124
+ *
125
+ * @param {object} answers - Configuration answers
126
+ * @param {object} registryConfigManager - Registry configuration manager
127
+ */
128
+ export async function _validateEnvironmentVariables(answers, registryConfigManager) {
129
+ // Get framework configuration
130
+ // For Triton configs, look up using deploymentConfig key (e.g. 'triton-fil')
131
+ let frameworkConfig;
132
+ if (answers.architecture === 'triton' && answers.deploymentConfig) {
133
+ const tritonEntry = registryConfigManager.frameworkRegistry?.[answers.deploymentConfig];
134
+ if (tritonEntry) {
135
+ const versions = Object.keys(tritonEntry);
136
+ if (versions.length > 0) {
137
+ frameworkConfig = tritonEntry[versions[0]];
138
+ }
139
+ }
140
+ }
141
+ if (!frameworkConfig) {
142
+ frameworkConfig = registryConfigManager.frameworkRegistry?.[answers.framework]?.[answers.frameworkVersion];
143
+ }
144
+
145
+ if (!frameworkConfig || !frameworkConfig.envVars) {
146
+ return; // No env vars to validate
147
+ }
148
+
149
+ console.log('\n🔍 Validating environment variables...');
150
+
151
+ // Validate environment variables
152
+ const validationResult = registryConfigManager.validateEnvironmentVariables(
153
+ frameworkConfig.envVars,
154
+ frameworkConfig
155
+ );
156
+
157
+ // Display validation results
158
+ if (validationResult.errors && validationResult.errors.length > 0) {
159
+ console.log('\n❌ Environment Variable Validation Errors:');
160
+ validationResult.errors.forEach(error => {
161
+ console.log(` • ${error.key}: ${error.message}`);
162
+ });
163
+ }
164
+
165
+ if (validationResult.warnings && validationResult.warnings.length > 0) {
166
+ console.log('\n⚠️ Environment Variable Validation Warnings:');
167
+ validationResult.warnings.forEach(warning => {
168
+ console.log(` • ${warning.key ? `${warning.key}: ` : ''}${warning.message}`);
169
+ });
170
+ }
171
+
172
+ if (validationResult.strategiesUsed && validationResult.strategiesUsed.length > 0) {
173
+ console.log(`\n✅ Validation methods used: ${validationResult.strategiesUsed.join(', ')}`);
174
+ }
175
+
176
+ if (!validationResult.errors || validationResult.errors.length === 0) {
177
+ if (!validationResult.warnings || validationResult.warnings.length === 0) {
178
+ console.log(' ✅ All environment variables validated successfully');
179
+ }
180
+ }
181
+
182
+ // In non-interactive mode (skip-prompts), throw on errors
183
+ if (validationResult.errors && validationResult.errors.length > 0) {
184
+ throw new Error('Environment variable validation failed. Please fix the errors and try again.');
185
+ }
186
+ }
187
+
188
+ /**
189
+ * Ensures all template variables have proper defaults to prevent
190
+ * "undefined" errors in EJS templates. Also enriches answers with
191
+ * registry data (env var merging, HuggingFace data, Triton base image).
192
+ *
193
+ * @param {object} answers - Answers object to fill defaults into
194
+ * @param {object|null} registryConfigManager - Registry configuration manager (or null)
195
+ */
196
+ export async function _ensureTemplateVariables(answers, registryConfigManager = null) {
197
+ const defaults = {
198
+ chatTemplate: null,
199
+ chatTemplateSource: null,
200
+ hfToken: null,
201
+ hfTokenArn: null,
202
+ ngcApiKey: null,
203
+ ngcTokenArn: null,
204
+ envVars: {},
205
+ inferenceAmiVersion: null,
206
+ accelerator: null,
207
+ frameworkVersion: null,
208
+ validationLevel: 'unknown',
209
+ configSources: [],
210
+ recommendedInstanceTypes: [],
211
+ roleArn: null,
212
+ deploymentConfig: '',
213
+ architecture: null,
214
+ backend: null,
215
+ engine: null,
216
+ codebuildComputeType: null,
217
+ codebuildProjectName: null,
218
+ modelName: null,
219
+ modelFormat: null,
220
+ includeSampleModel: true,
221
+ includeTesting: true,
222
+ testTypes: [],
223
+ buildTimestamp: new Date().toISOString(),
224
+ buildTarget: 'codebuild',
225
+ deploymentTarget: 'realtime-inference',
226
+ hyperPodCluster: null,
227
+ hyperPodNamespace: 'default',
228
+ hyperPodReplicas: 1,
229
+ fsxVolumeHandle: null,
230
+ baseImage: null,
231
+ modelSource: 'huggingface',
232
+ artifactUri: '',
233
+ modelLoadStrategy: 'runtime',
234
+ existingEndpointName: null,
235
+ enableLora: false,
236
+ maxLoras: 30,
237
+ maxLoraRank: 64
238
+ };
239
+
240
+ Object.entries(defaults).forEach(([key, value]) => {
241
+ if (answers[key] === undefined) {
242
+ answers[key] = value;
243
+ }
244
+ });
245
+
246
+ // Backward compatibility: populate framework and modelServer from architecture/backend
247
+ if (!answers.framework && answers.architecture) {
248
+ answers.framework = answers.architecture;
249
+ }
250
+ if (!answers.modelServer && answers.backend) {
251
+ answers.modelServer = answers.backend;
252
+ }
253
+
254
+ // Always include testing with all available test types
255
+ answers.includeTesting = true;
256
+ if (!answers.testTypes || answers.testTypes.length === 0) {
257
+ if (answers.architecture === 'transformers' || answers.framework === 'transformers') {
258
+ answers.testTypes = ['hosted-model-endpoint'];
259
+ } else {
260
+ answers.testTypes = ['local-model-cli', 'local-model-server', 'hosted-model-endpoint'];
261
+ }
262
+ }
263
+
264
+ // Merge catalog env vars into answers.envVars with correct precedence
265
+ await _mergeEnvVarsWithPrecedence(answers, registryConfigManager);
266
+
267
+ // For Triton architecture, set default base image fallback
268
+ if (answers.architecture === 'triton' && !answers.baseImage) {
269
+ // Try to look up base image from framework registry using deployment-config key
270
+ const tritonRegistryKey = answers.deploymentConfig;
271
+ if (tritonRegistryKey && registryConfigManager?.frameworkRegistry) {
272
+ const tritonFrameworkConfig = registryConfigManager.frameworkRegistry[tritonRegistryKey];
273
+ if (tritonFrameworkConfig) {
274
+ const versions = Object.keys(tritonFrameworkConfig).sort((a, b) =>
275
+ b.localeCompare(a, undefined, { numeric: true })
276
+ );
277
+ if (versions.length > 0) {
278
+ const latestConfig = tritonFrameworkConfig[versions[0]];
279
+ if (latestConfig.baseImage) {
280
+ answers.baseImage = latestConfig.baseImage;
281
+ }
282
+ if (latestConfig.inferenceAmiVersion && !answers.inferenceAmiVersion) {
283
+ answers.inferenceAmiVersion = latestConfig.inferenceAmiVersion;
284
+ }
285
+ if (latestConfig.accelerator) {
286
+ answers.accelerator = latestConfig.accelerator;
287
+ }
288
+ }
289
+ }
290
+ }
291
+ // Final fallback: hardcoded default Triton base image
292
+ if (!answers.baseImage) {
293
+ answers.baseImage = 'nvcr.io/nvidia/tritonserver:24.08-py3';
294
+ }
295
+ }
296
+
297
+ // For transformer models, enrich with HuggingFace data and non-envVar metadata
298
+ if (answers.framework === 'transformers' && answers.modelName && registryConfigManager) {
299
+ try {
300
+ // Fetch HuggingFace data for model-specific info
301
+ const hfData = await registryConfigManager._fetchHuggingFaceData(answers.modelName);
302
+
303
+ // Merge chatTemplate if available and not already set
304
+ if (hfData && hfData.chatTemplate && !answers.chatTemplate) {
305
+ answers.chatTemplate = hfData.chatTemplate;
306
+ answers.chatTemplateSource = 'HuggingFace_Hub_API';
307
+ }
308
+
309
+ // Check Model Registry for chatTemplate overrides
310
+ if (registryConfigManager.modelRegistry) {
311
+ const modelConfig = _findModelConfig(answers.modelName, registryConfigManager);
312
+
313
+ if (modelConfig && modelConfig.chatTemplate) {
314
+ answers.chatTemplate = modelConfig.chatTemplate;
315
+ answers.chatTemplateSource = 'Model_Registry';
316
+ }
317
+ }
318
+
319
+ // Set framework-level metadata (non-envVar fields)
320
+ if (answers.frameworkVersion && registryConfigManager.frameworkRegistry) {
321
+ const frameworkConfig = registryConfigManager.frameworkRegistry[answers.framework]?.[answers.frameworkVersion];
322
+
323
+ if (frameworkConfig) {
324
+ if (frameworkConfig.inferenceAmiVersion && !answers.inferenceAmiVersion) {
325
+ answers.inferenceAmiVersion = frameworkConfig.inferenceAmiVersion;
326
+ }
327
+ if (frameworkConfig.accelerator) {
328
+ answers.accelerator = frameworkConfig.accelerator;
329
+ }
330
+ }
331
+ }
332
+ } catch (error) {
333
+ // Silently continue - defaults are already set
334
+ }
335
+ }
336
+
337
+ // Populate baseImage from the catalog when still falsy (covers --skip-prompts and
338
+ // cases where MCP/CLI/config did not provide a base image).
339
+ // Precedence: MCP > CLI > config > catalog default (this block).
340
+ if (!answers.baseImage && registryConfigManager?.frameworkRegistry) {
341
+ const backendKey = answers.backend || answers.modelServer;
342
+ if (backendKey) {
343
+ const frameworkVersions = registryConfigManager.frameworkRegistry[backendKey];
344
+ if (frameworkVersions) {
345
+ let resolvedConfig = null;
346
+ if (answers.frameworkVersion && frameworkVersions[answers.frameworkVersion]) {
347
+ resolvedConfig = frameworkVersions[answers.frameworkVersion];
348
+ } else {
349
+ // Fall back to latest version
350
+ const versions = Object.keys(frameworkVersions).sort((a, b) =>
351
+ b.localeCompare(a, undefined, { numeric: true })
352
+ );
353
+ if (versions.length > 0) {
354
+ resolvedConfig = frameworkVersions[versions[0]];
355
+ }
356
+ }
357
+ if (resolvedConfig?.baseImage) {
358
+ answers.baseImage = resolvedConfig.baseImage;
359
+ }
360
+ }
361
+ }
362
+ }
363
+
364
+ // Populate icGpuCount from instance catalog when not explicitly set.
365
+ // The deploy template uses IC_GPU_COUNT unconditionally for NumberOfAcceleratorDevicesRequired,
366
+ // so it must always have a value for GPU deployments.
367
+ if ((answers.icGpuCount === null || answers.icGpuCount === undefined) && answers.instanceType) {
368
+ // Use gpuCount from instance-sizer recommendation if available
369
+ if (answers.gpuCount) {
370
+ answers.icGpuCount = answers.gpuCount;
371
+ } else {
372
+ // Look up from instances catalog
373
+ try {
374
+ const catalogPath = path.resolve(__dirname, '..', '..', 'servers', 'lib', 'catalogs', 'instances.json');
375
+ const catalogData = JSON.parse(fs.readFileSync(catalogPath, 'utf-8'));
376
+ const instanceInfo = catalogData?.catalog?.[answers.instanceType];
377
+ if (instanceInfo?.gpus && instanceInfo.gpus > 0) {
378
+ answers.icGpuCount = instanceInfo.gpus;
379
+ }
380
+ } catch {
381
+ // Silently continue — template fallback handles missing value
382
+ }
383
+ }
384
+ }
385
+
386
+ // Determine tune support based on model presence in the tune catalog.
387
+ // Used by the do/config template to write TUNE_SUPPORTED=true|false.
388
+ if (answers.tuneSupported === undefined) {
389
+ try {
390
+ const tuneCatalogPath = path.resolve(__dirname, '..', '..', 'config', 'tune-catalog.json');
391
+ const tuneCatalog = JSON.parse(fs.readFileSync(tuneCatalogPath, 'utf-8'));
392
+ const modelId = answers.modelName || '';
393
+ answers.tuneSupported = isTuneSupported(modelId, tuneCatalog);
394
+ } catch {
395
+ answers.tuneSupported = false;
396
+ }
397
+ }
398
+ }
@@ -10,35 +10,10 @@ echo "$(date -u '+%Y-%m-%dT%H:%M:%SZ') [serve] Container started — PID $$"
10
10
  # CUDA compatibility setup (required for newer SageMaker inference AMIs)
11
11
  source /usr/bin/cuda_compat.sh 2>/dev/null || true
12
12
 
13
- <% if (modelServer === 'vllm') { %>
14
- echo "Starting vLLM server"
15
- <% } else if (modelServer === 'sglang') { %>
16
- echo "Starting SGLang server"
17
- <% } else if (modelServer === 'tensorrt-llm') { %>
18
- echo "Starting TensorRT-LLM server"
19
- <% } else if (modelServer === 'lmi') { %>
20
- echo "Starting LMI (Large Model Inference) server"
21
- <% } else if (modelServer === 'djl') { %>
22
- echo "Starting DJL Serving server"
23
- <% } %>
13
+ echo "Starting <%= modelServer %> server"
24
14
 
25
15
  <% if (modelServer === 'lmi' || modelServer === 'djl') { %>
26
- # LMI/DJL containers use serving.properties for configuration
27
- # The configuration file should be at /opt/ml/model/serving.properties
28
- # DJL Serving will automatically start with this configuration
29
-
30
- if [ ! -f /opt/ml/model/serving.properties ]; then
31
- echo "Error: serving.properties not found at /opt/ml/model/serving.properties"
32
- exit 1
33
- fi
34
-
35
- echo "Using configuration from /opt/ml/model/serving.properties"
36
- cat /opt/ml/model/serving.properties
37
-
38
- # DJL Serving is already configured in the base image
39
- # This script is not typically needed for LMI/DJL as they have their own entrypoint
40
- # But we provide it for consistency with other model servers
41
- exit 0
16
+ <%- include('serve.d/lmi') %>
42
17
  <% } else { %>
43
18
 
44
19
  <% if (typeof modelSource !== 'undefined' && modelSource !== 'huggingface') { %>
@@ -60,7 +35,6 @@ download_model_from_s3() {
60
35
  mkdir -p "${dest_path}"
61
36
 
62
37
  if [[ "$s3_uri" == *.tar.gz ]] || [[ "$s3_uri" == *.tgz ]]; then
63
- # Tarball: download and extract
64
38
  if ! aws s3 cp "$s3_uri" /tmp/model_archive.tar.gz; then
65
39
  echo "Error: Failed to download tarball from ${s3_uri}" >&2
66
40
  return 1
@@ -72,13 +46,11 @@ download_model_from_s3() {
72
46
  fi
73
47
  rm -f /tmp/model_archive.tar.gz
74
48
  elif [[ "$s3_uri" == */ ]] || ! aws s3 ls "$s3_uri" 2>/dev/null | grep -q "^[0-9]"; then
75
- # Directory prefix: sync
76
49
  if ! aws s3 sync "$s3_uri" "$dest_path"; then
77
50
  echo "Error: Failed to sync from ${s3_uri}" >&2
78
51
  return 1
79
52
  fi
80
53
  else
81
- # Single file: copy
82
54
  if ! aws s3 cp "$s3_uri" "$dest_path/"; then
83
55
  echo "Error: Failed to copy ${s3_uri}" >&2
84
56
  return 1
@@ -109,19 +81,16 @@ _MODEL_VAR="TRTLLM_MODEL"
109
81
  resolve_model() {
110
82
  case "$MODEL_SOURCE" in
111
83
  huggingface)
112
- # Pass model name directly — server fetches from HF Hub
113
84
  echo "${!_MODEL_VAR}"
114
85
  return
115
86
  ;;
116
87
  s3|registry)
117
- # Check for pre-mounted artifacts first
118
88
  if [ -d "$LOCAL_MODEL_PATH" ] && [ "$(ls -A $LOCAL_MODEL_PATH 2>/dev/null)" ]; then
119
89
  echo "Using pre-mounted model artifacts at $LOCAL_MODEL_PATH" >&2
120
90
  echo "$LOCAL_MODEL_PATH"
121
91
  return
122
92
  fi
123
93
 
124
- # For registry:// models, resolve artifact URI at runtime via SageMaker API
125
94
  if [ "$MODEL_SOURCE" = "registry" ] && [ -z "$MODEL_ARTIFACT_URI" ]; then
126
95
  local model_uri="${!_MODEL_VAR}"
127
96
  local registry_prefix="registry://"
@@ -131,7 +100,6 @@ resolve_model() {
131
100
  local version="${registry_path#*/}"
132
101
  local region="${AWS_REGION:-${AWS_DEFAULT_REGION:-us-east-1}}"
133
102
 
134
- # Get account ID for ARN construction
135
103
  local account_id
136
104
  account_id=$(aws sts get-caller-identity --query Account --output text 2>/dev/null) || {
137
105
  echo "Error: Failed to get AWS account ID for model package ARN" >&2
@@ -151,38 +119,22 @@ resolve_model() {
151
119
  exit 1
152
120
  }
153
121
 
154
- # Try ModelDataUrl first, then S3DataSource.S3Uri, then description
155
122
  MODEL_ARTIFACT_URI=$(echo "$describe_output" | python3 -c "
156
123
  import sys, json, re
157
124
  try:
158
125
  pkg = json.load(sys.stdin)
159
126
  uri = ''
160
- # Check InferenceSpecification.Containers[0]
161
127
  containers = pkg.get('InferenceSpecification', {}).get('Containers', [])
162
128
  if containers:
163
129
  c = containers[0]
164
130
  uri = c.get('ModelDataUrl', '')
165
131
  if not uri:
166
132
  uri = c.get('ModelDataSource', {}).get('S3DataSource', {}).get('S3Uri', '')
167
- # Fallback: extract S3 URI from ModelPackageDescription
168
133
  if not uri:
169
134
  desc = pkg.get('ModelPackageDescription', '')
170
135
  m = re.search(r's3://[^\s]+', desc)
171
136
  if m:
172
137
  uri = m.group(0)
173
- # Fallback: check ModelCard hyperparameters for model_artifacts_s3
174
- if not uri:
175
- try:
176
- card = pkg.get('ModelCard', {})
177
- content = card.get('ModelCardContent', '{}')
178
- card_data = json.loads(content) if isinstance(content, str) else content
179
- params = card_data.get('training_details', {}).get('training_job_details', {}).get('hyper_parameters', [])
180
- for p in params:
181
- if p.get('name') == 'model_artifacts_s3':
182
- uri = p.get('value', '')
183
- break
184
- except:
185
- pass
186
138
  print(uri)
187
139
  except:
188
140
  print('')
@@ -192,19 +144,15 @@ except:
192
144
  echo "Resolved artifact URI: ${MODEL_ARTIFACT_URI}" >&2
193
145
  else
194
146
  echo "Error: No model artifact URI found in model package: ${package_arn}" >&2
195
- echo " Checked: InferenceSpecification.Containers[0].ModelDataUrl" >&2
196
- echo " Checked: InferenceSpecification.Containers[0].ModelDataSource.S3DataSource.S3Uri" >&2
197
147
  exit 1
198
148
  fi
199
149
  fi
200
150
  fi
201
151
 
202
- # Need artifact URI for download
203
152
  if [ -z "$MODEL_ARTIFACT_URI" ]; then
204
153
  echo "Error: ${MODEL_SOURCE} model requires artifact URI or pre-mounted artifacts at $LOCAL_MODEL_PATH" >&2
205
154
  exit 1
206
155
  fi
207
- # Download from S3
208
156
  if ! download_model_from_s3 "$MODEL_ARTIFACT_URI" "$LOCAL_MODEL_PATH"; then
209
157
  echo "Error: Failed to download model from ${MODEL_ARTIFACT_URI}" >&2
210
158
  exit 1
@@ -212,7 +160,6 @@ except:
212
160
  echo "$LOCAL_MODEL_PATH"
213
161
  ;;
214
162
  *)
215
- # Unrecognized source — treat as huggingface
216
163
  echo "${!_MODEL_VAR}"
217
164
  return
218
165
  ;;
@@ -226,89 +173,13 @@ unset _MODEL_VAR _RESOLVED_MODEL
226
173
 
227
174
  # Initialize server arguments
228
175
  <% if (modelServer === 'tensorrt-llm') { %>
229
- # port 8081 for internal TensorRT-LLM server (nginx proxies on 8080)
230
176
  SERVER_ARGS=(--host 0.0.0.0 --port 8081)
231
177
  <% } else { %>
232
- # port 8080 required by SageMaker: https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-code-container-response
233
178
  SERVER_ARGS=(--host 0.0.0.0 --port 8080)
234
179
  <% } %>
235
180
 
236
- # Define the prefix for environment variables to look for
237
- <% if (modelServer === 'vllm') { %>
238
- PREFIX="VLLM_"
239
- <% } else if (modelServer === 'sglang') { %>
240
- PREFIX="SGLANG_"
241
- <% } else if (modelServer === 'tensorrt-llm') { %>
242
- PREFIX="TRTLLM_"
243
- <% } %>
244
- ARG_PREFIX="--"
245
-
246
- # Define environment variables to exclude (internal variables set by base images)
247
- <% if (modelServer === 'vllm') { %>
248
- EXCLUDE_VARS=("VLLM_USAGE_SOURCE" "VLLM_ENABLE_CUDA_COMPATIBILITY")
249
- <% } else if (modelServer === 'sglang') { %>
250
- EXCLUDE_VARS=()
251
- <% } else if (modelServer === 'tensorrt-llm') { %>
252
- # Exclude TRTLLM_MODEL as it's used as the positional MODEL argument
253
- EXCLUDE_VARS=("TRTLLM_MODEL")
254
- <% } %>
255
-
256
- # Declare and populate array of matching environment variables
257
- mapfile -t env_vars < <(env | grep "^${PREFIX}")
258
-
259
- # Loop through the array and convert to command-line arguments
260
- for var in "${env_vars[@]}"; do
261
- IFS='=' read -r key value <<< "$var"
262
-
263
- # Skip excluded variables
264
- skip=false
265
- for exclude in "${EXCLUDE_VARS[@]}"; do
266
- if [ "$key" = "$exclude" ]; then
267
- skip=true
268
- break
269
- fi
270
- done
271
-
272
- if [ "$skip" = true ]; then
273
- continue
274
- fi
275
-
276
- # Remove prefix, convert to lowercase, and replace underscores with dashes
277
- arg_name=$(echo "${key#"${PREFIX}"}" | tr '[:upper:]' '[:lower:]' | tr '_' '-')
278
-
279
- # Boolean handling: true = flag only, false = skip entirely
280
- if [ "$value" = "false" ]; then
281
- continue
282
- fi
283
-
284
- SERVER_ARGS+=("${ARG_PREFIX}${arg_name}")
285
- if [ -n "$value" ] && [ "$value" != "true" ]; then
286
- SERVER_ARGS+=("$value")
287
- fi
288
- done
289
-
290
- echo "-------------------------------------------------------------------"
291
- <% if (modelServer === 'vllm') { %>
292
- echo "vLLM engine args: [${SERVER_ARGS[@]}]"
293
- <% } else if (modelServer === 'sglang') { %>
294
- echo "SGLang engine args: [${SERVER_ARGS[@]}]"
295
- <% } else if (modelServer === 'tensorrt-llm') { %>
296
- echo "TensorRT-LLM engine args: [${SERVER_ARGS[@]}]"
297
- <% } %>
298
- echo "-------------------------------------------------------------------"
299
-
300
- # Pass the collected arguments to the main entrypoint
301
- <% if (modelServer === 'vllm') { %>
302
- exec python3 -m vllm.entrypoints.openai.api_server "${SERVER_ARGS[@]}"
303
- <% } else if (modelServer === 'sglang') { %>
304
- exec python3 -m sglang.launch_server "${SERVER_ARGS[@]}"
305
- <% } else if (modelServer === 'tensorrt-llm') { %>
306
- # TensorRT-LLM requires the model as a positional argument
307
- # Syntax: trtllm-serve serve MODEL [OPTIONS]
308
- if [ -z "$TRTLLM_MODEL" ]; then
309
- echo "Error: TRTLLM_MODEL environment variable is not set"
310
- exit 1
311
- fi
312
- exec trtllm-serve serve "$TRTLLM_MODEL" "${SERVER_ARGS[@]}"
181
+ # --- Server-specific arg conversion and exec ---
182
+ <% if (['vllm', 'sglang', 'tensorrt-llm'].includes(modelServer)) { %>
183
+ <%- include('serve.d/' + modelServer) %>
313
184
  <% } %>
314
185
  <% } %>
@@ -0,0 +1,19 @@
1
+ # ---------------------------------------------------------------------------
2
+ # LMI / DJL Server Configuration
3
+ # ---------------------------------------------------------------------------
4
+ # Config: /opt/ml/model/serving.properties
5
+ # Entrypoint: DJL Serving (built into base image)
6
+ # Port: 8080 (configured in serving.properties)
7
+ # ---------------------------------------------------------------------------
8
+
9
+ # LMI/DJL containers use serving.properties for configuration
10
+ if [ ! -f /opt/ml/model/serving.properties ]; then
11
+ echo "Error: serving.properties not found at /opt/ml/model/serving.properties"
12
+ exit 1
13
+ fi
14
+
15
+ echo "Using configuration from /opt/ml/model/serving.properties"
16
+ cat /opt/ml/model/serving.properties
17
+
18
+ # DJL Serving is already configured in the base image entrypoint
19
+ exit 0