@aws/ml-container-creator 0.2.6 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. package/bin/cli.js +38 -2
  2. package/config/bootstrap-stack.json +94 -1
  3. package/config/defaults.json +1 -1
  4. package/infra/ci-harness/package-lock.json +22 -9
  5. package/package.json +3 -1
  6. package/servers/instance-sizer/index.js +45 -8
  7. package/servers/instance-sizer/lib/instance-ranker.js +140 -11
  8. package/servers/instance-sizer/lib/model-resolver.js +10 -6
  9. package/servers/instance-sizer/lib/quota-resolver.js +368 -0
  10. package/servers/instance-sizer/package.json +2 -0
  11. package/servers/lib/catalogs/instances.json +527 -12
  12. package/servers/lib/catalogs/model-servers.json +298 -20
  13. package/servers/lib/catalogs/model-sizes.json +27 -0
  14. package/servers/lib/catalogs/models.json +101 -0
  15. package/servers/lib/schemas/image-catalog.schema.json +15 -1
  16. package/servers/model-picker/index.js +2 -1
  17. package/src/app.js +96 -2
  18. package/src/lib/architecture-sync.js +171 -0
  19. package/src/lib/arn-detection.js +22 -0
  20. package/src/lib/bootstrap-command-handler.js +178 -3
  21. package/src/lib/cli-handler.js +2 -2
  22. package/src/lib/config-manager.js +121 -1
  23. package/src/lib/cross-cutting-checker.js +119 -0
  24. package/src/lib/deployment-entry-schema.js +1 -2
  25. package/src/lib/prompt-runner.js +514 -20
  26. package/src/lib/prompts.js +67 -5
  27. package/src/lib/registry-command-handler.js +236 -0
  28. package/src/lib/schema-sync.js +31 -0
  29. package/src/lib/secret-classification.js +56 -0
  30. package/src/lib/secrets-command-handler.js +550 -0
  31. package/src/lib/template-manager.js +49 -1
  32. package/src/lib/validate-runner.js +174 -2
  33. package/src/lib/validation-report.js +8 -1
  34. package/src/prompt-adapter.js +3 -2
  35. package/templates/Dockerfile +10 -2
  36. package/templates/code/cuda_compat.sh +22 -0
  37. package/templates/code/serve +3 -0
  38. package/templates/code/start_server.sh +3 -0
  39. package/templates/diffusors/Dockerfile +2 -1
  40. package/templates/diffusors/serve +3 -0
  41. package/templates/do/README.md +33 -0
  42. package/templates/do/benchmark +646 -0
  43. package/templates/do/build +22 -0
  44. package/templates/do/clean +86 -0
  45. package/templates/do/config +41 -6
  46. package/templates/do/deploy +66 -6
  47. package/templates/do/logs +18 -3
  48. package/templates/do/register +8 -1
  49. package/templates/do/run +10 -0
  50. package/templates/triton/Dockerfile +5 -0
