@aws/ml-container-creator 0.2.6 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. package/bin/cli.js +38 -2
  2. package/config/bootstrap-stack.json +94 -1
  3. package/config/defaults.json +1 -1
  4. package/infra/ci-harness/package-lock.json +22 -9
  5. package/package.json +3 -1
  6. package/servers/instance-sizer/index.js +45 -8
  7. package/servers/instance-sizer/lib/instance-ranker.js +140 -11
  8. package/servers/instance-sizer/lib/model-resolver.js +10 -6
  9. package/servers/instance-sizer/lib/quota-resolver.js +368 -0
  10. package/servers/instance-sizer/package.json +2 -0
  11. package/servers/lib/catalogs/instances.json +527 -12
  12. package/servers/lib/catalogs/model-servers.json +298 -20
  13. package/servers/lib/catalogs/model-sizes.json +27 -0
  14. package/servers/lib/catalogs/models.json +101 -0
  15. package/servers/lib/schemas/image-catalog.schema.json +15 -1
  16. package/servers/model-picker/index.js +2 -1
  17. package/src/app.js +96 -2
  18. package/src/lib/architecture-sync.js +171 -0
  19. package/src/lib/arn-detection.js +22 -0
  20. package/src/lib/bootstrap-command-handler.js +178 -3
  21. package/src/lib/cli-handler.js +2 -2
  22. package/src/lib/config-manager.js +121 -1
  23. package/src/lib/cross-cutting-checker.js +119 -0
  24. package/src/lib/deployment-entry-schema.js +1 -2
  25. package/src/lib/prompt-runner.js +514 -20
  26. package/src/lib/prompts.js +67 -5
  27. package/src/lib/registry-command-handler.js +236 -0
  28. package/src/lib/schema-sync.js +31 -0
  29. package/src/lib/secret-classification.js +56 -0
  30. package/src/lib/secrets-command-handler.js +550 -0
  31. package/src/lib/template-manager.js +49 -1
  32. package/src/lib/validate-runner.js +174 -2
  33. package/src/lib/validation-report.js +8 -1
  34. package/src/prompt-adapter.js +3 -2
  35. package/templates/Dockerfile +10 -2
  36. package/templates/code/cuda_compat.sh +22 -0
  37. package/templates/code/serve +3 -0
  38. package/templates/code/start_server.sh +3 -0
  39. package/templates/diffusors/Dockerfile +2 -1
  40. package/templates/diffusors/serve +3 -0
  41. package/templates/do/README.md +33 -0
  42. package/templates/do/benchmark +646 -0
  43. package/templates/do/build +22 -0
  44. package/templates/do/clean +86 -0
  45. package/templates/do/config +41 -6
  46. package/templates/do/deploy +66 -6
  47. package/templates/do/logs +18 -3
  48. package/templates/do/register +8 -1
  49. package/templates/do/run +10 -0
  50. package/templates/triton/Dockerfile +5 -0
@@ -300,6 +300,15 @@ export default class ConfigManager {
300
300
  finalConfig.hfToken = this._resolveHfToken(finalConfig.hfToken);
301
301
  }
302
302
 
303
+ // Mutual exclusion: ARN takes precedence over plaintext when both are set
304
+ // (CLI validation should prevent this, but enforce at config level too)
305
+ if (finalConfig.hfTokenArn) {
306
+ finalConfig.hfToken = null;
307
+ }
308
+ if (finalConfig.ngcTokenArn) {
309
+ finalConfig.ngcApiKey = null;
310
+ }
311
+
303
312
  // Map awsRoleArn to roleArn for templates
