@aws/ml-container-creator 1.0.3 → 1.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +10 -1
- package/bin/cli.js +57 -0
- package/config/agent.json +16 -0
- package/infra/ci-harness/lib/ci-harness-stack.ts +43 -0
- package/package.json +5 -2
- package/pyproject.toml +3 -0
- package/servers/agent-knowledge/index.js +592 -0
- package/servers/agent-knowledge/package.json +15 -0
- package/servers/base-image-picker/index.js +65 -18
- package/servers/instance-sizer/index.js +32 -0
- package/servers/lib/catalogs/fleet-drivers.json +38 -0
- package/servers/lib/catalogs/model-arch-support.json +51 -0
- package/servers/lib/catalogs/model-servers.json +2842 -1730
- package/servers/lib/schemas/image-catalog.schema.json +12 -0
- package/src/agent/__init__.py +2 -0
- package/src/agent/__pycache__/__init__.cpython-312.pyc +0 -0
- package/src/agent/__pycache__/config_loader.cpython-312.pyc +0 -0
- package/src/agent/__pycache__/context.cpython-312.pyc +0 -0
- package/src/agent/__pycache__/health_check.cpython-312.pyc +0 -0
- package/src/agent/agent.py +513 -0
- package/src/agent/config_loader.py +215 -0
- package/src/agent/context.py +380 -0
- package/src/agent/data/capability-matrix.json +106 -0
- package/src/agent/health_check.py +341 -0
- package/src/agent/prompts/system.md +173 -0
- package/src/agent/requirements-agent.txt +3 -0
- package/src/app.js +6 -4
- package/src/lib/generated/cli-options.js +1 -1
- package/src/lib/generated/parameter-matrix.js +1 -1
- package/src/lib/generated/validation-rules.js +1 -1
- package/src/lib/mcp-query-runner.js +110 -3
- package/src/lib/prompt-runner.js +66 -22
- package/src/lib/template-variable-resolver.js +8 -0
- package/src/lib/train-config-builder.js +339 -0
- package/src/lib/tune-config-state.js +89 -68
- package/templates/do/.benchmark_writer.py +3 -0
- package/templates/do/.eval_helper.py +409 -0
- package/templates/do/.register_helper.py +185 -11
- package/templates/do/.train_build_request.py +102 -113
- package/templates/do/.train_helper.py +433 -0
- package/templates/do/__pycache__/.register_helper.cpython-312.pyc +0 -0
- package/templates/do/adapter +157 -0
- package/templates/do/benchmark +60 -3
- package/templates/do/config +6 -1
- package/templates/do/deploy.d/managed-inference.ejs +83 -0
- package/templates/do/evaluate +272 -0
- package/templates/do/lib/resolve-instance.sh +155 -0
- package/templates/do/register +5 -0
- package/templates/do/test +1 -0
- package/templates/do/train +879 -126
- package/templates/do/training/config.yaml +83 -11
- package/templates/do/training/dpo/accelerate_config.yaml +24 -0
- package/templates/do/training/dpo/defaults.yaml +26 -0
- package/templates/do/training/dpo/prompts.json +8 -0
- package/templates/do/training/dpo/train.py +363 -0
- package/templates/do/training/sft/accelerate_config.yaml +22 -0
- package/templates/do/training/sft/defaults.yaml +18 -0
- package/templates/do/training/sft/prompts.json +7 -0
- package/templates/do/training/sft/train.py +310 -0
- package/templates/do/tune +11 -2
- package/src/lib/auto-prompt-builder.js +0 -172
- package/src/lib/cli-handler.js +0 -529
- package/src/lib/community-reports-validator.js +0 -91
- package/src/lib/configuration-exporter.js +0 -204
- package/src/lib/dataset-slug.js +0 -152
- package/src/lib/docker-introspection-validator.js +0 -51
- package/src/lib/known-flags-validator.js +0 -200
- package/src/lib/schema-validator.js +0 -157
- package/src/lib/train-config-parser.js +0 -136
- package/src/lib/train-config-persistence.js +0 -143
- package/src/lib/train-config-validator.js +0 -112
- package/src/lib/train-feedback.js +0 -46
- package/src/lib/train-idempotency.js +0 -97
- package/src/lib/train-request-builder.js +0 -120
- package/src/lib/tune-dataset-validator.js +0 -279
- package/src/lib/tune-output-resolver.js +0 -66
- package/templates/do/.train_poll_parser.py +0 -135
- package/templates/do/.train_status_parser.py +0 -187
- /package/templates/do/training/{train.py → custom/train.py} +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
// AUTO-GENERATED by scripts/codegen-parameter-matrix.js — DO NOT EDIT
|
|
2
2
|
// Source: config/parameter-schema-v2.json
|
|
3
|
-
// Generated: 2026-
|
|
3
|
+
// Generated: 2026-07-01T20:12:14.996Z
|
|
4
4
|
|
|
5
5
|
/**
|
|
6
6
|
* Parameter matrix defining how each parameter is loaded from various sources.
|
|
@@ -384,6 +384,9 @@ export default class McpQueryRunner {
|
|
|
384
384
|
const endpointNames = result.choices.endpointName;
|
|
385
385
|
const metadata = result.metadata || {};
|
|
386
386
|
|
|
387
|
+
// Store endpoint metadata for later instance type resolution (US-1)
|
|
388
|
+
this.runner._endpointPickerMetadata = metadata;
|
|
389
|
+
|
|
387
390
|
// Build choices with metadata annotations
|
|
388
391
|
this.runner._mcpEndpointChoices = endpointNames.map(name => {
|
|
389
392
|
const meta = metadata[name];
|
|
@@ -412,12 +415,15 @@ export default class McpQueryRunner {
|
|
|
412
415
|
}
|
|
413
416
|
|
|
414
417
|
/**
|
|
415
|
-
* Query MCP base-image-picker server after deployment config
|
|
418
|
+
* Query MCP base-image-picker server after deployment config and instance type are known.
|
|
416
419
|
* Populates _mcpBaseImageChoices for the base image selection prompt.
|
|
417
|
-
* Requirements: 5.1, 5.2, 5.3, 5.4, 9.1, 9.2, 9.3
|
|
420
|
+
* Requirements: 5.1, 5.2, 5.3, 5.4, 9.1, 9.2, 9.3, US-1 (ordering constraint)
|
|
421
|
+
* @param {Object} frameworkAnswers - Framework/architecture answers
|
|
422
|
+
* @param {Object} _explicitConfig - Explicit CLI/config values
|
|
423
|
+
* @param {Object} [infraContext] - Infrastructure context (instanceType, tensorParallelSize, modelId)
|
|
418
424
|
* @private
|
|
419
425
|
*/
|
|
420
|
-
async _queryMcpForBaseImage(frameworkAnswers, _explicitConfig) {
|
|
426
|
+
async _queryMcpForBaseImage(frameworkAnswers, _explicitConfig, infraContext = {}) {
|
|
421
427
|
// Skip if base image provided via CLI --base-image flag
|
|
422
428
|
if (this.runner.options['base-image']) return;
|
|
423
429
|
|
|
@@ -454,6 +460,17 @@ export default class McpQueryRunner {
|
|
|
454
460
|
context.searchCriteria = searchCriteria.trim();
|
|
455
461
|
}
|
|
456
462
|
|
|
463
|
+
// Pass infrastructure context for driver-aware filtering (US-1 ordering constraint)
|
|
464
|
+
if (infraContext.instanceType) {
|
|
465
|
+
context.instanceType = infraContext.instanceType;
|
|
466
|
+
}
|
|
467
|
+
if (infraContext.tensorParallelSize !== null && infraContext.tensorParallelSize !== undefined) {
|
|
468
|
+
context.tensorParallelSize = infraContext.tensorParallelSize;
|
|
469
|
+
}
|
|
470
|
+
if (infraContext.modelId) {
|
|
471
|
+
context.modelId = infraContext.modelId;
|
|
472
|
+
}
|
|
473
|
+
|
|
457
474
|
const result = await cm.queryMcpServer('base-image-picker', context);
|
|
458
475
|
|
|
459
476
|
if (result && result.metadata?.baseImage?.length > 0) {
|
|
@@ -716,6 +733,96 @@ export default class McpQueryRunner {
|
|
|
716
733
|
}
|
|
717
734
|
}
|
|
718
735
|
|
|
736
|
+
/**
|
|
737
|
+
* Resolve instance type from an existing endpoint.
|
|
738
|
+
* Priority:
|
|
739
|
+
* 1. Endpoint-picker metadata (already fetched, no network call)
|
|
740
|
+
* 2. Direct AWS SDK call: DescribeEndpoint → DescribeEndpointConfig
|
|
741
|
+
*
|
|
742
|
+
* Reuses the resolution pattern from do/lib/resolve-instance.sh:
|
|
743
|
+
* - Check ProductionVariants[0].CurrentInstanceType or InstanceType
|
|
744
|
+
* - Fallback: DescribeEndpointConfig → ProductionVariants[0].InstanceType
|
|
745
|
+
* - Final fallback: InstancePools[0] (highest priority)
|
|
746
|
+
*
|
|
747
|
+
* Requirements: US-1 (ordering constraint — resolve instance type before base image picker)
|
|
748
|
+
* @param {string} endpointName - The existing endpoint name
|
|
749
|
+
* @param {string} awsRegion - AWS region for API calls
|
|
750
|
+
* @returns {Promise<string|null>} Resolved instance type or null on failure
|
|
751
|
+
* @private
|
|
752
|
+
*/
|
|
753
|
+
async _resolveEndpointInstanceType(endpointName, awsRegion) {
|
|
754
|
+
// Strategy 1: Use endpoint-picker metadata (already fetched, no network call)
|
|
755
|
+
if (this.runner._endpointPickerMetadata) {
|
|
756
|
+
const meta = this.runner._endpointPickerMetadata[endpointName];
|
|
757
|
+
if (meta?.instanceType) {
|
|
758
|
+
// Strip pool annotation if present: "ml.g5.12xlarge (pool: 3 types)" → "ml.g5.12xlarge"
|
|
759
|
+
const rawInstanceType = meta.instanceType.includes(' (pool:')
|
|
760
|
+
? meta.instanceType.split(' (pool:')[0]
|
|
761
|
+
: meta.instanceType;
|
|
762
|
+
if (rawInstanceType && rawInstanceType !== 'unknown') {
|
|
763
|
+
console.log(` ✓ Resolved instance type from endpoint metadata: ${rawInstanceType}`);
|
|
764
|
+
return rawInstanceType;
|
|
765
|
+
}
|
|
766
|
+
}
|
|
767
|
+
}
|
|
768
|
+
|
|
769
|
+
// Strategy 2: Direct AWS SDK call (for custom endpoint names not in picker results)
|
|
770
|
+
console.log(' 🔍 Resolving instance type from existing endpoint...');
|
|
771
|
+
try {
|
|
772
|
+
const { SageMakerClient, DescribeEndpointCommand, DescribeEndpointConfigCommand } = await import('@aws-sdk/client-sagemaker');
|
|
773
|
+
|
|
774
|
+
const region = awsRegion || process.env.AWS_REGION || 'us-east-1';
|
|
775
|
+
const clientOptions = { region };
|
|
776
|
+
|
|
777
|
+
// Use AWS profile if available
|
|
778
|
+
const awsProfile = this.runner.configManager?.config?.awsProfile
|
|
779
|
+
|| this.runner.options?.profile || process.env.AWS_PROFILE || null;
|
|
780
|
+
if (awsProfile) {
|
|
781
|
+
try {
|
|
782
|
+
const { fromIni } = await import('@aws-sdk/credential-providers');
|
|
783
|
+
clientOptions.credentials = fromIni({ profile: awsProfile });
|
|
784
|
+
} catch {
|
|
785
|
+
// credential-providers not available, use default chain
|
|
786
|
+
}
|
|
787
|
+
}
|
|
788
|
+
|
|
789
|
+
const client = new SageMakerClient(clientOptions);
|
|
790
|
+
|
|
791
|
+
// DescribeEndpoint — check ProductionVariants for instance type
|
|
792
|
+
const epResponse = await client.send(new DescribeEndpointCommand({ EndpointName: endpointName }));
|
|
793
|
+
|
|
794
|
+
const primaryVariant = (epResponse.ProductionVariants || [])[0] || {};
|
|
795
|
+
let instanceType = primaryVariant.CurrentInstanceType || primaryVariant.InstanceType || null;
|
|
796
|
+
|
|
797
|
+
// Fallback: DescribeEndpointConfig
|
|
798
|
+
if (!instanceType && epResponse.EndpointConfigName) {
|
|
799
|
+
const ecResponse = await client.send(
|
|
800
|
+
new DescribeEndpointConfigCommand({ EndpointConfigName: epResponse.EndpointConfigName })
|
|
801
|
+
);
|
|
802
|
+
const ecVariant = (ecResponse.ProductionVariants || [])[0];
|
|
803
|
+
if (ecVariant?.InstanceType) {
|
|
804
|
+
instanceType = ecVariant.InstanceType;
|
|
805
|
+
} else if (ecVariant?.InstancePools?.length > 0) {
|
|
806
|
+
// Use highest-priority pool entry (lowest Priority number)
|
|
807
|
+
const sorted = [...ecVariant.InstancePools].sort((a, b) => (a.Priority || 99) - (b.Priority || 99));
|
|
808
|
+
instanceType = sorted[0].InstanceType || null;
|
|
809
|
+
}
|
|
810
|
+
}
|
|
811
|
+
|
|
812
|
+
if (instanceType) {
|
|
813
|
+
console.log(` ✓ Resolved instance type from endpoint: ${instanceType}`);
|
|
814
|
+
return instanceType;
|
|
815
|
+
}
|
|
816
|
+
|
|
817
|
+
console.log(' ↳ Could not determine instance type from endpoint');
|
|
818
|
+
return null;
|
|
819
|
+
} catch (err) {
|
|
820
|
+
// Graceful fallback: if AWS call fails, skip filtering (no driver-aware filter)
|
|
821
|
+
console.log(` ⚠️ Could not resolve instance type from endpoint: ${err.message}`);
|
|
822
|
+
return null;
|
|
823
|
+
}
|
|
824
|
+
}
|
|
825
|
+
|
|
719
826
|
/**
|
|
720
827
|
* Validate and display instance type compatibility
|
|
721
828
|
* Requirements: 4.1, 4.2, 4.3, 4.4, 4.5, 4.6
|
package/src/lib/prompt-runner.js
CHANGED
|
@@ -68,6 +68,7 @@ export default class PromptRunner {
|
|
|
68
68
|
_queryMcpForInstance(...args) { return this.mcpQueryRunner._queryMcpForInstance(...args); }
|
|
69
69
|
_queryMcpForInstanceSizing(...args) { return this.mcpQueryRunner._queryMcpForInstanceSizing(...args); }
|
|
70
70
|
_queryMcpForEndpoints(...args) { return this.mcpQueryRunner._queryMcpForEndpoints(...args); }
|
|
71
|
+
_resolveEndpointInstanceType(...args) { return this.mcpQueryRunner._resolveEndpointInstanceType(...args); }
|
|
71
72
|
_queryMcpForHyperPod(...args) { return this.mcpQueryRunner._queryMcpForHyperPod(...args); }
|
|
72
73
|
_fetchAndDisplayModelInfo(...args) { return this.mcpQueryRunner._fetchAndDisplayModelInfo(...args); }
|
|
73
74
|
_validateAndDisplayInstanceType(...args) { return this.mcpQueryRunner._validateAndDisplayInstanceType(...args); }
|
|
@@ -182,8 +183,8 @@ export default class PromptRunner {
|
|
|
182
183
|
}
|
|
183
184
|
|
|
184
185
|
// ══════════════════════════════════════════════════════════════════════
|
|
185
|
-
// Phase 2 — How (deployment target + serving profile
|
|
186
|
-
// Requirements:
|
|
186
|
+
// Phase 2 — How (deployment target + serving profile)
|
|
187
|
+
// Requirements: US-1 — base image selection moved AFTER instance resolution
|
|
187
188
|
// ══════════════════════════════════════════════════════════════════════
|
|
188
189
|
console.log('\n💪 Infrastructure & Deployment');
|
|
189
190
|
|
|
@@ -192,25 +193,8 @@ export default class PromptRunner {
|
|
|
192
193
|
const regionPreviousAnswers = bootstrapRegion ? { _bootstrapRegion: bootstrapRegion } : {};
|
|
193
194
|
const regionAndTargetAnswers = await this._runPhase(infraRegionAndTargetPrompts, { ...frameworkAnswers, ...regionPreviousAnswers }, explicitConfig, existingConfig);
|
|
194
195
|
|
|
195
|
-
//
|
|
196
|
-
|
|
197
|
-
const baseImagePreviousAnswers = {
|
|
198
|
-
...frameworkAnswers,
|
|
199
|
-
...engineAnswers,
|
|
200
|
-
...(this._mcpBaseImageChoices ? { _mcpBaseImageChoices: this._mcpBaseImageChoices } : {})
|
|
201
|
-
};
|
|
202
|
-
const baseImageAnswers = await this._runPhase(
|
|
203
|
-
baseImagePrompts,
|
|
204
|
-
baseImagePreviousAnswers,
|
|
205
|
-
explicitConfig,
|
|
206
|
-
existingConfig
|
|
207
|
-
);
|
|
208
|
-
|
|
209
|
-
// Requirements: 4.2-4.5 — Check model architecture compatibility after base image selection
|
|
210
|
-
this._checkModelArchitectureCompatibility(baseImageAnswers, frameworkAnswers);
|
|
211
|
-
|
|
212
|
-
// Extract CUDA version from selected base image for instance-sizer context
|
|
213
|
-
const selectedBaseImageCuda = this._extractCudaFromBaseImage(baseImageAnswers);
|
|
196
|
+
// NOTE: Base image selection moved to Phase 3 (after instance type resolution)
|
|
197
|
+
// to enable driver-aware filtering. See US-1 ordering constraint in requirements.
|
|
214
198
|
|
|
215
199
|
// ══════════════════════════════════════════════════════════════════════
|
|
216
200
|
// Phase 3 — Where (region + instance [derived] + CUDA/AMI + HyperPod + build target)
|
|
@@ -283,7 +267,7 @@ export default class PromptRunner {
|
|
|
283
267
|
} else if (phase1ModelId && phase1ModelId !== 'Custom (enter manually)') {
|
|
284
268
|
// Query instance-sizer with full context
|
|
285
269
|
await this.mcpQueryRunner._queryMcpForInstanceSizing(frameworkAnswers, modelFormatAnswers, explicitConfig, {
|
|
286
|
-
cudaVersion:
|
|
270
|
+
cudaVersion: null, // base image not yet selected (moved after instance resolution)
|
|
287
271
|
profileEnvVars: this._selectedProfileEnvVars || {}
|
|
288
272
|
});
|
|
289
273
|
} else {
|
|
@@ -422,6 +406,66 @@ export default class PromptRunner {
|
|
|
422
406
|
}
|
|
423
407
|
}
|
|
424
408
|
|
|
409
|
+
// 3b2. Base image selection — AFTER instance type resolved (US-1 ordering constraint)
|
|
410
|
+
// Pass resolved instanceType and tensorParallelSize for driver-aware filtering
|
|
411
|
+
let resolvedInstanceType = instanceAnswers.customInstanceType || instanceAnswers.instanceType;
|
|
412
|
+
let resolvedTensorParallelSize = this._autoTensorParallelism || 1;
|
|
413
|
+
|
|
414
|
+
// For existing endpoints: resolve instance type from the endpoint (US-1 ordering constraint)
|
|
415
|
+
// The instance type is needed for driver-aware base image filtering even though the user
|
|
416
|
+
// doesn't select it manually. Pattern reused from do/lib/resolve-instance.sh.
|
|
417
|
+
if (!resolvedInstanceType && existingEndpointAnswers.existingEndpointName) {
|
|
418
|
+
const resolvedRegion = regionAndTargetAnswers.customAwsRegion || regionAndTargetAnswers.awsRegion;
|
|
419
|
+
resolvedInstanceType = await this.mcpQueryRunner._resolveEndpointInstanceType(
|
|
420
|
+
existingEndpointAnswers.existingEndpointName,
|
|
421
|
+
resolvedRegion
|
|
422
|
+
);
|
|
423
|
+
// Store resolved instance type for downstream use (IC config, GPU count derivation)
|
|
424
|
+
if (resolvedInstanceType) {
|
|
425
|
+
existingEndpointAnswers._resolvedEndpointInstanceType = resolvedInstanceType;
|
|
426
|
+
// Propagate as instanceType so template-variable-resolver derives
|
|
427
|
+
// icGpuCount and tensorParallelSize from the instance catalog.
|
|
428
|
+
// Without this, IC_GPU_COUNT defaults to 1 even for multi-GPU instances.
|
|
429
|
+
existingEndpointAnswers.instanceType = resolvedInstanceType;
|
|
430
|
+
|
|
431
|
+
// Derive GPU count from instance catalog for immediate use (TP for base image filtering)
|
|
432
|
+
const endpointInstanceEntry = instanceCatalogRaw[resolvedInstanceType];
|
|
433
|
+
if (endpointInstanceEntry?.gpus && endpointInstanceEntry.gpus > 1) {
|
|
434
|
+
existingEndpointAnswers.gpuCount = endpointInstanceEntry.gpus;
|
|
435
|
+
existingEndpointAnswers.tensorParallelSize = endpointInstanceEntry.gpus;
|
|
436
|
+
this._autoTensorParallelism = endpointInstanceEntry.gpus;
|
|
437
|
+
this._autoGpuCount = endpointInstanceEntry.gpus;
|
|
438
|
+
console.log(` ✓ Endpoint instance ${resolvedInstanceType}: ${endpointInstanceEntry.gpus} GPUs (TP=${endpointInstanceEntry.gpus})`);
|
|
439
|
+
}
|
|
440
|
+
}
|
|
441
|
+
}
|
|
442
|
+
|
|
443
|
+
// Re-read tensor parallel size after potential endpoint resolution update
|
|
444
|
+
resolvedTensorParallelSize = this._autoTensorParallelism || 1;
|
|
445
|
+
|
|
446
|
+
await this.mcpQueryRunner._queryMcpForBaseImage(frameworkAnswers, explicitConfig, {
|
|
447
|
+
instanceType: resolvedInstanceType,
|
|
448
|
+
tensorParallelSize: resolvedTensorParallelSize,
|
|
449
|
+
modelId: phase1ModelId || undefined
|
|
450
|
+
});
|
|
451
|
+
const baseImagePreviousAnswers = {
|
|
452
|
+
...frameworkAnswers,
|
|
453
|
+
...engineAnswers,
|
|
454
|
+
...(this._mcpBaseImageChoices ? { _mcpBaseImageChoices: this._mcpBaseImageChoices } : {})
|
|
455
|
+
};
|
|
456
|
+
const baseImageAnswers = await this._runPhase(
|
|
457
|
+
baseImagePrompts,
|
|
458
|
+
baseImagePreviousAnswers,
|
|
459
|
+
explicitConfig,
|
|
460
|
+
existingConfig
|
|
461
|
+
);
|
|
462
|
+
|
|
463
|
+
// Requirements: 4.2-4.5 — Check model architecture compatibility after base image selection
|
|
464
|
+
this._checkModelArchitectureCompatibility(baseImageAnswers, frameworkAnswers);
|
|
465
|
+
|
|
466
|
+
// Extract CUDA version from selected base image for CUDA/AMI auto-resolution
|
|
467
|
+
const selectedBaseImageCuda = this._extractCudaFromBaseImage(baseImageAnswers);
|
|
468
|
+
|
|
425
469
|
// 3c. Async-specific prompts (only when deploymentTarget === 'async-inference')
|
|
426
470
|
let asyncAnswers = {};
|
|
427
471
|
if (regionAndTargetAnswers.deploymentTarget === 'async-inference') {
|
|
@@ -454,6 +454,14 @@ export async function _ensureTemplateVariables(answers, registryConfigManager =
|
|
|
454
454
|
answers.tensorParallelSize = instanceGpuCount;
|
|
455
455
|
answers._tpAutoResolved = true;
|
|
456
456
|
answers._tpAutoResolvedFrom = answers.instanceType;
|
|
457
|
+
|
|
458
|
+
// Also propagate to icEnvVars so IC_ENV_VLLM_TENSOR_PARALLEL_SIZE
|
|
459
|
+
// (or equivalent) is written in do/config for deploy-time IC creation.
|
|
460
|
+
if (!answers.icEnvVars) {
|
|
461
|
+
answers.icEnvVars = {};
|
|
462
|
+
}
|
|
463
|
+
answers.icEnvVars[tpEnvKey] = String(instanceGpuCount);
|
|
464
|
+
|
|
457
465
|
console.log(` ℹ️ TP degree: ${instanceGpuCount} (auto-detected from ${answers.instanceType})`);
|
|
458
466
|
}
|
|
459
467
|
}
|
|
@@ -0,0 +1,339 @@
|
|
|
1
|
+
#!/usr/bin/env node
|
|
2
|
+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
3
|
+
// SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
/**
|
|
6
|
+
* Interactive Training Job Configuration Builder.
|
|
7
|
+
*
|
|
8
|
+
* Guides users through configuring a custom training job by prompting
|
|
9
|
+
* for technique, model, dataset, instance type, and hyperparameters.
|
|
10
|
+
* Writes the result to training/config.yaml.
|
|
11
|
+
*
|
|
12
|
+
* Invoked from do/train --interactive:
|
|
13
|
+
* node -e "import('.../train-config-builder.js').then(m => m.run({...}))"
|
|
14
|
+
*
|
|
15
|
+
* Uses @inquirer/prompts via the project's prompt-adapter.js for UX
|
|
16
|
+
* consistency with the main ml-container-creator generation flow.
|
|
17
|
+
*/
|
|
18
|
+
|
|
19
|
+
import { select, input, confirm } from '@inquirer/prompts';
|
|
20
|
+
import { readFileSync, writeFileSync, readdirSync, existsSync } from 'node:fs';
|
|
21
|
+
import { join, resolve } from 'node:path';
|
|
22
|
+
import { parseArgs } from 'node:util';
|
|
23
|
+
|
|
24
|
+
// ── YAML helpers (minimal, no dependency) ────────────────────────────────────
|
|
25
|
+
|
|
26
|
+
/**
|
|
27
|
+
* Parse a simple YAML file (flat key-value, no nesting beyond what we need).
|
|
28
|
+
* Falls back gracefully if format is unexpected.
|
|
29
|
+
*/
|
|
30
|
+
function parseSimpleYaml(content) {
|
|
31
|
+
const result = {};
|
|
32
|
+
for (const line of content.split('\n')) {
|
|
33
|
+
const trimmed = line.trim();
|
|
34
|
+
if (!trimmed || trimmed.startsWith('#')) continue;
|
|
35
|
+
const colonIdx = trimmed.indexOf(':');
|
|
36
|
+
if (colonIdx === -1) continue;
|
|
37
|
+
const key = trimmed.slice(0, colonIdx).trim();
|
|
38
|
+
let value = trimmed.slice(colonIdx + 1).trim();
|
|
39
|
+
// Remove quotes
|
|
40
|
+
if ((value.startsWith('"') && value.endsWith('"')) ||
|
|
41
|
+
(value.startsWith('\'') && value.endsWith('\''))) {
|
|
42
|
+
value = value.slice(1, -1);
|
|
43
|
+
}
|
|
44
|
+
// Type coercion
|
|
45
|
+
if (value === 'true') result[key] = true;
|
|
46
|
+
else if (value === 'false') result[key] = false;
|
|
47
|
+
else if (value === '' || value === '""' || value === '\'\'') result[key] = '';
|
|
48
|
+
else if (!isNaN(value) && value !== '') result[key] = Number(value);
|
|
49
|
+
else result[key] = value;
|
|
50
|
+
}
|
|
51
|
+
return result;
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
// ── Technique scanning ───────────────────────────────────────────────────────
|
|
55
|
+
|
|
56
|
+
function scanTechniques(trainingDir) {
|
|
57
|
+
const techniques = [];
|
|
58
|
+
try {
|
|
59
|
+
const entries = readdirSync(trainingDir, { withFileTypes: true });
|
|
60
|
+
for (const entry of entries) {
|
|
61
|
+
if (entry.isDirectory()) {
|
|
62
|
+
const trainScript = join(trainingDir, entry.name, 'train.py');
|
|
63
|
+
if (existsSync(trainScript)) {
|
|
64
|
+
techniques.push(entry.name);
|
|
65
|
+
}
|
|
66
|
+
}
|
|
67
|
+
}
|
|
68
|
+
} catch {
|
|
69
|
+
// Directory doesn't exist or not readable
|
|
70
|
+
}
|
|
71
|
+
return techniques.length > 0 ? techniques : ['custom'];
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
// ── Prompts.json loading ─────────────────────────────────────────────────────
|
|
75
|
+
|
|
76
|
+
function loadTechniquePrompts(trainingDir, technique) {
|
|
77
|
+
const promptsFile = join(trainingDir, technique, 'prompts.json');
|
|
78
|
+
if (!existsSync(promptsFile)) return null;
|
|
79
|
+
try {
|
|
80
|
+
return JSON.parse(readFileSync(promptsFile, 'utf8'));
|
|
81
|
+
} catch {
|
|
82
|
+
return null;
|
|
83
|
+
}
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
// ── Defaults loading ─────────────────────────────────────────────────────────
|
|
87
|
+
|
|
88
|
+
function loadTechniqueDefaults(trainingDir, technique) {
|
|
89
|
+
const defaultsFile = join(trainingDir, technique, 'defaults.yaml');
|
|
90
|
+
if (!existsSync(defaultsFile)) return {};
|
|
91
|
+
try {
|
|
92
|
+
return parseSimpleYaml(readFileSync(defaultsFile, 'utf8'));
|
|
93
|
+
} catch {
|
|
94
|
+
return {};
|
|
95
|
+
}
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
// ── Main interactive flow ────────────────────────────────────────────────────
|
|
99
|
+
|
|
100
|
+
export async function run({ configFile, trainingDir }) {
|
|
101
|
+
const configPath = resolve(configFile);
|
|
102
|
+
const trainingPath = resolve(trainingDir);
|
|
103
|
+
|
|
104
|
+
// Resolve output_path from bootstrap profile if not already in config
|
|
105
|
+
let profileOutputPath = '';
|
|
106
|
+
try {
|
|
107
|
+
const homedir = process.env.HOME || process.env.USERPROFILE || '';
|
|
108
|
+
const profilePath = join(homedir, '.ml-container-creator', 'config.json');
|
|
109
|
+
if (existsSync(profilePath)) {
|
|
110
|
+
const profileData = JSON.parse(readFileSync(profilePath, 'utf8'));
|
|
111
|
+
const activeProfile = profileData.profiles?.[profileData.activeProfile] || {};
|
|
112
|
+
const bucket = activeProfile.benchmarkS3Bucket || '';
|
|
113
|
+
if (bucket) {
|
|
114
|
+
// Derive project name from training dir (parent dir name)
|
|
115
|
+
const projectName = resolve(trainingPath, '..').split('/').pop();
|
|
116
|
+
profileOutputPath = `s3://${bucket}/${projectName}/training-output/`;
|
|
117
|
+
}
|
|
118
|
+
}
|
|
119
|
+
} catch { /* best-effort */ }
|
|
120
|
+
|
|
121
|
+
// Load existing config as defaults
|
|
122
|
+
let existingConfig = {};
|
|
123
|
+
if (existsSync(configPath)) {
|
|
124
|
+
try {
|
|
125
|
+
existingConfig = parseSimpleYaml(readFileSync(configPath, 'utf8'));
|
|
126
|
+
} catch {
|
|
127
|
+
// Ignore parse errors — start fresh
|
|
128
|
+
}
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
console.log('');
|
|
132
|
+
console.log('🏋️ Custom Training Job Builder');
|
|
133
|
+
console.log('━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━');
|
|
134
|
+
console.log('');
|
|
135
|
+
|
|
136
|
+
// ── Technique selection ──────────────────────────────────────────────────
|
|
137
|
+
const techniques = scanTechniques(trainingPath);
|
|
138
|
+
const technique = await select({
|
|
139
|
+
message: 'Training technique?',
|
|
140
|
+
choices: techniques.map(t => ({ name: t, value: t })),
|
|
141
|
+
default: existingConfig.technique || 'sft'
|
|
142
|
+
});
|
|
143
|
+
|
|
144
|
+
// ── Common questions ─────────────────────────────────────────────────────
|
|
145
|
+
const modelId = await input({
|
|
146
|
+
message: 'Base model (HuggingFace ID)?',
|
|
147
|
+
default: existingConfig.model_id || process.env.HF_MODEL_ID || 'Qwen/Qwen3-0.6B'
|
|
148
|
+
});
|
|
149
|
+
|
|
150
|
+
const dataset = await input({
|
|
151
|
+
message: 'Dataset (hf://org/name, s3://..., or registry name)?',
|
|
152
|
+
default: existingConfig.dataset || ''
|
|
153
|
+
});
|
|
154
|
+
|
|
155
|
+
const instanceType = await input({
|
|
156
|
+
message: 'Instance type?',
|
|
157
|
+
default: existingConfig.instance_type || 'ml.g5.xlarge'
|
|
158
|
+
});
|
|
159
|
+
|
|
160
|
+
// ── Load technique defaults for hyperparam questions ─────────────────────
|
|
161
|
+
const defaults = loadTechniqueDefaults(trainingPath, technique);
|
|
162
|
+
|
|
163
|
+
const epochs = await input({
|
|
164
|
+
message: 'Epochs?',
|
|
165
|
+
default: String(existingConfig.epochs || defaults.epochs || 3),
|
|
166
|
+
validate: (v) => !isNaN(v) && Number(v) > 0 ? true : 'Must be a positive number'
|
|
167
|
+
});
|
|
168
|
+
|
|
169
|
+
const learningRate = await input({
|
|
170
|
+
message: 'Learning rate?',
|
|
171
|
+
default: String(existingConfig.learning_rate || defaults.learning_rate || '2e-4'),
|
|
172
|
+
validate: (v) => !isNaN(parseFloat(v)) ? true : 'Must be a number'
|
|
173
|
+
});
|
|
174
|
+
|
|
175
|
+
const loraR = await input({
|
|
176
|
+
message: 'LoRA rank (r)?',
|
|
177
|
+
default: String(existingConfig.lora_r || defaults.lora_r || 16),
|
|
178
|
+
validate: (v) => !isNaN(v) && Number(v) > 0 ? true : 'Must be a positive integer'
|
|
179
|
+
});
|
|
180
|
+
|
|
181
|
+
// ── Technique-specific prompts ───────────────────────────────────────────
|
|
182
|
+
const techniquePromptsSchema = loadTechniquePrompts(trainingPath, technique);
|
|
183
|
+
const techniqueAnswers = {};
|
|
184
|
+
|
|
185
|
+
if (techniquePromptsSchema && techniquePromptsSchema.prompts) {
|
|
186
|
+
console.log('');
|
|
187
|
+
console.log(`─── ${techniquePromptsSchema.section_title || `${technique} settings`} ───`);
|
|
188
|
+
|
|
189
|
+
for (const prompt of techniquePromptsSchema.prompts) {
|
|
190
|
+
const existingVal = existingConfig[prompt.name];
|
|
191
|
+
const defaultVal = existingVal !== null && existingVal !== undefined ? String(existingVal) :
|
|
192
|
+
(defaults[prompt.name] !== null && defaults[prompt.name] !== undefined ? String(defaults[prompt.name]) :
|
|
193
|
+
(prompt.default || ''));
|
|
194
|
+
|
|
195
|
+
const answer = await input({
|
|
196
|
+
message: `${prompt.message}`,
|
|
197
|
+
default: defaultVal,
|
|
198
|
+
validate: (v) => {
|
|
199
|
+
if (prompt.validate === 'float') return !isNaN(parseFloat(v)) ? true : 'Must be a number';
|
|
200
|
+
if (prompt.validate === 'int') return !isNaN(parseInt(v)) ? true : 'Must be an integer';
|
|
201
|
+
return true;
|
|
202
|
+
}
|
|
203
|
+
});
|
|
204
|
+
techniqueAnswers[prompt.name] = answer;
|
|
205
|
+
}
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
// ── Build config ─────────────────────────────────────────────────────────
|
|
209
|
+
const hyperparameters = {
|
|
210
|
+
epochs,
|
|
211
|
+
learning_rate: learningRate,
|
|
212
|
+
lora_r: loraR,
|
|
213
|
+
...techniqueAnswers
|
|
214
|
+
};
|
|
215
|
+
|
|
216
|
+
// ── Write config ─────────────────────────────────────────────────────────
|
|
217
|
+
// Build YAML output (preserving the original file structure where possible)
|
|
218
|
+
const yamlLines = [
|
|
219
|
+
'# do/training/config.yaml — Generated by interactive builder',
|
|
220
|
+
`# Technique: ${technique}`,
|
|
221
|
+
`# Generated: ${new Date().toISOString()}`,
|
|
222
|
+
'',
|
|
223
|
+
`technique: "${technique}"`,
|
|
224
|
+
'',
|
|
225
|
+
'# Base model',
|
|
226
|
+
`model_id: "${modelId}"`,
|
|
227
|
+
'',
|
|
228
|
+
'# Dataset',
|
|
229
|
+
`dataset: "${dataset}"`,
|
|
230
|
+
'',
|
|
231
|
+
'# Instance',
|
|
232
|
+
`instance_type: "${instanceType}"`,
|
|
233
|
+
`instance_count: ${existingConfig.instance_count || 1}`,
|
|
234
|
+
'',
|
|
235
|
+
'# Container image',
|
|
236
|
+
`image: "${existingConfig.image || ''}"`,
|
|
237
|
+
'',
|
|
238
|
+
'# Script (auto-selected from technique)',
|
|
239
|
+
`script: "do/training/${technique}/train.py"`,
|
|
240
|
+
'',
|
|
241
|
+
'# Output',
|
|
242
|
+
`output_path: "${existingConfig.output_path || profileOutputPath}"`,
|
|
243
|
+
'',
|
|
244
|
+
'# Hyperparameters',
|
|
245
|
+
'hyperparameters:'
|
|
246
|
+
];
|
|
247
|
+
|
|
248
|
+
for (const [key, val] of Object.entries(hyperparameters)) {
|
|
249
|
+
yamlLines.push(` ${key}: "${val}"`);
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
// Preserve other existing fields
|
|
253
|
+
if (existingConfig.max_runtime_seconds) {
|
|
254
|
+
yamlLines.push('', `max_runtime_seconds: ${existingConfig.max_runtime_seconds}`);
|
|
255
|
+
}
|
|
256
|
+
if (existingConfig.volume_size_gb) {
|
|
257
|
+
yamlLines.push(`volume_size_gb: ${existingConfig.volume_size_gb}`);
|
|
258
|
+
}
|
|
259
|
+
if (existingConfig.enable_spot) {
|
|
260
|
+
yamlLines.push(`enable_spot: ${existingConfig.enable_spot}`);
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
yamlLines.push('');
|
|
264
|
+
writeFileSync(configPath, yamlLines.join('\n'), 'utf8');
|
|
265
|
+
|
|
266
|
+
// ── Summary ──────────────────────────────────────────────────────────────
|
|
267
|
+
console.log('');
|
|
268
|
+
console.log('✅ Configuration written to training/config.yaml');
|
|
269
|
+
console.log('');
|
|
270
|
+
console.log(` technique: ${technique}`);
|
|
271
|
+
console.log(` model: ${modelId}`);
|
|
272
|
+
console.log(` dataset: ${dataset || '(none)'}`);
|
|
273
|
+
console.log(` instance_type: ${instanceType}`);
|
|
274
|
+
console.log(` epochs: ${epochs}`);
|
|
275
|
+
console.log(` learning_rate: ${learningRate}`);
|
|
276
|
+
console.log(` lora_r: ${loraR}`);
|
|
277
|
+
if (Object.keys(techniqueAnswers).length > 0) {
|
|
278
|
+
for (const [k, v] of Object.entries(techniqueAnswers)) {
|
|
279
|
+
console.log(` ${k}: ${v}`);
|
|
280
|
+
}
|
|
281
|
+
}
|
|
282
|
+
console.log('');
|
|
283
|
+
|
|
284
|
+
// ── Run now? ─────────────────────────────────────────────────────────────
|
|
285
|
+
const runNow = await confirm({
|
|
286
|
+
message: 'Run training job now?',
|
|
287
|
+
default: false
|
|
288
|
+
});
|
|
289
|
+
|
|
290
|
+
// Output JSON for bash consumption
|
|
291
|
+
const resultObj = {
|
|
292
|
+
config_written: true,
|
|
293
|
+
technique,
|
|
294
|
+
run_now: runNow
|
|
295
|
+
};
|
|
296
|
+
|
|
297
|
+
// Print to stdout (for CLI entry point / backward compat)
|
|
298
|
+
console.log(JSON.stringify(resultObj));
|
|
299
|
+
|
|
300
|
+
// Return for programmatic callers (do/train writes to temp file)
|
|
301
|
+
return resultObj;
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
// ── CLI entry point ──────────────────────────────────────────────────────────
|
|
305
|
+
|
|
306
|
+
async function main() {
|
|
307
|
+
const { values } = parseArgs({
|
|
308
|
+
options: {
|
|
309
|
+
'config-file': { type: 'string' },
|
|
310
|
+
'training-dir': { type: 'string' }
|
|
311
|
+
}
|
|
312
|
+
});
|
|
313
|
+
|
|
314
|
+
const configFile = values['config-file'];
|
|
315
|
+
const trainingDir = values['training-dir'];
|
|
316
|
+
|
|
317
|
+
if (!configFile || !trainingDir) {
|
|
318
|
+
console.error('Usage: train-config-builder --config-file <path> --training-dir <path>');
|
|
319
|
+
process.exit(1);
|
|
320
|
+
}
|
|
321
|
+
|
|
322
|
+
try {
|
|
323
|
+
await run({ configFile, trainingDir });
|
|
324
|
+
} catch (err) {
|
|
325
|
+
if (err.name === 'ExitPromptError') {
|
|
326
|
+
// User pressed Ctrl+C
|
|
327
|
+
console.log('\n⚠️ Cancelled.');
|
|
328
|
+
process.exit(130);
|
|
329
|
+
}
|
|
330
|
+
console.error(`❌ Error: ${err.message}`);
|
|
331
|
+
process.exit(1);
|
|
332
|
+
}
|
|
333
|
+
}
|
|
334
|
+
|
|
335
|
+
// Run if invoked directly
|
|
336
|
+
const isMainModule = process.argv[1] && resolve(process.argv[1]) === resolve(new URL(import.meta.url).pathname);
|
|
337
|
+
if (isMainModule) {
|
|
338
|
+
main();
|
|
339
|
+
}
|