@aws/ml-container-creator 1.0.3 → 1.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. package/README.md +10 -1
  2. package/bin/cli.js +57 -0
  3. package/config/agent.json +16 -0
  4. package/infra/ci-harness/lib/ci-harness-stack.ts +43 -0
  5. package/package.json +5 -2
  6. package/pyproject.toml +3 -0
  7. package/servers/agent-knowledge/index.js +592 -0
  8. package/servers/agent-knowledge/package.json +15 -0
  9. package/servers/base-image-picker/index.js +65 -18
  10. package/servers/instance-sizer/index.js +32 -0
  11. package/servers/lib/catalogs/fleet-drivers.json +38 -0
  12. package/servers/lib/catalogs/model-arch-support.json +51 -0
  13. package/servers/lib/catalogs/model-servers.json +2842 -1730
  14. package/servers/lib/schemas/image-catalog.schema.json +12 -0
  15. package/src/agent/__init__.py +2 -0
  16. package/src/agent/__pycache__/__init__.cpython-312.pyc +0 -0
  17. package/src/agent/__pycache__/config_loader.cpython-312.pyc +0 -0
  18. package/src/agent/__pycache__/context.cpython-312.pyc +0 -0
  19. package/src/agent/__pycache__/health_check.cpython-312.pyc +0 -0
  20. package/src/agent/agent.py +513 -0
  21. package/src/agent/config_loader.py +215 -0
  22. package/src/agent/context.py +380 -0
  23. package/src/agent/data/capability-matrix.json +106 -0
  24. package/src/agent/health_check.py +341 -0
  25. package/src/agent/prompts/system.md +173 -0
  26. package/src/agent/requirements-agent.txt +3 -0
  27. package/src/app.js +6 -4
  28. package/src/lib/generated/cli-options.js +1 -1
  29. package/src/lib/generated/parameter-matrix.js +1 -1
  30. package/src/lib/generated/validation-rules.js +1 -1
  31. package/src/lib/mcp-query-runner.js +110 -3
  32. package/src/lib/prompt-runner.js +66 -22
  33. package/src/lib/template-variable-resolver.js +8 -0
  34. package/src/lib/train-config-builder.js +339 -0
  35. package/src/lib/tune-config-state.js +89 -68
  36. package/templates/do/.benchmark_writer.py +3 -0
  37. package/templates/do/.eval_helper.py +409 -0
  38. package/templates/do/.register_helper.py +185 -11
  39. package/templates/do/.train_build_request.py +102 -113
  40. package/templates/do/.train_helper.py +433 -0
  41. package/templates/do/__pycache__/.register_helper.cpython-312.pyc +0 -0
  42. package/templates/do/adapter +157 -0
  43. package/templates/do/benchmark +60 -3
  44. package/templates/do/config +6 -1
  45. package/templates/do/deploy.d/managed-inference.ejs +83 -0
  46. package/templates/do/evaluate +272 -0
  47. package/templates/do/lib/resolve-instance.sh +155 -0
  48. package/templates/do/register +5 -0
  49. package/templates/do/test +1 -0
  50. package/templates/do/train +879 -126
  51. package/templates/do/training/config.yaml +83 -11
  52. package/templates/do/training/dpo/accelerate_config.yaml +24 -0
  53. package/templates/do/training/dpo/defaults.yaml +26 -0
  54. package/templates/do/training/dpo/prompts.json +8 -0
  55. package/templates/do/training/dpo/train.py +363 -0
  56. package/templates/do/training/sft/accelerate_config.yaml +22 -0
  57. package/templates/do/training/sft/defaults.yaml +18 -0
  58. package/templates/do/training/sft/prompts.json +7 -0
  59. package/templates/do/training/sft/train.py +310 -0
  60. package/templates/do/tune +11 -2
  61. package/src/lib/auto-prompt-builder.js +0 -172
  62. package/src/lib/cli-handler.js +0 -529
  63. package/src/lib/community-reports-validator.js +0 -91
  64. package/src/lib/configuration-exporter.js +0 -204
  65. package/src/lib/dataset-slug.js +0 -152
  66. package/src/lib/docker-introspection-validator.js +0 -51
  67. package/src/lib/known-flags-validator.js +0 -200
  68. package/src/lib/schema-validator.js +0 -157
  69. package/src/lib/train-config-parser.js +0 -136
  70. package/src/lib/train-config-persistence.js +0 -143
  71. package/src/lib/train-config-validator.js +0 -112
  72. package/src/lib/train-feedback.js +0 -46
  73. package/src/lib/train-idempotency.js +0 -97
  74. package/src/lib/train-request-builder.js +0 -120
  75. package/src/lib/tune-dataset-validator.js +0 -279
  76. package/src/lib/tune-output-resolver.js +0 -66
  77. package/templates/do/.train_poll_parser.py +0 -135
  78. package/templates/do/.train_status_parser.py +0 -187
  79. /package/templates/do/training/{train.py → custom/train.py} +0 -0
