@aws/ml-container-creator 0.8.0 → 0.9.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,143 @@
1
+ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ // SPDX-License-Identifier: Apache-2.0
3
+
4
+ /**
5
+ * Train Config Persistence
6
+ *
7
+ * JavaScript module that models the config persistence logic from the bash
8
+ * `_update_config_var` function in `templates/do/train`. This module provides
9
+ * a pure JavaScript implementation for property-based testing of the config
10
+ * persistence behavior after job submission and completion.
11
+ *
12
+ * The config file uses the format:
13
+ * export VAR_NAME="value"
14
+ *
15
+ * Behavior:
16
+ * - If the variable already exists: update its value in-place
17
+ * - If the variable doesn't exist: append it to the end
18
+ * - Existing variables in the config are preserved
19
+ *
20
+ * Requirements: 3.4, 5.1
21
+ */
22
+
23
+ import { readFileSync, writeFileSync } from 'node:fs';
24
+
25
+ /**
26
+ * Update or add a config variable in a do/config-style file.
27
+ * Mimics the bash _update_config_var() function from templates/do/train:
28
+ *
29
+ * if grep -q "^export ${var_name}=" "${config_file}"; then
30
+ * sed -i.bak "s|^export ${var_name}=.*|export ${var_name}=\"${var_value}\"|" "${config_file}"
31
+ * rm -f "${config_file}.bak"
32
+ * else
33
+ * echo "export ${var_name}=\"${var_value}\"" >> "${config_file}"
34
+ * fi
35
+ *
36
+ * @param {string} configContent - Current content of the config file
37
+ * @param {string} varName - Variable name (e.g., TRAIN_JOB_NAME)
38
+ * @param {string} varValue - Variable value
39
+ * @returns {string} Updated config content
40
+ */
41
+ export function updateConfigVar(configContent, varName, varValue) {
42
+ const pattern = new RegExp(`^export ${varName}=.*$`, 'm');
43
+
44
+ if (pattern.test(configContent)) {
45
+ // Variable exists — update in-place
46
+ return configContent.replace(pattern, `export ${varName}="${varValue}"`);
47
+ } else {
48
+ // Variable doesn't exist — append
49
+ let result = configContent;
50
+ if (result.length > 0 && !result.endsWith('\n')) {
51
+ result += '\n';
52
+ }
53
+ result += `export ${varName}="${varValue}"\n`;
54
+ return result;
55
+ }
56
+ }
57
+
58
+ /**
59
+ * Read a config variable value from a do/config-style file content.
60
+ *
61
+ * @param {string} configContent - Content of the config file
62
+ * @param {string} varName - Variable name to read
63
+ * @returns {string|null} The variable value, or null if not found
64
+ */
65
+ export function readConfigVar(configContent, varName) {
66
+ const pattern = new RegExp(`^export ${varName}="([^"]*)"`, 'm');
67
+ const match = configContent.match(pattern);
68
+ return match ? match[1] : null;
69
+ }
70
+
71
+ /**
72
+ * Simulate the config writes that happen after a successful training job submission.
73
+ * This mirrors the behavior in do/train's _submit_job() function which calls:
74
+ * _update_config_var "TRAIN_JOB_NAME" "${JOB_NAME}"
75
+ *
76
+ * @param {string} configContent - Current content of the config file
77
+ * @param {object} params - Submission parameters
78
+ * @param {string} params.jobName - Generated job name (pattern: ${PROJECT_NAME}-train-${TIMESTAMP})
79
+ * @returns {string} Updated config content
80
+ */
81
+ export function persistTrainSubmission(configContent, { jobName }) {
82
+ return updateConfigVar(configContent, 'TRAIN_JOB_NAME', jobName);
83
+ }
84
+
85
+ /**
86
+ * Simulate the config writes that happen after a training job completes.
87
+ * This mirrors the behavior in do/train's _handle_completion() function which calls:
88
+ * _update_config_var "TRAIN_OUTPUT_PATH" "${output_path}"
89
+ *
90
+ * @param {string} configContent - Current content of the config file
91
+ * @param {object} params - Completion parameters
92
+ * @param {string} params.outputPath - S3 path to the output artifacts
93
+ * @returns {string} Updated config content
94
+ */
95
+ export function persistTrainCompletion(configContent, { outputPath }) {
96
+ return updateConfigVar(configContent, 'TRAIN_OUTPUT_PATH', outputPath);
97
+ }
98
+
99
+ /**
100
+ * Generate a training job name following the pattern used by do/train.
101
+ * Pattern: ${projectName}-train-YYYYMMDD-HHMMSS
102
+ *
103
+ * @param {string} projectName - Project name
104
+ * @param {Date} [timestamp] - Optional timestamp (defaults to now)
105
+ * @returns {string} Generated job name
106
+ */
107
+ export function generateTrainJobName(projectName, timestamp = new Date()) {
108
+ const year = timestamp.getFullYear().toString();
109
+ const month = (timestamp.getMonth() + 1).toString().padStart(2, '0');
110
+ const day = timestamp.getDate().toString().padStart(2, '0');
111
+ const hours = timestamp.getHours().toString().padStart(2, '0');
112
+ const minutes = timestamp.getMinutes().toString().padStart(2, '0');
113
+ const seconds = timestamp.getSeconds().toString().padStart(2, '0');
114
+ const dateStr = `${year}${month}${day}`;
115
+ const timeStr = `${hours}${minutes}${seconds}`;
116
+ return `${projectName}-train-${dateStr}-${timeStr}`;
117
+ }
118
+
119
+ /**
120
+ * File-based version of updateConfigVar that reads/writes to disk.
121
+ * Used for integration-style tests that need actual file I/O.
122
+ *
123
+ * @param {string} configPath - Path to the config file
124
+ * @param {string} varName - Variable name
125
+ * @param {string} varValue - Variable value
126
+ */
127
+ export function updateConfigVarFile(configPath, varName, varValue) {
128
+ const content = readFileSync(configPath, 'utf8');
129
+ const updated = updateConfigVar(content, varName, varValue);
130
+ writeFileSync(configPath, updated, 'utf8');
131
+ }
132
+
133
+ /**
134
+ * File-based version of readConfigVar that reads from disk.
135
+ *
136
+ * @param {string} configPath - Path to the config file
137
+ * @param {string} varName - Variable name to read
138
+ * @returns {string|null} The variable value, or null if not found
139
+ */
140
+ export function readConfigVarFile(configPath, varName) {
141
+ const content = readFileSync(configPath, 'utf8');
142
+ return readConfigVar(content, varName);
143
+ }
@@ -0,0 +1,112 @@
1
+ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ // SPDX-License-Identifier: Apache-2.0
3
+
4
+ /**
5
+ * Train Config Validator
6
+ *
7
+ * Validates training configuration objects parsed from do/training/config.yaml.
8
+ * Checks that all required fields are present and provides descriptive error
9
+ * messages naming the specific missing field.
10
+ *
11
+ * This module mirrors the validation logic in the bash `_validate_config`
12
+ * function in templates/do/train, enabling property-based testing of the
13
+ * validation rules in isolation.
14
+ *
15
+ * Requirements: 2.12, 10.1
16
+ */
17
+
18
+ /**
19
+ * Required fields for a valid training configuration.
20
+ * Each entry maps the field name to a human-readable description and expected format.
21
+ */
22
+ export const REQUIRED_FIELDS = {
23
+ image: {
24
+ description: 'The container image URI',
25
+ format: 'image: "123456789012.dkr.ecr.us-east-1.amazonaws.com/my-training:latest"'
26
+ },
27
+ script: {
28
+ description: 'The training script S3 path',
29
+ format: 'script: "s3://my-bucket/scripts/train.py"'
30
+ },
31
+ instance_type: {
32
+ description: 'The SageMaker instance type',
33
+ format: 'instance_type: "ml.g5.xlarge"'
34
+ },
35
+ dataset: {
36
+ description: 'The S3 dataset path',
37
+ format: 'dataset: "s3://my-bucket/data/train/"'
38
+ },
39
+ output_path: {
40
+ description: 'The S3 output path',
41
+ format: 'output_path: "s3://my-bucket/output/"'
42
+ }
43
+ };
44
+
45
+ /**
46
+ * Validate that all required fields are present in a training config.
47
+ *
48
+ * @param {Object} config - The parsed training configuration object
49
+ * @returns {{ valid: boolean, errors: Array<{ field: string, message: string }> }}
50
+ * - valid: true if all required fields are present and non-empty
51
+ * - errors: array of error objects, each naming the missing field
52
+ */
53
+ export function validateRequiredFields(config) {
54
+ const errors = [];
55
+
56
+ for (const [field, meta] of Object.entries(REQUIRED_FIELDS)) {
57
+ const value = config ? config[field] : undefined;
58
+
59
+ if (value === undefined || value === null || value === '') {
60
+ errors.push({
61
+ field,
62
+ message: `Missing required field: ${field}\n ${meta.description} is required in do/training/config.yaml\n\n Expected format: ${meta.format}`
63
+ });
64
+ }
65
+ }
66
+
67
+ return {
68
+ valid: errors.length === 0,
69
+ errors
70
+ };
71
+ }
72
+
73
+ /**
74
+ * Validate spot training checkpoint requirement.
75
+ * When enable_spot is true, checkpoint_path must be specified.
76
+ *
77
+ * @param {Object} config - The parsed training configuration object
78
+ * @returns {{ valid: boolean, errors: Array<{ field: string, message: string }> }}
79
+ */
80
+ export function validateSpotConfig(config) {
81
+ const errors = [];
82
+
83
+ if (config && config.enable_spot === true && (!config.checkpoint_path || config.checkpoint_path === '')) {
84
+ errors.push({
85
+ field: 'checkpoint_path',
86
+ message: 'Checkpoint path required for spot training\n When enable_spot is true, a checkpoint S3 path must be specified\n so training can resume after spot interruptions.'
87
+ });
88
+ }
89
+
90
+ return {
91
+ valid: errors.length === 0,
92
+ errors
93
+ };
94
+ }
95
+
96
+ /**
97
+ * Full validation of a training config — checks required fields and spot config.
98
+ *
99
+ * @param {Object} config - The parsed training configuration object
100
+ * @returns {{ valid: boolean, errors: Array<{ field: string, message: string }> }}
101
+ */
102
+ export function validateTrainingConfig(config) {
103
+ const requiredResult = validateRequiredFields(config);
104
+ const spotResult = validateSpotConfig(config);
105
+
106
+ const allErrors = [...requiredResult.errors, ...spotResult.errors];
107
+
108
+ return {
109
+ valid: allErrors.length === 0,
110
+ errors: allErrors
111
+ };
112
+ }
@@ -0,0 +1,46 @@
1
+ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ // SPDX-License-Identifier: Apache-2.0
3
+
4
+ /**
5
+ * Train Feedback Loop — JavaScript equivalent of do/lib/feedback.sh
6
+ *
7
+ * Generates post-completion feedback output with artifact locations
8
+ * and deployment suggestions based on artifact type.
9
+ */
10
+
11
+ /**
12
+ * Generate completion feedback output for a training/tuning job.
13
+ *
14
+ * Replicates the logic of print_completion_feedback() in do/lib/feedback.sh.
15
+ *
16
+ * @param {object} params
17
+ * @param {string} params.outputPath - S3 URI to the output artifacts
18
+ * @param {string} params.outputType - "adapter" or "full-model"
19
+ * @param {string} params.jobName - Job name for reference
20
+ * @param {string} [params.modelPackageArn] - Optional model package ARN
21
+ * @returns {string} The formatted feedback output
22
+ */
23
+ export function generateCompletionFeedback({ outputPath, outputType, jobName, modelPackageArn = '' }) {
24
+ const lines = [];
25
+
26
+ lines.push('');
27
+ lines.push(`✅ Training complete: ${jobName}`);
28
+ lines.push('');
29
+ lines.push(` Artifacts: ${outputPath}`);
30
+ if (modelPackageArn) {
31
+ lines.push(` Model Package: ${modelPackageArn}`);
32
+ }
33
+ lines.push('');
34
+ lines.push(' Next steps:');
35
+
36
+ if (outputType === 'adapter') {
37
+ lines.push(` • Deploy as LoRA adapter: ./do/adapter add my-adapter --weights ${outputPath}`);
38
+ lines.push(' • (Requires running endpoint with LoRA enabled)');
39
+ } else if (outputType === 'full-model') {
40
+ lines.push(` • Deploy as new IC: ./do/add-ic my-model --model-data ${outputPath}`);
41
+ lines.push(` • Replace current base: ./do/deploy --force-ic --model-data ${outputPath}`);
42
+ }
43
+ lines.push('');
44
+
45
+ return lines.join('\n');
46
+ }
@@ -0,0 +1,97 @@
1
+ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ // SPDX-License-Identifier: Apache-2.0
3
+
4
+ /**
5
+ * Train Idempotency Decision Logic
6
+ *
7
+ * Models the idempotency check logic from the bash `_check_idempotency` function
8
+ * in `templates/do/train` as a pure JavaScript function for property-based testing.
9
+ *
10
+ * The idempotency pattern:
11
+ * - If --force is set, always create a new job regardless of existing status
12
+ * - If no existing job, create a new job
13
+ * - If existing job is InProgress, poll it
14
+ * - If existing job is Completed, display results
15
+ * - If existing job is Failed or Stopped, display failure and suggest --force
16
+ *
17
+ * Requirements: 5.1–5.5
18
+ */
19
+
20
+ /**
21
+ * Valid existing job statuses that SageMaker can report.
22
+ */
23
+ export const JOB_STATUSES = ['InProgress', 'Completed', 'Failed', 'Stopped'];
24
+
25
+ /**
26
+ * Possible actions the train script can take after the idempotency check.
27
+ */
28
+ export const ACTIONS = {
29
+ CREATE_NEW_JOB: 'create_new_job',
30
+ POLL_EXISTING: 'poll_existing',
31
+ DISPLAY_RESULTS: 'display_results',
32
+ DISPLAY_FAILURE: 'display_failure'
33
+ };
34
+
35
+ /**
36
+ * Determine the action to take based on existing job status and force flag.
37
+ *
38
+ * This mirrors the bash `_check_idempotency` logic in a testable form:
39
+ * - force=true → always create_new_job
40
+ * - no existing status (null/empty) → create_new_job
41
+ * - InProgress → poll_existing
42
+ * - Completed → display_results
43
+ * - Failed → display_failure
44
+ * - Stopped → display_failure
45
+ *
46
+ * @param {string|null|undefined} existingStatus - The current job status from DescribeTrainingJob
47
+ * @param {boolean} forceFlag - Whether --force was specified
48
+ * @returns {{ action: string, reason: string }}
49
+ * - action: one of ACTIONS values
50
+ * - reason: human-readable explanation of why this action was chosen
51
+ */
52
+ export function determineAction(existingStatus, forceFlag) {
53
+ // Force flag always overrides — create a new job regardless of existing status
54
+ if (forceFlag === true) {
55
+ return {
56
+ action: ACTIONS.CREATE_NEW_JOB,
57
+ reason: '--force specified, creating new job regardless of existing status'
58
+ };
59
+ }
60
+
61
+ // No existing job — create a new one
62
+ if (!existingStatus || existingStatus === '') {
63
+ return {
64
+ action: ACTIONS.CREATE_NEW_JOB,
65
+ reason: 'No existing job found, creating new job'
66
+ };
67
+ }
68
+
69
+ // Existing job found — action depends on status
70
+ switch (existingStatus) {
71
+ case 'InProgress':
72
+ return {
73
+ action: ACTIONS.POLL_EXISTING,
74
+ reason: `Existing job is ${existingStatus}, resuming polling`
75
+ };
76
+
77
+ case 'Completed':
78
+ return {
79
+ action: ACTIONS.DISPLAY_RESULTS,
80
+ reason: `Existing job is ${existingStatus}, displaying results`
81
+ };
82
+
83
+ case 'Failed':
84
+ case 'Stopped':
85
+ return {
86
+ action: ACTIONS.DISPLAY_FAILURE,
87
+ reason: `Existing job is ${existingStatus}, suggest --force to create new job`
88
+ };
89
+
90
+ default:
91
+ // Unknown status — treat as failure, suggest --force
92
+ return {
93
+ action: ACTIONS.DISPLAY_FAILURE,
94
+ reason: `Unexpected job status: ${existingStatus}, suggest --force`
95
+ };
96
+ }
97
+ }
@@ -0,0 +1,120 @@
1
+ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2
+ // SPDX-License-Identifier: Apache-2.0
3
+
4
+ /**
5
+ * Train Request Builder
6
+ *
7
+ * JavaScript module that replicates the Python helper's (.train_build_request.py)
8
+ * logic for constructing a CreateTrainingJob JSON request from a parsed config.
9
+ *
10
+ * This module mirrors the behavior of the Python build_request() function,
11
+ * providing a testable implementation of the config-to-API mapping logic.
12
+ */
13
+
14
+ /**
15
+ * Build a CreateTrainingJob request from a parsed training config.
16
+ *
17
+ * Maps config fields to the SageMaker CreateTrainingJob API structure:
18
+ * - image → AlgorithmSpecification.TrainingImage
19
+ * - instance_type → ResourceConfig.InstanceType
20
+ * - instance_count → ResourceConfig.InstanceCount
21
+ * - dataset → InputDataConfig[0].DataSource.S3DataSource.S3Uri
22
+ * - output_path → OutputDataConfig.S3OutputPath
23
+ * - hyperparameters → HyperParameters (string key-value pairs)
24
+ * - max_runtime_seconds → StoppingCondition.MaxRuntimeInSeconds
25
+ * - enable_spot=true → EnableManagedSpotTraining = true
26
+ * - enable_spot=true → StoppingCondition.MaxWaitTimeInSeconds
27
+ * - checkpoint_path → CheckpointConfig.S3Uri
28
+ * - metric_definitions → AlgorithmSpecification.MetricDefinitions
29
+ * - environment → Environment
30
+ * - tags → Tags (converted from {k:v} to [{Key:k, Value:v}])
31
+ *
32
+ * @param {object} options - Build options
33
+ * @param {string} options.jobName - Training job name
34
+ * @param {string} options.roleArn - SageMaker execution role ARN
35
+ * @param {object} options.config - Parsed training config (from parseTrainingConfig)
36
+ * @returns {object} CreateTrainingJob request body
37
+ */
38
+ export function buildTrainingJobRequest({ jobName, roleArn, config }) {
39
+ const request = {
40
+ TrainingJobName: jobName,
41
+ RoleArn: roleArn,
42
+ AlgorithmSpecification: {
43
+ TrainingImage: config.image,
44
+ TrainingInputMode: 'File'
45
+ },
46
+ InputDataConfig: [
47
+ {
48
+ ChannelName: 'training',
49
+ DataSource: {
50
+ S3DataSource: {
51
+ S3DataType: 'S3Prefix',
52
+ S3Uri: config.dataset,
53
+ S3DataDistributionType: 'FullyReplicated'
54
+ }
55
+ }
56
+ }
57
+ ],
58
+ OutputDataConfig: {
59
+ S3OutputPath: config.output_path
60
+ },
61
+ ResourceConfig: {
62
+ InstanceType: config.instance_type,
63
+ InstanceCount: parseInt(config.instance_count, 10),
64
+ VolumeSizeInGB: parseInt(config.volume_size_gb, 10)
65
+ },
66
+ StoppingCondition: {
67
+ MaxRuntimeInSeconds: parseInt(config.max_runtime_seconds, 10)
68
+ }
69
+ };
70
+
71
+ // Hyperparameters — ensure all values are strings (SageMaker requirement)
72
+ const hyperparams = config.hyperparameters || {};
73
+ if (Object.keys(hyperparams).length > 0) {
74
+ request.HyperParameters = {};
75
+ for (const [k, v] of Object.entries(hyperparams)) {
76
+ request.HyperParameters[String(k)] = String(v);
77
+ }
78
+ }
79
+
80
+ // Managed spot training
81
+ const enableSpot = config.enable_spot === 'true' || config.enable_spot === true;
82
+ if (enableSpot) {
83
+ request.EnableManagedSpotTraining = true;
84
+ request.StoppingCondition.MaxWaitTimeInSeconds = parseInt(config.max_wait_seconds, 10);
85
+ }
86
+
87
+ // Checkpoint configuration (for spot training resumption)
88
+ const checkpointPath = config.checkpoint_path || '';
89
+ if (checkpointPath) {
90
+ request.CheckpointConfig = {
91
+ S3Uri: checkpointPath
92
+ };
93
+ }
94
+
95
+ // Metric definitions (custom CloudWatch metrics)
96
+ const metricDefs = config.metric_definitions || [];
97
+ if (Array.isArray(metricDefs) && metricDefs.length > 0) {
98
+ request.AlgorithmSpecification.MetricDefinitions = metricDefs.map(m => ({
99
+ Name: m.name,
100
+ Regex: m.regex
101
+ }));
102
+ }
103
+
104
+ // Environment variables for the container
105
+ const environment = config.environment || {};
106
+ if (Object.keys(environment).length > 0) {
107
+ request.Environment = environment;
108
+ }
109
+
110
+ // Tags — convert from {key: value} map to [{Key: k, Value: v}] array
111
+ const tags = config.tags || {};
112
+ if (Object.keys(tags).length > 0) {
113
+ request.Tags = Object.entries(tags).map(([k, v]) => ({
114
+ Key: String(k),
115
+ Value: String(v)
116
+ }));
117
+ }
118
+
119
+ return request;
120
+ }
@@ -0,0 +1,141 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ """
6
+ Build the CreateTrainingJob JSON request for SageMaker.
7
+
8
+ This helper is called by do/train to construct the full API request body.
9
+ It handles conditional fields (spot training, metric definitions, environment,
10
+ tags) and writes the result to a JSON file for use with:
11
+ aws sagemaker create-training-job --cli-input-json file://path.json
12
+ """
13
+
14
+ import argparse
15
+ import json
16
+ import sys
17
+
18
+
19
+ def parse_args():
20
+ """Parse command-line arguments."""
21
+ parser = argparse.ArgumentParser(description='Build CreateTrainingJob request JSON')
22
+ parser.add_argument('--job-name', required=True, help='Training job name')
23
+ parser.add_argument('--role-arn', required=True, help='SageMaker execution role ARN')
24
+ parser.add_argument('--image', required=True, help='Training container image URI')
25
+ parser.add_argument('--instance-type', required=True, help='Instance type')
26
+ parser.add_argument('--instance-count', required=True, help='Instance count')
27
+ parser.add_argument('--volume-size', required=True, help='Volume size in GB')
28
+ parser.add_argument('--dataset', required=True, help='S3 URI for training dataset')
29
+ parser.add_argument('--output-path', required=True, help='S3 URI for output')
30
+ parser.add_argument('--max-runtime', required=True, help='Max runtime in seconds')
31
+ parser.add_argument('--hyperparams', required=True, help='Hyperparameters as JSON string')
32
+ parser.add_argument('--enable-spot', required=True, help='Enable spot training (true/false)')
33
+ parser.add_argument('--max-wait', required=True, help='Max wait time for spot in seconds')
34
+ parser.add_argument('--checkpoint-path', required=True, help='S3 checkpoint path')
35
+ parser.add_argument('--metric-definitions', required=True, help='Metric definitions as JSON array')
36
+ parser.add_argument('--environment', required=True, help='Environment variables as JSON object')
37
+ parser.add_argument('--tags', required=True, help='Tags as JSON object (key-value map)')
38
+ parser.add_argument('--output-file', required=True, help='Output file path for the JSON')
39
+ return parser.parse_args()
40
+
41
+
42
+ def build_request(args):
43
+ """Construct the CreateTrainingJob request dictionary."""
44
+ # Parse JSON inputs
45
+ hyperparams = json.loads(args.hyperparams) if args.hyperparams else {}
46
+ metric_definitions = json.loads(args.metric_definitions) if args.metric_definitions else []
47
+ environment = json.loads(args.environment) if args.environment else {}
48
+ tags = json.loads(args.tags) if args.tags else {}
49
+
50
+ # Base request structure
51
+ request = {
52
+ 'TrainingJobName': args.job_name,
53
+ 'RoleArn': args.role_arn,
54
+ 'AlgorithmSpecification': {
55
+ 'TrainingImage': args.image,
56
+ 'TrainingInputMode': 'File'
57
+ },
58
+ 'InputDataConfig': [
59
+ {
60
+ 'ChannelName': 'training',
61
+ 'DataSource': {
62
+ 'S3DataSource': {
63
+ 'S3DataType': 'S3Prefix',
64
+ 'S3Uri': args.dataset,
65
+ 'S3DataDistributionType': 'FullyReplicated'
66
+ }
67
+ }
68
+ }
69
+ ],
70
+ 'OutputDataConfig': {
71
+ 'S3OutputPath': args.output_path
72
+ },
73
+ 'ResourceConfig': {
74
+ 'InstanceType': args.instance_type,
75
+ 'InstanceCount': int(args.instance_count),
76
+ 'VolumeSizeInGB': int(args.volume_size)
77
+ },
78
+ 'StoppingCondition': {
79
+ 'MaxRuntimeInSeconds': int(args.max_runtime)
80
+ }
81
+ }
82
+
83
+ # Hyperparameters — ensure all values are strings (SageMaker requirement)
84
+ if hyperparams:
85
+ request['HyperParameters'] = {
86
+ str(k): str(v) for k, v in hyperparams.items()
87
+ }
88
+
89
+ # Managed spot training
90
+ if args.enable_spot == 'true':
91
+ request['EnableManagedSpotTraining'] = True
92
+ request['StoppingCondition']['MaxWaitTimeInSeconds'] = int(args.max_wait)
93
+
94
+ # Checkpoint configuration (for spot training resumption)
95
+ if args.checkpoint_path:
96
+ request['CheckpointConfig'] = {
97
+ 'S3Uri': args.checkpoint_path
98
+ }
99
+
100
+ # Metric definitions (custom CloudWatch metrics)
101
+ if metric_definitions and metric_definitions != []:
102
+ request['AlgorithmSpecification']['MetricDefinitions'] = [
103
+ {'Name': m['name'], 'Regex': m['regex']}
104
+ for m in metric_definitions
105
+ ]
106
+
107
+ # Environment variables for the container
108
+ if environment and environment != {}:
109
+ request['Environment'] = environment
110
+
111
+ # Tags — convert from {key: value} map to [{Key: k, Value: v}] array
112
+ if tags and tags != {}:
113
+ request['Tags'] = [
114
+ {'Key': str(k), 'Value': str(v)}
115
+ for k, v in tags.items()
116
+ ]
117
+
118
+ return request
119
+
120
+
121
+ def main():
122
+ """Main entry point."""
123
+ args = parse_args()
124
+
125
+ try:
126
+ request = build_request(args)
127
+ except (json.JSONDecodeError, ValueError) as e:
128
+ print(f'❌ Failed to build request: {e}', file=sys.stderr)
129
+ sys.exit(1)
130
+
131
+ # Write the JSON request to the output file
132
+ try:
133
+ with open(args.output_file, 'w') as f:
134
+ json.dump(request, f, indent=2)
135
+ except IOError as e:
136
+ print(f'❌ Failed to write request file: {e}', file=sys.stderr)
137
+ sys.exit(1)
138
+
139
+
140
+ if __name__ == '__main__':
141
+ main()