@aws/ml-container-creator 1.0.2 → 1.0.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +1 -1
- package/bin/cli.js +1 -1
- package/config/tune-catalog.json +303 -1
- package/infra/ci-harness/lib/ci-harness-stack.ts +43 -0
- package/package.json +3 -2
- package/servers/base-image-picker/index.js +65 -18
- package/servers/instance-sizer/index.js +32 -0
- package/servers/lib/catalogs/fleet-drivers.json +38 -0
- package/servers/lib/catalogs/model-arch-support.json +51 -0
- package/servers/lib/catalogs/model-servers.json +2842 -1516
- package/servers/lib/schemas/image-catalog.schema.json +12 -0
- package/src/app.js +6 -4
- package/src/lib/bootstrap-command-handler.js +12 -2
- package/src/lib/bootstrap-profile-manager.js +16 -0
- package/src/lib/cross-cutting-checker.js +6 -1
- package/src/lib/generated/cli-options.js +1 -1
- package/src/lib/generated/parameter-matrix.js +1 -1
- package/src/lib/generated/validation-rules.js +1 -1
- package/src/lib/mcp-query-runner.js +110 -3
- package/src/lib/prompt-runner.js +66 -22
- package/src/lib/template-variable-resolver.js +8 -0
- package/src/lib/train-config-builder.js +339 -0
- package/templates/do/.benchmark_writer.py +3 -0
- package/templates/do/.eval_helper.py +409 -0
- package/templates/do/.register_helper.py +185 -11
- package/templates/do/.train_build_request.py +102 -113
- package/templates/do/.train_helper.py +433 -0
- package/templates/do/__pycache__/.register_helper.cpython-312.pyc +0 -0
- package/templates/do/adapter +157 -0
- package/templates/do/benchmark +60 -3
- package/templates/do/deploy.d/managed-inference.ejs +83 -0
- package/templates/do/evaluate +272 -0
- package/templates/do/lib/resolve-instance.sh +155 -0
- package/templates/do/register +5 -0
- package/templates/do/test +1 -0
- package/templates/do/train +879 -126
- package/templates/do/training/config.yaml +83 -11
- package/templates/do/training/dpo/accelerate_config.yaml +24 -0
- package/templates/do/training/dpo/defaults.yaml +26 -0
- package/templates/do/training/dpo/prompts.json +8 -0
- package/templates/do/training/dpo/train.py +363 -0
- package/templates/do/training/sft/accelerate_config.yaml +22 -0
- package/templates/do/training/sft/defaults.yaml +18 -0
- package/templates/do/training/sft/prompts.json +7 -0
- package/templates/do/training/sft/train.py +310 -0
- package/templates/do/tune +11 -2
- package/templates/do/.train_poll_parser.py +0 -135
- package/templates/do/.train_status_parser.py +0 -187
- /package/templates/do/training/{train.py → custom/train.py} +0 -0
|
@@ -0,0 +1,339 @@
|
|
|
1
|
+
#!/usr/bin/env node
|
|
2
|
+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
3
|
+
// SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
/**
|
|
6
|
+
* Interactive Training Job Configuration Builder.
|
|
7
|
+
*
|
|
8
|
+
* Guides users through configuring a custom training job by prompting
|
|
9
|
+
* for technique, model, dataset, instance type, and hyperparameters.
|
|
10
|
+
* Writes the result to training/config.yaml.
|
|
11
|
+
*
|
|
12
|
+
* Invoked from do/train --interactive:
|
|
13
|
+
* node -e "import('.../train-config-builder.js').then(m => m.run({...}))"
|
|
14
|
+
*
|
|
15
|
+
* Uses @inquirer/prompts via the project's prompt-adapter.js for UX
|
|
16
|
+
* consistency with the main ml-container-creator generation flow.
|
|
17
|
+
*/
|
|
18
|
+
|
|
19
|
+
import { select, input, confirm } from '@inquirer/prompts';
|
|
20
|
+
import { readFileSync, writeFileSync, readdirSync, existsSync } from 'node:fs';
|
|
21
|
+
import { join, resolve } from 'node:path';
|
|
22
|
+
import { parseArgs } from 'node:util';
|
|
23
|
+
|
|
24
|
+
// ── YAML helpers (minimal, no dependency) ────────────────────────────────────
|
|
25
|
+
|
|
26
|
+
/**
|
|
27
|
+
* Parse a simple YAML file (flat key-value, no nesting beyond what we need).
|
|
28
|
+
* Falls back gracefully if format is unexpected.
|
|
29
|
+
*/
|
|
30
|
+
function parseSimpleYaml(content) {
|
|
31
|
+
const result = {};
|
|
32
|
+
for (const line of content.split('\n')) {
|
|
33
|
+
const trimmed = line.trim();
|
|
34
|
+
if (!trimmed || trimmed.startsWith('#')) continue;
|
|
35
|
+
const colonIdx = trimmed.indexOf(':');
|
|
36
|
+
if (colonIdx === -1) continue;
|
|
37
|
+
const key = trimmed.slice(0, colonIdx).trim();
|
|
38
|
+
let value = trimmed.slice(colonIdx + 1).trim();
|
|
39
|
+
// Remove quotes
|
|
40
|
+
if ((value.startsWith('"') && value.endsWith('"')) ||
|
|
41
|
+
(value.startsWith('\'') && value.endsWith('\''))) {
|
|
42
|
+
value = value.slice(1, -1);
|
|
43
|
+
}
|
|
44
|
+
// Type coercion
|
|
45
|
+
if (value === 'true') result[key] = true;
|
|
46
|
+
else if (value === 'false') result[key] = false;
|
|
47
|
+
else if (value === '' || value === '""' || value === '\'\'') result[key] = '';
|
|
48
|
+
else if (!isNaN(value) && value !== '') result[key] = Number(value);
|
|
49
|
+
else result[key] = value;
|
|
50
|
+
}
|
|
51
|
+
return result;
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
// ── Technique scanning ───────────────────────────────────────────────────────
|
|
55
|
+
|
|
56
|
+
function scanTechniques(trainingDir) {
|
|
57
|
+
const techniques = [];
|
|
58
|
+
try {
|
|
59
|
+
const entries = readdirSync(trainingDir, { withFileTypes: true });
|
|
60
|
+
for (const entry of entries) {
|
|
61
|
+
if (entry.isDirectory()) {
|
|
62
|
+
const trainScript = join(trainingDir, entry.name, 'train.py');
|
|
63
|
+
if (existsSync(trainScript)) {
|
|
64
|
+
techniques.push(entry.name);
|
|
65
|
+
}
|
|
66
|
+
}
|
|
67
|
+
}
|
|
68
|
+
} catch {
|
|
69
|
+
// Directory doesn't exist or not readable
|
|
70
|
+
}
|
|
71
|
+
return techniques.length > 0 ? techniques : ['custom'];
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
// ── Prompts.json loading ─────────────────────────────────────────────────────
|
|
75
|
+
|
|
76
|
+
function loadTechniquePrompts(trainingDir, technique) {
|
|
77
|
+
const promptsFile = join(trainingDir, technique, 'prompts.json');
|
|
78
|
+
if (!existsSync(promptsFile)) return null;
|
|
79
|
+
try {
|
|
80
|
+
return JSON.parse(readFileSync(promptsFile, 'utf8'));
|
|
81
|
+
} catch {
|
|
82
|
+
return null;
|
|
83
|
+
}
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
// ── Defaults loading ─────────────────────────────────────────────────────────
|
|
87
|
+
|
|
88
|
+
function loadTechniqueDefaults(trainingDir, technique) {
|
|
89
|
+
const defaultsFile = join(trainingDir, technique, 'defaults.yaml');
|
|
90
|
+
if (!existsSync(defaultsFile)) return {};
|
|
91
|
+
try {
|
|
92
|
+
return parseSimpleYaml(readFileSync(defaultsFile, 'utf8'));
|
|
93
|
+
} catch {
|
|
94
|
+
return {};
|
|
95
|
+
}
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
// ── Main interactive flow ────────────────────────────────────────────────────
|
|
99
|
+
|
|
100
|
+
export async function run({ configFile, trainingDir }) {
|
|
101
|
+
const configPath = resolve(configFile);
|
|
102
|
+
const trainingPath = resolve(trainingDir);
|
|
103
|
+
|
|
104
|
+
// Resolve output_path from bootstrap profile if not already in config
|
|
105
|
+
let profileOutputPath = '';
|
|
106
|
+
try {
|
|
107
|
+
const homedir = process.env.HOME || process.env.USERPROFILE || '';
|
|
108
|
+
const profilePath = join(homedir, '.ml-container-creator', 'config.json');
|
|
109
|
+
if (existsSync(profilePath)) {
|
|
110
|
+
const profileData = JSON.parse(readFileSync(profilePath, 'utf8'));
|
|
111
|
+
const activeProfile = profileData.profiles?.[profileData.activeProfile] || {};
|
|
112
|
+
const bucket = activeProfile.benchmarkS3Bucket || '';
|
|
113
|
+
if (bucket) {
|
|
114
|
+
// Derive project name from training dir (parent dir name)
|
|
115
|
+
const projectName = resolve(trainingPath, '..').split('/').pop();
|
|
116
|
+
profileOutputPath = `s3://${bucket}/${projectName}/training-output/`;
|
|
117
|
+
}
|
|
118
|
+
}
|
|
119
|
+
} catch { /* best-effort */ }
|
|
120
|
+
|
|
121
|
+
// Load existing config as defaults
|
|
122
|
+
let existingConfig = {};
|
|
123
|
+
if (existsSync(configPath)) {
|
|
124
|
+
try {
|
|
125
|
+
existingConfig = parseSimpleYaml(readFileSync(configPath, 'utf8'));
|
|
126
|
+
} catch {
|
|
127
|
+
// Ignore parse errors — start fresh
|
|
128
|
+
}
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
console.log('');
|
|
132
|
+
console.log('🏋️ Custom Training Job Builder');
|
|
133
|
+
console.log('━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━');
|
|
134
|
+
console.log('');
|
|
135
|
+
|
|
136
|
+
// ── Technique selection ──────────────────────────────────────────────────
|
|
137
|
+
const techniques = scanTechniques(trainingPath);
|
|
138
|
+
const technique = await select({
|
|
139
|
+
message: 'Training technique?',
|
|
140
|
+
choices: techniques.map(t => ({ name: t, value: t })),
|
|
141
|
+
default: existingConfig.technique || 'sft'
|
|
142
|
+
});
|
|
143
|
+
|
|
144
|
+
// ── Common questions ─────────────────────────────────────────────────────
|
|
145
|
+
const modelId = await input({
|
|
146
|
+
message: 'Base model (HuggingFace ID)?',
|
|
147
|
+
default: existingConfig.model_id || process.env.HF_MODEL_ID || 'Qwen/Qwen3-0.6B'
|
|
148
|
+
});
|
|
149
|
+
|
|
150
|
+
const dataset = await input({
|
|
151
|
+
message: 'Dataset (hf://org/name, s3://..., or registry name)?',
|
|
152
|
+
default: existingConfig.dataset || ''
|
|
153
|
+
});
|
|
154
|
+
|
|
155
|
+
const instanceType = await input({
|
|
156
|
+
message: 'Instance type?',
|
|
157
|
+
default: existingConfig.instance_type || 'ml.g5.xlarge'
|
|
158
|
+
});
|
|
159
|
+
|
|
160
|
+
// ── Load technique defaults for hyperparam questions ─────────────────────
|
|
161
|
+
const defaults = loadTechniqueDefaults(trainingPath, technique);
|
|
162
|
+
|
|
163
|
+
const epochs = await input({
|
|
164
|
+
message: 'Epochs?',
|
|
165
|
+
default: String(existingConfig.epochs || defaults.epochs || 3),
|
|
166
|
+
validate: (v) => !isNaN(v) && Number(v) > 0 ? true : 'Must be a positive number'
|
|
167
|
+
});
|
|
168
|
+
|
|
169
|
+
const learningRate = await input({
|
|
170
|
+
message: 'Learning rate?',
|
|
171
|
+
default: String(existingConfig.learning_rate || defaults.learning_rate || '2e-4'),
|
|
172
|
+
validate: (v) => !isNaN(parseFloat(v)) ? true : 'Must be a number'
|
|
173
|
+
});
|
|
174
|
+
|
|
175
|
+
const loraR = await input({
|
|
176
|
+
message: 'LoRA rank (r)?',
|
|
177
|
+
default: String(existingConfig.lora_r || defaults.lora_r || 16),
|
|
178
|
+
validate: (v) => !isNaN(v) && Number(v) > 0 ? true : 'Must be a positive integer'
|
|
179
|
+
});
|
|
180
|
+
|
|
181
|
+
// ── Technique-specific prompts ───────────────────────────────────────────
|
|
182
|
+
const techniquePromptsSchema = loadTechniquePrompts(trainingPath, technique);
|
|
183
|
+
const techniqueAnswers = {};
|
|
184
|
+
|
|
185
|
+
if (techniquePromptsSchema && techniquePromptsSchema.prompts) {
|
|
186
|
+
console.log('');
|
|
187
|
+
console.log(`─── ${techniquePromptsSchema.section_title || `${technique} settings`} ───`);
|
|
188
|
+
|
|
189
|
+
for (const prompt of techniquePromptsSchema.prompts) {
|
|
190
|
+
const existingVal = existingConfig[prompt.name];
|
|
191
|
+
const defaultVal = existingVal !== null && existingVal !== undefined ? String(existingVal) :
|
|
192
|
+
(defaults[prompt.name] !== null && defaults[prompt.name] !== undefined ? String(defaults[prompt.name]) :
|
|
193
|
+
(prompt.default || ''));
|
|
194
|
+
|
|
195
|
+
const answer = await input({
|
|
196
|
+
message: `${prompt.message}`,
|
|
197
|
+
default: defaultVal,
|
|
198
|
+
validate: (v) => {
|
|
199
|
+
if (prompt.validate === 'float') return !isNaN(parseFloat(v)) ? true : 'Must be a number';
|
|
200
|
+
if (prompt.validate === 'int') return !isNaN(parseInt(v)) ? true : 'Must be an integer';
|
|
201
|
+
return true;
|
|
202
|
+
}
|
|
203
|
+
});
|
|
204
|
+
techniqueAnswers[prompt.name] = answer;
|
|
205
|
+
}
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
// ── Build config ─────────────────────────────────────────────────────────
|
|
209
|
+
const hyperparameters = {
|
|
210
|
+
epochs,
|
|
211
|
+
learning_rate: learningRate,
|
|
212
|
+
lora_r: loraR,
|
|
213
|
+
...techniqueAnswers
|
|
214
|
+
};
|
|
215
|
+
|
|
216
|
+
// ── Write config ─────────────────────────────────────────────────────────
|
|
217
|
+
// Build YAML output (preserving the original file structure where possible)
|
|
218
|
+
const yamlLines = [
|
|
219
|
+
'# do/training/config.yaml — Generated by interactive builder',
|
|
220
|
+
`# Technique: ${technique}`,
|
|
221
|
+
`# Generated: ${new Date().toISOString()}`,
|
|
222
|
+
'',
|
|
223
|
+
`technique: "${technique}"`,
|
|
224
|
+
'',
|
|
225
|
+
'# Base model',
|
|
226
|
+
`model_id: "${modelId}"`,
|
|
227
|
+
'',
|
|
228
|
+
'# Dataset',
|
|
229
|
+
`dataset: "${dataset}"`,
|
|
230
|
+
'',
|
|
231
|
+
'# Instance',
|
|
232
|
+
`instance_type: "${instanceType}"`,
|
|
233
|
+
`instance_count: ${existingConfig.instance_count || 1}`,
|
|
234
|
+
'',
|
|
235
|
+
'# Container image',
|
|
236
|
+
`image: "${existingConfig.image || ''}"`,
|
|
237
|
+
'',
|
|
238
|
+
'# Script (auto-selected from technique)',
|
|
239
|
+
`script: "do/training/${technique}/train.py"`,
|
|
240
|
+
'',
|
|
241
|
+
'# Output',
|
|
242
|
+
`output_path: "${existingConfig.output_path || profileOutputPath}"`,
|
|
243
|
+
'',
|
|
244
|
+
'# Hyperparameters',
|
|
245
|
+
'hyperparameters:'
|
|
246
|
+
];
|
|
247
|
+
|
|
248
|
+
for (const [key, val] of Object.entries(hyperparameters)) {
|
|
249
|
+
yamlLines.push(` ${key}: "${val}"`);
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
// Preserve other existing fields
|
|
253
|
+
if (existingConfig.max_runtime_seconds) {
|
|
254
|
+
yamlLines.push('', `max_runtime_seconds: ${existingConfig.max_runtime_seconds}`);
|
|
255
|
+
}
|
|
256
|
+
if (existingConfig.volume_size_gb) {
|
|
257
|
+
yamlLines.push(`volume_size_gb: ${existingConfig.volume_size_gb}`);
|
|
258
|
+
}
|
|
259
|
+
if (existingConfig.enable_spot) {
|
|
260
|
+
yamlLines.push(`enable_spot: ${existingConfig.enable_spot}`);
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
yamlLines.push('');
|
|
264
|
+
writeFileSync(configPath, yamlLines.join('\n'), 'utf8');
|
|
265
|
+
|
|
266
|
+
// ── Summary ──────────────────────────────────────────────────────────────
|
|
267
|
+
console.log('');
|
|
268
|
+
console.log('✅ Configuration written to training/config.yaml');
|
|
269
|
+
console.log('');
|
|
270
|
+
console.log(` technique: ${technique}`);
|
|
271
|
+
console.log(` model: ${modelId}`);
|
|
272
|
+
console.log(` dataset: ${dataset || '(none)'}`);
|
|
273
|
+
console.log(` instance_type: ${instanceType}`);
|
|
274
|
+
console.log(` epochs: ${epochs}`);
|
|
275
|
+
console.log(` learning_rate: ${learningRate}`);
|
|
276
|
+
console.log(` lora_r: ${loraR}`);
|
|
277
|
+
if (Object.keys(techniqueAnswers).length > 0) {
|
|
278
|
+
for (const [k, v] of Object.entries(techniqueAnswers)) {
|
|
279
|
+
console.log(` ${k}: ${v}`);
|
|
280
|
+
}
|
|
281
|
+
}
|
|
282
|
+
console.log('');
|
|
283
|
+
|
|
284
|
+
// ── Run now? ─────────────────────────────────────────────────────────────
|
|
285
|
+
const runNow = await confirm({
|
|
286
|
+
message: 'Run training job now?',
|
|
287
|
+
default: false
|
|
288
|
+
});
|
|
289
|
+
|
|
290
|
+
// Output JSON for bash consumption
|
|
291
|
+
const resultObj = {
|
|
292
|
+
config_written: true,
|
|
293
|
+
technique,
|
|
294
|
+
run_now: runNow
|
|
295
|
+
};
|
|
296
|
+
|
|
297
|
+
// Print to stdout (for CLI entry point / backward compat)
|
|
298
|
+
console.log(JSON.stringify(resultObj));
|
|
299
|
+
|
|
300
|
+
// Return for programmatic callers (do/train writes to temp file)
|
|
301
|
+
return resultObj;
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
// ── CLI entry point ──────────────────────────────────────────────────────────
|
|
305
|
+
|
|
306
|
+
async function main() {
|
|
307
|
+
const { values } = parseArgs({
|
|
308
|
+
options: {
|
|
309
|
+
'config-file': { type: 'string' },
|
|
310
|
+
'training-dir': { type: 'string' }
|
|
311
|
+
}
|
|
312
|
+
});
|
|
313
|
+
|
|
314
|
+
const configFile = values['config-file'];
|
|
315
|
+
const trainingDir = values['training-dir'];
|
|
316
|
+
|
|
317
|
+
if (!configFile || !trainingDir) {
|
|
318
|
+
console.error('Usage: train-config-builder --config-file <path> --training-dir <path>');
|
|
319
|
+
process.exit(1);
|
|
320
|
+
}
|
|
321
|
+
|
|
322
|
+
try {
|
|
323
|
+
await run({ configFile, trainingDir });
|
|
324
|
+
} catch (err) {
|
|
325
|
+
if (err.name === 'ExitPromptError') {
|
|
326
|
+
// User pressed Ctrl+C
|
|
327
|
+
console.log('\n⚠️ Cancelled.');
|
|
328
|
+
process.exit(130);
|
|
329
|
+
}
|
|
330
|
+
console.error(`❌ Error: ${err.message}`);
|
|
331
|
+
process.exit(1);
|
|
332
|
+
}
|
|
333
|
+
}
|
|
334
|
+
|
|
335
|
+
// Run if invoked directly
|
|
336
|
+
const isMainModule = process.argv[1] && resolve(process.argv[1]) === resolve(new URL(import.meta.url).pathname);
|
|
337
|
+
if (isMainModule) {
|
|
338
|
+
main();
|
|
339
|
+
}
|
|
@@ -1478,6 +1478,7 @@ def _load_config_file(config_path):
|
|
|
1478
1478
|
'HF_MODEL_ID': 'hf_model_id',
|
|
1479
1479
|
'INSTANCE_TYPE': 'instance_type',
|
|
1480
1480
|
'INSTANCE_POOLS': 'instance_pools',
|
|
1481
|
+
'DEPLOYED_INSTANCE_TYPE': 'deployed_instance_type',
|
|
1481
1482
|
'BENCHMARK_INSTANCE_TYPE': 'benchmark_instance_type',
|
|
1482
1483
|
'DEPLOYMENT_CONFIG': 'deployment_config',
|
|
1483
1484
|
'DEPLOYMENT_TARGET': 'deployment_target',
|
|
@@ -1521,6 +1522,8 @@ def _load_config_file(config_path):
|
|
|
1521
1522
|
# BENCHMARK_INSTANCE_TYPE (live-resolved, persisted by do/benchmark) > INSTANCE_TYPE > INSTANCE_POOLS fallback
|
|
1522
1523
|
if context.get('benchmark_instance_type'):
|
|
1523
1524
|
context['instance_type'] = context.pop('benchmark_instance_type')
|
|
1525
|
+
elif context.get('deployed_instance_type'):
|
|
1526
|
+
context['instance_type'] = context.pop('deployed_instance_type')
|
|
1524
1527
|
# Fall back to INSTANCE_POOLS when neither is set.
|
|
1525
1528
|
# Heterogeneous pool configs may not have a standalone INSTANCE_TYPE value
|
|
1526
1529
|
# but always define INSTANCE_POOLS as a JSON array with Priority fields.
|