@aws/ml-container-creator 0.2.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +202 -0
- package/LICENSE-THIRD-PARTY +68620 -0
- package/NOTICE +2 -0
- package/README.md +106 -0
- package/bin/cli.js +365 -0
- package/config/defaults.json +32 -0
- package/config/presets/transformers-djl.json +26 -0
- package/config/presets/transformers-gpu.json +24 -0
- package/config/presets/transformers-lmi.json +27 -0
- package/package.json +129 -0
- package/servers/README.md +419 -0
- package/servers/base-image-picker/catalogs/model-servers.json +1191 -0
- package/servers/base-image-picker/catalogs/python-slim.json +38 -0
- package/servers/base-image-picker/catalogs/triton-backends.json +51 -0
- package/servers/base-image-picker/catalogs/triton.json +38 -0
- package/servers/base-image-picker/index.js +495 -0
- package/servers/base-image-picker/manifest.json +17 -0
- package/servers/base-image-picker/package.json +15 -0
- package/servers/hyperpod-cluster-picker/LICENSE +202 -0
- package/servers/hyperpod-cluster-picker/index.js +424 -0
- package/servers/hyperpod-cluster-picker/manifest.json +14 -0
- package/servers/hyperpod-cluster-picker/package.json +17 -0
- package/servers/instance-recommender/LICENSE +202 -0
- package/servers/instance-recommender/catalogs/instances.json +852 -0
- package/servers/instance-recommender/index.js +284 -0
- package/servers/instance-recommender/manifest.json +16 -0
- package/servers/instance-recommender/package.json +15 -0
- package/servers/lib/LICENSE +202 -0
- package/servers/lib/bedrock-client.js +160 -0
- package/servers/lib/custom-validators.js +46 -0
- package/servers/lib/dynamic-resolver.js +36 -0
- package/servers/lib/package.json +11 -0
- package/servers/lib/schemas/image-catalog.schema.json +185 -0
- package/servers/lib/schemas/instances.schema.json +124 -0
- package/servers/lib/schemas/manifest.schema.json +64 -0
- package/servers/lib/schemas/model-catalog.schema.json +91 -0
- package/servers/lib/schemas/regions.schema.json +26 -0
- package/servers/lib/schemas/triton-backends.schema.json +51 -0
- package/servers/model-picker/catalogs/jumpstart-public.json +66 -0
- package/servers/model-picker/catalogs/popular-diffusors.json +88 -0
- package/servers/model-picker/catalogs/popular-transformers.json +226 -0
- package/servers/model-picker/index.js +1693 -0
- package/servers/model-picker/manifest.json +18 -0
- package/servers/model-picker/package.json +20 -0
- package/servers/region-picker/LICENSE +202 -0
- package/servers/region-picker/catalogs/regions.json +263 -0
- package/servers/region-picker/index.js +230 -0
- package/servers/region-picker/manifest.json +16 -0
- package/servers/region-picker/package.json +15 -0
- package/src/app.js +1007 -0
- package/src/copy-tpl.js +77 -0
- package/src/lib/accelerator-validator.js +39 -0
- package/src/lib/asset-manager.js +385 -0
- package/src/lib/aws-profile-parser.js +181 -0
- package/src/lib/bootstrap-command-handler.js +1647 -0
- package/src/lib/bootstrap-config.js +238 -0
- package/src/lib/ci-register-helpers.js +124 -0
- package/src/lib/ci-report-helpers.js +158 -0
- package/src/lib/ci-stage-helpers.js +268 -0
- package/src/lib/cli-handler.js +529 -0
- package/src/lib/comment-generator.js +544 -0
- package/src/lib/community-reports-validator.js +91 -0
- package/src/lib/config-manager.js +2106 -0
- package/src/lib/configuration-exporter.js +204 -0
- package/src/lib/configuration-manager.js +695 -0
- package/src/lib/configuration-matcher.js +221 -0
- package/src/lib/cpu-validator.js +36 -0
- package/src/lib/cuda-validator.js +57 -0
- package/src/lib/deployment-config-resolver.js +103 -0
- package/src/lib/deployment-entry-schema.js +125 -0
- package/src/lib/deployment-registry.js +598 -0
- package/src/lib/docker-introspection-validator.js +51 -0
- package/src/lib/engine-prefix-resolver.js +60 -0
- package/src/lib/huggingface-client.js +172 -0
- package/src/lib/key-value-parser.js +37 -0
- package/src/lib/known-flags-validator.js +200 -0
- package/src/lib/manifest-cli.js +280 -0
- package/src/lib/mcp-client.js +303 -0
- package/src/lib/mcp-command-handler.js +532 -0
- package/src/lib/neuron-validator.js +80 -0
- package/src/lib/parameter-schema-validator.js +284 -0
- package/src/lib/prompt-runner.js +1349 -0
- package/src/lib/prompts.js +1138 -0
- package/src/lib/registry-command-handler.js +519 -0
- package/src/lib/registry-loader.js +198 -0
- package/src/lib/rocm-validator.js +80 -0
- package/src/lib/schema-validator.js +157 -0
- package/src/lib/sensitive-redactor.js +59 -0
- package/src/lib/template-engine.js +156 -0
- package/src/lib/template-manager.js +341 -0
- package/src/lib/validation-engine.js +314 -0
- package/src/prompt-adapter.js +63 -0
- package/templates/Dockerfile +300 -0
- package/templates/IAM_PERMISSIONS.md +84 -0
- package/templates/MIGRATION.md +488 -0
- package/templates/PROJECT_README.md +439 -0
- package/templates/TEMPLATE_SYSTEM.md +243 -0
- package/templates/buildspec.yml +64 -0
- package/templates/code/chat_template.jinja +1 -0
- package/templates/code/flask/gunicorn_config.py +35 -0
- package/templates/code/flask/wsgi.py +10 -0
- package/templates/code/model_handler.py +387 -0
- package/templates/code/serve +300 -0
- package/templates/code/serve.py +175 -0
- package/templates/code/serving.properties +105 -0
- package/templates/code/start_server.py +39 -0
- package/templates/code/start_server.sh +39 -0
- package/templates/diffusors/Dockerfile +72 -0
- package/templates/diffusors/patch_image_api.py +35 -0
- package/templates/diffusors/serve +115 -0
- package/templates/diffusors/start_server.sh +114 -0
- package/templates/do/.gitkeep +1 -0
- package/templates/do/README.md +541 -0
- package/templates/do/build +83 -0
- package/templates/do/ci +681 -0
- package/templates/do/clean +811 -0
- package/templates/do/config +260 -0
- package/templates/do/deploy +1560 -0
- package/templates/do/export +306 -0
- package/templates/do/logs +319 -0
- package/templates/do/manifest +12 -0
- package/templates/do/push +119 -0
- package/templates/do/register +580 -0
- package/templates/do/run +113 -0
- package/templates/do/submit +417 -0
- package/templates/do/test +1147 -0
- package/templates/hyperpod/configmap.yaml +24 -0
- package/templates/hyperpod/deployment.yaml +71 -0
- package/templates/hyperpod/pvc.yaml +42 -0
- package/templates/hyperpod/service.yaml +17 -0
- package/templates/nginx-diffusors.conf +74 -0
- package/templates/nginx-predictors.conf +47 -0
- package/templates/nginx-tensorrt.conf +74 -0
- package/templates/requirements.txt +61 -0
- package/templates/sample_model/test_inference.py +123 -0
- package/templates/sample_model/train_abalone.py +252 -0
- package/templates/test/test_endpoint.sh +79 -0
- package/templates/test/test_local_image.sh +80 -0
- package/templates/test/test_model_handler.py +180 -0
- package/templates/triton/Dockerfile +128 -0
- package/templates/triton/config.pbtxt +163 -0
- package/templates/triton/model.py +130 -0
- package/templates/triton/requirements.txt +11 -0
|
@@ -0,0 +1,1349 @@
|
|
|
1
|
+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
// SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
/**
|
|
5
|
+
* Prompt Runner - Orchestrates the prompting phases with clear user feedback
|
|
6
|
+
*
|
|
7
|
+
* This module handles running prompts in organized phases with console output
|
|
8
|
+
* to guide users through the configuration process.
|
|
9
|
+
*/
|
|
10
|
+
|
|
11
|
+
import {
|
|
12
|
+
deploymentConfigPrompts,
|
|
13
|
+
enginePrompts,
|
|
14
|
+
frameworkVersionPrompts,
|
|
15
|
+
frameworkProfilePrompts,
|
|
16
|
+
modelFormatPrompts,
|
|
17
|
+
modelServerPrompts,
|
|
18
|
+
modelLoadStrategyPrompts,
|
|
19
|
+
modelProfilePrompts,
|
|
20
|
+
hfTokenPrompts,
|
|
21
|
+
ngcApiKeyPrompts,
|
|
22
|
+
modulePrompts,
|
|
23
|
+
infraRegionAndTargetPrompts,
|
|
24
|
+
infraInstancePrompts,
|
|
25
|
+
infraAsyncPrompts,
|
|
26
|
+
infraBatchTransformPrompts,
|
|
27
|
+
infraHyperPodPrompts,
|
|
28
|
+
infraBuildPrompts,
|
|
29
|
+
projectPrompts,
|
|
30
|
+
destinationPrompts,
|
|
31
|
+
baseImageSearchPrompts,
|
|
32
|
+
baseImagePrompts,
|
|
33
|
+
formatImageChoices
|
|
34
|
+
} from './prompts.js';
|
|
35
|
+
|
|
36
|
+
import fs from 'fs';
|
|
37
|
+
import path from 'path';
|
|
38
|
+
import { fileURLToPath } from 'node:url';
|
|
39
|
+
import RegistryLoader from './registry-loader.js';
|
|
40
|
+
import { runPrompts } from '../prompt-adapter.js';
|
|
41
|
+
|
|
42
|
+
const __pr_filename = fileURLToPath(import.meta.url);
|
|
43
|
+
const __pr_dirname = path.dirname(__pr_filename);
|
|
44
|
+
const GENERATOR_ROOT = path.resolve(__pr_dirname, '..', '..');
|
|
45
|
+
|
|
46
|
+
export default class PromptRunner {
|
|
47
|
+
constructor({ configManager, options, registryConfigManager, baseConfig, promptFn }) {
|
|
48
|
+
this.configManager = configManager;
|
|
49
|
+
this.options = options || {};
|
|
50
|
+
this.registryConfigManager = registryConfigManager || null;
|
|
51
|
+
this.baseConfig = baseConfig || {};
|
|
52
|
+
this._runPrompts = promptFn || runPrompts;
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
/**
|
|
56
|
+
* Runs all prompting phases and returns combined answers
|
|
57
|
+
* @returns {Promise<Object>} Combined answers from all phases
|
|
58
|
+
*/
|
|
59
|
+
async run() {
|
|
60
|
+
const buildTimestamp = new Date().toISOString().replace(/[:.]/g, '-').slice(0, 19);
|
|
61
|
+
|
|
62
|
+
// Load catalog data via Registry_Loader
|
|
63
|
+
const registryLoader = new RegistryLoader()
|
|
64
|
+
this._tritonBackends = await registryLoader.loadTritonBackends()
|
|
65
|
+
this._instanceAcceleratorMapping = await registryLoader.loadInstanceAcceleratorMapping()
|
|
66
|
+
|
|
67
|
+
// Get existing configuration to use as defaults
|
|
68
|
+
const existingConfig = this.baseConfig || {};
|
|
69
|
+
|
|
70
|
+
// Get only explicit configuration (not defaults) for prompt skipping
|
|
71
|
+
const explicitConfig = this.configManager ? this.configManager.getExplicitConfiguration() : {};
|
|
72
|
+
|
|
73
|
+
// Phase 1: Infrastructure & Deployment
|
|
74
|
+
// Requirements: 3.1 ā infrastructure prompts run first
|
|
75
|
+
// Ordering: Region ā Deployment Target ā Instance (if managed) ā HyperPod (if eks) ā Build Target
|
|
76
|
+
console.log('\nšŖ Infrastructure & Deployment');
|
|
77
|
+
|
|
78
|
+
// 1a. Query region MCP, then prompt for region + deployment target
|
|
79
|
+
await this._queryMcpForRegion({}, explicitConfig);
|
|
80
|
+
const bootstrapRegion = existingConfig.awsRegion || explicitConfig.awsRegion
|
|
81
|
+
const regionPreviousAnswers = bootstrapRegion ? { _bootstrapRegion: bootstrapRegion } : {}
|
|
82
|
+
const regionAndTargetAnswers = await this._runPhase(infraRegionAndTargetPrompts, regionPreviousAnswers, explicitConfig, existingConfig);
|
|
83
|
+
|
|
84
|
+
// 1b. Instance type ā query MCP and prompt for managed-inference, async-inference, batch-transform, and hyperpod-eks
|
|
85
|
+
let instanceAnswers = {};
|
|
86
|
+
if (regionAndTargetAnswers.deploymentTarget === 'managed-inference' ||
|
|
87
|
+
regionAndTargetAnswers.deploymentTarget === 'async-inference' ||
|
|
88
|
+
regionAndTargetAnswers.deploymentTarget === 'batch-transform' ||
|
|
89
|
+
regionAndTargetAnswers.deploymentTarget === 'hyperpod-eks') {
|
|
90
|
+
await this._queryMcpForInstance({}, explicitConfig);
|
|
91
|
+
const mcpInstanceChoices = this.configManager?.mcpChoices?.instanceType;
|
|
92
|
+
const instancePreviousAnswers = {
|
|
93
|
+
...regionAndTargetAnswers,
|
|
94
|
+
...(mcpInstanceChoices && mcpInstanceChoices.length > 0 ? { _mcpInstanceChoices: mcpInstanceChoices } : {})
|
|
95
|
+
};
|
|
96
|
+
instanceAnswers = await this._runPhase(infraInstancePrompts, instancePreviousAnswers, explicitConfig, existingConfig);
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
// 1b-async. Async-specific prompts (only when deploymentTarget === 'async-inference')
|
|
100
|
+
let asyncAnswers = {};
|
|
101
|
+
if (regionAndTargetAnswers.deploymentTarget === 'async-inference') {
|
|
102
|
+
asyncAnswers = await this._runPhase(infraAsyncPrompts, { ...regionAndTargetAnswers }, explicitConfig, existingConfig);
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
// 1b-batch. Batch transform-specific prompts (only when deploymentTarget === 'batch-transform')
|
|
106
|
+
let batchTransformAnswers = {};
|
|
107
|
+
if (regionAndTargetAnswers.deploymentTarget === 'batch-transform') {
|
|
108
|
+
batchTransformAnswers = await this._runPhase(
|
|
109
|
+
infraBatchTransformPrompts,
|
|
110
|
+
{ ...regionAndTargetAnswers },
|
|
111
|
+
explicitConfig,
|
|
112
|
+
existingConfig
|
|
113
|
+
);
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
// 1c. HyperPod prompts ā only query MCP and prompt when deployment target is hyperpod-eks
|
|
117
|
+
let hyperPodAnswers = {};
|
|
118
|
+
if (regionAndTargetAnswers.deploymentTarget === 'hyperpod-eks') {
|
|
119
|
+
// Resolve the actual region (handle 'custom' selection)
|
|
120
|
+
const resolvedRegion = regionAndTargetAnswers.customAwsRegion || regionAndTargetAnswers.awsRegion;
|
|
121
|
+
await this._queryMcpForHyperPod({ ...regionAndTargetAnswers, awsRegion: resolvedRegion }, explicitConfig);
|
|
122
|
+
hyperPodAnswers = await this._runPhase(infraHyperPodPrompts, { ...regionAndTargetAnswers }, explicitConfig, existingConfig);
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
// 1d. Build target + role ARN (always)
|
|
126
|
+
const buildAnswers = await this._runPhase(infraBuildPrompts, { ...regionAndTargetAnswers, ...instanceAnswers, ...hyperPodAnswers }, explicitConfig, existingConfig);
|
|
127
|
+
|
|
128
|
+
// Combine all infrastructure answers
|
|
129
|
+
const infraAnswers = {
|
|
130
|
+
...regionAndTargetAnswers,
|
|
131
|
+
...instanceAnswers,
|
|
132
|
+
...asyncAnswers,
|
|
133
|
+
...batchTransformAnswers,
|
|
134
|
+
...hyperPodAnswers,
|
|
135
|
+
...buildAnswers
|
|
136
|
+
};
|
|
137
|
+
|
|
138
|
+
// Phase 2: Core ML Configuration
|
|
139
|
+
// Requirements: 3.1, 3.2 ā ML configuration prompts run after infrastructure
|
|
140
|
+
console.log('\nš§ Core ML Configuration');
|
|
141
|
+
const deploymentConfigAnswers = await this._runPhase(deploymentConfigPrompts, { ...infraAnswers }, explicitConfig, existingConfig);
|
|
142
|
+
|
|
143
|
+
// Derive architecture, backend, and legacy framework/modelServer from deploymentConfig
|
|
144
|
+
// Requirements: 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7
|
|
145
|
+
let architecture, backend, framework, modelServer;
|
|
146
|
+
if (deploymentConfigAnswers.deploymentConfig) {
|
|
147
|
+
const parts = deploymentConfigAnswers.deploymentConfig.split('-');
|
|
148
|
+
architecture = parts[0];
|
|
149
|
+
backend = parts.slice(1).join('-');
|
|
150
|
+
// Legacy compatibility: derive framework and modelServer
|
|
151
|
+
framework = architecture;
|
|
152
|
+
modelServer = backend;
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
// Add derived values to answers
|
|
156
|
+
const frameworkAnswers = {
|
|
157
|
+
...deploymentConfigAnswers,
|
|
158
|
+
architecture: architecture || deploymentConfigAnswers.architecture,
|
|
159
|
+
backend: backend || deploymentConfigAnswers.backend,
|
|
160
|
+
framework: framework || deploymentConfigAnswers.framework,
|
|
161
|
+
modelServer: modelServer || deploymentConfigAnswers.modelServer
|
|
162
|
+
};
|
|
163
|
+
|
|
164
|
+
// Engine prompt for http architecture
|
|
165
|
+
// Requirements: 3.7
|
|
166
|
+
const engineAnswers = await this._runPhase(enginePrompts, { ...frameworkAnswers }, explicitConfig, existingConfig);
|
|
167
|
+
|
|
168
|
+
// Auto-set model format for Triton backends with single format
|
|
169
|
+
// Requirements: 3.3, 3.4, 3.5
|
|
170
|
+
const tritonAutoFormat = this._getTritonAutoModelFormat(architecture, backend);
|
|
171
|
+
|
|
172
|
+
// Query base-image-picker MCP server for base image choices
|
|
173
|
+
// Requirements: 5.1, 5.2, 5.3
|
|
174
|
+
await this._queryMcpForBaseImage(frameworkAnswers, explicitConfig)
|
|
175
|
+
const baseImagePreviousAnswers = {
|
|
176
|
+
...frameworkAnswers,
|
|
177
|
+
...engineAnswers,
|
|
178
|
+
...(this._mcpBaseImageChoices ? { _mcpBaseImageChoices: this._mcpBaseImageChoices } : {})
|
|
179
|
+
}
|
|
180
|
+
const baseImageAnswers = await this._runPhase(
|
|
181
|
+
baseImagePrompts,
|
|
182
|
+
baseImagePreviousAnswers,
|
|
183
|
+
explicitConfig,
|
|
184
|
+
existingConfig
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
// Populate framework version choices from registry
|
|
188
|
+
const frameworkVersionChoices = this._getFrameworkVersionChoices(frameworkAnswers.framework);
|
|
189
|
+
const frameworkVersionAnswers = await this._runPhase(
|
|
190
|
+
frameworkVersionPrompts,
|
|
191
|
+
{...frameworkAnswers, ...engineAnswers, _frameworkVersionChoices: frameworkVersionChoices},
|
|
192
|
+
explicitConfig,
|
|
193
|
+
existingConfig
|
|
194
|
+
);
|
|
195
|
+
|
|
196
|
+
// Display validation information if version was selected
|
|
197
|
+
if (frameworkVersionAnswers.frameworkVersion) {
|
|
198
|
+
this._displayFrameworkValidationInfo(frameworkAnswers.framework, frameworkVersionAnswers.frameworkVersion);
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
// Populate framework profile choices from registry
|
|
202
|
+
const frameworkProfileChoices = this._getFrameworkProfileChoices(
|
|
203
|
+
frameworkAnswers.framework,
|
|
204
|
+
frameworkVersionAnswers.frameworkVersion
|
|
205
|
+
);
|
|
206
|
+
const frameworkProfileAnswers = await this._runPhase(
|
|
207
|
+
frameworkProfilePrompts,
|
|
208
|
+
{...frameworkAnswers, ...engineAnswers, ...frameworkVersionAnswers, _frameworkProfileChoices: frameworkProfileChoices},
|
|
209
|
+
explicitConfig,
|
|
210
|
+
existingConfig
|
|
211
|
+
);
|
|
212
|
+
|
|
213
|
+
// Query model-picker MCP server for model choices
|
|
214
|
+
this._queryMcpForModels(frameworkAnswers.architecture)
|
|
215
|
+
if (this._mcpModelChoices) {
|
|
216
|
+
console.log(` š Querying model-picker...`)
|
|
217
|
+
console.log(` ā ${this._mcpModelChoices.length} model(s) available from catalog`)
|
|
218
|
+
}
|
|
219
|
+
const modelFormatPreviousAnswers = {
|
|
220
|
+
...frameworkAnswers,
|
|
221
|
+
...engineAnswers,
|
|
222
|
+
...frameworkVersionAnswers,
|
|
223
|
+
...frameworkProfileAnswers,
|
|
224
|
+
...(this._mcpModelChoices ? { _mcpModelChoices: this._mcpModelChoices } : {})
|
|
225
|
+
}
|
|
226
|
+
const modelFormatAnswers = await this._runPhase(
|
|
227
|
+
modelFormatPrompts,
|
|
228
|
+
modelFormatPreviousAnswers,
|
|
229
|
+
explicitConfig,
|
|
230
|
+
existingConfig
|
|
231
|
+
);
|
|
232
|
+
|
|
233
|
+
// Model server prompts are now deprecated (empty array)
|
|
234
|
+
const modelServerAnswers = await this._runPhase(
|
|
235
|
+
modelServerPrompts,
|
|
236
|
+
{...frameworkAnswers, ...engineAnswers, ...frameworkVersionAnswers, ...frameworkProfileAnswers},
|
|
237
|
+
explicitConfig,
|
|
238
|
+
existingConfig
|
|
239
|
+
);
|
|
240
|
+
|
|
241
|
+
// Populate model profile choices from registry (if model ID is available)
|
|
242
|
+
const currentAnswers = {...frameworkAnswers, ...engineAnswers, ...frameworkVersionAnswers, ...frameworkProfileAnswers, ...modelFormatAnswers, ...modelServerAnswers};
|
|
243
|
+
const modelId = currentAnswers.customModelName || currentAnswers.modelName || explicitConfig.modelName;
|
|
244
|
+
|
|
245
|
+
// Fetch model information from HuggingFace and Model Registry
|
|
246
|
+
// Requirements: 5.1, 5.2, 5.3, 5.4, 5.5, 5.6, 5.11, 11.1, 11.2, 11.3, 11.5, 11.6, 11.7
|
|
247
|
+
if (modelId && modelId !== 'Custom (enter manually)') {
|
|
248
|
+
await this._fetchAndDisplayModelInfo(modelId);
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
const modelProfileChoices = this._getModelProfileChoices(modelId);
|
|
252
|
+
const modelProfileAnswers = await this._runPhase(
|
|
253
|
+
modelProfilePrompts,
|
|
254
|
+
{...currentAnswers, _modelProfileChoices: modelProfileChoices},
|
|
255
|
+
explicitConfig,
|
|
256
|
+
existingConfig
|
|
257
|
+
);
|
|
258
|
+
|
|
259
|
+
// Model loading strategy prompt (build-time vs runtime)
|
|
260
|
+
// Requirements: 13.1, 13.2, 13.3, 13.4, 13.5
|
|
261
|
+
const modelLoadStrategyAnswers = await this._runPhase(
|
|
262
|
+
modelLoadStrategyPrompts,
|
|
263
|
+
{ ...frameworkAnswers, ...engineAnswers, ...modelFormatAnswers, ...modelServerAnswers, ...modelProfileAnswers },
|
|
264
|
+
explicitConfig,
|
|
265
|
+
existingConfig
|
|
266
|
+
);
|
|
267
|
+
|
|
268
|
+
const hfTokenAnswers = await this._runPhase(hfTokenPrompts,
|
|
269
|
+
{ ...frameworkAnswers, ...engineAnswers, ...frameworkVersionAnswers, ...frameworkProfileAnswers, ...modelFormatAnswers, ...modelServerAnswers, ...modelProfileAnswers },
|
|
270
|
+
explicitConfig, existingConfig);
|
|
271
|
+
|
|
272
|
+
const ngcApiKeyAnswers = await this._runPhase(ngcApiKeyPrompts,
|
|
273
|
+
{ ...frameworkAnswers, ...engineAnswers, ...frameworkVersionAnswers, ...frameworkProfileAnswers, ...modelFormatAnswers, ...modelServerAnswers, ...modelProfileAnswers },
|
|
274
|
+
explicitConfig, existingConfig);
|
|
275
|
+
|
|
276
|
+
// Validate instance type against framework requirements (now that framework is known)
|
|
277
|
+
// Requirements: 4.1, 4.2, 4.3, 4.4, 4.5, 4.6
|
|
278
|
+
const instanceType = infraAnswers.customInstanceType || infraAnswers.instanceType;
|
|
279
|
+
if (instanceType && frameworkVersionAnswers.frameworkVersion) {
|
|
280
|
+
await this._validateAndDisplayInstanceType(
|
|
281
|
+
instanceType,
|
|
282
|
+
frameworkAnswers.framework,
|
|
283
|
+
frameworkVersionAnswers.frameworkVersion
|
|
284
|
+
);
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
// CUDA version selection: if the selected instance supports multiple CUDA versions,
|
|
288
|
+
// let the user pick which one. This transparently sets the inference AMI version.
|
|
289
|
+
const cudaAnswer = await this._promptCudaVersion(
|
|
290
|
+
instanceType,
|
|
291
|
+
frameworkAnswers.framework,
|
|
292
|
+
frameworkVersionAnswers.frameworkVersion
|
|
293
|
+
);
|
|
294
|
+
if (cudaAnswer) {
|
|
295
|
+
infraAnswers._selectedCudaVersion = cudaAnswer.cudaVersion;
|
|
296
|
+
infraAnswers._resolvedInferenceAmiVersion = cudaAnswer.inferenceAmiVersion;
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
// Phase 3: Module Selection
|
|
300
|
+
// Requirements: 3.3 ā module selection after ML configuration
|
|
301
|
+
console.log('\nš¦ Module Selection');
|
|
302
|
+
const moduleAnswers = await this._runPhase(modulePrompts, { ...frameworkAnswers, ...engineAnswers }, explicitConfig, existingConfig);
|
|
303
|
+
|
|
304
|
+
// Ensure transformers, diffusors, and ineligible Triton backends don't get sample model
|
|
305
|
+
if (frameworkAnswers.architecture === 'transformers' ||
|
|
306
|
+
frameworkAnswers.architecture === 'diffusors' ||
|
|
307
|
+
(frameworkAnswers.architecture === 'triton' &&
|
|
308
|
+
!this._tritonBackends[frameworkAnswers.backend]?.supportsSampleModel)) {
|
|
309
|
+
moduleAnswers.includeSampleModel = false;
|
|
310
|
+
}
|
|
311
|
+
|
|
312
|
+
// Phase 4: Project Configuration
|
|
313
|
+
// Requirements: 3.4 ā project configuration last
|
|
314
|
+
console.log('\nš Project Configuration');
|
|
315
|
+
const allTechnicalAnswers = {
|
|
316
|
+
...frameworkAnswers,
|
|
317
|
+
...engineAnswers,
|
|
318
|
+
...modelFormatAnswers,
|
|
319
|
+
...modelServerAnswers,
|
|
320
|
+
...moduleAnswers,
|
|
321
|
+
...infraAnswers
|
|
322
|
+
};
|
|
323
|
+
const projectAnswers = await this._runPhase(projectPrompts, allTechnicalAnswers, explicitConfig, existingConfig);
|
|
324
|
+
const destinationAnswers = await this._runPhase(destinationPrompts,
|
|
325
|
+
{ ...allTechnicalAnswers, ...projectAnswers }, explicitConfig, existingConfig);
|
|
326
|
+
|
|
327
|
+
// Combine all answers
|
|
328
|
+
const combinedAnswers = {
|
|
329
|
+
...infraAnswers,
|
|
330
|
+
...frameworkAnswers,
|
|
331
|
+
...engineAnswers,
|
|
332
|
+
...baseImageAnswers,
|
|
333
|
+
...frameworkVersionAnswers,
|
|
334
|
+
...frameworkProfileAnswers,
|
|
335
|
+
...modelFormatAnswers,
|
|
336
|
+
...modelServerAnswers,
|
|
337
|
+
...modelProfileAnswers,
|
|
338
|
+
...modelLoadStrategyAnswers,
|
|
339
|
+
...hfTokenAnswers,
|
|
340
|
+
...ngcApiKeyAnswers,
|
|
341
|
+
...moduleAnswers,
|
|
342
|
+
...projectAnswers,
|
|
343
|
+
...destinationAnswers,
|
|
344
|
+
buildTimestamp
|
|
345
|
+
};
|
|
346
|
+
|
|
347
|
+
// Ensure CLI-provided values that were skipped during prompting are in combinedAnswers
|
|
348
|
+
if (explicitConfig.modelName && !combinedAnswers.modelName) {
|
|
349
|
+
combinedAnswers.modelName = explicitConfig.modelName;
|
|
350
|
+
}
|
|
351
|
+
|
|
352
|
+
// Flow model source metadata from model-picker MCP response
|
|
353
|
+
// Requirements: 2.1, 2.2, 2.3, 2.4, 2.5
|
|
354
|
+
if (this._mcpModelSource) {
|
|
355
|
+
combinedAnswers.modelSource = this._mcpModelSource;
|
|
356
|
+
}
|
|
357
|
+
if (this._mcpArtifactUri) {
|
|
358
|
+
combinedAnswers.artifactUri = this._mcpArtifactUri;
|
|
359
|
+
}
|
|
360
|
+
|
|
361
|
+
// Validate: non-HF model sources require an artifact URI
|
|
362
|
+
// Without it, the serve script can't download the model at runtime
|
|
363
|
+
// Infer modelSource from model name prefix if not set by MCP
|
|
364
|
+
const modelName = combinedAnswers.customModelName || combinedAnswers.modelName;
|
|
365
|
+
if (!combinedAnswers.modelSource && modelName) {
|
|
366
|
+
if (modelName.startsWith('s3://')) {
|
|
367
|
+
combinedAnswers.modelSource = 's3';
|
|
368
|
+
combinedAnswers.artifactUri = modelName;
|
|
369
|
+
} else if (modelName.startsWith('jumpstart://')) {
|
|
370
|
+
combinedAnswers.modelSource = 'jumpstart';
|
|
371
|
+
} else if (modelName.startsWith('jumpstart-hub://')) {
|
|
372
|
+
combinedAnswers.modelSource = 'jumpstart-hub';
|
|
373
|
+
} else if (modelName.startsWith('registry://')) {
|
|
374
|
+
combinedAnswers.modelSource = 'registry';
|
|
375
|
+
}
|
|
376
|
+
}
|
|
377
|
+
// For s3:// models, the model name IS the artifact URI
|
|
378
|
+
if (combinedAnswers.modelSource === 's3' && !combinedAnswers.artifactUri) {
|
|
379
|
+
if (modelName && modelName.startsWith('s3://')) {
|
|
380
|
+
combinedAnswers.artifactUri = modelName;
|
|
381
|
+
}
|
|
382
|
+
}
|
|
383
|
+
const downloadSources = ['jumpstart', 's3'];
|
|
384
|
+
if (downloadSources.includes(combinedAnswers.modelSource) && !combinedAnswers.artifactUri) {
|
|
385
|
+
console.log(`\n ā ļø Model source is '${combinedAnswers.modelSource}' but no artifact URI was resolved.`);
|
|
386
|
+
console.log(' The model-picker could not determine the download location.');
|
|
387
|
+
console.log(' Falling back to HuggingFace source ā the model will be loaded by name.');
|
|
388
|
+
console.log(' If this model requires S3 download, set MODEL_ARTIFACT_URI in do/config after generation.\n');
|
|
389
|
+
combinedAnswers.modelSource = 'huggingface';
|
|
390
|
+
}
|
|
391
|
+
|
|
392
|
+
// Registry models ā note about InferenceSpecification requirement
|
|
393
|
+
if (combinedAnswers.modelSource === 'registry') {
|
|
394
|
+
if (!combinedAnswers.artifactUri) {
|
|
395
|
+
console.log(`\n ā ļø Model source is 'registry' but no artifact URI was resolved.`);
|
|
396
|
+
console.log(' The model package must have an InferenceSpecification with a valid');
|
|
397
|
+
console.log(' ModelDataUrl or S3DataSource for the runtime resolver to work.');
|
|
398
|
+
console.log(' If your model package was registered without an InferenceSpecification,');
|
|
399
|
+
console.log(' use the S3 path directly instead: --model-name="s3://bucket/path/model.tar.gz"');
|
|
400
|
+
console.log(' Or set MODEL_ARTIFACT_URI in do/config before deploying.\n');
|
|
401
|
+
} else {
|
|
402
|
+
console.log('\n ā¹ļø Registry model: the container will resolve the artifact URI at startup');
|
|
403
|
+
console.log(' via DescribeModelPackage. Ensure the model package has a valid');
|
|
404
|
+
console.log(' InferenceSpecification with ModelDataUrl or S3DataSource.\n');
|
|
405
|
+
}
|
|
406
|
+
}
|
|
407
|
+
|
|
408
|
+
// Warn about jumpstart-hub:// models ā private hub deployment requires
|
|
409
|
+
// HubAccessConfig on CreateModel, which is not yet supported by the generator.
|
|
410
|
+
if (combinedAnswers.modelSource === 'jumpstart-hub') {
|
|
411
|
+
console.log('\n ā ļø JumpStart Private Hub models are not yet fully supported.');
|
|
412
|
+
console.log(' Private hub artifacts live in AWS-managed S3 buckets that require');
|
|
413
|
+
console.log(' SageMaker\'s HubAccessConfig mechanism for access.');
|
|
414
|
+
console.log(' The generated project will not be able to download model artifacts at runtime.');
|
|
415
|
+
console.log(' This feature is tracked for a future release.\n');
|
|
416
|
+
console.log(' Falling back to HuggingFace source.\n');
|
|
417
|
+
combinedAnswers.modelSource = 'huggingface';
|
|
418
|
+
delete combinedAnswers.artifactUri;
|
|
419
|
+
}
|
|
420
|
+
|
|
421
|
+
// Apply auto-set model format for Triton backends with single format
|
|
422
|
+
// Requirements: 3.3, 3.4, 3.5
|
|
423
|
+
if (tritonAutoFormat) {
|
|
424
|
+
combinedAnswers.modelFormat = tritonAutoFormat
|
|
425
|
+
}
|
|
426
|
+
|
|
427
|
+
// Handle custom model name for transformers, diffusors, and Triton LLM backends
|
|
428
|
+
if ((combinedAnswers.architecture === 'transformers' ||
|
|
429
|
+
combinedAnswers.architecture === 'diffusors' ||
|
|
430
|
+
(combinedAnswers.architecture === 'triton' && (combinedAnswers.backend === 'vllm' || combinedAnswers.backend === 'tensorrtllm')))
|
|
431
|
+
&& combinedAnswers.customModelName) {
|
|
432
|
+
combinedAnswers.modelName = combinedAnswers.customModelName;
|
|
433
|
+
delete combinedAnswers.customModelName;
|
|
434
|
+
}
|
|
435
|
+
|
|
436
|
+
// Handle custom instance type
|
|
437
|
+
if (combinedAnswers.customInstanceType) {
|
|
438
|
+
combinedAnswers.instanceType = combinedAnswers.customInstanceType;
|
|
439
|
+
delete combinedAnswers.customInstanceType;
|
|
440
|
+
}
|
|
441
|
+
|
|
442
|
+
// Handle custom HyperPod cluster name
|
|
443
|
+
if (combinedAnswers.customHyperPodCluster) {
|
|
444
|
+
combinedAnswers.hyperPodCluster = combinedAnswers.customHyperPodCluster;
|
|
445
|
+
delete combinedAnswers.customHyperPodCluster;
|
|
446
|
+
}
|
|
447
|
+
|
|
448
|
+
// Apply CUDA version selection ā inference AMI override
|
|
449
|
+
if (combinedAnswers._resolvedInferenceAmiVersion) {
|
|
450
|
+
combinedAnswers.inferenceAmiVersion = combinedAnswers._resolvedInferenceAmiVersion;
|
|
451
|
+
}
|
|
452
|
+
if (combinedAnswers._selectedCudaVersion) {
|
|
453
|
+
combinedAnswers.selectedCudaVersion = combinedAnswers._selectedCudaVersion;
|
|
454
|
+
}
|
|
455
|
+
// Clean up internal fields
|
|
456
|
+
delete combinedAnswers._resolvedInferenceAmiVersion;
|
|
457
|
+
delete combinedAnswers._selectedCudaVersion;
|
|
458
|
+
|
|
459
|
+
// Handle custom AWS region
|
|
460
|
+
if (combinedAnswers.customAwsRegion) {
|
|
461
|
+
combinedAnswers.awsRegion = combinedAnswers.customAwsRegion;
|
|
462
|
+
delete combinedAnswers.customAwsRegion;
|
|
463
|
+
}
|
|
464
|
+
|
|
465
|
+
// Handle custom base image
|
|
466
|
+
if (combinedAnswers.customBaseImage) {
|
|
467
|
+
combinedAnswers.baseImage = combinedAnswers.customBaseImage
|
|
468
|
+
combinedAnswers._baseImageSource = 'custom'
|
|
469
|
+
delete combinedAnswers.customBaseImage
|
|
470
|
+
}
|
|
471
|
+
|
|
472
|
+
// Handle --base-image CLI override
|
|
473
|
+
if (this.options['base-image']) {
|
|
474
|
+
combinedAnswers.baseImage = this.options['base-image']
|
|
475
|
+
}
|
|
476
|
+
|
|
477
|
+
// Map awsRoleArn to roleArn for templates
|
|
478
|
+
if (combinedAnswers.awsRoleArn) {
|
|
479
|
+
combinedAnswers.roleArn = combinedAnswers.awsRoleArn;
|
|
480
|
+
delete combinedAnswers.awsRoleArn;
|
|
481
|
+
}
|
|
482
|
+
|
|
483
|
+
return combinedAnswers;
|
|
484
|
+
}
|
|
485
|
+
|
|
486
|
+
/**
|
|
487
|
+
* Checks if a parameter is promptable according to the parameter matrix
|
|
488
|
+
* @param {string} parameterName - Name of the parameter
|
|
489
|
+
* @returns {boolean} True if parameter is promptable
|
|
490
|
+
* @private
|
|
491
|
+
*/
|
|
492
|
+
_isParameterPromptable(parameterName) {
|
|
493
|
+
if (!this.configManager || !this.configManager.parameterMatrix) {
|
|
494
|
+
return true; // Default to promptable if matrix not available
|
|
495
|
+
}
|
|
496
|
+
|
|
497
|
+
const paramConfig = this.configManager.parameterMatrix[parameterName];
|
|
498
|
+
return paramConfig ? paramConfig.promptable : true;
|
|
499
|
+
}
|
|
500
|
+
|
|
501
|
+
/**
|
|
502
|
+
* Filters prompts to exclude non-promptable parameters
|
|
503
|
+
* @param {Array} prompts - Array of prompt objects
|
|
504
|
+
* @returns {Array} Filtered prompts excluding non-promptable parameters
|
|
505
|
+
* @private
|
|
506
|
+
*/
|
|
507
|
+
_filterPromptableParameters(prompts) {
|
|
508
|
+
return prompts.filter(prompt => this._isParameterPromptable(prompt.name));
|
|
509
|
+
}
|
|
510
|
+
|
|
511
|
+
/**
|
|
512
|
+
* Runs a single phase of prompts
|
|
513
|
+
* @private
|
|
514
|
+
*/
|
|
515
|
+
async _runPhase(prompts, previousAnswers = {}, explicitConfig = {}, existingConfig = {}) {
|
|
516
|
+
// Filter out non-promptable parameters
|
|
517
|
+
const promptablePrompts = this._filterPromptableParameters(prompts);
|
|
518
|
+
|
|
519
|
+
if (promptablePrompts.length === 0) return {};
|
|
520
|
+
|
|
521
|
+
// First, add any existing config values to previousAnswers so they're available for defaults
|
|
522
|
+
const allPreviousAnswers = { ...existingConfig, ...previousAnswers };
|
|
523
|
+
|
|
524
|
+
return await this._runPrompts(promptablePrompts.map(prompt => ({
|
|
525
|
+
...prompt,
|
|
526
|
+
// Wrap message to inject previousAnswers so prompts can access _mcpInstanceChoices etc.
|
|
527
|
+
message: typeof prompt.message === 'function' ? (answers) => {
|
|
528
|
+
return prompt.message({...allPreviousAnswers, ...answers});
|
|
529
|
+
} : prompt.message,
|
|
530
|
+
// Use existing config as default if available
|
|
531
|
+
default: prompt.default ? (answers) => {
|
|
532
|
+
// Check if we have a value from existing config first
|
|
533
|
+
if (existingConfig[prompt.name] !== undefined && existingConfig[prompt.name] !== null) {
|
|
534
|
+
return existingConfig[prompt.name];
|
|
535
|
+
}
|
|
536
|
+
// Otherwise use the original default logic
|
|
537
|
+
if (typeof prompt.default === 'function') {
|
|
538
|
+
return prompt.default({...allPreviousAnswers, ...answers});
|
|
539
|
+
}
|
|
540
|
+
return prompt.default;
|
|
541
|
+
} : (existingConfig[prompt.name] !== undefined && existingConfig[prompt.name] !== null) ?
|
|
542
|
+
existingConfig[prompt.name] : undefined,
|
|
543
|
+
// Skip prompt ONLY if we have explicit config (not defaults)
|
|
544
|
+
when: prompt.when ? (answers) => {
|
|
545
|
+
// Skip if we have the value from explicit config (CLI, env vars, config files)
|
|
546
|
+
if (explicitConfig[prompt.name] !== undefined && explicitConfig[prompt.name] !== null) {
|
|
547
|
+
return false;
|
|
548
|
+
}
|
|
549
|
+
return prompt.when({...allPreviousAnswers, ...answers});
|
|
550
|
+
} : (explicitConfig[prompt.name] !== undefined && explicitConfig[prompt.name] !== null) ?
|
|
551
|
+
() => false : undefined,
|
|
552
|
+
// Provide access to previous answers for conditional logic
|
|
553
|
+
// For unbounded parameters, inject MCP-provided choices if available
|
|
554
|
+
choices: prompt.choices ? (answers) => {
|
|
555
|
+
const mcpChoices = this.configManager?.mcpChoices?.[prompt.name];
|
|
556
|
+
if (mcpChoices && mcpChoices.length > 0) {
|
|
557
|
+
return [...mcpChoices.map(v => ({ name: v, value: v })), { name: 'Custom (enter manually)', value: 'custom' }];
|
|
558
|
+
}
|
|
559
|
+
// Fallback to original choices
|
|
560
|
+
if (typeof prompt.choices === 'function') {
|
|
561
|
+
return prompt.choices({...allPreviousAnswers, ...answers});
|
|
562
|
+
}
|
|
563
|
+
return prompt.choices;
|
|
564
|
+
} : undefined
|
|
565
|
+
})));
|
|
566
|
+
}
|
|
567
|
+
|
|
568
|
+
/**
|
|
569
|
+
* Get auto-set model format for Triton backends with a single format.
|
|
570
|
+
* Returns null if the backend requires user selection (FIL, Python) or
|
|
571
|
+
* doesn't use model formats (vllm, tensorrtllm).
|
|
572
|
+
* Requirements: 3.3, 3.4, 3.5
|
|
573
|
+
* @param {string} architecture - Resolved architecture
|
|
574
|
+
* @param {string} backend - Resolved backend
|
|
575
|
+
* @returns {string|null} Auto-set model format or null
|
|
576
|
+
* @private
|
|
577
|
+
*/
|
|
578
|
+
_getTritonAutoModelFormat(architecture, backend) {
|
|
579
|
+
if (architecture !== 'triton') return null
|
|
580
|
+
|
|
581
|
+
const meta = this._tritonBackends[backend]
|
|
582
|
+
if (!meta || !meta.modelFormats) return null
|
|
583
|
+
|
|
584
|
+
// Only auto-set if there's exactly one format
|
|
585
|
+
if (meta.modelFormats.length === 1) {
|
|
586
|
+
return meta.modelFormats[0]
|
|
587
|
+
}
|
|
588
|
+
|
|
589
|
+
return null
|
|
590
|
+
}
|
|
591
|
+
|
|
592
|
+
/**
|
|
593
|
+
* Query MCP region-picker server before infrastructure prompts.
|
|
594
|
+
* Populates configManager.mcpChoices so _runPhase injects them into list prompts.
|
|
595
|
+
* @private
|
|
596
|
+
*/
|
|
597
|
+
async _queryMcpForRegion(frameworkAnswers, explicitConfig) {
|
|
598
|
+
const cm = this.configManager;
|
|
599
|
+
if (!cm) return;
|
|
600
|
+
|
|
601
|
+
const mcpServers = cm.getMcpServerNames();
|
|
602
|
+
if (mcpServers.length === 0) return;
|
|
603
|
+
|
|
604
|
+
const smart = this.options.smart === true;
|
|
605
|
+
|
|
606
|
+
// Region: skip MCP query if region was explicitly provided via CLI, config file, or bootstrap profile
|
|
607
|
+
const cliRegion = this.options.region;
|
|
608
|
+
const bootstrapRegion = explicitConfig.awsRegion;
|
|
609
|
+
const skipRegionQuery = (cliRegion !== undefined && cliRegion !== null) ||
|
|
610
|
+
(bootstrapRegion !== undefined && bootstrapRegion !== null);
|
|
611
|
+
|
|
612
|
+
if (!skipRegionQuery && mcpServers.includes('region-picker')) {
|
|
613
|
+
const { regionSearch } = await this._runPrompts([{
|
|
614
|
+
type: 'input',
|
|
615
|
+
name: 'regionSearch',
|
|
616
|
+
message: 'š Search for a region (e.g. "europe", "us west", "tokyo"):',
|
|
617
|
+
default: ''
|
|
618
|
+
}]);
|
|
619
|
+
|
|
620
|
+
if (regionSearch && regionSearch.trim()) {
|
|
621
|
+
console.log(` š Querying region-picker${smart ? ' [smart]' : ''}...`);
|
|
622
|
+
const result = await cm.queryMcpServer('region-picker', {
|
|
623
|
+
...frameworkAnswers,
|
|
624
|
+
regionSearch: regionSearch.trim()
|
|
625
|
+
});
|
|
626
|
+
if (result && result.choices?.awsRegion?.length > 0) {
|
|
627
|
+
const choices = result.choices.awsRegion;
|
|
628
|
+
const preview = choices.length <= 5
|
|
629
|
+
? choices.join(', ')
|
|
630
|
+
: `${choices.slice(0, 5).join(', ') } (+${choices.length - 5} more)`;
|
|
631
|
+
console.log(` ā ${choices.length} region(s): [${preview}]`);
|
|
632
|
+
} else {
|
|
633
|
+
console.log(' ā³ No MCP results, using static list');
|
|
634
|
+
}
|
|
635
|
+
}
|
|
636
|
+
}
|
|
637
|
+
}
|
|
638
|
+
|
|
639
|
+
/**
|
|
640
|
+
* Query MCP instance-recommender server after deployment target is known.
|
|
641
|
+
* Only runs when deploymentTarget is managed-inference.
|
|
642
|
+
* Populates configManager.mcpChoices so _runPhase injects them into list prompts.
|
|
643
|
+
* @private
|
|
644
|
+
*/
|
|
645
|
+
async _queryMcpForInstance(frameworkAnswers, explicitConfig) {
|
|
646
|
+
const cm = this.configManager;
|
|
647
|
+
if (!cm) return;
|
|
648
|
+
|
|
649
|
+
const mcpServers = cm.getMcpServerNames();
|
|
650
|
+
if (mcpServers.length === 0) return;
|
|
651
|
+
|
|
652
|
+
const smart = this.options.smart === true;
|
|
653
|
+
|
|
654
|
+
// Instance type: query if not already provided via CLI/config
|
|
655
|
+
if (!explicitConfig.instanceType && mcpServers.includes('instance-recommender')) {
|
|
656
|
+
const { instanceSearch } = await this._runPrompts([{
|
|
657
|
+
type: 'input',
|
|
658
|
+
name: 'instanceSearch',
|
|
659
|
+
message: 'š Describe your instance needs (e.g. "multi-gpu", "cost-effective cpu"):',
|
|
660
|
+
default: frameworkAnswers.framework || ''
|
|
661
|
+
}]);
|
|
662
|
+
|
|
663
|
+
if (instanceSearch && instanceSearch.trim()) {
|
|
664
|
+
console.log(` š Querying instance-recommender${smart ? ' [smart]' : ''}...`);
|
|
665
|
+
const result = await cm.queryMcpServer('instance-recommender', {
|
|
666
|
+
...frameworkAnswers,
|
|
667
|
+
instanceSearch: instanceSearch.trim()
|
|
668
|
+
});
|
|
669
|
+
if (result && result.choices?.instanceType?.length > 0) {
|
|
670
|
+
const choices = result.choices.instanceType;
|
|
671
|
+
const preview = choices.length <= 5
|
|
672
|
+
? choices.join(', ')
|
|
673
|
+
: `${choices.slice(0, 5).join(', ') } (+${choices.length - 5} more)`;
|
|
674
|
+
console.log(` ā ${choices.length} instance(s): [${preview}]`);
|
|
675
|
+
} else {
|
|
676
|
+
console.log(' ā³ No MCP results, using static list');
|
|
677
|
+
}
|
|
678
|
+
}
|
|
679
|
+
}
|
|
680
|
+
}
|
|
681
|
+
|
|
682
|
+
/**
|
|
683
|
+
* Query the hyperpod-cluster-picker MCP server for available HyperPod EKS clusters.
|
|
684
|
+
* Populates configManager.mcpChoices.hyperPodCluster so _runPhase injects them into the list prompt.
|
|
685
|
+
* Falls back to manual entry if the MCP server is not configured or fails.
|
|
686
|
+
* Requirements: 12.1, 12.2, 12.3
|
|
687
|
+
* @private
|
|
688
|
+
*/
|
|
689
|
+
async _queryMcpForHyperPod(infraAnswers, explicitConfig) {
|
|
690
|
+
const cm = this.configManager;
|
|
691
|
+
if (!cm) return;
|
|
692
|
+
|
|
693
|
+
const mcpServers = cm.getMcpServerNames();
|
|
694
|
+
if (!mcpServers.includes('hyperpod-cluster-picker')) return;
|
|
695
|
+
|
|
696
|
+
// Skip if cluster already provided via CLI/config
|
|
697
|
+
if (explicitConfig.hyperPodCluster) return;
|
|
698
|
+
|
|
699
|
+
const smart = this.options.smart === true;
|
|
700
|
+
console.log(` š Querying hyperpod-cluster-picker${smart ? ' [smart]' : ''}...`);
|
|
701
|
+
|
|
702
|
+
const result = await cm.queryMcpServer('hyperpod-cluster-picker', {
|
|
703
|
+
...infraAnswers
|
|
704
|
+
});
|
|
705
|
+
|
|
706
|
+
if (result && result.choices?.hyperPodCluster?.length > 0) {
|
|
707
|
+
const choices = result.choices.hyperPodCluster;
|
|
708
|
+
const preview = choices.length <= 5
|
|
709
|
+
? choices.join(', ')
|
|
710
|
+
: `${choices.slice(0, 5).join(', ')} (+${choices.length - 5} more)`;
|
|
711
|
+
console.log(` ā ${choices.length} cluster(s): [${preview}]`);
|
|
712
|
+
} else {
|
|
713
|
+
// Surface any error message from the MCP server
|
|
714
|
+
if (result?.message) {
|
|
715
|
+
console.log(` ā ļø ${result.message}`);
|
|
716
|
+
} else {
|
|
717
|
+
console.log(' ā³ No HyperPod clusters found via MCP, manual entry available');
|
|
718
|
+
}
|
|
719
|
+
}
|
|
720
|
+
}
|
|
721
|
+
|
|
722
|
+
/**
|
|
723
|
+
* Query MCP base-image-picker server after deployment config is selected.
|
|
724
|
+
* Populates _mcpBaseImageChoices for the base image selection prompt.
|
|
725
|
+
* Requirements: 5.1, 5.2, 5.3, 5.4, 9.1, 9.2, 9.3
|
|
726
|
+
* @private
|
|
727
|
+
*/
|
|
728
|
+
async _queryMcpForBaseImage(frameworkAnswers, explicitConfig) {
|
|
729
|
+
// Skip if base image provided via CLI --base-image flag
|
|
730
|
+
if (this.options['base-image']) return
|
|
731
|
+
|
|
732
|
+
const cm = this.configManager
|
|
733
|
+
if (!cm) return
|
|
734
|
+
|
|
735
|
+
const mcpServers = cm.getMcpServerNames()
|
|
736
|
+
if (!mcpServers.includes('base-image-picker')) return
|
|
737
|
+
|
|
738
|
+
const smart = this.options.smart === true
|
|
739
|
+
const discover = this.options.discover === true
|
|
740
|
+
const framework = frameworkAnswers.framework
|
|
741
|
+
const modelServer = frameworkAnswers.modelServer
|
|
742
|
+
const architecture = frameworkAnswers.architecture || frameworkAnswers.deploymentConfig?.split('-')[0]
|
|
743
|
+
const isTransformer = framework === 'transformers'
|
|
744
|
+
const isTriton = architecture === 'triton'
|
|
745
|
+
const isDiffusors = architecture === 'diffusors'
|
|
746
|
+
|
|
747
|
+
// For non-transformer, non-triton, non-diffusors frameworks, prompt for optional search criteria
|
|
748
|
+
let searchCriteria
|
|
749
|
+
if (!isTransformer && !isTriton && !isDiffusors) {
|
|
750
|
+
const searchAnswer = await this._runPrompts(baseImageSearchPrompts.map(p => ({
|
|
751
|
+
...p,
|
|
752
|
+
when: () => true // Always show for non-transformer since we already checked
|
|
753
|
+
})))
|
|
754
|
+
searchCriteria = searchAnswer.baseImageSearch
|
|
755
|
+
}
|
|
756
|
+
|
|
757
|
+
const modeLabel = [smart && '[smart]', discover && '[discover]'].filter(Boolean).join(' ')
|
|
758
|
+
console.log(` š Querying base-image-picker${modeLabel ? ` ${modeLabel}` : ''}...`)
|
|
759
|
+
|
|
760
|
+
const context = { framework, modelServer, architecture }
|
|
761
|
+
if (searchCriteria && searchCriteria.trim()) {
|
|
762
|
+
context.searchCriteria = searchCriteria.trim()
|
|
763
|
+
}
|
|
764
|
+
|
|
765
|
+
const result = await cm.queryMcpServer('base-image-picker', context)
|
|
766
|
+
|
|
767
|
+
if (result && result.metadata?.baseImage?.length > 0) {
|
|
768
|
+
const entries = result.metadata.baseImage
|
|
769
|
+
this._mcpBaseImageChoices = formatImageChoices(entries, isTransformer || isTriton || isDiffusors)
|
|
770
|
+
const count = entries.length
|
|
771
|
+
console.log(` ā ${count} base image(s) available`)
|
|
772
|
+
} else {
|
|
773
|
+
console.log(' ā³ No MCP results, using default image')
|
|
774
|
+
}
|
|
775
|
+
}
|
|
776
|
+
|
|
777
|
+
/**
|
|
778
|
+
* Query model-picker MCP server catalog for model choices.
|
|
779
|
+
* Reads the architecture-specific catalog (popular-transformers.json or
|
|
780
|
+
* popular-diffusors.json) to populate the model selection prompt.
|
|
781
|
+
* @param {string} [architecture] - Current architecture ('transformers', 'diffusors', etc.)
|
|
782
|
+
* @private
|
|
783
|
+
*/
|
|
784
|
+
_queryMcpForModels(architecture) {
|
|
785
|
+
const cm = this.configManager
|
|
786
|
+
if (!cm) return
|
|
787
|
+
|
|
788
|
+
const mcpServers = cm.getMcpServerNames()
|
|
789
|
+
if (!mcpServers.includes('model-picker')) return
|
|
790
|
+
|
|
791
|
+
try {
|
|
792
|
+
const mcpConfigPath = path.join(GENERATOR_ROOT, 'config', 'mcp.json')
|
|
793
|
+
if (!fs.existsSync(mcpConfigPath)) return
|
|
794
|
+
|
|
795
|
+
const mcpConfig = JSON.parse(fs.readFileSync(mcpConfigPath, 'utf8'))
|
|
796
|
+
const serverConfig = mcpConfig.mcpServers?.['model-picker']
|
|
797
|
+
if (!serverConfig?.args?.length) return
|
|
798
|
+
|
|
799
|
+
// Resolve the server entry point directory from the args
|
|
800
|
+
const serverEntryPoint = serverConfig.args[serverConfig.args.length - 1]
|
|
801
|
+
const serverDir = path.dirname(serverEntryPoint)
|
|
802
|
+
|
|
803
|
+
// Read manifest to find catalog path
|
|
804
|
+
const manifestPath = path.join(serverDir, 'manifest.json')
|
|
805
|
+
if (!fs.existsSync(manifestPath)) return
|
|
806
|
+
|
|
807
|
+
const manifest = JSON.parse(fs.readFileSync(manifestPath, 'utf8'))
|
|
808
|
+
|
|
809
|
+
// Select catalog based on architecture
|
|
810
|
+
const catalogKey = architecture === 'diffusors'
|
|
811
|
+
? 'popular-diffusors'
|
|
812
|
+
: 'popular-transformers'
|
|
813
|
+
const catalogRelPath = manifest.catalogs?.[catalogKey]
|
|
814
|
+
if (!catalogRelPath) return
|
|
815
|
+
|
|
816
|
+
const catalogPath = path.resolve(serverDir, catalogRelPath)
|
|
817
|
+
if (!fs.existsSync(catalogPath)) return
|
|
818
|
+
|
|
819
|
+
const catalog = JSON.parse(fs.readFileSync(catalogPath, 'utf8'))
|
|
820
|
+
|
|
821
|
+
// Extract model IDs, filtering out glob patterns (entries with *)
|
|
822
|
+
const modelIds = Object.keys(catalog).filter(id => !id.includes('*'))
|
|
823
|
+
|
|
824
|
+
if (modelIds.length > 0) {
|
|
825
|
+
this._mcpModelChoices = modelIds
|
|
826
|
+
}
|
|
827
|
+
} catch {
|
|
828
|
+
// Silently fall back to hardcoded defaults
|
|
829
|
+
}
|
|
830
|
+
}
|
|
831
|
+
|
|
832
|
+
/**
|
|
833
|
+
* Get framework version choices from registry
|
|
834
|
+
* Requirements: 2.1, 2.6, 8.2, 8.3
|
|
835
|
+
* @private
|
|
836
|
+
*/
|
|
837
|
+
_getFrameworkVersionChoices(framework) {
|
|
838
|
+
const registryConfigManager = this.registryConfigManager;
|
|
839
|
+
|
|
840
|
+
if (!registryConfigManager || !registryConfigManager.frameworkRegistry) {
|
|
841
|
+
return [];
|
|
842
|
+
}
|
|
843
|
+
|
|
844
|
+
const frameworkVersions = registryConfigManager.frameworkRegistry[framework];
|
|
845
|
+
if (!frameworkVersions || Object.keys(frameworkVersions).length === 0) {
|
|
846
|
+
return [];
|
|
847
|
+
}
|
|
848
|
+
|
|
849
|
+
// Get available versions and sort them
|
|
850
|
+
const versions = Object.keys(frameworkVersions).sort((a, b) => {
|
|
851
|
+
// Simple version comparison (can be enhanced with semver)
|
|
852
|
+
return b.localeCompare(a, undefined, { numeric: true });
|
|
853
|
+
});
|
|
854
|
+
|
|
855
|
+
// Create choices with validation level indicators
|
|
856
|
+
return versions.map(version => {
|
|
857
|
+
const config = frameworkVersions[version];
|
|
858
|
+
const validationLevel = config.validationLevel || 'unknown';
|
|
859
|
+
const indicator = {
|
|
860
|
+
'tested': 'ā
',
|
|
861
|
+
'community-validated': 'š„',
|
|
862
|
+
'experimental': 'š§Ŗ',
|
|
863
|
+
'unknown': 'ā'
|
|
864
|
+
}[validationLevel] || 'ā';
|
|
865
|
+
|
|
866
|
+
return {
|
|
867
|
+
name: `${version} ${indicator} (${validationLevel})`,
|
|
868
|
+
value: version,
|
|
869
|
+
short: version
|
|
870
|
+
};
|
|
871
|
+
});
|
|
872
|
+
}
|
|
873
|
+
|
|
874
|
+
/**
|
|
875
|
+
* Display framework validation information
|
|
876
|
+
* Requirements: 2.6, 8.2, 8.3
|
|
877
|
+
* @private
|
|
878
|
+
*/
|
|
879
|
+
_displayFrameworkValidationInfo(framework, version) {
|
|
880
|
+
const registryConfigManager = this.registryConfigManager;
|
|
881
|
+
|
|
882
|
+
if (!registryConfigManager || !registryConfigManager.frameworkRegistry) {
|
|
883
|
+
return;
|
|
884
|
+
}
|
|
885
|
+
|
|
886
|
+
const config = registryConfigManager.frameworkRegistry[framework]?.[version];
|
|
887
|
+
if (!config) {
|
|
888
|
+
return;
|
|
889
|
+
}
|
|
890
|
+
|
|
891
|
+
console.log('\nš Framework Configuration:');
|
|
892
|
+
console.log(` ⢠Framework: ${framework} ${version}`);
|
|
893
|
+
console.log(` ⢠Validation Level: ${config.validationLevel || 'unknown'}`);
|
|
894
|
+
console.log(' ⢠Source: Framework_Registry');
|
|
895
|
+
|
|
896
|
+
if (config.accelerator) {
|
|
897
|
+
console.log(` ⢠Accelerator: ${config.accelerator.type} ${config.accelerator.version || 'any'}`);
|
|
898
|
+
}
|
|
899
|
+
|
|
900
|
+
if (config.recommendedInstanceTypes && config.recommendedInstanceTypes.length > 0) {
|
|
901
|
+
console.log(` ⢠Recommended Instances: ${config.recommendedInstanceTypes.slice(0, 3).join(', ')}`);
|
|
902
|
+
}
|
|
903
|
+
|
|
904
|
+
if (config.notes) {
|
|
905
|
+
console.log(` ⢠Notes: ${config.notes}`);
|
|
906
|
+
}
|
|
907
|
+
}
|
|
908
|
+
|
|
909
|
+
/**
|
|
910
|
+
* Get framework profile choices from registry
|
|
911
|
+
* Requirements: 12.1, 12.2, 12.3, 12.4, 12.5, 12.10
|
|
912
|
+
* @private
|
|
913
|
+
*/
|
|
914
|
+
_getFrameworkProfileChoices(framework, version) {
|
|
915
|
+
const registryConfigManager = this.registryConfigManager;
|
|
916
|
+
|
|
917
|
+
if (!registryConfigManager || !registryConfigManager.frameworkRegistry) {
|
|
918
|
+
return [];
|
|
919
|
+
}
|
|
920
|
+
|
|
921
|
+
const config = registryConfigManager.frameworkRegistry[framework]?.[version];
|
|
922
|
+
if (!config || !config.profiles || Object.keys(config.profiles).length === 0) {
|
|
923
|
+
return [];
|
|
924
|
+
}
|
|
925
|
+
|
|
926
|
+
// Create choices from profiles
|
|
927
|
+
const choices = Object.entries(config.profiles).map(([profileName, profileConfig]) => {
|
|
928
|
+
return {
|
|
929
|
+
name: `${profileConfig.displayName || profileName} - ${profileConfig.description || 'No description'}`,
|
|
930
|
+
value: profileName,
|
|
931
|
+
short: profileConfig.displayName || profileName
|
|
932
|
+
};
|
|
933
|
+
});
|
|
934
|
+
|
|
935
|
+
// Add "default" option to skip profile selection
|
|
936
|
+
choices.unshift({
|
|
937
|
+
name: 'Default (no profile)',
|
|
938
|
+
value: null,
|
|
939
|
+
short: 'Default'
|
|
940
|
+
});
|
|
941
|
+
|
|
942
|
+
return choices;
|
|
943
|
+
}
|
|
944
|
+
|
|
945
|
+
/**
|
|
946
|
+
* Get model profile choices from registry
|
|
947
|
+
* Requirements: 12.1, 12.2, 12.3, 12.4, 12.5, 12.10
|
|
948
|
+
* @private
|
|
949
|
+
*/
|
|
950
|
+
_getModelProfileChoices(modelId) {
|
|
951
|
+
const registryConfigManager = this.registryConfigManager;
|
|
952
|
+
|
|
953
|
+
if (!registryConfigManager || !registryConfigManager.modelRegistry || !modelId) {
|
|
954
|
+
return [];
|
|
955
|
+
}
|
|
956
|
+
|
|
957
|
+
// Try to find model in registry (exact match or pattern match)
|
|
958
|
+
let modelConfig = registryConfigManager.modelRegistry[modelId];
|
|
959
|
+
|
|
960
|
+
// If no exact match, try pattern matching
|
|
961
|
+
if (!modelConfig) {
|
|
962
|
+
for (const [pattern, config] of Object.entries(registryConfigManager.modelRegistry)) {
|
|
963
|
+
if (pattern.includes('*')) {
|
|
964
|
+
const regex = new RegExp(`^${ pattern.replace(/\*/g, '.*') }$`);
|
|
965
|
+
if (regex.test(modelId)) {
|
|
966
|
+
modelConfig = config;
|
|
967
|
+
break;
|
|
968
|
+
}
|
|
969
|
+
}
|
|
970
|
+
}
|
|
971
|
+
}
|
|
972
|
+
|
|
973
|
+
if (!modelConfig || !modelConfig.profiles || Object.keys(modelConfig.profiles).length === 0) {
|
|
974
|
+
return [];
|
|
975
|
+
}
|
|
976
|
+
|
|
977
|
+
// Create choices from profiles
|
|
978
|
+
const choices = Object.entries(modelConfig.profiles).map(([profileName, profileConfig]) => {
|
|
979
|
+
return {
|
|
980
|
+
name: `${profileConfig.displayName || profileName} - ${profileConfig.description || 'No description'}`,
|
|
981
|
+
value: profileName,
|
|
982
|
+
short: profileConfig.displayName || profileName
|
|
983
|
+
};
|
|
984
|
+
});
|
|
985
|
+
|
|
986
|
+
// Add "default" option to skip profile selection
|
|
987
|
+
choices.unshift({
|
|
988
|
+
name: 'Default (no profile)',
|
|
989
|
+
value: null,
|
|
990
|
+
short: 'Default'
|
|
991
|
+
});
|
|
992
|
+
|
|
993
|
+
return choices;
|
|
994
|
+
}
|
|
995
|
+
|
|
996
|
+
/**
|
|
997
|
+
* Fetch and display model information from HuggingFace API and Model Registry
|
|
998
|
+
* Requirements: 5.1, 5.2, 5.3, 5.4, 5.5, 5.6, 5.11, 11.1, 11.2, 11.3, 11.5, 11.6, 11.7
|
|
999
|
+
* @private
|
|
1000
|
+
*/
|
|
1001
|
+
async _fetchAndDisplayModelInfo(modelId) {
|
|
1002
|
+
console.log(`\n š Querying model-picker [discover]...`);
|
|
1003
|
+
|
|
1004
|
+
const sources = [];
|
|
1005
|
+
let chatTemplate = null;
|
|
1006
|
+
let modelFamily = null;
|
|
1007
|
+
let mcpUsed = false;
|
|
1008
|
+
|
|
1009
|
+
// Try model-picker MCP server in discover mode (queries HuggingFace + merges with catalog)
|
|
1010
|
+
const cm = this.configManager;
|
|
1011
|
+
if (cm) {
|
|
1012
|
+
const mcpServers = cm.getMcpServerNames();
|
|
1013
|
+
if (mcpServers.includes('model-picker')) {
|
|
1014
|
+
try {
|
|
1015
|
+
const mcpConfigPath = path.join(GENERATOR_ROOT, 'config', 'mcp.json');
|
|
1016
|
+
if (fs.existsSync(mcpConfigPath)) {
|
|
1017
|
+
const mcpConfig = JSON.parse(fs.readFileSync(mcpConfigPath, 'utf8'));
|
|
1018
|
+
const serverConfig = mcpConfig.mcpServers?.['model-picker'];
|
|
1019
|
+
if (serverConfig) {
|
|
1020
|
+
const { McpClient } = await import('./mcp-client.js');
|
|
1021
|
+
const client = new McpClient(serverConfig, { timeout: 15000 });
|
|
1022
|
+
|
|
1023
|
+
// Override _buildContext to pass model_id and mode directly
|
|
1024
|
+
client._getUnboundedParameterNames = () => [];
|
|
1025
|
+
client._buildContext = () => ({});
|
|
1026
|
+
|
|
1027
|
+
// Connect and call get_models directly
|
|
1028
|
+
const { Client } = await import('@modelcontextprotocol/sdk/client/index.js');
|
|
1029
|
+
const { StdioClientTransport } = await import('@modelcontextprotocol/sdk/client/stdio.js');
|
|
1030
|
+
|
|
1031
|
+
const transport = new StdioClientTransport({
|
|
1032
|
+
command: serverConfig.command,
|
|
1033
|
+
args: serverConfig.args || [],
|
|
1034
|
+
env: { ...process.env, ...(serverConfig.env || {}) },
|
|
1035
|
+
stderr: 'pipe'
|
|
1036
|
+
});
|
|
1037
|
+
|
|
1038
|
+
const mcpClient = new Client(
|
|
1039
|
+
{ name: 'ml-container-creator', version: '1.0.0' },
|
|
1040
|
+
{ capabilities: {} }
|
|
1041
|
+
);
|
|
1042
|
+
|
|
1043
|
+
await mcpClient.connect(transport);
|
|
1044
|
+
|
|
1045
|
+
const result = await mcpClient.callTool({
|
|
1046
|
+
name: 'get_models',
|
|
1047
|
+
arguments: { model_id: modelId, mode: 'discover' }
|
|
1048
|
+
});
|
|
1049
|
+
|
|
1050
|
+
await mcpClient.close();
|
|
1051
|
+
|
|
1052
|
+
// Parse the response
|
|
1053
|
+
const textBlock = result?.content?.find(b => b.type === 'text');
|
|
1054
|
+
if (textBlock) {
|
|
1055
|
+
const parsed = JSON.parse(textBlock.text);
|
|
1056
|
+
if (parsed.values && Object.keys(parsed.values).length > 0) {
|
|
1057
|
+
mcpUsed = true;
|
|
1058
|
+
const vals = parsed.values;
|
|
1059
|
+
|
|
1060
|
+
if (vals.chat_template) {
|
|
1061
|
+
chatTemplate = vals.chat_template;
|
|
1062
|
+
}
|
|
1063
|
+
if (vals.family) {
|
|
1064
|
+
modelFamily = vals.family;
|
|
1065
|
+
}
|
|
1066
|
+
|
|
1067
|
+
// Extract model source metadata for loading adapter
|
|
1068
|
+
// Requirements: 2.1, 2.2, 2.3, 2.4
|
|
1069
|
+
if (vals.provider) {
|
|
1070
|
+
this._mcpModelSource = vals.provider;
|
|
1071
|
+
}
|
|
1072
|
+
if (vals.artifactUri) {
|
|
1073
|
+
this._mcpArtifactUri = vals.artifactUri;
|
|
1074
|
+
}
|
|
1075
|
+
|
|
1076
|
+
// Determine sources based on what was returned
|
|
1077
|
+
if (vals.tags || vals.pipeline_tag) {
|
|
1078
|
+
sources.push('HuggingFace_Hub_API');
|
|
1079
|
+
}
|
|
1080
|
+
if (vals.validation_level || vals.framework_compatibility) {
|
|
1081
|
+
sources.push('Model_Picker_Catalog');
|
|
1082
|
+
}
|
|
1083
|
+
if (sources.length === 0) {
|
|
1084
|
+
sources.push('model-picker');
|
|
1085
|
+
}
|
|
1086
|
+
console.log(` ā Resolved: ${modelId}`);
|
|
1087
|
+
} else if (parsed.message) {
|
|
1088
|
+
console.log(` ā³ ${parsed.message}`);
|
|
1089
|
+
}
|
|
1090
|
+
}
|
|
1091
|
+
}
|
|
1092
|
+
}
|
|
1093
|
+
} catch (err) {
|
|
1094
|
+
console.log(' ā³ model-picker unavailable, using fallback');
|
|
1095
|
+
}
|
|
1096
|
+
}
|
|
1097
|
+
}
|
|
1098
|
+
|
|
1099
|
+
// Fallback to legacy path if MCP didn't resolve
|
|
1100
|
+
if (!mcpUsed) {
|
|
1101
|
+
const registryConfigManager = this.registryConfigManager;
|
|
1102
|
+
if (registryConfigManager) {
|
|
1103
|
+
// Only try HuggingFace API for bare model IDs (not prefixed URIs)
|
|
1104
|
+
const isNonHfUri = modelId.startsWith('jumpstart://') ||
|
|
1105
|
+
modelId.startsWith('jumpstart-hub://') ||
|
|
1106
|
+
modelId.startsWith('s3://') ||
|
|
1107
|
+
modelId.startsWith('registry://');
|
|
1108
|
+
|
|
1109
|
+
if (!isNonHfUri) {
|
|
1110
|
+
// Try HuggingFace API directly
|
|
1111
|
+
try {
|
|
1112
|
+
const hfData = await registryConfigManager._fetchHuggingFaceData(modelId);
|
|
1113
|
+
if (hfData) {
|
|
1114
|
+
sources.push('HuggingFace_Hub_API');
|
|
1115
|
+
if (hfData.chatTemplate) {
|
|
1116
|
+
chatTemplate = hfData.chatTemplate;
|
|
1117
|
+
}
|
|
1118
|
+
console.log(' ā
Found on HuggingFace Hub');
|
|
1119
|
+
} else {
|
|
1120
|
+
console.log(' ā¹ļø Not found on HuggingFace Hub (may be private or offline)');
|
|
1121
|
+
}
|
|
1122
|
+
} catch (error) {
|
|
1123
|
+
console.log(' ā ļø HuggingFace API unavailable');
|
|
1124
|
+
}
|
|
1125
|
+
} else {
|
|
1126
|
+
// Non-HF URI (jumpstart://, s3://, etc.) ā skip HF lookup silently
|
|
1127
|
+
// The summary at the end of this function will report "No additional model information"
|
|
1128
|
+
}
|
|
1129
|
+
|
|
1130
|
+
// Check Model Registry for overrides
|
|
1131
|
+
if (registryConfigManager.modelRegistry) {
|
|
1132
|
+
let modelConfig = registryConfigManager.modelRegistry[modelId];
|
|
1133
|
+
|
|
1134
|
+
if (!modelConfig) {
|
|
1135
|
+
for (const [pattern, config] of Object.entries(registryConfigManager.modelRegistry)) {
|
|
1136
|
+
if (pattern.includes('*')) {
|
|
1137
|
+
const regex = new RegExp('^' + pattern.replace(/\*/g, '.*') + '$');
|
|
1138
|
+
if (regex.test(modelId)) {
|
|
1139
|
+
modelConfig = config;
|
|
1140
|
+
console.log(` ā
Matched pattern in Model_Registry: ${pattern}`);
|
|
1141
|
+
break;
|
|
1142
|
+
}
|
|
1143
|
+
}
|
|
1144
|
+
}
|
|
1145
|
+
} else {
|
|
1146
|
+
console.log(' ā
Found in Model_Registry');
|
|
1147
|
+
}
|
|
1148
|
+
|
|
1149
|
+
if (modelConfig) {
|
|
1150
|
+
sources.push('Model_Registry');
|
|
1151
|
+
if (modelConfig.chatTemplate) {
|
|
1152
|
+
chatTemplate = modelConfig.chatTemplate;
|
|
1153
|
+
}
|
|
1154
|
+
if (modelConfig.family) {
|
|
1155
|
+
modelFamily = modelConfig.family;
|
|
1156
|
+
}
|
|
1157
|
+
}
|
|
1158
|
+
}
|
|
1159
|
+
}
|
|
1160
|
+
}
|
|
1161
|
+
|
|
1162
|
+
// Display information
|
|
1163
|
+
if (sources.length > 0) {
|
|
1164
|
+
console.log('\nš Model Information:');
|
|
1165
|
+
console.log(` ⢠Model ID: ${modelId}`);
|
|
1166
|
+
if (modelFamily) {
|
|
1167
|
+
console.log(` ⢠Family: ${modelFamily}`);
|
|
1168
|
+
}
|
|
1169
|
+
if (chatTemplate) {
|
|
1170
|
+
console.log(' ⢠Chat Template: ā
Available');
|
|
1171
|
+
console.log(' (Will be injected into generated files)');
|
|
1172
|
+
} else {
|
|
1173
|
+
console.log(' ⢠Chat Template: ā Not available');
|
|
1174
|
+
console.log(' (Chat endpoints may require manual configuration)');
|
|
1175
|
+
}
|
|
1176
|
+
console.log(` ⢠Sources: ${sources.join(', ')}`);
|
|
1177
|
+
} else {
|
|
1178
|
+
console.log(' ā¹ļø No additional model information available');
|
|
1179
|
+
console.log(' Proceeding with default configuration');
|
|
1180
|
+
}
|
|
1181
|
+
}
|
|
1182
|
+
|
|
1183
|
+
|
|
1184
|
+
|
|
1185
|
+
/**
|
|
1186
|
+
* Validate and display instance type compatibility
|
|
1187
|
+
* Requirements: 4.1, 4.2, 4.3, 4.4, 4.5, 4.6
|
|
1188
|
+
* @private
|
|
1189
|
+
*/
|
|
1190
|
+
async _validateAndDisplayInstanceType(instanceType, framework, version) {
|
|
1191
|
+
const registryConfigManager = this.registryConfigManager;
|
|
1192
|
+
|
|
1193
|
+
if (!registryConfigManager) {
|
|
1194
|
+
return;
|
|
1195
|
+
}
|
|
1196
|
+
|
|
1197
|
+
// Get framework configuration
|
|
1198
|
+
const frameworkConfig = registryConfigManager.frameworkRegistry?.[framework]?.[version];
|
|
1199
|
+
if (!frameworkConfig) {
|
|
1200
|
+
return; // No framework config, skip validation
|
|
1201
|
+
}
|
|
1202
|
+
|
|
1203
|
+
console.log(`\nš Validating instance type: ${instanceType}`);
|
|
1204
|
+
|
|
1205
|
+
// Validate instance type
|
|
1206
|
+
const validationResult = registryConfigManager.validateInstanceType(instanceType, frameworkConfig);
|
|
1207
|
+
|
|
1208
|
+
if (validationResult.compatible) {
|
|
1209
|
+
console.log(' ā
Instance type is compatible');
|
|
1210
|
+
if (validationResult.info) {
|
|
1211
|
+
console.log(` ā¹ļø ${validationResult.info}`);
|
|
1212
|
+
}
|
|
1213
|
+
} else {
|
|
1214
|
+
console.log(' ā Instance type compatibility issue detected');
|
|
1215
|
+
if (validationResult.error) {
|
|
1216
|
+
console.log(` Error: ${validationResult.error}`);
|
|
1217
|
+
}
|
|
1218
|
+
if (validationResult.recommendations && validationResult.recommendations.length > 0) {
|
|
1219
|
+
console.log(` š” Recommended instances: ${validationResult.recommendations.join(', ')}`);
|
|
1220
|
+
}
|
|
1221
|
+
|
|
1222
|
+
// In test mode or non-interactive mode, throw error instead of prompting
|
|
1223
|
+
if (this.options.skipPrompts || process.env.NODE_ENV === 'test') {
|
|
1224
|
+
throw new Error('Instance type validation failed. Please select a compatible instance type.');
|
|
1225
|
+
}
|
|
1226
|
+
|
|
1227
|
+
// Ask user if they want to proceed
|
|
1228
|
+
const proceed = await this._runPrompts([{
|
|
1229
|
+
type: 'confirm',
|
|
1230
|
+
name: 'proceedWithIncompatible',
|
|
1231
|
+
message: 'Instance type may not be compatible. Proceed anyway?',
|
|
1232
|
+
default: false
|
|
1233
|
+
}]);
|
|
1234
|
+
|
|
1235
|
+
if (!proceed.proceedWithIncompatible) {
|
|
1236
|
+
throw new Error('Instance type validation failed. Please select a compatible instance type.');
|
|
1237
|
+
}
|
|
1238
|
+
}
|
|
1239
|
+
|
|
1240
|
+
if (validationResult.warning) {
|
|
1241
|
+
console.log(` ā ļø Warning: ${validationResult.warning}`);
|
|
1242
|
+
}
|
|
1243
|
+
}
|
|
1244
|
+
|
|
1245
|
+
/**
|
|
1246
|
+
* CUDA-to-AMI mapping.
|
|
1247
|
+
* Maps CUDA major.minor versions to the SageMaker inference AMI that provides
|
|
1248
|
+
* the matching CUDA driver. Derived from the framework registry patterns.
|
|
1249
|
+
* @private
|
|
1250
|
+
*/
|
|
1251
|
+
static CUDA_AMI_MAP = {
|
|
1252
|
+
'11.0': 'al2-ami-sagemaker-inference-gpu-2-1',
|
|
1253
|
+
'11.4': 'al2-ami-sagemaker-inference-gpu-2-1',
|
|
1254
|
+
'11.8': 'al2-ami-sagemaker-inference-gpu-3-1',
|
|
1255
|
+
'12.1': 'al2-ami-sagemaker-inference-gpu-3-1',
|
|
1256
|
+
'12.2': 'al2-ami-sagemaker-inference-gpu-3-2',
|
|
1257
|
+
'12.4': 'al2-ami-sagemaker-inference-gpu-3-2',
|
|
1258
|
+
'12.6': 'al2-ami-sagemaker-inference-gpu-3-2'
|
|
1259
|
+
};
|
|
1260
|
+
|
|
1261
|
+
/**
|
|
1262
|
+
* Prompt the user to select a CUDA version when the selected GPU instance
|
|
1263
|
+
* supports multiple versions. The choice transparently resolves to the
|
|
1264
|
+
* correct SageMaker inference AMI.
|
|
1265
|
+
*
|
|
1266
|
+
* Skipped for CPU instances, non-CUDA accelerators, or when only one
|
|
1267
|
+
* compatible CUDA version exists.
|
|
1268
|
+
*
|
|
1269
|
+
* @param {string} instanceType - Selected instance type (e.g. "ml.g5.2xlarge")
|
|
1270
|
+
* @param {string} framework - Selected framework name
|
|
1271
|
+
* @param {string} frameworkVersion - Selected framework version
|
|
1272
|
+
* @returns {Promise<{cudaVersion: string, inferenceAmiVersion: string}|null>}
|
|
1273
|
+
* @private
|
|
1274
|
+
*/
|
|
1275
|
+
async _promptCudaVersion(instanceType, framework, frameworkVersion) {
|
|
1276
|
+
if (!instanceType) return null;
|
|
1277
|
+
|
|
1278
|
+
// Look up instance in accelerator mapping
|
|
1279
|
+
const instanceInfo = this._instanceAcceleratorMapping[instanceType];
|
|
1280
|
+
if (!instanceInfo || instanceInfo.accelerator.type !== 'cuda') return null;
|
|
1281
|
+
|
|
1282
|
+
const instanceCudaVersions = instanceInfo.accelerator.versions;
|
|
1283
|
+
if (!instanceCudaVersions || instanceCudaVersions.length === 0) return null;
|
|
1284
|
+
|
|
1285
|
+
// Get framework CUDA requirements (if available)
|
|
1286
|
+
const registryConfigManager = this.registryConfigManager;
|
|
1287
|
+
const frameworkConfig = registryConfigManager?.frameworkRegistry?.[framework]?.[frameworkVersion];
|
|
1288
|
+
const frameworkAccel = frameworkConfig?.accelerator;
|
|
1289
|
+
|
|
1290
|
+
// Compute compatible CUDA versions: intersection of instance support and framework range
|
|
1291
|
+
let compatibleVersions;
|
|
1292
|
+
if (frameworkAccel?.versionRange) {
|
|
1293
|
+
const { min, max } = frameworkAccel.versionRange;
|
|
1294
|
+
compatibleVersions = instanceCudaVersions.filter(v => {
|
|
1295
|
+
return v >= min && v <= max;
|
|
1296
|
+
});
|
|
1297
|
+
} else {
|
|
1298
|
+
compatibleVersions = [...instanceCudaVersions];
|
|
1299
|
+
}
|
|
1300
|
+
|
|
1301
|
+
if (compatibleVersions.length === 0) {
|
|
1302
|
+
// No overlap ā fall back to all instance versions (validation already warned)
|
|
1303
|
+
compatibleVersions = [...instanceCudaVersions];
|
|
1304
|
+
}
|
|
1305
|
+
|
|
1306
|
+
// If only one option, auto-select it silently
|
|
1307
|
+
if (compatibleVersions.length === 1) {
|
|
1308
|
+
const cudaVersion = compatibleVersions[0];
|
|
1309
|
+
const inferenceAmiVersion = PromptRunner.CUDA_AMI_MAP[cudaVersion];
|
|
1310
|
+
if (inferenceAmiVersion) {
|
|
1311
|
+
console.log(`\nš§ CUDA ${cudaVersion} auto-selected (only compatible version for ${instanceType})`);
|
|
1312
|
+
console.log(` AMI: ${inferenceAmiVersion}`);
|
|
1313
|
+
}
|
|
1314
|
+
return inferenceAmiVersion ? { cudaVersion, inferenceAmiVersion } : null;
|
|
1315
|
+
}
|
|
1316
|
+
|
|
1317
|
+
// Multiple options ā let the user choose
|
|
1318
|
+
const defaultVersion = frameworkAccel?.version
|
|
1319
|
+
&& compatibleVersions.includes(frameworkAccel.version)
|
|
1320
|
+
? frameworkAccel.version
|
|
1321
|
+
: instanceInfo.accelerator.default || compatibleVersions[compatibleVersions.length - 1];
|
|
1322
|
+
|
|
1323
|
+
const choices = compatibleVersions.map(v => {
|
|
1324
|
+
const ami = PromptRunner.CUDA_AMI_MAP[v] || 'unknown';
|
|
1325
|
+
const isDefault = v === defaultVersion ? ' (recommended)' : '';
|
|
1326
|
+
return {
|
|
1327
|
+
name: `CUDA ${v}${isDefault} ā AMI: ${ami}`,
|
|
1328
|
+
value: v,
|
|
1329
|
+
short: `CUDA ${v}`
|
|
1330
|
+
};
|
|
1331
|
+
});
|
|
1332
|
+
|
|
1333
|
+
const { cudaVersion } = await this._runPrompts([{
|
|
1334
|
+
type: 'list',
|
|
1335
|
+
name: 'cudaVersion',
|
|
1336
|
+
message: `Select CUDA version for ${instanceType} (${instanceInfo.accelerator.hardware}):`,
|
|
1337
|
+
choices,
|
|
1338
|
+
default: defaultVersion
|
|
1339
|
+
}]);
|
|
1340
|
+
|
|
1341
|
+
const inferenceAmiVersion = PromptRunner.CUDA_AMI_MAP[cudaVersion];
|
|
1342
|
+
if (inferenceAmiVersion) {
|
|
1343
|
+
console.log(` ā
CUDA ${cudaVersion} ā AMI: ${inferenceAmiVersion}`);
|
|
1344
|
+
}
|
|
1345
|
+
|
|
1346
|
+
return inferenceAmiVersion ? { cudaVersion, inferenceAmiVersion } : null;
|
|
1347
|
+
}
|
|
1348
|
+
}
|
|
1349
|
+
|