@aws/ml-container-creator 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (143) hide show
  1. package/LICENSE +202 -0
  2. package/LICENSE-THIRD-PARTY +68620 -0
  3. package/NOTICE +2 -0
  4. package/README.md +106 -0
  5. package/bin/cli.js +365 -0
  6. package/config/defaults.json +32 -0
  7. package/config/presets/transformers-djl.json +26 -0
  8. package/config/presets/transformers-gpu.json +24 -0
  9. package/config/presets/transformers-lmi.json +27 -0
  10. package/package.json +129 -0
  11. package/servers/README.md +419 -0
  12. package/servers/base-image-picker/catalogs/model-servers.json +1191 -0
  13. package/servers/base-image-picker/catalogs/python-slim.json +38 -0
  14. package/servers/base-image-picker/catalogs/triton-backends.json +51 -0
  15. package/servers/base-image-picker/catalogs/triton.json +38 -0
  16. package/servers/base-image-picker/index.js +495 -0
  17. package/servers/base-image-picker/manifest.json +17 -0
  18. package/servers/base-image-picker/package.json +15 -0
  19. package/servers/hyperpod-cluster-picker/LICENSE +202 -0
  20. package/servers/hyperpod-cluster-picker/index.js +424 -0
  21. package/servers/hyperpod-cluster-picker/manifest.json +14 -0
  22. package/servers/hyperpod-cluster-picker/package.json +17 -0
  23. package/servers/instance-recommender/LICENSE +202 -0
  24. package/servers/instance-recommender/catalogs/instances.json +852 -0
  25. package/servers/instance-recommender/index.js +284 -0
  26. package/servers/instance-recommender/manifest.json +16 -0
  27. package/servers/instance-recommender/package.json +15 -0
  28. package/servers/lib/LICENSE +202 -0
  29. package/servers/lib/bedrock-client.js +160 -0
  30. package/servers/lib/custom-validators.js +46 -0
  31. package/servers/lib/dynamic-resolver.js +36 -0
  32. package/servers/lib/package.json +11 -0
  33. package/servers/lib/schemas/image-catalog.schema.json +185 -0
  34. package/servers/lib/schemas/instances.schema.json +124 -0
  35. package/servers/lib/schemas/manifest.schema.json +64 -0
  36. package/servers/lib/schemas/model-catalog.schema.json +91 -0
  37. package/servers/lib/schemas/regions.schema.json +26 -0
  38. package/servers/lib/schemas/triton-backends.schema.json +51 -0
  39. package/servers/model-picker/catalogs/jumpstart-public.json +66 -0
  40. package/servers/model-picker/catalogs/popular-diffusors.json +88 -0
  41. package/servers/model-picker/catalogs/popular-transformers.json +226 -0
  42. package/servers/model-picker/index.js +1693 -0
  43. package/servers/model-picker/manifest.json +18 -0
  44. package/servers/model-picker/package.json +20 -0
  45. package/servers/region-picker/LICENSE +202 -0
  46. package/servers/region-picker/catalogs/regions.json +263 -0
  47. package/servers/region-picker/index.js +230 -0
  48. package/servers/region-picker/manifest.json +16 -0
  49. package/servers/region-picker/package.json +15 -0
  50. package/src/app.js +1007 -0
  51. package/src/copy-tpl.js +77 -0
  52. package/src/lib/accelerator-validator.js +39 -0
  53. package/src/lib/asset-manager.js +385 -0
  54. package/src/lib/aws-profile-parser.js +181 -0
  55. package/src/lib/bootstrap-command-handler.js +1647 -0
  56. package/src/lib/bootstrap-config.js +238 -0
  57. package/src/lib/ci-register-helpers.js +124 -0
  58. package/src/lib/ci-report-helpers.js +158 -0
  59. package/src/lib/ci-stage-helpers.js +268 -0
  60. package/src/lib/cli-handler.js +529 -0
  61. package/src/lib/comment-generator.js +544 -0
  62. package/src/lib/community-reports-validator.js +91 -0
  63. package/src/lib/config-manager.js +2106 -0
  64. package/src/lib/configuration-exporter.js +204 -0
  65. package/src/lib/configuration-manager.js +695 -0
  66. package/src/lib/configuration-matcher.js +221 -0
  67. package/src/lib/cpu-validator.js +36 -0
  68. package/src/lib/cuda-validator.js +57 -0
  69. package/src/lib/deployment-config-resolver.js +103 -0
  70. package/src/lib/deployment-entry-schema.js +125 -0
  71. package/src/lib/deployment-registry.js +598 -0
  72. package/src/lib/docker-introspection-validator.js +51 -0
  73. package/src/lib/engine-prefix-resolver.js +60 -0
  74. package/src/lib/huggingface-client.js +172 -0
  75. package/src/lib/key-value-parser.js +37 -0
  76. package/src/lib/known-flags-validator.js +200 -0
  77. package/src/lib/manifest-cli.js +280 -0
  78. package/src/lib/mcp-client.js +303 -0
  79. package/src/lib/mcp-command-handler.js +532 -0
  80. package/src/lib/neuron-validator.js +80 -0
  81. package/src/lib/parameter-schema-validator.js +284 -0
  82. package/src/lib/prompt-runner.js +1349 -0
  83. package/src/lib/prompts.js +1138 -0
  84. package/src/lib/registry-command-handler.js +519 -0
  85. package/src/lib/registry-loader.js +198 -0
  86. package/src/lib/rocm-validator.js +80 -0
  87. package/src/lib/schema-validator.js +157 -0
  88. package/src/lib/sensitive-redactor.js +59 -0
  89. package/src/lib/template-engine.js +156 -0
  90. package/src/lib/template-manager.js +341 -0
  91. package/src/lib/validation-engine.js +314 -0
  92. package/src/prompt-adapter.js +63 -0
  93. package/templates/Dockerfile +300 -0
  94. package/templates/IAM_PERMISSIONS.md +84 -0
  95. package/templates/MIGRATION.md +488 -0
  96. package/templates/PROJECT_README.md +439 -0
  97. package/templates/TEMPLATE_SYSTEM.md +243 -0
  98. package/templates/buildspec.yml +64 -0
  99. package/templates/code/chat_template.jinja +1 -0
  100. package/templates/code/flask/gunicorn_config.py +35 -0
  101. package/templates/code/flask/wsgi.py +10 -0
  102. package/templates/code/model_handler.py +387 -0
  103. package/templates/code/serve +300 -0
  104. package/templates/code/serve.py +175 -0
  105. package/templates/code/serving.properties +105 -0
  106. package/templates/code/start_server.py +39 -0
  107. package/templates/code/start_server.sh +39 -0
  108. package/templates/diffusors/Dockerfile +72 -0
  109. package/templates/diffusors/patch_image_api.py +35 -0
  110. package/templates/diffusors/serve +115 -0
  111. package/templates/diffusors/start_server.sh +114 -0
  112. package/templates/do/.gitkeep +1 -0
  113. package/templates/do/README.md +541 -0
  114. package/templates/do/build +83 -0
  115. package/templates/do/ci +681 -0
  116. package/templates/do/clean +811 -0
  117. package/templates/do/config +260 -0
  118. package/templates/do/deploy +1560 -0
  119. package/templates/do/export +306 -0
  120. package/templates/do/logs +319 -0
  121. package/templates/do/manifest +12 -0
  122. package/templates/do/push +119 -0
  123. package/templates/do/register +580 -0
  124. package/templates/do/run +113 -0
  125. package/templates/do/submit +417 -0
  126. package/templates/do/test +1147 -0
  127. package/templates/hyperpod/configmap.yaml +24 -0
  128. package/templates/hyperpod/deployment.yaml +71 -0
  129. package/templates/hyperpod/pvc.yaml +42 -0
  130. package/templates/hyperpod/service.yaml +17 -0
  131. package/templates/nginx-diffusors.conf +74 -0
  132. package/templates/nginx-predictors.conf +47 -0
  133. package/templates/nginx-tensorrt.conf +74 -0
  134. package/templates/requirements.txt +61 -0
  135. package/templates/sample_model/test_inference.py +123 -0
  136. package/templates/sample_model/train_abalone.py +252 -0
  137. package/templates/test/test_endpoint.sh +79 -0
  138. package/templates/test/test_local_image.sh +80 -0
  139. package/templates/test/test_model_handler.py +180 -0
  140. package/templates/triton/Dockerfile +128 -0
  141. package/templates/triton/config.pbtxt +163 -0
  142. package/templates/triton/model.py +130 -0
  143. package/templates/triton/requirements.txt +11 -0
