@aws/ml-container-creator 0.2.6 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -24,6 +24,8 @@ import { fileURLToPath } from 'node:url';
24
24
  import BootstrapConfig from './bootstrap-config.js';
25
25
  import AwsProfileParser from './aws-profile-parser.js';
26
26
  import AssetManager from './asset-manager.js';
27
+ import McpCommandHandler from './mcp-command-handler.js';
28
+ import RegistryCommandHandler from './registry-command-handler.js';
27
29
  import { runPrompts } from '../prompt-adapter.js';
28
30
 
29
31
  const __filename = fileURLToPath(import.meta.url);
@@ -311,6 +313,9 @@ export default class BootstrapCommandHandler {
311
313
 
312
314
  // Display summary
313
315
  this._displaySummary(profileName, profileData);
316
+
317
+ // Step 6: Post-setup chain (mcp init → sync-architectures → sync-schemas)
318
+ await this._runPostSetupChain(options);
314
319
  }
315
320
 
316
321
  /**
@@ -1054,6 +1059,74 @@ export default class BootstrapCommandHandler {
1054
1059
  // Save updated profile
1055
1060
  this.config.setProfile(name, profileConfig);
1056
1061
  console.log(`\n✅ Update complete for profile "${name}"`);
1062
+
1063
+ // Re-run post-setup chain after updating AWS resources
1064
+ await this._runPostSetupChain(options);
1065
+ }
1066
+
1067
+ /**
1068
+ * Run the post-setup chain: mcp init → registry sync-architectures → sync-schemas.
1069
+ * Each step is independent — failures are collected and reported at the end.
1070
+ *
1071
+ * @param {object} options - Parsed CLI options (checks skipPostSetup)
1072
+ */
1073
+ async _runPostSetupChain(options = {}) {
1074
+ if (options['skip-post-setup']) {
1075
+ console.log('\n⏭️ Skipping post-setup chain (--skip-post-setup)');
1076
+ return;
1077
+ }
1078
+
1079
+ console.log('\n🔗 Running post-setup configuration...\n');
1080
+
1081
+ const failures = [];
1082
+
1083
+ // 1. MCP init — register bundled MCP servers
1084
+ console.log('📡 Registering MCP servers...');
1085
+ try {
1086
+ const generatorAdapter = {
1087
+ destinationPath(...segments) {
1088
+ return path.resolve(process.cwd(), ...segments);
1089
+ }
1090
+ };
1091
+ const mcpHandler = new McpCommandHandler(generatorAdapter);
1092
+ await mcpHandler.handle(['init'], {});
1093
+ } catch (error) {
1094
+ failures.push({ step: 'mcp init', error: error.message });
1095
+ console.log(` ⚠️ mcp init failed: ${error.message}`);
1096
+ }
1097
+
1098
+ // 2. Registry sync-architectures — populate supportedModelTypes
1099
+ console.log('\n📋 Syncing model architecture registry...');
1100
+ try {
1101
+ const registryHandler = new RegistryCommandHandler();
1102
+ await registryHandler.handle(['sync-architectures'], {});
1103
+ } catch (error) {
1104
+ failures.push({ step: 'registry sync-architectures', error: error.message });
1105
+ console.log(` ⚠️ registry sync-architectures failed: ${error.message}`);
1106
+ }
1107
+
1108
+ // 3. Schema sync — download AWS service models
1109
+ console.log('\n📐 Syncing service schemas...');
1110
+ try {
1111
+ await this._handleSyncSchemas();
1112
+ } catch (error) {
1113
+ failures.push({ step: 'sync-schemas', error: error.message });
1114
+ console.log(` ⚠️ sync-schemas failed: ${error.message}`);
1115
+ }
1116
+
1117
+ // Report results
1118
+ if (failures.length === 0) {
1119
+ console.log('\n✅ Bootstrap complete — all systems operational');
1120
+ } else {
1121
+ console.log(`\n⚠️ Bootstrap complete with ${failures.length} warning${failures.length === 1 ? '' : 's'}:`);
1122
+ for (const { step, error } of failures) {
1123
+ console.log(` • ${step}: ${error}`);
1124
+ }
1125
+ console.log('\n These steps can be re-run individually:');
1126
+ console.log(' ml-container-creator mcp init');
1127
+ console.log(' ml-container-creator registry sync-architectures');
1128
+ console.log(' ml-container-creator bootstrap sync-schemas');
1129
+ }
1057
1130
  }
