@aws/ml-container-creator 1.0.0 → 1.0.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -52,6 +52,7 @@ export default class BootstrapCommandHandler {
52
52
  _setupS3Buckets() { return this.provisioners._setupS3Buckets(); }
53
53
  _createS3Bucket(name, tags) { return this.provisioners._createS3Bucket(name, tags); }
54
54
  _verifyCliV2() { return this.provisioners._verifyCliV2(); }
55
+ _provisionAiRegistryHub(profileData) { return this.provisioners.provisionAiRegistryHub(profileData); }
55
56
 
56
57
  // ── ProfileManager delegations (backward compat for tests) ──────
57
58
 
@@ -63,6 +64,7 @@ export default class BootstrapCommandHandler {
63
64
  _handlePrune() { return this.profileManager._handlePrune(); }
64
65
  _handleSyncSchemas() { return this.profileManager._handleSyncSchemas(); }
65
66
  _handleSyncModelFamilies() { return this.profileManager._handleSyncModelFamilies(); }
67
+ _handleSyncServingVersions() { return this.profileManager._handleSyncServingVersions(); }
66
68
 
67
69
  /**
68
70
  * Dispatch bootstrap subcommands.
@@ -131,6 +133,9 @@ export default class BootstrapCommandHandler {
131
133
  case 'sync-model-families':
132
134
  await this._handleSyncModelFamilies();
133
135
  break;
136
+ case 'sync-serving-versions':
137
+ await this._handleSyncServingVersions();
138
+ break;
134
139
  // Migration path: upgrades legacy profiles to current naming conventions.
135
140
  // Corrects stackName to mlcc-bootstrap-{profileName}, renames sharedStackFrom
136
141
  // to sharedInfraFrom. Idempotent — safe to run multiple times.
@@ -357,6 +362,9 @@ export default class BootstrapCommandHandler {
357
362
  console.log(' Tune jobs will still work but experiment tracking may not be available.');
358
363
  }
359
364
 
365
+ // Step 4c: AI Registry Hub
366
+ await this._provisionAiRegistryHub(profileData);
367
+
360
368
  // Step 5: CI Infrastructure setup (separate CDK stack — unchanged)
361
369
  this._displayProgress('🧪', 'CI Testing Infrastructure...');
362
370
  try {
@@ -714,6 +722,10 @@ export default class BootstrapCommandHandler {
714
722
  console.log(` ⚠️ MLflow App setup skipped: ${error.message}`);
715
723
  }
716
724
 
725
+ // Ensure AI Registry hub exists
726
+ this._currentProfile = profileConfig.awsProfile;
727
+ await this._provisionAiRegistryHub(profileConfig);
728
+
717
729
  // Save updated profile
718
730
  this.config.setProfile(name, profileConfig);
719
731
  console.log(`\n✅ Update complete for profile "${name}"`);
@@ -1459,7 +1471,9 @@ SUBCOMMANDS:
1459
1471
  prune Remove deleted and unknown records from the deployment manifest
1460
1472
  update Re-deploy bootstrap stacks using active profile (no prompts)
1461
1473
  migrate Upgrade legacy profiles to current naming conventions
1474
+ sync-schemas Download AWS service model schemas (sagemaker, iam, ecr, s3)
1462
1475
  sync-model-families Discover tune-eligible models from JumpStart Hub and update catalog
1476
+ sync-serving-versions Discover latest vLLM/SGLang/TRT-LLM image versions and update catalog
1463
1477
 
1464
1478
  SETUP OPTIONS:
1465
1479
  --non-interactive Run without interactive prompts
@@ -1469,8 +1483,10 @@ SETUP OPTIONS:
1469
1483
  --role-arn <arn> Use existing IAM role ARN (skip role creation)
1470
1484
  --skip-s3 Skip S3 bucket creation
1471
1485
  --ci Provision CI testing infrastructure
1486
+ --benchmark-infra Provision Athena/Glue benchmark infrastructure (requires --ci)
1472
1487
  --skip-ci Skip CI infrastructure provisioning
1473
1488
  --skip-post-setup Skip post-setup chain (mcp init, sync-architectures, sync-schemas)
1489
+ --ignore-staleness Suppress schema staleness warnings
1474
1490
 
1475
1491
  STATUS OPTIONS:
1476
1492
  --verify Check each active resource against AWS APIs for drift detection
@@ -1487,13 +1503,15 @@ EXAMPLES:
1487
1503
  ml-container-creator bootstrap list
1488
1504
  ml-container-creator bootstrap remove dev
1489
1505
  ml-container-creator bootstrap remove dev --force --delete-stack
1506
+ ml-container-creator bootstrap update
1507
+ ml-container-creator bootstrap update --ci --benchmark-infra
1490
1508
  ml-container-creator bootstrap scan
1509
+ ml-container-creator bootstrap sync-schemas
1491
1510
  ml-container-creator bootstrap sync-model-families
1511
+ ml-container-creator bootstrap sync-serving-versions
1492
1512
  ml-container-creator bootstrap migrate
1493
1513
  ml-container-creator bootstrap --non-interactive --profile my-aws-profile --region us-west-2
1494
- ml-container-creator bootstrap --non-interactive --profile my-aws-profile --role-arn arn:aws:iam::123456789012:role/MyRole --skip-s3
1495
1514
  ml-container-creator bootstrap --non-interactive --profile my-aws-profile --region us-west-2 --ci
1496
- ml-container-creator bootstrap --non-interactive --profile my-aws-profile --region us-west-2 --skip-ci
1497
1515
  `);
1498
1516
  }
1499
1517
 
@@ -172,6 +172,23 @@ export default class BootstrapProfileManager {
172
172
  }
173
173
  }
174
174
 
175
+ // Check AI Registry hub status
176
+ if (profile.config.aiRegistryHubName) {
177
+ try {
178
+ const hubExists = this.handler._resourceExists(
179
+ `sagemaker describe-hub --hub-name ${profile.config.aiRegistryHubName} --region ${profile.config.awsRegion}`,
180
+ profile.config.awsProfile
181
+ );
182
+ console.log(hubExists
183
+ ? ` ✅ AI Registry hub: ${profile.config.aiRegistryHubName}`
184
+ : ` ⚠️ AI Registry hub: ${profile.config.aiRegistryHubName} — missing`);
185
+ } catch {
186
+ console.log(` ⚠️ AI Registry hub: ${profile.config.aiRegistryHubName} — could not validate`);
187
+ }
188
+ } else {
189
+ console.log(' ℹ️ AI Registry hub: not provisioned (run bootstrap to create)');
190
+ }
191
+
175
192
  // Display deployed resources from manifest
176
193
  console.log('\n📦 Deployed Resources:');
177
194
 
@@ -638,4 +655,20 @@ export default class BootstrapProfileManager {
638
655
  process.exit(1);
639
656
  }
640
657
  }
658
+
659
+ /**
660
+ * Handle sync-serving-versions subcommand: discover latest container image
661
+ * versions for vLLM, SGLang, and TensorRT-LLM and update the model-servers catalog.
662
+ */
663
+ async _handleSyncServingVersions() {
664
+ console.log('\n🔄 Sync Serving Versions — Discovering latest container images...\n');
665
+ try {
666
+ const { syncServingVersions } = await import('../../scripts/sync-serving-versions.js');
667
+ const result = await syncServingVersions();
668
+ console.log(`\n✅ Sync complete: ${result.totalAdded} new, ${result.totalRemoved} pruned\n`);
669
+ } catch (err) {
670
+ console.log(`❌ Sync failed: ${err.message}`);
671
+ process.exit(1);
672
+ }
673
+ }
641
674
  }
@@ -405,6 +405,54 @@ export default class BootstrapProvisioners {
405
405
  }
406
406
  }
