@aws/ml-container-creator 0.9.0 → 0.10.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/cli.js +31 -137
- package/config/parameter-schema-v2.json +2065 -0
- package/package.json +6 -3
- package/servers/lib/catalogs/jumpstart-public.json +101 -16
- package/servers/lib/catalogs/models.json +182 -26
- package/src/app.js +6 -389
- package/src/lib/bootstrap-command-handler.js +75 -1078
- package/src/lib/bootstrap-profile-manager.js +634 -0
- package/src/lib/bootstrap-provisioners.js +421 -0
- package/src/lib/config-loader.js +405 -0
- package/src/lib/config-manager.js +59 -1668
- package/src/lib/config-mcp-client.js +118 -0
- package/src/lib/config-validator.js +634 -0
- package/src/lib/cuda-resolver.js +140 -0
- package/src/lib/e2e-catalog-validator.js +251 -3
- package/src/lib/e2e-ci-recorder.js +103 -0
- package/src/lib/generated/cli-options.js +471 -0
- package/src/lib/generated/parameter-matrix.js +671 -0
- package/src/lib/generated/validation-rules.js +202 -0
- package/src/lib/marketplace-flow.js +276 -0
- package/src/lib/mcp-query-runner.js +768 -0
- package/src/lib/parameter-schema-validator.js +62 -18
- package/src/lib/prompt-runner.js +41 -1504
- package/src/lib/prompts/feature-prompts.js +172 -0
- package/src/lib/prompts/index.js +48 -0
- package/src/lib/prompts/infrastructure-prompts.js +690 -0
- package/src/lib/prompts/model-prompts.js +552 -0
- package/src/lib/prompts/project-prompts.js +70 -0
- package/src/lib/prompts.js +2 -1446
- package/src/lib/registry-command-handler.js +135 -3
- package/src/lib/secrets-prompt-runner.js +251 -0
- package/src/lib/template-variable-resolver.js +398 -0
- package/templates/code/serve +5 -134
- package/templates/code/serve.d/lmi.ejs +19 -0
- package/templates/code/serve.d/sglang.ejs +47 -0
- package/templates/code/serve.d/tensorrt-llm.ejs +53 -0
- package/templates/code/serve.d/vllm.ejs +48 -0
- package/templates/do/clean +1 -1387
- package/templates/do/clean.d/async-inference.ejs +508 -0
- package/templates/do/clean.d/batch-transform.ejs +512 -0
- package/templates/do/clean.d/hyperpod-eks.ejs +481 -0
- package/templates/do/clean.d/managed-inference.ejs +1043 -0
- package/templates/do/deploy +1 -1766
- package/templates/do/deploy.d/async-inference.ejs +501 -0
- package/templates/do/deploy.d/batch-transform.ejs +529 -0
- package/templates/do/deploy.d/hyperpod-eks.ejs +339 -0
- package/templates/do/deploy.d/managed-inference.ejs +726 -0
- package/config/parameter-schema.json +0 -88
|
@@ -0,0 +1,398 @@
|
|
|
1
|
+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
// SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
import fs from 'fs';
|
|
5
|
+
import path from 'path';
|
|
6
|
+
import { fileURLToPath } from 'url';
|
|
7
|
+
import { isTuneSupported } from './tune-catalog-validator.js';
|
|
8
|
+
|
|
9
|
+
const __filename = fileURLToPath(import.meta.url);
|
|
10
|
+
const __dirname = path.dirname(__filename);
|
|
11
|
+
|
|
12
|
+
/**
|
|
13
|
+
* Finds model configuration by exact match or glob-pattern match.
|
|
14
|
+
*
|
|
15
|
+
* @param {string} modelName - Model ID to look up
|
|
16
|
+
* @param {object} registryConfigManager - Registry configuration manager
|
|
17
|
+
* @returns {object|null} Model configuration or null
|
|
18
|
+
*/
|
|
19
|
+
function _findModelConfig(modelName, registryConfigManager) {
|
|
20
|
+
if (!registryConfigManager?.modelRegistry) return null;
|
|
21
|
+
|
|
22
|
+
// Exact match first
|
|
23
|
+
const exact = registryConfigManager.modelRegistry[modelName];
|
|
24
|
+
if (exact) return exact;
|
|
25
|
+
|
|
26
|
+
// Pattern matching with glob-style wildcards
|
|
27
|
+
for (const [pattern, config] of Object.entries(registryConfigManager.modelRegistry)) {
|
|
28
|
+
if (pattern.includes('*')) {
|
|
29
|
+
const regex = new RegExp(`^${pattern.replace(/\*/g, '.*')}$`);
|
|
30
|
+
if (regex.test(modelName)) {
|
|
31
|
+
return config;
|
|
32
|
+
}
|
|
33
|
+
}
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
return null;
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
/**
|
|
40
|
+
* Merges environment variables from all catalog sources with correct precedence.
|
|
41
|
+
* Precedence (lowest → highest):
|
|
42
|
+
* 1. catalog defaults (Image_Entry defaults.envVars)
|
|
43
|
+
* 2. framework profile (Image_Entry profiles[selectedProfile].envVars)
|
|
44
|
+
* 3. model entry (model catalog entry envVars)
|
|
45
|
+
* 4. model profile (model catalog entry profiles[selectedProfile].envVars)
|
|
46
|
+
* 5. CLI overrides (existing answers.envVars from user CLI input)
|
|
47
|
+
*
|
|
48
|
+
* @param {object} answers - Configuration answers
|
|
49
|
+
* @param {object|null} registryConfigManager - Registry configuration manager
|
|
50
|
+
*/
|
|
51
|
+
export async function _mergeEnvVarsWithPrecedence(answers, registryConfigManager) {
|
|
52
|
+
if (!registryConfigManager) return;
|
|
53
|
+
|
|
54
|
+
// Capture CLI-provided env vars before merging (highest precedence)
|
|
55
|
+
const cliEnvVars = { ...answers.envVars };
|
|
56
|
+
|
|
57
|
+
// Resolve the framework config for the selected framework + version
|
|
58
|
+
const frameworkName = answers.framework || answers.deploymentConfig;
|
|
59
|
+
const frameworkVersion = answers.frameworkVersion;
|
|
60
|
+
let frameworkConfig = null;
|
|
61
|
+
|
|
62
|
+
if (frameworkName && registryConfigManager.frameworkRegistry) {
|
|
63
|
+
const frameworkVersions = registryConfigManager.frameworkRegistry[frameworkName];
|
|
64
|
+
if (frameworkVersions) {
|
|
65
|
+
if (frameworkVersion && frameworkVersions[frameworkVersion]) {
|
|
66
|
+
frameworkConfig = frameworkVersions[frameworkVersion];
|
|
67
|
+
} else {
|
|
68
|
+
// Fall back to latest version for Triton and other non-versioned lookups
|
|
69
|
+
const versions = Object.keys(frameworkVersions).sort((a, b) =>
|
|
70
|
+
b.localeCompare(a, undefined, { numeric: true })
|
|
71
|
+
);
|
|
72
|
+
if (versions.length > 0) {
|
|
73
|
+
frameworkConfig = frameworkVersions[versions[0]];
|
|
74
|
+
}
|
|
75
|
+
}
|
|
76
|
+
}
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
// Resolve the model config (exact match or pattern match)
|
|
80
|
+
let modelConfig = null;
|
|
81
|
+
if (answers.modelName && registryConfigManager.modelRegistry) {
|
|
82
|
+
modelConfig = _findModelConfig(answers.modelName, registryConfigManager);
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
// Layer 1: catalog defaults (Image_Entry defaults.envVars)
|
|
86
|
+
const catalogDefaults = frameworkConfig?.envVars || {};
|
|
87
|
+
|
|
88
|
+
// Layer 2: framework profile envVars
|
|
89
|
+
let frameworkProfileEnvVars = {};
|
|
90
|
+
if (answers.frameworkProfile && frameworkConfig?.profiles) {
|
|
91
|
+
const profile = frameworkConfig.profiles[answers.frameworkProfile];
|
|
92
|
+
if (profile?.envVars) {
|
|
93
|
+
frameworkProfileEnvVars = profile.envVars;
|
|
94
|
+
}
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
// Layer 3: model entry envVars
|
|
98
|
+
const modelEntryEnvVars = modelConfig?.envVars || {};
|
|
99
|
+
|
|
100
|
+
// Layer 4: model profile envVars
|
|
101
|
+
let modelProfileEnvVars = {};
|
|
102
|
+
if (answers.modelProfile && modelConfig?.profiles) {
|
|
103
|
+
const profile = modelConfig.profiles[answers.modelProfile];
|
|
104
|
+
if (profile?.envVars) {
|
|
105
|
+
modelProfileEnvVars = profile.envVars;
|
|
106
|
+
}
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
// Layer 5: CLI overrides (captured above)
|
|
110
|
+
|
|
111
|
+
// Merge in precedence order: each layer overrides the previous
|
|
112
|
+
answers.envVars = {
|
|
113
|
+
...catalogDefaults,
|
|
114
|
+
...frameworkProfileEnvVars,
|
|
115
|
+
...modelEntryEnvVars,
|
|
116
|
+
...modelProfileEnvVars,
|
|
117
|
+
...cliEnvVars
|
|
118
|
+
};
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
/**
|
|
122
|
+
* Validates environment variables using the registry system.
|
|
123
|
+
* Displays errors and warnings to the user.
|
|
124
|
+
*
|
|
125
|
+
* @param {object} answers - Configuration answers
|
|
126
|
+
* @param {object} registryConfigManager - Registry configuration manager
|
|
127
|
+
*/
|
|
128
|
+
export async function _validateEnvironmentVariables(answers, registryConfigManager) {
|
|
129
|
+
// Get framework configuration
|
|
130
|
+
// For Triton configs, look up using deploymentConfig key (e.g. 'triton-fil')
|
|
131
|
+
let frameworkConfig;
|
|
132
|
+
if (answers.architecture === 'triton' && answers.deploymentConfig) {
|
|
133
|
+
const tritonEntry = registryConfigManager.frameworkRegistry?.[answers.deploymentConfig];
|
|
134
|
+
if (tritonEntry) {
|
|
135
|
+
const versions = Object.keys(tritonEntry);
|
|
136
|
+
if (versions.length > 0) {
|
|
137
|
+
frameworkConfig = tritonEntry[versions[0]];
|
|
138
|
+
}
|
|
139
|
+
}
|
|
140
|
+
}
|
|
141
|
+
if (!frameworkConfig) {
|
|
142
|
+
frameworkConfig = registryConfigManager.frameworkRegistry?.[answers.framework]?.[answers.frameworkVersion];
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
if (!frameworkConfig || !frameworkConfig.envVars) {
|
|
146
|
+
return; // No env vars to validate
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
console.log('\n🔍 Validating environment variables...');
|
|
150
|
+
|
|
151
|
+
// Validate environment variables
|
|
152
|
+
const validationResult = registryConfigManager.validateEnvironmentVariables(
|
|
153
|
+
frameworkConfig.envVars,
|
|
154
|
+
frameworkConfig
|
|
155
|
+
);
|
|
156
|
+
|
|
157
|
+
// Display validation results
|
|
158
|
+
if (validationResult.errors && validationResult.errors.length > 0) {
|
|
159
|
+
console.log('\n❌ Environment Variable Validation Errors:');
|
|
160
|
+
validationResult.errors.forEach(error => {
|
|
161
|
+
console.log(` • ${error.key}: ${error.message}`);
|
|
162
|
+
});
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
if (validationResult.warnings && validationResult.warnings.length > 0) {
|
|
166
|
+
console.log('\n⚠️ Environment Variable Validation Warnings:');
|
|
167
|
+
validationResult.warnings.forEach(warning => {
|
|
168
|
+
console.log(` • ${warning.key ? `${warning.key}: ` : ''}${warning.message}`);
|
|
169
|
+
});
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
if (validationResult.strategiesUsed && validationResult.strategiesUsed.length > 0) {
|
|
173
|
+
console.log(`\n✅ Validation methods used: ${validationResult.strategiesUsed.join(', ')}`);
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
if (!validationResult.errors || validationResult.errors.length === 0) {
|
|
177
|
+
if (!validationResult.warnings || validationResult.warnings.length === 0) {
|
|
178
|
+
console.log(' ✅ All environment variables validated successfully');
|
|
179
|
+
}
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
// In non-interactive mode (skip-prompts), throw on errors
|
|
183
|
+
if (validationResult.errors && validationResult.errors.length > 0) {
|
|
184
|
+
throw new Error('Environment variable validation failed. Please fix the errors and try again.');
|
|
185
|
+
}
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
/**
|
|
189
|
+
* Ensures all template variables have proper defaults to prevent
|
|
190
|
+
* "undefined" errors in EJS templates. Also enriches answers with
|
|
191
|
+
* registry data (env var merging, HuggingFace data, Triton base image).
|
|
192
|
+
*
|
|
193
|
+
* @param {object} answers - Answers object to fill defaults into
|
|
194
|
+
* @param {object|null} registryConfigManager - Registry configuration manager (or null)
|
|
195
|
+
*/
|
|
196
|
+
export async function _ensureTemplateVariables(answers, registryConfigManager = null) {
|
|
197
|
+
const defaults = {
|
|
198
|
+
chatTemplate: null,
|
|
199
|
+
chatTemplateSource: null,
|
|
200
|
+
hfToken: null,
|
|
201
|
+
hfTokenArn: null,
|
|
202
|
+
ngcApiKey: null,
|
|
203
|
+
ngcTokenArn: null,
|
|
204
|
+
envVars: {},
|
|
205
|
+
inferenceAmiVersion: null,
|
|
206
|
+
accelerator: null,
|
|
207
|
+
frameworkVersion: null,
|
|
208
|
+
validationLevel: 'unknown',
|
|
209
|
+
configSources: [],
|
|
210
|
+
recommendedInstanceTypes: [],
|
|
211
|
+
roleArn: null,
|
|
212
|
+
deploymentConfig: '',
|
|
213
|
+
architecture: null,
|
|
214
|
+
backend: null,
|
|
215
|
+
engine: null,
|
|
216
|
+
codebuildComputeType: null,
|
|
217
|
+
codebuildProjectName: null,
|
|
218
|
+
modelName: null,
|
|
219
|
+
modelFormat: null,
|
|
220
|
+
includeSampleModel: true,
|
|
221
|
+
includeTesting: true,
|
|
222
|
+
testTypes: [],
|
|
223
|
+
buildTimestamp: new Date().toISOString(),
|
|
224
|
+
buildTarget: 'codebuild',
|
|
225
|
+
deploymentTarget: 'realtime-inference',
|
|
226
|
+
hyperPodCluster: null,
|
|
227
|
+
hyperPodNamespace: 'default',
|
|
228
|
+
hyperPodReplicas: 1,
|
|
229
|
+
fsxVolumeHandle: null,
|
|
230
|
+
baseImage: null,
|
|
231
|
+
modelSource: 'huggingface',
|
|
232
|
+
artifactUri: '',
|
|
233
|
+
modelLoadStrategy: 'runtime',
|
|
234
|
+
existingEndpointName: null,
|
|
235
|
+
enableLora: false,
|
|
236
|
+
maxLoras: 30,
|
|
237
|
+
maxLoraRank: 64
|
|
238
|
+
};
|
|
239
|
+
|
|
240
|
+
Object.entries(defaults).forEach(([key, value]) => {
|
|
241
|
+
if (answers[key] === undefined) {
|
|
242
|
+
answers[key] = value;
|
|
243
|
+
}
|
|
244
|
+
});
|
|
245
|
+
|
|
246
|
+
// Backward compatibility: populate framework and modelServer from architecture/backend
|
|
247
|
+
if (!answers.framework && answers.architecture) {
|
|
248
|
+
answers.framework = answers.architecture;
|
|
249
|
+
}
|
|
250
|
+
if (!answers.modelServer && answers.backend) {
|
|
251
|
+
answers.modelServer = answers.backend;
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
// Always include testing with all available test types
|
|
255
|
+
answers.includeTesting = true;
|
|
256
|
+
if (!answers.testTypes || answers.testTypes.length === 0) {
|
|
257
|
+
if (answers.architecture === 'transformers' || answers.framework === 'transformers') {
|
|
258
|
+
answers.testTypes = ['hosted-model-endpoint'];
|
|
259
|
+
} else {
|
|
260
|
+
answers.testTypes = ['local-model-cli', 'local-model-server', 'hosted-model-endpoint'];
|
|
261
|
+
}
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
// Merge catalog env vars into answers.envVars with correct precedence
|
|
265
|
+
await _mergeEnvVarsWithPrecedence(answers, registryConfigManager);
|
|
266
|
+
|
|
267
|
+
// For Triton architecture, set default base image fallback
|
|
268
|
+
if (answers.architecture === 'triton' && !answers.baseImage) {
|
|
269
|
+
// Try to look up base image from framework registry using deployment-config key
|
|
270
|
+
const tritonRegistryKey = answers.deploymentConfig;
|
|
271
|
+
if (tritonRegistryKey && registryConfigManager?.frameworkRegistry) {
|
|
272
|
+
const tritonFrameworkConfig = registryConfigManager.frameworkRegistry[tritonRegistryKey];
|
|
273
|
+
if (tritonFrameworkConfig) {
|
|
274
|
+
const versions = Object.keys(tritonFrameworkConfig).sort((a, b) =>
|
|
275
|
+
b.localeCompare(a, undefined, { numeric: true })
|
|
276
|
+
);
|
|
277
|
+
if (versions.length > 0) {
|
|
278
|
+
const latestConfig = tritonFrameworkConfig[versions[0]];
|
|
279
|
+
if (latestConfig.baseImage) {
|
|
280
|
+
answers.baseImage = latestConfig.baseImage;
|
|
281
|
+
}
|
|
282
|
+
if (latestConfig.inferenceAmiVersion && !answers.inferenceAmiVersion) {
|
|
283
|
+
answers.inferenceAmiVersion = latestConfig.inferenceAmiVersion;
|
|
284
|
+
}
|
|
285
|
+
if (latestConfig.accelerator) {
|
|
286
|
+
answers.accelerator = latestConfig.accelerator;
|
|
287
|
+
}
|
|
288
|
+
}
|
|
289
|
+
}
|
|
290
|
+
}
|
|
291
|
+
// Final fallback: hardcoded default Triton base image
|
|
292
|
+
if (!answers.baseImage) {
|
|
293
|
+
answers.baseImage = 'nvcr.io/nvidia/tritonserver:24.08-py3';
|
|
294
|
+
}
|
|
295
|
+
}
|
|
296
|
+
|
|
297
|
+
// For transformer models, enrich with HuggingFace data and non-envVar metadata
|
|
298
|
+
if (answers.framework === 'transformers' && answers.modelName && registryConfigManager) {
|
|
299
|
+
try {
|
|
300
|
+
// Fetch HuggingFace data for model-specific info
|
|
301
|
+
const hfData = await registryConfigManager._fetchHuggingFaceData(answers.modelName);
|
|
302
|
+
|
|
303
|
+
// Merge chatTemplate if available and not already set
|
|
304
|
+
if (hfData && hfData.chatTemplate && !answers.chatTemplate) {
|
|
305
|
+
answers.chatTemplate = hfData.chatTemplate;
|
|
306
|
+
answers.chatTemplateSource = 'HuggingFace_Hub_API';
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
// Check Model Registry for chatTemplate overrides
|
|
310
|
+
if (registryConfigManager.modelRegistry) {
|
|
311
|
+
const modelConfig = _findModelConfig(answers.modelName, registryConfigManager);
|
|
312
|
+
|
|
313
|
+
if (modelConfig && modelConfig.chatTemplate) {
|
|
314
|
+
answers.chatTemplate = modelConfig.chatTemplate;
|
|
315
|
+
answers.chatTemplateSource = 'Model_Registry';
|
|
316
|
+
}
|
|
317
|
+
}
|
|
318
|
+
|
|
319
|
+
// Set framework-level metadata (non-envVar fields)
|
|
320
|
+
if (answers.frameworkVersion && registryConfigManager.frameworkRegistry) {
|
|
321
|
+
const frameworkConfig = registryConfigManager.frameworkRegistry[answers.framework]?.[answers.frameworkVersion];
|
|
322
|
+
|
|
323
|
+
if (frameworkConfig) {
|
|
324
|
+
if (frameworkConfig.inferenceAmiVersion && !answers.inferenceAmiVersion) {
|
|
325
|
+
answers.inferenceAmiVersion = frameworkConfig.inferenceAmiVersion;
|
|
326
|
+
}
|
|
327
|
+
if (frameworkConfig.accelerator) {
|
|
328
|
+
answers.accelerator = frameworkConfig.accelerator;
|
|
329
|
+
}
|
|
330
|
+
}
|
|
331
|
+
}
|
|
332
|
+
} catch (error) {
|
|
333
|
+
// Silently continue - defaults are already set
|
|
334
|
+
}
|
|
335
|
+
}
|
|
336
|
+
|
|
337
|
+
// Populate baseImage from the catalog when still falsy (covers --skip-prompts and
|
|
338
|
+
// cases where MCP/CLI/config did not provide a base image).
|
|
339
|
+
// Precedence: MCP > CLI > config > catalog default (this block).
|
|
340
|
+
if (!answers.baseImage && registryConfigManager?.frameworkRegistry) {
|
|
341
|
+
const backendKey = answers.backend || answers.modelServer;
|
|
342
|
+
if (backendKey) {
|
|
343
|
+
const frameworkVersions = registryConfigManager.frameworkRegistry[backendKey];
|
|
344
|
+
if (frameworkVersions) {
|
|
345
|
+
let resolvedConfig = null;
|
|
346
|
+
if (answers.frameworkVersion && frameworkVersions[answers.frameworkVersion]) {
|
|
347
|
+
resolvedConfig = frameworkVersions[answers.frameworkVersion];
|
|
348
|
+
} else {
|
|
349
|
+
// Fall back to latest version
|
|
350
|
+
const versions = Object.keys(frameworkVersions).sort((a, b) =>
|
|
351
|
+
b.localeCompare(a, undefined, { numeric: true })
|
|
352
|
+
);
|
|
353
|
+
if (versions.length > 0) {
|
|
354
|
+
resolvedConfig = frameworkVersions[versions[0]];
|
|
355
|
+
}
|
|
356
|
+
}
|
|
357
|
+
if (resolvedConfig?.baseImage) {
|
|
358
|
+
answers.baseImage = resolvedConfig.baseImage;
|
|
359
|
+
}
|
|
360
|
+
}
|
|
361
|
+
}
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
// Populate icGpuCount from instance catalog when not explicitly set.
|
|
365
|
+
// The deploy template uses IC_GPU_COUNT unconditionally for NumberOfAcceleratorDevicesRequired,
|
|
366
|
+
// so it must always have a value for GPU deployments.
|
|
367
|
+
if ((answers.icGpuCount === null || answers.icGpuCount === undefined) && answers.instanceType) {
|
|
368
|
+
// Use gpuCount from instance-sizer recommendation if available
|
|
369
|
+
if (answers.gpuCount) {
|
|
370
|
+
answers.icGpuCount = answers.gpuCount;
|
|
371
|
+
} else {
|
|
372
|
+
// Look up from instances catalog
|
|
373
|
+
try {
|
|
374
|
+
const catalogPath = path.resolve(__dirname, '..', '..', 'servers', 'lib', 'catalogs', 'instances.json');
|
|
375
|
+
const catalogData = JSON.parse(fs.readFileSync(catalogPath, 'utf-8'));
|
|
376
|
+
const instanceInfo = catalogData?.catalog?.[answers.instanceType];
|
|
377
|
+
if (instanceInfo?.gpus && instanceInfo.gpus > 0) {
|
|
378
|
+
answers.icGpuCount = instanceInfo.gpus;
|
|
379
|
+
}
|
|
380
|
+
} catch {
|
|
381
|
+
// Silently continue — template fallback handles missing value
|
|
382
|
+
}
|
|
383
|
+
}
|
|
384
|
+
}
|
|
385
|
+
|
|
386
|
+
// Determine tune support based on model presence in the tune catalog.
|
|
387
|
+
// Used by the do/config template to write TUNE_SUPPORTED=true|false.
|
|
388
|
+
if (answers.tuneSupported === undefined) {
|
|
389
|
+
try {
|
|
390
|
+
const tuneCatalogPath = path.resolve(__dirname, '..', '..', 'config', 'tune-catalog.json');
|
|
391
|
+
const tuneCatalog = JSON.parse(fs.readFileSync(tuneCatalogPath, 'utf-8'));
|
|
392
|
+
const modelId = answers.modelName || '';
|
|
393
|
+
answers.tuneSupported = isTuneSupported(modelId, tuneCatalog);
|
|
394
|
+
} catch {
|
|
395
|
+
answers.tuneSupported = false;
|
|
396
|
+
}
|
|
397
|
+
}
|
|
398
|
+
}
|
package/templates/code/serve
CHANGED
|
@@ -10,35 +10,10 @@ echo "$(date -u '+%Y-%m-%dT%H:%M:%SZ') [serve] Container started — PID $$"
|
|
|
10
10
|
# CUDA compatibility setup (required for newer SageMaker inference AMIs)
|
|
11
11
|
source /usr/bin/cuda_compat.sh 2>/dev/null || true
|
|
12
12
|
|
|
13
|
-
|
|
14
|
-
echo "Starting vLLM server"
|
|
15
|
-
<% } else if (modelServer === 'sglang') { %>
|
|
16
|
-
echo "Starting SGLang server"
|
|
17
|
-
<% } else if (modelServer === 'tensorrt-llm') { %>
|
|
18
|
-
echo "Starting TensorRT-LLM server"
|
|
19
|
-
<% } else if (modelServer === 'lmi') { %>
|
|
20
|
-
echo "Starting LMI (Large Model Inference) server"
|
|
21
|
-
<% } else if (modelServer === 'djl') { %>
|
|
22
|
-
echo "Starting DJL Serving server"
|
|
23
|
-
<% } %>
|
|
13
|
+
echo "Starting <%= modelServer %> server"
|
|
24
14
|
|
|
25
15
|
<% if (modelServer === 'lmi' || modelServer === 'djl') { %>
|
|
26
|
-
|
|
27
|
-
# The configuration file should be at /opt/ml/model/serving.properties
|
|
28
|
-
# DJL Serving will automatically start with this configuration
|
|
29
|
-
|
|
30
|
-
if [ ! -f /opt/ml/model/serving.properties ]; then
|
|
31
|
-
echo "Error: serving.properties not found at /opt/ml/model/serving.properties"
|
|
32
|
-
exit 1
|
|
33
|
-
fi
|
|
34
|
-
|
|
35
|
-
echo "Using configuration from /opt/ml/model/serving.properties"
|
|
36
|
-
cat /opt/ml/model/serving.properties
|
|
37
|
-
|
|
38
|
-
# DJL Serving is already configured in the base image
|
|
39
|
-
# This script is not typically needed for LMI/DJL as they have their own entrypoint
|
|
40
|
-
# But we provide it for consistency with other model servers
|
|
41
|
-
exit 0
|
|
16
|
+
<%- include('serve.d/lmi') %>
|
|
42
17
|
<% } else { %>
|
|
43
18
|
|
|
44
19
|
<% if (typeof modelSource !== 'undefined' && modelSource !== 'huggingface') { %>
|
|
@@ -60,7 +35,6 @@ download_model_from_s3() {
|
|
|
60
35
|
mkdir -p "${dest_path}"
|
|
61
36
|
|
|
62
37
|
if [[ "$s3_uri" == *.tar.gz ]] || [[ "$s3_uri" == *.tgz ]]; then
|
|
63
|
-
# Tarball: download and extract
|
|
64
38
|
if ! aws s3 cp "$s3_uri" /tmp/model_archive.tar.gz; then
|
|
65
39
|
echo "Error: Failed to download tarball from ${s3_uri}" >&2
|
|
66
40
|
return 1
|
|
@@ -72,13 +46,11 @@ download_model_from_s3() {
|
|
|
72
46
|
fi
|
|
73
47
|
rm -f /tmp/model_archive.tar.gz
|
|
74
48
|
elif [[ "$s3_uri" == */ ]] || ! aws s3 ls "$s3_uri" 2>/dev/null | grep -q "^[0-9]"; then
|
|
75
|
-
# Directory prefix: sync
|
|
76
49
|
if ! aws s3 sync "$s3_uri" "$dest_path"; then
|
|
77
50
|
echo "Error: Failed to sync from ${s3_uri}" >&2
|
|
78
51
|
return 1
|
|
79
52
|
fi
|
|
80
53
|
else
|
|
81
|
-
# Single file: copy
|
|
82
54
|
if ! aws s3 cp "$s3_uri" "$dest_path/"; then
|
|
83
55
|
echo "Error: Failed to copy ${s3_uri}" >&2
|
|
84
56
|
return 1
|
|
@@ -109,19 +81,16 @@ _MODEL_VAR="TRTLLM_MODEL"
|
|
|
109
81
|
resolve_model() {
|
|
110
82
|
case "$MODEL_SOURCE" in
|
|
111
83
|
huggingface)
|
|
112
|
-
# Pass model name directly — server fetches from HF Hub
|
|
113
84
|
echo "${!_MODEL_VAR}"
|
|
114
85
|
return
|
|
115
86
|
;;
|
|
116
87
|
s3|registry)
|
|
117
|
-
# Check for pre-mounted artifacts first
|
|
118
88
|
if [ -d "$LOCAL_MODEL_PATH" ] && [ "$(ls -A $LOCAL_MODEL_PATH 2>/dev/null)" ]; then
|
|
119
89
|
echo "Using pre-mounted model artifacts at $LOCAL_MODEL_PATH" >&2
|
|
120
90
|
echo "$LOCAL_MODEL_PATH"
|
|
121
91
|
return
|
|
122
92
|
fi
|
|
123
93
|
|
|
124
|
-
# For registry:// models, resolve artifact URI at runtime via SageMaker API
|
|
125
94
|
if [ "$MODEL_SOURCE" = "registry" ] && [ -z "$MODEL_ARTIFACT_URI" ]; then
|
|
126
95
|
local model_uri="${!_MODEL_VAR}"
|
|
127
96
|
local registry_prefix="registry://"
|
|
@@ -131,7 +100,6 @@ resolve_model() {
|
|
|
131
100
|
local version="${registry_path#*/}"
|
|
132
101
|
local region="${AWS_REGION:-${AWS_DEFAULT_REGION:-us-east-1}}"
|
|
133
102
|
|
|
134
|
-
# Get account ID for ARN construction
|
|
135
103
|
local account_id
|
|
136
104
|
account_id=$(aws sts get-caller-identity --query Account --output text 2>/dev/null) || {
|
|
137
105
|
echo "Error: Failed to get AWS account ID for model package ARN" >&2
|
|
@@ -151,38 +119,22 @@ resolve_model() {
|
|
|
151
119
|
exit 1
|
|
152
120
|
}
|
|
153
121
|
|
|
154
|
-
# Try ModelDataUrl first, then S3DataSource.S3Uri, then description
|
|
155
122
|
MODEL_ARTIFACT_URI=$(echo "$describe_output" | python3 -c "
|
|
156
123
|
import sys, json, re
|
|
157
124
|
try:
|
|
158
125
|
pkg = json.load(sys.stdin)
|
|
159
126
|
uri = ''
|
|
160
|
-
# Check InferenceSpecification.Containers[0]
|
|
161
127
|
containers = pkg.get('InferenceSpecification', {}).get('Containers', [])
|
|
162
128
|
if containers:
|
|
163
129
|
c = containers[0]
|
|
164
130
|
uri = c.get('ModelDataUrl', '')
|
|
165
131
|
if not uri:
|
|
166
132
|
uri = c.get('ModelDataSource', {}).get('S3DataSource', {}).get('S3Uri', '')
|
|
167
|
-
# Fallback: extract S3 URI from ModelPackageDescription
|
|
168
133
|
if not uri:
|
|
169
134
|
desc = pkg.get('ModelPackageDescription', '')
|
|
170
135
|
m = re.search(r's3://[^\s]+', desc)
|
|
171
136
|
if m:
|
|
172
137
|
uri = m.group(0)
|
|
173
|
-
# Fallback: check ModelCard hyperparameters for model_artifacts_s3
|
|
174
|
-
if not uri:
|
|
175
|
-
try:
|
|
176
|
-
card = pkg.get('ModelCard', {})
|
|
177
|
-
content = card.get('ModelCardContent', '{}')
|
|
178
|
-
card_data = json.loads(content) if isinstance(content, str) else content
|
|
179
|
-
params = card_data.get('training_details', {}).get('training_job_details', {}).get('hyper_parameters', [])
|
|
180
|
-
for p in params:
|
|
181
|
-
if p.get('name') == 'model_artifacts_s3':
|
|
182
|
-
uri = p.get('value', '')
|
|
183
|
-
break
|
|
184
|
-
except:
|
|
185
|
-
pass
|
|
186
138
|
print(uri)
|
|
187
139
|
except:
|
|
188
140
|
print('')
|
|
@@ -192,19 +144,15 @@ except:
|
|
|
192
144
|
echo "Resolved artifact URI: ${MODEL_ARTIFACT_URI}" >&2
|
|
193
145
|
else
|
|
194
146
|
echo "Error: No model artifact URI found in model package: ${package_arn}" >&2
|
|
195
|
-
echo " Checked: InferenceSpecification.Containers[0].ModelDataUrl" >&2
|
|
196
|
-
echo " Checked: InferenceSpecification.Containers[0].ModelDataSource.S3DataSource.S3Uri" >&2
|
|
197
147
|
exit 1
|
|
198
148
|
fi
|
|
199
149
|
fi
|
|
200
150
|
fi
|
|
201
151
|
|
|
202
|
-
# Need artifact URI for download
|
|
203
152
|
if [ -z "$MODEL_ARTIFACT_URI" ]; then
|
|
204
153
|
echo "Error: ${MODEL_SOURCE} model requires artifact URI or pre-mounted artifacts at $LOCAL_MODEL_PATH" >&2
|
|
205
154
|
exit 1
|
|
206
155
|
fi
|
|
207
|
-
# Download from S3
|
|
208
156
|
if ! download_model_from_s3 "$MODEL_ARTIFACT_URI" "$LOCAL_MODEL_PATH"; then
|
|
209
157
|
echo "Error: Failed to download model from ${MODEL_ARTIFACT_URI}" >&2
|
|
210
158
|
exit 1
|
|
@@ -212,7 +160,6 @@ except:
|
|
|
212
160
|
echo "$LOCAL_MODEL_PATH"
|
|
213
161
|
;;
|
|
214
162
|
*)
|
|
215
|
-
# Unrecognized source — treat as huggingface
|
|
216
163
|
echo "${!_MODEL_VAR}"
|
|
217
164
|
return
|
|
218
165
|
;;
|
|
@@ -226,89 +173,13 @@ unset _MODEL_VAR _RESOLVED_MODEL
|
|
|
226
173
|
|
|
227
174
|
# Initialize server arguments
|
|
228
175
|
<% if (modelServer === 'tensorrt-llm') { %>
|
|
229
|
-
# port 8081 for internal TensorRT-LLM server (nginx proxies on 8080)
|
|
230
176
|
SERVER_ARGS=(--host 0.0.0.0 --port 8081)
|
|
231
177
|
<% } else { %>
|
|
232
|
-
# port 8080 required by SageMaker: https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-inference-code.html#your-algorithms-inference-code-container-response
|
|
233
178
|
SERVER_ARGS=(--host 0.0.0.0 --port 8080)
|
|
234
179
|
<% } %>
|
|
235
180
|
|
|
236
|
-
#
|
|
237
|
-
<% if (
|
|
238
|
-
|
|
239
|
-
<% } else if (modelServer === 'sglang') { %>
|
|
240
|
-
PREFIX="SGLANG_"
|
|
241
|
-
<% } else if (modelServer === 'tensorrt-llm') { %>
|
|
242
|
-
PREFIX="TRTLLM_"
|
|
243
|
-
<% } %>
|
|
244
|
-
ARG_PREFIX="--"
|
|
245
|
-
|
|
246
|
-
# Define environment variables to exclude (internal variables set by base images)
|
|
247
|
-
<% if (modelServer === 'vllm') { %>
|
|
248
|
-
EXCLUDE_VARS=("VLLM_USAGE_SOURCE" "VLLM_ENABLE_CUDA_COMPATIBILITY")
|
|
249
|
-
<% } else if (modelServer === 'sglang') { %>
|
|
250
|
-
EXCLUDE_VARS=()
|
|
251
|
-
<% } else if (modelServer === 'tensorrt-llm') { %>
|
|
252
|
-
# Exclude TRTLLM_MODEL as it's used as the positional MODEL argument
|
|
253
|
-
EXCLUDE_VARS=("TRTLLM_MODEL")
|
|
254
|
-
<% } %>
|
|
255
|
-
|
|
256
|
-
# Declare and populate array of matching environment variables
|
|
257
|
-
mapfile -t env_vars < <(env | grep "^${PREFIX}")
|
|
258
|
-
|
|
259
|
-
# Loop through the array and convert to command-line arguments
|
|
260
|
-
for var in "${env_vars[@]}"; do
|
|
261
|
-
IFS='=' read -r key value <<< "$var"
|
|
262
|
-
|
|
263
|
-
# Skip excluded variables
|
|
264
|
-
skip=false
|
|
265
|
-
for exclude in "${EXCLUDE_VARS[@]}"; do
|
|
266
|
-
if [ "$key" = "$exclude" ]; then
|
|
267
|
-
skip=true
|
|
268
|
-
break
|
|
269
|
-
fi
|
|
270
|
-
done
|
|
271
|
-
|
|
272
|
-
if [ "$skip" = true ]; then
|
|
273
|
-
continue
|
|
274
|
-
fi
|
|
275
|
-
|
|
276
|
-
# Remove prefix, convert to lowercase, and replace underscores with dashes
|
|
277
|
-
arg_name=$(echo "${key#"${PREFIX}"}" | tr '[:upper:]' '[:lower:]' | tr '_' '-')
|
|
278
|
-
|
|
279
|
-
# Boolean handling: true = flag only, false = skip entirely
|
|
280
|
-
if [ "$value" = "false" ]; then
|
|
281
|
-
continue
|
|
282
|
-
fi
|
|
283
|
-
|
|
284
|
-
SERVER_ARGS+=("${ARG_PREFIX}${arg_name}")
|
|
285
|
-
if [ -n "$value" ] && [ "$value" != "true" ]; then
|
|
286
|
-
SERVER_ARGS+=("$value")
|
|
287
|
-
fi
|
|
288
|
-
done
|
|
289
|
-
|
|
290
|
-
echo "-------------------------------------------------------------------"
|
|
291
|
-
<% if (modelServer === 'vllm') { %>
|
|
292
|
-
echo "vLLM engine args: [${SERVER_ARGS[@]}]"
|
|
293
|
-
<% } else if (modelServer === 'sglang') { %>
|
|
294
|
-
echo "SGLang engine args: [${SERVER_ARGS[@]}]"
|
|
295
|
-
<% } else if (modelServer === 'tensorrt-llm') { %>
|
|
296
|
-
echo "TensorRT-LLM engine args: [${SERVER_ARGS[@]}]"
|
|
297
|
-
<% } %>
|
|
298
|
-
echo "-------------------------------------------------------------------"
|
|
299
|
-
|
|
300
|
-
# Pass the collected arguments to the main entrypoint
|
|
301
|
-
<% if (modelServer === 'vllm') { %>
|
|
302
|
-
exec python3 -m vllm.entrypoints.openai.api_server "${SERVER_ARGS[@]}"
|
|
303
|
-
<% } else if (modelServer === 'sglang') { %>
|
|
304
|
-
exec python3 -m sglang.launch_server "${SERVER_ARGS[@]}"
|
|
305
|
-
<% } else if (modelServer === 'tensorrt-llm') { %>
|
|
306
|
-
# TensorRT-LLM requires the model as a positional argument
|
|
307
|
-
# Syntax: trtllm-serve serve MODEL [OPTIONS]
|
|
308
|
-
if [ -z "$TRTLLM_MODEL" ]; then
|
|
309
|
-
echo "Error: TRTLLM_MODEL environment variable is not set"
|
|
310
|
-
exit 1
|
|
311
|
-
fi
|
|
312
|
-
exec trtllm-serve serve "$TRTLLM_MODEL" "${SERVER_ARGS[@]}"
|
|
181
|
+
# --- Server-specific arg conversion and exec ---
|
|
182
|
+
<% if (['vllm', 'sglang', 'tensorrt-llm'].includes(modelServer)) { %>
|
|
183
|
+
<%- include('serve.d/' + modelServer) %>
|
|
313
184
|
<% } %>
|
|
314
185
|
<% } %>
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# ---------------------------------------------------------------------------
|
|
2
|
+
# LMI / DJL Server Configuration
|
|
3
|
+
# ---------------------------------------------------------------------------
|
|
4
|
+
# Config: /opt/ml/model/serving.properties
|
|
5
|
+
# Entrypoint: DJL Serving (built into base image)
|
|
6
|
+
# Port: 8080 (configured in serving.properties)
|
|
7
|
+
# ---------------------------------------------------------------------------
|
|
8
|
+
|
|
9
|
+
# LMI/DJL containers use serving.properties for configuration
|
|
10
|
+
if [ ! -f /opt/ml/model/serving.properties ]; then
|
|
11
|
+
echo "Error: serving.properties not found at /opt/ml/model/serving.properties"
|
|
12
|
+
exit 1
|
|
13
|
+
fi
|
|
14
|
+
|
|
15
|
+
echo "Using configuration from /opt/ml/model/serving.properties"
|
|
16
|
+
cat /opt/ml/model/serving.properties
|
|
17
|
+
|
|
18
|
+
# DJL Serving is already configured in the base image entrypoint
|
|
19
|
+
exit 0
|