@aws/ml-container-creator 0.8.0 → 0.9.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE-THIRD-PARTY +50760 -16218
- package/package.json +3 -1
- package/servers/lib/catalogs/instances.json +52 -1275
- package/servers/lib/catalogs/models.json +0 -132
- package/servers/lib/catalogs/popular-diffusors.json +1 -110
- package/src/app.js +24 -2
- package/src/lib/mcp-client.js +16 -1
- package/src/lib/mcp-command-handler.js +10 -2
- package/src/lib/prompt-runner.js +16 -2
- package/src/lib/train-config-parser.js +136 -0
- package/src/lib/train-config-persistence.js +143 -0
- package/src/lib/train-config-validator.js +112 -0
- package/src/lib/train-feedback.js +46 -0
- package/src/lib/train-idempotency.js +97 -0
- package/src/lib/train-request-builder.js +120 -0
- package/templates/do/.train_build_request.py +141 -0
- package/templates/do/.train_poll_parser.py +135 -0
- package/templates/do/.train_status_parser.py +187 -0
- package/templates/do/lib/feedback.sh +41 -0
- package/templates/do/train +786 -0
- package/templates/do/training/config.yaml +140 -0
- package/templates/do/training/train.py +463 -0
|
@@ -555,98 +555,6 @@
|
|
|
555
555
|
"text-generation"
|
|
556
556
|
]
|
|
557
557
|
},
|
|
558
|
-
"stabilityai/stable-diffusion-3.5-medium": {
|
|
559
|
-
"family": "stable-diffusion-3",
|
|
560
|
-
"gated": false,
|
|
561
|
-
"tags": [
|
|
562
|
-
"image-generation",
|
|
563
|
-
"diffusion",
|
|
564
|
-
"stable-diffusion"
|
|
565
|
-
],
|
|
566
|
-
"architecture": "StableDiffusion3Pipeline",
|
|
567
|
-
"profiles": {
|
|
568
|
-
"default": {
|
|
569
|
-
"displayName": "SD3.5 Medium",
|
|
570
|
-
"envVars": {}
|
|
571
|
-
}
|
|
572
|
-
},
|
|
573
|
-
"notes": "Stable Diffusion 3.5 medium model. Supported natively by vLLM-Omni StableDiffusion3Pipeline.",
|
|
574
|
-
"chatTemplate": null,
|
|
575
|
-
"frameworkCompatibility": {
|
|
576
|
-
"vllm-omni": ">=0.14.0"
|
|
577
|
-
},
|
|
578
|
-
"validationLevel": "experimental",
|
|
579
|
-
"modelType": "diffusor",
|
|
580
|
-
"tasks": [
|
|
581
|
-
"text-to-image"
|
|
582
|
-
]
|
|
583
|
-
},
|
|
584
|
-
"black-forest-labs/FLUX.1-dev": {
|
|
585
|
-
"family": "flux",
|
|
586
|
-
"gated": true,
|
|
587
|
-
"tags": [
|
|
588
|
-
"image-generation",
|
|
589
|
-
"diffusion",
|
|
590
|
-
"flux"
|
|
591
|
-
],
|
|
592
|
-
"architecture": "FluxPipeline",
|
|
593
|
-
"profiles": {
|
|
594
|
-
"default": {
|
|
595
|
-
"displayName": "FLUX.1 Dev",
|
|
596
|
-
"envVars": {}
|
|
597
|
-
}
|
|
598
|
-
},
|
|
599
|
-
"notes": "FLUX.1-dev high-quality generation model. Uses dual text encoders (CLIP + T5) and FlowMatchEuler scheduler. Requires significant VRAM.",
|
|
600
|
-
"chatTemplate": null,
|
|
601
|
-
"frameworkCompatibility": {
|
|
602
|
-
"vllm-omni": ">=0.14.0"
|
|
603
|
-
},
|
|
604
|
-
"validationLevel": "experimental",
|
|
605
|
-
"modelType": "diffusor",
|
|
606
|
-
"tasks": [
|
|
607
|
-
"text-to-image"
|
|
608
|
-
]
|
|
609
|
-
},
|
|
610
|
-
"black-forest-labs/FLUX.1-schnell": {
|
|
611
|
-
"family": "flux",
|
|
612
|
-
"gated": false,
|
|
613
|
-
"tags": [
|
|
614
|
-
"image-generation",
|
|
615
|
-
"diffusion",
|
|
616
|
-
"flux"
|
|
617
|
-
],
|
|
618
|
-
"architecture": "FluxPipeline",
|
|
619
|
-
"notes": "FLUX.1-schnell fast generation model. Fewer denoising steps for faster inference at slightly lower quality",
|
|
620
|
-
"chatTemplate": null,
|
|
621
|
-
"frameworkCompatibility": {
|
|
622
|
-
"vllm-omni": ">=0.14.0"
|
|
623
|
-
},
|
|
624
|
-
"validationLevel": "experimental",
|
|
625
|
-
"modelType": "diffusor",
|
|
626
|
-
"tasks": [
|
|
627
|
-
"text-to-image"
|
|
628
|
-
]
|
|
629
|
-
},
|
|
630
|
-
"Wan-AI/Wan2.1-T2V-14B-Diffusers": {
|
|
631
|
-
"family": "wan",
|
|
632
|
-
"gated": false,
|
|
633
|
-
"tags": [
|
|
634
|
-
"video-generation",
|
|
635
|
-
"diffusion",
|
|
636
|
-
"wan"
|
|
637
|
-
],
|
|
638
|
-
"architecture": "WanPipeline",
|
|
639
|
-
"notes": "Wan2.1 text-to-video 14B model (diffusers format). Requires multi-GPU instance (ml.g5.12xlarge or larger). Must use the -Diffusers variant — the base Wan2.1-T2V-14B repo lacks model_index.json required by vLLM-Omni",
|
|
640
|
-
"chatTemplate": null,
|
|
641
|
-
"frameworkCompatibility": {
|
|
642
|
-
"vllm-omni": ">=0.16.0"
|
|
643
|
-
},
|
|
644
|
-
"validationLevel": "experimental",
|
|
645
|
-
"modelType": "diffusor",
|
|
646
|
-
"tasks": [
|
|
647
|
-
"text-to-video"
|
|
648
|
-
]
|
|
649
|
-
},
|
|
650
558
|
"meta-llama/Llama-3*": {
|
|
651
559
|
"family": "llama-3",
|
|
652
560
|
"gated": true,
|
|
@@ -731,45 +639,5 @@
|
|
|
731
639
|
"tasks": [
|
|
732
640
|
"text-generation"
|
|
733
641
|
]
|
|
734
|
-
},
|
|
735
|
-
"stabilityai/stable-diffusion-*": {
|
|
736
|
-
"family": "stable-diffusion",
|
|
737
|
-
"gated": false,
|
|
738
|
-
"tags": [
|
|
739
|
-
"image-generation",
|
|
740
|
-
"diffusion",
|
|
741
|
-
"stable-diffusion"
|
|
742
|
-
],
|
|
743
|
-
"architecture": null,
|
|
744
|
-
"notes": "Fallback for Stable Diffusion variants not explicitly listed",
|
|
745
|
-
"chatTemplate": null,
|
|
746
|
-
"frameworkCompatibility": {
|
|
747
|
-
"vllm-omni": ">=0.14.0"
|
|
748
|
-
},
|
|
749
|
-
"validationLevel": "experimental",
|
|
750
|
-
"modelType": "diffusor",
|
|
751
|
-
"tasks": [
|
|
752
|
-
"text-to-image"
|
|
753
|
-
]
|
|
754
|
-
},
|
|
755
|
-
"black-forest-labs/FLUX*": {
|
|
756
|
-
"family": "flux",
|
|
757
|
-
"gated": false,
|
|
758
|
-
"tags": [
|
|
759
|
-
"image-generation",
|
|
760
|
-
"diffusion",
|
|
761
|
-
"flux"
|
|
762
|
-
],
|
|
763
|
-
"architecture": null,
|
|
764
|
-
"notes": "Fallback for FLUX model variants not explicitly listed",
|
|
765
|
-
"chatTemplate": null,
|
|
766
|
-
"frameworkCompatibility": {
|
|
767
|
-
"vllm-omni": ">=0.14.0"
|
|
768
|
-
},
|
|
769
|
-
"validationLevel": "experimental",
|
|
770
|
-
"modelType": "diffusor",
|
|
771
|
-
"tasks": [
|
|
772
|
-
"text-to-image"
|
|
773
|
-
]
|
|
774
642
|
}
|
|
775
643
|
}
|
|
@@ -1,110 +1 @@
|
|
|
1
|
-
{
|
|
2
|
-
"stabilityai/stable-diffusion-3.5-medium": {
|
|
3
|
-
"family": "stable-diffusion-3",
|
|
4
|
-
"chat_template": null,
|
|
5
|
-
"gated": false,
|
|
6
|
-
"tags": [
|
|
7
|
-
"image-generation",
|
|
8
|
-
"diffusion",
|
|
9
|
-
"stable-diffusion"
|
|
10
|
-
],
|
|
11
|
-
"architecture": "StableDiffusion3Pipeline",
|
|
12
|
-
"framework_compatibility": {
|
|
13
|
-
"vllm-omni": ">=0.14.0"
|
|
14
|
-
},
|
|
15
|
-
"validation_level": "experimental",
|
|
16
|
-
"profiles": {
|
|
17
|
-
"default": {
|
|
18
|
-
"displayName": "SD3.5 Medium",
|
|
19
|
-
"envVars": {}
|
|
20
|
-
}
|
|
21
|
-
},
|
|
22
|
-
"notes": "Stable Diffusion 3.5 medium model. Supported natively by vLLM-Omni StableDiffusion3Pipeline."
|
|
23
|
-
},
|
|
24
|
-
"black-forest-labs/FLUX.1-dev": {
|
|
25
|
-
"family": "flux",
|
|
26
|
-
"chat_template": null,
|
|
27
|
-
"gated": true,
|
|
28
|
-
"tags": [
|
|
29
|
-
"image-generation",
|
|
30
|
-
"diffusion",
|
|
31
|
-
"flux"
|
|
32
|
-
],
|
|
33
|
-
"architecture": "FluxPipeline",
|
|
34
|
-
"framework_compatibility": {
|
|
35
|
-
"vllm-omni": ">=0.14.0"
|
|
36
|
-
},
|
|
37
|
-
"validation_level": "experimental",
|
|
38
|
-
"profiles": {
|
|
39
|
-
"default": {
|
|
40
|
-
"displayName": "FLUX.1 Dev",
|
|
41
|
-
"envVars": {}
|
|
42
|
-
}
|
|
43
|
-
},
|
|
44
|
-
"notes": "FLUX.1-dev high-quality generation model. Uses dual text encoders (CLIP + T5) and FlowMatchEuler scheduler. Requires significant VRAM."
|
|
45
|
-
},
|
|
46
|
-
"black-forest-labs/FLUX.1-schnell": {
|
|
47
|
-
"family": "flux",
|
|
48
|
-
"chat_template": null,
|
|
49
|
-
"gated": false,
|
|
50
|
-
"tags": [
|
|
51
|
-
"image-generation",
|
|
52
|
-
"diffusion",
|
|
53
|
-
"flux"
|
|
54
|
-
],
|
|
55
|
-
"architecture": "FluxPipeline",
|
|
56
|
-
"framework_compatibility": {
|
|
57
|
-
"vllm-omni": ">=0.14.0"
|
|
58
|
-
},
|
|
59
|
-
"validation_level": "experimental",
|
|
60
|
-
"notes": "FLUX.1-schnell fast generation model. Fewer denoising steps for faster inference at slightly lower quality"
|
|
61
|
-
},
|
|
62
|
-
"Wan-AI/Wan2.1-T2V-14B-Diffusers": {
|
|
63
|
-
"family": "wan",
|
|
64
|
-
"chat_template": null,
|
|
65
|
-
"gated": false,
|
|
66
|
-
"tags": [
|
|
67
|
-
"video-generation",
|
|
68
|
-
"diffusion",
|
|
69
|
-
"wan"
|
|
70
|
-
],
|
|
71
|
-
"architecture": "WanPipeline",
|
|
72
|
-
"framework_compatibility": {
|
|
73
|
-
"vllm-omni": ">=0.16.0"
|
|
74
|
-
},
|
|
75
|
-
"validation_level": "experimental",
|
|
76
|
-
"notes": "Wan2.1 text-to-video 14B model (diffusers format). Requires multi-GPU instance (ml.g5.12xlarge or larger). Must use the -Diffusers variant — the base Wan2.1-T2V-14B repo lacks model_index.json required by vLLM-Omni"
|
|
77
|
-
},
|
|
78
|
-
"stabilityai/stable-diffusion-*": {
|
|
79
|
-
"family": "stable-diffusion",
|
|
80
|
-
"chat_template": null,
|
|
81
|
-
"gated": false,
|
|
82
|
-
"tags": [
|
|
83
|
-
"image-generation",
|
|
84
|
-
"diffusion",
|
|
85
|
-
"stable-diffusion"
|
|
86
|
-
],
|
|
87
|
-
"architecture": null,
|
|
88
|
-
"framework_compatibility": {
|
|
89
|
-
"vllm-omni": ">=0.14.0"
|
|
90
|
-
},
|
|
91
|
-
"validation_level": "experimental",
|
|
92
|
-
"notes": "Fallback for Stable Diffusion variants not explicitly listed"
|
|
93
|
-
},
|
|
94
|
-
"black-forest-labs/FLUX*": {
|
|
95
|
-
"family": "flux",
|
|
96
|
-
"chat_template": null,
|
|
97
|
-
"gated": false,
|
|
98
|
-
"tags": [
|
|
99
|
-
"image-generation",
|
|
100
|
-
"diffusion",
|
|
101
|
-
"flux"
|
|
102
|
-
],
|
|
103
|
-
"architecture": null,
|
|
104
|
-
"framework_compatibility": {
|
|
105
|
-
"vllm-omni": ">=0.14.0"
|
|
106
|
-
},
|
|
107
|
-
"validation_level": "experimental",
|
|
108
|
-
"notes": "Fallback for FLUX model variants not explicitly listed"
|
|
109
|
-
}
|
|
110
|
-
}
|
|
1
|
+
{}
|
package/src/app.js
CHANGED
|
@@ -349,11 +349,27 @@ export async function writeProject(templateDir, destDir, answers, registryConfig
|
|
|
349
349
|
}
|
|
350
350
|
|
|
351
351
|
// Exclude tune files when framework is NOT transformers OR deploymentTarget is batch-transform
|
|
352
|
-
|
|
352
|
+
const tuneIncluded = architecture === 'transformers' && answers.deploymentTarget !== 'batch-transform';
|
|
353
|
+
if (!tuneIncluded) {
|
|
353
354
|
ignorePatterns.push('**/do/tune');
|
|
354
355
|
ignorePatterns.push('**/do/.tune_helper.py');
|
|
355
356
|
}
|
|
356
357
|
|
|
358
|
+
// Exclude train files when deploymentTarget is batch-transform
|
|
359
|
+
const trainIncluded = answers.deploymentTarget !== 'batch-transform';
|
|
360
|
+
if (!trainIncluded) {
|
|
361
|
+
ignorePatterns.push('**/do/train');
|
|
362
|
+
ignorePatterns.push('**/do/.train_build_request.py');
|
|
363
|
+
ignorePatterns.push('**/do/.train_status_parser.py');
|
|
364
|
+
ignorePatterns.push('**/do/.train_poll_parser.py');
|
|
365
|
+
ignorePatterns.push('**/do/training/**');
|
|
366
|
+
}
|
|
367
|
+
|
|
368
|
+
// Exclude feedback.sh when neither tune nor train is included
|
|
369
|
+
if (!tuneIncluded && !trainIncluded) {
|
|
370
|
+
ignorePatterns.push('**/do/lib/feedback.sh');
|
|
371
|
+
}
|
|
372
|
+
|
|
357
373
|
// Exclude do/test when hosted-model-endpoint is not selected
|
|
358
374
|
const testTypes = answers.testTypes || [];
|
|
359
375
|
if (!testTypes.includes('hosted-model-endpoint')) {
|
|
@@ -371,6 +387,11 @@ export async function writeProject(templateDir, destDir, answers, registryConfig
|
|
|
371
387
|
ignorePatterns.push('**/do/adapters/**');
|
|
372
388
|
ignorePatterns.push('**/do/tune');
|
|
373
389
|
ignorePatterns.push('**/do/.tune_helper.py');
|
|
390
|
+
ignorePatterns.push('**/do/train');
|
|
391
|
+
ignorePatterns.push('**/do/.train_build_request.py');
|
|
392
|
+
ignorePatterns.push('**/do/.train_status_parser.py');
|
|
393
|
+
ignorePatterns.push('**/do/.train_poll_parser.py');
|
|
394
|
+
ignorePatterns.push('**/do/training/**');
|
|
374
395
|
ignorePatterns.push('**/do/add-ic');
|
|
375
396
|
ignorePatterns.push('**/do/run');
|
|
376
397
|
ignorePatterns.push('**/sample_model/**');
|
|
@@ -1177,7 +1198,8 @@ function _setExecutablePermissions(destDir, answers = {}) {
|
|
|
1177
1198
|
'do/status',
|
|
1178
1199
|
'do/add-ic',
|
|
1179
1200
|
'do/adapter',
|
|
1180
|
-
'do/tune'
|
|
1201
|
+
'do/tune',
|
|
1202
|
+
'do/train'
|
|
1181
1203
|
];
|
|
1182
1204
|
|
|
1183
1205
|
const shellScripts = architecture === 'marketplace' ? marketplaceScripts : defaultScripts;
|
package/src/lib/mcp-client.js
CHANGED
|
@@ -14,6 +14,12 @@
|
|
|
14
14
|
|
|
15
15
|
import { Client } from '@modelcontextprotocol/sdk/client/index.js';
|
|
16
16
|
import { StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js';
|
|
17
|
+
import path from 'path';
|
|
18
|
+
import { fileURLToPath } from 'url';
|
|
19
|
+
|
|
20
|
+
const __mcp_filename = fileURLToPath(import.meta.url);
|
|
21
|
+
const __mcp_dirname = path.dirname(__mcp_filename);
|
|
22
|
+
const PACKAGE_ROOT = path.resolve(__mcp_dirname, '../..');
|
|
17
23
|
|
|
18
24
|
const DEFAULT_TOOL_NAME = 'get_ml_config';
|
|
19
25
|
const DEFAULT_LIMIT = 10;
|
|
@@ -96,6 +102,15 @@ class McpClient {
|
|
|
96
102
|
async _executeQuery() {
|
|
97
103
|
const { command, args = [], env } = this.serverConfig;
|
|
98
104
|
|
|
105
|
+
// Resolve relative paths in args against the package root
|
|
106
|
+
const resolvedArgs = args.map(arg => {
|
|
107
|
+
if (arg && !path.isAbsolute(arg) && !arg.startsWith('-')) {
|
|
108
|
+
const resolved = path.resolve(PACKAGE_ROOT, arg);
|
|
109
|
+
return resolved;
|
|
110
|
+
}
|
|
111
|
+
return arg;
|
|
112
|
+
});
|
|
113
|
+
|
|
99
114
|
// Build environment: merge process.env with server-specific env
|
|
100
115
|
// When --smart flag is active, inject BEDROCK_SMART=true for this run
|
|
101
116
|
// Discover mode is now default; inject DISCOVER_MODE=false only when explicitly disabled
|
|
@@ -108,7 +123,7 @@ class McpClient {
|
|
|
108
123
|
// Create stdio transport — spawns the server process
|
|
109
124
|
this._transport = new StdioClientTransport({
|
|
110
125
|
command,
|
|
111
|
-
args,
|
|
126
|
+
args: resolvedArgs,
|
|
112
127
|
env: spawnEnv,
|
|
113
128
|
stderr: 'pipe'
|
|
114
129
|
});
|
|
@@ -91,8 +91,12 @@ export default class McpCommandHandler {
|
|
|
91
91
|
const installed = await this._installBundledDependencies(resolved.serverDir, name);
|
|
92
92
|
if (!installed) return;
|
|
93
93
|
|
|
94
|
+
// Store path relative to package root for portability
|
|
95
|
+
const packageRoot = path.resolve(__dirname, '../..');
|
|
96
|
+
const relativePath = path.relative(packageRoot, resolved.entryPoint);
|
|
97
|
+
|
|
94
98
|
command = 'node';
|
|
95
|
-
commandArgs = [
|
|
99
|
+
commandArgs = [relativePath];
|
|
96
100
|
} else {
|
|
97
101
|
// Find the '--' separator to split name from command
|
|
98
102
|
const separatorIndex = positionalArgs.indexOf('--');
|
|
@@ -195,9 +199,13 @@ export default class McpCommandHandler {
|
|
|
195
199
|
const installed = await this._installBundledDependencies(resolved.serverDir, server.name);
|
|
196
200
|
if (!installed) continue;
|
|
197
201
|
|
|
202
|
+
// Store path relative to package root for portability across machines
|
|
203
|
+
const packageRoot = path.resolve(__dirname, '../..');
|
|
204
|
+
const relativePath = path.relative(packageRoot, resolved.entryPoint);
|
|
205
|
+
|
|
198
206
|
config.mcpServers[server.name] = {
|
|
199
207
|
command: 'node',
|
|
200
|
-
args: [
|
|
208
|
+
args: [relativePath]
|
|
201
209
|
};
|
|
202
210
|
added++;
|
|
203
211
|
}
|
package/src/lib/prompt-runner.js
CHANGED
|
@@ -50,6 +50,20 @@ const __pr_filename = fileURLToPath(import.meta.url);
|
|
|
50
50
|
const __pr_dirname = path.dirname(__pr_filename);
|
|
51
51
|
const GENERATOR_ROOT = path.resolve(__pr_dirname, '..', '..');
|
|
52
52
|
|
|
53
|
+
/**
|
|
54
|
+
* Resolve MCP server args — converts relative paths to absolute using GENERATOR_ROOT.
|
|
55
|
+
* @param {string[]} args - The args array from mcp.json serverConfig
|
|
56
|
+
* @returns {string[]} Args with relative paths resolved
|
|
57
|
+
*/
|
|
58
|
+
function resolveMcpArgs(args) {
|
|
59
|
+
return (args || []).map(arg => {
|
|
60
|
+
if (arg && !path.isAbsolute(arg) && !arg.startsWith('-')) {
|
|
61
|
+
return path.resolve(GENERATOR_ROOT, arg);
|
|
62
|
+
}
|
|
63
|
+
return arg;
|
|
64
|
+
});
|
|
65
|
+
}
|
|
66
|
+
|
|
53
67
|
export default class PromptRunner {
|
|
54
68
|
constructor({ configManager, options, registryConfigManager, baseConfig, promptFn }) {
|
|
55
69
|
this.configManager = configManager;
|
|
@@ -1384,7 +1398,7 @@ export default class PromptRunner {
|
|
|
1384
1398
|
const { Client } = await import('@modelcontextprotocol/sdk/client/index.js');
|
|
1385
1399
|
const { StdioClientTransport } = await import('@modelcontextprotocol/sdk/client/stdio.js');
|
|
1386
1400
|
|
|
1387
|
-
const serverArgs = [...(serverConfig.args
|
|
1401
|
+
const serverArgs = [...resolveMcpArgs(serverConfig.args)];
|
|
1388
1402
|
if (!discover && !serverArgs.includes('--no-discover')) {
|
|
1389
1403
|
serverArgs.push('--no-discover');
|
|
1390
1404
|
}
|
|
@@ -1939,7 +1953,7 @@ export default class PromptRunner {
|
|
|
1939
1953
|
|
|
1940
1954
|
const transport = new StdioClientTransport({
|
|
1941
1955
|
command: serverConfig.command,
|
|
1942
|
-
args: serverConfig.args
|
|
1956
|
+
args: resolveMcpArgs(serverConfig.args),
|
|
1943
1957
|
env: { ...process.env, ...(serverConfig.env || {}) },
|
|
1944
1958
|
stderr: 'pipe'
|
|
1945
1959
|
});
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
// SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
/**
|
|
5
|
+
* Train Config Parser
|
|
6
|
+
*
|
|
7
|
+
* JavaScript module that replicates the YAML config parsing logic from
|
|
8
|
+
* do/train's _parse_config_python() function. Parses do/training/config.yaml
|
|
9
|
+
* and extracts all supported fields into a structured object.
|
|
10
|
+
*
|
|
11
|
+
* This module mirrors the behavior of both the yq and Python fallback paths
|
|
12
|
+
* in the bash script, providing a testable implementation of the parsing logic.
|
|
13
|
+
*/
|
|
14
|
+
|
|
15
|
+
import { readFileSync } from 'node:fs';
|
|
16
|
+
import yaml from 'js-yaml';
|
|
17
|
+
|
|
18
|
+
/**
|
|
19
|
+
* Default values for optional fields, matching the bash script defaults.
|
|
20
|
+
*/
|
|
21
|
+
const DEFAULTS = {
|
|
22
|
+
instance_count: '1',
|
|
23
|
+
max_runtime_seconds: '86400',
|
|
24
|
+
volume_size_gb: '50',
|
|
25
|
+
enable_spot: 'false',
|
|
26
|
+
max_wait_seconds: '172800',
|
|
27
|
+
checkpoint_path: '',
|
|
28
|
+
hyperparameters: {},
|
|
29
|
+
metric_definitions: [],
|
|
30
|
+
environment: {},
|
|
31
|
+
tags: {}
|
|
32
|
+
};
|
|
33
|
+
|
|
34
|
+
/**
|
|
35
|
+
* Convert a value to its string representation, matching the Python helper's
|
|
36
|
+
* `s()` function behavior in _parse_config_python.
|
|
37
|
+
*
|
|
38
|
+
* @param {*} val - The value to convert
|
|
39
|
+
* @param {string} defaultVal - Default value if val is null/undefined
|
|
40
|
+
* @returns {string} String representation
|
|
41
|
+
*/
|
|
42
|
+
function toStringValue(val, defaultVal = '') {
|
|
43
|
+
if (val === null || val === undefined) {
|
|
44
|
+
return defaultVal;
|
|
45
|
+
}
|
|
46
|
+
if (typeof val === 'boolean') {
|
|
47
|
+
return val ? 'true' : 'false';
|
|
48
|
+
}
|
|
49
|
+
return String(val);
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
/**
|
|
53
|
+
* Parse a training config YAML file and extract all supported fields.
|
|
54
|
+
*
|
|
55
|
+
* This mirrors the behavior of _parse_config_python() in do/train:
|
|
56
|
+
* - Scalar fields are converted to strings
|
|
57
|
+
* - Boolean fields are converted to "true"/"false" strings
|
|
58
|
+
* - Missing optional fields get default values
|
|
59
|
+
* - Complex fields (hyperparameters, metric_definitions, environment, tags)
|
|
60
|
+
* are kept as their native types (objects/arrays)
|
|
61
|
+
*
|
|
62
|
+
* @param {string} configPath - Path to the YAML config file
|
|
63
|
+
* @returns {object} Parsed config with all supported fields
|
|
64
|
+
* @throws {Error} If the file cannot be read or parsed
|
|
65
|
+
*/
|
|
66
|
+
export function parseTrainingConfig(configPath) {
|
|
67
|
+
const content = readFileSync(configPath, 'utf8');
|
|
68
|
+
return parseTrainingConfigFromString(content);
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
/**
|
|
72
|
+
* Parse a training config from a YAML string.
|
|
73
|
+
* Useful for testing without file I/O.
|
|
74
|
+
*
|
|
75
|
+
* @param {string} yamlContent - YAML content string
|
|
76
|
+
* @returns {object} Parsed config with all supported fields
|
|
77
|
+
* @throws {Error} If the YAML cannot be parsed
|
|
78
|
+
*/
|
|
79
|
+
export function parseTrainingConfigFromString(yamlContent) {
|
|
80
|
+
const cfg = yaml.load(yamlContent) || {};
|
|
81
|
+
|
|
82
|
+
return {
|
|
83
|
+
// Required fields (empty string if missing)
|
|
84
|
+
image: toStringValue(cfg.image, ''),
|
|
85
|
+
script: toStringValue(cfg.script, ''),
|
|
86
|
+
instance_type: toStringValue(cfg.instance_type, ''),
|
|
87
|
+
instance_count: toStringValue(cfg.instance_count, DEFAULTS.instance_count),
|
|
88
|
+
dataset: toStringValue(cfg.dataset, ''),
|
|
89
|
+
output_path: toStringValue(cfg.output_path, ''),
|
|
90
|
+
|
|
91
|
+
// Optional scalar fields with defaults
|
|
92
|
+
max_runtime_seconds: toStringValue(cfg.max_runtime_seconds, DEFAULTS.max_runtime_seconds),
|
|
93
|
+
volume_size_gb: toStringValue(cfg.volume_size_gb, DEFAULTS.volume_size_gb),
|
|
94
|
+
enable_spot: toStringValue(cfg.enable_spot, DEFAULTS.enable_spot),
|
|
95
|
+
max_wait_seconds: toStringValue(cfg.max_wait_seconds, DEFAULTS.max_wait_seconds),
|
|
96
|
+
checkpoint_path: toStringValue(cfg.checkpoint_path, DEFAULTS.checkpoint_path),
|
|
97
|
+
|
|
98
|
+
// Complex fields (objects/arrays)
|
|
99
|
+
hyperparameters: cfg.hyperparameters || DEFAULTS.hyperparameters,
|
|
100
|
+
metric_definitions: cfg.metric_definitions || DEFAULTS.metric_definitions,
|
|
101
|
+
environment: cfg.environment || DEFAULTS.environment,
|
|
102
|
+
tags: cfg.tags || DEFAULTS.tags
|
|
103
|
+
};
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
/**
|
|
107
|
+
* List of all supported fields in the training config.
|
|
108
|
+
*/
|
|
109
|
+
export const SUPPORTED_FIELDS = [
|
|
110
|
+
'image',
|
|
111
|
+
'script',
|
|
112
|
+
'instance_type',
|
|
113
|
+
'instance_count',
|
|
114
|
+
'dataset',
|
|
115
|
+
'output_path',
|
|
116
|
+
'max_runtime_seconds',
|
|
117
|
+
'volume_size_gb',
|
|
118
|
+
'enable_spot',
|
|
119
|
+
'max_wait_seconds',
|
|
120
|
+
'checkpoint_path',
|
|
121
|
+
'hyperparameters',
|
|
122
|
+
'metric_definitions',
|
|
123
|
+
'environment',
|
|
124
|
+
'tags'
|
|
125
|
+
];
|
|
126
|
+
|
|
127
|
+
/**
|
|
128
|
+
* List of required fields that must be non-empty.
|
|
129
|
+
*/
|
|
130
|
+
export const REQUIRED_FIELDS = [
|
|
131
|
+
'image',
|
|
132
|
+
'script',
|
|
133
|
+
'instance_type',
|
|
134
|
+
'dataset',
|
|
135
|
+
'output_path'
|
|
136
|
+
];
|