@aws/ml-container-creator 0.15.1 → 1.0.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/package.json +1 -1
- package/servers/endpoint-picker/index.js +24 -4
- package/src/lib/bootstrap-command-handler.js +8 -0
- package/src/lib/bootstrap-profile-manager.js +17 -0
- package/src/lib/bootstrap-provisioners.js +48 -0
- package/src/lib/path-prover-brain.js +57 -0
- package/src/lib/prove-pipeline-executor.js +35 -0
- package/templates/do/.benchmark_writer.py +114 -4
- package/templates/do/.register_helper.py +643 -67
- package/templates/do/.stage_helper.py +1 -0
- package/templates/do/__pycache__/.benchmark_writer.cpython-312.pyc +0 -0
- package/templates/do/__pycache__/.register_helper.cpython-312.pyc +0 -0
- package/templates/do/__pycache__/.tune_helper.cpython-312.pyc +0 -0
- package/templates/do/adapter +267 -171
- package/templates/do/benchmark +74 -14
- package/templates/do/config +1 -1
- package/templates/do/lib/inference-component.sh +6 -25
- package/templates/do/register +29 -2
- package/templates/do/tune +94 -12
package/package.json
CHANGED
|
@@ -200,22 +200,37 @@ async function fetchEndpoints(client, { limit = 10, showFull = false } = {}) {
|
|
|
200
200
|
|
|
201
201
|
const variantName = primaryVariant.VariantName || 'AllTraffic';
|
|
202
202
|
let instanceType = primaryVariant.InstanceType || null;
|
|
203
|
+
let instancePools = primaryVariant.InstancePools || null;
|
|
203
204
|
|
|
204
205
|
// For IC-based endpoints, InstanceType may not be in the variant runtime response.
|
|
205
|
-
// Fall back to DescribeEndpointConfig which
|
|
206
|
-
if (!instanceType && detail.EndpointConfigName) {
|
|
206
|
+
// Fall back to DescribeEndpointConfig which has either InstanceType or InstancePools.
|
|
207
|
+
if (!instanceType && !instancePools && detail.EndpointConfigName) {
|
|
207
208
|
try {
|
|
208
209
|
const ecCmd = new _DescribeEndpointConfigCommand({ EndpointConfigName: detail.EndpointConfigName });
|
|
209
210
|
const ecDetail = await client.send(ecCmd);
|
|
210
211
|
const ecVariant = (ecDetail.ProductionVariants || [])[0];
|
|
211
212
|
if (ecVariant?.InstanceType) {
|
|
212
213
|
instanceType = ecVariant.InstanceType;
|
|
214
|
+
} else if (ecVariant?.InstancePools && ecVariant.InstancePools.length > 0) {
|
|
215
|
+
instancePools = ecVariant.InstancePools;
|
|
213
216
|
}
|
|
214
217
|
} catch (ecErr) {
|
|
215
218
|
log(`Warning: could not describe endpoint config for "${endpointName}": ${ecErr.message}`);
|
|
216
219
|
}
|
|
217
220
|
}
|
|
218
|
-
|
|
221
|
+
|
|
222
|
+
// Resolve instanceType display string from pools if needed
|
|
223
|
+
if (!instanceType && instancePools && instancePools.length > 0) {
|
|
224
|
+
// Sort by priority, use highest-priority (lowest number) for GPU lookup
|
|
225
|
+
const sorted = [...instancePools].sort((a, b) => (a.Priority || 99) - (b.Priority || 99));
|
|
226
|
+
instanceType = sorted[0].InstanceType || 'unknown';
|
|
227
|
+
// Build display string showing the pool: "ml.g5.12xl (pool: 3 types)"
|
|
228
|
+
if (sorted.length > 1) {
|
|
229
|
+
instanceType = `${instanceType} (pool: ${sorted.length} types)`;
|
|
230
|
+
}
|
|
231
|
+
} else {
|
|
232
|
+
instanceType = instanceType || 'unknown';
|
|
233
|
+
}
|
|
219
234
|
|
|
220
235
|
const instanceCount = primaryVariant.CurrentInstanceCount ?? primaryVariant.DesiredInstanceCount ?? 1;
|
|
221
236
|
const hasInstancePools = !!(primaryVariant.InstancePools && primaryVariant.InstancePools.length > 0);
|
|
@@ -244,7 +259,12 @@ async function fetchEndpoints(client, { limit = 10, showFull = false } = {}) {
|
|
|
244
259
|
} while (icNextToken);
|
|
245
260
|
|
|
246
261
|
// Capacity estimation
|
|
247
|
-
|
|
262
|
+
// For pool endpoints, instanceType may be "ml.g5.12xlarge (pool: 3 types)"
|
|
263
|
+
// Extract the raw type for catalog lookup
|
|
264
|
+
const instanceTypeForLookup = instanceType.includes(' (pool:')
|
|
265
|
+
? instanceType.split(' (pool:')[0]
|
|
266
|
+
: instanceType;
|
|
267
|
+
const gpusPerInstance = getGpusForInstance(instanceTypeForLookup);
|
|
248
268
|
let availableGpus;
|
|
249
269
|
if (gpusPerInstance === null) {
|
|
250
270
|
availableGpus = '?';
|
|
@@ -52,6 +52,7 @@ export default class BootstrapCommandHandler {
|
|
|
52
52
|
_setupS3Buckets() { return this.provisioners._setupS3Buckets(); }
|
|
53
53
|
_createS3Bucket(name, tags) { return this.provisioners._createS3Bucket(name, tags); }
|
|
54
54
|
_verifyCliV2() { return this.provisioners._verifyCliV2(); }
|
|
55
|
+
_provisionAiRegistryHub(profileData) { return this.provisioners.provisionAiRegistryHub(profileData); }
|
|
55
56
|
|
|
56
57
|
// ── ProfileManager delegations (backward compat for tests) ──────
|
|
57
58
|
|
|
@@ -357,6 +358,9 @@ export default class BootstrapCommandHandler {
|
|
|
357
358
|
console.log(' Tune jobs will still work but experiment tracking may not be available.');
|
|
358
359
|
}
|
|
359
360
|
|
|
361
|
+
// Step 4c: AI Registry Hub
|
|
362
|
+
await this._provisionAiRegistryHub(profileData);
|
|
363
|
+
|
|
360
364
|
// Step 5: CI Infrastructure setup (separate CDK stack — unchanged)
|
|
361
365
|
this._displayProgress('🧪', 'CI Testing Infrastructure...');
|
|
362
366
|
try {
|
|
@@ -714,6 +718,10 @@ export default class BootstrapCommandHandler {
|
|
|
714
718
|
console.log(` ⚠️ MLflow App setup skipped: ${error.message}`);
|
|
715
719
|
}
|
|
716
720
|
|
|
721
|
+
// Ensure AI Registry hub exists
|
|
722
|
+
this._currentProfile = profileConfig.awsProfile;
|
|
723
|
+
await this._provisionAiRegistryHub(profileConfig);
|
|
724
|
+
|
|
717
725
|
// Save updated profile
|
|
718
726
|
this.config.setProfile(name, profileConfig);
|
|
719
727
|
console.log(`\n✅ Update complete for profile "${name}"`);
|
|
@@ -172,6 +172,23 @@ export default class BootstrapProfileManager {
|
|
|
172
172
|
}
|
|
173
173
|
}
|
|
174
174
|
|
|
175
|
+
// Check AI Registry hub status
|
|
176
|
+
if (profile.config.aiRegistryHubName) {
|
|
177
|
+
try {
|
|
178
|
+
const hubExists = this.handler._resourceExists(
|
|
179
|
+
`sagemaker describe-hub --hub-name ${profile.config.aiRegistryHubName} --region ${profile.config.awsRegion}`,
|
|
180
|
+
profile.config.awsProfile
|
|
181
|
+
);
|
|
182
|
+
console.log(hubExists
|
|
183
|
+
? ` ✅ AI Registry hub: ${profile.config.aiRegistryHubName}`
|
|
184
|
+
: ` ⚠️ AI Registry hub: ${profile.config.aiRegistryHubName} — missing`);
|
|
185
|
+
} catch {
|
|
186
|
+
console.log(` ⚠️ AI Registry hub: ${profile.config.aiRegistryHubName} — could not validate`);
|
|
187
|
+
}
|
|
188
|
+
} else {
|
|
189
|
+
console.log(' ℹ️ AI Registry hub: not provisioned (run bootstrap to create)');
|
|
190
|
+
}
|
|
191
|
+
|
|
175
192
|
// Display deployed resources from manifest
|
|
176
193
|
console.log('\n📦 Deployed Resources:');
|
|
177
194
|
|
|
@@ -405,6 +405,54 @@ export default class BootstrapProvisioners {
|
|
|
405
405
|
}
|
|
406
406
|
}
|
|
407
407
|
|
|
408
|
+
/**
|
|
409
|
+
* Provision a deterministic SageMaker AI Registry Hub.
|
|
410
|
+
* Idempotent: checks if `mlcc-registry-{accountId}` already exists before creating.
|
|
411
|
+
* Non-fatal: catches all errors and prints a warning — bootstrap continues regardless.
|
|
412
|
+
*
|
|
413
|
+
* @param {object} profileData - Profile data object (mutated in place with hub info)
|
|
414
|
+
*/
|
|
415
|
+
async provisionAiRegistryHub(profileData) {
|
|
416
|
+
const hubName = `mlcc-registry-${profileData.accountId}`;
|
|
417
|
+
const region = profileData.awsRegion;
|
|
418
|
+
|
|
419
|
+
console.log('\n📦 Provisioning AI Registry hub...');
|
|
420
|
+
|
|
421
|
+
try {
|
|
422
|
+
// Check if hub already exists (idempotent)
|
|
423
|
+
const hubExists = this.handler._resourceExists(
|
|
424
|
+
`sagemaker describe-hub --hub-name ${hubName} --region ${region}`,
|
|
425
|
+
this.handler._currentProfile
|
|
426
|
+
);
|
|
427
|
+
|
|
428
|
+
if (hubExists) {
|
|
429
|
+
const hubInfo = this.handler._execAws(
|
|
430
|
+
`sagemaker describe-hub --hub-name ${hubName} --region ${region}`,
|
|
431
|
+
this.handler._currentProfile
|
|
432
|
+
);
|
|
433
|
+
console.log(` ✅ AI Registry hub already provisioned: ${hubName}`);
|
|
434
|
+
profileData.aiRegistryHubName = hubName;
|
|
435
|
+
profileData.aiRegistryHubArn = hubInfo.HubArn;
|
|
436
|
+
return;
|
|
437
|
+
}
|
|
438
|
+
|
|
439
|
+
// Create new hub (always — no adopt-existing logic)
|
|
440
|
+
const tags = this._buildResourceTags();
|
|
441
|
+
const tagsFile = this.handler._formatTagsForCli(tags);
|
|
442
|
+
const createResult = this.handler._execAws(
|
|
443
|
+
`sagemaker create-hub --hub-name ${hubName} --hub-display-name "MCC AI Registry" --hub-description "Dataset, evaluator, and model versioning for ml-container-creator" --tags ${tagsFile} --region ${region}`,
|
|
444
|
+
this.handler._currentProfile
|
|
445
|
+
);
|
|
446
|
+
console.log(` ✅ AI Registry hub "${hubName}" — created`);
|
|
447
|
+
profileData.aiRegistryHubName = hubName;
|
|
448
|
+
profileData.aiRegistryHubArn = createResult.HubArn;
|
|
449
|
+
} catch (err) {
|
|
450
|
+
const message = err.message || String(err);
|
|
451
|
+
console.log(` ⚠️ Could not provision AI Registry hub (non-fatal): ${message}`);
|
|
452
|
+
console.log(' Dataset registration will use local JSON registry.');
|
|
453
|
+
}
|
|
454
|
+
}
|
|
455
|
+
|
|
408
456
|
/**
|
|
409
457
|
* Build the standard resource tag set.
|
|
410
458
|
* @returns {Array<{Key: string, Value: string}>} Tag array
|
|
@@ -16,6 +16,24 @@ const __dirname = dirname(__filename);
|
|
|
16
16
|
* classifies failures, gates tune/adapter stages, and builds
|
|
17
17
|
* Athena-compatible records with run_type='path_prove'.
|
|
18
18
|
*
|
|
19
|
+
* ## Module Status (AC-1.4)
|
|
20
|
+
*
|
|
21
|
+
* ALL exported functions are FULLY FUNCTIONAL:
|
|
22
|
+
* - `identifyGaps()` — Cartesian product gap finder, prioritized by neighbor count
|
|
23
|
+
* - `findNearestSubstitution()` — Hamming distance nearest-neighbor, same-family constraint
|
|
24
|
+
* - `classifyFailure()` — regex pattern matching to 6 categories (capacity, timeout, oom, code_bug, model_incompatibility, service_limitation)
|
|
25
|
+
* - `shouldExecuteTuneStages()` — gating logic for tune/adapter stages
|
|
26
|
+
* - `hammingDistance()` — config vector comparison across CONFIG_DIMENSIONS
|
|
27
|
+
* - `buildPathProverRecord()` — Athena record construction with run_type='path_prove'
|
|
28
|
+
* - `findUnfeasibleRecord()` — checks if a config is known-unfeasible to prevent repeated attempts
|
|
29
|
+
* - `getNextPriorityConfig()` — priority queue management for v1 validation mode
|
|
30
|
+
* - `updatePriorityStatus()` — updates target status after prove attempts
|
|
31
|
+
* - `getPriorityQueueStatus()` — summary counts for priority queue
|
|
32
|
+
* - `loadPriorityTargets()` — file-based priority target loading
|
|
33
|
+
* - `resolveProveTpDegree()` — TP degree auto-resolution from instance catalog
|
|
34
|
+
*
|
|
35
|
+
* This is stabilization (tests + docs), not implementation. No new logic needed.
|
|
36
|
+
*
|
|
19
37
|
* Feature: ci-benchmark-pipeline
|
|
20
38
|
* Requirements: 8.1–8.12
|
|
21
39
|
*/
|
|
@@ -611,6 +629,45 @@ export function loadPriorityTargets(configPath) {
|
|
|
611
629
|
}
|
|
612
630
|
}
|
|
613
631
|
|
|
632
|
+
// ── Optimization Space Schema (Task 3 — AC-3.5) ─────────────────────────────
|
|
633
|
+
|
|
634
|
+
/**
|
|
635
|
+
* Load the optimization search space schema from config/optimization-space.json.
|
|
636
|
+
*
|
|
637
|
+
* Returns the parsed schema with dimensions, version, and description.
|
|
638
|
+
* Used by gap identification to enumerate sweepable dimensions and their
|
|
639
|
+
* allowed values for the optimization/prove sweep.
|
|
640
|
+
*
|
|
641
|
+
* @returns {object|null} Parsed schema object, or null if file not found/invalid
|
|
642
|
+
*/
|
|
643
|
+
export function loadOptimizationSpace() {
|
|
644
|
+
try {
|
|
645
|
+
const schemaPath = resolve(__dirname, '..', '..', 'config', 'optimization-space.json');
|
|
646
|
+
const raw = readFileSync(schemaPath, 'utf8');
|
|
647
|
+
return JSON.parse(raw);
|
|
648
|
+
} catch {
|
|
649
|
+
return null;
|
|
650
|
+
}
|
|
651
|
+
}
|
|
652
|
+
|
|
653
|
+
/**
|
|
654
|
+
* Get the list of sweepable dimension names from the optimization space schema.
|
|
655
|
+
*
|
|
656
|
+
* Filters dimensions by status === 'sweepable' and returns their keys.
|
|
657
|
+
* Useful for verifying sync between CONFIG_DIMENSIONS and the schema.
|
|
658
|
+
*
|
|
659
|
+
* @param {object} [schema] - Pre-loaded schema (loads from file if omitted)
|
|
660
|
+
* @returns {string[]} Array of sweepable dimension names
|
|
661
|
+
*/
|
|
662
|
+
export function getSweepableDimensions(schema = null) {
|
|
663
|
+
const data = schema || loadOptimizationSpace();
|
|
664
|
+
if (!data || !data.dimensions) return [];
|
|
665
|
+
|
|
666
|
+
return Object.keys(data.dimensions).filter(
|
|
667
|
+
key => data.dimensions[key].status === 'sweepable'
|
|
668
|
+
);
|
|
669
|
+
}
|
|
670
|
+
|
|
614
671
|
// ── TP Degree Auto-Resolution at Prove-Time (Task 6.5) ──────────────────────
|
|
615
672
|
|
|
616
673
|
/**
|
|
@@ -8,6 +8,25 @@
|
|
|
8
8
|
* Handles stage-specific logic including idempotency checks, status tracking,
|
|
9
9
|
* and fail-fast behavior.
|
|
10
10
|
*
|
|
11
|
+
* ## Module Status (AC-1.4)
|
|
12
|
+
*
|
|
13
|
+
* FUNCTIONAL stages:
|
|
14
|
+
* - `executeStageStep()` — fully wired with idempotency via `.mlcc/staged-assets.json`
|
|
15
|
+
* - `isAlreadyStaged()` — checks staged assets existence and validity
|
|
16
|
+
* - `getStagingState()` — resolves current staging state from filesystem + step results
|
|
17
|
+
* - `isValidLifecycleStage()` — validates individual stage names
|
|
18
|
+
* - `validateStagesArray()` — validates arrays of stage names
|
|
19
|
+
* - `formatStagingStatus()` — formats staging state for display
|
|
20
|
+
* - `buildTargetStatus()` — builds status summary for a prove target
|
|
21
|
+
*
|
|
22
|
+
* INTENTIONALLY INCOMPLETE (post-v1 scope):
|
|
23
|
+
* - Other lifecycle stage executors (build, push, deploy, test, tune, adapter,
|
|
24
|
+
* test-adapter, benchmark, register, clean) are NOT implemented.
|
|
25
|
+
* - Only the `stage` step has execution logic. Other stages are recognized in
|
|
26
|
+
* validation but have no executor function.
|
|
27
|
+
* - This is not "broken" — these were never finished before the laptop was bricked.
|
|
28
|
+
* They are explicitly post-v1 scope.
|
|
29
|
+
*
|
|
11
30
|
* Feature: s3-model-loading
|
|
12
31
|
* Requirements: 5.1, 5.2, 5.3, 5.4, 5.5
|
|
13
32
|
*/
|
|
@@ -40,6 +59,22 @@ export const VALID_LIFECYCLE_STAGES = [
|
|
|
40
59
|
'clean'
|
|
41
60
|
];
|
|
42
61
|
|
|
62
|
+
// TODO(post-v1): Implement executor functions for lifecycle stages beyond 'stage'.
|
|
63
|
+
// The following stages are recognized for validation purposes but have no execution logic:
|
|
64
|
+
// - generate: Should invoke `mcc generate` to produce project scaffolding
|
|
65
|
+
// - build: Should run `do/build` to build the Docker container
|
|
66
|
+
// - push: Should run `do/push` to push container to ECR
|
|
67
|
+
// - deploy: Should run `do/deploy` to create SageMaker endpoint
|
|
68
|
+
// - test: Should run `do/test` to invoke endpoint and verify correctness
|
|
69
|
+
// - tune: Should run `do/tune` for fine-tuning jobs (gated by shouldExecuteTuneStages)
|
|
70
|
+
// - adapter: Should run `do/adapter` for LoRA adapter serving
|
|
71
|
+
// - test-adapter: Should test adapter endpoints after deployment
|
|
72
|
+
// - benchmark: Should run `do/benchmark` for performance measurement
|
|
73
|
+
// - register: Should register proven config in Athena/DynamoDB
|
|
74
|
+
// - clean: Should tear down deployed resources
|
|
75
|
+
// These were never finished before the original developer's laptop was bricked.
|
|
76
|
+
// They are explicitly post-v1 scope, not "broken" code.
|
|
77
|
+
|
|
43
78
|
/**
|
|
44
79
|
* Possible staging states for status output.
|
|
45
80
|
*/
|
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
#!/usr/bin/env python3
|
|
2
|
-
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
3
1
|
# SPDX-License-Identifier: Apache-2.0
|
|
4
2
|
|
|
5
3
|
"""Benchmark Writer — Converts do/benchmark output to enriched Parquet for Athena.
|
|
@@ -340,7 +338,7 @@ def _extract_base_image_version(base_image):
|
|
|
340
338
|
return ''
|
|
341
339
|
|
|
342
340
|
|
|
343
|
-
def enrich_records(config, results, run_timestamp=None):
|
|
341
|
+
def enrich_records(config, results, run_timestamp=None, instance_catalog=None):
|
|
344
342
|
"""Build enriched records from config context and benchmark results.
|
|
345
343
|
|
|
346
344
|
Each metrics entry becomes one enriched record with all Athena columns populated.
|
|
@@ -349,6 +347,7 @@ def enrich_records(config, results, run_timestamp=None):
|
|
|
349
347
|
config: dict with config context fields (project_name, model_name, etc.)
|
|
350
348
|
results: dict with benchmark results (job_name, metrics array)
|
|
351
349
|
run_timestamp: Optional datetime for run_timestamp. Defaults to now UTC.
|
|
350
|
+
instance_catalog: Optional pre-loaded instance catalog dict. If None, loaded from disk.
|
|
352
351
|
|
|
353
352
|
Returns:
|
|
354
353
|
list of enriched record dicts (one per concurrency level).
|
|
@@ -364,10 +363,21 @@ def enrich_records(config, results, run_timestamp=None):
|
|
|
364
363
|
|
|
365
364
|
# Derived fields
|
|
366
365
|
model_family = derive_model_family(model_name)
|
|
366
|
+
instance_family = derive_instance_family(instance_type)
|
|
367
|
+
|
|
368
|
+
# Resolve instance metadata from catalog (AC-2.8)
|
|
369
|
+
hw_meta = resolve_instance_metadata(instance_type, instance_catalog)
|
|
370
|
+
gpu_count = hw_meta['gpu_count']
|
|
371
|
+
gpu_type = hw_meta['gpu_type']
|
|
372
|
+
gpu_memory_gb = hw_meta['gpu_memory_gb']
|
|
367
373
|
|
|
368
374
|
# Optional context fields
|
|
369
375
|
deployment_target = config.get('deployment_target', 'realtime-inference')
|
|
370
|
-
|
|
376
|
+
try:
|
|
377
|
+
tensor_parallel_degree = int(config.get('tensor_parallel_degree', 1))
|
|
378
|
+
except (ValueError, TypeError):
|
|
379
|
+
tensor_parallel_degree = 1
|
|
380
|
+
|
|
371
381
|
quantization = config.get('quantization', 'none')
|
|
372
382
|
enable_lora = config.get('enable_lora', False)
|
|
373
383
|
base_image = config.get('base_image', '')
|
|
@@ -377,6 +387,11 @@ def enrich_records(config, results, run_timestamp=None):
|
|
|
377
387
|
ci_run_id = config.get('ci_run_id', '')
|
|
378
388
|
account_id = config.get('account_id', '')
|
|
379
389
|
|
|
390
|
+
# Configuration dimensions (nullable)
|
|
391
|
+
max_model_len_raw = config.get('max_model_len')
|
|
392
|
+
max_model_len = int(max_model_len_raw) if max_model_len_raw not in (None, '', 0) else None
|
|
393
|
+
kv_cache_dtype = config.get('kv_cache_dtype') or None
|
|
394
|
+
|
|
380
395
|
|
|
381
396
|
# Get metrics from results
|
|
382
397
|
metrics = results.get('metrics', []) if isinstance(results, dict) else []
|
|
@@ -447,6 +462,13 @@ def enrich_records(config, results, run_timestamp=None):
|
|
|
447
462
|
'deployment_target': deployment_target,
|
|
448
463
|
'quantization': quantization,
|
|
449
464
|
'tensor_parallel_degree': tensor_parallel_degree,
|
|
465
|
+
'instance_family': instance_family,
|
|
466
|
+
'gpu_count': gpu_count,
|
|
467
|
+
'gpu_type': gpu_type,
|
|
468
|
+
'gpu_memory_gb': gpu_memory_gb,
|
|
469
|
+
'max_model_len': max_model_len,
|
|
470
|
+
'enable_lora': enable_lora,
|
|
471
|
+
'kv_cache_dtype': kv_cache_dtype,
|
|
450
472
|
'serving_config': json.dumps(serving_config_dict),
|
|
451
473
|
'workload': config.get('workload', 'manual'),
|
|
452
474
|
'concurrency': concurrency,
|
|
@@ -481,6 +503,7 @@ def enrich_records(config, results, run_timestamp=None):
|
|
|
481
503
|
'output_sequence_length_avg': scalar(metric.get('output_sequence_length', metric.get('output_sequence_length_avg', 0.0))),
|
|
482
504
|
'input_sequence_length_avg': scalar(metric.get('input_sequence_length', metric.get('input_sequence_length_avg', 0.0))),
|
|
483
505
|
'error_rate': error_rate,
|
|
506
|
+
'cost_per_1m_tokens': cost,
|
|
484
507
|
'benchmark_duration_sec': metric.get('benchmark_duration_sec', duration_seconds),
|
|
485
508
|
'run_type': run_type,
|
|
486
509
|
'benchmark_job_name': results.get('job_name', '') if isinstance(results, dict) else '',
|
|
@@ -792,6 +815,54 @@ def register_partition(bucket, model, instance, target,
|
|
|
792
815
|
# ── Parquet Serialization ─────────────────────────────────────────────────────
|
|
793
816
|
|
|
794
817
|
|
|
818
|
+
def load_instance_catalog():
|
|
819
|
+
"""Load the instance catalog from servers/lib/catalogs/instances.json.
|
|
820
|
+
|
|
821
|
+
Resolves the path relative to the project root (two levels up from templates/do/).
|
|
822
|
+
Returns the 'catalog' dict mapping instance_type → metadata, or empty dict on failure.
|
|
823
|
+
|
|
824
|
+
Returns:
|
|
825
|
+
dict mapping instance type strings to their metadata dicts.
|
|
826
|
+
"""
|
|
827
|
+
# Resolve relative to this file: templates/do/.benchmark_writer.py → project root
|
|
828
|
+
this_dir = os.path.dirname(os.path.abspath(__file__))
|
|
829
|
+
# Navigate up from templates/do/ to project root
|
|
830
|
+
project_root = os.path.normpath(os.path.join(this_dir, '..', '..'))
|
|
831
|
+
catalog_path = os.path.join(project_root, 'servers', 'lib', 'catalogs', 'instances.json')
|
|
832
|
+
|
|
833
|
+
try:
|
|
834
|
+
with open(catalog_path, 'r') as f:
|
|
835
|
+
data = json.load(f)
|
|
836
|
+
return data.get('catalog', {})
|
|
837
|
+
except (FileNotFoundError, json.JSONDecodeError, IOError):
|
|
838
|
+
return {}
|
|
839
|
+
|
|
840
|
+
|
|
841
|
+
def resolve_instance_metadata(instance_type, instance_catalog=None):
|
|
842
|
+
"""Resolve GPU metadata from the instance catalog for a given instance_type.
|
|
843
|
+
|
|
844
|
+
Args:
|
|
845
|
+
instance_type: SageMaker instance type (e.g., 'ml.g5.xlarge').
|
|
846
|
+
instance_catalog: Optional pre-loaded catalog dict. If None, loads from disk.
|
|
847
|
+
|
|
848
|
+
Returns:
|
|
849
|
+
dict with keys: gpu_count (int|None), gpu_type (str|None), gpu_memory_gb (float|None).
|
|
850
|
+
All values are None if instance_type is not found in catalog.
|
|
851
|
+
"""
|
|
852
|
+
if instance_catalog is None:
|
|
853
|
+
instance_catalog = load_instance_catalog()
|
|
854
|
+
|
|
855
|
+
entry = instance_catalog.get(instance_type)
|
|
856
|
+
if entry is None:
|
|
857
|
+
return {'gpu_count': None, 'gpu_type': None, 'gpu_memory_gb': None}
|
|
858
|
+
|
|
859
|
+
return {
|
|
860
|
+
'gpu_count': entry.get('gpus'),
|
|
861
|
+
'gpu_type': entry.get('gpuType'),
|
|
862
|
+
'gpu_memory_gb': entry.get('gpuMemoryGb'),
|
|
863
|
+
}
|
|
864
|
+
|
|
865
|
+
|
|
795
866
|
def get_parquet_schema():
|
|
796
867
|
"""Return the pyarrow schema matching the Athena DDL for benchmark_results.
|
|
797
868
|
|
|
@@ -814,6 +885,17 @@ def get_parquet_schema():
|
|
|
814
885
|
pa.field("quantization", pa.string()),
|
|
815
886
|
pa.field("tensor_parallel_degree", pa.int32()),
|
|
816
887
|
|
|
888
|
+
# Hardware metadata (resolved from instance catalog at write time)
|
|
889
|
+
pa.field("instance_family", pa.string()),
|
|
890
|
+
pa.field("gpu_count", pa.int32()),
|
|
891
|
+
pa.field("gpu_type", pa.string()),
|
|
892
|
+
pa.field("gpu_memory_gb", pa.float64()),
|
|
893
|
+
|
|
894
|
+
# Configuration dimensions (top-level for Athena queryability)
|
|
895
|
+
pa.field("max_model_len", pa.int32()),
|
|
896
|
+
pa.field("enable_lora", pa.bool_()),
|
|
897
|
+
pa.field("kv_cache_dtype", pa.string()),
|
|
898
|
+
|
|
817
899
|
# Full serving config (extensible JSON blob)
|
|
818
900
|
pa.field("serving_config", pa.string()),
|
|
819
901
|
|
|
@@ -852,6 +934,7 @@ def get_parquet_schema():
|
|
|
852
934
|
pa.field("output_sequence_length_avg", pa.float64()),
|
|
853
935
|
pa.field("input_sequence_length_avg", pa.float64()),
|
|
854
936
|
pa.field("error_rate", pa.float64()),
|
|
937
|
+
pa.field("cost_per_1m_tokens", pa.float64()),
|
|
855
938
|
pa.field("benchmark_duration_sec", pa.float64()),
|
|
856
939
|
|
|
857
940
|
# Run Metadata
|
|
@@ -1182,6 +1265,9 @@ def cmd_write(args):
|
|
|
1182
1265
|
if args.adapter_name:
|
|
1183
1266
|
input_data['adapter_name'] = args.adapter_name
|
|
1184
1267
|
|
|
1268
|
+
if getattr(args, 'instance_type', None):
|
|
1269
|
+
input_data['instance_type'] = args.instance_type
|
|
1270
|
+
|
|
1185
1271
|
# ── Validate before any S3 interaction ────────────────────────────────
|
|
1186
1272
|
errors = validate_benchmark_input(input_data)
|
|
1187
1273
|
if errors:
|
|
@@ -1391,6 +1477,8 @@ def _load_config_file(config_path):
|
|
|
1391
1477
|
'MODEL_NAME': 'model_name',
|
|
1392
1478
|
'HF_MODEL_ID': 'hf_model_id',
|
|
1393
1479
|
'INSTANCE_TYPE': 'instance_type',
|
|
1480
|
+
'INSTANCE_POOLS': 'instance_pools',
|
|
1481
|
+
'BENCHMARK_INSTANCE_TYPE': 'benchmark_instance_type',
|
|
1394
1482
|
'DEPLOYMENT_CONFIG': 'deployment_config',
|
|
1395
1483
|
'DEPLOYMENT_TARGET': 'deployment_target',
|
|
1396
1484
|
'AWS_REGION': 'region',
|
|
@@ -1429,6 +1517,24 @@ def _load_config_file(config_path):
|
|
|
1429
1517
|
parts = context['model_name'].rstrip('/').split('/')
|
|
1430
1518
|
context['model_name'] = parts[-1] if parts else context['model_name']
|
|
1431
1519
|
|
|
1520
|
+
# Resolve instance_type precedence:
|
|
1521
|
+
# BENCHMARK_INSTANCE_TYPE (live-resolved, persisted by do/benchmark) > INSTANCE_TYPE > INSTANCE_POOLS fallback
|
|
1522
|
+
if context.get('benchmark_instance_type'):
|
|
1523
|
+
context['instance_type'] = context.pop('benchmark_instance_type')
|
|
1524
|
+
# Fall back to INSTANCE_POOLS when neither is set.
|
|
1525
|
+
# Heterogeneous pool configs may not have a standalone INSTANCE_TYPE value
|
|
1526
|
+
# but always define INSTANCE_POOLS as a JSON array with Priority fields.
|
|
1527
|
+
if not context.get('instance_type') and context.get('instance_pools'):
|
|
1528
|
+
try:
|
|
1529
|
+
pools = json.loads(context['instance_pools'])
|
|
1530
|
+
if pools:
|
|
1531
|
+
# Pick the highest-priority (lowest number) instance
|
|
1532
|
+
best = min(pools, key=lambda p: p.get('Priority', 999))
|
|
1533
|
+
context['instance_type'] = best.get('InstanceType', '')
|
|
1534
|
+
except (json.JSONDecodeError, TypeError, KeyError):
|
|
1535
|
+
pass
|
|
1536
|
+
context.pop('instance_pools', None) # Don't leak raw JSON into record
|
|
1537
|
+
|
|
1432
1538
|
# Also scan IC config files (do/ic/*.conf) for IC_ENV_* serving params
|
|
1433
1539
|
# These override do/config values for serving-specific settings
|
|
1434
1540
|
try:
|
|
@@ -1505,6 +1611,10 @@ def main():
|
|
|
1505
1611
|
help='LoRA adapter name (differentiates adapter benchmarks from base model in Athena)'
|
|
1506
1612
|
)
|
|
1507
1613
|
|
|
1614
|
+
write_parser.add_argument(
|
|
1615
|
+
'--instance-type', dest='instance_type', default=None,
|
|
1616
|
+
help='Override instance type (use when actual provisioned instance differs from config, e.g. heterogeneous pools)'
|
|
1617
|
+
)
|
|
1508
1618
|
write_parser.add_argument(
|
|
1509
1619
|
'--dry-run', dest='dry_run', action='store_true',
|
|
1510
1620
|
help='Output enriched records as JSON without writing to S3'
|