@aws/ml-container-creator 1.0.2 → 1.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. package/README.md +1 -1
  2. package/bin/cli.js +1 -1
  3. package/config/tune-catalog.json +303 -1
  4. package/infra/ci-harness/lib/ci-harness-stack.ts +43 -0
  5. package/package.json +3 -2
  6. package/servers/base-image-picker/index.js +65 -18
  7. package/servers/instance-sizer/index.js +32 -0
  8. package/servers/lib/catalogs/fleet-drivers.json +38 -0
  9. package/servers/lib/catalogs/model-arch-support.json +51 -0
  10. package/servers/lib/catalogs/model-servers.json +2842 -1516
  11. package/servers/lib/schemas/image-catalog.schema.json +12 -0
  12. package/src/app.js +6 -4
  13. package/src/lib/bootstrap-command-handler.js +12 -2
  14. package/src/lib/bootstrap-profile-manager.js +16 -0
  15. package/src/lib/cross-cutting-checker.js +6 -1
  16. package/src/lib/generated/cli-options.js +1 -1
  17. package/src/lib/generated/parameter-matrix.js +1 -1
  18. package/src/lib/generated/validation-rules.js +1 -1
  19. package/src/lib/mcp-query-runner.js +110 -3
  20. package/src/lib/prompt-runner.js +66 -22
  21. package/src/lib/template-variable-resolver.js +8 -0
  22. package/src/lib/train-config-builder.js +339 -0
  23. package/templates/do/.benchmark_writer.py +3 -0
  24. package/templates/do/.eval_helper.py +409 -0
  25. package/templates/do/.register_helper.py +185 -11
  26. package/templates/do/.train_build_request.py +102 -113
  27. package/templates/do/.train_helper.py +433 -0
  28. package/templates/do/__pycache__/.register_helper.cpython-312.pyc +0 -0
  29. package/templates/do/adapter +157 -0
  30. package/templates/do/benchmark +60 -3
  31. package/templates/do/deploy.d/managed-inference.ejs +83 -0
  32. package/templates/do/evaluate +272 -0
  33. package/templates/do/lib/resolve-instance.sh +155 -0
  34. package/templates/do/register +5 -0
  35. package/templates/do/test +1 -0
  36. package/templates/do/train +879 -126
  37. package/templates/do/training/config.yaml +83 -11
  38. package/templates/do/training/dpo/accelerate_config.yaml +24 -0
  39. package/templates/do/training/dpo/defaults.yaml +26 -0
  40. package/templates/do/training/dpo/prompts.json +8 -0
  41. package/templates/do/training/dpo/train.py +363 -0
  42. package/templates/do/training/sft/accelerate_config.yaml +22 -0
  43. package/templates/do/training/sft/defaults.yaml +18 -0
  44. package/templates/do/training/sft/prompts.json +7 -0
  45. package/templates/do/training/sft/train.py +310 -0
  46. package/templates/do/tune +11 -2
  47. package/templates/do/.train_poll_parser.py +0 -135
  48. package/templates/do/.train_status_parser.py +0 -187
  49. /package/templates/do/training/{train.py → custom/train.py} +0 -0
