@aws/ml-container-creator 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +202 -0
- package/LICENSE-THIRD-PARTY +68620 -0
- package/NOTICE +2 -0
- package/README.md +106 -0
- package/bin/cli.js +365 -0
- package/config/defaults.json +32 -0
- package/config/presets/transformers-djl.json +26 -0
- package/config/presets/transformers-gpu.json +24 -0
- package/config/presets/transformers-lmi.json +27 -0
- package/package.json +129 -0
- package/servers/README.md +419 -0
- package/servers/base-image-picker/catalogs/model-servers.json +1191 -0
- package/servers/base-image-picker/catalogs/python-slim.json +38 -0
- package/servers/base-image-picker/catalogs/triton-backends.json +51 -0
- package/servers/base-image-picker/catalogs/triton.json +38 -0
- package/servers/base-image-picker/index.js +495 -0
- package/servers/base-image-picker/manifest.json +17 -0
- package/servers/base-image-picker/package.json +15 -0
- package/servers/hyperpod-cluster-picker/LICENSE +202 -0
- package/servers/hyperpod-cluster-picker/index.js +424 -0
- package/servers/hyperpod-cluster-picker/manifest.json +14 -0
- package/servers/hyperpod-cluster-picker/package.json +17 -0
- package/servers/instance-recommender/LICENSE +202 -0
- package/servers/instance-recommender/catalogs/instances.json +852 -0
- package/servers/instance-recommender/index.js +284 -0
- package/servers/instance-recommender/manifest.json +16 -0
- package/servers/instance-recommender/package.json +15 -0
- package/servers/lib/LICENSE +202 -0
- package/servers/lib/bedrock-client.js +160 -0
- package/servers/lib/custom-validators.js +46 -0
- package/servers/lib/dynamic-resolver.js +36 -0
- package/servers/lib/package.json +11 -0
- package/servers/lib/schemas/image-catalog.schema.json +185 -0
- package/servers/lib/schemas/instances.schema.json +124 -0
- package/servers/lib/schemas/manifest.schema.json +64 -0
- package/servers/lib/schemas/model-catalog.schema.json +91 -0
- package/servers/lib/schemas/regions.schema.json +26 -0
- package/servers/lib/schemas/triton-backends.schema.json +51 -0
- package/servers/model-picker/catalogs/jumpstart-public.json +66 -0
- package/servers/model-picker/catalogs/popular-diffusors.json +88 -0
- package/servers/model-picker/catalogs/popular-transformers.json +226 -0
- package/servers/model-picker/index.js +1693 -0
- package/servers/model-picker/manifest.json +18 -0
- package/servers/model-picker/package.json +20 -0
- package/servers/region-picker/LICENSE +202 -0
- package/servers/region-picker/catalogs/regions.json +263 -0
- package/servers/region-picker/index.js +230 -0
- package/servers/region-picker/manifest.json +16 -0
- package/servers/region-picker/package.json +15 -0
- package/src/app.js +1007 -0
- package/src/copy-tpl.js +77 -0
- package/src/lib/accelerator-validator.js +39 -0
- package/src/lib/asset-manager.js +385 -0
- package/src/lib/aws-profile-parser.js +181 -0
- package/src/lib/bootstrap-command-handler.js +1647 -0
- package/src/lib/bootstrap-config.js +238 -0
- package/src/lib/ci-register-helpers.js +124 -0
- package/src/lib/ci-report-helpers.js +158 -0
- package/src/lib/ci-stage-helpers.js +268 -0
- package/src/lib/cli-handler.js +529 -0
- package/src/lib/comment-generator.js +544 -0
- package/src/lib/community-reports-validator.js +91 -0
- package/src/lib/config-manager.js +2106 -0
- package/src/lib/configuration-exporter.js +204 -0
- package/src/lib/configuration-manager.js +695 -0
- package/src/lib/configuration-matcher.js +221 -0
- package/src/lib/cpu-validator.js +36 -0
- package/src/lib/cuda-validator.js +57 -0
- package/src/lib/deployment-config-resolver.js +103 -0
- package/src/lib/deployment-entry-schema.js +125 -0
- package/src/lib/deployment-registry.js +598 -0
- package/src/lib/docker-introspection-validator.js +51 -0
- package/src/lib/engine-prefix-resolver.js +60 -0
- package/src/lib/huggingface-client.js +172 -0
- package/src/lib/key-value-parser.js +37 -0
- package/src/lib/known-flags-validator.js +200 -0
- package/src/lib/manifest-cli.js +280 -0
- package/src/lib/mcp-client.js +303 -0
- package/src/lib/mcp-command-handler.js +532 -0
- package/src/lib/neuron-validator.js +80 -0
- package/src/lib/parameter-schema-validator.js +284 -0
- package/src/lib/prompt-runner.js +1349 -0
- package/src/lib/prompts.js +1138 -0
- package/src/lib/registry-command-handler.js +519 -0
- package/src/lib/registry-loader.js +198 -0
- package/src/lib/rocm-validator.js +80 -0
- package/src/lib/schema-validator.js +157 -0
- package/src/lib/sensitive-redactor.js +59 -0
- package/src/lib/template-engine.js +156 -0
- package/src/lib/template-manager.js +341 -0
- package/src/lib/validation-engine.js +314 -0
- package/src/prompt-adapter.js +63 -0
- package/templates/Dockerfile +300 -0
- package/templates/IAM_PERMISSIONS.md +84 -0
- package/templates/MIGRATION.md +488 -0
- package/templates/PROJECT_README.md +439 -0
- package/templates/TEMPLATE_SYSTEM.md +243 -0
- package/templates/buildspec.yml +64 -0
- package/templates/code/chat_template.jinja +1 -0
- package/templates/code/flask/gunicorn_config.py +35 -0
- package/templates/code/flask/wsgi.py +10 -0
- package/templates/code/model_handler.py +387 -0
- package/templates/code/serve +300 -0
- package/templates/code/serve.py +175 -0
- package/templates/code/serving.properties +105 -0
- package/templates/code/start_server.py +39 -0
- package/templates/code/start_server.sh +39 -0
- package/templates/diffusors/Dockerfile +72 -0
- package/templates/diffusors/patch_image_api.py +35 -0
- package/templates/diffusors/serve +115 -0
- package/templates/diffusors/start_server.sh +114 -0
- package/templates/do/.gitkeep +1 -0
- package/templates/do/README.md +541 -0
- package/templates/do/build +83 -0
- package/templates/do/ci +681 -0
- package/templates/do/clean +811 -0
- package/templates/do/config +260 -0
- package/templates/do/deploy +1560 -0
- package/templates/do/export +306 -0
- package/templates/do/logs +319 -0
- package/templates/do/manifest +12 -0
- package/templates/do/push +119 -0
- package/templates/do/register +580 -0
- package/templates/do/run +113 -0
- package/templates/do/submit +417 -0
- package/templates/do/test +1147 -0
- package/templates/hyperpod/configmap.yaml +24 -0
- package/templates/hyperpod/deployment.yaml +71 -0
- package/templates/hyperpod/pvc.yaml +42 -0
- package/templates/hyperpod/service.yaml +17 -0
- package/templates/nginx-diffusors.conf +74 -0
- package/templates/nginx-predictors.conf +47 -0
- package/templates/nginx-tensorrt.conf +74 -0
- package/templates/requirements.txt +61 -0
- package/templates/sample_model/test_inference.py +123 -0
- package/templates/sample_model/train_abalone.py +252 -0
- package/templates/test/test_endpoint.sh +79 -0
- package/templates/test/test_local_image.sh +80 -0
- package/templates/test/test_model_handler.py +180 -0
- package/templates/triton/Dockerfile +128 -0
- package/templates/triton/config.pbtxt +163 -0
- package/templates/triton/model.py +130 -0
- package/templates/triton/requirements.txt +11 -0
|
@@ -0,0 +1,1138 @@
|
|
|
1
|
+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
// SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
/**
|
|
5
|
+
* Prompt definitions organized by phase for better maintainability.
|
|
6
|
+
* Each phase handles a specific aspect of project configuration.
|
|
7
|
+
*/
|
|
8
|
+
|
|
9
|
+
import Table from 'cli-table3';
|
|
10
|
+
import chalk from 'chalk';
|
|
11
|
+
import { readFileSync } from 'node:fs';
|
|
12
|
+
import { resolve, dirname } from 'node:path';
|
|
13
|
+
import { fileURLToPath } from 'node:url';
|
|
14
|
+
|
|
15
|
+
const __promptsFilename = fileURLToPath(import.meta.url);
|
|
16
|
+
const __promptsDir = dirname(__promptsFilename);
|
|
17
|
+
const instancesCatalogPath = resolve(__promptsDir, '../../servers/instance-recommender/catalogs/instances.json');
|
|
18
|
+
|
|
19
|
+
/**
|
|
20
|
+
* Load instance types from the instances.json catalog and transform
|
|
21
|
+
* into the display shape expected by prompts (type, vcpus, memory, accelerator, useCase, category).
|
|
22
|
+
*/
|
|
23
|
+
function loadInstanceTypeRegistry() {
|
|
24
|
+
try {
|
|
25
|
+
const raw = readFileSync(instancesCatalogPath, 'utf8');
|
|
26
|
+
const catalog = JSON.parse(raw);
|
|
27
|
+
const entries = catalog?.catalog || {};
|
|
28
|
+
const registry = {};
|
|
29
|
+
for (const [instanceType, entry] of Object.entries(entries)) {
|
|
30
|
+
registry[instanceType] = {
|
|
31
|
+
type: instanceType,
|
|
32
|
+
vcpus: entry.vcpus || 0,
|
|
33
|
+
memory: entry.memGb ? `${entry.memGb} GB` : '0 GB',
|
|
34
|
+
accelerator: entry.hardware && entry.hardware !== 'None'
|
|
35
|
+
? entry.accelerator || entry.hardware
|
|
36
|
+
: 'None',
|
|
37
|
+
useCase: entry.notes || entry.tags?.join(', ') || '',
|
|
38
|
+
category: entry.category || 'cpu',
|
|
39
|
+
};
|
|
40
|
+
}
|
|
41
|
+
return registry;
|
|
42
|
+
} catch (error) {
|
|
43
|
+
console.warn(`Failed to load instance type registry from catalog: ${error.message}`);
|
|
44
|
+
return {};
|
|
45
|
+
}
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
const instanceTypeRegistry = loadInstanceTypeRegistry();
|
|
49
|
+
|
|
50
|
+
/**
|
|
51
|
+
* Generate pseudo-randomized project name based on framework
|
|
52
|
+
* @param {string} framework - The ML framework
|
|
53
|
+
* @returns {string} Generated project name
|
|
54
|
+
*/
|
|
55
|
+
function generateProjectName(framework) {
|
|
56
|
+
const adjectives = [
|
|
57
|
+
'smart', 'fast', 'clever', 'bright', 'swift', 'agile', 'sharp', 'quick',
|
|
58
|
+
'wise', 'keen', 'bold', 'sleek', 'neat', 'cool', 'fresh', 'prime'
|
|
59
|
+
];
|
|
60
|
+
|
|
61
|
+
const frameworkNames = {
|
|
62
|
+
'sklearn': ['sklearn', 'scikit', 'sk'],
|
|
63
|
+
'xgboost': ['xgb', 'xgboost', 'boost'],
|
|
64
|
+
'tensorflow': ['tf', 'tensorflow', 'tensor'],
|
|
65
|
+
'transformers': ['llm', 'transformer', 'gpt', 'bert', 'ai']
|
|
66
|
+
};
|
|
67
|
+
|
|
68
|
+
const suffixes = [
|
|
69
|
+
'model', 'predictor', 'classifier', 'engine', 'service', 'api',
|
|
70
|
+
'container', 'deployment', 'inference', 'ml', 'ai', 'bot'
|
|
71
|
+
];
|
|
72
|
+
|
|
73
|
+
// Get random elements
|
|
74
|
+
const adjective = adjectives[Math.floor(Math.random() * adjectives.length)];
|
|
75
|
+
const frameworkName = frameworkNames[framework] ?
|
|
76
|
+
frameworkNames[framework][Math.floor(Math.random() * frameworkNames[framework].length)] :
|
|
77
|
+
'ml';
|
|
78
|
+
const suffix = suffixes[Math.floor(Math.random() * suffixes.length)];
|
|
79
|
+
|
|
80
|
+
return `${adjective}-${frameworkName}-${suffix}`;
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
/**
|
|
84
|
+
* Phase 1: Core ML configuration (moved to first)
|
|
85
|
+
* Flattened deployment configuration combining architecture + backend
|
|
86
|
+
* Requirements: 3.1, 3.2, 16.1, 16.2, 16.3, 16.4, 16.8, 16.9
|
|
87
|
+
*/
|
|
88
|
+
const deploymentConfigPrompts = [
|
|
89
|
+
{
|
|
90
|
+
type: 'list',
|
|
91
|
+
name: 'deploymentConfig',
|
|
92
|
+
message: 'Select deployment configuration:',
|
|
93
|
+
choices: [
|
|
94
|
+
{ type: 'separator', separator: '── Large Language Models ──' },
|
|
95
|
+
{
|
|
96
|
+
name: 'Transformers with vLLM',
|
|
97
|
+
value: 'transformers-vllm',
|
|
98
|
+
short: 'transformers-vllm'
|
|
99
|
+
},
|
|
100
|
+
{
|
|
101
|
+
name: 'Transformers with SGLang',
|
|
102
|
+
value: 'transformers-sglang',
|
|
103
|
+
short: 'transformers-sglang'
|
|
104
|
+
},
|
|
105
|
+
{
|
|
106
|
+
name: 'Transformers with TensorRT-LLM',
|
|
107
|
+
value: 'transformers-tensorrt-llm',
|
|
108
|
+
short: 'transformers-tensorrt-llm'
|
|
109
|
+
},
|
|
110
|
+
{
|
|
111
|
+
name: 'Transformers with LMI (Large Model Inference)',
|
|
112
|
+
value: 'transformers-lmi',
|
|
113
|
+
short: 'transformers-lmi'
|
|
114
|
+
},
|
|
115
|
+
{
|
|
116
|
+
name: 'Transformers with DJL (Deep Java Library)',
|
|
117
|
+
value: 'transformers-djl',
|
|
118
|
+
short: 'transformers-djl'
|
|
119
|
+
},
|
|
120
|
+
{ type: 'separator', separator: '── HTTP Serving ──' },
|
|
121
|
+
{
|
|
122
|
+
name: 'HTTP with Flask',
|
|
123
|
+
value: 'http-flask',
|
|
124
|
+
short: 'http-flask'
|
|
125
|
+
},
|
|
126
|
+
{
|
|
127
|
+
name: 'HTTP with FastAPI',
|
|
128
|
+
value: 'http-fastapi',
|
|
129
|
+
short: 'http-fastapi'
|
|
130
|
+
},
|
|
131
|
+
{ type: 'separator', separator: '── NVIDIA Triton Inference Server ──' },
|
|
132
|
+
{
|
|
133
|
+
name: 'Triton FIL (XGBoost, LightGBM)',
|
|
134
|
+
value: 'triton-fil',
|
|
135
|
+
short: 'triton-fil'
|
|
136
|
+
},
|
|
137
|
+
{
|
|
138
|
+
name: 'Triton ONNX Runtime',
|
|
139
|
+
value: 'triton-onnxruntime',
|
|
140
|
+
short: 'triton-onnxruntime'
|
|
141
|
+
},
|
|
142
|
+
{
|
|
143
|
+
name: 'Triton TensorFlow',
|
|
144
|
+
value: 'triton-tensorflow',
|
|
145
|
+
short: 'triton-tensorflow'
|
|
146
|
+
},
|
|
147
|
+
{
|
|
148
|
+
name: 'Triton PyTorch',
|
|
149
|
+
value: 'triton-pytorch',
|
|
150
|
+
short: 'triton-pytorch'
|
|
151
|
+
},
|
|
152
|
+
{
|
|
153
|
+
name: 'Triton vLLM',
|
|
154
|
+
value: 'triton-vllm',
|
|
155
|
+
short: 'triton-vllm'
|
|
156
|
+
},
|
|
157
|
+
{
|
|
158
|
+
name: 'Triton TensorRT-LLM',
|
|
159
|
+
value: 'triton-tensorrtllm',
|
|
160
|
+
short: 'triton-tensorrtllm'
|
|
161
|
+
},
|
|
162
|
+
{
|
|
163
|
+
name: 'Triton Python Backend',
|
|
164
|
+
value: 'triton-python',
|
|
165
|
+
short: 'triton-python'
|
|
166
|
+
},
|
|
167
|
+
{ type: 'separator', separator: '── Diffusion Models ──' },
|
|
168
|
+
{
|
|
169
|
+
name: 'Diffusors with vLLM Omni',
|
|
170
|
+
value: 'diffusors-vllm-omni',
|
|
171
|
+
short: 'diffusors-vllm-omni'
|
|
172
|
+
}
|
|
173
|
+
]
|
|
174
|
+
}
|
|
175
|
+
];
|
|
176
|
+
|
|
177
|
+
// Keep legacy frameworkPrompts for backward compatibility (deprecated)
|
|
178
|
+
const frameworkPrompts = deploymentConfigPrompts;
|
|
179
|
+
|
|
180
|
+
/**
|
|
181
|
+
* Engine selection prompt for http architecture
|
|
182
|
+
* Requirements: 3.7
|
|
183
|
+
*/
|
|
184
|
+
const enginePrompts = [
|
|
185
|
+
{
|
|
186
|
+
type: 'list',
|
|
187
|
+
name: 'engine',
|
|
188
|
+
message: 'Select ML engine:',
|
|
189
|
+
choices: [
|
|
190
|
+
{ name: 'scikit-learn', value: 'sklearn' },
|
|
191
|
+
{ name: 'XGBoost', value: 'xgboost' },
|
|
192
|
+
{ name: 'TensorFlow', value: 'tensorflow' }
|
|
193
|
+
],
|
|
194
|
+
when: (answers) => {
|
|
195
|
+
const architecture = answers.architecture || answers.deploymentConfig?.split('-')[0]
|
|
196
|
+
return architecture === 'http'
|
|
197
|
+
}
|
|
198
|
+
}
|
|
199
|
+
];
|
|
200
|
+
|
|
201
|
+
/**
|
|
202
|
+
* Framework version selection prompts (for registry system)
|
|
203
|
+
* Requirements: 2.1, 2.6, 8.2, 8.3
|
|
204
|
+
*/
|
|
205
|
+
const frameworkVersionPrompts = [
|
|
206
|
+
{
|
|
207
|
+
type: 'list',
|
|
208
|
+
name: 'frameworkVersion',
|
|
209
|
+
message: (answers) => `Which version of ${answers.framework} are you using?`,
|
|
210
|
+
choices: (answers) => {
|
|
211
|
+
// Choices will be populated by PromptRunner with registry data
|
|
212
|
+
return answers._frameworkVersionChoices || [];
|
|
213
|
+
},
|
|
214
|
+
when: (answers) => {
|
|
215
|
+
// Only show if we have version choices available
|
|
216
|
+
return answers._frameworkVersionChoices && answers._frameworkVersionChoices.length > 0;
|
|
217
|
+
}
|
|
218
|
+
}
|
|
219
|
+
];
|
|
220
|
+
|
|
221
|
+
/**
|
|
222
|
+
* Framework profile selection prompts (for registry system)
|
|
223
|
+
* Requirements: 12.1, 12.2, 12.3, 12.4, 12.5, 12.10
|
|
224
|
+
*/
|
|
225
|
+
const frameworkProfilePrompts = [
|
|
226
|
+
{
|
|
227
|
+
type: 'list',
|
|
228
|
+
name: 'frameworkProfile',
|
|
229
|
+
message: 'Select a framework configuration profile:',
|
|
230
|
+
choices: (answers) => {
|
|
231
|
+
// Choices will be populated by PromptRunner with registry data
|
|
232
|
+
return answers._frameworkProfileChoices || [];
|
|
233
|
+
},
|
|
234
|
+
when: (answers) => {
|
|
235
|
+
// Only show if we have profile choices available
|
|
236
|
+
return answers._frameworkProfileChoices && answers._frameworkProfileChoices.length > 0;
|
|
237
|
+
}
|
|
238
|
+
}
|
|
239
|
+
];
|
|
240
|
+
|
|
241
|
+
const modelFormatPrompts = [
|
|
242
|
+
{
|
|
243
|
+
type: 'list',
|
|
244
|
+
name: 'modelFormat',
|
|
245
|
+
message: 'In which format is your model serialized?',
|
|
246
|
+
choices: (answers) => {
|
|
247
|
+
// Derive architecture from deploymentConfig
|
|
248
|
+
const architecture = answers.architecture || answers.deploymentConfig?.split('-')[0]
|
|
249
|
+
const backend = answers.backend || answers.deploymentConfig?.split('-').slice(1).join('-')
|
|
250
|
+
|
|
251
|
+
// For http architecture, use engine to determine formats
|
|
252
|
+
if (architecture === 'http') {
|
|
253
|
+
const engine = answers.engine
|
|
254
|
+
const formatMap = {
|
|
255
|
+
'xgboost': ['json', 'model', 'ubj'],
|
|
256
|
+
'sklearn': ['pkl', 'joblib'],
|
|
257
|
+
'tensorflow': ['keras', 'h5', 'SavedModel']
|
|
258
|
+
}
|
|
259
|
+
return formatMap[engine] || []
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
// For triton architecture, use backend-specific formats
|
|
263
|
+
if (architecture === 'triton') {
|
|
264
|
+
// FIL backend has multiple format choices
|
|
265
|
+
if (backend === 'fil') {
|
|
266
|
+
return ['xgboost_json', 'xgboost_ubj', 'lightgbm_txt']
|
|
267
|
+
}
|
|
268
|
+
// Python backend has multiple format choices
|
|
269
|
+
if (backend === 'python') {
|
|
270
|
+
return ['pkl', 'joblib', 'custom']
|
|
271
|
+
}
|
|
272
|
+
// Other Triton backends have auto-set formats (handled in when clause)
|
|
273
|
+
return []
|
|
274
|
+
}
|
|
275
|
+
|
|
276
|
+
// Legacy support for old format (should not be reached with new configs)
|
|
277
|
+
const framework = answers.framework || architecture
|
|
278
|
+
const formatMap = {
|
|
279
|
+
'xgboost': ['json', 'model', 'ubj'],
|
|
280
|
+
'sklearn': ['pkl', 'joblib'],
|
|
281
|
+
'tensorflow': ['keras', 'h5', 'SavedModel']
|
|
282
|
+
}
|
|
283
|
+
return formatMap[framework] || []
|
|
284
|
+
},
|
|
285
|
+
when: answers => {
|
|
286
|
+
const architecture = answers.architecture || answers.deploymentConfig?.split('-')[0]
|
|
287
|
+
const backend = answers.backend || answers.deploymentConfig?.split('-').slice(1).join('-')
|
|
288
|
+
|
|
289
|
+
// Skip for transformers (they use HF Hub)
|
|
290
|
+
if (architecture === 'transformers') {
|
|
291
|
+
return false
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
// Skip for diffusors (they use HF Hub)
|
|
295
|
+
if (architecture === 'diffusors') {
|
|
296
|
+
return false
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
// For http architecture, always show
|
|
300
|
+
if (architecture === 'http') {
|
|
301
|
+
return true
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
// For triton architecture, only show for backends with multiple format choices
|
|
305
|
+
if (architecture === 'triton') {
|
|
306
|
+
// FIL and Python backends have multiple format choices
|
|
307
|
+
if (backend === 'fil' || backend === 'python') {
|
|
308
|
+
return true
|
|
309
|
+
}
|
|
310
|
+
// Other backends have auto-set formats
|
|
311
|
+
return false
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
// Legacy support
|
|
315
|
+
const framework = answers.framework || architecture
|
|
316
|
+
return framework !== 'transformers'
|
|
317
|
+
}
|
|
318
|
+
},
|
|
319
|
+
{
|
|
320
|
+
type: 'list',
|
|
321
|
+
name: 'modelName',
|
|
322
|
+
message: 'Which model do you want to use?',
|
|
323
|
+
choices: (answers) => {
|
|
324
|
+
// Use MCP model-picker choices when available
|
|
325
|
+
if (answers._mcpModelChoices && answers._mcpModelChoices.length > 0) {
|
|
326
|
+
return [...answers._mcpModelChoices, 'Custom (enter manually)']
|
|
327
|
+
}
|
|
328
|
+
// Fallback to hardcoded defaults based on architecture
|
|
329
|
+
const architecture = answers.architecture || answers.deploymentConfig?.split('-')[0]
|
|
330
|
+
if (architecture === 'diffusors') {
|
|
331
|
+
return [
|
|
332
|
+
'stabilityai/stable-diffusion-3.5-medium',
|
|
333
|
+
'black-forest-labs/FLUX.1-schnell',
|
|
334
|
+
'black-forest-labs/FLUX.1-dev',
|
|
335
|
+
'Custom (enter manually)'
|
|
336
|
+
]
|
|
337
|
+
}
|
|
338
|
+
return [
|
|
339
|
+
'openai/gpt-oss-20b',
|
|
340
|
+
'meta-llama/Llama-3.2-3B-Instruct',
|
|
341
|
+
'meta-llama/Llama-3.2-1B-Instruct',
|
|
342
|
+
'Custom (enter manually)'
|
|
343
|
+
]
|
|
344
|
+
},
|
|
345
|
+
default: (answers) => {
|
|
346
|
+
if (answers._mcpModelChoices && answers._mcpModelChoices.length > 0) {
|
|
347
|
+
return answers._mcpModelChoices[0]
|
|
348
|
+
}
|
|
349
|
+
const architecture = answers.architecture || answers.deploymentConfig?.split('-')[0]
|
|
350
|
+
if (architecture === 'diffusors') {
|
|
351
|
+
return 'stabilityai/stable-diffusion-3.5-medium'
|
|
352
|
+
}
|
|
353
|
+
return 'openai/gpt-oss-20b'
|
|
354
|
+
},
|
|
355
|
+
when: answers => {
|
|
356
|
+
const architecture = answers.architecture || answers.deploymentConfig?.split('-')[0]
|
|
357
|
+
const backend = answers.backend || answers.deploymentConfig?.split('-').slice(1).join('-')
|
|
358
|
+
|
|
359
|
+
// Show for transformers architecture
|
|
360
|
+
if (architecture === 'transformers') {
|
|
361
|
+
return true
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
// Show for diffusors architecture (reuse HuggingFace model selection)
|
|
365
|
+
if (architecture === 'diffusors') {
|
|
366
|
+
return true
|
|
367
|
+
}
|
|
368
|
+
|
|
369
|
+
// Show for Triton LLM backends (vllm, tensorrtllm)
|
|
370
|
+
if (architecture === 'triton' && (backend === 'vllm' || backend === 'tensorrtllm')) {
|
|
371
|
+
return true
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
return false
|
|
375
|
+
}
|
|
376
|
+
},
|
|
377
|
+
{
|
|
378
|
+
type: 'input',
|
|
379
|
+
name: 'customModelName',
|
|
380
|
+
message: 'Enter the model path:',
|
|
381
|
+
validate: (input) => {
|
|
382
|
+
if (!input || input.trim() === '') {
|
|
383
|
+
return 'Model name is required'
|
|
384
|
+
}
|
|
385
|
+
// Basic validation - must contain a slash (org/model, hub/model, s3://path, etc.)
|
|
386
|
+
if (!input.includes('/')) {
|
|
387
|
+
return 'Please use the full model path (e.g., microsoft/DialoGPT-medium, jumpstart-hub://my-hub/my-model)'
|
|
388
|
+
}
|
|
389
|
+
return true
|
|
390
|
+
},
|
|
391
|
+
when: answers => {
|
|
392
|
+
const architecture = answers.architecture || answers.deploymentConfig?.split('-')[0]
|
|
393
|
+
const backend = answers.backend || answers.deploymentConfig?.split('-').slice(1).join('-')
|
|
394
|
+
|
|
395
|
+
// Show for transformers with custom model selection
|
|
396
|
+
if (architecture === 'transformers' && answers.modelName === 'Custom (enter manually)') {
|
|
397
|
+
return true
|
|
398
|
+
}
|
|
399
|
+
|
|
400
|
+
// Show for diffusors with custom model selection
|
|
401
|
+
if (architecture === 'diffusors' && answers.modelName === 'Custom (enter manually)') {
|
|
402
|
+
return true
|
|
403
|
+
}
|
|
404
|
+
|
|
405
|
+
// Show for Triton LLM backends with custom model selection
|
|
406
|
+
if (architecture === 'triton' && (backend === 'vllm' || backend === 'tensorrtllm') && answers.modelName === 'Custom (enter manually)') {
|
|
407
|
+
return true
|
|
408
|
+
}
|
|
409
|
+
|
|
410
|
+
return false
|
|
411
|
+
}
|
|
412
|
+
}
|
|
413
|
+
];
|
|
414
|
+
|
|
415
|
+
// Model server prompts are now deprecated - modelServer is derived from deploymentConfig
|
|
416
|
+
const modelServerPrompts = [];
|
|
417
|
+
|
|
418
|
+
/**
|
|
419
|
+
* Model loading strategy prompt
|
|
420
|
+
* Asks user whether to bake model into image at build time or download at container startup.
|
|
421
|
+
* Requirements: 13.1, 13.2, 13.3, 13.4, 13.5
|
|
422
|
+
*/
|
|
423
|
+
const modelLoadStrategyPrompts = [
|
|
424
|
+
{
|
|
425
|
+
type: 'list',
|
|
426
|
+
name: 'modelLoadStrategy',
|
|
427
|
+
message: 'How should the model be loaded?\n'
|
|
428
|
+
+ ' Build-time: Bakes model into image (larger image, faster startup)\n'
|
|
429
|
+
+ ' Runtime: Downloads at container startup (smaller image, slower startup)',
|
|
430
|
+
choices: [
|
|
431
|
+
{ name: 'Runtime (download at startup)', value: 'runtime' },
|
|
432
|
+
{ name: 'Build-time (bake into image) [EXPERIMENTAL]', value: 'build-time' }
|
|
433
|
+
],
|
|
434
|
+
default: 'runtime',
|
|
435
|
+
when: (answers) => {
|
|
436
|
+
const architecture = answers.architecture || answers.deploymentConfig?.split('-')[0]
|
|
437
|
+
return architecture === 'transformers' || architecture === 'diffusors'
|
|
438
|
+
}
|
|
439
|
+
}
|
|
440
|
+
];
|
|
441
|
+
|
|
442
|
+
/**
|
|
443
|
+
* Model profile selection prompts (for registry system)
|
|
444
|
+
* Requirements: 12.1, 12.2, 12.3, 12.4, 12.5, 12.10
|
|
445
|
+
*/
|
|
446
|
+
const modelProfilePrompts = [
|
|
447
|
+
{
|
|
448
|
+
type: 'list',
|
|
449
|
+
name: 'modelProfile',
|
|
450
|
+
message: 'Select a model configuration profile:',
|
|
451
|
+
choices: (answers) => {
|
|
452
|
+
// Choices will be populated by PromptRunner with registry data
|
|
453
|
+
return answers._modelProfileChoices || [];
|
|
454
|
+
},
|
|
455
|
+
when: (answers) => {
|
|
456
|
+
// Only show if we have profile choices available
|
|
457
|
+
return answers._modelProfileChoices && answers._modelProfileChoices.length > 0;
|
|
458
|
+
}
|
|
459
|
+
}
|
|
460
|
+
];
|
|
461
|
+
|
|
462
|
+
/**
|
|
463
|
+
* List of example model IDs that don't require HF_TOKEN prompts
|
|
464
|
+
* These are public models that don't need authentication
|
|
465
|
+
*/
|
|
466
|
+
// eslint-disable-next-line no-unused-vars -- reference list for future use
|
|
467
|
+
const EXAMPLE_MODEL_IDS = [
|
|
468
|
+
'openai/gpt-oss-20b',
|
|
469
|
+
'meta-llama/Llama-3.2-3B-Instruct',
|
|
470
|
+
'meta-llama/Llama-3.2-1B-Instruct'
|
|
471
|
+
];
|
|
472
|
+
|
|
473
|
+
const hfTokenPrompts = [
|
|
474
|
+
{
|
|
475
|
+
type: 'input',
|
|
476
|
+
name: 'hfToken',
|
|
477
|
+
message: 'HuggingFace token (enter token, "$HF_TOKEN" for env var, or leave empty):',
|
|
478
|
+
when: (answers) => {
|
|
479
|
+
const architecture = answers.architecture || answers.deploymentConfig?.split('-')[0]
|
|
480
|
+
const backend = answers.backend || answers.deploymentConfig?.split('-').slice(1).join('-')
|
|
481
|
+
|
|
482
|
+
// Prompt for transformers architecture
|
|
483
|
+
const isTransformers = architecture === 'transformers'
|
|
484
|
+
|
|
485
|
+
// Prompt for diffusors architecture (uses HuggingFace Hub)
|
|
486
|
+
const isDiffusors = architecture === 'diffusors'
|
|
487
|
+
|
|
488
|
+
// Prompt for Triton LLM backends (vllm, tensorrtllm)
|
|
489
|
+
// Requirements: 9.1, 9.2
|
|
490
|
+
const isTritonLlm = architecture === 'triton' && (backend === 'vllm' || backend === 'tensorrtllm')
|
|
491
|
+
|
|
492
|
+
if (!isTransformers && !isDiffusors && !isTritonLlm) {
|
|
493
|
+
return false
|
|
494
|
+
}
|
|
495
|
+
|
|
496
|
+
// Skip HF token prompt for non-HuggingFace model sources
|
|
497
|
+
// (S3, JumpStart, Private Hub, Registry models don't need HF auth)
|
|
498
|
+
const modelSource = answers.modelSource
|
|
499
|
+
if (modelSource && modelSource !== 'huggingface') {
|
|
500
|
+
return false
|
|
501
|
+
}
|
|
502
|
+
|
|
503
|
+
// Display security warning before prompting
|
|
504
|
+
console.log('\n🔐 HuggingFace Authentication')
|
|
505
|
+
console.log(' Many models (e.g. Llama, Mistral) are gated and require a token.')
|
|
506
|
+
console.log('⚠️ Security Note: The token will be baked into the Docker image.')
|
|
507
|
+
console.log(' Anyone with access to the image can extract the token using \'docker inspect\'.')
|
|
508
|
+
console.log(' For CI/CD pipelines, use "$HF_TOKEN" to reference an environment variable.')
|
|
509
|
+
console.log(' This keeps the token out of the image and allows rotation without rebuilding.\n')
|
|
510
|
+
|
|
511
|
+
return true
|
|
512
|
+
},
|
|
513
|
+
validate: (input) => {
|
|
514
|
+
// Empty is valid (not all models require auth)
|
|
515
|
+
if (!input || input.trim() === '') {
|
|
516
|
+
return true
|
|
517
|
+
}
|
|
518
|
+
|
|
519
|
+
// $HF_TOKEN reference is valid
|
|
520
|
+
if (input.trim() === '$HF_TOKEN') {
|
|
521
|
+
return true
|
|
522
|
+
}
|
|
523
|
+
|
|
524
|
+
// Direct token should start with hf_ (warning only, not blocking)
|
|
525
|
+
if (!input.startsWith('hf_')) {
|
|
526
|
+
console.warn('\n⚠️ Warning: HuggingFace tokens typically start with "hf_"')
|
|
527
|
+
console.warn(' If this is intentional, you can ignore this warning.')
|
|
528
|
+
}
|
|
529
|
+
|
|
530
|
+
return true // Always return true (non-blocking validation)
|
|
531
|
+
}
|
|
532
|
+
}
|
|
533
|
+
];
|
|
534
|
+
|
|
535
|
+
const ngcApiKeyPrompts = [
|
|
536
|
+
{
|
|
537
|
+
type: 'input',
|
|
538
|
+
name: 'ngcApiKey',
|
|
539
|
+
message: 'NVIDIA NGC API key (enter key, "$NGC_API_KEY" for env var, or leave empty):',
|
|
540
|
+
when: (answers) => {
|
|
541
|
+
const architecture = answers.architecture || answers.deploymentConfig?.split('-')[0]
|
|
542
|
+
const backend = answers.backend || answers.deploymentConfig?.split('-').slice(1).join('-')
|
|
543
|
+
|
|
544
|
+
// Never prompt for NGC key for Triton configs (public images)
|
|
545
|
+
// Requirements: 9.2
|
|
546
|
+
if (architecture === 'triton') {
|
|
547
|
+
return false
|
|
548
|
+
}
|
|
549
|
+
|
|
550
|
+
// Never prompt for NGC key for diffusors configs (public Docker Hub images)
|
|
551
|
+
if (architecture === 'diffusors') {
|
|
552
|
+
return false
|
|
553
|
+
}
|
|
554
|
+
|
|
555
|
+
// Only prompt for transformers-tensorrt-llm
|
|
556
|
+
if (architecture === 'transformers' && backend === 'tensorrt-llm') {
|
|
557
|
+
console.log('\n🔐 NVIDIA NGC Authentication')
|
|
558
|
+
console.log(' TensorRT-LLM base images are hosted on NVIDIA NGC and require an API key.')
|
|
559
|
+
console.log(' 1. Create account at: https://ngc.nvidia.com/')
|
|
560
|
+
console.log(' 2. Generate API key in account settings')
|
|
561
|
+
console.log(' For CI/CD pipelines, use "$NGC_API_KEY" to reference an environment variable.\n')
|
|
562
|
+
return true
|
|
563
|
+
}
|
|
564
|
+
|
|
565
|
+
return false
|
|
566
|
+
},
|
|
567
|
+
validate: (input) => {
|
|
568
|
+
if (!input || input.trim() === '') {
|
|
569
|
+
return true
|
|
570
|
+
}
|
|
571
|
+
|
|
572
|
+
if (input.trim() === '$NGC_API_KEY') {
|
|
573
|
+
return true
|
|
574
|
+
}
|
|
575
|
+
|
|
576
|
+
return true
|
|
577
|
+
}
|
|
578
|
+
}
|
|
579
|
+
];
|
|
580
|
+
|
|
581
|
+
const modulePrompts = [
|
|
582
|
+
{
|
|
583
|
+
type: 'confirm',
|
|
584
|
+
name: 'includeSampleModel',
|
|
585
|
+
message: 'Include sample Abalone classifier?',
|
|
586
|
+
default: false,
|
|
587
|
+
when: (answers) => {
|
|
588
|
+
const architecture = answers.architecture || answers.deploymentConfig?.split('-')[0]
|
|
589
|
+
const backend = answers.backend || answers.deploymentConfig?.split('-').slice(1).join('-')
|
|
590
|
+
|
|
591
|
+
// Never for transformers
|
|
592
|
+
if (architecture === 'transformers') {
|
|
593
|
+
return false
|
|
594
|
+
}
|
|
595
|
+
|
|
596
|
+
// Never for diffusors (diffusion models cannot be trained inline)
|
|
597
|
+
if (architecture === 'diffusors') {
|
|
598
|
+
return false
|
|
599
|
+
}
|
|
600
|
+
|
|
601
|
+
// For Triton, check if backend supports sample model
|
|
602
|
+
if (architecture === 'triton') {
|
|
603
|
+
// Triton LLM backends don't support sample model
|
|
604
|
+
if (backend === 'vllm' || backend === 'tensorrtllm' || backend === 'pytorch') {
|
|
605
|
+
return false
|
|
606
|
+
}
|
|
607
|
+
// Other Triton backends support sample model
|
|
608
|
+
return true
|
|
609
|
+
}
|
|
610
|
+
|
|
611
|
+
// For http architecture, always show
|
|
612
|
+
return true
|
|
613
|
+
}
|
|
614
|
+
},
|
|
615
|
+
{
|
|
616
|
+
type: 'checkbox',
|
|
617
|
+
name: 'testTypes',
|
|
618
|
+
message: 'Test type?',
|
|
619
|
+
choices: (answers) => {
|
|
620
|
+
const architecture = answers.architecture || answers.deploymentConfig?.split('-')[0]
|
|
621
|
+
const backend = answers.backend || answers.deploymentConfig?.split('-').slice(1).join('-')
|
|
622
|
+
|
|
623
|
+
// Transformers and Triton LLM backends only support hosted endpoint tests
|
|
624
|
+
if (architecture === 'transformers') {
|
|
625
|
+
return ['hosted-model-endpoint']
|
|
626
|
+
}
|
|
627
|
+
if (architecture === 'triton' && (backend === 'vllm' || backend === 'tensorrtllm')) {
|
|
628
|
+
return ['hosted-model-endpoint']
|
|
629
|
+
}
|
|
630
|
+
|
|
631
|
+
return ['local-model-cli', 'local-model-server', 'hosted-model-endpoint']
|
|
632
|
+
},
|
|
633
|
+
default: (answers) => {
|
|
634
|
+
const architecture = answers.architecture || answers.deploymentConfig?.split('-')[0]
|
|
635
|
+
const backend = answers.backend || answers.deploymentConfig?.split('-').slice(1).join('-')
|
|
636
|
+
|
|
637
|
+
if (architecture === 'transformers') {
|
|
638
|
+
return ['hosted-model-endpoint']
|
|
639
|
+
}
|
|
640
|
+
if (architecture === 'triton' && (backend === 'vllm' || backend === 'tensorrtllm')) {
|
|
641
|
+
return ['hosted-model-endpoint']
|
|
642
|
+
}
|
|
643
|
+
|
|
644
|
+
return ['local-model-cli', 'local-model-server', 'hosted-model-endpoint']
|
|
645
|
+
}
|
|
646
|
+
}
|
|
647
|
+
];
|
|
648
|
+
|
|
649
|
+
/**
|
|
650
|
+
* Infrastructure prompts split into sub-phases so the prompt runner can
|
|
651
|
+
* interleave MCP queries between them (e.g. query instance-recommender
|
|
652
|
+
* only after we know the deployment target is managed-inference).
|
|
653
|
+
*
|
|
654
|
+
* Ordering: Region → Deployment Target → Instance/HyperPod → Build Target → Role
|
|
655
|
+
*/
|
|
656
|
+
|
|
657
|
+
// Sub-phase A: Region + Deployment Target (always asked first)
|
|
658
|
+
const infraRegionAndTargetPrompts = [
|
|
659
|
+
{
|
|
660
|
+
type: 'list',
|
|
661
|
+
name: 'awsRegion',
|
|
662
|
+
message: 'Target AWS region?',
|
|
663
|
+
choices: (answers) => {
|
|
664
|
+
// If a bootstrap profile set a region, include it in choices
|
|
665
|
+
const bootstrapRegion = answers._bootstrapRegion
|
|
666
|
+
const choices = ['us-east-1']
|
|
667
|
+
if (bootstrapRegion && bootstrapRegion !== 'us-east-1') {
|
|
668
|
+
choices.unshift({ name: `${bootstrapRegion} (from bootstrap profile)`, value: bootstrapRegion })
|
|
669
|
+
}
|
|
670
|
+
choices.push({ name: 'Custom...', value: 'custom' })
|
|
671
|
+
return choices
|
|
672
|
+
},
|
|
673
|
+
default: (answers) => answers._bootstrapRegion || 'us-east-1'
|
|
674
|
+
},
|
|
675
|
+
{
|
|
676
|
+
type: 'input',
|
|
677
|
+
name: 'customAwsRegion',
|
|
678
|
+
message: 'Enter AWS region (e.g., us-west-2, eu-west-1):',
|
|
679
|
+
when: answers => answers.awsRegion === 'custom'
|
|
680
|
+
},
|
|
681
|
+
{
|
|
682
|
+
type: 'list',
|
|
683
|
+
name: 'deploymentTarget',
|
|
684
|
+
message: 'Deployment target?',
|
|
685
|
+
choices: [
|
|
686
|
+
{ name: 'SageMaker Managed Inference - Real Time', value: 'managed-inference' },
|
|
687
|
+
{ name: 'SageMaker Managed Inference - Async', value: 'async-inference' },
|
|
688
|
+
{ name: 'SageMaker Managed Inference - Batch', value: 'batch-transform' },
|
|
689
|
+
{ name: 'SageMaker HyperPod - EKS', value: 'hyperpod-eks' }
|
|
690
|
+
],
|
|
691
|
+
default: 'managed-inference'
|
|
692
|
+
}
|
|
693
|
+
];
|
|
694
|
+
|
|
695
|
+
// Sub-phase B: Instance type (only when deploymentTarget === 'managed-inference')
|
|
696
|
+
const infraInstancePrompts = [
|
|
697
|
+
{
|
|
698
|
+
type: 'list',
|
|
699
|
+
name: 'instanceType',
|
|
700
|
+
when: answers => answers.deploymentTarget === 'managed-inference' || answers.deploymentTarget === 'async-inference' || answers.deploymentTarget === 'batch-transform' || answers.deploymentTarget === 'hyperpod-eks',
|
|
701
|
+
message: (answers) => {
|
|
702
|
+
const framework = answers.framework || answers.deploymentConfig?.split('-')[0];
|
|
703
|
+
|
|
704
|
+
const table = new Table({
|
|
705
|
+
head: [
|
|
706
|
+
chalk.cyan('Instance Type'),
|
|
707
|
+
chalk.cyan('vCPUs'),
|
|
708
|
+
chalk.cyan('Memory'),
|
|
709
|
+
chalk.cyan('Accelerator'),
|
|
710
|
+
chalk.cyan('Use Case')
|
|
711
|
+
],
|
|
712
|
+
colWidths: [20, 8, 12, 20, 25]
|
|
713
|
+
});
|
|
714
|
+
|
|
715
|
+
const instances = Object.values(instanceTypeRegistry);
|
|
716
|
+
let filteredInstances = framework === 'transformers'
|
|
717
|
+
? instances.filter(i => i.category === 'gpu')
|
|
718
|
+
: instances;
|
|
719
|
+
|
|
720
|
+
const mcpChoices = answers._mcpInstanceChoices;
|
|
721
|
+
if (mcpChoices && mcpChoices.length > 0) {
|
|
722
|
+
const mcpSet = new Set(mcpChoices);
|
|
723
|
+
filteredInstances = filteredInstances.filter(i => mcpSet.has(i.type));
|
|
724
|
+
}
|
|
725
|
+
|
|
726
|
+
filteredInstances.forEach(instance => {
|
|
727
|
+
table.push([
|
|
728
|
+
instance.type,
|
|
729
|
+
instance.vcpus.toString(),
|
|
730
|
+
instance.memory,
|
|
731
|
+
instance.accelerator,
|
|
732
|
+
instance.useCase
|
|
733
|
+
]);
|
|
734
|
+
});
|
|
735
|
+
|
|
736
|
+
table.push([
|
|
737
|
+
chalk.yellow('Custom...'),
|
|
738
|
+
'-',
|
|
739
|
+
'-',
|
|
740
|
+
'-',
|
|
741
|
+
'Specify your own'
|
|
742
|
+
]);
|
|
743
|
+
|
|
744
|
+
const header = mcpChoices && mcpChoices.length > 0
|
|
745
|
+
? 'Available Instance Types (filtered by MCP):'
|
|
746
|
+
: 'Available Instance Types:';
|
|
747
|
+
console.log(`\n${ chalk.bold(header)}`);
|
|
748
|
+
console.log(table.toString());
|
|
749
|
+
console.log('');
|
|
750
|
+
|
|
751
|
+
return 'Select instance type:';
|
|
752
|
+
},
|
|
753
|
+
choices: (answers) => {
|
|
754
|
+
const framework = answers.framework || answers.deploymentConfig?.split('-')[0];
|
|
755
|
+
|
|
756
|
+
const instances = Object.values(instanceTypeRegistry);
|
|
757
|
+
let filteredInstances = framework === 'transformers'
|
|
758
|
+
? instances.filter(i => i.category === 'gpu')
|
|
759
|
+
: instances;
|
|
760
|
+
|
|
761
|
+
const mcpChoices = answers._mcpInstanceChoices;
|
|
762
|
+
if (mcpChoices && mcpChoices.length > 0) {
|
|
763
|
+
const mcpSet = new Set(mcpChoices);
|
|
764
|
+
filteredInstances = filteredInstances.filter(i => mcpSet.has(i.type));
|
|
765
|
+
}
|
|
766
|
+
|
|
767
|
+
const choices = filteredInstances.map(instance => ({
|
|
768
|
+
name: instance.type,
|
|
769
|
+
value: instance.type
|
|
770
|
+
}));
|
|
771
|
+
|
|
772
|
+
choices.push({
|
|
773
|
+
name: 'Custom...',
|
|
774
|
+
value: 'custom'
|
|
775
|
+
});
|
|
776
|
+
|
|
777
|
+
return choices;
|
|
778
|
+
},
|
|
779
|
+
default: (answers) => {
|
|
780
|
+
const framework = answers.framework || answers.deploymentConfig?.split('-')[0];
|
|
781
|
+
const modelServer = answers.modelServer || answers.deploymentConfig?.split('-')[1];
|
|
782
|
+
|
|
783
|
+
if (framework === 'transformers') {
|
|
784
|
+
if (modelServer === 'tensorrt-llm') {
|
|
785
|
+
return 'ml.g5.12xlarge';
|
|
786
|
+
}
|
|
787
|
+
return 'ml.g5.2xlarge';
|
|
788
|
+
}
|
|
789
|
+
return 'ml.m5.xlarge';
|
|
790
|
+
}
|
|
791
|
+
},
|
|
792
|
+
{
|
|
793
|
+
type: 'input',
|
|
794
|
+
name: 'customInstanceType',
|
|
795
|
+
message: 'Enter AWS SageMaker instance type (e.g., ml.t3.medium, ml.g4dn.xlarge):',
|
|
796
|
+
validate: (input) => {
|
|
797
|
+
if (!input || input.trim() === '') {
|
|
798
|
+
return 'Instance type is required';
|
|
799
|
+
}
|
|
800
|
+
const instancePattern = /^ml\.[a-z0-9]+\.(nano|micro|small|medium|large|xlarge|[0-9]+xlarge)$/;
|
|
801
|
+
if (!instancePattern.test(input.trim())) {
|
|
802
|
+
return 'Invalid instance type format. Expected format: ml.{family}.{size} (e.g., ml.m5.large, ml.g4dn.xlarge)';
|
|
803
|
+
}
|
|
804
|
+
return true;
|
|
805
|
+
},
|
|
806
|
+
when: answers => answers.instanceType === 'custom'
|
|
807
|
+
}
|
|
808
|
+
];
|
|
809
|
+
|
|
810
|
+
// Sub-phase C: HyperPod EKS-specific prompts (only when deploymentTarget === 'hyperpod-eks')
|
|
811
|
+
const infraHyperPodPrompts = [
|
|
812
|
+
{
|
|
813
|
+
type: 'list',
|
|
814
|
+
name: 'hyperPodCluster',
|
|
815
|
+
message: 'Select HyperPod EKS cluster:',
|
|
816
|
+
choices: (answers) => {
|
|
817
|
+
const mcpChoices = answers._mcpHyperPodChoices || [];
|
|
818
|
+
if (mcpChoices.length > 0) {
|
|
819
|
+
return [...mcpChoices, { name: 'Custom (enter manually)', value: 'custom' }];
|
|
820
|
+
}
|
|
821
|
+
// No MCP results — offer manual entry as the only option
|
|
822
|
+
return [{ name: 'Enter cluster name manually', value: 'custom' }];
|
|
823
|
+
},
|
|
824
|
+
when: answers => answers.deploymentTarget === 'hyperpod-eks'
|
|
825
|
+
},
|
|
826
|
+
{
|
|
827
|
+
type: 'input',
|
|
828
|
+
name: 'customHyperPodCluster',
|
|
829
|
+
message: 'Enter HyperPod EKS cluster name:',
|
|
830
|
+
validate: (input) => {
|
|
831
|
+
if (!input || input.trim() === '') {
|
|
832
|
+
return 'Cluster name is required';
|
|
833
|
+
}
|
|
834
|
+
return true;
|
|
835
|
+
},
|
|
836
|
+
when: answers => answers.deploymentTarget === 'hyperpod-eks' && answers.hyperPodCluster === 'custom'
|
|
837
|
+
},
|
|
838
|
+
{
|
|
839
|
+
type: 'input',
|
|
840
|
+
name: 'hyperPodNamespace',
|
|
841
|
+
message: 'Kubernetes namespace?',
|
|
842
|
+
default: 'default',
|
|
843
|
+
when: answers => answers.deploymentTarget === 'hyperpod-eks'
|
|
844
|
+
},
|
|
845
|
+
{
|
|
846
|
+
type: 'number',
|
|
847
|
+
name: 'hyperPodReplicas',
|
|
848
|
+
message: 'Number of pod replicas?',
|
|
849
|
+
default: 1,
|
|
850
|
+
when: answers => answers.deploymentTarget === 'hyperpod-eks'
|
|
851
|
+
},
|
|
852
|
+
{
|
|
853
|
+
type: 'input',
|
|
854
|
+
name: 'fsxVolumeHandle',
|
|
855
|
+
message: 'FSx for Lustre volume handle (optional, press Enter to skip):',
|
|
856
|
+
when: answers => answers.deploymentTarget === 'hyperpod-eks'
|
|
857
|
+
}
|
|
858
|
+
];
|
|
859
|
+
|
|
860
|
+
// Sub-phase D: Build target + role ARN (always asked last)
|
|
861
|
+
const infraBuildPrompts = [
|
|
862
|
+
{
|
|
863
|
+
type: 'list',
|
|
864
|
+
name: 'buildTarget',
|
|
865
|
+
message: 'Build target?',
|
|
866
|
+
choices: [
|
|
867
|
+
{ name: 'CodeBuild (recommended)', value: 'codebuild' }
|
|
868
|
+
],
|
|
869
|
+
default: 'codebuild'
|
|
870
|
+
},
|
|
871
|
+
{
|
|
872
|
+
type: 'list',
|
|
873
|
+
name: 'codebuildComputeType',
|
|
874
|
+
message: 'CodeBuild compute type?',
|
|
875
|
+
choices: [
|
|
876
|
+
'BUILD_GENERAL1_SMALL',
|
|
877
|
+
'BUILD_GENERAL1_MEDIUM',
|
|
878
|
+
'BUILD_GENERAL1_LARGE'
|
|
879
|
+
],
|
|
880
|
+
default: 'BUILD_GENERAL1_MEDIUM',
|
|
881
|
+
when: answers => answers.buildTarget === 'codebuild'
|
|
882
|
+
},
|
|
883
|
+
{
|
|
884
|
+
type: 'input',
|
|
885
|
+
name: 'awsRoleArn',
|
|
886
|
+
message: 'AWS IAM Role ARN for SageMaker execution (optional)?',
|
|
887
|
+
validate: (input) => {
|
|
888
|
+
if (!input || input.trim() === '') {
|
|
889
|
+
return true;
|
|
890
|
+
}
|
|
891
|
+
const arnPattern = /^arn:aws:iam::\d{12}:role\/[\w+=,.@-]+$/;
|
|
892
|
+
if (!arnPattern.test(input)) {
|
|
893
|
+
return 'Invalid ARN format. Expected: arn:aws:iam::123456789012:role/RoleName';
|
|
894
|
+
}
|
|
895
|
+
return true;
|
|
896
|
+
}
|
|
897
|
+
}
|
|
898
|
+
];
|
|
899
|
+
|
|
900
|
+
/**
|
|
901
|
+
* Sub-phase: Async-specific prompts (only when deploymentTarget === 'async-inference')
|
|
902
|
+
* Requirements: 2.1, 2.2, 2.3, 2.4
|
|
903
|
+
*/
|
|
904
|
+
const infraAsyncPrompts = [
|
|
905
|
+
{
|
|
906
|
+
type: 'input',
|
|
907
|
+
name: 'asyncS3OutputPath',
|
|
908
|
+
message: 'S3 output path for async results (leave empty for default: s3://ml-container-creator-async-{region}-{account-id}/{project-name}/output/):',
|
|
909
|
+
when: answers => answers.deploymentTarget === 'async-inference'
|
|
910
|
+
},
|
|
911
|
+
{
|
|
912
|
+
type: 'input',
|
|
913
|
+
name: 'asyncSnsSuccessTopic',
|
|
914
|
+
message: 'SNS success topic ARN (leave empty for auto-created per-project topic):',
|
|
915
|
+
when: answers => answers.deploymentTarget === 'async-inference'
|
|
916
|
+
},
|
|
917
|
+
{
|
|
918
|
+
type: 'input',
|
|
919
|
+
name: 'asyncSnsErrorTopic',
|
|
920
|
+
message: 'SNS error topic ARN (leave empty for auto-created per-project topic):',
|
|
921
|
+
when: answers => answers.deploymentTarget === 'async-inference'
|
|
922
|
+
},
|
|
923
|
+
{
|
|
924
|
+
type: 'number',
|
|
925
|
+
name: 'asyncMaxConcurrentInvocations',
|
|
926
|
+
message: 'Max concurrent invocations per instance?',
|
|
927
|
+
default: 1,
|
|
928
|
+
when: answers => answers.deploymentTarget === 'async-inference'
|
|
929
|
+
}
|
|
930
|
+
];
|
|
931
|
+
|
|
932
|
+
/**
|
|
933
|
+
* Sub-phase: Batch transform-specific prompts (only when deploymentTarget === 'batch-transform')
|
|
934
|
+
* Requirements: 2.1, 2.2, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9
|
|
935
|
+
*/
|
|
936
|
+
const infraBatchTransformPrompts = [
|
|
937
|
+
{
|
|
938
|
+
type: 'input',
|
|
939
|
+
name: 'batchInputPath',
|
|
940
|
+
message: 'S3 input path for batch transform data (leave empty for default: s3://ml-container-creator-batch-{region}-{account-id}/{project-name}/input/):',
|
|
941
|
+
when: answers => answers.deploymentTarget === 'batch-transform'
|
|
942
|
+
},
|
|
943
|
+
{
|
|
944
|
+
type: 'input',
|
|
945
|
+
name: 'batchOutputPath',
|
|
946
|
+
message: 'S3 output path for batch transform results (leave empty for default: s3://ml-container-creator-batch-{region}-{account-id}/{project-name}/output/):',
|
|
947
|
+
when: answers => answers.deploymentTarget === 'batch-transform'
|
|
948
|
+
},
|
|
949
|
+
{
|
|
950
|
+
type: 'number',
|
|
951
|
+
name: 'batchInstanceCount',
|
|
952
|
+
message: 'How many instances should run the batch job in parallel?',
|
|
953
|
+
default: 1,
|
|
954
|
+
when: answers => answers.deploymentTarget === 'batch-transform'
|
|
955
|
+
},
|
|
956
|
+
{
|
|
957
|
+
type: 'list',
|
|
958
|
+
name: 'batchSplitType',
|
|
959
|
+
message: 'Input file format — how should SageMaker read your input files?',
|
|
960
|
+
choices: [
|
|
961
|
+
{ name: 'Line — one record per line (JSON lines, CSV)', value: 'Line' },
|
|
962
|
+
{ name: 'RecordIO — Amazon RecordIO format', value: 'RecordIO' },
|
|
963
|
+
{ name: 'None — send each file as a single request', value: 'None' }
|
|
964
|
+
],
|
|
965
|
+
default: 'Line',
|
|
966
|
+
when: answers => answers.deploymentTarget === 'batch-transform'
|
|
967
|
+
},
|
|
968
|
+
{
|
|
969
|
+
type: 'list',
|
|
970
|
+
name: 'batchStrategy',
|
|
971
|
+
message: 'How many records should be sent per inference request?',
|
|
972
|
+
choices: [
|
|
973
|
+
{ name: 'MultiRecord — batch multiple records per request (higher throughput)', value: 'MultiRecord' },
|
|
974
|
+
{ name: 'SingleRecord — one record per request (simpler, more predictable)', value: 'SingleRecord' }
|
|
975
|
+
],
|
|
976
|
+
default: 'MultiRecord',
|
|
977
|
+
when: answers => answers.deploymentTarget === 'batch-transform'
|
|
978
|
+
},
|
|
979
|
+
{
|
|
980
|
+
type: 'list',
|
|
981
|
+
name: 'batchJoinSource',
|
|
982
|
+
message: 'Include original input data alongside predictions in the output?',
|
|
983
|
+
choices: [
|
|
984
|
+
{ name: 'No — output predictions only', value: 'None' },
|
|
985
|
+
{ name: 'Yes — merge input with predictions (useful for traceability)', value: 'Input' }
|
|
986
|
+
],
|
|
987
|
+
default: 'None',
|
|
988
|
+
when: answers => answers.deploymentTarget === 'batch-transform'
|
|
989
|
+
},
|
|
990
|
+
{
|
|
991
|
+
type: 'number',
|
|
992
|
+
name: 'batchMaxConcurrentTransforms',
|
|
993
|
+
message: 'Max concurrent inference requests per instance?',
|
|
994
|
+
default: 1,
|
|
995
|
+
when: answers => answers.deploymentTarget === 'batch-transform'
|
|
996
|
+
},
|
|
997
|
+
{
|
|
998
|
+
type: 'number',
|
|
999
|
+
name: 'batchMaxPayloadInMB',
|
|
1000
|
+
message: 'Max request payload size in MB (0-100)?',
|
|
1001
|
+
default: 6,
|
|
1002
|
+
when: answers => answers.deploymentTarget === 'batch-transform'
|
|
1003
|
+
}
|
|
1004
|
+
];
|
|
1005
|
+
|
|
1006
|
+
// Combined view for tests and backward compatibility
|
|
1007
|
+
const infrastructurePrompts = [
|
|
1008
|
+
...infraRegionAndTargetPrompts,
|
|
1009
|
+
...infraInstancePrompts,
|
|
1010
|
+
...infraHyperPodPrompts,
|
|
1011
|
+
...infraBuildPrompts
|
|
1012
|
+
];
|
|
1013
|
+
|
|
1014
|
+
const projectPrompts = [
|
|
1015
|
+
{
|
|
1016
|
+
type: 'input',
|
|
1017
|
+
name: 'projectName',
|
|
1018
|
+
message: 'What is the Project Name?',
|
|
1019
|
+
default: (answers) => {
|
|
1020
|
+
// Derive framework from deploymentConfig if not already set
|
|
1021
|
+
const framework = answers.framework || answers.deploymentConfig?.split('-')[0];
|
|
1022
|
+
return generateProjectName(framework);
|
|
1023
|
+
}
|
|
1024
|
+
}
|
|
1025
|
+
];
|
|
1026
|
+
|
|
1027
|
+
const destinationPrompts = [
|
|
1028
|
+
{
|
|
1029
|
+
type: 'input',
|
|
1030
|
+
name: 'destinationDir',
|
|
1031
|
+
message: 'Where will the output directory be?',
|
|
1032
|
+
default: (answers) => {
|
|
1033
|
+
const timestamp = new Date().toISOString().replace(/[:.]/g, '-').slice(0, 19);
|
|
1034
|
+
return `./${answers.projectName}-${timestamp}`;
|
|
1035
|
+
}
|
|
1036
|
+
}
|
|
1037
|
+
];
|
|
1038
|
+
|
|
1039
|
+
/**
|
|
1040
|
+
* Format ImageEntry[] into Inquirer list choices with tabular display.
|
|
1041
|
+
*
|
|
1042
|
+
* @param {ImageEntry[]} entries - Image entries from the resolver
|
|
1043
|
+
* @param {boolean} isTransformer - Whether to show CUDA column
|
|
1044
|
+
* @returns {Array<{name: string, value: string}>} Inquirer choices
|
|
1045
|
+
*/
|
|
1046
|
+
function formatImageChoices(entries, isTransformer) {
|
|
1047
|
+
return entries.map(entry => {
|
|
1048
|
+
const cuda = entry.labels.cuda_version || '-'
|
|
1049
|
+
const python = entry.labels.python_version || '-'
|
|
1050
|
+
const date = entry.created.slice(0, 10)
|
|
1051
|
+
|
|
1052
|
+
const name = isTransformer
|
|
1053
|
+
? `${entry.repository.padEnd(30)} ${entry.tag.padEnd(16)} ${entry.architecture.padEnd(7)} ${cuda.padEnd(6)} ${python.padEnd(8)} ${date}`
|
|
1054
|
+
: `${entry.repository.padEnd(30)} ${entry.tag.padEnd(16)} ${entry.architecture.padEnd(7)} ${python.padEnd(8)} ${date}`
|
|
1055
|
+
|
|
1056
|
+
return { name, value: entry.image }
|
|
1057
|
+
})
|
|
1058
|
+
}
|
|
1059
|
+
|
|
1060
|
+
/**
|
|
1061
|
+
* Base image search prompt (non-transformer only)
|
|
1062
|
+
* Requirements: 5.2, 5.4
|
|
1063
|
+
*/
|
|
1064
|
+
const baseImageSearchPrompts = [
|
|
1065
|
+
{
|
|
1066
|
+
type: 'input',
|
|
1067
|
+
name: 'baseImageSearch',
|
|
1068
|
+
message: '🔌 Search for a Python base image (e.g. "3.11", "3.10", or leave empty for all):',
|
|
1069
|
+
default: '',
|
|
1070
|
+
when: (answers) => {
|
|
1071
|
+
const architecture = answers.architecture || answers.deploymentConfig?.split('-')[0]
|
|
1072
|
+
// Skip for transformers (uses model-server images) and triton (uses NGC images)
|
|
1073
|
+
return architecture !== 'transformers' && architecture !== 'triton'
|
|
1074
|
+
}
|
|
1075
|
+
}
|
|
1076
|
+
]
|
|
1077
|
+
|
|
1078
|
+
/**
|
|
1079
|
+
* Base image selection prompt (all frameworks)
|
|
1080
|
+
* Requirements: 5.2, 5.4, 10.1, 10.2, 10.3
|
|
1081
|
+
*/
|
|
1082
|
+
const baseImagePrompts = [
|
|
1083
|
+
{
|
|
1084
|
+
type: 'list',
|
|
1085
|
+
name: 'baseImage',
|
|
1086
|
+
message: 'Select base container image:',
|
|
1087
|
+
choices: (answers) => {
|
|
1088
|
+
const mcpChoices = answers._mcpBaseImageChoices || []
|
|
1089
|
+
return [...mcpChoices, { name: 'Custom (enter your own)', value: 'custom' }]
|
|
1090
|
+
},
|
|
1091
|
+
when: (answers) => {
|
|
1092
|
+
return answers._mcpBaseImageChoices && answers._mcpBaseImageChoices.length > 0
|
|
1093
|
+
}
|
|
1094
|
+
},
|
|
1095
|
+
{
|
|
1096
|
+
type: 'input',
|
|
1097
|
+
name: 'customBaseImage',
|
|
1098
|
+
message: 'Enter custom base container image (e.g. myrepo/myimage:v1):',
|
|
1099
|
+
validate: (input) => {
|
|
1100
|
+
if (!input || input.trim() === '') {
|
|
1101
|
+
return 'Base image is required'
|
|
1102
|
+
}
|
|
1103
|
+
const pattern = /^[a-zA-Z0-9][a-zA-Z0-9._\-\/]*(:[a-zA-Z0-9._\-]+)?$/
|
|
1104
|
+
if (!pattern.test(input.trim())) {
|
|
1105
|
+
return 'Invalid image format. Expected: [registry/]repository[:tag]'
|
|
1106
|
+
}
|
|
1107
|
+
return true
|
|
1108
|
+
},
|
|
1109
|
+
when: (answers) => answers.baseImage === 'custom'
|
|
1110
|
+
}
|
|
1111
|
+
]
|
|
1112
|
+
|
|
1113
|
+
export {
|
|
1114
|
+
deploymentConfigPrompts,
|
|
1115
|
+
frameworkPrompts, // Deprecated: kept for backward compatibility
|
|
1116
|
+
enginePrompts,
|
|
1117
|
+
frameworkVersionPrompts,
|
|
1118
|
+
frameworkProfilePrompts,
|
|
1119
|
+
modelFormatPrompts,
|
|
1120
|
+
modelServerPrompts, // Deprecated: now empty, modelServer derived from deploymentConfig
|
|
1121
|
+
modelLoadStrategyPrompts,
|
|
1122
|
+
modelProfilePrompts,
|
|
1123
|
+
hfTokenPrompts,
|
|
1124
|
+
ngcApiKeyPrompts,
|
|
1125
|
+
modulePrompts,
|
|
1126
|
+
infrastructurePrompts,
|
|
1127
|
+
infraRegionAndTargetPrompts,
|
|
1128
|
+
infraInstancePrompts,
|
|
1129
|
+
infraAsyncPrompts,
|
|
1130
|
+
infraBatchTransformPrompts,
|
|
1131
|
+
infraHyperPodPrompts,
|
|
1132
|
+
infraBuildPrompts,
|
|
1133
|
+
projectPrompts,
|
|
1134
|
+
destinationPrompts,
|
|
1135
|
+
baseImageSearchPrompts,
|
|
1136
|
+
baseImagePrompts,
|
|
1137
|
+
formatImageChoices
|
|
1138
|
+
};
|