@aws/ml-container-creator 1.0.0 → 1.0.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@aws/ml-container-creator",
3
- "version": "1.0.0",
3
+ "version": "1.0.2",
4
4
  "description": "Build and deploy custom ML containers on AWS SageMaker with minimal configuration.",
5
5
  "main": "src/index.js",
6
6
  "bin": {
@@ -200,22 +200,37 @@ async function fetchEndpoints(client, { limit = 10, showFull = false } = {}) {
200
200
 
201
201
  const variantName = primaryVariant.VariantName || 'AllTraffic';
202
202
  let instanceType = primaryVariant.InstanceType || null;
203
+ let instancePools = primaryVariant.InstancePools || null;
203
204
 
204
205
  // For IC-based endpoints, InstanceType may not be in the variant runtime response.
205
- // Fall back to DescribeEndpointConfig which always has it.
206
- if (!instanceType && detail.EndpointConfigName) {
206
+ // Fall back to DescribeEndpointConfig which has either InstanceType or InstancePools.
207
+ if (!instanceType && !instancePools && detail.EndpointConfigName) {
207
208
  try {
208
209
  const ecCmd = new _DescribeEndpointConfigCommand({ EndpointConfigName: detail.EndpointConfigName });
209
210
  const ecDetail = await client.send(ecCmd);
210
211
  const ecVariant = (ecDetail.ProductionVariants || [])[0];
211
212
  if (ecVariant?.InstanceType) {
212
213
  instanceType = ecVariant.InstanceType;
214
+ } else if (ecVariant?.InstancePools && ecVariant.InstancePools.length > 0) {
215
+ instancePools = ecVariant.InstancePools;
213
216
  }
214
217
  } catch (ecErr) {
215
218
  log(`Warning: could not describe endpoint config for "${endpointName}": ${ecErr.message}`);
216
219
  }
217
220
  }
218
- instanceType = instanceType || 'unknown';
221
+
222
+ // Resolve instanceType display string from pools if needed
223
+ if (!instanceType && instancePools && instancePools.length > 0) {
224
+ // Sort by priority, use highest-priority (lowest number) for GPU lookup
225
+ const sorted = [...instancePools].sort((a, b) => (a.Priority || 99) - (b.Priority || 99));
226
+ instanceType = sorted[0].InstanceType || 'unknown';
227
+ // Build display string showing the pool: "ml.g5.12xl (pool: 3 types)"
228
+ if (sorted.length > 1) {
229
+ instanceType = `${instanceType} (pool: ${sorted.length} types)`;
230
+ }
231
+ } else {
232
+ instanceType = instanceType || 'unknown';
233
+ }
219
234
 
220
235
  const instanceCount = primaryVariant.CurrentInstanceCount ?? primaryVariant.DesiredInstanceCount ?? 1;
221
236
  const hasInstancePools = !!(primaryVariant.InstancePools && primaryVariant.InstancePools.length > 0);
@@ -244,7 +259,12 @@ async function fetchEndpoints(client, { limit = 10, showFull = false } = {}) {
244
259
  } while (icNextToken);
245
260
 
246
261
  // Capacity estimation
247
- const gpusPerInstance = getGpusForInstance(instanceType);
262
+ // For pool endpoints, instanceType may be "ml.g5.12xlarge (pool: 3 types)"
263
+ // Extract the raw type for catalog lookup
264
+ const instanceTypeForLookup = instanceType.includes(' (pool:')
265
+ ? instanceType.split(' (pool:')[0]
266
+ : instanceType;
267
+ const gpusPerInstance = getGpusForInstance(instanceTypeForLookup);
248
268
  let availableGpus;
249
269
  if (gpusPerInstance === null) {
250
270
  availableGpus = '?';
@@ -52,6 +52,7 @@ export default class BootstrapCommandHandler {
52
52
  _setupS3Buckets() { return this.provisioners._setupS3Buckets(); }
53
53
  _createS3Bucket(name, tags) { return this.provisioners._createS3Bucket(name, tags); }
54
54
  _verifyCliV2() { return this.provisioners._verifyCliV2(); }
55
+ _provisionAiRegistryHub(profileData) { return this.provisioners.provisionAiRegistryHub(profileData); }
55
56
 
56
57
  // ── ProfileManager delegations (backward compat for tests) ──────
57
58
 
@@ -357,6 +358,9 @@ export default class BootstrapCommandHandler {
357
358
  console.log(' Tune jobs will still work but experiment tracking may not be available.');
358
359
  }
359
360
 
361
+ // Step 4c: AI Registry Hub
362
+ await this._provisionAiRegistryHub(profileData);
363
+
360
364
  // Step 5: CI Infrastructure setup (separate CDK stack — unchanged)
361
365
  this._displayProgress('🧪', 'CI Testing Infrastructure...');
362
366
  try {
@@ -714,6 +718,10 @@ export default class BootstrapCommandHandler {
714
718
  console.log(` ⚠️ MLflow App setup skipped: ${error.message}`);
715
719
  }
716
720
 
721
+ // Ensure AI Registry hub exists
722
+ this._currentProfile = profileConfig.awsProfile;
723
+ await this._provisionAiRegistryHub(profileConfig);
724
+
717
725
  // Save updated profile
718
726
  this.config.setProfile(name, profileConfig);
719
727
  console.log(`\n✅ Update complete for profile "${name}"`);
@@ -172,6 +172,23 @@ export default class BootstrapProfileManager {
172
172
  }
173
173
  }
174
174
 
175
+ // Check AI Registry hub status
176
+ if (profile.config.aiRegistryHubName) {
177
+ try {
178
+ const hubExists = this.handler._resourceExists(
179
+ `sagemaker describe-hub --hub-name ${profile.config.aiRegistryHubName} --region ${profile.config.awsRegion}`,
180
+ profile.config.awsProfile
181
+ );
182
+ console.log(hubExists
183
+ ? ` ✅ AI Registry hub: ${profile.config.aiRegistryHubName}`
184
+ : ` ⚠️ AI Registry hub: ${profile.config.aiRegistryHubName} — missing`);
185
+ } catch {
186
+ console.log(` ⚠️ AI Registry hub: ${profile.config.aiRegistryHubName} — could not validate`);
187
+ }
188
+ } else {
189
+ console.log(' ℹ️ AI Registry hub: not provisioned (run bootstrap to create)');
190
+ }
191
+
175
192
  // Display deployed resources from manifest
176
193
  console.log('\n📦 Deployed Resources:');
177
194
 
@@ -405,6 +405,54 @@ export default class BootstrapProvisioners {
405
405
  }
406
406
  }
407
407
 
408
+ /**
409
+ * Provision a deterministic SageMaker AI Registry Hub.
410
+ * Idempotent: checks if `mlcc-registry-{accountId}` already exists before creating.
411
+ * Non-fatal: catches all errors and prints a warning — bootstrap continues regardless.
412
+ *
413
+ * @param {object} profileData - Profile data object (mutated in place with hub info)
414
+ */
415
+ async provisionAiRegistryHub(profileData) {
416
+ const hubName = `mlcc-registry-${profileData.accountId}`;
417
+ const region = profileData.awsRegion;
418
+
419
+ console.log('\n📦 Provisioning AI Registry hub...');
420
+
421
+ try {
422
+ // Check if hub already exists (idempotent)
423
+ const hubExists = this.handler._resourceExists(
424
+ `sagemaker describe-hub --hub-name ${hubName} --region ${region}`,
425
+ this.handler._currentProfile
426
+ );
427
+
428
+ if (hubExists) {
429
+ const hubInfo = this.handler._execAws(
430
+ `sagemaker describe-hub --hub-name ${hubName} --region ${region}`,
431
+ this.handler._currentProfile
432
+ );
433
+ console.log(` ✅ AI Registry hub already provisioned: ${hubName}`);
434
+ profileData.aiRegistryHubName = hubName;
435
+ profileData.aiRegistryHubArn = hubInfo.HubArn;
436
+ return;
437
+ }
438
+
439
+ // Create new hub (always — no adopt-existing logic)
440
+ const tags = this._buildResourceTags();
441
+ const tagsFile = this.handler._formatTagsForCli(tags);
442
+ const createResult = this.handler._execAws(
443
+ `sagemaker create-hub --hub-name ${hubName} --hub-display-name "MCC AI Registry" --hub-description "Dataset, evaluator, and model versioning for ml-container-creator" --tags ${tagsFile} --region ${region}`,
444
+ this.handler._currentProfile
445
+ );
446
+ console.log(` ✅ AI Registry hub "${hubName}" — created`);
447
+ profileData.aiRegistryHubName = hubName;
448
+ profileData.aiRegistryHubArn = createResult.HubArn;
449
+ } catch (err) {
450
+ const message = err.message || String(err);
451
+ console.log(` ⚠️ Could not provision AI Registry hub (non-fatal): ${message}`);
452
+ console.log(' Dataset registration will use local JSON registry.');
453
+ }
454
+ }
455
+
408
456
  /**
409
457
  * Build the standard resource tag set.
410
458
  * @returns {Array<{Key: string, Value: string}>} Tag array
@@ -16,6 +16,24 @@ const __dirname = dirname(__filename);
16
16
  * classifies failures, gates tune/adapter stages, and builds
17
17
  * Athena-compatible records with run_type='path_prove'.
18
18
  *
19
+ * ## Module Status (AC-1.4)
20
+ *
21
+ * ALL exported functions are FULLY FUNCTIONAL:
22
+ * - `identifyGaps()` — Cartesian product gap finder, prioritized by neighbor count
23
+ * - `findNearestSubstitution()` — Hamming distance nearest-neighbor, same-family constraint
24
+ * - `classifyFailure()` — regex pattern matching to 6 categories (capacity, timeout, oom, code_bug, model_incompatibility, service_limitation)
25
+ * - `shouldExecuteTuneStages()` — gating logic for tune/adapter stages
26
+ * - `hammingDistance()` — config vector comparison across CONFIG_DIMENSIONS
27
+ * - `buildPathProverRecord()` — Athena record construction with run_type='path_prove'
28
+ * - `findUnfeasibleRecord()` — checks if a config is known-unfeasible to prevent repeated attempts
29
+ * - `getNextPriorityConfig()` — priority queue management for v1 validation mode
30
+ * - `updatePriorityStatus()` — updates target status after prove attempts
31
+ * - `getPriorityQueueStatus()` — summary counts for priority queue
32
+ * - `loadPriorityTargets()` — file-based priority target loading
33
+ * - `resolveProveTpDegree()` — TP degree auto-resolution from instance catalog
34
+ *
35
+ * This is stabilization (tests + docs), not implementation. No new logic needed.
36
+ *
19
37
  * Feature: ci-benchmark-pipeline
20
38
  * Requirements: 8.1–8.12
21
39
  */
@@ -611,6 +629,45 @@ export function loadPriorityTargets(configPath) {
611
629
  }
612
630
  }
613
631
 
632
+ // ── Optimization Space Schema (Task 3 — AC-3.5) ─────────────────────────────
633
+
634
+ /**
635
+ * Load the optimization search space schema from config/optimization-space.json.
636
+ *
637
+ * Returns the parsed schema with dimensions, version, and description.
638
+ * Used by gap identification to enumerate sweepable dimensions and their
639
+ * allowed values for the optimization/prove sweep.
640
+ *
641
+ * @returns {object|null} Parsed schema object, or null if file not found/invalid
642
+ */
643
+ export function loadOptimizationSpace() {
644
+ try {
645
+ const schemaPath = resolve(__dirname, '..', '..', 'config', 'optimization-space.json');
646
+ const raw = readFileSync(schemaPath, 'utf8');
647
+ return JSON.parse(raw);
648
+ } catch {
649
+ return null;
650
+ }
651
+ }
652
+
653
+ /**
654
+ * Get the list of sweepable dimension names from the optimization space schema.
655
+ *
656
+ * Filters dimensions by status === 'sweepable' and returns their keys.
657
+ * Useful for verifying sync between CONFIG_DIMENSIONS and the schema.
658
+ *
659
+ * @param {object} [schema] - Pre-loaded schema (loads from file if omitted)
660
+ * @returns {string[]} Array of sweepable dimension names
661
+ */
662
+ export function getSweepableDimensions(schema = null) {
663
+ const data = schema || loadOptimizationSpace();
664
+ if (!data || !data.dimensions) return [];
665
+
666
+ return Object.keys(data.dimensions).filter(
667
+ key => data.dimensions[key].status === 'sweepable'
668
+ );
669
+ }
670
+
614
671
  // ── TP Degree Auto-Resolution at Prove-Time (Task 6.5) ──────────────────────
615
672
 
616
673
  /**
@@ -8,6 +8,25 @@
8
8
  * Handles stage-specific logic including idempotency checks, status tracking,
9
9
  * and fail-fast behavior.
10
10
  *
11
+ * ## Module Status (AC-1.4)
12
+ *
13
+ * FUNCTIONAL stages:
14
+ * - `executeStageStep()` — fully wired with idempotency via `.mlcc/staged-assets.json`
15
+ * - `isAlreadyStaged()` — checks staged assets existence and validity
16
+ * - `getStagingState()` — resolves current staging state from filesystem + step results
17
+ * - `isValidLifecycleStage()` — validates individual stage names
18
+ * - `validateStagesArray()` — validates arrays of stage names
19
+ * - `formatStagingStatus()` — formats staging state for display
20
+ * - `buildTargetStatus()` — builds status summary for a prove target
21
+ *
22
+ * INTENTIONALLY INCOMPLETE (post-v1 scope):
23
+ * - Other lifecycle stage executors (build, push, deploy, test, tune, adapter,
24
+ * test-adapter, benchmark, register, clean) are NOT implemented.
25
+ * - Only the `stage` step has execution logic. Other stages are recognized in
26
+ * validation but have no executor function.
27
+ * - This is not "broken" — these were never finished before the laptop was bricked.
28
+ * They are explicitly post-v1 scope.
29
+ *
11
30
  * Feature: s3-model-loading
12
31
  * Requirements: 5.1, 5.2, 5.3, 5.4, 5.5
13
32
  */
@@ -40,6 +59,22 @@ export const VALID_LIFECYCLE_STAGES = [
40
59
  'clean'
41
60
  ];
42
61
 
62
+ // TODO(post-v1): Implement executor functions for lifecycle stages beyond 'stage'.
63
+ // The following stages are recognized for validation purposes but have no execution logic:
64
+ // - generate: Should invoke `mcc generate` to produce project scaffolding
65
+ // - build: Should run `do/build` to build the Docker container
66
+ // - push: Should run `do/push` to push container to ECR
67
+ // - deploy: Should run `do/deploy` to create SageMaker endpoint
68
+ // - test: Should run `do/test` to invoke endpoint and verify correctness
69
+ // - tune: Should run `do/tune` for fine-tuning jobs (gated by shouldExecuteTuneStages)
70
+ // - adapter: Should run `do/adapter` for LoRA adapter serving
71
+ // - test-adapter: Should test adapter endpoints after deployment
72
+ // - benchmark: Should run `do/benchmark` for performance measurement
73
+ // - register: Should register proven config in Athena/DynamoDB
74
+ // - clean: Should tear down deployed resources
75
+ // These were never finished before the original developer's laptop was bricked.
76
+ // They are explicitly post-v1 scope, not "broken" code.
77
+
43
78
  /**
44
79
  * Possible staging states for status output.
45
80
  */
@@ -1,5 +1,3 @@
1
- #!/usr/bin/env python3
2
- # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
1
  # SPDX-License-Identifier: Apache-2.0
4
2
 
5
3
  """Benchmark Writer — Converts do/benchmark output to enriched Parquet for Athena.
@@ -340,7 +338,7 @@ def _extract_base_image_version(base_image):
340
338
  return ''
341
339
 
342
340
 
343
- def enrich_records(config, results, run_timestamp=None):
341
+ def enrich_records(config, results, run_timestamp=None, instance_catalog=None):
344
342
  """Build enriched records from config context and benchmark results.
345
343
 
346
344
  Each metrics entry becomes one enriched record with all Athena columns populated.
@@ -349,6 +347,7 @@ def enrich_records(config, results, run_timestamp=None):
349
347
  config: dict with config context fields (project_name, model_name, etc.)
350
348
  results: dict with benchmark results (job_name, metrics array)
351
349
  run_timestamp: Optional datetime for run_timestamp. Defaults to now UTC.
350
+ instance_catalog: Optional pre-loaded instance catalog dict. If None, loaded from disk.
352
351
 
353
352
  Returns:
354
353
  list of enriched record dicts (one per concurrency level).
@@ -364,10 +363,21 @@ def enrich_records(config, results, run_timestamp=None):
364
363
 
365
364
  # Derived fields
366
365
  model_family = derive_model_family(model_name)
366
+ instance_family = derive_instance_family(instance_type)
367
+
368
+ # Resolve instance metadata from catalog (AC-2.8)
369
+ hw_meta = resolve_instance_metadata(instance_type, instance_catalog)
370
+ gpu_count = hw_meta['gpu_count']
371
+ gpu_type = hw_meta['gpu_type']
372
+ gpu_memory_gb = hw_meta['gpu_memory_gb']
367
373
 
368
374
  # Optional context fields
369
375
  deployment_target = config.get('deployment_target', 'realtime-inference')
370
- tensor_parallel_degree = config.get('tensor_parallel_degree', 1)
376
+ try:
377
+ tensor_parallel_degree = int(config.get('tensor_parallel_degree', 1))
378
+ except (ValueError, TypeError):
379
+ tensor_parallel_degree = 1
380
+
371
381
  quantization = config.get('quantization', 'none')
372
382
  enable_lora = config.get('enable_lora', False)
373
383
  base_image = config.get('base_image', '')
@@ -377,6 +387,11 @@ def enrich_records(config, results, run_timestamp=None):
377
387
  ci_run_id = config.get('ci_run_id', '')
378
388
  account_id = config.get('account_id', '')
379
389
 
390
+ # Configuration dimensions (nullable)
391
+ max_model_len_raw = config.get('max_model_len')
392
+ max_model_len = int(max_model_len_raw) if max_model_len_raw not in (None, '', 0) else None
393
+ kv_cache_dtype = config.get('kv_cache_dtype') or None
394
+
380
395
 
381
396
  # Get metrics from results
382
397
  metrics = results.get('metrics', []) if isinstance(results, dict) else []
@@ -447,6 +462,13 @@ def enrich_records(config, results, run_timestamp=None):
447
462
  'deployment_target': deployment_target,
448
463
  'quantization': quantization,
449
464
  'tensor_parallel_degree': tensor_parallel_degree,
465
+ 'instance_family': instance_family,
466
+ 'gpu_count': gpu_count,
467
+ 'gpu_type': gpu_type,
468
+ 'gpu_memory_gb': gpu_memory_gb,
469
+ 'max_model_len': max_model_len,
470
+ 'enable_lora': enable_lora,
471
+ 'kv_cache_dtype': kv_cache_dtype,
450
472
  'serving_config': json.dumps(serving_config_dict),
451
473
  'workload': config.get('workload', 'manual'),
452
474
  'concurrency': concurrency,
@@ -481,6 +503,7 @@ def enrich_records(config, results, run_timestamp=None):
481
503
  'output_sequence_length_avg': scalar(metric.get('output_sequence_length', metric.get('output_sequence_length_avg', 0.0))),
482
504
  'input_sequence_length_avg': scalar(metric.get('input_sequence_length', metric.get('input_sequence_length_avg', 0.0))),
483
505
  'error_rate': error_rate,
506
+ 'cost_per_1m_tokens': cost,
484
507
  'benchmark_duration_sec': metric.get('benchmark_duration_sec', duration_seconds),
485
508
  'run_type': run_type,
486
509
  'benchmark_job_name': results.get('job_name', '') if isinstance(results, dict) else '',
@@ -792,6 +815,54 @@ def register_partition(bucket, model, instance, target,
792
815
  # ── Parquet Serialization ─────────────────────────────────────────────────────
793
816
 
794
817
 
818
+ def load_instance_catalog():
819
+ """Load the instance catalog from servers/lib/catalogs/instances.json.
820
+
821
+ Resolves the path relative to the project root (two levels up from templates/do/).
822
+ Returns the 'catalog' dict mapping instance_type → metadata, or empty dict on failure.
823
+
824
+ Returns:
825
+ dict mapping instance type strings to their metadata dicts.
826
+ """
827
+ # Resolve relative to this file: templates/do/.benchmark_writer.py → project root
828
+ this_dir = os.path.dirname(os.path.abspath(__file__))
829
+ # Navigate up from templates/do/ to project root
830
+ project_root = os.path.normpath(os.path.join(this_dir, '..', '..'))
831
+ catalog_path = os.path.join(project_root, 'servers', 'lib', 'catalogs', 'instances.json')
832
+
833
+ try:
834
+ with open(catalog_path, 'r') as f:
835
+ data = json.load(f)
836
+ return data.get('catalog', {})
837
+ except (FileNotFoundError, json.JSONDecodeError, IOError):
838
+ return {}
839
+
840
+
841
+ def resolve_instance_metadata(instance_type, instance_catalog=None):
842
+ """Resolve GPU metadata from the instance catalog for a given instance_type.
843
+
844
+ Args:
845
+ instance_type: SageMaker instance type (e.g., 'ml.g5.xlarge').
846
+ instance_catalog: Optional pre-loaded catalog dict. If None, loads from disk.
847
+
848
+ Returns:
849
+ dict with keys: gpu_count (int|None), gpu_type (str|None), gpu_memory_gb (float|None).
850
+ All values are None if instance_type is not found in catalog.
851
+ """
852
+ if instance_catalog is None:
853
+ instance_catalog = load_instance_catalog()
854
+
855
+ entry = instance_catalog.get(instance_type)
856
+ if entry is None:
857
+ return {'gpu_count': None, 'gpu_type': None, 'gpu_memory_gb': None}
858
+
859
+ return {
860
+ 'gpu_count': entry.get('gpus'),
861
+ 'gpu_type': entry.get('gpuType'),
862
+ 'gpu_memory_gb': entry.get('gpuMemoryGb'),
863
+ }
864
+
865
+
795
866
  def get_parquet_schema():
796
867
  """Return the pyarrow schema matching the Athena DDL for benchmark_results.
797
868
 
@@ -814,6 +885,17 @@ def get_parquet_schema():
814
885
  pa.field("quantization", pa.string()),
815
886
  pa.field("tensor_parallel_degree", pa.int32()),
816
887
 
888
+ # Hardware metadata (resolved from instance catalog at write time)
889
+ pa.field("instance_family", pa.string()),
890
+ pa.field("gpu_count", pa.int32()),
891
+ pa.field("gpu_type", pa.string()),
892
+ pa.field("gpu_memory_gb", pa.float64()),
893
+
894
+ # Configuration dimensions (top-level for Athena queryability)
895
+ pa.field("max_model_len", pa.int32()),
896
+ pa.field("enable_lora", pa.bool_()),
897
+ pa.field("kv_cache_dtype", pa.string()),
898
+
817
899
  # Full serving config (extensible JSON blob)
818
900
  pa.field("serving_config", pa.string()),
819
901
 
@@ -852,6 +934,7 @@ def get_parquet_schema():
852
934
  pa.field("output_sequence_length_avg", pa.float64()),
853
935
  pa.field("input_sequence_length_avg", pa.float64()),
854
936
  pa.field("error_rate", pa.float64()),
937
+ pa.field("cost_per_1m_tokens", pa.float64()),
855
938
  pa.field("benchmark_duration_sec", pa.float64()),
856
939
 
857
940
  # Run Metadata
@@ -1182,6 +1265,9 @@ def cmd_write(args):
1182
1265
  if args.adapter_name:
1183
1266
  input_data['adapter_name'] = args.adapter_name
1184
1267
 
1268
+ if getattr(args, 'instance_type', None):
1269
+ input_data['instance_type'] = args.instance_type
1270
+
1185
1271
  # ── Validate before any S3 interaction ────────────────────────────────
1186
1272
  errors = validate_benchmark_input(input_data)
1187
1273
  if errors:
@@ -1391,6 +1477,8 @@ def _load_config_file(config_path):
1391
1477
  'MODEL_NAME': 'model_name',
1392
1478
  'HF_MODEL_ID': 'hf_model_id',
1393
1479
  'INSTANCE_TYPE': 'instance_type',
1480
+ 'INSTANCE_POOLS': 'instance_pools',
1481
+ 'BENCHMARK_INSTANCE_TYPE': 'benchmark_instance_type',
1394
1482
  'DEPLOYMENT_CONFIG': 'deployment_config',
1395
1483
  'DEPLOYMENT_TARGET': 'deployment_target',
1396
1484
  'AWS_REGION': 'region',
@@ -1429,6 +1517,24 @@ def _load_config_file(config_path):
1429
1517
  parts = context['model_name'].rstrip('/').split('/')
1430
1518
  context['model_name'] = parts[-1] if parts else context['model_name']
1431
1519
 
1520
+ # Resolve instance_type precedence:
1521
+ # BENCHMARK_INSTANCE_TYPE (live-resolved, persisted by do/benchmark) > INSTANCE_TYPE > INSTANCE_POOLS fallback
1522
+ if context.get('benchmark_instance_type'):
1523
+ context['instance_type'] = context.pop('benchmark_instance_type')
1524
+ # Fall back to INSTANCE_POOLS when neither is set.
1525
+ # Heterogeneous pool configs may not have a standalone INSTANCE_TYPE value
1526
+ # but always define INSTANCE_POOLS as a JSON array with Priority fields.
1527
+ if not context.get('instance_type') and context.get('instance_pools'):
1528
+ try:
1529
+ pools = json.loads(context['instance_pools'])
1530
+ if pools:
1531
+ # Pick the highest-priority (lowest number) instance
1532
+ best = min(pools, key=lambda p: p.get('Priority', 999))
1533
+ context['instance_type'] = best.get('InstanceType', '')
1534
+ except (json.JSONDecodeError, TypeError, KeyError):
1535
+ pass
1536
+ context.pop('instance_pools', None) # Don't leak raw JSON into record
1537
+
1432
1538
  # Also scan IC config files (do/ic/*.conf) for IC_ENV_* serving params
1433
1539
  # These override do/config values for serving-specific settings
1434
1540
  try:
@@ -1505,6 +1611,10 @@ def main():
1505
1611
  help='LoRA adapter name (differentiates adapter benchmarks from base model in Athena)'
1506
1612
  )
1507
1613
 
1614
+ write_parser.add_argument(
1615
+ '--instance-type', dest='instance_type', default=None,
1616
+ help='Override instance type (use when actual provisioned instance differs from config, e.g. heterogeneous pools)'
1617
+ )
1508
1618
  write_parser.add_argument(
1509
1619
  '--dry-run', dest='dry_run', action='store_true',
1510
1620
  help='Output enriched records as JSON without writing to S3'