@aws/ml-container-creator 1.0.0 → 1.0.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/cli.js +1 -1
- package/config/tune-catalog.json +303 -1
- package/package.json +2 -1
- package/servers/endpoint-picker/index.js +24 -4
- package/servers/lib/catalogs/model-servers.json +334 -120
- package/src/lib/bootstrap-command-handler.js +20 -2
- package/src/lib/bootstrap-profile-manager.js +33 -0
- package/src/lib/bootstrap-provisioners.js +48 -0
- package/src/lib/cross-cutting-checker.js +6 -1
- package/src/lib/generated/cli-options.js +1 -1
- package/src/lib/generated/parameter-matrix.js +1 -1
- package/src/lib/generated/validation-rules.js +1 -1
- package/src/lib/path-prover-brain.js +57 -0
- package/src/lib/prove-pipeline-executor.js +35 -0
- package/templates/do/.benchmark_writer.py +114 -4
- package/templates/do/.register_helper.py +643 -67
- package/templates/do/.stage_helper.py +1 -0
- package/templates/do/__pycache__/.benchmark_writer.cpython-312.pyc +0 -0
- package/templates/do/__pycache__/.register_helper.cpython-312.pyc +0 -0
- package/templates/do/__pycache__/.tune_helper.cpython-312.pyc +0 -0
- package/templates/do/adapter +267 -171
- package/templates/do/benchmark +60 -5
- package/templates/do/config +1 -1
- package/templates/do/lib/inference-component.sh +6 -25
- package/templates/do/register +29 -2
- package/templates/do/tune +94 -12
|
@@ -52,6 +52,7 @@ export default class BootstrapCommandHandler {
|
|
|
52
52
|
_setupS3Buckets() { return this.provisioners._setupS3Buckets(); }
|
|
53
53
|
_createS3Bucket(name, tags) { return this.provisioners._createS3Bucket(name, tags); }
|
|
54
54
|
_verifyCliV2() { return this.provisioners._verifyCliV2(); }
|
|
55
|
+
_provisionAiRegistryHub(profileData) { return this.provisioners.provisionAiRegistryHub(profileData); }
|
|
55
56
|
|
|
56
57
|
// ── ProfileManager delegations (backward compat for tests) ──────
|
|
57
58
|
|
|
@@ -63,6 +64,7 @@ export default class BootstrapCommandHandler {
|
|
|
63
64
|
_handlePrune() { return this.profileManager._handlePrune(); }
|
|
64
65
|
_handleSyncSchemas() { return this.profileManager._handleSyncSchemas(); }
|
|
65
66
|
_handleSyncModelFamilies() { return this.profileManager._handleSyncModelFamilies(); }
|
|
67
|
+
_handleSyncServingVersions() { return this.profileManager._handleSyncServingVersions(); }
|
|
66
68
|
|
|
67
69
|
/**
|
|
68
70
|
* Dispatch bootstrap subcommands.
|
|
@@ -131,6 +133,9 @@ export default class BootstrapCommandHandler {
|
|
|
131
133
|
case 'sync-model-families':
|
|
132
134
|
await this._handleSyncModelFamilies();
|
|
133
135
|
break;
|
|
136
|
+
case 'sync-serving-versions':
|
|
137
|
+
await this._handleSyncServingVersions();
|
|
138
|
+
break;
|
|
134
139
|
// Migration path: upgrades legacy profiles to current naming conventions.
|
|
135
140
|
// Corrects stackName to mlcc-bootstrap-{profileName}, renames sharedStackFrom
|
|
136
141
|
// to sharedInfraFrom. Idempotent — safe to run multiple times.
|
|
@@ -357,6 +362,9 @@ export default class BootstrapCommandHandler {
|
|
|
357
362
|
console.log(' Tune jobs will still work but experiment tracking may not be available.');
|
|
358
363
|
}
|
|
359
364
|
|
|
365
|
+
// Step 4c: AI Registry Hub
|
|
366
|
+
await this._provisionAiRegistryHub(profileData);
|
|
367
|
+
|
|
360
368
|
// Step 5: CI Infrastructure setup (separate CDK stack — unchanged)
|
|
361
369
|
this._displayProgress('🧪', 'CI Testing Infrastructure...');
|
|
362
370
|
try {
|
|
@@ -714,6 +722,10 @@ export default class BootstrapCommandHandler {
|
|
|
714
722
|
console.log(` ⚠️ MLflow App setup skipped: ${error.message}`);
|
|
715
723
|
}
|
|
716
724
|
|
|
725
|
+
// Ensure AI Registry hub exists
|
|
726
|
+
this._currentProfile = profileConfig.awsProfile;
|
|
727
|
+
await this._provisionAiRegistryHub(profileConfig);
|
|
728
|
+
|
|
717
729
|
// Save updated profile
|
|
718
730
|
this.config.setProfile(name, profileConfig);
|
|
719
731
|
console.log(`\n✅ Update complete for profile "${name}"`);
|
|
@@ -1459,7 +1471,9 @@ SUBCOMMANDS:
|
|
|
1459
1471
|
prune Remove deleted and unknown records from the deployment manifest
|
|
1460
1472
|
update Re-deploy bootstrap stacks using active profile (no prompts)
|
|
1461
1473
|
migrate Upgrade legacy profiles to current naming conventions
|
|
1474
|
+
sync-schemas Download AWS service model schemas (sagemaker, iam, ecr, s3)
|
|
1462
1475
|
sync-model-families Discover tune-eligible models from JumpStart Hub and update catalog
|
|
1476
|
+
sync-serving-versions Discover latest vLLM/SGLang/TRT-LLM image versions and update catalog
|
|
1463
1477
|
|
|
1464
1478
|
SETUP OPTIONS:
|
|
1465
1479
|
--non-interactive Run without interactive prompts
|
|
@@ -1469,8 +1483,10 @@ SETUP OPTIONS:
|
|
|
1469
1483
|
--role-arn <arn> Use existing IAM role ARN (skip role creation)
|
|
1470
1484
|
--skip-s3 Skip S3 bucket creation
|
|
1471
1485
|
--ci Provision CI testing infrastructure
|
|
1486
|
+
--benchmark-infra Provision Athena/Glue benchmark infrastructure (requires --ci)
|
|
1472
1487
|
--skip-ci Skip CI infrastructure provisioning
|
|
1473
1488
|
--skip-post-setup Skip post-setup chain (mcp init, sync-architectures, sync-schemas)
|
|
1489
|
+
--ignore-staleness Suppress schema staleness warnings
|
|
1474
1490
|
|
|
1475
1491
|
STATUS OPTIONS:
|
|
1476
1492
|
--verify Check each active resource against AWS APIs for drift detection
|
|
@@ -1487,13 +1503,15 @@ EXAMPLES:
|
|
|
1487
1503
|
ml-container-creator bootstrap list
|
|
1488
1504
|
ml-container-creator bootstrap remove dev
|
|
1489
1505
|
ml-container-creator bootstrap remove dev --force --delete-stack
|
|
1506
|
+
ml-container-creator bootstrap update
|
|
1507
|
+
ml-container-creator bootstrap update --ci --benchmark-infra
|
|
1490
1508
|
ml-container-creator bootstrap scan
|
|
1509
|
+
ml-container-creator bootstrap sync-schemas
|
|
1491
1510
|
ml-container-creator bootstrap sync-model-families
|
|
1511
|
+
ml-container-creator bootstrap sync-serving-versions
|
|
1492
1512
|
ml-container-creator bootstrap migrate
|
|
1493
1513
|
ml-container-creator bootstrap --non-interactive --profile my-aws-profile --region us-west-2
|
|
1494
|
-
ml-container-creator bootstrap --non-interactive --profile my-aws-profile --role-arn arn:aws:iam::123456789012:role/MyRole --skip-s3
|
|
1495
1514
|
ml-container-creator bootstrap --non-interactive --profile my-aws-profile --region us-west-2 --ci
|
|
1496
|
-
ml-container-creator bootstrap --non-interactive --profile my-aws-profile --region us-west-2 --skip-ci
|
|
1497
1515
|
`);
|
|
1498
1516
|
}
|
|
1499
1517
|
|
|
@@ -172,6 +172,23 @@ export default class BootstrapProfileManager {
|
|
|
172
172
|
}
|
|
173
173
|
}
|
|
174
174
|
|
|
175
|
+
// Check AI Registry hub status
|
|
176
|
+
if (profile.config.aiRegistryHubName) {
|
|
177
|
+
try {
|
|
178
|
+
const hubExists = this.handler._resourceExists(
|
|
179
|
+
`sagemaker describe-hub --hub-name ${profile.config.aiRegistryHubName} --region ${profile.config.awsRegion}`,
|
|
180
|
+
profile.config.awsProfile
|
|
181
|
+
);
|
|
182
|
+
console.log(hubExists
|
|
183
|
+
? ` ✅ AI Registry hub: ${profile.config.aiRegistryHubName}`
|
|
184
|
+
: ` ⚠️ AI Registry hub: ${profile.config.aiRegistryHubName} — missing`);
|
|
185
|
+
} catch {
|
|
186
|
+
console.log(` ⚠️ AI Registry hub: ${profile.config.aiRegistryHubName} — could not validate`);
|
|
187
|
+
}
|
|
188
|
+
} else {
|
|
189
|
+
console.log(' ℹ️ AI Registry hub: not provisioned (run bootstrap to create)');
|
|
190
|
+
}
|
|
191
|
+
|
|
175
192
|
// Display deployed resources from manifest
|
|
176
193
|
console.log('\n📦 Deployed Resources:');
|
|
177
194
|
|
|
@@ -638,4 +655,20 @@ export default class BootstrapProfileManager {
|
|
|
638
655
|
process.exit(1);
|
|
639
656
|
}
|
|
640
657
|
}
|
|
658
|
+
|
|
659
|
+
/**
|
|
660
|
+
* Handle sync-serving-versions subcommand: discover latest container image
|
|
661
|
+
* versions for vLLM, SGLang, and TensorRT-LLM and update the model-servers catalog.
|
|
662
|
+
*/
|
|
663
|
+
async _handleSyncServingVersions() {
|
|
664
|
+
console.log('\n🔄 Sync Serving Versions — Discovering latest container images...\n');
|
|
665
|
+
try {
|
|
666
|
+
const { syncServingVersions } = await import('../../scripts/sync-serving-versions.js');
|
|
667
|
+
const result = await syncServingVersions();
|
|
668
|
+
console.log(`\n✅ Sync complete: ${result.totalAdded} new, ${result.totalRemoved} pruned\n`);
|
|
669
|
+
} catch (err) {
|
|
670
|
+
console.log(`❌ Sync failed: ${err.message}`);
|
|
671
|
+
process.exit(1);
|
|
672
|
+
}
|
|
673
|
+
}
|
|
641
674
|
}
|
|
@@ -405,6 +405,54 @@ export default class BootstrapProvisioners {
|
|
|
405
405
|
}
|
|
406
406
|
}
|
|
407
407
|
|
|
408
|
+
/**
|
|
409
|
+
* Provision a deterministic SageMaker AI Registry Hub.
|
|
410
|
+
* Idempotent: checks if `mlcc-registry-{accountId}` already exists before creating.
|
|
411
|
+
* Non-fatal: catches all errors and prints a warning — bootstrap continues regardless.
|
|
412
|
+
*
|
|
413
|
+
* @param {object} profileData - Profile data object (mutated in place with hub info)
|
|
414
|
+
*/
|
|
415
|
+
async provisionAiRegistryHub(profileData) {
|
|
416
|
+
const hubName = `mlcc-registry-${profileData.accountId}`;
|
|
417
|
+
const region = profileData.awsRegion;
|
|
418
|
+
|
|
419
|
+
console.log('\n📦 Provisioning AI Registry hub...');
|
|
420
|
+
|
|
421
|
+
try {
|
|
422
|
+
// Check if hub already exists (idempotent)
|
|
423
|
+
const hubExists = this.handler._resourceExists(
|
|
424
|
+
`sagemaker describe-hub --hub-name ${hubName} --region ${region}`,
|
|
425
|
+
this.handler._currentProfile
|
|
426
|
+
);
|
|
427
|
+
|
|
428
|
+
if (hubExists) {
|
|
429
|
+
const hubInfo = this.handler._execAws(
|
|
430
|
+
`sagemaker describe-hub --hub-name ${hubName} --region ${region}`,
|
|
431
|
+
this.handler._currentProfile
|
|
432
|
+
);
|
|
433
|
+
console.log(` ✅ AI Registry hub already provisioned: ${hubName}`);
|
|
434
|
+
profileData.aiRegistryHubName = hubName;
|
|
435
|
+
profileData.aiRegistryHubArn = hubInfo.HubArn;
|
|
436
|
+
return;
|
|
437
|
+
}
|
|
438
|
+
|
|
439
|
+
// Create new hub (always — no adopt-existing logic)
|
|
440
|
+
const tags = this._buildResourceTags();
|
|
441
|
+
const tagsFile = this.handler._formatTagsForCli(tags);
|
|
442
|
+
const createResult = this.handler._execAws(
|
|
443
|
+
`sagemaker create-hub --hub-name ${hubName} --hub-display-name "MCC AI Registry" --hub-description "Dataset, evaluator, and model versioning for ml-container-creator" --tags ${tagsFile} --region ${region}`,
|
|
444
|
+
this.handler._currentProfile
|
|
445
|
+
);
|
|
446
|
+
console.log(` ✅ AI Registry hub "${hubName}" — created`);
|
|
447
|
+
profileData.aiRegistryHubName = hubName;
|
|
448
|
+
profileData.aiRegistryHubArn = createResult.HubArn;
|
|
449
|
+
} catch (err) {
|
|
450
|
+
const message = err.message || String(err);
|
|
451
|
+
console.log(` ⚠️ Could not provision AI Registry hub (non-fatal): ${message}`);
|
|
452
|
+
console.log(' Dataset registration will use local JSON registry.');
|
|
453
|
+
}
|
|
454
|
+
}
|
|
455
|
+
|
|
408
456
|
/**
|
|
409
457
|
* Build the standard resource tag set.
|
|
410
458
|
* @returns {Array<{Key: string, Value: string}>} Tag array
|
|
@@ -290,7 +290,12 @@ export default class CrossCuttingChecker {
|
|
|
290
290
|
if (!modelType || !server || !serverVersion) return findings;
|
|
291
291
|
|
|
292
292
|
const entries = modelServersCatalog[server] || [];
|
|
293
|
-
|
|
293
|
+
// Try exact version match first, then fall back to nearest entry with supportedModelTypes
|
|
294
|
+
let entry = entries.find(e => e.labels?.framework_version === serverVersion);
|
|
295
|
+
if (!entry?.supportedModelTypes?.length) {
|
|
296
|
+
// Fall back to any entry that has supportedModelTypes populated
|
|
297
|
+
entry = entries.find(e => e.supportedModelTypes?.length > 0);
|
|
298
|
+
}
|
|
294
299
|
if (!entry?.supportedModelTypes?.length) return findings;
|
|
295
300
|
|
|
296
301
|
if (!entry.supportedModelTypes.includes(modelType.toLowerCase())) {
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
// AUTO-GENERATED by scripts/codegen-parameter-matrix.js — DO NOT EDIT
|
|
2
2
|
// Source: config/parameter-schema-v2.json
|
|
3
|
-
// Generated: 2026-06-
|
|
3
|
+
// Generated: 2026-06-29T13:37:06.375Z
|
|
4
4
|
|
|
5
5
|
/**
|
|
6
6
|
* Parameter matrix defining how each parameter is loaded from various sources.
|
|
@@ -16,6 +16,24 @@ const __dirname = dirname(__filename);
|
|
|
16
16
|
* classifies failures, gates tune/adapter stages, and builds
|
|
17
17
|
* Athena-compatible records with run_type='path_prove'.
|
|
18
18
|
*
|
|
19
|
+
* ## Module Status (AC-1.4)
|
|
20
|
+
*
|
|
21
|
+
* ALL exported functions are FULLY FUNCTIONAL:
|
|
22
|
+
* - `identifyGaps()` — Cartesian product gap finder, prioritized by neighbor count
|
|
23
|
+
* - `findNearestSubstitution()` — Hamming distance nearest-neighbor, same-family constraint
|
|
24
|
+
* - `classifyFailure()` — regex pattern matching to 6 categories (capacity, timeout, oom, code_bug, model_incompatibility, service_limitation)
|
|
25
|
+
* - `shouldExecuteTuneStages()` — gating logic for tune/adapter stages
|
|
26
|
+
* - `hammingDistance()` — config vector comparison across CONFIG_DIMENSIONS
|
|
27
|
+
* - `buildPathProverRecord()` — Athena record construction with run_type='path_prove'
|
|
28
|
+
* - `findUnfeasibleRecord()` — checks if a config is known-unfeasible to prevent repeated attempts
|
|
29
|
+
* - `getNextPriorityConfig()` — priority queue management for v1 validation mode
|
|
30
|
+
* - `updatePriorityStatus()` — updates target status after prove attempts
|
|
31
|
+
* - `getPriorityQueueStatus()` — summary counts for priority queue
|
|
32
|
+
* - `loadPriorityTargets()` — file-based priority target loading
|
|
33
|
+
* - `resolveProveTpDegree()` — TP degree auto-resolution from instance catalog
|
|
34
|
+
*
|
|
35
|
+
* This is stabilization (tests + docs), not implementation. No new logic needed.
|
|
36
|
+
*
|
|
19
37
|
* Feature: ci-benchmark-pipeline
|
|
20
38
|
* Requirements: 8.1–8.12
|
|
21
39
|
*/
|
|
@@ -611,6 +629,45 @@ export function loadPriorityTargets(configPath) {
|
|
|
611
629
|
}
|
|
612
630
|
}
|
|
613
631
|
|
|
632
|
+
// ── Optimization Space Schema (Task 3 — AC-3.5) ─────────────────────────────
|
|
633
|
+
|
|
634
|
+
/**
|
|
635
|
+
* Load the optimization search space schema from config/optimization-space.json.
|
|
636
|
+
*
|
|
637
|
+
* Returns the parsed schema with dimensions, version, and description.
|
|
638
|
+
* Used by gap identification to enumerate sweepable dimensions and their
|
|
639
|
+
* allowed values for the optimization/prove sweep.
|
|
640
|
+
*
|
|
641
|
+
* @returns {object|null} Parsed schema object, or null if file not found/invalid
|
|
642
|
+
*/
|
|
643
|
+
export function loadOptimizationSpace() {
|
|
644
|
+
try {
|
|
645
|
+
const schemaPath = resolve(__dirname, '..', '..', 'config', 'optimization-space.json');
|
|
646
|
+
const raw = readFileSync(schemaPath, 'utf8');
|
|
647
|
+
return JSON.parse(raw);
|
|
648
|
+
} catch {
|
|
649
|
+
return null;
|
|
650
|
+
}
|
|
651
|
+
}
|
|
652
|
+
|
|
653
|
+
/**
|
|
654
|
+
* Get the list of sweepable dimension names from the optimization space schema.
|
|
655
|
+
*
|
|
656
|
+
* Filters dimensions by status === 'sweepable' and returns their keys.
|
|
657
|
+
* Useful for verifying sync between CONFIG_DIMENSIONS and the schema.
|
|
658
|
+
*
|
|
659
|
+
* @param {object} [schema] - Pre-loaded schema (loads from file if omitted)
|
|
660
|
+
* @returns {string[]} Array of sweepable dimension names
|
|
661
|
+
*/
|
|
662
|
+
export function getSweepableDimensions(schema = null) {
|
|
663
|
+
const data = schema || loadOptimizationSpace();
|
|
664
|
+
if (!data || !data.dimensions) return [];
|
|
665
|
+
|
|
666
|
+
return Object.keys(data.dimensions).filter(
|
|
667
|
+
key => data.dimensions[key].status === 'sweepable'
|
|
668
|
+
);
|
|
669
|
+
}
|
|
670
|
+
|
|
614
671
|
// ── TP Degree Auto-Resolution at Prove-Time (Task 6.5) ──────────────────────
|
|
615
672
|
|
|
616
673
|
/**
|
|
@@ -8,6 +8,25 @@
|
|
|
8
8
|
* Handles stage-specific logic including idempotency checks, status tracking,
|
|
9
9
|
* and fail-fast behavior.
|
|
10
10
|
*
|
|
11
|
+
* ## Module Status (AC-1.4)
|
|
12
|
+
*
|
|
13
|
+
* FUNCTIONAL stages:
|
|
14
|
+
* - `executeStageStep()` — fully wired with idempotency via `.mlcc/staged-assets.json`
|
|
15
|
+
* - `isAlreadyStaged()` — checks staged assets existence and validity
|
|
16
|
+
* - `getStagingState()` — resolves current staging state from filesystem + step results
|
|
17
|
+
* - `isValidLifecycleStage()` — validates individual stage names
|
|
18
|
+
* - `validateStagesArray()` — validates arrays of stage names
|
|
19
|
+
* - `formatStagingStatus()` — formats staging state for display
|
|
20
|
+
* - `buildTargetStatus()` — builds status summary for a prove target
|
|
21
|
+
*
|
|
22
|
+
* INTENTIONALLY INCOMPLETE (post-v1 scope):
|
|
23
|
+
* - Other lifecycle stage executors (build, push, deploy, test, tune, adapter,
|
|
24
|
+
* test-adapter, benchmark, register, clean) are NOT implemented.
|
|
25
|
+
* - Only the `stage` step has execution logic. Other stages are recognized in
|
|
26
|
+
* validation but have no executor function.
|
|
27
|
+
* - This is not "broken" — these were never finished before the laptop was bricked.
|
|
28
|
+
* They are explicitly post-v1 scope.
|
|
29
|
+
*
|
|
11
30
|
* Feature: s3-model-loading
|
|
12
31
|
* Requirements: 5.1, 5.2, 5.3, 5.4, 5.5
|
|
13
32
|
*/
|
|
@@ -40,6 +59,22 @@ export const VALID_LIFECYCLE_STAGES = [
|
|
|
40
59
|
'clean'
|
|
41
60
|
];
|
|
42
61
|
|
|
62
|
+
// TODO(post-v1): Implement executor functions for lifecycle stages beyond 'stage'.
|
|
63
|
+
// The following stages are recognized for validation purposes but have no execution logic:
|
|
64
|
+
// - generate: Should invoke `mcc generate` to produce project scaffolding
|
|
65
|
+
// - build: Should run `do/build` to build the Docker container
|
|
66
|
+
// - push: Should run `do/push` to push container to ECR
|
|
67
|
+
// - deploy: Should run `do/deploy` to create SageMaker endpoint
|
|
68
|
+
// - test: Should run `do/test` to invoke endpoint and verify correctness
|
|
69
|
+
// - tune: Should run `do/tune` for fine-tuning jobs (gated by shouldExecuteTuneStages)
|
|
70
|
+
// - adapter: Should run `do/adapter` for LoRA adapter serving
|
|
71
|
+
// - test-adapter: Should test adapter endpoints after deployment
|
|
72
|
+
// - benchmark: Should run `do/benchmark` for performance measurement
|
|
73
|
+
// - register: Should register proven config in Athena/DynamoDB
|
|
74
|
+
// - clean: Should tear down deployed resources
|
|
75
|
+
// These were never finished before the original developer's laptop was bricked.
|
|
76
|
+
// They are explicitly post-v1 scope, not "broken" code.
|
|
77
|
+
|
|
43
78
|
/**
|
|
44
79
|
* Possible staging states for status output.
|
|
45
80
|
*/
|
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
#!/usr/bin/env python3
|
|
2
|
-
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
3
1
|
# SPDX-License-Identifier: Apache-2.0
|
|
4
2
|
|
|
5
3
|
"""Benchmark Writer — Converts do/benchmark output to enriched Parquet for Athena.
|
|
@@ -340,7 +338,7 @@ def _extract_base_image_version(base_image):
|
|
|
340
338
|
return ''
|
|
341
339
|
|
|
342
340
|
|
|
343
|
-
def enrich_records(config, results, run_timestamp=None):
|
|
341
|
+
def enrich_records(config, results, run_timestamp=None, instance_catalog=None):
|
|
344
342
|
"""Build enriched records from config context and benchmark results.
|
|
345
343
|
|
|
346
344
|
Each metrics entry becomes one enriched record with all Athena columns populated.
|
|
@@ -349,6 +347,7 @@ def enrich_records(config, results, run_timestamp=None):
|
|
|
349
347
|
config: dict with config context fields (project_name, model_name, etc.)
|
|
350
348
|
results: dict with benchmark results (job_name, metrics array)
|
|
351
349
|
run_timestamp: Optional datetime for run_timestamp. Defaults to now UTC.
|
|
350
|
+
instance_catalog: Optional pre-loaded instance catalog dict. If None, loaded from disk.
|
|
352
351
|
|
|
353
352
|
Returns:
|
|
354
353
|
list of enriched record dicts (one per concurrency level).
|
|
@@ -364,10 +363,21 @@ def enrich_records(config, results, run_timestamp=None):
|
|
|
364
363
|
|
|
365
364
|
# Derived fields
|
|
366
365
|
model_family = derive_model_family(model_name)
|
|
366
|
+
instance_family = derive_instance_family(instance_type)
|
|
367
|
+
|
|
368
|
+
# Resolve instance metadata from catalog (AC-2.8)
|
|
369
|
+
hw_meta = resolve_instance_metadata(instance_type, instance_catalog)
|
|
370
|
+
gpu_count = hw_meta['gpu_count']
|
|
371
|
+
gpu_type = hw_meta['gpu_type']
|
|
372
|
+
gpu_memory_gb = hw_meta['gpu_memory_gb']
|
|
367
373
|
|
|
368
374
|
# Optional context fields
|
|
369
375
|
deployment_target = config.get('deployment_target', 'realtime-inference')
|
|
370
|
-
|
|
376
|
+
try:
|
|
377
|
+
tensor_parallel_degree = int(config.get('tensor_parallel_degree', 1))
|
|
378
|
+
except (ValueError, TypeError):
|
|
379
|
+
tensor_parallel_degree = 1
|
|
380
|
+
|
|
371
381
|
quantization = config.get('quantization', 'none')
|
|
372
382
|
enable_lora = config.get('enable_lora', False)
|
|
373
383
|
base_image = config.get('base_image', '')
|
|
@@ -377,6 +387,11 @@ def enrich_records(config, results, run_timestamp=None):
|
|
|
377
387
|
ci_run_id = config.get('ci_run_id', '')
|
|
378
388
|
account_id = config.get('account_id', '')
|
|
379
389
|
|
|
390
|
+
# Configuration dimensions (nullable)
|
|
391
|
+
max_model_len_raw = config.get('max_model_len')
|
|
392
|
+
max_model_len = int(max_model_len_raw) if max_model_len_raw not in (None, '', 0) else None
|
|
393
|
+
kv_cache_dtype = config.get('kv_cache_dtype') or None
|
|
394
|
+
|
|
380
395
|
|
|
381
396
|
# Get metrics from results
|
|
382
397
|
metrics = results.get('metrics', []) if isinstance(results, dict) else []
|
|
@@ -447,6 +462,13 @@ def enrich_records(config, results, run_timestamp=None):
|
|
|
447
462
|
'deployment_target': deployment_target,
|
|
448
463
|
'quantization': quantization,
|
|
449
464
|
'tensor_parallel_degree': tensor_parallel_degree,
|
|
465
|
+
'instance_family': instance_family,
|
|
466
|
+
'gpu_count': gpu_count,
|
|
467
|
+
'gpu_type': gpu_type,
|
|
468
|
+
'gpu_memory_gb': gpu_memory_gb,
|
|
469
|
+
'max_model_len': max_model_len,
|
|
470
|
+
'enable_lora': enable_lora,
|
|
471
|
+
'kv_cache_dtype': kv_cache_dtype,
|
|
450
472
|
'serving_config': json.dumps(serving_config_dict),
|
|
451
473
|
'workload': config.get('workload', 'manual'),
|
|
452
474
|
'concurrency': concurrency,
|
|
@@ -481,6 +503,7 @@ def enrich_records(config, results, run_timestamp=None):
|
|
|
481
503
|
'output_sequence_length_avg': scalar(metric.get('output_sequence_length', metric.get('output_sequence_length_avg', 0.0))),
|
|
482
504
|
'input_sequence_length_avg': scalar(metric.get('input_sequence_length', metric.get('input_sequence_length_avg', 0.0))),
|
|
483
505
|
'error_rate': error_rate,
|
|
506
|
+
'cost_per_1m_tokens': cost,
|
|
484
507
|
'benchmark_duration_sec': metric.get('benchmark_duration_sec', duration_seconds),
|
|
485
508
|
'run_type': run_type,
|
|
486
509
|
'benchmark_job_name': results.get('job_name', '') if isinstance(results, dict) else '',
|
|
@@ -792,6 +815,54 @@ def register_partition(bucket, model, instance, target,
|
|
|
792
815
|
# ── Parquet Serialization ─────────────────────────────────────────────────────
|
|
793
816
|
|
|
794
817
|
|
|
818
|
+
def load_instance_catalog():
|
|
819
|
+
"""Load the instance catalog from servers/lib/catalogs/instances.json.
|
|
820
|
+
|
|
821
|
+
Resolves the path relative to the project root (two levels up from templates/do/).
|
|
822
|
+
Returns the 'catalog' dict mapping instance_type → metadata, or empty dict on failure.
|
|
823
|
+
|
|
824
|
+
Returns:
|
|
825
|
+
dict mapping instance type strings to their metadata dicts.
|
|
826
|
+
"""
|
|
827
|
+
# Resolve relative to this file: templates/do/.benchmark_writer.py → project root
|
|
828
|
+
this_dir = os.path.dirname(os.path.abspath(__file__))
|
|
829
|
+
# Navigate up from templates/do/ to project root
|
|
830
|
+
project_root = os.path.normpath(os.path.join(this_dir, '..', '..'))
|
|
831
|
+
catalog_path = os.path.join(project_root, 'servers', 'lib', 'catalogs', 'instances.json')
|
|
832
|
+
|
|
833
|
+
try:
|
|
834
|
+
with open(catalog_path, 'r') as f:
|
|
835
|
+
data = json.load(f)
|
|
836
|
+
return data.get('catalog', {})
|
|
837
|
+
except (FileNotFoundError, json.JSONDecodeError, IOError):
|
|
838
|
+
return {}
|
|
839
|
+
|
|
840
|
+
|
|
841
|
+
def resolve_instance_metadata(instance_type, instance_catalog=None):
|
|
842
|
+
"""Resolve GPU metadata from the instance catalog for a given instance_type.
|
|
843
|
+
|
|
844
|
+
Args:
|
|
845
|
+
instance_type: SageMaker instance type (e.g., 'ml.g5.xlarge').
|
|
846
|
+
instance_catalog: Optional pre-loaded catalog dict. If None, loads from disk.
|
|
847
|
+
|
|
848
|
+
Returns:
|
|
849
|
+
dict with keys: gpu_count (int|None), gpu_type (str|None), gpu_memory_gb (float|None).
|
|
850
|
+
All values are None if instance_type is not found in catalog.
|
|
851
|
+
"""
|
|
852
|
+
if instance_catalog is None:
|
|
853
|
+
instance_catalog = load_instance_catalog()
|
|
854
|
+
|
|
855
|
+
entry = instance_catalog.get(instance_type)
|
|
856
|
+
if entry is None:
|
|
857
|
+
return {'gpu_count': None, 'gpu_type': None, 'gpu_memory_gb': None}
|
|
858
|
+
|
|
859
|
+
return {
|
|
860
|
+
'gpu_count': entry.get('gpus'),
|
|
861
|
+
'gpu_type': entry.get('gpuType'),
|
|
862
|
+
'gpu_memory_gb': entry.get('gpuMemoryGb'),
|
|
863
|
+
}
|
|
864
|
+
|
|
865
|
+
|
|
795
866
|
def get_parquet_schema():
|
|
796
867
|
"""Return the pyarrow schema matching the Athena DDL for benchmark_results.
|
|
797
868
|
|
|
@@ -814,6 +885,17 @@ def get_parquet_schema():
|
|
|
814
885
|
pa.field("quantization", pa.string()),
|
|
815
886
|
pa.field("tensor_parallel_degree", pa.int32()),
|
|
816
887
|
|
|
888
|
+
# Hardware metadata (resolved from instance catalog at write time)
|
|
889
|
+
pa.field("instance_family", pa.string()),
|
|
890
|
+
pa.field("gpu_count", pa.int32()),
|
|
891
|
+
pa.field("gpu_type", pa.string()),
|
|
892
|
+
pa.field("gpu_memory_gb", pa.float64()),
|
|
893
|
+
|
|
894
|
+
# Configuration dimensions (top-level for Athena queryability)
|
|
895
|
+
pa.field("max_model_len", pa.int32()),
|
|
896
|
+
pa.field("enable_lora", pa.bool_()),
|
|
897
|
+
pa.field("kv_cache_dtype", pa.string()),
|
|
898
|
+
|
|
817
899
|
# Full serving config (extensible JSON blob)
|
|
818
900
|
pa.field("serving_config", pa.string()),
|
|
819
901
|
|
|
@@ -852,6 +934,7 @@ def get_parquet_schema():
|
|
|
852
934
|
pa.field("output_sequence_length_avg", pa.float64()),
|
|
853
935
|
pa.field("input_sequence_length_avg", pa.float64()),
|
|
854
936
|
pa.field("error_rate", pa.float64()),
|
|
937
|
+
pa.field("cost_per_1m_tokens", pa.float64()),
|
|
855
938
|
pa.field("benchmark_duration_sec", pa.float64()),
|
|
856
939
|
|
|
857
940
|
# Run Metadata
|
|
@@ -1182,6 +1265,9 @@ def cmd_write(args):
|
|
|
1182
1265
|
if args.adapter_name:
|
|
1183
1266
|
input_data['adapter_name'] = args.adapter_name
|
|
1184
1267
|
|
|
1268
|
+
if getattr(args, 'instance_type', None):
|
|
1269
|
+
input_data['instance_type'] = args.instance_type
|
|
1270
|
+
|
|
1185
1271
|
# ── Validate before any S3 interaction ────────────────────────────────
|
|
1186
1272
|
errors = validate_benchmark_input(input_data)
|
|
1187
1273
|
if errors:
|
|
@@ -1391,6 +1477,8 @@ def _load_config_file(config_path):
|
|
|
1391
1477
|
'MODEL_NAME': 'model_name',
|
|
1392
1478
|
'HF_MODEL_ID': 'hf_model_id',
|
|
1393
1479
|
'INSTANCE_TYPE': 'instance_type',
|
|
1480
|
+
'INSTANCE_POOLS': 'instance_pools',
|
|
1481
|
+
'BENCHMARK_INSTANCE_TYPE': 'benchmark_instance_type',
|
|
1394
1482
|
'DEPLOYMENT_CONFIG': 'deployment_config',
|
|
1395
1483
|
'DEPLOYMENT_TARGET': 'deployment_target',
|
|
1396
1484
|
'AWS_REGION': 'region',
|
|
@@ -1429,6 +1517,24 @@ def _load_config_file(config_path):
|
|
|
1429
1517
|
parts = context['model_name'].rstrip('/').split('/')
|
|
1430
1518
|
context['model_name'] = parts[-1] if parts else context['model_name']
|
|
1431
1519
|
|
|
1520
|
+
# Resolve instance_type precedence:
|
|
1521
|
+
# BENCHMARK_INSTANCE_TYPE (live-resolved, persisted by do/benchmark) > INSTANCE_TYPE > INSTANCE_POOLS fallback
|
|
1522
|
+
if context.get('benchmark_instance_type'):
|
|
1523
|
+
context['instance_type'] = context.pop('benchmark_instance_type')
|
|
1524
|
+
# Fall back to INSTANCE_POOLS when neither is set.
|
|
1525
|
+
# Heterogeneous pool configs may not have a standalone INSTANCE_TYPE value
|
|
1526
|
+
# but always define INSTANCE_POOLS as a JSON array with Priority fields.
|
|
1527
|
+
if not context.get('instance_type') and context.get('instance_pools'):
|
|
1528
|
+
try:
|
|
1529
|
+
pools = json.loads(context['instance_pools'])
|
|
1530
|
+
if pools:
|
|
1531
|
+
# Pick the highest-priority (lowest number) instance
|
|
1532
|
+
best = min(pools, key=lambda p: p.get('Priority', 999))
|
|
1533
|
+
context['instance_type'] = best.get('InstanceType', '')
|
|
1534
|
+
except (json.JSONDecodeError, TypeError, KeyError):
|
|
1535
|
+
pass
|
|
1536
|
+
context.pop('instance_pools', None) # Don't leak raw JSON into record
|
|
1537
|
+
|
|
1432
1538
|
# Also scan IC config files (do/ic/*.conf) for IC_ENV_* serving params
|
|
1433
1539
|
# These override do/config values for serving-specific settings
|
|
1434
1540
|
try:
|
|
@@ -1505,6 +1611,10 @@ def main():
|
|
|
1505
1611
|
help='LoRA adapter name (differentiates adapter benchmarks from base model in Athena)'
|
|
1506
1612
|
)
|
|
1507
1613
|
|
|
1614
|
+
write_parser.add_argument(
|
|
1615
|
+
'--instance-type', dest='instance_type', default=None,
|
|
1616
|
+
help='Override instance type (use when actual provisioned instance differs from config, e.g. heterogeneous pools)'
|
|
1617
|
+
)
|
|
1508
1618
|
write_parser.add_argument(
|
|
1509
1619
|
'--dry-run', dest='dry_run', action='store_true',
|
|
1510
1620
|
help='Output enriched records as JSON without writing to S3'
|