1058
1131
 
1059
1132
  /**
@@ -1242,12 +1315,20 @@ export default class BootstrapCommandHandler {
1242
1315
  Effect: 'Allow',
1243
1316
  Action: [
1244
1317
  's3:GetObject',
1318
+ 's3:PutObject',
1319
+ 's3:AbortMultipartUpload',
1245
1320
  's3:ListBucket'
1246
1321
  ],
1247
1322
  Resource: [
1248
1323
  'arn:aws:s3:::ml-container-creator-*',
1249
1324
  'arn:aws:s3:::ml-container-creator-*/*'
1250
1325
  ]
1326
+ },
1327
+ {
1328
+ Sid: 'SNSPublish',
1329
+ Effect: 'Allow',
1330
+ Action: 'sns:Publish',
1331
+ Resource: 'arn:aws:sns:*:*:ml-container-creator-*'
1251
1332
  }
1252
1333
  ]
1253
1334
  };
@@ -1649,6 +1730,7 @@ SETUP OPTIONS:
1649
1730
  --skip-s3 Skip S3 bucket creation
1650
1731
  --ci Provision CI testing infrastructure
1651
1732
  --skip-ci Skip CI infrastructure provisioning
1733
+ --skip-post-setup Skip post-setup chain (mcp init, sync-architectures, sync-schemas)
1652
1734
 
1653
1735
  STATUS OPTIONS:
1654
1736
  --verify Check each active resource against AWS APIs for drift detection
@@ -300,6 +300,15 @@ export default class ConfigManager {
300
300
  finalConfig.hfToken = this._resolveHfToken(finalConfig.hfToken);
301
301
  }
302
302
 
303
+ // Mutual exclusion: ARN takes precedence over plaintext when both are set
304
+ // (CLI validation should prevent this, but enforce at config level too)
305
+ if (finalConfig.hfTokenArn) {
306
+ finalConfig.hfToken = null;
307
+ }
308
+ if (finalConfig.ngcTokenArn) {
309
+ finalConfig.ngcApiKey = null;
310
+ }
311
+
303
312
  // Map awsRoleArn to roleArn for templates