@@ -25,6 +25,8 @@ import { readFileSync } from 'node:fs';
25
25
  import { fileURLToPath } from 'node:url';
26
26
  import { resolve, dirname } from 'node:path';
27
27
  import { DynamicResolver as DynamicResolverBase } from '../lib/dynamic-resolver.js';
28
+ import { filterImages, deriveMinDriverVersion } from '../lib/image-filter.js';
29
+ import { resolveModelArchitecture } from '../lib/model-id-resolver.js';
28
30
 
29
31
  // ── Catalog loader ───────────────────────────────────────────────────────────
30
32
 
@@ -156,15 +158,25 @@ class DynamicResolver extends ImageResolver {
156
158
  }
157
159
 
158
160
  const data = await response.json();
159
- const images = (data.results || []).map(tag => ({
160
- image: `${this._repoForFramework(framework)}:${tag.name}`,
161
- tag: tag.name,
162
- architecture: 'amd64',
163
- created: tag.last_updated || tag.tag_last_pushed || new Date().toISOString(),
164
- labels: {},
165
- registry: 'dockerhub',
166
- repository: this._repoForFramework(framework)
167
- }));
161
+ const images = (data.results || []).map(tag => {
162
+ const entry = {
163
+ image: `${this._repoForFramework(framework)}:${tag.name}`,
164
+ tag: tag.name,
165
+ architecture: 'amd64',
166
+ created: tag.last_updated || tag.tag_last_pushed || new Date().toISOString(),
167
+ labels: {},
168
+ registry: 'dockerhub',
169
+ repository: this._repoForFramework(framework)
170
+ };
171
+
172
+ // Derive min_driver_version from CUDA version in tag or labels
173
+ const minDriver = deriveMinDriverVersion(entry);
174
+ if (minDriver) {
175
+ entry.min_driver_version = minDriver;
176
+ }
177
+
178
+ return entry;
179
+ });
168
180
 