407
407
 
408
+ /**
409
+ * Provision a deterministic SageMaker AI Registry Hub.
410
+ * Idempotent: checks if `mlcc-registry-{accountId}` already exists before creating.
411
+ * Non-fatal: catches all errors and prints a warning — bootstrap continues regardless.
412
+ *
413
+ * @param {object} profileData - Profile data object (mutated in place with hub info)
414
+ */
415
+ async provisionAiRegistryHub(profileData) {
416
+ const hubName = `mlcc-registry-${profileData.accountId}`;
417
+ const region = profileData.awsRegion;
418
+
419
+ console.log('\n📦 Provisioning AI Registry hub...');
420
+
421
+ try {
422
+ // Check if hub already exists (idempotent)
423
+ const hubExists = this.handler._resourceExists(
424
+ `sagemaker describe-hub --hub-name ${hubName} --region ${region}`,
425
+ this.handler._currentProfile
426
+ );
427
+
428
+ if (hubExists) {
429
+ const hubInfo = this.handler._execAws(
430
+ `sagemaker describe-hub --hub-name ${hubName} --region ${region}`,
431
+ this.handler._currentProfile
432
+ );
433
+ console.log(` ✅ AI Registry hub already provisioned: ${hubName}`);
434
+ profileData.aiRegistryHubName = hubName;
435
+ profileData.aiRegistryHubArn = hubInfo.HubArn;
436
+ return;
437
+ }
438
+
439
+ // Create new hub (always — no adopt-existing logic)
440
+ const tags = this._buildResourceTags();
441
+ const tagsFile = this.handler._formatTagsForCli(tags);
442
+ const createResult = this.handler._execAws(
443
+ `sagemaker create-hub --hub-name ${hubName} --hub-display-name "MCC AI Registry" --hub-description "Dataset, evaluator, and model versioning for ml-container-creator" --tags ${tagsFile} --region ${region}`,
444
+ this.handler._currentProfile
445
+ );
446
+ console.log(` ✅ AI Registry hub "${hubName}" — created`);
447
+ profileData.aiRegistryHubName = hubName;
448
+ profileData.aiRegistryHubArn = createResult.HubArn;
449
+ } catch (err) {
450
+ const message = err.message || String(err);
451
+ console.log(` ⚠️ Could not provision AI Registry hub (non-fatal): ${message}`);
452
+ console.log(' Dataset registration will use local JSON registry.');
453
+ }
454
+ }
455
+
408
456
  /**
409
457
  * Build the standard resource tag set.
410
458
  * @returns {Array<{Key: string, Value: string}>} Tag array
@@ -290,7 +290,12 @@ export default class CrossCuttingChecker {
290
290
  if (!modelType || !server || !serverVersion) return findings;
291
291
 
292
292
  const entries = modelServersCatalog[server] || [];
293
- const entry = entries.find(e => e.labels?.framework_version === serverVersion);
293
+ // Try exact version match first, then fall back to nearest entry with supportedModelTypes
294
+ let entry = entries.find(e => e.labels?.framework_version === serverVersion);
295
+ if (!entry?.supportedModelTypes?.length) {
296
+ // Fall back to any entry that has supportedModelTypes populated
297
+ entry = entries.find(e => e.supportedModelTypes?.length > 0);
298
+ }
294
299
  if (!entry?.supportedModelTypes?.length) return findings;
295
300
 
296
301
  if (!entry.supportedModelTypes.includes(modelType.toLowerCase())) {
@@ -1,6 +1,6 @@
1
1
  // AUTO-GENERATED by scripts/codegen-cli.js — DO NOT EDIT
2
2
  // Source: config/parameter-schema-v2.json
3
- // Generated: 2026-06-23T20:55:23.381Z
3
+ // Generated: 2026-06-29T13:37:06.271Z
4
4
 
5
5
  /**
6
6
  * CLI option definitions derived from parameter-schema-v2.json.
@@ -1,6 +1,6 @@
1
1
  // AUTO-GENERATED by scripts/codegen-parameter-matrix.js — DO NOT EDIT
2
2
  // Source: config/parameter-schema-v2.json
3
- // Generated: 2026-06-23T20:55:23.482Z
3
+ // Generated: 2026-06-29T13:37:06.375Z
4
4
 
5
5
  /**
6
6
  * Parameter matrix defining how each parameter is loaded from various sources.
@@ -1,6 +1,6 @@
1
1
  // AUTO-GENERATED by scripts/codegen-validator.js — DO NOT EDIT
2
2
  // Source: config/parameter-schema-v2.json
3
- // Generated: 2026-06-23T20:55:23.412Z
3
+ // Generated: 2026-06-29T13:37:06.303Z
4
4
 
5
5
  /**
6
6
  * Validation rules derived from parameter-schema-v2.json.
@@ -16,6 +16,24 @@ const __dirname = dirname(__filename);
16
16
  * classifies failures, gates tune/adapter stages, and builds
17
17
  * Athena-compatible records with run_type='path_prove'.
18
18
  *
19
+ * ## Module Status (AC-1.4)
20
+ *
21
+ * ALL exported functions are FULLY FUNCTIONAL:
22
+ * - `identifyGaps()` — Cartesian product gap finder, prioritized by neighbor count
23
+ * - `findNearestSubstitution()` — Hamming distance nearest-neighbor, same-family constraint
24
+ * - `classifyFailure()` — regex pattern matching to 6 categories (capacity, timeout, oom, code_bug, model_incompatibility, service_limitation)
25
+ * - `shouldExecuteTuneStages()` — gating logic for tune/adapter stages
26
+ * - `hammingDistance()` — config vector comparison across CONFIG_DIMENSIONS
27
+ * - `buildPathProverRecord()` — Athena record construction with run_type='path_prove'
28
+ * - `findUnfeasibleRecord()` — checks if a config is known-unfeasible to prevent repeated attempts
29
+ * - `getNextPriorityConfig()` — priority queue management for v1 validation mode
30
+ * - `updatePriorityStatus()` — updates target status after prove attempts
31
+ * - `getPriorityQueueStatus()` — summary counts for priority queue
32
+ * - `loadPriorityTargets()` — file-based priority target loading
33
+ * - `resolveProveTpDegree()` — TP degree auto-resolution from instance catalog
34
+ *
35
+ * This is stabilization (tests + docs), not implementation. No new logic needed.
36
+ *
19
37
  * Feature: ci-benchmark-pipeline
20
38
  * Requirements: 8.1–8.12
21
39
  */