package/README.md CHANGED
@@ -97,7 +97,7 @@ Full documentation is available at [awslabs.github.io/ml-container-creator](http
97
97
 
98
98
  ### Python dependencies
99
99
 
100
- The `do/` lifecycle scripts (`do/tune`, `do/stage`, `do/adapter`) require Python packages. Install them in your Python environment before first use:
100
+ The `do/` lifecycle scripts (`do/tune`, `do/train`, `do/stage`, `do/adapter`) require Python packages. Install them in your Python environment before first use:
101
101
 
102
102
  ```bash
103
103
  # Recommended (fast):
package/bin/cli.js CHANGED
@@ -162,7 +162,7 @@ program
162
162
  .command('bootstrap')
163
163
  .description('Set up AWS infrastructure (IAM role, ECR repo, S3 buckets)')
164
164
  .passThroughOptions()
165
- .argument('[action]', 'Bootstrap action (status, use, list, remove, scan, prune, update, sync-schemas)')
165
+ .argument('[action]', 'Bootstrap action (status, use, list, remove, scan, prune, update, migrate, sync-schemas, sync-model-families)')
166
166
  .argument('[args...]', 'Additional arguments')
167
167
  .option('--profile <profile>', 'AWS profile name')
168
168
  .option('--region <region>', 'AWS region')
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "version": "2026-05-27",
3
- "lastSynced": "2026-05-28T09:48:25.209Z",
3
+ "lastSynced": "2026-06-26T19:01:02.821Z",
4
4
  "source": "https://docs.aws.amazon.com/sagemaker/latest/dg/model-customize-open-weight.html",
5
5
  "models": {
6
6
  "huggingface-llm-qwen2-5-7b-instruct": {
@@ -1614,6 +1614,24 @@
1614
1614
  "prompt": "array"
1615
1615
  }
1616
1616
  }
1617
+ },
1618
+ "dpo": {
1619
+ "trainingTypes": [
1620
+ "lora"
1621
+ ],
1622
+ "datasetFormat": "default-dpo",
1623
+ "datasetSchema": {
1624
+ "required": [
1625
+ "prompt",
1626
+ "chosen",
1627
+ "rejected"
1628
+ ],
1629
+ "types": {
1630
+ "prompt": "string",
1631
+ "chosen": "string",
1632
+ "rejected": "string"
1633
+ }
1634
+ }
1617
1635
  }
1618
1636
  },
1619
1637
  "goldenPath": false
@@ -1667,6 +1685,24 @@
1667
1685
  "prompt": "array"
1668
1686
  }
1669
1687
  }
1688
+ },
1689
+ "dpo": {
1690
+ "trainingTypes": [
1691
+ "lora"
1692
+ ],
1693
+ "datasetFormat": "default-dpo",
1694
+ "datasetSchema": {
1695
+ "required": [
1696
+ "prompt",
1697
+ "chosen",
1698
+ "rejected"
1699
+ ],
1700
+ "types": {
1701
+ "prompt": "string",
1702
+ "chosen": "string",
1703
+ "rejected": "string"
1704
+ }
1705
+ }
1670
1706
  }
1671
1707
  },
1672
1708
  "goldenPath": false
@@ -1773,6 +1809,272 @@
1773
1809
  "prompt": "array"
1774
1810
  }
1775
1811
  }
1812
+ },
1813
+ "dpo": {
1814
+ "trainingTypes": [
1815
+ "lora"
1816
+ ],
1817
+ "datasetFormat": "default-dpo",
1818
+ "datasetSchema": {
1819
+ "required": [
1820
+ "prompt",
1821
+ "chosen",
1822
+ "rejected"
1823
+ ],
1824
+ "types": {
1825
+ "prompt": "string",
1826
+ "chosen": "string",
1827
+ "rejected": "string"
1828
+ }
1829
+ }
1830
+ }
1831
+ },
1832
+ "goldenPath": false
1833
+ },
1834
+ "huggingface-llm-nvidia-nemotron-3-super-120b-a12b-bf16": {
1835
+ "family": "huggingface-llm-nvidia-nemotron",
1836
+ "provider": "unknown",
1837
+ "displayName": "NVIDIA-Nemotron-3-Super-120B-A12B-BF16",
1838
+ "huggingFaceId": "",
1839
+ "techniques": {
1840
+ "sft": {
1841
+ "trainingTypes": [
1842
+ "lora"
1843
+ ],
1844
+ "datasetFormat": "default-sft",
1845
+ "datasetSchema": {
1846
+ "required": [
1847
+ "prompt",
1848
+ "completion"
1849
+ ],
1850
+ "types": {
1851
+ "prompt": "string",
1852
+ "completion": "string"
1853
+ }
1854
+ }
1855
+ },
1856
+ "rlvr": {
1857
+ "trainingTypes": [
1858
+ "lora"
1859
+ ],
1860
+ "datasetFormat": "default-rlvr",
1861
+ "datasetSchema": {
1862
+ "required": [
1863
+ "prompt"
1864
+ ],
1865
+ "types": {
1866
+ "prompt": "array"
1867
+ }
1868
+ }
1869
+ },
1870
+ "rlaif": {
1871
+ "trainingTypes": [
1872
+ "lora"
1873
+ ],
1874
+ "datasetFormat": "default-rlaif",
1875
+ "datasetSchema": {
1876
+ "required": [
1877
+ "prompt"
1878
+ ],
1879
+ "types": {
1880
+ "prompt": "array"
1881
+ }
1882
+ }
1883
+ }
1884
+ },
1885
+ "goldenPath": false
1886
+ },
1887
+ "huggingface-reasoning-nvidia-nemotron-3-nano-30b-a3b-bf16": {
1888
+ "family": "huggingface-reasoning-nvidia-nemotron",
1889
+ "provider": "unknown",
1890
+ "displayName": "NVIDIA-Nemotron-3-Nano-30B-A3B-BF16",
1891
+ "huggingFaceId": "",
1892
+ "techniques": {
1893
+ "sft": {
1894
+ "trainingTypes": [
1895
+ "lora"
1896
+ ],
1897
+ "datasetFormat": "default-sft",
1898
+ "datasetSchema": {
1899
+ "required": [
1900
+ "prompt",
1901
+ "completion"
1902
+ ],
1903
+ "types": {
1904
+ "prompt": "string",
1905
+ "completion": "string"
1906
+ }
1907
+ }
1908
+ },
1909
+ "rlaif": {
1910
+ "trainingTypes": [
1911
+ "lora"
1912
+ ],
1913
+ "datasetFormat": "default-rlaif",
1914
+ "datasetSchema": {
1915
+ "required": [
1916
+ "prompt"
1917
+ ],
1918
+ "types": {
1919
+ "prompt": "array"
1920
+ }
1921
+ }
1922
+ },
1923
+ "rlvr": {
1924
+ "trainingTypes": [
1925
+ "lora"
1926
+ ],
1927
+ "datasetFormat": "default-rlvr",
1928
+ "datasetSchema": {
1929
+ "required": [
1930
+ "prompt"
1931
+ ],
1932
+ "types": {
1933
+ "prompt": "array"
1934
+ }
1935
+ }
1936
+ }
1937
+ },
1938
+ "goldenPath": false
1939
+ },
1940
+ "huggingface-vlm-gemma-4-e4b-it": {
1941
+ "family": "huggingface-vlm",
1942
+ "provider": "unknown",
1943
+ "displayName": "gemma-4-e4b-it",
1944
+ "huggingFaceId": "",
1945
+ "techniques": {
1946
+ "dpo": {
1947
+ "trainingTypes": [
1948
+ "lora"
1949
+ ],
1950
+ "datasetFormat": "default-dpo",
1951
+ "datasetSchema": {
1952
+ "required": [
1953
+ "prompt",
1954
+ "chosen",
1955
+ "rejected"
1956
+ ],
1957
+ "types": {
1958
+ "prompt": "string",
1959
+ "chosen": "string",
1960
+ "rejected": "string"
1961
+ }
1962
+ }
1963
+ },
1964
+ "sft": {
1965
+ "trainingTypes": [
1966
+ "lora"
1967
+ ],
1968
+ "datasetFormat": "default-sft",
1969
+ "datasetSchema": {
1970
+ "required": [
1971
+ "prompt",
1972
+ "completion"
1973
+ ],
1974
+ "types": {
1975
+ "prompt": "string",
1976
+ "completion": "string"
1977
+ }
1978
+ }
1979
+ },
1980
+ "rlvr": {
1981
+ "trainingTypes": [
1982
+ "lora"
1983
+ ],
1984
+ "datasetFormat": "default-rlvr",
1985
+ "datasetSchema": {
1986
+ "required": [
1987
+ "prompt"
1988
+ ],
1989
+ "types": {
1990
+ "prompt": "array"
1991
+ }
1992
+ }
1993
+ },
1994
+ "rlaif": {
1995
+ "trainingTypes": [
1996
+ "lora"
1997
+ ],
1998
+ "datasetFormat": "default-rlaif",
1999
+ "datasetSchema": {
2000
+ "required": [
2001
+ "prompt"
2002
+ ],
2003
+ "types": {
2004
+ "prompt": "array"
2005
+ }
2006
+ }
2007
+ }
2008
+ },
2009
+ "goldenPath": false
2010
+ },
2011
+ "huggingface-vlm-gemma-4-31b-it": {
2012
+ "family": "huggingface-vlm",
2013
+ "provider": "unknown",
2014
+ "displayName": "gemma-4-31b-it",
2015
+ "huggingFaceId": "",
2016
+ "techniques": {
2017
+ "dpo": {
2018
+ "trainingTypes": [
2019
+ "lora"
2020
+ ],
2021
+ "datasetFormat": "default-dpo",
2022
+ "datasetSchema": {
2023
+ "required": [
2024
+ "prompt",
2025
+ "chosen",
2026
+ "rejected"
2027
+ ],
2028
+ "types": {
2029
+ "prompt": "string",
2030
+ "chosen": "string",
2031
+ "rejected": "string"
2032
+ }
2033
+ }
2034
+ },
2035
+ "sft": {
2036
+ "trainingTypes": [
2037
+ "lora"
2038
+ ],
2039
+ "datasetFormat": "default-sft",
2040
+ "datasetSchema": {
2041
+ "required": [
2042
+ "prompt",
2043
+ "completion"
2044
+ ],
2045
+ "types": {
2046
+ "prompt": "string",
2047
+ "completion": "string"
2048
+ }
2049
+ }
2050
+ },
2051
+ "rlaif": {
2052
+ "trainingTypes": [
2053
+ "lora"
2054
+ ],
2055
+ "datasetFormat": "default-rlaif",
2056
+ "datasetSchema": {
2057
+ "required": [
2058
+ "prompt"
2059
+ ],
2060
+ "types": {
2061
+ "prompt": "array"
2062
+ }
2063
+ }
2064
+ },
2065
+ "rlvr": {
2066
+ "trainingTypes": [
2067
+ "lora"
2068
+ ],
2069
+ "datasetFormat": "default-rlvr",
2070
+ "datasetSchema": {
2071
+ "required": [
2072
+ "prompt"
2073
+ ],
2074
+ "types": {
2075
+ "prompt": "array"
2076
+ }
2077
+ }
1776
2078
  }
1777
2079
  },
1778
2080
  "goldenPath": false
@@ -1057,6 +1057,49 @@ export class MlccCiHarnessStack extends cdk.Stack {
1057
1057
  glueTable.addDependency(glueDatabase);
1058
1058
  glueTable.cfnOptions.condition = benchmarkInfraCondition;
1059
1059
 
1060
+ // Glue Table: mlcc_evaluations — model quality evaluation results
1061
+ // Written by do/evaluate via .eval_helper.py eval-write subcommand.
1062
+ // Partitioned by model + adapter for efficient comparison queries.
1063
+ const evalGlueTable = new glue.CfnTable(this, 'EvaluationResultsTable', {
1064
+ catalogId: this.account,
1065
+ databaseName: 'mlcc_ci',
1066
+ tableInput: {
1067
+ name: 'mlcc_evaluations',
1068
+ tableType: 'EXTERNAL_TABLE',
1069
+ parameters: {
1070
+ 'classification': 'json',
1071
+ },
1072
+ storageDescriptor: {
1073
+ columns: [
1074
+ { name: 'project_name', type: 'string', comment: 'MCC project name' },
1075
+ { name: 'model_name', type: 'string', comment: 'HuggingFace model ID' },
1076
+ { name: 'adapter_name', type: 'string', comment: 'Adapter name or IC name' },
1077
+ { name: 'technique', type: 'string', comment: 'Training technique (sft, dpo)' },
1078
+ { name: 'eval_dataset', type: 'string', comment: 'Evaluation dataset URI or name' },
1079
+ { name: 'samples_evaluated', type: 'int', comment: 'Number of samples evaluated' },
1080
+ { name: 'metrics', type: 'string', comment: 'JSON blob of all computed metrics' },
1081
+ { name: 'timestamp', type: 'string', comment: 'ISO 8601 UTC timestamp' },
1082
+ { name: 'region', type: 'string', comment: 'AWS region' },
1083
+ ],
1084
+ location: `s3://mlcc-benchmark-results-${this.account}-${this.region}/evaluations/`,
1085
+ inputFormat: 'org.apache.hadoop.mapred.TextInputFormat',
1086
+ outputFormat: 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat',
1087
+ serdeInfo: {
1088
+ serializationLibrary: 'org.openx.data.jsonserde.JsonSerDe',
1089
+ parameters: {
1090
+ 'serialization.format': '1',
1091
+ },
1092
+ },
1093
+ },
1094
+ partitionKeys: [
1095
+ { name: 'model', type: 'string', comment: 'Model name (partition key)' },
1096
+ { name: 'adapter', type: 'string', comment: 'Adapter name (partition key)' },
1097
+ ],
1098
+ },
1099
+ });
1100
+ evalGlueTable.addDependency(glueDatabase);
1101
+ evalGlueTable.cfnOptions.condition = benchmarkInfraCondition;
1102
+
1060
1103
  // Configurable lifecycle parameters for the benchmark results bucket