169
181
  return {
170
182
  images: images.slice(0, limit),
@@ -375,7 +387,9 @@ if (discoverMode) {
375
387
  * When discover mode is active, merges static and dynamic results.
376
388
  */
377
389
  async function resolveBaseImage(context, limit) {
378
- const { framework, modelServer, searchCriteria, architecture } = context;
390
+ const { framework, modelServer, searchCriteria, architecture,
391
+ instanceType, driverVersion, inferenceAmiVersion,
392
+ tensorParallelSize, modelArchitecture, modelId } = context;
379
393
 
380
394
  // Determine which framework identifier to resolve
381
395
  let resolverKey;
@@ -398,21 +412,52 @@ async function resolveBaseImage(context, limit) {
398
412
 
399
413
  if (discoverMode && dynamicResolver && dynamicResolver.supportedFrameworks().includes(resolverKey)) {
400
414
  // Fetch both static and dynamic results, then merge
401
- const staticResult = await staticResolver.fetchImages(resolverKey, { limit, searchCriteria });
415
+ const staticResult = await staticResolver.fetchImages(resolverKey, { limit: limit * 3, searchCriteria });
402
416
  const dynamicResult = await dynamicResolver.fetchImages(resolverKey, { limit: 5 });
403
417
 
404
- resultImages = mergeStaticAndDynamic(staticResult.images, dynamicResult.images, limit);
418
+ resultImages = mergeStaticAndDynamic(staticResult.images, dynamicResult.images, limit * 3);
405
419
  } else {
406
- // Static-only path (no network calls)
407
- const result = await resolver.fetchImages(resolverKey, { limit, searchCriteria });
420
+ // Static-only path (no network calls) — fetch extra to allow for filtering
421
+ const fetchLimit = (instanceType || driverVersion || modelArchitecture || modelId) ? limit * 3 : limit;
422
+ const result = await resolver.fetchImages(resolverKey, { limit: fetchLimit, searchCriteria });
408
423
  resultImages = result.images;
409
424
  }
410
425
 
426
+ // ── Resolve modelId → modelArchitecture if needed ───────────────────
427
+ let resolvedModelArchitecture = modelArchitecture || '';
428
+ if (!modelArchitecture && modelId) {
429
+ const arch = await resolveModelArchitecture(modelId);
430
+ if (arch) {
431
+ resolvedModelArchitecture = arch;
432
+ }
433
+ }
434
+
435
+ // ── Apply driver-aware + model-architecture filtering ─────────────────
436
+ let filterMetadata = null;
437
+ if (instanceType || driverVersion || inferenceAmiVersion || resolvedModelArchitecture) {
438
+ const filterResult = filterImages(resultImages, {
439
+ framework: resolverKey,
440
+ instanceType,
441
+ driverVersion,
442
+ inferenceAmiVersion,
443
+ tensorParallelSize: tensorParallelSize || 1,
444
+ modelArchitecture: resolvedModelArchitecture
445
+ });
446
+ resultImages = filterResult.images;
447
+ filterMetadata = filterResult.metadata;
448
+ }
449
+
450
+ // Apply final limit after filtering
451
+ resultImages = resultImages.slice(0, limit);
452
+
411
453
  const images = resultImages.map(e => e.image);
412
454
  return {
413
455
  values: { baseImage: images[0] || null },
414
456
  choices: { baseImage: images },
415
- metadata: { baseImage: resultImages }
457
+ metadata: {
458
+ baseImage: resultImages,
459
+ ...(filterMetadata ? { driverFilter: filterMetadata } : {})
460
+ }
416
461
  };
417
462
  }
418
463
 
@@ -432,11 +477,11 @@ const server = new McpServer({
432
477
 
433
478
  server.tool(
434
479
  'get_base_images',
435
- 'Returns curated base container images for ML Container Creator Dockerfiles',
480
+ 'Returns curated base container images for ML Container Creator Dockerfiles. Supports driver-aware filtering when instanceType is provided — excludes images incompatible with the fleet GPU driver, especially for multi-GPU tensor-parallel deployments.',
436
481
  {
437
482
  parameters: z.array(z.string()).describe('List of parameter names to provide values for'),
438
483
  limit: z.number().int().positive().default(5).describe('Maximum number of choices per parameter'),
439
- context: z.record(z.string(), z.any()).optional().describe('Current configuration context (framework, modelServer, searchCriteria)')
484
+ context: z.record(z.string(), z.any()).optional().describe('Configuration context. Supports: framework, modelServer, searchCriteria, architecture, instanceType (triggers driver filtering), driverVersion (override), inferenceAmiVersion (resolves to driver), tensorParallelSize (TP>1 = strict filtering), modelId, modelArchitecture (excludes old framework versions)')
440
485
  },
441
486
  async ({ parameters, limit, context }) => {
442
487
  const values = {};
@@ -472,10 +517,12 @@ export {
472
517
  TRITON_IMAGE_CATALOG,
473
518
  resolveBaseImage,
474
519
  mergeStaticAndDynamic,
520
+ filterImages,
475
521
  registry,
476
522
  staticResolver,
477
523
  dynamicResolver,
478
- discoverMode
524
+ discoverMode,
525
+ resolveModelArchitecture
479
526
  };
480
527
 
481
528
  export { DynamicResolverBase as DynamicResolverBase };
@@ -393,6 +393,38 @@ async function handleGetInstanceRecommendation(params) {
393
393
  { limit }
394
394
  );
395
395
 
396
+ // Step 3-recommended: When VRAM filter returns empty but catalog has recommendedInstances,
397
+ // use those as the fallback (they represent tested/validated deployments).
398
+ if (recommendations.length === 0 && modelMetadata.recommendedInstances && modelMetadata.recommendedInstances.length > 0) {
399
+ for (const instanceType of modelMetadata.recommendedInstances) {
400
+ const meta = effectiveCatalog[instanceType];
401
+ if (meta) {
402
+ const perGpuMemory = getPerGpuMemoryGb(meta);
403
+ const gpuCount = meta.gpus || 1;
404
+ const totalVramGb = perGpuMemory ? perGpuMemory * gpuCount : null;
405
+ recommendations.push({
406
+ instanceType,
407
+ gpuCount,
408
+ totalVramGb,
409
+ utilizationPercent: totalVramGb ? Math.round((vramEstimate.vramGb / totalVramGb) * 100) : null,
410
+ tensorParallelism: gpuCount,
411
+ costTier: meta.costTier || null
412
+ });
413
+ } else {
414
+ // Instance not in catalog but listed as recommended — still include it
415
+ recommendations.push({
416
+ instanceType,
417
+ gpuCount: null,
418
+ totalVramGb: null,
419
+ utilizationPercent: null,
420
+ tensorParallelism: null,
421
+ costTier: null
422
+ });
423
+ }
424
+ }
425
+ log(`Using catalog recommendedInstances for "${modelName}" (VRAM filter returned empty)`);
426
+ }
427
+
396
428
  // Step 3-max_model_len: When no instance fits at full context, try capping context length
397
429
  // NFR-1 guard: skip this logic for models with recommendedInstances in catalog
398
430
  let suggestedMaxModelLen = null;
@@ -0,0 +1,38 @@
1
+ {
2
+ "_comment": "Instance family → GPU driver version mapping for SageMaker inference fleet. Source: AWS docs (inference-gpu-drivers.html) + empirical validation. Updated quarterly or when AWS announces fleet driver updates.",
3
+ "_last_updated": "2026-06-29",
4
+ "instance_families": {
5
+ "g4dn": { "driver": "535.183", "cuda_native": "12.2", "gpu": "T4", "gpu_memory_gb": 16 },
6
+ "g5": { "driver": "550.163", "cuda_native": "12.4", "gpu": "A10G", "gpu_memory_gb": 24 },
7
+ "g5n": { "driver": "550.163", "cuda_native": "12.4", "gpu": "A10G", "gpu_memory_gb": 24 },
8
+ "g6": { "driver": "560.35", "cuda_native": "12.6", "gpu": "L4", "gpu_memory_gb": 24 },
9
+ "g6e": { "driver": "560.35", "cuda_native": "12.6", "gpu": "L40S", "gpu_memory_gb": 48 },
10
+ "p4d": { "driver": "550.163", "cuda_native": "12.4", "gpu": "A100", "gpu_memory_gb": 40 },
11
+ "p4de": { "driver": "550.163", "cuda_native": "12.4", "gpu": "A100", "gpu_memory_gb": 80 },
12
+ "p5": { "driver": "580.95", "cuda_native": "12.9", "gpu": "H100", "gpu_memory_gb": 80 },
13
+ "p5e": { "driver": "580.95", "cuda_native": "12.9", "gpu": "H200", "gpu_memory_gb": 141 },
14
+ "trn1": null,
15
+ "trn2": null,
16
+ "inf2": null
17
+ },
18
+ "ami_versions": {
19
+ "_comment": "InferenceAmiVersion enum → driver version mapping. From SDK enum + empirical.",
20
+ "al2-ami-sagemaker-inference-gpu-2": { "driver": "535.183", "cuda_native": "12.2" },
21
+ "al2-ami-sagemaker-inference-gpu-2-1": { "driver": "535.216", "cuda_native": "12.2" },
22
+ "al2-ami-sagemaker-inference-gpu-3-1": { "driver": "550.163", "cuda_native": "12.4" },
23
+ "al2023-ami-sagemaker-inference-gpu-4-1": { "driver": "570.86", "cuda_native": "12.8" }
24
+ },
25
+ "cuda_to_min_driver": {
26
+ "_comment": "CUDA toolkit version → minimum required driver (Linux data center). Source: NVIDIA CUDA compatibility docs.",
27
+ "12.0": "525.60",
28
+ "12.1": "525.60",
29
+ "12.2": "535.54",
30
+ "12.3": "535.54",
31
+ "12.4": "550.54",
32
+ "12.5": "555.42",
33
+ "12.6": "560.28",
34
+ "12.7": "565.57",
35
+ "12.8": "570.86",
36
+ "12.9": "580.00"
37
+ }
38
+ }
@@ -0,0 +1,51 @@
1
+ {
2
+ "_comment": "Model architecture → minimum framework version mapping. Used by driver-aware filtering to exclude framework versions that don't support a given model architecture. Source: vLLM/SGLang release notes and supported_models docs.",
3
+ "_last_updated": "2026-06-29",
4
+ "vllm": {
5
+ "LlamaForCausalLM": "v0.4.0",
6
+ "Llama4ForCausalLM": "v0.22.0",
7
+ "MistralForCausalLM": "v0.4.0",
8
+ "MixtralForCausalLM": "v0.4.0",
9
+ "Qwen2ForCausalLM": "v0.6.0",
10
+ "Qwen2MoeForCausalLM": "v0.6.0",
11
+ "Qwen3ForCausalLM": "v0.20.0",
12
+ "Qwen3MoeForCausalLM": "v0.20.0",
13
+ "DeepseekV2ForCausalLM": "v0.16.0",
14
+ "DeepseekV3ForCausalLM": "v0.19.0",
15
+ "Gemma2ForCausalLM": "v0.8.0",
16
+ "Gemma3ForCausalLM": "v0.20.0",
17
+ "Gemma4ForCausalLM": "v0.23.0",
18
+ "GptOssForCausalLM": "v0.22.0",
19
+ "NemotronForCausalLM": "v0.17.0",
20
+ "Phi3ForCausalLM": "v0.6.0",
21
+ "PhiMoEForCausalLM": "v0.16.0",
22
+ "GraniteForCausalLM": "v0.17.0",
23
+ "GraniteMoeForCausalLM": "v0.19.0",
24
+ "CohereForCausalLM": "v0.16.0",
25
+ "Cohere2ForCausalLM": "v0.19.0",
26
+ "FalconForCausalLM": "v0.4.0",
27
+ "StarCoder2ForCausalLM": "v0.6.0",
28
+ "InternLM2ForCausalLM": "v0.6.0",
29
+ "OlmoForCausalLM": "v0.16.0",
30
+ "Olmo2ForCausalLM": "v0.19.0"
31
+ },
32
+ "sglang": {
33
+ "LlamaForCausalLM": "v0.3.0",
34
+ "MistralForCausalLM": "v0.3.0",
35
+ "Qwen2ForCausalLM": "v0.4.0",
36
+ "Qwen3ForCausalLM": "v0.5.0",
37
+ "Qwen3MoeForCausalLM": "v0.5.0",
38
+ "DeepseekV2ForCausalLM": "v0.4.0",
39
+ "DeepseekV3ForCausalLM": "v0.5.0",
40
+ "Gemma2ForCausalLM": "v0.4.0",
41
+ "Gemma3ForCausalLM": "v0.5.0",
42
+ "Phi3ForCausalLM": "v0.4.0",
43
+ "InternLM2ForCausalLM": "v0.4.0"
44
+ },
45
+ "lmi": {
46
+ "_comment": "DJL LMI uses HuggingFace transformers directly — architecture support is determined by the bundled transformers version, not the LMI version itself. No version-based filtering needed for LMI.",
47
+ "LlamaForCausalLM": "v0.25.0",
48
+ "Qwen2ForCausalLM": "v0.27.0",
49
+ "DeepseekV2ForCausalLM": "v0.28.0"
50
+ }
51
+ }