@@ -611,6 +629,45 @@ export function loadPriorityTargets(configPath) {
611
629
  }
612
630
  }
613
631
 
632
+ // ── Optimization Space Schema (Task 3 — AC-3.5) ─────────────────────────────
633
+
634
+ /**
635
+ * Load the optimization search space schema from config/optimization-space.json.
636
+ *
637
+ * Returns the parsed schema with dimensions, version, and description.
638
+ * Used by gap identification to enumerate sweepable dimensions and their
639
+ * allowed values for the optimization/prove sweep.
640
+ *
641
+ * @returns {object|null} Parsed schema object, or null if file not found/invalid
642
+ */
643
+ export function loadOptimizationSpace() {
644
+ try {
645
+ const schemaPath = resolve(__dirname, '..', '..', 'config', 'optimization-space.json');
646
+ const raw = readFileSync(schemaPath, 'utf8');
647
+ return JSON.parse(raw);
648
+ } catch {
649
+ return null;
650
+ }
651
+ }
652
+
653
+ /**
654
+ * Get the list of sweepable dimension names from the optimization space schema.
655
+ *
656
+ * Filters dimensions by status === 'sweepable' and returns their keys.
657
+ * Useful for verifying sync between CONFIG_DIMENSIONS and the schema.
658
+ *
659
+ * @param {object} [schema] - Pre-loaded schema (loads from file if omitted)
660
+ * @returns {string[]} Array of sweepable dimension names
661
+ */
662
+ export function getSweepableDimensions(schema = null) {
663
+ const data = schema || loadOptimizationSpace();
664
+ if (!data || !data.dimensions) return [];
665
+
666
+ return Object.keys(data.dimensions).filter(
667
+ key => data.dimensions[key].status === 'sweepable'
668
+ );
669
+ }
670
+
614
671
  // ── TP Degree Auto-Resolution at Prove-Time (Task 6.5) ──────────────────────