@@ -0,0 +1,1349 @@
1
+ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ // SPDX-License-Identifier: Apache-2.0
3
+
4
+ /**
5
+ * Prompt Runner - Orchestrates the prompting phases with clear user feedback
6
+ *
7
+ * This module handles running prompts in organized phases with console output
8
+ * to guide users through the configuration process.
9
+ */
10
+
11
+ import {
12
+ deploymentConfigPrompts,
13
+ enginePrompts,
14
+ frameworkVersionPrompts,
15
+ frameworkProfilePrompts,
16
+ modelFormatPrompts,
17
+ modelServerPrompts,
18
+ modelLoadStrategyPrompts,
19
+ modelProfilePrompts,
20
+ hfTokenPrompts,
21
+ ngcApiKeyPrompts,
22
+ modulePrompts,
23
+ infraRegionAndTargetPrompts,
24
+ infraInstancePrompts,
25
+ infraAsyncPrompts,
26
+ infraBatchTransformPrompts,
27
+ infraHyperPodPrompts,
28
+ infraBuildPrompts,
29
+ projectPrompts,
30
+ destinationPrompts,
31
+ baseImageSearchPrompts,
32
+ baseImagePrompts,
33
+ formatImageChoices
34
+ } from './prompts.js';
35
+
36
+ import fs from 'fs';
37
+ import path from 'path';
38
+ import { fileURLToPath } from 'node:url';
39
+ import RegistryLoader from './registry-loader.js';
40
+ import { runPrompts } from '../prompt-adapter.js';
41
+
42
+ const __pr_filename = fileURLToPath(import.meta.url);
43
+ const __pr_dirname = path.dirname(__pr_filename);
44
+ const GENERATOR_ROOT = path.resolve(__pr_dirname, '..', '..');
45
+
46
+ export default class PromptRunner {
47
+ constructor({ configManager, options, registryConfigManager, baseConfig, promptFn }) {
48
+ this.configManager = configManager;
49
+ this.options = options || {};
50
+ this.registryConfigManager = registryConfigManager || null;
51
+ this.baseConfig = baseConfig || {};
52
+ this._runPrompts = promptFn || runPrompts;
53
+ }
54
+
55
+ /**
56
+ * Runs all prompting phases and returns combined answers
57
+ * @returns {Promise<Object>} Combined answers from all phases
58
+ */
59
+ async run() {
60
+ const buildTimestamp = new Date().toISOString().replace(/[:.]/g, '-').slice(0, 19);
61
+
62
+ // Load catalog data via Registry_Loader
63
+ const registryLoader = new RegistryLoader()
64
+ this._tritonBackends = await registryLoader.loadTritonBackends()
65
+ this._instanceAcceleratorMapping = await registryLoader.loadInstanceAcceleratorMapping()
66
+
67
+ // Get existing configuration to use as defaults
68
+ const existingConfig = this.baseConfig || {};
69
+
70
+ // Get only explicit configuration (not defaults) for prompt skipping
71
+ const explicitConfig = this.configManager ? this.configManager.getExplicitConfiguration() : {};
72
+
73
+ // Phase 1: Infrastructure & Deployment
74
+ // Requirements: 3.1 — infrastructure prompts run first
75
+ // Ordering: Region → Deployment Target → Instance (if managed) → HyperPod (if eks) → Build Target
76
+ console.log('\nšŸ’Ŗ Infrastructure & Deployment');
77
+
78
+ // 1a. Query region MCP, then prompt for region + deployment target
79
+ await this._queryMcpForRegion({}, explicitConfig);
80
+ const bootstrapRegion = existingConfig.awsRegion || explicitConfig.awsRegion
81
+ const regionPreviousAnswers = bootstrapRegion ? { _bootstrapRegion: bootstrapRegion } : {}
82
+ const regionAndTargetAnswers = await this._runPhase(infraRegionAndTargetPrompts, regionPreviousAnswers, explicitConfig, existingConfig);
83
+
84
+ // 1b. Instance type — query MCP and prompt for managed-inference, async-inference, batch-transform, and hyperpod-eks
85
+ let instanceAnswers = {};
86
+ if (regionAndTargetAnswers.deploymentTarget === 'managed-inference' ||
87
+ regionAndTargetAnswers.deploymentTarget === 'async-inference' ||
88
+ regionAndTargetAnswers.deploymentTarget === 'batch-transform' ||
89
+ regionAndTargetAnswers.deploymentTarget === 'hyperpod-eks') {
90
+ await this._queryMcpForInstance({}, explicitConfig);
91
+ const mcpInstanceChoices = this.configManager?.mcpChoices?.instanceType;
92
+ const instancePreviousAnswers = {
93
+ ...regionAndTargetAnswers,
94
+ ...(mcpInstanceChoices && mcpInstanceChoices.length > 0 ? { _mcpInstanceChoices: mcpInstanceChoices } : {})
95
+ };
96
+ instanceAnswers = await this._runPhase(infraInstancePrompts, instancePreviousAnswers, explicitConfig, existingConfig);
97
+ }
98
+
99
+ // 1b-async. Async-specific prompts (only when deploymentTarget === 'async-inference')
100
+ let asyncAnswers = {};
101
+ if (regionAndTargetAnswers.deploymentTarget === 'async-inference') {
102
+ asyncAnswers = await this._runPhase(infraAsyncPrompts, { ...regionAndTargetAnswers }, explicitConfig, existingConfig);
103
+ }
104
+
105
+ // 1b-batch. Batch transform-specific prompts (only when deploymentTarget === 'batch-transform')
106
+ let batchTransformAnswers = {};
107
+ if (regionAndTargetAnswers.deploymentTarget === 'batch-transform') {
108
+ batchTransformAnswers = await this._runPhase(
109
+ infraBatchTransformPrompts,
110
+ { ...regionAndTargetAnswers },
111
+ explicitConfig,
112
+ existingConfig
113
+ );
114
+ }
115
+
116
+ // 1c. HyperPod prompts — only query MCP and prompt when deployment target is hyperpod-eks
117
+ let hyperPodAnswers = {};
118
+ if (regionAndTargetAnswers.deploymentTarget === 'hyperpod-eks') {
119
+ // Resolve the actual region (handle 'custom' selection)
120
+ const resolvedRegion = regionAndTargetAnswers.customAwsRegion || regionAndTargetAnswers.awsRegion;
121
+ await this._queryMcpForHyperPod({ ...regionAndTargetAnswers, awsRegion: resolvedRegion }, explicitConfig);
122
+ hyperPodAnswers = await this._runPhase(infraHyperPodPrompts, { ...regionAndTargetAnswers }, explicitConfig, existingConfig);
123
+ }
124
+
125
+ // 1d. Build target + role ARN (always)
126
+ const buildAnswers = await this._runPhase(infraBuildPrompts, { ...regionAndTargetAnswers, ...instanceAnswers, ...hyperPodAnswers }, explicitConfig, existingConfig);
127
+
128
+ // Combine all infrastructure answers
129
+ const infraAnswers = {
130
+ ...regionAndTargetAnswers,
131
+ ...instanceAnswers,
132
+ ...asyncAnswers,
133
+ ...batchTransformAnswers,
134
+ ...hyperPodAnswers,
135
+ ...buildAnswers
136
+ };
137
+
138
+ // Phase 2: Core ML Configuration
139
+ // Requirements: 3.1, 3.2 — ML configuration prompts run after infrastructure
140
+ console.log('\nšŸ”§ Core ML Configuration');
141
+ const deploymentConfigAnswers = await this._runPhase(deploymentConfigPrompts, { ...infraAnswers }, explicitConfig, existingConfig);
142
+
143
+ // Derive architecture, backend, and legacy framework/modelServer from deploymentConfig
144
+ // Requirements: 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7
145
+ let architecture, backend, framework, modelServer;
146
+ if (deploymentConfigAnswers.deploymentConfig) {
147
+ const parts = deploymentConfigAnswers.deploymentConfig.split('-');
148
+ architecture = parts[0];
149
+ backend = parts.slice(1).join('-');
150
+ // Legacy compatibility: derive framework and modelServer
151
+ framework = architecture;
152
+ modelServer = backend;
153
+ }
154
+
155
+ // Add derived values to answers
156
+ const frameworkAnswers = {
157
+ ...deploymentConfigAnswers,
158
+ architecture: architecture || deploymentConfigAnswers.architecture,
159
+ backend: backend || deploymentConfigAnswers.backend,
160
+ framework: framework || deploymentConfigAnswers.framework,
161
+ modelServer: modelServer || deploymentConfigAnswers.modelServer
162
+ };
163
+
164
+ // Engine prompt for http architecture
165
+ // Requirements: 3.7
166
+ const engineAnswers = await this._runPhase(enginePrompts, { ...frameworkAnswers }, explicitConfig, existingConfig);
167
+
168
+ // Auto-set model format for Triton backends with single format
169
+ // Requirements: 3.3, 3.4, 3.5
170
+ const tritonAutoFormat = this._getTritonAutoModelFormat(architecture, backend);
171
+
172
+ // Query base-image-picker MCP server for base image choices
173
+ // Requirements: 5.1, 5.2, 5.3
174
+ await this._queryMcpForBaseImage(frameworkAnswers, explicitConfig)
175
+ const baseImagePreviousAnswers = {
176
+ ...frameworkAnswers,
177
+ ...engineAnswers,
178
+ ...(this._mcpBaseImageChoices ? { _mcpBaseImageChoices: this._mcpBaseImageChoices } : {})
179
+ }
180
+ const baseImageAnswers = await this._runPhase(
181
+ baseImagePrompts,
182
+ baseImagePreviousAnswers,
183
+ explicitConfig,
184
+ existingConfig
185
+ )
186
+
187
+ // Populate framework version choices from registry
188
+ const frameworkVersionChoices = this._getFrameworkVersionChoices(frameworkAnswers.framework);
189
+ const frameworkVersionAnswers = await this._runPhase(
190
+ frameworkVersionPrompts,
191
+ {...frameworkAnswers, ...engineAnswers, _frameworkVersionChoices: frameworkVersionChoices},
192
+ explicitConfig,
193
+ existingConfig
194
+ );
195
+
196
+ // Display validation information if version was selected
197
+ if (frameworkVersionAnswers.frameworkVersion) {
198
+ this._displayFrameworkValidationInfo(frameworkAnswers.framework, frameworkVersionAnswers.frameworkVersion);
199
+ }
200
+
201
+ // Populate framework profile choices from registry
202
+ const frameworkProfileChoices = this._getFrameworkProfileChoices(
203
+ frameworkAnswers.framework,
204
+ frameworkVersionAnswers.frameworkVersion
205
+ );
206
+ const frameworkProfileAnswers = await this._runPhase(
207
+ frameworkProfilePrompts,
208
+ {...frameworkAnswers, ...engineAnswers, ...frameworkVersionAnswers, _frameworkProfileChoices: frameworkProfileChoices},
209
+ explicitConfig,
210
+ existingConfig
211
+ );
212
+
213
+ // Query model-picker MCP server for model choices
214
+ this._queryMcpForModels(frameworkAnswers.architecture)
215
+ if (this._mcpModelChoices) {
216
+ console.log(` šŸ” Querying model-picker...`)
217
+ console.log(` āœ“ ${this._mcpModelChoices.length} model(s) available from catalog`)
218
+ }
219
+ const modelFormatPreviousAnswers = {
220
+ ...frameworkAnswers,
221
+ ...engineAnswers,
222
+ ...frameworkVersionAnswers,
223
+ ...frameworkProfileAnswers,
224
+ ...(this._mcpModelChoices ? { _mcpModelChoices: this._mcpModelChoices } : {})
225
+ }
226
+ const modelFormatAnswers = await this._runPhase(
227
+ modelFormatPrompts,
228
+ modelFormatPreviousAnswers,
229
+ explicitConfig,
230
+ existingConfig
231
+ );
232
+
233
+ // Model server prompts are now deprecated (empty array)
234
+ const modelServerAnswers = await this._runPhase(
235
+ modelServerPrompts,
236
+ {...frameworkAnswers, ...engineAnswers, ...frameworkVersionAnswers, ...frameworkProfileAnswers},
237
+ explicitConfig,
238
+ existingConfig
239
+ );
240
+
241
+ // Populate model profile choices from registry (if model ID is available)
242
+ const currentAnswers = {...frameworkAnswers, ...engineAnswers, ...frameworkVersionAnswers, ...frameworkProfileAnswers, ...modelFormatAnswers, ...modelServerAnswers};
243
+ const modelId = currentAnswers.customModelName || currentAnswers.modelName || explicitConfig.modelName;
244
+
245
+ // Fetch model information from HuggingFace and Model Registry
246
+ // Requirements: 5.1, 5.2, 5.3, 5.4, 5.5, 5.6, 5.11, 11.1, 11.2, 11.3, 11.5, 11.6, 11.7
247
+ if (modelId && modelId !== 'Custom (enter manually)') {
248
+ await this._fetchAndDisplayModelInfo(modelId);
249
+ }
250
+
251
+ const modelProfileChoices = this._getModelProfileChoices(modelId);
252
+ const modelProfileAnswers = await this._runPhase(
253
+ modelProfilePrompts,
254
+ {...currentAnswers, _modelProfileChoices: modelProfileChoices},
255
+ explicitConfig,
256
+ existingConfig
257
+ );
258
+
259
+ // Model loading strategy prompt (build-time vs runtime)
260
+ // Requirements: 13.1, 13.2, 13.3, 13.4, 13.5
261
+ const modelLoadStrategyAnswers = await this._runPhase(
262
+ modelLoadStrategyPrompts,
263
+ { ...frameworkAnswers, ...engineAnswers, ...modelFormatAnswers, ...modelServerAnswers, ...modelProfileAnswers },
264
+ explicitConfig,
265
+ existingConfig
266
+ );
267
+
268
+ const hfTokenAnswers = await this._runPhase(hfTokenPrompts,
269
+ { ...frameworkAnswers, ...engineAnswers, ...frameworkVersionAnswers, ...frameworkProfileAnswers, ...modelFormatAnswers, ...modelServerAnswers, ...modelProfileAnswers },
270
+ explicitConfig, existingConfig);
271
+
272
+ const ngcApiKeyAnswers = await this._runPhase(ngcApiKeyPrompts,
273
+ { ...frameworkAnswers, ...engineAnswers, ...frameworkVersionAnswers, ...frameworkProfileAnswers, ...modelFormatAnswers, ...modelServerAnswers, ...modelProfileAnswers },
274
+ explicitConfig, existingConfig);
275
+
276
+ // Validate instance type against framework requirements (now that framework is known)
277
+ // Requirements: 4.1, 4.2, 4.3, 4.4, 4.5, 4.6
278
+ const instanceType = infraAnswers.customInstanceType || infraAnswers.instanceType;
279
+ if (instanceType && frameworkVersionAnswers.frameworkVersion) {
280
+ await this._validateAndDisplayInstanceType(
281
+ instanceType,
282
+ frameworkAnswers.framework,
283
+ frameworkVersionAnswers.frameworkVersion
284
+ );
285
+ }
286
+
287
+ // CUDA version selection: if the selected instance supports multiple CUDA versions,
288
+ // let the user pick which one. This transparently sets the inference AMI version.
289
+ const cudaAnswer = await this._promptCudaVersion(
290
+ instanceType,
291
+ frameworkAnswers.framework,
292
+ frameworkVersionAnswers.frameworkVersion
293
+ );
294
+ if (cudaAnswer) {
295
+ infraAnswers._selectedCudaVersion = cudaAnswer.cudaVersion;
296
+ infraAnswers._resolvedInferenceAmiVersion = cudaAnswer.inferenceAmiVersion;
297
+ }
298
+
299
+ // Phase 3: Module Selection
300
+ // Requirements: 3.3 — module selection after ML configuration
301
+ console.log('\nšŸ“¦ Module Selection');
302
+ const moduleAnswers = await this._runPhase(modulePrompts, { ...frameworkAnswers, ...engineAnswers }, explicitConfig, existingConfig);
303
+
304
+ // Ensure transformers, diffusors, and ineligible Triton backends don't get sample model
305
+ if (frameworkAnswers.architecture === 'transformers' ||
306
+ frameworkAnswers.architecture === 'diffusors' ||
307
+ (frameworkAnswers.architecture === 'triton' &&
308
+ !this._tritonBackends[frameworkAnswers.backend]?.supportsSampleModel)) {
309
+ moduleAnswers.includeSampleModel = false;
310
+ }
311
+
312
+ // Phase 4: Project Configuration
313
+ // Requirements: 3.4 — project configuration last
314
+ console.log('\nšŸ“‹ Project Configuration');
315
+ const allTechnicalAnswers = {
316
+ ...frameworkAnswers,
317
+ ...engineAnswers,
318
+ ...modelFormatAnswers,
319
+ ...modelServerAnswers,
320
+ ...moduleAnswers,
321
+ ...infraAnswers
322
+ };
323
+ const projectAnswers = await this._runPhase(projectPrompts, allTechnicalAnswers, explicitConfig, existingConfig);
324
+ const destinationAnswers = await this._runPhase(destinationPrompts,
325
+ { ...allTechnicalAnswers, ...projectAnswers }, explicitConfig, existingConfig);
326
+
327
+ // Combine all answers
328
+ const combinedAnswers = {
329
+ ...infraAnswers,
330
+ ...frameworkAnswers,
331
+ ...engineAnswers,
332
+ ...baseImageAnswers,
333
+ ...frameworkVersionAnswers,
334
+ ...frameworkProfileAnswers,
335
+ ...modelFormatAnswers,
336
+ ...modelServerAnswers,
337
+ ...modelProfileAnswers,
338
+ ...modelLoadStrategyAnswers,
339
+ ...hfTokenAnswers,
340
+ ...ngcApiKeyAnswers,
341
+ ...moduleAnswers,
342
+ ...projectAnswers,
343
+ ...destinationAnswers,
344
+ buildTimestamp
345
+ };
346
+
347
+ // Ensure CLI-provided values that were skipped during prompting are in combinedAnswers
348
+ if (explicitConfig.modelName && !combinedAnswers.modelName) {
349
+ combinedAnswers.modelName = explicitConfig.modelName;
350
+ }
351
+
352
+ // Flow model source metadata from model-picker MCP response
353
+ // Requirements: 2.1, 2.2, 2.3, 2.4, 2.5
354
+ if (this._mcpModelSource) {
355
+ combinedAnswers.modelSource = this._mcpModelSource;
356
+ }
357
+ if (this._mcpArtifactUri) {
358
+ combinedAnswers.artifactUri = this._mcpArtifactUri;
359
+ }
360
+
361
+ // Validate: non-HF model sources require an artifact URI
362
+ // Without it, the serve script can't download the model at runtime
363
+ // Infer modelSource from model name prefix if not set by MCP
364
+ const modelName = combinedAnswers.customModelName || combinedAnswers.modelName;
365
+ if (!combinedAnswers.modelSource && modelName) {
366
+ if (modelName.startsWith('s3://')) {
367
+ combinedAnswers.modelSource = 's3';
368
+ combinedAnswers.artifactUri = modelName;
369
+ } else if (modelName.startsWith('jumpstart://')) {
370
+ combinedAnswers.modelSource = 'jumpstart';
371
+ } else if (modelName.startsWith('jumpstart-hub://')) {
372
+ combinedAnswers.modelSource = 'jumpstart-hub';
373
+ } else if (modelName.startsWith('registry://')) {
374
+ combinedAnswers.modelSource = 'registry';
375
+ }
376
+ }
377
+ // For s3:// models, the model name IS the artifact URI
378
+ if (combinedAnswers.modelSource === 's3' && !combinedAnswers.artifactUri) {
379
+ if (modelName && modelName.startsWith('s3://')) {
380
+ combinedAnswers.artifactUri = modelName;
381
+ }
382
+ }
383
+ const downloadSources = ['jumpstart', 's3'];
384
+ if (downloadSources.includes(combinedAnswers.modelSource) && !combinedAnswers.artifactUri) {
385
+ console.log(`\n āš ļø Model source is '${combinedAnswers.modelSource}' but no artifact URI was resolved.`);
386
+ console.log(' The model-picker could not determine the download location.');
387
+ console.log(' Falling back to HuggingFace source — the model will be loaded by name.');
388
+ console.log(' If this model requires S3 download, set MODEL_ARTIFACT_URI in do/config after generation.\n');
389
+ combinedAnswers.modelSource = 'huggingface';
390
+ }
391
+
392
+ // Registry models — note about InferenceSpecification requirement
393
+ if (combinedAnswers.modelSource === 'registry') {
394
+ if (!combinedAnswers.artifactUri) {
395
+ console.log(`\n āš ļø Model source is 'registry' but no artifact URI was resolved.`);
396
+ console.log(' The model package must have an InferenceSpecification with a valid');
397
+ console.log(' ModelDataUrl or S3DataSource for the runtime resolver to work.');
398
+ console.log(' If your model package was registered without an InferenceSpecification,');
399
+ console.log(' use the S3 path directly instead: --model-name="s3://bucket/path/model.tar.gz"');
400
+ console.log(' Or set MODEL_ARTIFACT_URI in do/config before deploying.\n');
401
+ } else {
402
+ console.log('\n ā„¹ļø Registry model: the container will resolve the artifact URI at startup');
403
+ console.log(' via DescribeModelPackage. Ensure the model package has a valid');
404
+ console.log(' InferenceSpecification with ModelDataUrl or S3DataSource.\n');
405
+ }
406
+ }
407
+
408
+ // Warn about jumpstart-hub:// models — private hub deployment requires
409
+ // HubAccessConfig on CreateModel, which is not yet supported by the generator.
410
+ if (combinedAnswers.modelSource === 'jumpstart-hub') {
411
+ console.log('\n āš ļø JumpStart Private Hub models are not yet fully supported.');
412
+ console.log(' Private hub artifacts live in AWS-managed S3 buckets that require');
413
+ console.log(' SageMaker\'s HubAccessConfig mechanism for access.');
414
+ console.log(' The generated project will not be able to download model artifacts at runtime.');
415
+ console.log(' This feature is tracked for a future release.\n');
416
+ console.log(' Falling back to HuggingFace source.\n');
417
+ combinedAnswers.modelSource = 'huggingface';
418
+ delete combinedAnswers.artifactUri;
419
+ }
420
+
421
+ // Apply auto-set model format for Triton backends with single format
422
+ // Requirements: 3.3, 3.4, 3.5
423
+ if (tritonAutoFormat) {
424
+ combinedAnswers.modelFormat = tritonAutoFormat
425
+ }
426
+
427
+ // Handle custom model name for transformers, diffusors, and Triton LLM backends
428
+ if ((combinedAnswers.architecture === 'transformers' ||
429
+ combinedAnswers.architecture === 'diffusors' ||
430
+ (combinedAnswers.architecture === 'triton' && (combinedAnswers.backend === 'vllm' || combinedAnswers.backend === 'tensorrtllm')))
431
+ && combinedAnswers.customModelName) {
432
+ combinedAnswers.modelName = combinedAnswers.customModelName;
433
+ delete combinedAnswers.customModelName;
434
+ }
435
+
436
+ // Handle custom instance type
437
+ if (combinedAnswers.customInstanceType) {
438
+ combinedAnswers.instanceType = combinedAnswers.customInstanceType;
439
+ delete combinedAnswers.customInstanceType;
440
+ }
441
+
442
+ // Handle custom HyperPod cluster name
443
+ if (combinedAnswers.customHyperPodCluster) {
444
+ combinedAnswers.hyperPodCluster = combinedAnswers.customHyperPodCluster;
445
+ delete combinedAnswers.customHyperPodCluster;
446
+ }
447
+
448
+ // Apply CUDA version selection → inference AMI override
449
+ if (combinedAnswers._resolvedInferenceAmiVersion) {
450
+ combinedAnswers.inferenceAmiVersion = combinedAnswers._resolvedInferenceAmiVersion;
451
+ }
452
+ if (combinedAnswers._selectedCudaVersion) {
453
+ combinedAnswers.selectedCudaVersion = combinedAnswers._selectedCudaVersion;
454
+ }
455
+ // Clean up internal fields
456
+ delete combinedAnswers._resolvedInferenceAmiVersion;
457
+ delete combinedAnswers._selectedCudaVersion;
458
+
459
+ // Handle custom AWS region
460
+ if (combinedAnswers.customAwsRegion) {
461
+ combinedAnswers.awsRegion = combinedAnswers.customAwsRegion;
462
+ delete combinedAnswers.customAwsRegion;
463
+ }
464
+
465
+ // Handle custom base image
466
+ if (combinedAnswers.customBaseImage) {
467
+ combinedAnswers.baseImage = combinedAnswers.customBaseImage
468
+ combinedAnswers._baseImageSource = 'custom'
469
+ delete combinedAnswers.customBaseImage
470
+ }
471
+
472
+ // Handle --base-image CLI override
473
+ if (this.options['base-image']) {
474
+ combinedAnswers.baseImage = this.options['base-image']
475
+ }
476
+
477
+ // Map awsRoleArn to roleArn for templates
478
+ if (combinedAnswers.awsRoleArn) {
479
+ combinedAnswers.roleArn = combinedAnswers.awsRoleArn;
480
+ delete combinedAnswers.awsRoleArn;
481
+ }
482
+
483
+ return combinedAnswers;
484
+ }
485
+
486
+ /**
487
+ * Checks if a parameter is promptable according to the parameter matrix
488
+ * @param {string} parameterName - Name of the parameter
489
+ * @returns {boolean} True if parameter is promptable
490
+ * @private
491
+ */
492
+ _isParameterPromptable(parameterName) {
493
+ if (!this.configManager || !this.configManager.parameterMatrix) {
494
+ return true; // Default to promptable if matrix not available
495
+ }
496
+
497
+ const paramConfig = this.configManager.parameterMatrix[parameterName];
498
+ return paramConfig ? paramConfig.promptable : true;
499
+ }
500
+
501
+ /**
502
+ * Filters prompts to exclude non-promptable parameters
503
+ * @param {Array} prompts - Array of prompt objects
504
+ * @returns {Array} Filtered prompts excluding non-promptable parameters
505
+ * @private
506
+ */
507
+ _filterPromptableParameters(prompts) {
508
+ return prompts.filter(prompt => this._isParameterPromptable(prompt.name));
509
+ }
510
+
511
+ /**
512
+ * Runs a single phase of prompts
513
+ * @private
514
+ */
515
+ async _runPhase(prompts, previousAnswers = {}, explicitConfig = {}, existingConfig = {}) {
516
+ // Filter out non-promptable parameters
517
+ const promptablePrompts = this._filterPromptableParameters(prompts);
518
+
519
+ if (promptablePrompts.length === 0) return {};
520
+
521
+ // First, add any existing config values to previousAnswers so they're available for defaults
522
+ const allPreviousAnswers = { ...existingConfig, ...previousAnswers };
523
+
524
+ return await this._runPrompts(promptablePrompts.map(prompt => ({
525
+ ...prompt,
526
+ // Wrap message to inject previousAnswers so prompts can access _mcpInstanceChoices etc.
527
+ message: typeof prompt.message === 'function' ? (answers) => {
528
+ return prompt.message({...allPreviousAnswers, ...answers});
529
+ } : prompt.message,
530
+ // Use existing config as default if available
531
+ default: prompt.default ? (answers) => {
532
+ // Check if we have a value from existing config first
533
+ if (existingConfig[prompt.name] !== undefined && existingConfig[prompt.name] !== null) {
534
+ return existingConfig[prompt.name];
535
+ }
536
+ // Otherwise use the original default logic
537
+ if (typeof prompt.default === 'function') {
538
+ return prompt.default({...allPreviousAnswers, ...answers});
539
+ }
540
+ return prompt.default;
541
+ } : (existingConfig[prompt.name] !== undefined && existingConfig[prompt.name] !== null) ?
542
+ existingConfig[prompt.name] : undefined,
543
+ // Skip prompt ONLY if we have explicit config (not defaults)
544
+ when: prompt.when ? (answers) => {
545
+ // Skip if we have the value from explicit config (CLI, env vars, config files)
546
+ if (explicitConfig[prompt.name] !== undefined && explicitConfig[prompt.name] !== null) {
547
+ return false;
548
+ }
549
+ return prompt.when({...allPreviousAnswers, ...answers});
550
+ } : (explicitConfig[prompt.name] !== undefined && explicitConfig[prompt.name] !== null) ?
551
+ () => false : undefined,
552
+ // Provide access to previous answers for conditional logic
553
+ // For unbounded parameters, inject MCP-provided choices if available
554
+ choices: prompt.choices ? (answers) => {
555
+ const mcpChoices = this.configManager?.mcpChoices?.[prompt.name];
556
+ if (mcpChoices && mcpChoices.length > 0) {
557
+ return [...mcpChoices.map(v => ({ name: v, value: v })), { name: 'Custom (enter manually)', value: 'custom' }];
558
+ }
559
+ // Fallback to original choices
560
+ if (typeof prompt.choices === 'function') {
561
+ return prompt.choices({...allPreviousAnswers, ...answers});
562
+ }
563
+ return prompt.choices;
564
+ } : undefined
565
+ })));
566
+ }
567
+
568
+ /**
569
+ * Get auto-set model format for Triton backends with a single format.
570
+ * Returns null if the backend requires user selection (FIL, Python) or
571
+ * doesn't use model formats (vllm, tensorrtllm).
572
+ * Requirements: 3.3, 3.4, 3.5
573
+ * @param {string} architecture - Resolved architecture
574
+ * @param {string} backend - Resolved backend
575
+ * @returns {string|null} Auto-set model format or null
576
+ * @private
577
+ */
578
+ _getTritonAutoModelFormat(architecture, backend) {
579
+ if (architecture !== 'triton') return null
580
+
581
+ const meta = this._tritonBackends[backend]
582
+ if (!meta || !meta.modelFormats) return null
583
+
584
+ // Only auto-set if there's exactly one format
585
+ if (meta.modelFormats.length === 1) {
586
+ return meta.modelFormats[0]
587
+ }
588
+
589
+ return null
590
+ }
591
+
592
+ /**
593
+ * Query MCP region-picker server before infrastructure prompts.
594
+ * Populates configManager.mcpChoices so _runPhase injects them into list prompts.
595
+ * @private
596
+ */
597
+ async _queryMcpForRegion(frameworkAnswers, explicitConfig) {
598
+ const cm = this.configManager;
599
+ if (!cm) return;
600
+
601
+ const mcpServers = cm.getMcpServerNames();
602
+ if (mcpServers.length === 0) return;
603
+
604
+ const smart = this.options.smart === true;
605
+
606
+ // Region: skip MCP query if region was explicitly provided via CLI, config file, or bootstrap profile
607
+ const cliRegion = this.options.region;
608
+ const bootstrapRegion = explicitConfig.awsRegion;
609
+ const skipRegionQuery = (cliRegion !== undefined && cliRegion !== null) ||
610
+ (bootstrapRegion !== undefined && bootstrapRegion !== null);
611
+
612
+ if (!skipRegionQuery && mcpServers.includes('region-picker')) {
613
+ const { regionSearch } = await this._runPrompts([{
614
+ type: 'input',
615
+ name: 'regionSearch',
616
+ message: 'šŸ”Œ Search for a region (e.g. "europe", "us west", "tokyo"):',
617
+ default: ''
618
+ }]);
619
+
620
+ if (regionSearch && regionSearch.trim()) {
621
+ console.log(` šŸ” Querying region-picker${smart ? ' [smart]' : ''}...`);
622
+ const result = await cm.queryMcpServer('region-picker', {
623
+ ...frameworkAnswers,
624
+ regionSearch: regionSearch.trim()
625
+ });
626
+ if (result && result.choices?.awsRegion?.length > 0) {
627
+ const choices = result.choices.awsRegion;
628
+ const preview = choices.length <= 5
629
+ ? choices.join(', ')
630
+ : `${choices.slice(0, 5).join(', ') } (+${choices.length - 5} more)`;
631
+ console.log(` āœ“ ${choices.length} region(s): [${preview}]`);
632
+ } else {
633
+ console.log(' ↳ No MCP results, using static list');
634
+ }
635
+ }
636
+ }
637
+ }
638
+
639
+ /**
640
+ * Query MCP instance-recommender server after deployment target is known.
641
+ * Only runs when deploymentTarget is managed-inference.
642
+ * Populates configManager.mcpChoices so _runPhase injects them into list prompts.
643
+ * @private
644
+ */
645
+ async _queryMcpForInstance(frameworkAnswers, explicitConfig) {
646
+ const cm = this.configManager;
647
+ if (!cm) return;
648
+
649
+ const mcpServers = cm.getMcpServerNames();
650
+ if (mcpServers.length === 0) return;
651
+
652
+ const smart = this.options.smart === true;
653
+
654
+ // Instance type: query if not already provided via CLI/config
655
+ if (!explicitConfig.instanceType && mcpServers.includes('instance-recommender')) {
656
+ const { instanceSearch } = await this._runPrompts([{
657
+ type: 'input',
658
+ name: 'instanceSearch',
659
+ message: 'šŸ”Œ Describe your instance needs (e.g. "multi-gpu", "cost-effective cpu"):',
660
+ default: frameworkAnswers.framework || ''
661
+ }]);
662
+
663
+ if (instanceSearch && instanceSearch.trim()) {
664
+ console.log(` šŸ” Querying instance-recommender${smart ? ' [smart]' : ''}...`);
665
+ const result = await cm.queryMcpServer('instance-recommender', {
666
+ ...frameworkAnswers,
667
+ instanceSearch: instanceSearch.trim()
668
+ });
669
+ if (result && result.choices?.instanceType?.length > 0) {
670
+ const choices = result.choices.instanceType;
671
+ const preview = choices.length <= 5
672
+ ? choices.join(', ')
673
+ : `${choices.slice(0, 5).join(', ') } (+${choices.length - 5} more)`;
674
+ console.log(` āœ“ ${choices.length} instance(s): [${preview}]`);
675
+ } else {
676
+ console.log(' ↳ No MCP results, using static list');
677
+ }
678
+ }
679
+ }
680
+ }
681
+
682
+ /**
683
+ * Query the hyperpod-cluster-picker MCP server for available HyperPod EKS clusters.
684
+ * Populates configManager.mcpChoices.hyperPodCluster so _runPhase injects them into the list prompt.
685
+ * Falls back to manual entry if the MCP server is not configured or fails.
686
+ * Requirements: 12.1, 12.2, 12.3
687
+ * @private
688
+ */
689
+ async _queryMcpForHyperPod(infraAnswers, explicitConfig) {
690
+ const cm = this.configManager;
691
+ if (!cm) return;
692
+
693
+ const mcpServers = cm.getMcpServerNames();
694
+ if (!mcpServers.includes('hyperpod-cluster-picker')) return;
695
+
696
+ // Skip if cluster already provided via CLI/config
697
+ if (explicitConfig.hyperPodCluster) return;
698
+
699
+ const smart = this.options.smart === true;
700
+ console.log(` šŸ” Querying hyperpod-cluster-picker${smart ? ' [smart]' : ''}...`);
701
+
702
+ const result = await cm.queryMcpServer('hyperpod-cluster-picker', {
703
+ ...infraAnswers
704
+ });
705
+
706
+ if (result && result.choices?.hyperPodCluster?.length > 0) {
707
+ const choices = result.choices.hyperPodCluster;
708
+ const preview = choices.length <= 5
709
+ ? choices.join(', ')
710
+ : `${choices.slice(0, 5).join(', ')} (+${choices.length - 5} more)`;
711
+ console.log(` āœ“ ${choices.length} cluster(s): [${preview}]`);
712
+ } else {
713
+ // Surface any error message from the MCP server
714
+ if (result?.message) {
715
+ console.log(` āš ļø ${result.message}`);
716
+ } else {
717
+ console.log(' ↳ No HyperPod clusters found via MCP, manual entry available');
718
+ }
719
+ }
720
+ }
721
+
722
+ /**
723
+ * Query MCP base-image-picker server after deployment config is selected.
724
+ * Populates _mcpBaseImageChoices for the base image selection prompt.
725
+ * Requirements: 5.1, 5.2, 5.3, 5.4, 9.1, 9.2, 9.3
726
+ * @private
727
+ */
728
+ async _queryMcpForBaseImage(frameworkAnswers, explicitConfig) {
729
+ // Skip if base image provided via CLI --base-image flag
730
+ if (this.options['base-image']) return
731
+
732
+ const cm = this.configManager
733
+ if (!cm) return
734
+
735
+ const mcpServers = cm.getMcpServerNames()
736
+ if (!mcpServers.includes('base-image-picker')) return
737
+
738
+ const smart = this.options.smart === true
739
+ const discover = this.options.discover === true
740
+ const framework = frameworkAnswers.framework
741
+ const modelServer = frameworkAnswers.modelServer
742
+ const architecture = frameworkAnswers.architecture || frameworkAnswers.deploymentConfig?.split('-')[0]
743
+ const isTransformer = framework === 'transformers'
744
+ const isTriton = architecture === 'triton'
745
+ const isDiffusors = architecture === 'diffusors'
746
+
747
+ // For non-transformer, non-triton, non-diffusors frameworks, prompt for optional search criteria
748
+ let searchCriteria
749
+ if (!isTransformer && !isTriton && !isDiffusors) {
750
+ const searchAnswer = await this._runPrompts(baseImageSearchPrompts.map(p => ({
751
+ ...p,
752
+ when: () => true // Always show for non-transformer since we already checked
753
+ })))
754
+ searchCriteria = searchAnswer.baseImageSearch
755
+ }
756
+
757
+ const modeLabel = [smart && '[smart]', discover && '[discover]'].filter(Boolean).join(' ')
758
+ console.log(` šŸ” Querying base-image-picker${modeLabel ? ` ${modeLabel}` : ''}...`)
759
+
760
+ const context = { framework, modelServer, architecture }
761
+ if (searchCriteria && searchCriteria.trim()) {
762
+ context.searchCriteria = searchCriteria.trim()
763
+ }
764
+
765
+ const result = await cm.queryMcpServer('base-image-picker', context)
766
+
767
+ if (result && result.metadata?.baseImage?.length > 0) {
768
+ const entries = result.metadata.baseImage
769
+ this._mcpBaseImageChoices = formatImageChoices(entries, isTransformer || isTriton || isDiffusors)
770
+ const count = entries.length
771
+ console.log(` āœ“ ${count} base image(s) available`)
772
+ } else {
773
+ console.log(' ↳ No MCP results, using default image')
774
+ }
775
+ }
776
+
777
+ /**
778
+ * Query model-picker MCP server catalog for model choices.
779
+ * Reads the architecture-specific catalog (popular-transformers.json or
780
+ * popular-diffusors.json) to populate the model selection prompt.
781
+ * @param {string} [architecture] - Current architecture ('transformers', 'diffusors', etc.)
782
+ * @private
783
+ */
784
+ _queryMcpForModels(architecture) {
785
+ const cm = this.configManager
786
+ if (!cm) return
787
+
788
+ const mcpServers = cm.getMcpServerNames()
789
+ if (!mcpServers.includes('model-picker')) return
790
+
791
+ try {
792
+ const mcpConfigPath = path.join(GENERATOR_ROOT, 'config', 'mcp.json')
793
+ if (!fs.existsSync(mcpConfigPath)) return
794
+
795
+ const mcpConfig = JSON.parse(fs.readFileSync(mcpConfigPath, 'utf8'))
796
+ const serverConfig = mcpConfig.mcpServers?.['model-picker']
797
+ if (!serverConfig?.args?.length) return
798
+
799
+ // Resolve the server entry point directory from the args
800
+ const serverEntryPoint = serverConfig.args[serverConfig.args.length - 1]
801
+ const serverDir = path.dirname(serverEntryPoint)
802
+
803
+ // Read manifest to find catalog path
804
+ const manifestPath = path.join(serverDir, 'manifest.json')
805
+ if (!fs.existsSync(manifestPath)) return
806
+
807
+ const manifest = JSON.parse(fs.readFileSync(manifestPath, 'utf8'))
808
+
809
+ // Select catalog based on architecture
810
+ const catalogKey = architecture === 'diffusors'
811
+ ? 'popular-diffusors'
812
+ : 'popular-transformers'
813
+ const catalogRelPath = manifest.catalogs?.[catalogKey]
814
+ if (!catalogRelPath) return
815
+
816
+ const catalogPath = path.resolve(serverDir, catalogRelPath)
817
+ if (!fs.existsSync(catalogPath)) return
818
+
819
+ const catalog = JSON.parse(fs.readFileSync(catalogPath, 'utf8'))
820
+
821
+ // Extract model IDs, filtering out glob patterns (entries with *)
822
+ const modelIds = Object.keys(catalog).filter(id => !id.includes('*'))
823
+
824
+ if (modelIds.length > 0) {
825
+ this._mcpModelChoices = modelIds
826
+ }
827
+ } catch {
828
+ // Silently fall back to hardcoded defaults
829
+ }
830
+ }
831
+
832
+ /**
833
+ * Get framework version choices from registry
834
+ * Requirements: 2.1, 2.6, 8.2, 8.3
835
+ * @private
836
+ */
837
+ _getFrameworkVersionChoices(framework) {
838
+ const registryConfigManager = this.registryConfigManager;
839
+
840
+ if (!registryConfigManager || !registryConfigManager.frameworkRegistry) {
841
+ return [];
842
+ }
843
+
844
+ const frameworkVersions = registryConfigManager.frameworkRegistry[framework];
845
+ if (!frameworkVersions || Object.keys(frameworkVersions).length === 0) {
846
+ return [];
847
+ }
848
+
849
+ // Get available versions and sort them
850
+ const versions = Object.keys(frameworkVersions).sort((a, b) => {
851
+ // Simple version comparison (can be enhanced with semver)
852
+ return b.localeCompare(a, undefined, { numeric: true });
853
+ });
854
+
855
+ // Create choices with validation level indicators
856
+ return versions.map(version => {
857
+ const config = frameworkVersions[version];
858
+ const validationLevel = config.validationLevel || 'unknown';
859
+ const indicator = {
860
+ 'tested': 'āœ…',
861
+ 'community-validated': 'šŸ‘„',
862
+ 'experimental': '🧪',
863
+ 'unknown': 'ā“'
864
+ }[validationLevel] || 'ā“';
865
+
866
+ return {
867
+ name: `${version} ${indicator} (${validationLevel})`,
868
+ value: version,
869
+ short: version
870
+ };
871
+ });
872
+ }
873
+
874
+ /**
875
+ * Display framework validation information
876
+ * Requirements: 2.6, 8.2, 8.3
877
+ * @private
878
+ */
879
+ _displayFrameworkValidationInfo(framework, version) {
880
+ const registryConfigManager = this.registryConfigManager;
881
+
882
+ if (!registryConfigManager || !registryConfigManager.frameworkRegistry) {
883
+ return;
884
+ }
885
+
886
+ const config = registryConfigManager.frameworkRegistry[framework]?.[version];
887
+ if (!config) {
888
+ return;
889
+ }
890
+
891
+ console.log('\nšŸ“‹ Framework Configuration:');
892
+ console.log(` • Framework: ${framework} ${version}`);
893
+ console.log(` • Validation Level: ${config.validationLevel || 'unknown'}`);
894
+ console.log(' • Source: Framework_Registry');
895
+
896
+ if (config.accelerator) {
897
+ console.log(` • Accelerator: ${config.accelerator.type} ${config.accelerator.version || 'any'}`);
898
+ }
899
+
900
+ if (config.recommendedInstanceTypes && config.recommendedInstanceTypes.length > 0) {
901
+ console.log(` • Recommended Instances: ${config.recommendedInstanceTypes.slice(0, 3).join(', ')}`);
902
+ }
903
+
904
+ if (config.notes) {
905
+ console.log(` • Notes: ${config.notes}`);
906
+ }
907
+ }
908
+
909
+ /**
910
+ * Get framework profile choices from registry
911
+ * Requirements: 12.1, 12.2, 12.3, 12.4, 12.5, 12.10
912
+ * @private
913
+ */
914
+ _getFrameworkProfileChoices(framework, version) {
915
+ const registryConfigManager = this.registryConfigManager;
916
+
917
+ if (!registryConfigManager || !registryConfigManager.frameworkRegistry) {
918
+ return [];
919
+ }
920
+
921
+ const config = registryConfigManager.frameworkRegistry[framework]?.[version];
922
+ if (!config || !config.profiles || Object.keys(config.profiles).length === 0) {
923
+ return [];
924
+ }
925
+
926
+ // Create choices from profiles
927
+ const choices = Object.entries(config.profiles).map(([profileName, profileConfig]) => {
928
+ return {
929
+ name: `${profileConfig.displayName || profileName} - ${profileConfig.description || 'No description'}`,
930
+ value: profileName,
931
+ short: profileConfig.displayName || profileName
932
+ };
933
+ });
934
+
935
+ // Add "default" option to skip profile selection
936
+ choices.unshift({
937
+ name: 'Default (no profile)',
938
+ value: null,
939
+ short: 'Default'
940
+ });
941
+
942
+ return choices;
943
+ }
944
+
945
+ /**
946
+ * Get model profile choices from registry
947
+ * Requirements: 12.1, 12.2, 12.3, 12.4, 12.5, 12.10
948
+ * @private
949
+ */
950
+ _getModelProfileChoices(modelId) {
951
+ const registryConfigManager = this.registryConfigManager;
952
+
953
+ if (!registryConfigManager || !registryConfigManager.modelRegistry || !modelId) {
954
+ return [];
955
+ }
956
+
957
+ // Try to find model in registry (exact match or pattern match)
958
+ let modelConfig = registryConfigManager.modelRegistry[modelId];
959
+
960
+ // If no exact match, try pattern matching
961
+ if (!modelConfig) {
962
+ for (const [pattern, config] of Object.entries(registryConfigManager.modelRegistry)) {
963
+ if (pattern.includes('*')) {
964
+ const regex = new RegExp(`^${ pattern.replace(/\*/g, '.*') }$`);
965
+ if (regex.test(modelId)) {
966
+ modelConfig = config;
967
+ break;
968
+ }
969
+ }
970
+ }
971
+ }
972
+
973
+ if (!modelConfig || !modelConfig.profiles || Object.keys(modelConfig.profiles).length === 0) {
974
+ return [];
975
+ }
976
+
977
+ // Create choices from profiles
978
+ const choices = Object.entries(modelConfig.profiles).map(([profileName, profileConfig]) => {
979
+ return {
980
+ name: `${profileConfig.displayName || profileName} - ${profileConfig.description || 'No description'}`,
981
+ value: profileName,
982
+ short: profileConfig.displayName || profileName
983
+ };
984
+ });
985
+
986
+ // Add "default" option to skip profile selection
987
+ choices.unshift({
988
+ name: 'Default (no profile)',
989
+ value: null,
990
+ short: 'Default'
991
+ });
992
+
993
+ return choices;
994
+ }
995
+
996
+ /**
997
+ * Fetch and display model information from HuggingFace API and Model Registry
998
+ * Requirements: 5.1, 5.2, 5.3, 5.4, 5.5, 5.6, 5.11, 11.1, 11.2, 11.3, 11.5, 11.6, 11.7
999
+ * @private
1000
+ */
1001
+ async _fetchAndDisplayModelInfo(modelId) {
1002
+ console.log(`\n šŸ” Querying model-picker [discover]...`);
1003
+
1004
+ const sources = [];
1005
+ let chatTemplate = null;
1006
+ let modelFamily = null;
1007
+ let mcpUsed = false;
1008
+
1009
+ // Try model-picker MCP server in discover mode (queries HuggingFace + merges with catalog)
1010
+ const cm = this.configManager;
1011
+ if (cm) {
1012
+ const mcpServers = cm.getMcpServerNames();
1013
+ if (mcpServers.includes('model-picker')) {
1014
+ try {
1015
+ const mcpConfigPath = path.join(GENERATOR_ROOT, 'config', 'mcp.json');
1016
+ if (fs.existsSync(mcpConfigPath)) {
1017
+ const mcpConfig = JSON.parse(fs.readFileSync(mcpConfigPath, 'utf8'));
1018
+ const serverConfig = mcpConfig.mcpServers?.['model-picker'];
1019
+ if (serverConfig) {
1020
+ const { McpClient } = await import('./mcp-client.js');
1021
+ const client = new McpClient(serverConfig, { timeout: 15000 });
1022
+
1023
+ // Override _buildContext to pass model_id and mode directly
1024
+ client._getUnboundedParameterNames = () => [];
1025
+ client._buildContext = () => ({});
1026
+
1027
+ // Connect and call get_models directly
1028
+ const { Client } = await import('@modelcontextprotocol/sdk/client/index.js');
1029
+ const { StdioClientTransport } = await import('@modelcontextprotocol/sdk/client/stdio.js');
1030
+
1031
+ const transport = new StdioClientTransport({
1032
+ command: serverConfig.command,
1033
+ args: serverConfig.args || [],
1034
+ env: { ...process.env, ...(serverConfig.env || {}) },
1035
+ stderr: 'pipe'
1036
+ });
1037
+
1038
+ const mcpClient = new Client(
1039
+ { name: 'ml-container-creator', version: '1.0.0' },
1040
+ { capabilities: {} }
1041
+ );
1042
+
1043
+ await mcpClient.connect(transport);
1044
+
1045
+ const result = await mcpClient.callTool({
1046
+ name: 'get_models',
1047
+ arguments: { model_id: modelId, mode: 'discover' }
1048
+ });
1049
+
1050
+ await mcpClient.close();
1051
+
1052
+ // Parse the response
1053
+ const textBlock = result?.content?.find(b => b.type === 'text');
1054
+ if (textBlock) {
1055
+ const parsed = JSON.parse(textBlock.text);
1056
+ if (parsed.values && Object.keys(parsed.values).length > 0) {
1057
+ mcpUsed = true;
1058
+ const vals = parsed.values;
1059
+
1060
+ if (vals.chat_template) {
1061
+ chatTemplate = vals.chat_template;
1062
+ }
1063
+ if (vals.family) {
1064
+ modelFamily = vals.family;
1065
+ }
1066
+
1067
+ // Extract model source metadata for loading adapter
1068
+ // Requirements: 2.1, 2.2, 2.3, 2.4
1069
+ if (vals.provider) {
1070
+ this._mcpModelSource = vals.provider;
1071
+ }
1072
+ if (vals.artifactUri) {
1073
+ this._mcpArtifactUri = vals.artifactUri;
1074
+ }
1075
+
1076
+ // Determine sources based on what was returned
1077
+ if (vals.tags || vals.pipeline_tag) {
1078
+ sources.push('HuggingFace_Hub_API');
1079
+ }
1080
+ if (vals.validation_level || vals.framework_compatibility) {
1081
+ sources.push('Model_Picker_Catalog');
1082
+ }
1083
+ if (sources.length === 0) {
1084
+ sources.push('model-picker');
1085
+ }
1086
+ console.log(` āœ“ Resolved: ${modelId}`);
1087
+ } else if (parsed.message) {
1088
+ console.log(` ↳ ${parsed.message}`);
1089
+ }
1090
+ }
1091
+ }
1092
+ }
1093
+ } catch (err) {
1094
+ console.log(' ↳ model-picker unavailable, using fallback');
1095
+ }
1096
+ }
1097
+ }
1098
+
1099
+ // Fallback to legacy path if MCP didn't resolve
1100
+ if (!mcpUsed) {
1101
+ const registryConfigManager = this.registryConfigManager;
1102
+ if (registryConfigManager) {
1103
+ // Only try HuggingFace API for bare model IDs (not prefixed URIs)
1104
+ const isNonHfUri = modelId.startsWith('jumpstart://') ||
1105
+ modelId.startsWith('jumpstart-hub://') ||
1106
+ modelId.startsWith('s3://') ||
1107
+ modelId.startsWith('registry://');
1108
+
1109
+ if (!isNonHfUri) {
1110
+ // Try HuggingFace API directly
1111
+ try {
1112
+ const hfData = await registryConfigManager._fetchHuggingFaceData(modelId);
1113
+ if (hfData) {
1114
+ sources.push('HuggingFace_Hub_API');
1115
+ if (hfData.chatTemplate) {
1116
+ chatTemplate = hfData.chatTemplate;
1117
+ }
1118
+ console.log(' āœ… Found on HuggingFace Hub');
1119
+ } else {
1120
+ console.log(' ā„¹ļø Not found on HuggingFace Hub (may be private or offline)');
1121
+ }
1122
+ } catch (error) {
1123
+ console.log(' āš ļø HuggingFace API unavailable');
1124
+ }
1125
+ } else {
1126
+ // Non-HF URI (jumpstart://, s3://, etc.) — skip HF lookup silently
1127
+ // The summary at the end of this function will report "No additional model information"
1128
+ }
1129
+
1130
+ // Check Model Registry for overrides
1131
+ if (registryConfigManager.modelRegistry) {
1132
+ let modelConfig = registryConfigManager.modelRegistry[modelId];
1133
+
1134
+ if (!modelConfig) {
1135
+ for (const [pattern, config] of Object.entries(registryConfigManager.modelRegistry)) {
1136
+ if (pattern.includes('*')) {
1137
+ const regex = new RegExp('^' + pattern.replace(/\*/g, '.*') + '$');
1138
+ if (regex.test(modelId)) {
1139
+ modelConfig = config;
1140
+ console.log(` āœ… Matched pattern in Model_Registry: ${pattern}`);
1141
+ break;
1142
+ }
1143
+ }
1144
+ }
1145
+ } else {
1146
+ console.log(' āœ… Found in Model_Registry');
1147
+ }
1148
+
1149
+ if (modelConfig) {
1150
+ sources.push('Model_Registry');
1151
+ if (modelConfig.chatTemplate) {
1152
+ chatTemplate = modelConfig.chatTemplate;
1153
+ }
1154
+ if (modelConfig.family) {
1155
+ modelFamily = modelConfig.family;
1156
+ }
1157
+ }
1158
+ }
1159
+ }
1160
+ }
1161
+
1162
+ // Display information
1163
+ if (sources.length > 0) {
1164
+ console.log('\nšŸ“‹ Model Information:');
1165
+ console.log(` • Model ID: ${modelId}`);
1166
+ if (modelFamily) {
1167
+ console.log(` • Family: ${modelFamily}`);
1168
+ }
1169
+ if (chatTemplate) {
1170
+ console.log(' • Chat Template: āœ… Available');
1171
+ console.log(' (Will be injected into generated files)');
1172
+ } else {
1173
+ console.log(' • Chat Template: āŒ Not available');
1174
+ console.log(' (Chat endpoints may require manual configuration)');
1175
+ }
1176
+ console.log(` • Sources: ${sources.join(', ')}`);
1177
+ } else {
1178
+ console.log(' ā„¹ļø No additional model information available');
1179
+ console.log(' Proceeding with default configuration');
1180
+ }
1181
+ }
1182
+
1183
+
1184
+
1185
+ /**
1186
+ * Validate and display instance type compatibility
1187
+ * Requirements: 4.1, 4.2, 4.3, 4.4, 4.5, 4.6
1188
+ * @private
1189
+ */
1190
+ async _validateAndDisplayInstanceType(instanceType, framework, version) {
1191
+ const registryConfigManager = this.registryConfigManager;
1192
+
1193
+ if (!registryConfigManager) {
1194
+ return;
1195
+ }
1196
+
1197
+ // Get framework configuration
1198
+ const frameworkConfig = registryConfigManager.frameworkRegistry?.[framework]?.[version];
1199
+ if (!frameworkConfig) {
1200
+ return; // No framework config, skip validation
1201
+ }
1202
+
1203
+ console.log(`\nšŸ” Validating instance type: ${instanceType}`);
1204
+
1205
+ // Validate instance type
1206
+ const validationResult = registryConfigManager.validateInstanceType(instanceType, frameworkConfig);
1207
+
1208
+ if (validationResult.compatible) {
1209
+ console.log(' āœ… Instance type is compatible');
1210
+ if (validationResult.info) {
1211
+ console.log(` ā„¹ļø ${validationResult.info}`);
1212
+ }
1213
+ } else {
1214
+ console.log(' āŒ Instance type compatibility issue detected');
1215
+ if (validationResult.error) {
1216
+ console.log(` Error: ${validationResult.error}`);
1217
+ }
1218
+ if (validationResult.recommendations && validationResult.recommendations.length > 0) {
1219
+ console.log(` šŸ’” Recommended instances: ${validationResult.recommendations.join(', ')}`);
1220
+ }
1221
+
1222
+ // In test mode or non-interactive mode, throw error instead of prompting
1223
+ if (this.options.skipPrompts || process.env.NODE_ENV === 'test') {
1224
+ throw new Error('Instance type validation failed. Please select a compatible instance type.');
1225
+ }
1226
+
1227
+ // Ask user if they want to proceed
1228
+ const proceed = await this._runPrompts([{
1229
+ type: 'confirm',
1230
+ name: 'proceedWithIncompatible',
1231
+ message: 'Instance type may not be compatible. Proceed anyway?',
1232
+ default: false
1233
+ }]);
1234
+
1235
+ if (!proceed.proceedWithIncompatible) {
1236
+ throw new Error('Instance type validation failed. Please select a compatible instance type.');
1237
+ }
1238
+ }
1239
+
1240
+ if (validationResult.warning) {
1241
+ console.log(` āš ļø Warning: ${validationResult.warning}`);
1242
+ }
1243
+ }
1244
+
1245
+ /**
1246
+ * CUDA-to-AMI mapping.
1247
+ * Maps CUDA major.minor versions to the SageMaker inference AMI that provides
1248
+ * the matching CUDA driver. Derived from the framework registry patterns.
1249
+ * @private
1250
+ */
1251
+ static CUDA_AMI_MAP = {
1252
+ '11.0': 'al2-ami-sagemaker-inference-gpu-2-1',
1253
+ '11.4': 'al2-ami-sagemaker-inference-gpu-2-1',
1254
+ '11.8': 'al2-ami-sagemaker-inference-gpu-3-1',
1255
+ '12.1': 'al2-ami-sagemaker-inference-gpu-3-1',
1256
+ '12.2': 'al2-ami-sagemaker-inference-gpu-3-2',
1257
+ '12.4': 'al2-ami-sagemaker-inference-gpu-3-2',
1258
+ '12.6': 'al2-ami-sagemaker-inference-gpu-3-2'
1259
+ };
1260
+
1261
+ /**
1262
+ * Prompt the user to select a CUDA version when the selected GPU instance
1263
+ * supports multiple versions. The choice transparently resolves to the
1264
+ * correct SageMaker inference AMI.
1265
+ *
1266
+ * Skipped for CPU instances, non-CUDA accelerators, or when only one
1267
+ * compatible CUDA version exists.
1268
+ *
1269
+ * @param {string} instanceType - Selected instance type (e.g. "ml.g5.2xlarge")
1270
+ * @param {string} framework - Selected framework name
1271
+ * @param {string} frameworkVersion - Selected framework version
1272
+ * @returns {Promise<{cudaVersion: string, inferenceAmiVersion: string}|null>}
1273
+ * @private
1274
+ */
1275
+ async _promptCudaVersion(instanceType, framework, frameworkVersion) {
1276
+ if (!instanceType) return null;
1277
+
1278
+ // Look up instance in accelerator mapping
1279
+ const instanceInfo = this._instanceAcceleratorMapping[instanceType];
1280
+ if (!instanceInfo || instanceInfo.accelerator.type !== 'cuda') return null;
1281
+
1282
+ const instanceCudaVersions = instanceInfo.accelerator.versions;
1283
+ if (!instanceCudaVersions || instanceCudaVersions.length === 0) return null;
1284
+
1285
+ // Get framework CUDA requirements (if available)
1286
+ const registryConfigManager = this.registryConfigManager;
1287
+ const frameworkConfig = registryConfigManager?.frameworkRegistry?.[framework]?.[frameworkVersion];
1288
+ const frameworkAccel = frameworkConfig?.accelerator;
1289
+
1290
+ // Compute compatible CUDA versions: intersection of instance support and framework range
1291
+ let compatibleVersions;
1292
+ if (frameworkAccel?.versionRange) {
1293
+ const { min, max } = frameworkAccel.versionRange;
1294
+ compatibleVersions = instanceCudaVersions.filter(v => {
1295
+ return v >= min && v <= max;
1296
+ });
1297
+ } else {
1298
+ compatibleVersions = [...instanceCudaVersions];
1299
+ }
1300
+
1301
+ if (compatibleVersions.length === 0) {
1302
+ // No overlap — fall back to all instance versions (validation already warned)
1303
+ compatibleVersions = [...instanceCudaVersions];
1304
+ }
1305
+
1306
+ // If only one option, auto-select it silently
1307
+ if (compatibleVersions.length === 1) {
1308
+ const cudaVersion = compatibleVersions[0];
1309
+ const inferenceAmiVersion = PromptRunner.CUDA_AMI_MAP[cudaVersion];
1310
+ if (inferenceAmiVersion) {
1311
+ console.log(`\nšŸ”§ CUDA ${cudaVersion} auto-selected (only compatible version for ${instanceType})`);
1312
+ console.log(` AMI: ${inferenceAmiVersion}`);
1313
+ }
1314
+ return inferenceAmiVersion ? { cudaVersion, inferenceAmiVersion } : null;
1315
+ }
1316
+
1317
+ // Multiple options — let the user choose
1318
+ const defaultVersion = frameworkAccel?.version
1319
+ && compatibleVersions.includes(frameworkAccel.version)
1320
+ ? frameworkAccel.version
1321
+ : instanceInfo.accelerator.default || compatibleVersions[compatibleVersions.length - 1];
1322
+
1323
+ const choices = compatibleVersions.map(v => {
1324
+ const ami = PromptRunner.CUDA_AMI_MAP[v] || 'unknown';
1325
+ const isDefault = v === defaultVersion ? ' (recommended)' : '';
1326
+ return {
1327
+ name: `CUDA ${v}${isDefault} → AMI: ${ami}`,
1328
+ value: v,
1329
+ short: `CUDA ${v}`
1330
+ };
1331
+ });
1332
+
1333
+ const { cudaVersion } = await this._runPrompts([{
1334
+ type: 'list',
1335
+ name: 'cudaVersion',
1336
+ message: `Select CUDA version for ${instanceType} (${instanceInfo.accelerator.hardware}):`,
1337
+ choices,
1338
+ default: defaultVersion
1339
+ }]);
1340
+
1341
+ const inferenceAmiVersion = PromptRunner.CUDA_AMI_MAP[cudaVersion];
1342
+ if (inferenceAmiVersion) {
1343
+ console.log(` āœ… CUDA ${cudaVersion} → AMI: ${inferenceAmiVersion}`);
1344
+ }
1345
+
1346
+ return inferenceAmiVersion ? { cudaVersion, inferenceAmiVersion } : null;
1347
+ }
1348
+ }
1349
+