@@ -583,7 +583,7 @@ const modulePrompts = [
583
583
  type: 'confirm',
584
584
  name: 'includeSampleModel',
585
585
  message: 'Include sample Abalone classifier?',
586
- default: false,
586
+ default: true,
587
587
  when: (answers) => {
588
588
  const architecture = answers.architecture || answers.deploymentConfig?.split('-')[0];
589
589
  const backend = answers.backend || answers.deploymentConfig?.split('-').slice(1).join('-');
@@ -622,7 +622,10 @@ const modulePrompts = [
622
622
 
623
623
  // Transformers and Triton LLM backends only support hosted endpoint tests
624
624
  if (architecture === 'transformers') {
625
- return ['hosted-model-endpoint'];
625
+ return ['hosted-model-endpoint', 'sagemaker-ai-automated-benchmarking'];
626
+ }
627
+ if (architecture === 'diffusors') {
628
+ return ['hosted-model-endpoint', 'sagemaker-ai-automated-benchmarking'];
626
629
  }
627
630
  if (architecture === 'triton' && (backend === 'vllm' || backend === 'tensorrtllm')) {
628
631
  return ['hosted-model-endpoint'];
@@ -635,7 +638,10 @@ const modulePrompts = [
635
638
  const backend = answers.backend || answers.deploymentConfig?.split('-').slice(1).join('-');
636
639
 
637
640
  if (architecture === 'transformers') {
638
- return ['hosted-model-endpoint'];
641
+ return ['hosted-model-endpoint', 'sagemaker-ai-automated-benchmarking'];
642
+ }
643
+ if (architecture === 'diffusors') {
644
+ return ['hosted-model-endpoint', 'sagemaker-ai-automated-benchmarking'];
639
645
  }
640
646
  if (architecture === 'triton' && (backend === 'vllm' || backend === 'tensorrtllm')) {
641
647
  return ['hosted-model-endpoint'];
@@ -700,7 +706,12 @@ const infraInstancePrompts = [
700
706
  when: answers => answers.deploymentTarget === 'realtime-inference' || answers.deploymentTarget === 'async-inference' || answers.deploymentTarget === 'batch-transform' || answers.deploymentTarget === 'hyperpod-eks',
701
707
  message: (answers) => {
702
708
  const framework = answers.framework || answers.deploymentConfig?.split('-')[0];
703
-
709
+
710
+ // Skip table when MCP sizer already displayed annotated results
711
+ if (answers._mcpInstanceChoices && answers._mcpInstanceChoices.length > 0) {
712
+ return 'Select instance type:';
713
+ }
714
+
704
715
  const table = new Table({
705
716
  head: [
706
717
  chalk.cyan('Instance Type'),
@@ -1053,7 +1064,7 @@ function formatImageChoices(entries, isTransformer) {
1053
1064
  ? `${entry.repository.padEnd(30)} ${entry.tag.padEnd(16)} ${entry.architecture.padEnd(7)} ${cuda.padEnd(6)} ${python.padEnd(8)} ${date}`
1054
1065
  : `${entry.repository.padEnd(30)} ${entry.tag.padEnd(16)} ${entry.architecture.padEnd(7)} ${python.padEnd(8)} ${date}`;
1055
1066
 
1056
- return { name, value: entry.image };
1067
+ return { name, value: entry.image, _meta: { labels: entry.labels, accelerator: entry.accelerator } };
1057
1068
  });
1058
1069
  }
1059
1070
 
@@ -1110,6 +1121,56 @@ const baseImagePrompts = [
1110
1121
  }
1111
1122
  ];
1112
1123
 
1124
+ /**
1125
+ * Benchmark prompts for SageMaker AI Benchmarking (NVIDIA AIPerf)
1126
+ * Sub-prompts shown when 'sagemaker-ai-automated-benchmarking' is selected in testTypes.
1127
+ * Requirements: 2.1, 2.2, 2.3, 2.4, 2.5
1128
+ */
1129
+ const benchmarkPrompts = [
1130
+ {
1131
+ type: 'number',
1132
+ name: 'benchmarkConcurrency',
1133
+ message: 'Concurrent requests for benchmark:',
1134
+ default: 10,
1135
+ when: (answers) => answers.includeBenchmark === true
1136
+ },
1137
+ {
1138
+ type: 'number',
1139
+ name: 'benchmarkInputTokensMean',
1140
+ message: 'Mean input tokens per request:',
1141
+ default: 550,
1142
+ when: (answers) => answers.includeBenchmark === true
1143
+ },
1144
+ {
1145
+ type: 'number',
1146
+ name: 'benchmarkOutputTokensMean',
1147
+ message: 'Mean output tokens per request:',
1148
+ default: 150,
1149
+ when: (answers) => answers.includeBenchmark === true
1150
+ },
1151
+ {
1152
+ type: 'confirm',
1153
+ name: 'benchmarkStreaming',
1154
+ message: 'Enable streaming for benchmark?',
1155
+ default: true,
1156
+ when: (answers) => answers.includeBenchmark === true
1157
+ },
1158
+ {
1159
+ type: 'input',
1160
+ name: 'benchmarkRequestCount',
1161
+ message: 'Total request count (leave empty for service default):',
1162
+ default: '',
1163
+ when: (answers) => answers.includeBenchmark === true
1164
+ },
1165
+ {
1166
+ type: 'input',
1167
+ name: 'benchmarkS3OutputPath',
1168
+ message: 'Benchmark results S3 path (leave empty for auto-created bucket):',
1169
+ default: '',
1170
+ when: (answers) => answers.includeBenchmark === true
1171
+ }
1172
+ ];
1173
+
1113
1174
  export {
1114
1175
  deploymentConfigPrompts,
1115
1176
  frameworkPrompts, // Deprecated: kept for backward compatibility
@@ -1123,6 +1184,7 @@ export {
1123
1184
  hfTokenPrompts,
1124
1185
  ngcApiKeyPrompts,
1125
1186
  modulePrompts,
1187
+ benchmarkPrompts,
1126
1188
  infrastructurePrompts,
1127
1189
  infraRegionAndTargetPrompts,
1128
1190
  infraInstancePrompts,
@@ -24,6 +24,8 @@ import { readFileSync } from 'node:fs';
24
24
  import { execSync } from 'node:child_process';
25
25
  import { fileURLToPath } from 'node:url';
26
26
  import DeploymentRegistry, { reconstructReplayFlags } from './deployment-registry.js';
27
+ import { syncArchitectures } from './architecture-sync.js';
28
+ import HuggingFaceClient from './huggingface-client.js';
27
29
 
28
30
  const PERSONAL_REGISTRY_PATH = path.join(os.homedir(), '.ml-container-creator', 'registry.json');
29
31
  const PROJECT_REGISTRY_PATH = path.join(process.cwd(), '.ml-container-creator', 'registry.json');
@@ -71,6 +73,15 @@ export default class RegistryCommandHandler {
71
73
  case 'search':
72
74
  this._handleSearch(options);
73
75
  break;
76
+ case 'sync-architectures':
77
+ await this._handleSyncArchitectures();
78
+ break;
79
+ case 'list-architectures':
80
+ this._handleListArchitectures(args, options);
81
+ break;
82
+ case 'check':
83
+ await this._handleCheck(args);
84
+ break;
74
85
  default:
75
86
  console.log(`Unknown registry subcommand: ${subcommand}`);
76
87
  this._showRegistryHelp();
@@ -431,6 +442,220 @@ export default class RegistryCommandHandler {
431
442
  console.log('');
432
443
  }
433
444
 
445
+ /**
446
+ * registry sync-architectures
447
+ *
448
+ * Fetches model registry source files from server GitHub repositories
449
+ * and populates supportedModelTypes in the model-servers catalog.
450
+ */
451
+ async _handleSyncArchitectures() {
452
+ const __filename = fileURLToPath(import.meta.url);
453
+ const __dirname = path.dirname(__filename);
454
+ const catalogPath = path.resolve(__dirname, '../../servers/lib/catalogs/model-servers.json');
455
+
456
+ console.log('\n📋 Syncing model architecture registry...\n');
457
+
458
+ const summary = await syncArchitectures(catalogPath);
459
+
460
+ console.log('\n── Summary ──────────────────────────────────────');
461
+ if (summary.servers.length > 0) {
462
+ console.log('\n Architectures synced:');
463
+ for (const { server, version, count } of summary.servers) {
464
+ console.log(` ${server} ${version}: ${count} architectures`);
465
+ }
466
+ }
467
+ if (summary.failures.length > 0) {
468
+ console.log('\n Failures:');
469
+ for (const { server, version, reason } of summary.failures) {
470
+ console.log(` ${server} ${version}: ${reason}`);
471
+ }
472
+ }
473
+ if (summary.servers.length === 0 && summary.failures.length === 0) {
474
+ console.log('\n No server entries found with matching registry sources.');
475
+ }
476
+ console.log('');
477
+ }
478
+
479
+ /**
480
+ * registry list-architectures [--server <name>] [--verbose]
481
+ *
482
+ * Displays a table of server versions and their supported architecture counts.
483
+ * With --server or --verbose, shows the full list of supported model types.
484
+ *
485
+ * @param {object} options - Parsed CLI options
486
+ */
487
+ _handleListArchitectures(args, options) {
488
+ const __filename = fileURLToPath(import.meta.url);
489
+ const __dirname = path.dirname(__filename);
490
+ const catalogPath = path.resolve(__dirname, '../../servers/lib/catalogs/model-servers.json');
491
+
492
+ let catalog;
493
+ try {
494
+ catalog = JSON.parse(readFileSync(catalogPath, 'utf8'));
495
+ } catch (err) {
496
+ console.log(`Error: Could not read model-servers catalog: ${err.message}`);
497
+ return;
498
+ }
499
+
500
+ // Parse --server and --verbose from pass-through args (Commander's passThroughOptions
501
+ // puts options after the subcommand into the args array)
502
+ let serverFilter = options.server || null;
503
+ let verbose = options.verbose || false;
504
+ for (const arg of args) {
505
+ if (arg.startsWith('--server=')) {
506
+ serverFilter = arg.split('=')[1];
507
+ } else if (arg === '--server' && args.indexOf(arg) + 1 < args.length) {
508
+ serverFilter = args[args.indexOf(arg) + 1];
509
+ } else if (arg === '--verbose') {
510
+ verbose = true;
511
+ }
512
+ }
513
+
514
+ // Collect rows: { server, version, count, types }
515
+ const rows = [];
516
+ for (const [server, entries] of Object.entries(catalog)) {
517
+ if (serverFilter && server !== serverFilter) continue;
518
+ for (const entry of entries) {
519
+ const version = entry.labels?.framework_version || '(unknown)';
520
+ const types = entry.supportedModelTypes || [];
521
+ rows.push({ server, version, count: types.length, types });
522
+ }
523
+ }
524
+
525
+ if (rows.length === 0) {
526
+ if (serverFilter) {
527
+ console.log(`No entries found for server "${serverFilter}".`);
528
+ } else {
529
+ console.log('No server entries found in catalog.');
530
+ }
531
+ return;
532
+ }
533
+
534
+ // Display summary table
535
+ console.log('\nModel Architecture Support:\n');
536
+ console.log(' Server Version Architectures');
537
+ console.log(' ──────────────────── ─────────── ─────────────');
538
+ for (const row of rows) {
539
+ const srv = row.server.padEnd(20);
540
+ const ver = row.version.padEnd(11);
541
+ const cnt = row.count === 0 ? '(not synced)' : String(row.count);
542
+ console.log(` ${srv} ${ver} ${cnt}`);
543
+ }
544
+ console.log('');
545
+
546
+ // Show full list when --server or --verbose is set
547
+ if (serverFilter || verbose) {
548
+ for (const row of rows) {
549
+ if (row.types.length === 0) continue;
550
+ console.log(` ${row.server} ${row.version} supported model types:`);
551
+ console.log(` ${row.types.join(', ')}`);
552
+ console.log('');
553
+ }
554
+ }
555
+ }
556
+
557
+ /**
558
+ * registry check <model-id>
559
+ *
560
+ * Fetches a model's config.json from HuggingFace, extracts the model_type,
561
+ * and checks compatibility against all server versions in the catalog.
562
+ *
563
+ * @param {string[]} args - Remaining positional args (args[1] = model-id)
564
+ */
565
+ async _handleCheck(args) {
566
+ const modelId = args[1];
567
+
568
+ if (!modelId) {
569
+ console.log('Usage: ml-container-creator registry check <model-id>');
570
+ console.log('Example: ml-container-creator registry check meta-llama/Llama-2-7b-chat-hf');
571
+ return;
572
+ }
573
+
574
+ const __filename = fileURLToPath(import.meta.url);
575
+ const __dirname = path.dirname(__filename);
576
+ const catalogPath = path.resolve(__dirname, '../../servers/lib/catalogs/model-servers.json');
577
+
578
+ // Fetch model's config.json from HuggingFace
579
+ console.log(`\n🔍 Checking model: ${modelId}\n`);
580
+ console.log(' Fetching model config from HuggingFace...');
581
+
582
+ const hfClient = new HuggingFaceClient({ timeout: 10000 });
583
+ const config = await hfClient.fetchModelConfig(modelId);
584
+
585
+ if (!config) {
586
+ console.log(`\n ❌ Could not fetch config.json for "${modelId}".`);
587
+ console.log(' Verify the model ID is correct and accessible on HuggingFace.');
588
+ return;
589
+ }
590
+
591
+ const modelType = config.model_type;
592
+ if (!modelType) {
593
+ console.log(`\n ❌ No "model_type" field found in config.json for "${modelId}".`);
594
+ return;
595
+ }
596
+
597
+ console.log(` Model type: ${modelType}`);
598
+
599
+ // Load model-servers catalog
600
+ let catalog;
601
+ try {
602
+ catalog = JSON.parse(readFileSync(catalogPath, 'utf8'));
603
+ } catch (err) {
604
+ console.log(`\n ❌ Could not read model-servers catalog: ${err.message}`);
605
+ return;
606
+ }
607
+
608
+ // Check model_type against all server entries
609
+ const compatible = [];
610
+ const incompatible = [];
611
+ let hasAnyData = false;
612
+
613
+ for (const [server, entries] of Object.entries(catalog)) {
614
+ for (const entry of entries) {
615
+ const version = entry.labels?.framework_version || '(unknown)';
616
+ const supported = entry.supportedModelTypes;
617
+
618
+ if (!supported || supported.length === 0) continue;
619
+
620
+ hasAnyData = true;
621
+ const modelTypeLower = modelType.toLowerCase();
622
+ if (supported.includes(modelTypeLower) || supported.includes(modelType)) {
623
+ compatible.push({ server, version });
624
+ } else {
625
+ incompatible.push({ server, version });
626
+ }
627
+ }
628
+ }
629
+
630
+ // Display results
631
+ if (!hasAnyData) {
632
+ console.log('\n ⚠️ No architecture data available. Run "registry sync-architectures" first.');
633
+ return;
634
+ }
635
+
636
+ if (compatible.length > 0) {
637
+ console.log('\n ✅ Compatible server versions:');
638
+ for (const { server, version } of compatible) {
639
+ console.log(` • ${server} ${version}`);
640
+ }
641
+ }
642
+
643
+ if (incompatible.length > 0) {
644
+ console.log('\n ⚠️ Potentially incompatible server versions:');
645
+ for (const { server, version } of incompatible) {
646
+ console.log(` • ${server} ${version}`);
647
+ }
648
+ }
649
+
650
+ if (compatible.length === 0) {
651
+ console.log(`\n ⚠️ Model architecture "${modelType}" was not found in any server's supported types.`);
652
+ console.log(' This may indicate the model requires a newer server version,');
653
+ console.log(' or it may work via trust_remote_code. Check server documentation for details.');
654
+ }
655
+
656
+ console.log('');
657
+ }
658
+
434
659
  /**
435
660
  * Show registry usage help.
436
661
  */
@@ -449,6 +674,9 @@ SUBCOMMANDS:
449
674
  export [id] [--status <status>] Export entries as JSON
450
675
  import <file> [--merge|--replace] Import entries from JSON
451
676
  search [filters] Search entries with glob matching
677
+ sync-architectures Sync supported model types from server repos
678
+ list-architectures Show supported architectures per server version
679
+ check <model-id> Check model compatibility with server versions
452
680
 
453
681
  FILTER OPTIONS (for list and search):
454
682
  --backend <backend> Filter by backend (e.g., vllm, flask)
@@ -467,6 +695,10 @@ IMPORT OPTIONS:
467
695
  --merge Keep both existing and imported on conflict
468
696
  --replace Overwrite existing with imported on conflict
469
697
 
698
+ LIST-ARCHITECTURES OPTIONS:
699
+ --server <name> Show full model type list for a specific server
700
+ --verbose Show full model type list for all servers
701
+
470
702
  OTHER OPTIONS:
471
703
  --project Use project-level registry instead of personal
472
704
 
@@ -481,6 +713,10 @@ EXAMPLES:
481
713
  ml-container-creator registry export a1b2c3d4
482
714
  ml-container-creator registry import team-deployments.json --merge
483
715
  ml-container-creator registry search --model "meta-llama/*" --backend vllm
716
+ ml-container-creator registry list-architectures
717
+ ml-container-creator registry list-architectures --server vllm
718
+ ml-container-creator registry list-architectures --verbose
719
+ ml-container-creator registry check meta-llama/Llama-2-7b-chat-hf
484
720
  `);
485
721
  }
486
722
 
@@ -188,6 +188,37 @@ export function loadServiceModel(serviceName, registryPath) {
188
188
  return readFileSync(modelPath, 'utf8');
189
189
  }
190
190
 
191
+ /**
192
+ * Check whether the SageMaker service model includes the CreateAIBenchmarkJob operation shape.
193
+ * Used to determine if benchmark parameter validation can be performed.
194
+ *
195
+ * @param {string} [registryPath] - Override registry path
196
+ * @returns {{ available: boolean, reason?: string }}
197
+ */
198
+ export function hasBenchmarkShape(registryPath) {
199
+ const regPath = registryPath || getRegistryPath();
200
+ const modelContent = loadServiceModel('sagemaker', regPath);
201
+
202
+ if (!modelContent) {
203
+ return { available: false, reason: 'SageMaker service model not found in registry' };
204
+ }
205
+
206
+ try {
207
+ const model = JSON.parse(modelContent);
208
+ const operations = model.operations || {};
209
+ const shapes = model.shapes || {};
210
+
211
+ // Check for the CreateAIBenchmarkJob operation or its input shape
212
+ if (operations.CreateAIBenchmarkJob || shapes.CreateAIBenchmarkJobRequest) {
213
+ return { available: true };
214
+ }
215
+
216
+ return { available: false, reason: 'service model does not include AI Benchmark operations' };
217
+ } catch {
218
+ return { available: false, reason: 'Failed to parse SageMaker service model' };
219
+ }
220
+ }
221
+
191
222
  /**
192
223
  * Store a service model in the registry.
193
224
  * @param {string} serviceName - Service name (e.g., 'sagemaker')
@@ -0,0 +1,56 @@
1
+ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ // SPDX-License-Identifier: Apache-2.0
3
+
4
+ /**
5
+ * Secret Classification Registry
6
+ *
7
+ * Single source of truth for all secret type metadata. Each entry defines
8
+ * the identifier, display name, applicable stages, purpose, CLI flags,
9
+ * environment variable names, and prompt labels for a secret type.
10
+ *
11
+ * Adding a new secret type requires only adding a new entry to this array —
12
+ * the CLI, prompt flow, and do-script templates derive behavior from this registry.
13
+ */
14
+
15
+ export const SECRET_CLASSIFICATIONS = Object.freeze([
16
+ {
17
+ identifier: 'hf-token',
18
+ displayName: 'HuggingFace Token',
19
+ stages: ['build-time', 'runtime'],
20
+ purpose: 'Gated model download from HuggingFace Hub',
21
+ cliFlag: 'hf-token-arn',
22
+ cliFlagPlaintext: 'hf-token',
23
+ envVar: 'HF_TOKEN',
24
+ envVarArn: 'HF_TOKEN_ARN',
25
+ promptLabel: 'HuggingFace token'
26
+ },
27
+ {
28
+ identifier: 'ngc-token',
29
+ displayName: 'NVIDIA NGC Token',
30
+ stages: ['build-time'],
31
+ purpose: 'Pulling base images from NVIDIA NGC registry',
32
+ cliFlag: 'ngc-token-arn',
33
+ cliFlagPlaintext: 'ngc-token',
34
+ envVar: 'NGC_API_KEY',
35
+ envVarArn: 'NGC_API_KEY_ARN',
36
+ promptLabel: 'NVIDIA NGC API key'
37
+ }
38
+ ]);
39
+
40
+ /**
41
+ * Look up a classification entry by identifier.
42
+ * @param {string} identifier - e.g. 'hf-token'
43
+ * @returns {Object|undefined}
44
+ */
45
+ export function getClassification(identifier) {
46
+ return SECRET_CLASSIFICATIONS.find(c => c.identifier === identifier);
47
+ }
48
+
49
+ /**
50
+ * Get all classifications applicable to a given stage.
51
+ * @param {string} stage - 'build-time' or 'runtime'
52
+ * @returns {Object[]}
53
+ */
54
+ export function getClassificationsForStage(stage) {
55
+ return SECRET_CLASSIFICATIONS.filter(c => c.stages.includes(stage));
56
+ }