304
313
  if (finalConfig.awsRoleArn) {
305
314
  finalConfig.roleArn = finalConfig.awsRoleArn;
@@ -643,6 +652,28 @@ export default class ConfigManager {
643
652
  default: null,
644
653
  valueSpace: 'bounded'
645
654
  },
655
+ hfTokenArn: {
656
+ cliOption: 'hf-token-arn',
657
+ envVar: null,
658
+ configFile: true,
659
+ packageJson: false,
660
+ mcp: false,
661
+ promptable: false,
662
+ required: false,
663
+ default: null,
664
+ valueSpace: 'bounded'
665
+ },
666
+ ngcTokenArn: {
667
+ cliOption: 'ngc-token-arn',
668
+ envVar: null,
669
+ configFile: true,
670
+ packageJson: false,
671
+ mcp: false,
672
+ promptable: false,
673
+ required: false,
674
+ default: null,
675
+ valueSpace: 'bounded'
676
+ },
646
677
  deploymentTarget: {
647
678
  cliOption: 'deployment-target',
648
679
  envVar: 'ML_DEPLOYMENT_TARGET',
@@ -948,6 +979,83 @@ export default class ConfigManager {
948
979
  default: 1.0,
949
980
  valueSpace: 'bounded',
950
981
  schemaValidated: true
982
+ },
983
+ includeBenchmark: {
984
+ cliOption: 'include-benchmark',
985
+ envVar: 'ML_INCLUDE_BENCHMARK',
986
+ configFile: true,
987
+ packageJson: false,
988
+ mcp: false,
989
+ promptable: true,
990
+ required: false,
991
+ default: false,
992
+ valueSpace: 'bounded'
993
+ },
994
+ benchmarkConcurrency: {
995
+ cliOption: 'benchmark-concurrency',
996
+ envVar: null,
997
+ configFile: true,
998
+ packageJson: false,
999
+ mcp: false,
1000
+ promptable: true,
1001
+ required: false,
1002
+ default: 10,
1003
+ valueSpace: 'bounded'
1004
+ },
1005
+ benchmarkInputTokensMean: {
1006
+ cliOption: 'benchmark-input-tokens',
1007
+ envVar: null,
1008
+ configFile: true,
1009
+ packageJson: false,
1010
+ mcp: false,
1011
+ promptable: true,
1012
+ required: false,
1013
+ default: 550,
1014
+ valueSpace: 'bounded'
1015
+ },
1016
+ benchmarkOutputTokensMean: {
1017
+ cliOption: 'benchmark-output-tokens',
1018
+ envVar: null,
1019
+ configFile: true,
1020
+ packageJson: false,
1021
+ mcp: false,
1022
+ promptable: true,
1023
+ required: false,
1024
+ default: 150,
1025
+ valueSpace: 'bounded'
1026
+ },
1027
+ benchmarkStreaming: {
1028
+ cliOption: 'benchmark-streaming',
1029
+ envVar: null,
1030
+ configFile: true,
1031
+ packageJson: false,
1032
+ mcp: false,
1033
+ promptable: true,
1034
+ required: false,
1035
+ default: true,
1036
+ valueSpace: 'bounded'
1037
+ },
1038
+ benchmarkRequestCount: {
1039
+ cliOption: 'benchmark-request-count',
1040
+ envVar: null,
1041
+ configFile: true,
1042
+ packageJson: false,
1043
+ mcp: false,
1044
+ promptable: true,
1045
+ required: false,
1046
+ default: null,
1047
+ valueSpace: 'bounded'
1048
+ },
1049
+ benchmarkS3OutputPath: {
1050
+ cliOption: 'benchmark-s3-output-path',
1051
+ envVar: 'ML_BENCHMARK_S3_OUTPUT_PATH',
1052
+ configFile: true,
1053
+ packageJson: false,
1054
+ mcp: false,
1055
+ promptable: true,
1056
+ required: false,
1057
+ default: null,
1058
+ valueSpace: 'bounded'
951
1059
  }
952
1060
  };
953
1061
  }