615
672
 
616
673
  /**
@@ -8,6 +8,25 @@
8
8
  * Handles stage-specific logic including idempotency checks, status tracking,
9
9
  * and fail-fast behavior.
10
10
  *
11
+ * ## Module Status (AC-1.4)
12
+ *
13
+ * FUNCTIONAL stages:
14
+ * - `executeStageStep()` — fully wired with idempotency via `.mlcc/staged-assets.json`
15
+ * - `isAlreadyStaged()` — checks staged assets existence and validity
16
+ * - `getStagingState()` — resolves current staging state from filesystem + step results
17
+ * - `isValidLifecycleStage()` — validates individual stage names
18
+ * - `validateStagesArray()` — validates arrays of stage names
19
+ * - `formatStagingStatus()` — formats staging state for display
20
+ * - `buildTargetStatus()` — builds status summary for a prove target
21
+ *
22
+ * INTENTIONALLY INCOMPLETE (post-v1 scope):
23
+ * - Other lifecycle stage executors (build, push, deploy, test, tune, adapter,
24
+ * test-adapter, benchmark, register, clean) are NOT implemented.
25
+ * - Only the `stage` step has execution logic. Other stages are recognized in
26
+ * validation but have no executor function.
27
+ * - This is not "broken" — these were never finished before the laptop was bricked.
28
+ * They are explicitly post-v1 scope.
29
+ *
11
30
  * Feature: s3-model-loading
12
31
  * Requirements: 5.1, 5.2, 5.3, 5.4, 5.5
13
32
  */