304
313
  if (finalConfig.awsRoleArn) {
305
314
  finalConfig.roleArn = finalConfig.awsRoleArn;
@@ -643,6 +652,28 @@ export default class ConfigManager {
643
652
  default: null,
644
653
  valueSpace: 'bounded'
645
654
  },
655
+ hfTokenArn: {
656
+ cliOption: 'hf-token-arn',
657
+ envVar: null,
658
+ configFile: true,
659
+ packageJson: false,
660
+ mcp: false,
661
+ promptable: false,
662
+ required: false,
663
+ default: null,
664
+ valueSpace: 'bounded'
665
+ },
666
+ ngcTokenArn: {
667
+ cliOption: 'ngc-token-arn',
668
+ envVar: null,
669
+ configFile: true,
670
+ packageJson: false,
671
+ mcp: false,
672
+ promptable: false,
673
+ required: false,
674
+ default: null,
675
+ valueSpace: 'bounded'
676
+ },
646
677
  deploymentTarget: {
647
678
  cliOption: 'deployment-target',
648
679
  envVar: 'ML_DEPLOYMENT_TARGET',
@@ -1675,6 +1706,18 @@ export default class ConfigManager {
1675
1706
  }
1676
1707
  }
1677
1708
 
1709
+ // Validate mutual exclusion: plaintext token and ARN cannot both be set
1710
+ if (this.config.hfToken && this.config.hfTokenArn) {
1711
+ errors.push('Cannot specify both --hf-token and --hf-token-arn. Use one or the other.');
1712
+ }
1713
+ if (this.config.ngcTokenArn) {
1714
+ // Check ngcToken from CLI options (Commander converts --ngc-token to ngcToken)
1715
+ const ngcTokenFromCli = this.options['ngc-token'];
1716
+ if (ngcTokenFromCli) {
1717
+ errors.push('Cannot specify both --ngc-token and --ngc-token-arn. Use one or the other.');
1718
+ }
1719
+ }
1720
+
1678
1721
  // Validate AWS Role ARN format if provided
1679
1722
  if (this.config.awsRoleArn) {
1680
1723
  try {
@@ -22,6 +22,7 @@ export default class CrossCuttingChecker {
22
22
  findings.push(...this.checkRoleArnFormat(context));
23
23
  findings.push(...this.checkCudaCompatibility(context, instanceCatalog));
24
24
  findings.push(...this.checkModelTypeInstanceAlignment(context, instanceCatalog));
25
+ findings.push(...this.checkKvCacheMemoryFit(context, instanceCatalog));
25
26
 
26
27
  return findings;
27
28
  }
@@ -298,6 +299,45 @@ export default class CrossCuttingChecker {
298
299
  return findings;
299
300
  }
300
301
 
302
+ /**
303
+ * Verify model architecture compatibility with the selected server version.
304
+ * Checks model_type against the server's supportedModelTypes from the catalog.
305
+ * Skips silently when supportedModelTypes is empty (sync not run).
306
+ *
307
+ * @param {Object} context - ValidationContext
308
+ * @param {Object} modelServersCatalog - Model servers catalog (from servers/lib/catalogs/model-servers.json)
309
+ * @returns {Array} Findings
310
+ */
311
+ checkModelArchitectureCompatibility(context, modelServersCatalog) {
312
+ const findings = [];
313
+ const config = context.config || {};
314
+
315
+ const modelType = config.modelType;
316
+ const serverVersion = config.baseImageVersion;
317
+ const server = config.modelServer;
318
+
319
+ if (!modelType || !server || !serverVersion) return findings;
320
+
321
+ const entries = modelServersCatalog[server] || [];
322
+ const entry = entries.find(e => e.labels?.framework_version === serverVersion);
323
+ if (!entry?.supportedModelTypes?.length) return findings;
324
+
325
+ if (!entry.supportedModelTypes.includes(modelType.toLowerCase())) {
326
+ findings.push({
327
+ service: 'cross-cutting',
328
+ operation: 'configuration',
329
+ fieldPath: 'MODEL_NAME',
330
+ invalidValue: modelType,
331
+ constraint: { type: 'architecture-compatibility', server, version: serverVersion },
332
+ severity: 'warning',
333
+ confidence: 'medium',
334
+ source: 'cross-cutting',
335
+ remediationHint: `Model architecture "${modelType}" may not be supported by ${server} ${serverVersion}. Consider a newer server version.`
336
+ });
337
+ }
338
+ return findings;
339
+ }
340
+
301
341
  /**
302
342
  * Verify predictor models are not assigned GPU instances.
303
343
  * @param {Object} context - ValidationContext
@@ -338,4 +378,83 @@ export default class CrossCuttingChecker {
338
378
 
339
379
  return findings;
340
380
  }
381
+
382
+ /**
383
+ * Verify that the model's estimated VRAM (weights + KV cache at configured max_model_len)
384
+ * fits in the instance's available GPU memory.
385
+ *
386
+ * Uses the same estimation formula as the instance-sizer's vram-estimator:
387
+ * total = weights + KV cache + 10% overhead
388
+ *
389
+ * @param {Object} context - ValidationContext
390
+ * @param {Object} instanceCatalog - Instance catalog
391
+ * @returns {Array} Findings
392
+ */
393
+ checkKvCacheMemoryFit(context, instanceCatalog) {
394
+ const findings = [];
395
+ const config = context.config || {};
396
+ const catalog = instanceCatalog?.catalog || instanceCatalog || {};
397
+
398
+ const instanceType = config.INSTANCE_TYPE;
399
+ if (!instanceType) return findings;
400
+
401
+ const instanceInfo = catalog[instanceType];
402
+ if (!instanceInfo || !instanceInfo.gpus || instanceInfo.gpus <= 0) return findings;
403
+
404
+ // Need parameter count to estimate weights
405
+ const parameterCount = config._parameterCount || config.parameterCount;
406
+ if (!parameterCount) return findings;
407
+
408
+ // Resolve max sequence length: explicit env var > model's max_position_embeddings > skip
409
+ const maxModelLen = parseInt(config.VLLM_MAX_MODEL_LEN || config.SGLANG_MAX_MODEL_LEN || '0', 10);
410
+ const maxPosEmbed = parseInt(config._maxPositionEmbeddings || '0', 10);
411
+ const seqLen = maxModelLen || maxPosEmbed;
412
+ if (!seqLen) return findings;
413
+
414
+ // Estimate per-GPU VRAM from instance catalog
415
+ let perGpuVramGb = instanceInfo.gpuMemoryGb;
416
+ if (!perGpuVramGb && instanceInfo.accelerator) {
417
+ const match = instanceInfo.accelerator.match(/(\d+)GB/);
418
+ if (match) {
419
+ const totalGb = parseInt(match[1], 10);
420
+ const hasMultiplier = instanceInfo.accelerator.match(/^(\d+)x\s/);
421
+ perGpuVramGb = hasMultiplier ? totalGb / instanceInfo.gpus : totalGb;
422
+ }
423
+ }
424
+ if (!perGpuVramGb) return findings;
425
+
426
+ const totalVramGb = perGpuVramGb * instanceInfo.gpus;
427
+
428
+ // Estimate VRAM needed (same formula as vram-estimator.js)
429
+ const dtype = config._dtype || 'float16';
430
+ const bytesPerParam = dtype === 'float32' ? 4.0 : dtype === 'int8' ? 1.0 : 2.0;
431
+ const weightsGb = (parameterCount * bytesPerParam) / (1024 ** 3);
432
+ const kvCacheGb = (parameterCount * (seqLen / 4096) * 0.05) / (1024 ** 3);
433
+ const overheadGb = weightsGb * 0.1;
434
+ const estimatedTotalGb = weightsGb + kvCacheGb + overheadGb;
435
+
436
+ if (estimatedTotalGb > totalVramGb) {
437
+ findings.push({
438
+ service: 'cross-cutting',
439
+ operation: 'configuration',
440
+ fieldPath: 'INSTANCE_TYPE',
441
+ invalidValue: instanceType,
442
+ constraint: {
443
+ type: 'kv-cache-memory-fit',
444
+ estimatedVramGb: Math.round(estimatedTotalGb * 10) / 10,
445
+ weightsGb: Math.round(weightsGb * 10) / 10,
446
+ kvCacheGb: Math.round(kvCacheGb * 10) / 10,
447
+ totalVramGb,
448
+ maxModelLen: seqLen,
449
+ instanceType
450
+ },
451
+ severity: 'warning',
452
+ confidence: 'medium',
453
+ source: 'cross-cutting',
454
+ remediationHint: `Estimated VRAM needed: ${estimatedTotalGb.toFixed(1)}GB (weights: ${weightsGb.toFixed(1)}GB + KV cache: ${kvCacheGb.toFixed(1)}GB at seq_len=${seqLen}) exceeds instance capacity (${totalVramGb}GB). Reduce VLLM_MAX_MODEL_LEN, use quantization, or select a larger instance.`
455
+ });
456
+ }
457
+
458
+ return findings;
459
+ }
341
460
  }
@@ -65,8 +65,7 @@ export default {
65
65
  required: ['modelName'],
66
66
  properties: {
67
67
  modelName: {
68
- type: 'string',
69
- minLength: 1
68
+ type: ['string', 'null']
70
69
  },
71
70
  modelFormat: {
72
71
  type: ['string', 'null']