@aws/ml-container-creator 0.8.0 → 0.9.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE-THIRD-PARTY +50760 -16218
- package/package.json +3 -1
- package/servers/lib/catalogs/instances.json +52 -1275
- package/servers/lib/catalogs/models.json +0 -132
- package/servers/lib/catalogs/popular-diffusors.json +1 -110
- package/src/app.js +24 -2
- package/src/lib/mcp-client.js +16 -1
- package/src/lib/mcp-command-handler.js +10 -2
- package/src/lib/prompt-runner.js +16 -2
- package/src/lib/train-config-parser.js +136 -0
- package/src/lib/train-config-persistence.js +143 -0
- package/src/lib/train-config-validator.js +112 -0
- package/src/lib/train-feedback.js +46 -0
- package/src/lib/train-idempotency.js +97 -0
- package/src/lib/train-request-builder.js +120 -0
- package/templates/do/.train_build_request.py +141 -0
- package/templates/do/.train_poll_parser.py +135 -0
- package/templates/do/.train_status_parser.py +187 -0
- package/templates/do/lib/feedback.sh +41 -0
- package/templates/do/train +786 -0
- package/templates/do/training/config.yaml +140 -0
- package/templates/do/training/train.py +463 -0
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
// SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
/**
|
|
5
|
+
* Train Config Persistence
|
|
6
|
+
*
|
|
7
|
+
* JavaScript module that models the config persistence logic from the bash
|
|
8
|
+
* `_update_config_var` function in `templates/do/train`. This module provides
|
|
9
|
+
* a pure JavaScript implementation for property-based testing of the config
|
|
10
|
+
* persistence behavior after job submission and completion.
|
|
11
|
+
*
|
|
12
|
+
* The config file uses the format:
|
|
13
|
+
* export VAR_NAME="value"
|
|
14
|
+
*
|
|
15
|
+
* Behavior:
|
|
16
|
+
* - If the variable already exists: update its value in-place
|
|
17
|
+
* - If the variable doesn't exist: append it to the end
|
|
18
|
+
* - Existing variables in the config are preserved
|
|
19
|
+
*
|
|
20
|
+
* Requirements: 3.4, 5.1
|
|
21
|
+
*/
|
|
22
|
+
|
|
23
|
+
import { readFileSync, writeFileSync } from 'node:fs';
|
|
24
|
+
|
|
25
|
+
/**
|
|
26
|
+
* Update or add a config variable in a do/config-style file.
|
|
27
|
+
* Mimics the bash _update_config_var() function from templates/do/train:
|
|
28
|
+
*
|
|
29
|
+
* if grep -q "^export ${var_name}=" "${config_file}"; then
|
|
30
|
+
* sed -i.bak "s|^export ${var_name}=.*|export ${var_name}=\"${var_value}\"|" "${config_file}"
|
|
31
|
+
* rm -f "${config_file}.bak"
|
|
32
|
+
* else
|
|
33
|
+
* echo "export ${var_name}=\"${var_value}\"" >> "${config_file}"
|
|
34
|
+
* fi
|
|
35
|
+
*
|
|
36
|
+
* @param {string} configContent - Current content of the config file
|
|
37
|
+
* @param {string} varName - Variable name (e.g., TRAIN_JOB_NAME)
|
|
38
|
+
* @param {string} varValue - Variable value
|
|
39
|
+
* @returns {string} Updated config content
|
|
40
|
+
*/
|
|
41
|
+
export function updateConfigVar(configContent, varName, varValue) {
|
|
42
|
+
const pattern = new RegExp(`^export ${varName}=.*$`, 'm');
|
|
43
|
+
|
|
44
|
+
if (pattern.test(configContent)) {
|
|
45
|
+
// Variable exists — update in-place
|
|
46
|
+
return configContent.replace(pattern, `export ${varName}="${varValue}"`);
|
|
47
|
+
} else {
|
|
48
|
+
// Variable doesn't exist — append
|
|
49
|
+
let result = configContent;
|
|
50
|
+
if (result.length > 0 && !result.endsWith('\n')) {
|
|
51
|
+
result += '\n';
|
|
52
|
+
}
|
|
53
|
+
result += `export ${varName}="${varValue}"\n`;
|
|
54
|
+
return result;
|
|
55
|
+
}
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
/**
|
|
59
|
+
* Read a config variable value from a do/config-style file content.
|
|
60
|
+
*
|
|
61
|
+
* @param {string} configContent - Content of the config file
|
|
62
|
+
* @param {string} varName - Variable name to read
|
|
63
|
+
* @returns {string|null} The variable value, or null if not found
|
|
64
|
+
*/
|
|
65
|
+
export function readConfigVar(configContent, varName) {
|
|
66
|
+
const pattern = new RegExp(`^export ${varName}="([^"]*)"`, 'm');
|
|
67
|
+
const match = configContent.match(pattern);
|
|
68
|
+
return match ? match[1] : null;
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
/**
|
|
72
|
+
* Simulate the config writes that happen after a successful training job submission.
|
|
73
|
+
* This mirrors the behavior in do/train's _submit_job() function which calls:
|
|
74
|
+
* _update_config_var "TRAIN_JOB_NAME" "${JOB_NAME}"
|
|
75
|
+
*
|
|
76
|
+
* @param {string} configContent - Current content of the config file
|
|
77
|
+
* @param {object} params - Submission parameters
|
|
78
|
+
* @param {string} params.jobName - Generated job name (pattern: ${PROJECT_NAME}-train-${TIMESTAMP})
|
|
79
|
+
* @returns {string} Updated config content
|
|
80
|
+
*/
|
|
81
|
+
export function persistTrainSubmission(configContent, { jobName }) {
|
|
82
|
+
return updateConfigVar(configContent, 'TRAIN_JOB_NAME', jobName);
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
/**
|
|
86
|
+
* Simulate the config writes that happen after a training job completes.
|
|
87
|
+
* This mirrors the behavior in do/train's _handle_completion() function which calls:
|
|
88
|
+
* _update_config_var "TRAIN_OUTPUT_PATH" "${output_path}"
|
|
89
|
+
*
|
|
90
|
+
* @param {string} configContent - Current content of the config file
|
|
91
|
+
* @param {object} params - Completion parameters
|
|
92
|
+
* @param {string} params.outputPath - S3 path to the output artifacts
|
|
93
|
+
* @returns {string} Updated config content
|
|
94
|
+
*/
|
|
95
|
+
export function persistTrainCompletion(configContent, { outputPath }) {
|
|
96
|
+
return updateConfigVar(configContent, 'TRAIN_OUTPUT_PATH', outputPath);
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
/**
|
|
100
|
+
* Generate a training job name following the pattern used by do/train.
|
|
101
|
+
* Pattern: ${projectName}-train-YYYYMMDD-HHMMSS
|
|
102
|
+
*
|
|
103
|
+
* @param {string} projectName - Project name
|
|
104
|
+
* @param {Date} [timestamp] - Optional timestamp (defaults to now)
|
|
105
|
+
* @returns {string} Generated job name
|
|
106
|
+
*/
|
|
107
|
+
export function generateTrainJobName(projectName, timestamp = new Date()) {
|
|
108
|
+
const year = timestamp.getFullYear().toString();
|
|
109
|
+
const month = (timestamp.getMonth() + 1).toString().padStart(2, '0');
|
|
110
|
+
const day = timestamp.getDate().toString().padStart(2, '0');
|
|
111
|
+
const hours = timestamp.getHours().toString().padStart(2, '0');
|
|
112
|
+
const minutes = timestamp.getMinutes().toString().padStart(2, '0');
|
|
113
|
+
const seconds = timestamp.getSeconds().toString().padStart(2, '0');
|
|
114
|
+
const dateStr = `${year}${month}${day}`;
|
|
115
|
+
const timeStr = `${hours}${minutes}${seconds}`;
|
|
116
|
+
return `${projectName}-train-${dateStr}-${timeStr}`;
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
/**
|
|
120
|
+
* File-based version of updateConfigVar that reads/writes to disk.
|
|
121
|
+
* Used for integration-style tests that need actual file I/O.
|
|
122
|
+
*
|
|
123
|
+
* @param {string} configPath - Path to the config file
|
|
124
|
+
* @param {string} varName - Variable name
|
|
125
|
+
* @param {string} varValue - Variable value
|
|
126
|
+
*/
|
|
127
|
+
export function updateConfigVarFile(configPath, varName, varValue) {
|
|
128
|
+
const content = readFileSync(configPath, 'utf8');
|
|
129
|
+
const updated = updateConfigVar(content, varName, varValue);
|
|
130
|
+
writeFileSync(configPath, updated, 'utf8');
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
/**
|
|
134
|
+
* File-based version of readConfigVar that reads from disk.
|
|
135
|
+
*
|
|
136
|
+
* @param {string} configPath - Path to the config file
|
|
137
|
+
* @param {string} varName - Variable name to read
|
|
138
|
+
* @returns {string|null} The variable value, or null if not found
|
|
139
|
+
*/
|
|
140
|
+
export function readConfigVarFile(configPath, varName) {
|
|
141
|
+
const content = readFileSync(configPath, 'utf8');
|
|
142
|
+
return readConfigVar(content, varName);
|
|
143
|
+
}
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
// SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
/**
|
|
5
|
+
* Train Config Validator
|
|
6
|
+
*
|
|
7
|
+
* Validates training configuration objects parsed from do/training/config.yaml.
|
|
8
|
+
* Checks that all required fields are present and provides descriptive error
|
|
9
|
+
* messages naming the specific missing field.
|
|
10
|
+
*
|
|
11
|
+
* This module mirrors the validation logic in the bash `_validate_config`
|
|
12
|
+
* function in templates/do/train, enabling property-based testing of the
|
|
13
|
+
* validation rules in isolation.
|
|
14
|
+
*
|
|
15
|
+
* Requirements: 2.12, 10.1
|
|
16
|
+
*/
|
|
17
|
+
|
|
18
|
+
/**
|
|
19
|
+
* Required fields for a valid training configuration.
|
|
20
|
+
* Each entry maps the field name to a human-readable description and expected format.
|
|
21
|
+
*/
|
|
22
|
+
export const REQUIRED_FIELDS = {
|
|
23
|
+
image: {
|
|
24
|
+
description: 'The container image URI',
|
|
25
|
+
format: 'image: "123456789012.dkr.ecr.us-east-1.amazonaws.com/my-training:latest"'
|
|
26
|
+
},
|
|
27
|
+
script: {
|
|
28
|
+
description: 'The training script S3 path',
|
|
29
|
+
format: 'script: "s3://my-bucket/scripts/train.py"'
|
|
30
|
+
},
|
|
31
|
+
instance_type: {
|
|
32
|
+
description: 'The SageMaker instance type',
|
|
33
|
+
format: 'instance_type: "ml.g5.xlarge"'
|
|
34
|
+
},
|
|
35
|
+
dataset: {
|
|
36
|
+
description: 'The S3 dataset path',
|
|
37
|
+
format: 'dataset: "s3://my-bucket/data/train/"'
|
|
38
|
+
},
|
|
39
|
+
output_path: {
|
|
40
|
+
description: 'The S3 output path',
|
|
41
|
+
format: 'output_path: "s3://my-bucket/output/"'
|
|
42
|
+
}
|
|
43
|
+
};
|
|
44
|
+
|
|
45
|
+
/**
|
|
46
|
+
* Validate that all required fields are present in a training config.
|
|
47
|
+
*
|
|
48
|
+
* @param {Object} config - The parsed training configuration object
|
|
49
|
+
* @returns {{ valid: boolean, errors: Array<{ field: string, message: string }> }}
|
|
50
|
+
* - valid: true if all required fields are present and non-empty
|
|
51
|
+
* - errors: array of error objects, each naming the missing field
|
|
52
|
+
*/
|
|
53
|
+
export function validateRequiredFields(config) {
|
|
54
|
+
const errors = [];
|
|
55
|
+
|
|
56
|
+
for (const [field, meta] of Object.entries(REQUIRED_FIELDS)) {
|
|
57
|
+
const value = config ? config[field] : undefined;
|
|
58
|
+
|
|
59
|
+
if (value === undefined || value === null || value === '') {
|
|
60
|
+
errors.push({
|
|
61
|
+
field,
|
|
62
|
+
message: `Missing required field: ${field}\n ${meta.description} is required in do/training/config.yaml\n\n Expected format: ${meta.format}`
|
|
63
|
+
});
|
|
64
|
+
}
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
return {
|
|
68
|
+
valid: errors.length === 0,
|
|
69
|
+
errors
|
|
70
|
+
};
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
/**
|
|
74
|
+
* Validate spot training checkpoint requirement.
|
|
75
|
+
* When enable_spot is true, checkpoint_path must be specified.
|
|
76
|
+
*
|
|
77
|
+
* @param {Object} config - The parsed training configuration object
|
|
78
|
+
* @returns {{ valid: boolean, errors: Array<{ field: string, message: string }> }}
|
|
79
|
+
*/
|
|
80
|
+
export function validateSpotConfig(config) {
|
|
81
|
+
const errors = [];
|
|
82
|
+
|
|
83
|
+
if (config && config.enable_spot === true && (!config.checkpoint_path || config.checkpoint_path === '')) {
|
|
84
|
+
errors.push({
|
|
85
|
+
field: 'checkpoint_path',
|
|
86
|
+
message: 'Checkpoint path required for spot training\n When enable_spot is true, a checkpoint S3 path must be specified\n so training can resume after spot interruptions.'
|
|
87
|
+
});
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
return {
|
|
91
|
+
valid: errors.length === 0,
|
|
92
|
+
errors
|
|
93
|
+
};
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
/**
|
|
97
|
+
* Full validation of a training config — checks required fields and spot config.
|
|
98
|
+
*
|
|
99
|
+
* @param {Object} config - The parsed training configuration object
|
|
100
|
+
* @returns {{ valid: boolean, errors: Array<{ field: string, message: string }> }}
|
|
101
|
+
*/
|
|
102
|
+
export function validateTrainingConfig(config) {
|
|
103
|
+
const requiredResult = validateRequiredFields(config);
|
|
104
|
+
const spotResult = validateSpotConfig(config);
|
|
105
|
+
|
|
106
|
+
const allErrors = [...requiredResult.errors, ...spotResult.errors];
|
|
107
|
+
|
|
108
|
+
return {
|
|
109
|
+
valid: allErrors.length === 0,
|
|
110
|
+
errors: allErrors
|
|
111
|
+
};
|
|
112
|
+
}
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
// SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
/**
|
|
5
|
+
* Train Feedback Loop — JavaScript equivalent of do/lib/feedback.sh
|
|
6
|
+
*
|
|
7
|
+
* Generates post-completion feedback output with artifact locations
|
|
8
|
+
* and deployment suggestions based on artifact type.
|
|
9
|
+
*/
|
|
10
|
+
|
|
11
|
+
/**
|
|
12
|
+
* Generate completion feedback output for a training/tuning job.
|
|
13
|
+
*
|
|
14
|
+
* Replicates the logic of print_completion_feedback() in do/lib/feedback.sh.
|
|
15
|
+
*
|
|
16
|
+
* @param {object} params
|
|
17
|
+
* @param {string} params.outputPath - S3 URI to the output artifacts
|
|
18
|
+
* @param {string} params.outputType - "adapter" or "full-model"
|
|
19
|
+
* @param {string} params.jobName - Job name for reference
|
|
20
|
+
* @param {string} [params.modelPackageArn] - Optional model package ARN
|
|
21
|
+
* @returns {string} The formatted feedback output
|
|
22
|
+
*/
|
|
23
|
+
export function generateCompletionFeedback({ outputPath, outputType, jobName, modelPackageArn = '' }) {
|
|
24
|
+
const lines = [];
|
|
25
|
+
|
|
26
|
+
lines.push('');
|
|
27
|
+
lines.push(`✅ Training complete: ${jobName}`);
|
|
28
|
+
lines.push('');
|
|
29
|
+
lines.push(` Artifacts: ${outputPath}`);
|
|
30
|
+
if (modelPackageArn) {
|
|
31
|
+
lines.push(` Model Package: ${modelPackageArn}`);
|
|
32
|
+
}
|
|
33
|
+
lines.push('');
|
|
34
|
+
lines.push(' Next steps:');
|
|
35
|
+
|
|
36
|
+
if (outputType === 'adapter') {
|
|
37
|
+
lines.push(` • Deploy as LoRA adapter: ./do/adapter add my-adapter --weights ${outputPath}`);
|
|
38
|
+
lines.push(' • (Requires running endpoint with LoRA enabled)');
|
|
39
|
+
} else if (outputType === 'full-model') {
|
|
40
|
+
lines.push(` • Deploy as new IC: ./do/add-ic my-model --model-data ${outputPath}`);
|
|
41
|
+
lines.push(` • Replace current base: ./do/deploy --force-ic --model-data ${outputPath}`);
|
|
42
|
+
}
|
|
43
|
+
lines.push('');
|
|
44
|
+
|
|
45
|
+
return lines.join('\n');
|
|
46
|
+
}
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
// SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
/**
|
|
5
|
+
* Train Idempotency Decision Logic
|
|
6
|
+
*
|
|
7
|
+
* Models the idempotency check logic from the bash `_check_idempotency` function
|
|
8
|
+
* in `templates/do/train` as a pure JavaScript function for property-based testing.
|
|
9
|
+
*
|
|
10
|
+
* The idempotency pattern:
|
|
11
|
+
* - If --force is set, always create a new job regardless of existing status
|
|
12
|
+
* - If no existing job, create a new job
|
|
13
|
+
* - If existing job is InProgress, poll it
|
|
14
|
+
* - If existing job is Completed, display results
|
|
15
|
+
* - If existing job is Failed or Stopped, display failure and suggest --force
|
|
16
|
+
*
|
|
17
|
+
* Requirements: 5.1–5.5
|
|
18
|
+
*/
|
|
19
|
+
|
|
20
|
+
/**
|
|
21
|
+
* Valid existing job statuses that SageMaker can report.
|
|
22
|
+
*/
|
|
23
|
+
export const JOB_STATUSES = ['InProgress', 'Completed', 'Failed', 'Stopped'];
|
|
24
|
+
|
|
25
|
+
/**
|
|
26
|
+
* Possible actions the train script can take after the idempotency check.
|
|
27
|
+
*/
|
|
28
|
+
export const ACTIONS = {
|
|
29
|
+
CREATE_NEW_JOB: 'create_new_job',
|
|
30
|
+
POLL_EXISTING: 'poll_existing',
|
|
31
|
+
DISPLAY_RESULTS: 'display_results',
|
|
32
|
+
DISPLAY_FAILURE: 'display_failure'
|
|
33
|
+
};
|
|
34
|
+
|
|
35
|
+
/**
|
|
36
|
+
* Determine the action to take based on existing job status and force flag.
|
|
37
|
+
*
|
|
38
|
+
* This mirrors the bash `_check_idempotency` logic in a testable form:
|
|
39
|
+
* - force=true → always create_new_job
|
|
40
|
+
* - no existing status (null/empty) → create_new_job
|
|
41
|
+
* - InProgress → poll_existing
|
|
42
|
+
* - Completed → display_results
|
|
43
|
+
* - Failed → display_failure
|
|
44
|
+
* - Stopped → display_failure
|
|
45
|
+
*
|
|
46
|
+
* @param {string|null|undefined} existingStatus - The current job status from DescribeTrainingJob
|
|
47
|
+
* @param {boolean} forceFlag - Whether --force was specified
|
|
48
|
+
* @returns {{ action: string, reason: string }}
|
|
49
|
+
* - action: one of ACTIONS values
|
|
50
|
+
* - reason: human-readable explanation of why this action was chosen
|
|
51
|
+
*/
|
|
52
|
+
export function determineAction(existingStatus, forceFlag) {
|
|
53
|
+
// Force flag always overrides — create a new job regardless of existing status
|
|
54
|
+
if (forceFlag === true) {
|
|
55
|
+
return {
|
|
56
|
+
action: ACTIONS.CREATE_NEW_JOB,
|
|
57
|
+
reason: '--force specified, creating new job regardless of existing status'
|
|
58
|
+
};
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
// No existing job — create a new one
|
|
62
|
+
if (!existingStatus || existingStatus === '') {
|
|
63
|
+
return {
|
|
64
|
+
action: ACTIONS.CREATE_NEW_JOB,
|
|
65
|
+
reason: 'No existing job found, creating new job'
|
|
66
|
+
};
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
// Existing job found — action depends on status
|
|
70
|
+
switch (existingStatus) {
|
|
71
|
+
case 'InProgress':
|
|
72
|
+
return {
|
|
73
|
+
action: ACTIONS.POLL_EXISTING,
|
|
74
|
+
reason: `Existing job is ${existingStatus}, resuming polling`
|
|
75
|
+
};
|
|
76
|
+
|
|
77
|
+
case 'Completed':
|
|
78
|
+
return {
|
|
79
|
+
action: ACTIONS.DISPLAY_RESULTS,
|
|
80
|
+
reason: `Existing job is ${existingStatus}, displaying results`
|
|
81
|
+
};
|
|
82
|
+
|
|
83
|
+
case 'Failed':
|
|
84
|
+
case 'Stopped':
|
|
85
|
+
return {
|
|
86
|
+
action: ACTIONS.DISPLAY_FAILURE,
|
|
87
|
+
reason: `Existing job is ${existingStatus}, suggest --force to create new job`
|
|
88
|
+
};
|
|
89
|
+
|
|
90
|
+
default:
|
|
91
|
+
// Unknown status — treat as failure, suggest --force
|
|
92
|
+
return {
|
|
93
|
+
action: ACTIONS.DISPLAY_FAILURE,
|
|
94
|
+
reason: `Unexpected job status: ${existingStatus}, suggest --force`
|
|
95
|
+
};
|
|
96
|
+
}
|
|
97
|
+
}
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
2
|
+
// SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
/**
|
|
5
|
+
* Train Request Builder
|
|
6
|
+
*
|
|
7
|
+
* JavaScript module that replicates the Python helper's (.train_build_request.py)
|
|
8
|
+
* logic for constructing a CreateTrainingJob JSON request from a parsed config.
|
|
9
|
+
*
|
|
10
|
+
* This module mirrors the behavior of the Python build_request() function,
|
|
11
|
+
* providing a testable implementation of the config-to-API mapping logic.
|
|
12
|
+
*/
|
|
13
|
+
|
|
14
|
+
/**
|
|
15
|
+
* Build a CreateTrainingJob request from a parsed training config.
|
|
16
|
+
*
|
|
17
|
+
* Maps config fields to the SageMaker CreateTrainingJob API structure:
|
|
18
|
+
* - image → AlgorithmSpecification.TrainingImage
|
|
19
|
+
* - instance_type → ResourceConfig.InstanceType
|
|
20
|
+
* - instance_count → ResourceConfig.InstanceCount
|
|
21
|
+
* - dataset → InputDataConfig[0].DataSource.S3DataSource.S3Uri
|
|
22
|
+
* - output_path → OutputDataConfig.S3OutputPath
|
|
23
|
+
* - hyperparameters → HyperParameters (string key-value pairs)
|
|
24
|
+
* - max_runtime_seconds → StoppingCondition.MaxRuntimeInSeconds
|
|
25
|
+
* - enable_spot=true → EnableManagedSpotTraining = true
|
|
26
|
+
* - enable_spot=true → StoppingCondition.MaxWaitTimeInSeconds
|
|
27
|
+
* - checkpoint_path → CheckpointConfig.S3Uri
|
|
28
|
+
* - metric_definitions → AlgorithmSpecification.MetricDefinitions
|
|
29
|
+
* - environment → Environment
|
|
30
|
+
* - tags → Tags (converted from {k:v} to [{Key:k, Value:v}])
|
|
31
|
+
*
|
|
32
|
+
* @param {object} options - Build options
|
|
33
|
+
* @param {string} options.jobName - Training job name
|
|
34
|
+
* @param {string} options.roleArn - SageMaker execution role ARN
|
|
35
|
+
* @param {object} options.config - Parsed training config (from parseTrainingConfig)
|
|
36
|
+
* @returns {object} CreateTrainingJob request body
|
|
37
|
+
*/
|
|
38
|
+
export function buildTrainingJobRequest({ jobName, roleArn, config }) {
|
|
39
|
+
const request = {
|
|
40
|
+
TrainingJobName: jobName,
|
|
41
|
+
RoleArn: roleArn,
|
|
42
|
+
AlgorithmSpecification: {
|
|
43
|
+
TrainingImage: config.image,
|
|
44
|
+
TrainingInputMode: 'File'
|
|
45
|
+
},
|
|
46
|
+
InputDataConfig: [
|
|
47
|
+
{
|
|
48
|
+
ChannelName: 'training',
|
|
49
|
+
DataSource: {
|
|
50
|
+
S3DataSource: {
|
|
51
|
+
S3DataType: 'S3Prefix',
|
|
52
|
+
S3Uri: config.dataset,
|
|
53
|
+
S3DataDistributionType: 'FullyReplicated'
|
|
54
|
+
}
|
|
55
|
+
}
|
|
56
|
+
}
|
|
57
|
+
],
|
|
58
|
+
OutputDataConfig: {
|
|
59
|
+
S3OutputPath: config.output_path
|
|
60
|
+
},
|
|
61
|
+
ResourceConfig: {
|
|
62
|
+
InstanceType: config.instance_type,
|
|
63
|
+
InstanceCount: parseInt(config.instance_count, 10),
|
|
64
|
+
VolumeSizeInGB: parseInt(config.volume_size_gb, 10)
|
|
65
|
+
},
|
|
66
|
+
StoppingCondition: {
|
|
67
|
+
MaxRuntimeInSeconds: parseInt(config.max_runtime_seconds, 10)
|
|
68
|
+
}
|
|
69
|
+
};
|
|
70
|
+
|
|
71
|
+
// Hyperparameters — ensure all values are strings (SageMaker requirement)
|
|
72
|
+
const hyperparams = config.hyperparameters || {};
|
|
73
|
+
if (Object.keys(hyperparams).length > 0) {
|
|
74
|
+
request.HyperParameters = {};
|
|
75
|
+
for (const [k, v] of Object.entries(hyperparams)) {
|
|
76
|
+
request.HyperParameters[String(k)] = String(v);
|
|
77
|
+
}
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
// Managed spot training
|
|
81
|
+
const enableSpot = config.enable_spot === 'true' || config.enable_spot === true;
|
|
82
|
+
if (enableSpot) {
|
|
83
|
+
request.EnableManagedSpotTraining = true;
|
|
84
|
+
request.StoppingCondition.MaxWaitTimeInSeconds = parseInt(config.max_wait_seconds, 10);
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
// Checkpoint configuration (for spot training resumption)
|
|
88
|
+
const checkpointPath = config.checkpoint_path || '';
|
|
89
|
+
if (checkpointPath) {
|
|
90
|
+
request.CheckpointConfig = {
|
|
91
|
+
S3Uri: checkpointPath
|
|
92
|
+
};
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
// Metric definitions (custom CloudWatch metrics)
|
|
96
|
+
const metricDefs = config.metric_definitions || [];
|
|
97
|
+
if (Array.isArray(metricDefs) && metricDefs.length > 0) {
|
|
98
|
+
request.AlgorithmSpecification.MetricDefinitions = metricDefs.map(m => ({
|
|
99
|
+
Name: m.name,
|
|
100
|
+
Regex: m.regex
|
|
101
|
+
}));
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
// Environment variables for the container
|
|
105
|
+
const environment = config.environment || {};
|
|
106
|
+
if (Object.keys(environment).length > 0) {
|
|
107
|
+
request.Environment = environment;
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
// Tags — convert from {key: value} map to [{Key: k, Value: v}] array
|
|
111
|
+
const tags = config.tags || {};
|
|
112
|
+
if (Object.keys(tags).length > 0) {
|
|
113
|
+
request.Tags = Object.entries(tags).map(([k, v]) => ({
|
|
114
|
+
Key: String(k),
|
|
115
|
+
Value: String(v)
|
|
116
|
+
}));
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
return request;
|
|
120
|
+
}
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
"""
|
|
6
|
+
Build the CreateTrainingJob JSON request for SageMaker.
|
|
7
|
+
|
|
8
|
+
This helper is called by do/train to construct the full API request body.
|
|
9
|
+
It handles conditional fields (spot training, metric definitions, environment,
|
|
10
|
+
tags) and writes the result to a JSON file for use with:
|
|
11
|
+
aws sagemaker create-training-job --cli-input-json file://path.json
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import argparse
|
|
15
|
+
import json
|
|
16
|
+
import sys
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def parse_args():
|
|
20
|
+
"""Parse command-line arguments."""
|
|
21
|
+
parser = argparse.ArgumentParser(description='Build CreateTrainingJob request JSON')
|
|
22
|
+
parser.add_argument('--job-name', required=True, help='Training job name')
|
|
23
|
+
parser.add_argument('--role-arn', required=True, help='SageMaker execution role ARN')
|
|
24
|
+
parser.add_argument('--image', required=True, help='Training container image URI')
|
|
25
|
+
parser.add_argument('--instance-type', required=True, help='Instance type')
|
|
26
|
+
parser.add_argument('--instance-count', required=True, help='Instance count')
|
|
27
|
+
parser.add_argument('--volume-size', required=True, help='Volume size in GB')
|
|
28
|
+
parser.add_argument('--dataset', required=True, help='S3 URI for training dataset')
|
|
29
|
+
parser.add_argument('--output-path', required=True, help='S3 URI for output')
|
|
30
|
+
parser.add_argument('--max-runtime', required=True, help='Max runtime in seconds')
|
|
31
|
+
parser.add_argument('--hyperparams', required=True, help='Hyperparameters as JSON string')
|
|
32
|
+
parser.add_argument('--enable-spot', required=True, help='Enable spot training (true/false)')
|
|
33
|
+
parser.add_argument('--max-wait', required=True, help='Max wait time for spot in seconds')
|
|
34
|
+
parser.add_argument('--checkpoint-path', required=True, help='S3 checkpoint path')
|
|
35
|
+
parser.add_argument('--metric-definitions', required=True, help='Metric definitions as JSON array')
|
|
36
|
+
parser.add_argument('--environment', required=True, help='Environment variables as JSON object')
|
|
37
|
+
parser.add_argument('--tags', required=True, help='Tags as JSON object (key-value map)')
|
|
38
|
+
parser.add_argument('--output-file', required=True, help='Output file path for the JSON')
|
|
39
|
+
return parser.parse_args()
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def build_request(args):
|
|
43
|
+
"""Construct the CreateTrainingJob request dictionary."""
|
|
44
|
+
# Parse JSON inputs
|
|
45
|
+
hyperparams = json.loads(args.hyperparams) if args.hyperparams else {}
|
|
46
|
+
metric_definitions = json.loads(args.metric_definitions) if args.metric_definitions else []
|
|
47
|
+
environment = json.loads(args.environment) if args.environment else {}
|
|
48
|
+
tags = json.loads(args.tags) if args.tags else {}
|
|
49
|
+
|
|
50
|
+
# Base request structure
|
|
51
|
+
request = {
|
|
52
|
+
'TrainingJobName': args.job_name,
|
|
53
|
+
'RoleArn': args.role_arn,
|
|
54
|
+
'AlgorithmSpecification': {
|
|
55
|
+
'TrainingImage': args.image,
|
|
56
|
+
'TrainingInputMode': 'File'
|
|
57
|
+
},
|
|
58
|
+
'InputDataConfig': [
|
|
59
|
+
{
|
|
60
|
+
'ChannelName': 'training',
|
|
61
|
+
'DataSource': {
|
|
62
|
+
'S3DataSource': {
|
|
63
|
+
'S3DataType': 'S3Prefix',
|
|
64
|
+
'S3Uri': args.dataset,
|
|
65
|
+
'S3DataDistributionType': 'FullyReplicated'
|
|
66
|
+
}
|
|
67
|
+
}
|
|
68
|
+
}
|
|
69
|
+
],
|
|
70
|
+
'OutputDataConfig': {
|
|
71
|
+
'S3OutputPath': args.output_path
|
|
72
|
+
},
|
|
73
|
+
'ResourceConfig': {
|
|
74
|
+
'InstanceType': args.instance_type,
|
|
75
|
+
'InstanceCount': int(args.instance_count),
|
|
76
|
+
'VolumeSizeInGB': int(args.volume_size)
|
|
77
|
+
},
|
|
78
|
+
'StoppingCondition': {
|
|
79
|
+
'MaxRuntimeInSeconds': int(args.max_runtime)
|
|
80
|
+
}
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
# Hyperparameters — ensure all values are strings (SageMaker requirement)
|
|
84
|
+
if hyperparams:
|
|
85
|
+
request['HyperParameters'] = {
|
|
86
|
+
str(k): str(v) for k, v in hyperparams.items()
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
# Managed spot training
|
|
90
|
+
if args.enable_spot == 'true':
|
|
91
|
+
request['EnableManagedSpotTraining'] = True
|
|
92
|
+
request['StoppingCondition']['MaxWaitTimeInSeconds'] = int(args.max_wait)
|
|
93
|
+
|
|
94
|
+
# Checkpoint configuration (for spot training resumption)
|
|
95
|
+
if args.checkpoint_path:
|
|
96
|
+
request['CheckpointConfig'] = {
|
|
97
|
+
'S3Uri': args.checkpoint_path
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
# Metric definitions (custom CloudWatch metrics)
|
|
101
|
+
if metric_definitions and metric_definitions != []:
|
|
102
|
+
request['AlgorithmSpecification']['MetricDefinitions'] = [
|
|
103
|
+
{'Name': m['name'], 'Regex': m['regex']}
|
|
104
|
+
for m in metric_definitions
|
|
105
|
+
]
|
|
106
|
+
|
|
107
|
+
# Environment variables for the container
|
|
108
|
+
if environment and environment != {}:
|
|
109
|
+
request['Environment'] = environment
|
|
110
|
+
|
|
111
|
+
# Tags — convert from {key: value} map to [{Key: k, Value: v}] array
|
|
112
|
+
if tags and tags != {}:
|
|
113
|
+
request['Tags'] = [
|
|
114
|
+
{'Key': str(k), 'Value': str(v)}
|
|
115
|
+
for k, v in tags.items()
|
|
116
|
+
]
|
|
117
|
+
|
|
118
|
+
return request
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def main():
|
|
122
|
+
"""Main entry point."""
|
|
123
|
+
args = parse_args()
|
|
124
|
+
|
|
125
|
+
try:
|
|
126
|
+
request = build_request(args)
|
|
127
|
+
except (json.JSONDecodeError, ValueError) as e:
|
|
128
|
+
print(f'❌ Failed to build request: {e}', file=sys.stderr)
|
|
129
|
+
sys.exit(1)
|
|
130
|
+
|
|
131
|
+
# Write the JSON request to the output file
|
|
132
|
+
try:
|
|
133
|
+
with open(args.output_file, 'w') as f:
|
|
134
|
+
json.dump(request, f, indent=2)
|
|
135
|
+
except IOError as e:
|
|
136
|
+
print(f'❌ Failed to write request file: {e}', file=sys.stderr)
|
|
137
|
+
sys.exit(1)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
if __name__ == '__main__':
|
|
141
|
+
main()
|