1061
1104
  const benchmarkIaTransitionDays = new cdk.CfnParameter(this, 'BenchmarkIaTransitionDays', {
1062
1105
  type: 'Number',
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@aws/ml-container-creator",
3
- "version": "1.0.2",
3
+ "version": "1.0.4",
4
4
  "description": "Build and deploy custom ML containers on AWS SageMaker with minimal configuration.",
5
5
  "main": "src/index.js",
6
6
  "bin": {
@@ -88,7 +88,7 @@
88
88
  },
89
89
  "scripts": {
90
90
  "test": "mocha 'test/**/*.test.js' --ignore 'test/property/**' --recursive --timeout 30000 --parallel",
91
- "test:property": "mocha 'test/property/**/*.test.js' --recursive --timeout 60000 --parallel",
91
+ "test:property": "NODE_OPTIONS='--max-old-space-size=8192' mocha 'test/property/**/*.test.js' --recursive --timeout 60000 --parallel --jobs 4",
92
92
  "test:all": "npm run test && npm run test:property",
93
93
  "test:fast": "mocha 'test/**/*.test.js' --recursive --timeout 15000 --parallel",
94
94
  "test:unit": "mocha 'test/unit/**/*.test.js' --recursive --timeout 15000",
@@ -107,6 +107,7 @@
107
107
  "prepare": "husky || true"
108
108
  },
109
109
  "dependencies": {
110
+ "@aws/ml-container-creator": "^1.0.2",
110
111
  "@inquirer/prompts": "^8.4.2",
111
112
  "@modelcontextprotocol/sdk": "^1.27.1",
112
113
  "ajv": "^8.12.0",
@@ -25,6 +25,8 @@ import { readFileSync } from 'node:fs';
25
25
  import { fileURLToPath } from 'node:url';
26
26
  import { resolve, dirname } from 'node:path';
27
27
  import { DynamicResolver as DynamicResolverBase } from '../lib/dynamic-resolver.js';
28
+ import { filterImages, deriveMinDriverVersion } from '../lib/image-filter.js';
29
+ import { resolveModelArchitecture } from '../lib/model-id-resolver.js';
28
30
 
29
31
  // ── Catalog loader ───────────────────────────────────────────────────────────
30
32
 
@@ -156,15 +158,25 @@ class DynamicResolver extends ImageResolver {
156
158
  }
157
159
 
158
160
  const data = await response.json();
159
- const images = (data.results || []).map(tag => ({
160
- image: `${this._repoForFramework(framework)}:${tag.name}`,
161
- tag: tag.name,
162
- architecture: 'amd64',
163
- created: tag.last_updated || tag.tag_last_pushed || new Date().toISOString(),
164
- labels: {},
165
- registry: 'dockerhub',
166
- repository: this._repoForFramework(framework)
167
- }));
161
+ const images = (data.results || []).map(tag => {
162
+ const entry = {
163
+ image: `${this._repoForFramework(framework)}:${tag.name}`,
164
+ tag: tag.name,
165
+ architecture: 'amd64',
166
+ created: tag.last_updated || tag.tag_last_pushed || new Date().toISOString(),
167
+ labels: {},
168
+ registry: 'dockerhub',
169
+ repository: this._repoForFramework(framework)
170
+ };
171
+
172
+ // Derive min_driver_version from CUDA version in tag or labels
173
+ const minDriver = deriveMinDriverVersion(entry);
174
+ if (minDriver) {
175
+ entry.min_driver_version = minDriver;
176
+ }
177
+
178
+ return entry;
179
+ });
168
180
 
169
181
  return {
170
182
  images: images.slice(0, limit),
@@ -375,7 +387,9 @@ if (discoverMode) {
375
387
  * When discover mode is active, merges static and dynamic results.
376
388
  */
377
389
  async function resolveBaseImage(context, limit) {
378
- const { framework, modelServer, searchCriteria, architecture } = context;
390
+ const { framework, modelServer, searchCriteria, architecture,
391
+ instanceType, driverVersion, inferenceAmiVersion,
392
+ tensorParallelSize, modelArchitecture, modelId } = context;
379
393
 
380
394
  // Determine which framework identifier to resolve
381
395
  let resolverKey;
@@ -398,21 +412,52 @@ async function resolveBaseImage(context, limit) {
398
412
 
399
413
  if (discoverMode && dynamicResolver && dynamicResolver.supportedFrameworks().includes(resolverKey)) {
400
414
  // Fetch both static and dynamic results, then merge
401
- const staticResult = await staticResolver.fetchImages(resolverKey, { limit, searchCriteria });
415
+ const staticResult = await staticResolver.fetchImages(resolverKey, { limit: limit * 3, searchCriteria });
402
416
  const dynamicResult = await dynamicResolver.fetchImages(resolverKey, { limit: 5 });
403
417
 
404
- resultImages = mergeStaticAndDynamic(staticResult.images, dynamicResult.images, limit);
418
+ resultImages = mergeStaticAndDynamic(staticResult.images, dynamicResult.images, limit * 3);
405
419
  } else {
406
- // Static-only path (no network calls)
407
- const result = await resolver.fetchImages(resolverKey, { limit, searchCriteria });
420
+ // Static-only path (no network calls) — fetch extra to allow for filtering
421
+ const fetchLimit = (instanceType || driverVersion || modelArchitecture || modelId) ? limit * 3 : limit;
422
+ const result = await resolver.fetchImages(resolverKey, { limit: fetchLimit, searchCriteria });
408
423
  resultImages = result.images;
409
424
  }
410
425
 
426
+ // ── Resolve modelId → modelArchitecture if needed ───────────────────
427
+ let resolvedModelArchitecture = modelArchitecture || '';
428
+ if (!modelArchitecture && modelId) {
429
+ const arch = await resolveModelArchitecture(modelId);
430
+ if (arch) {
431
+ resolvedModelArchitecture = arch;
432
+ }
433
+ }
434
+
435
+ // ── Apply driver-aware + model-architecture filtering ─────────────────
436
+ let filterMetadata = null;
437
+ if (instanceType || driverVersion || inferenceAmiVersion || resolvedModelArchitecture) {
438
+ const filterResult = filterImages(resultImages, {
439
+ framework: resolverKey,
440
+ instanceType,
441
+ driverVersion,
442
+ inferenceAmiVersion,
443
+ tensorParallelSize: tensorParallelSize || 1,
444
+ modelArchitecture: resolvedModelArchitecture
445
+ });
446
+ resultImages = filterResult.images;
447
+ filterMetadata = filterResult.metadata;
448
+ }
449
+
450
+ // Apply final limit after filtering
451
+ resultImages = resultImages.slice(0, limit);
452
+
411
453
  const images = resultImages.map(e => e.image);
412
454
  return {
413
455
  values: { baseImage: images[0] || null },
414
456
  choices: { baseImage: images },
415
- metadata: { baseImage: resultImages }
457
+ metadata: {
458
+ baseImage: resultImages,
459
+ ...(filterMetadata ? { driverFilter: filterMetadata } : {})
460
+ }
416
461
  };
417
462
  }
418
463
 
@@ -432,11 +477,11 @@ const server = new McpServer({
432
477
 
433
478
  server.tool(
434
479
  'get_base_images',
435
- 'Returns curated base container images for ML Container Creator Dockerfiles',
480
+ 'Returns curated base container images for ML Container Creator Dockerfiles. Supports driver-aware filtering when instanceType is provided — excludes images incompatible with the fleet GPU driver, especially for multi-GPU tensor-parallel deployments.',
436
481
  {
437
482
  parameters: z.array(z.string()).describe('List of parameter names to provide values for'),
438
483
  limit: z.number().int().positive().default(5).describe('Maximum number of choices per parameter'),
439
- context: z.record(z.string(), z.any()).optional().describe('Current configuration context (framework, modelServer, searchCriteria)')
484
+ context: z.record(z.string(), z.any()).optional().describe('Configuration context. Supports: framework, modelServer, searchCriteria, architecture, instanceType (triggers driver filtering), driverVersion (override), inferenceAmiVersion (resolves to driver), tensorParallelSize (TP>1 = strict filtering), modelId, modelArchitecture (excludes old framework versions)')
440
485
  },
441
486
  async ({ parameters, limit, context }) => {
442
487
  const values = {};
@@ -472,10 +517,12 @@ export {
472
517
  TRITON_IMAGE_CATALOG,
473
518
  resolveBaseImage,
474
519
  mergeStaticAndDynamic,
520
+ filterImages,
475
521
  registry,
476
522
  staticResolver,
477
523
  dynamicResolver,
478
- discoverMode
524
+ discoverMode,
525
+ resolveModelArchitecture
479
526
  };
480
527
 
481
528
  export { DynamicResolverBase as DynamicResolverBase };
@@ -393,6 +393,38 @@ async function handleGetInstanceRecommendation(params) {
393
393
  { limit }
394
394
  );
395
395
 
396
+ // Step 3-recommended: When VRAM filter returns empty but catalog has recommendedInstances,
397
+ // use those as the fallback (they represent tested/validated deployments).
398
+ if (recommendations.length === 0 && modelMetadata.recommendedInstances && modelMetadata.recommendedInstances.length > 0) {
399
+ for (const instanceType of modelMetadata.recommendedInstances) {
400
+ const meta = effectiveCatalog[instanceType];
401
+ if (meta) {
402
+ const perGpuMemory = getPerGpuMemoryGb(meta);
403
+ const gpuCount = meta.gpus || 1;
404
+ const totalVramGb = perGpuMemory ? perGpuMemory * gpuCount : null;
405
+ recommendations.push({
406
+ instanceType,
407
+ gpuCount,
408
+ totalVramGb,
409
+ utilizationPercent: totalVramGb ? Math.round((vramEstimate.vramGb / totalVramGb) * 100) : null,
410
+ tensorParallelism: gpuCount,
411
+ costTier: meta.costTier || null
412
+ });
413
+ } else {
414
+ // Instance not in catalog but listed as recommended — still include it
415
+ recommendations.push({
416
+ instanceType,
417
+ gpuCount: null,
418
+ totalVramGb: null,
419
+ utilizationPercent: null,
420
+ tensorParallelism: null,
421
+ costTier: null
422
+ });
423
+ }
424
+ }
425
+ log(`Using catalog recommendedInstances for "${modelName}" (VRAM filter returned empty)`);
426
+ }
427
+
396
428
  // Step 3-max_model_len: When no instance fits at full context, try capping context length
397
429
  // NFR-1 guard: skip this logic for models with recommendedInstances in catalog
398
430
  let suggestedMaxModelLen = null;
@@ -0,0 +1,38 @@
1
+ {
2
+ "_comment": "Instance family → GPU driver version mapping for SageMaker inference fleet. Source: AWS docs (inference-gpu-drivers.html) + empirical validation. Updated quarterly or when AWS announces fleet driver updates.",
3
+ "_last_updated": "2026-06-29",
4
+ "instance_families": {
5
+ "g4dn": { "driver": "535.183", "cuda_native": "12.2", "gpu": "T4", "gpu_memory_gb": 16 },
6
+ "g5": { "driver": "550.163", "cuda_native": "12.4", "gpu": "A10G", "gpu_memory_gb": 24 },
7
+ "g5n": { "driver": "550.163", "cuda_native": "12.4", "gpu": "A10G", "gpu_memory_gb": 24 },
8
+ "g6": { "driver": "560.35", "cuda_native": "12.6", "gpu": "L4", "gpu_memory_gb": 24 },
9
+ "g6e": { "driver": "560.35", "cuda_native": "12.6", "gpu": "L40S", "gpu_memory_gb": 48 },
10
+ "p4d": { "driver": "550.163", "cuda_native": "12.4", "gpu": "A100", "gpu_memory_gb": 40 },
11
+ "p4de": { "driver": "550.163", "cuda_native": "12.4", "gpu": "A100", "gpu_memory_gb": 80 },
12
+ "p5": { "driver": "580.95", "cuda_native": "12.9", "gpu": "H100", "gpu_memory_gb": 80 },
13
+ "p5e": { "driver": "580.95", "cuda_native": "12.9", "gpu": "H200", "gpu_memory_gb": 141 },
14
+ "trn1": null,
15
+ "trn2": null,
16
+ "inf2": null
17
+ },
18
+ "ami_versions": {
19
+ "_comment": "InferenceAmiVersion enum → driver version mapping. From SDK enum + empirical.",
20
+ "al2-ami-sagemaker-inference-gpu-2": { "driver": "535.183", "cuda_native": "12.2" },
21
+ "al2-ami-sagemaker-inference-gpu-2-1": { "driver": "535.216", "cuda_native": "12.2" },
22
+ "al2-ami-sagemaker-inference-gpu-3-1": { "driver": "550.163", "cuda_native": "12.4" },
23
+ "al2023-ami-sagemaker-inference-gpu-4-1": { "driver": "570.86", "cuda_native": "12.8" }
24
+ },
25
+ "cuda_to_min_driver": {
26
+ "_comment": "CUDA toolkit version → minimum required driver (Linux data center). Source: NVIDIA CUDA compatibility docs.",
27
+ "12.0": "525.60",
28
+ "12.1": "525.60",
29
+ "12.2": "535.54",
30
+ "12.3": "535.54",
31
+ "12.4": "550.54",
32
+ "12.5": "555.42",
33
+ "12.6": "560.28",
34
+ "12.7": "565.57",
35
+ "12.8": "570.86",
36
+ "12.9": "580.00"
37
+ }
38
+ }