@@ -40,6 +59,22 @@ export const VALID_LIFECYCLE_STAGES = [
40
59
  'clean'
41
60
  ];
42
61
 
62
+ // TODO(post-v1): Implement executor functions for lifecycle stages beyond 'stage'.
63
+ // The following stages are recognized for validation purposes but have no execution logic:
64
+ // - generate: Should invoke `mcc generate` to produce project scaffolding
65
+ // - build: Should run `do/build` to build the Docker container
66
+ // - push: Should run `do/push` to push container to ECR
67
+ // - deploy: Should run `do/deploy` to create SageMaker endpoint
68
+ // - test: Should run `do/test` to invoke endpoint and verify correctness
69
+ // - tune: Should run `do/tune` for fine-tuning jobs (gated by shouldExecuteTuneStages)
70
+ // - adapter: Should run `do/adapter` for LoRA adapter serving
71
+ // - test-adapter: Should test adapter endpoints after deployment
72
+ // - benchmark: Should run `do/benchmark` for performance measurement
73
+ // - register: Should register proven config in Athena/DynamoDB
74
+ // - clean: Should tear down deployed resources
75
+ // These were never finished before the original developer's laptop was bricked.
76
+ // They are explicitly post-v1 scope, not "broken" code.
77
+
43
78
  /**
44
79
  * Possible staging states for status output.
45
80
  */
@@ -1,5 +1,3 @@
1
- #!/usr/bin/env python3
2
- # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
1
  # SPDX-License-Identifier: Apache-2.0
4
2
 