@@ -980,7 +1088,7 @@ export default class ConfigManager {
980
1088
  */
981
1089
  _parseValue(parameter, value) {
982
1090
  // Handle boolean parameters
983
- if (parameter === 'includeSampleModel' || parameter === 'includeTesting' || parameter === 'skipPrompts') {
1091
+ if (parameter === 'includeSampleModel' || parameter === 'includeTesting' || parameter === 'skipPrompts' || parameter === 'includeBenchmark' || parameter === 'benchmarkStreaming') {
984
1092
  return value === true || value === 'true';
985
1093
  }
986
1094
 
@@ -1675,6 +1783,18 @@ export default class ConfigManager {
1675
1783
  }
1676
1784
  }
1677
1785
 
1786
+ // Validate mutual exclusion: plaintext token and ARN cannot both be set
1787
+ if (this.config.hfToken && this.config.hfTokenArn) {
1788
+ errors.push('Cannot specify both --hf-token and --hf-token-arn. Use one or the other.');
1789
+ }
1790
+ if (this.config.ngcTokenArn) {
1791
+ // Check ngcToken from CLI options (Commander converts --ngc-token to ngcToken)
1792
+ const ngcTokenFromCli = this.options['ngc-token'];
1793
+ if (ngcTokenFromCli) {
1794
+ errors.push('Cannot specify both --ngc-token and --ngc-token-arn. Use one or the other.');
1795
+ }
1796
+ }
1797
+
1678
1798
  // Validate AWS Role ARN format if provided
1679
1799
  if (this.config.awsRoleArn) {
1680
1800
  try {
@@ -22,6 +22,7 @@ export default class CrossCuttingChecker {
22
22
  findings.push(...this.checkRoleArnFormat(context));
23
23
  findings.push(...this.checkCudaCompatibility(context, instanceCatalog));
24
24
  findings.push(...this.checkModelTypeInstanceAlignment(context, instanceCatalog));
25
+ findings.push(...this.checkKvCacheMemoryFit(context, instanceCatalog));
25
26
 
26
27
  return findings;
27
28
  }
@@ -298,6 +299,45 @@ export default class CrossCuttingChecker {
298
299
  return findings;
299
300
  }
300
301
 
302
+ /**
303
+ * Verify model architecture compatibility with the selected server version.
304
+ * Checks model_type against the server's supportedModelTypes from the catalog.
305
+ * Skips silently when supportedModelTypes is empty (sync not run).
306
+ *
307
+ * @param {Object} context - ValidationContext
308
+ * @param {Object} modelServersCatalog - Model servers catalog (from servers/lib/catalogs/model-servers.json)
309
+ * @returns {Array} Findings
310
+ */
311
+ checkModelArchitectureCompatibility(context, modelServersCatalog) {
312
+ const findings = [];
313
+ const config = context.config || {};
314
+
315
+ const modelType = config.modelType;
316
+ const serverVersion = config.baseImageVersion;
317
+ const server = config.modelServer;
318
+
319
+ if (!modelType || !server || !serverVersion) return findings;
320
+
321
+ const entries = modelServersCatalog[server] || [];
322
+ const entry = entries.find(e => e.labels?.framework_version === serverVersion);
323
+ if (!entry?.supportedModelTypes?.length) return findings;
324
+
325
+ if (!entry.supportedModelTypes.includes(modelType.toLowerCase())) {
326
+ findings.push({
327
+ service: 'cross-cutting',
328
+ operation: 'configuration',
329
+ fieldPath: 'MODEL_NAME',
330
+ invalidValue: modelType,
331
+ constraint: { type: 'architecture-compatibility', server, version: serverVersion },
332
+ severity: 'warning',
333
+ confidence: 'medium',
334
+ source: 'cross-cutting',
335
+ remediationHint: `Model architecture "${modelType}" may not be supported by ${server} ${serverVersion}. Consider a newer server version.`
336
+ });
337
+ }
338
+ return findings;
339
+ }
340
+
301
341
  /**
302
342
  * Verify predictor models are not assigned GPU instances.
303
343
  * @param {Object} context - ValidationContext
@@ -338,4 +378,83 @@ export default class CrossCuttingChecker {
338
378
 
339
379
  return findings;
340
380
  }
381
+
382
+ /**
383
+ * Verify that the model's estimated VRAM (weights + KV cache at configured max_model_len)
384
+ * fits in the instance's available GPU memory.
385
+ *
386
+ * Uses the same estimation formula as the instance-sizer's vram-estimator:
387
+ * total = weights + KV cache + 10% overhead
388
+ *
389
+ * @param {Object} context - ValidationContext
390
+ * @param {Object} instanceCatalog - Instance catalog
391
+ * @returns {Array} Findings
392
+ */
393
+ checkKvCacheMemoryFit(context, instanceCatalog) {
394
+ const findings = [];
395
+ const config = context.config || {};
396
+ const catalog = instanceCatalog?.catalog || instanceCatalog || {};
397
+
398
+ const instanceType = config.INSTANCE_TYPE;
399
+ if (!instanceType) return findings;
400
+
401
+ const instanceInfo = catalog[instanceType];
402
+ if (!instanceInfo || !instanceInfo.gpus || instanceInfo.gpus <= 0) return findings;
403
+
404
+ // Need parameter count to estimate weights
405
+ const parameterCount = config._parameterCount || config.parameterCount;
406
+ if (!parameterCount) return findings;
407
+
408
+ // Resolve max sequence length: explicit env var > model's max_position_embeddings > skip
409
+ const maxModelLen = parseInt(config.VLLM_MAX_MODEL_LEN || config.SGLANG_MAX_MODEL_LEN || '0', 10);
410
+ const maxPosEmbed = parseInt(config._maxPositionEmbeddings || '0', 10);
411
+ const seqLen = maxModelLen || maxPosEmbed;
412
+ if (!seqLen) return findings;
413
+
414
+ // Estimate per-GPU VRAM from instance catalog
415
+ let perGpuVramGb = instanceInfo.gpuMemoryGb;
416
+ if (!perGpuVramGb && instanceInfo.accelerator) {
417
+ const match = instanceInfo.accelerator.match(/(\d+)GB/);
418
+ if (match) {
419
+ const totalGb = parseInt(match[1], 10);
420
+ const hasMultiplier = instanceInfo.accelerator.match(/^(\d+)x\s/);
421
+ perGpuVramGb = hasMultiplier ? totalGb / instanceInfo.gpus : totalGb;
422
+ }
423
+ }
424
+ if (!perGpuVramGb) return findings;
425
+
426
+ const totalVramGb = perGpuVramGb * instanceInfo.gpus;
427
+
428
+ // Estimate VRAM needed (same formula as vram-estimator.js)
429
+ const dtype = config._dtype || 'float16';
430
+ const bytesPerParam = dtype === 'float32' ? 4.0 : dtype === 'int8' ? 1.0 : 2.0;
431
+ const weightsGb = (parameterCount * bytesPerParam) / (1024 ** 3);
432
+ const kvCacheGb = (parameterCount * (seqLen / 4096) * 0.05) / (1024 ** 3);
433
+ const overheadGb = weightsGb * 0.1;
434
+ const estimatedTotalGb = weightsGb + kvCacheGb + overheadGb;
435
+
436
+ if (estimatedTotalGb > totalVramGb) {
437
+ findings.push({
438
+ service: 'cross-cutting',
439
+ operation: 'configuration',
440
+ fieldPath: 'INSTANCE_TYPE',
441
+ invalidValue: instanceType,
442
+ constraint: {
443
+ type: 'kv-cache-memory-fit',
444
+ estimatedVramGb: Math.round(estimatedTotalGb * 10) / 10,
445
+ weightsGb: Math.round(weightsGb * 10) / 10,
446
+ kvCacheGb: Math.round(kvCacheGb * 10) / 10,
447
+ totalVramGb,
448
+ maxModelLen: seqLen,
449
+ instanceType
450
+ },
451
+ severity: 'warning',
452
+ confidence: 'medium',
453
+ source: 'cross-cutting',
454
+ remediationHint: `Estimated VRAM needed: ${estimatedTotalGb.toFixed(1)}GB (weights: ${weightsGb.toFixed(1)}GB + KV cache: ${kvCacheGb.toFixed(1)}GB at seq_len=${seqLen}) exceeds instance capacity (${totalVramGb}GB). Reduce VLLM_MAX_MODEL_LEN, use quantization, or select a larger instance.`
455
+ });
456
+ }
457
+
458
+ return findings;
459
+ }
341
460
  }
@@ -65,8 +65,7 @@ export default {
65
65
  required: ['modelName'],
66
66
  properties: {
67
67
  modelName: {
68
- type: 'string',
69
- minLength: 1
68
+ type: ['string', 'null']
70
69
  },
71
70
  modelFormat: {
72
71
  type: ['string', 'null']