5
3
  """Benchmark Writer — Converts do/benchmark output to enriched Parquet for Athena.
@@ -340,7 +338,7 @@ def _extract_base_image_version(base_image):
340
338
  return ''
341
339
 
342
340
 
343
- def enrich_records(config, results, run_timestamp=None):
341
+ def enrich_records(config, results, run_timestamp=None, instance_catalog=None):
344
342
  """Build enriched records from config context and benchmark results.
345
343
 
346
344
  Each metrics entry becomes one enriched record with all Athena columns populated.
@@ -349,6 +347,7 @@ def enrich_records(config, results, run_timestamp=None):
349
347
  config: dict with config context fields (project_name, model_name, etc.)
350
348
  results: dict with benchmark results (job_name, metrics array)
351
349
  run_timestamp: Optional datetime for run_timestamp. Defaults to now UTC.
350
+ instance_catalog: Optional pre-loaded instance catalog dict. If None, loaded from disk.
352
351
 
353
352
  Returns:
354
353
  list of enriched record dicts (one per concurrency level).
@@ -364,10 +363,21 @@ def enrich_records(config, results, run_timestamp=None):
364
363
 
365
364
  # Derived fields
366
365
  model_family = derive_model_family(model_name)
366
+ instance_family = derive_instance_family(instance_type)
367
+
368
+ # Resolve instance metadata from catalog (AC-2.8)
369
+ hw_meta = resolve_instance_metadata(instance_type, instance_catalog)
370
+ gpu_count = hw_meta['gpu_count']
371
+ gpu_type = hw_meta['gpu_type']
372
+ gpu_memory_gb = hw_meta['gpu_memory_gb']
367
373
 
368
374
  # Optional context fields
369
375
  deployment_target = config.get('deployment_target', 'realtime-inference')
370
- tensor_parallel_degree = config.get('tensor_parallel_degree', 1)
376
+ try:
377
+ tensor_parallel_degree = int(config.get('tensor_parallel_degree', 1))
378
+ except (ValueError, TypeError):
379
+ tensor_parallel_degree = 1
380
+
371
381
  quantization = config.get('quantization', 'none')
372
382
  enable_lora = config.get('enable_lora', False)
373
383
  base_image = config.get('base_image', '')
@@ -377,6 +387,11 @@ def enrich_records(config, results, run_timestamp=None):
377
387
  ci_run_id = config.get('ci_run_id', '')
378
388
  account_id = config.get('account_id', '')
379
389
 
390
+ # Configuration dimensions (nullable)
391
+ max_model_len_raw = config.get('max_model_len')
392
+ max_model_len = int(max_model_len_raw) if max_model_len_raw not in (None, '', 0) else None
393
+ kv_cache_dtype = config.get('kv_cache_dtype') or None
394
+
380
395
 
381
396
  # Get metrics from results
382
397
  metrics = results.get('metrics', []) if isinstance(results, dict) else []
@@ -447,6 +462,13 @@ def enrich_records(config, results, run_timestamp=None):
447
462
  'deployment_target': deployment_target,
448
463
  'quantization': quantization,
449
464
  'tensor_parallel_degree': tensor_parallel_degree,
465
+ 'instance_family': instance_family,
466
+ 'gpu_count': gpu_count,
467
+ 'gpu_type': gpu_type,
468
+ 'gpu_memory_gb': gpu_memory_gb,
469
+ 'max_model_len': max_model_len,
470
+ 'enable_lora': enable_lora,
471
+ 'kv_cache_dtype': kv_cache_dtype,
450
472
  'serving_config': json.dumps(serving_config_dict),
451
473
  'workload': config.get('workload', 'manual'),
452
474
  'concurrency': concurrency,
@@ -481,6 +503,7 @@ def enrich_records(config, results, run_timestamp=None):
481
503
  'output_sequence_length_avg': scalar(metric.get('output_sequence_length', metric.get('output_sequence_length_avg', 0.0))),
482
504
  'input_sequence_length_avg': scalar(metric.get('input_sequence_length', metric.get('input_sequence_length_avg', 0.0))),
483
505
  'error_rate': error_rate,
506
+ 'cost_per_1m_tokens': cost,
484
507
  'benchmark_duration_sec': metric.get('benchmark_duration_sec', duration_seconds),
485
508
  'run_type': run_type,
486
509
  'benchmark_job_name': results.get('job_name', '') if isinstance(results, dict) else '',
@@ -792,6 +815,54 @@ def register_partition(bucket, model, instance, target,
792
815
  # ── Parquet Serialization ─────────────────────────────────────────────────────
793
816
 
794
817
 
818
+ def load_instance_catalog():
819
+ """Load the instance catalog from servers/lib/catalogs/instances.json.
820
+
821
+ Resolves the path relative to the project root (two levels up from templates/do/).
822
+ Returns the 'catalog' dict mapping instance_type → metadata, or empty dict on failure.
823
+
824
+ Returns:
825
+ dict mapping instance type strings to their metadata dicts.
826
+ """
827
+ # Resolve relative to this file: templates/do/.benchmark_writer.py → project root
828
+ this_dir = os.path.dirname(os.path.abspath(__file__))
829
+ # Navigate up from templates/do/ to project root
830
+ project_root = os.path.normpath(os.path.join(this_dir, '..', '..'))
831
+ catalog_path = os.path.join(project_root, 'servers', 'lib', 'catalogs', 'instances.json')
832
+
833
+ try:
834
+ with open(catalog_path, 'r') as f:
835
+ data = json.load(f)
836
+ return data.get('catalog', {})
837
+ except (FileNotFoundError, json.JSONDecodeError, IOError):
838
+ return {}
839
+
840
+
841
+ def resolve_instance_metadata(instance_type, instance_catalog=None):
842
+ """Resolve GPU metadata from the instance catalog for a given instance_type.
843
+
844
+ Args:
845
+ instance_type: SageMaker instance type (e.g., 'ml.g5.xlarge').
846
+ instance_catalog: Optional pre-loaded catalog dict. If None, loads from disk.
847
+
848
+ Returns:
849
+ dict with keys: gpu_count (int|None), gpu_type (str|None), gpu_memory_gb (float|None).
850
+ All values are None if instance_type is not found in catalog.
851
+ """
852
+ if instance_catalog is None:
853
+ instance_catalog = load_instance_catalog()
854
+
855
+ entry = instance_catalog.get(instance_type)
856
+ if entry is None:
857
+ return {'gpu_count': None, 'gpu_type': None, 'gpu_memory_gb': None}
858
+
859
+ return {
860
+ 'gpu_count': entry.get('gpus'),
861
+ 'gpu_type': entry.get('gpuType'),
862
+ 'gpu_memory_gb': entry.get('gpuMemoryGb'),
863
+ }
864
+
865
+
795
866
  def get_parquet_schema():
796
867
  """Return the pyarrow schema matching the Athena DDL for benchmark_results.
797
868
 
@@ -814,6 +885,17 @@ def get_parquet_schema():
814
885
  pa.field("quantization", pa.string()),
815
886
  pa.field("tensor_parallel_degree", pa.int32()),
816
887
 
888
+ # Hardware metadata (resolved from instance catalog at write time)
889
+ pa.field("instance_family", pa.string()),
890
+ pa.field("gpu_count", pa.int32()),
891
+ pa.field("gpu_type", pa.string()),
892
+ pa.field("gpu_memory_gb", pa.float64()),
893
+
894
+ # Configuration dimensions (top-level for Athena queryability)
895
+ pa.field("max_model_len", pa.int32()),
896
+ pa.field("enable_lora", pa.bool_()),
897
+ pa.field("kv_cache_dtype", pa.string()),
898
+
817
899
  # Full serving config (extensible JSON blob)
818
900
  pa.field("serving_config", pa.string()),
819
901
 
@@ -852,6 +934,7 @@ def get_parquet_schema():
852
934
  pa.field("output_sequence_length_avg", pa.float64()),
853
935
  pa.field("input_sequence_length_avg", pa.float64()),
854
936
  pa.field("error_rate", pa.float64()),
937
+ pa.field("cost_per_1m_tokens", pa.float64()),
855
938
  pa.field("benchmark_duration_sec", pa.float64()),
856
939
 
857
940
  # Run Metadata
@@ -1182,6 +1265,9 @@ def cmd_write(args):
1182
1265
  if args.adapter_name:
1183
1266
  input_data['adapter_name'] = args.adapter_name
1184
1267
 
1268
+ if getattr(args, 'instance_type', None):
1269
+ input_data['instance_type'] = args.instance_type
1270
+
1185
1271
  # ── Validate before any S3 interaction ────────────────────────────────
1186
1272
  errors = validate_benchmark_input(input_data)
1187
1273
  if errors:
@@ -1391,6 +1477,8 @@ def _load_config_file(config_path):
1391
1477
  'MODEL_NAME': 'model_name',
1392
1478
  'HF_MODEL_ID': 'hf_model_id',
1393
1479
  'INSTANCE_TYPE': 'instance_type',
1480
+ 'INSTANCE_POOLS': 'instance_pools',
1481
+ 'BENCHMARK_INSTANCE_TYPE': 'benchmark_instance_type',
1394
1482
  'DEPLOYMENT_CONFIG': 'deployment_config',
1395
1483
  'DEPLOYMENT_TARGET': 'deployment_target',
1396
1484
  'AWS_REGION': 'region',
@@ -1429,6 +1517,24 @@ def _load_config_file(config_path):
1429
1517
  parts = context['model_name'].rstrip('/').split('/')
1430
1518
  context['model_name'] = parts[-1] if parts else context['model_name']
1431
1519
 
1520
+ # Resolve instance_type precedence:
1521
+ # BENCHMARK_INSTANCE_TYPE (live-resolved, persisted by do/benchmark) > INSTANCE_TYPE > INSTANCE_POOLS fallback
1522
+ if context.get('benchmark_instance_type'):
1523
+ context['instance_type'] = context.pop('benchmark_instance_type')
1524
+ # Fall back to INSTANCE_POOLS when neither is set.
1525
+ # Heterogeneous pool configs may not have a standalone INSTANCE_TYPE value
1526
+ # but always define INSTANCE_POOLS as a JSON array with Priority fields.
1527
+ if not context.get('instance_type') and context.get('instance_pools'):
1528
+ try:
1529
+ pools = json.loads(context['instance_pools'])
1530
+ if pools:
1531
+ # Pick the highest-priority (lowest number) instance
1532
+ best = min(pools, key=lambda p: p.get('Priority', 999))
1533
+ context['instance_type'] = best.get('InstanceType', '')
1534
+ except (json.JSONDecodeError, TypeError, KeyError):
1535
+ pass
1536
+ context.pop('instance_pools', None) # Don't leak raw JSON into record
1537
+
1432
1538
  # Also scan IC config files (do/ic/*.conf) for IC_ENV_* serving params
1433
1539
  # These override do/config values for serving-specific settings
1434
1540
  try:
@@ -1505,6 +1611,10 @@ def main():
1505
1611
  help='LoRA adapter name (differentiates adapter benchmarks from base model in Athena)'
1506
1612
  )
1507
1613
 
1614
+ write_parser.add_argument(
1615
+ '--instance-type', dest='instance_type', default=None,
1616
+ help='Override instance type (use when actual provisioned instance differs from config, e.g. heterogeneous pools)'
1617
+ )
1508
1618
  write_parser.add_argument(
1509
1619
  '--dry-run', dest='dry_run', action='store_true',
1510
1620
  help='Output enriched records as JSON without writing to S3'