@aws/ml-container-creator 0.8.0 → 0.9.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE-THIRD-PARTY +50760 -16218
- package/package.json +3 -1
- package/servers/lib/catalogs/instances.json +52 -1275
- package/servers/lib/catalogs/models.json +0 -132
- package/servers/lib/catalogs/popular-diffusors.json +1 -110
- package/src/app.js +24 -2
- package/src/lib/mcp-client.js +16 -1
- package/src/lib/mcp-command-handler.js +10 -2
- package/src/lib/prompt-runner.js +16 -2
- package/src/lib/train-config-parser.js +136 -0
- package/src/lib/train-config-persistence.js +143 -0
- package/src/lib/train-config-validator.js +112 -0
- package/src/lib/train-feedback.js +46 -0
- package/src/lib/train-idempotency.js +97 -0
- package/src/lib/train-request-builder.js +120 -0
- package/templates/do/.train_build_request.py +141 -0
- package/templates/do/.train_poll_parser.py +135 -0
- package/templates/do/.train_status_parser.py +187 -0
- package/templates/do/lib/feedback.sh +41 -0
- package/templates/do/train +786 -0
- package/templates/do/training/config.yaml +140 -0
- package/templates/do/training/train.py +463 -0
|
@@ -0,0 +1,786 @@
|
|
|
1
|
+
#!/bin/bash
|
|
2
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
# do/train — SageMaker Bespoke Training Job
|
|
6
|
+
# Wraps CreateTrainingJob for custom training where you provide your own
|
|
7
|
+
# training script, container, dataset, and hyperparameters.
|
|
8
|
+
# Configuration is read from do/training/config.yaml.
|
|
9
|
+
#
|
|
10
|
+
# Project: <%= projectName %>
|
|
11
|
+
|
|
12
|
+
set -e
|
|
13
|
+
set -u
|
|
14
|
+
set -o pipefail
|
|
15
|
+
|
|
16
|
+
# ── Source project configuration ──────────────────────────────────────────────
|
|
17
|
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
|
18
|
+
source "${SCRIPT_DIR}/config"
|
|
19
|
+
|
|
20
|
+
# ── Constants ─────────────────────────────────────────────────────────────────
|
|
21
|
+
CONFIG_FILE="${SCRIPT_DIR}/training/config.yaml"
|
|
22
|
+
POLL_INTERVAL=60
|
|
23
|
+
|
|
24
|
+
# ── CLI Variables (set by _parse_args) ────────────────────────────────────────
|
|
25
|
+
ARG_FORCE=false
|
|
26
|
+
ARG_STATUS=false
|
|
27
|
+
ARG_DRY_RUN=false
|
|
28
|
+
ARG_NO_WAIT=false
|
|
29
|
+
ARG_HELP=false
|
|
30
|
+
|
|
31
|
+
# ── Job Variables (set by _build_job_request) ─────────────────────────────────
|
|
32
|
+
JOB_NAME=""
|
|
33
|
+
JOB_REQUEST_FILE=""
|
|
34
|
+
|
|
35
|
+
# ── Training Config Variables (set by _parse_config) ──────────────────────────
|
|
36
|
+
TRAIN_IMAGE=""
|
|
37
|
+
TRAIN_SCRIPT=""
|
|
38
|
+
TRAIN_INSTANCE_TYPE=""
|
|
39
|
+
TRAIN_INSTANCE_COUNT=""
|
|
40
|
+
TRAIN_DATASET=""
|
|
41
|
+
TRAIN_OUTPUT_PATH=""
|
|
42
|
+
TRAIN_HYPERPARAMS=""
|
|
43
|
+
TRAIN_MAX_RUNTIME=""
|
|
44
|
+
TRAIN_VOLUME_SIZE=""
|
|
45
|
+
TRAIN_ENABLE_SPOT=""
|
|
46
|
+
TRAIN_MAX_WAIT=""
|
|
47
|
+
TRAIN_CHECKPOINT_PATH=""
|
|
48
|
+
TRAIN_METRIC_DEFINITIONS=""
|
|
49
|
+
TRAIN_ENVIRONMENT=""
|
|
50
|
+
TRAIN_TAGS=""
|
|
51
|
+
|
|
52
|
+
# ── SIGINT Trap ───────────────────────────────────────────────────────────────
|
|
53
|
+
trap '_handle_interrupt' INT
|
|
54
|
+
|
|
55
|
+
_handle_interrupt() {
|
|
56
|
+
echo ""
|
|
57
|
+
echo ""
|
|
58
|
+
echo "⚠️ Interrupted — training job continues in background."
|
|
59
|
+
echo " The job will keep running in SageMaker."
|
|
60
|
+
echo " Re-run ./do/train to resume polling, or ./do/train --status to check progress."
|
|
61
|
+
exit 130
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
# ── _parse_args() ─────────────────────────────────────────────────────────────
|
|
65
|
+
# Parse CLI flags into variables.
|
|
66
|
+
_parse_args() {
|
|
67
|
+
while [ $# -gt 0 ]; do
|
|
68
|
+
case "$1" in
|
|
69
|
+
--force) ARG_FORCE=true; shift ;;
|
|
70
|
+
--status) ARG_STATUS=true; shift ;;
|
|
71
|
+
--dry-run) ARG_DRY_RUN=true; shift ;;
|
|
72
|
+
--no-wait) ARG_NO_WAIT=true; shift ;;
|
|
73
|
+
--help|-h) ARG_HELP=true; shift ;;
|
|
74
|
+
*)
|
|
75
|
+
echo "❌ Unknown option: $1"
|
|
76
|
+
echo " Run ./do/train --help for usage."
|
|
77
|
+
exit 1
|
|
78
|
+
;;
|
|
79
|
+
esac
|
|
80
|
+
done
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
# ── _show_help() ──────────────────────────────────────────────────────────────
|
|
84
|
+
_show_help() {
|
|
85
|
+
echo "Usage: ./do/train [OPTIONS]"
|
|
86
|
+
echo " ./do/train --status"
|
|
87
|
+
echo " ./do/train --help"
|
|
88
|
+
echo ""
|
|
89
|
+
echo "SageMaker Bespoke Training — submit custom training jobs using your own"
|
|
90
|
+
echo "training script, container, dataset, and hyperparameters."
|
|
91
|
+
echo ""
|
|
92
|
+
echo "Configuration is read from do/training/config.yaml"
|
|
93
|
+
echo ""
|
|
94
|
+
echo "Options:"
|
|
95
|
+
echo " --force Create a new job even if a previous job exists"
|
|
96
|
+
echo " --status Show current job status without submitting"
|
|
97
|
+
echo " --dry-run Validate inputs and show the CreateTrainingJob request without submitting"
|
|
98
|
+
echo " --no-wait Submit job and exit without polling for completion"
|
|
99
|
+
echo " --help, -h Show this help message"
|
|
100
|
+
echo ""
|
|
101
|
+
echo "Examples:"
|
|
102
|
+
echo " ./do/train # Submit a training job"
|
|
103
|
+
echo " ./do/train --status # Check status of current job"
|
|
104
|
+
echo " ./do/train --dry-run # Validate and preview request"
|
|
105
|
+
echo " ./do/train --force # Force a new job after failure"
|
|
106
|
+
exit 0
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
# ── _parse_config() ───────────────────────────────────────────────────────────
|
|
110
|
+
# Read and parse do/training/config.yaml into bash variables.
|
|
111
|
+
# Uses yq if available, falls back to python3 YAML parsing.
|
|
112
|
+
_parse_config() {
|
|
113
|
+
if [ ! -f "${CONFIG_FILE}" ]; then
|
|
114
|
+
echo "❌ Configuration file not found: ${CONFIG_FILE}"
|
|
115
|
+
echo " Expected at: do/training/config.yaml"
|
|
116
|
+
echo ""
|
|
117
|
+
echo " Create it with the required fields (image, script, instance_type, dataset, output_path)."
|
|
118
|
+
exit 1
|
|
119
|
+
fi
|
|
120
|
+
|
|
121
|
+
if command -v yq &>/dev/null; then
|
|
122
|
+
_parse_config_yq
|
|
123
|
+
elif command -v python3 &>/dev/null; then
|
|
124
|
+
_parse_config_python
|
|
125
|
+
else
|
|
126
|
+
echo "❌ Neither yq nor python3 found."
|
|
127
|
+
echo " Install yq (https://github.com/mikefarah/yq) or ensure python3 is available."
|
|
128
|
+
exit 1
|
|
129
|
+
fi
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
# ── _parse_config_yq() ───────────────────────────────────────────────────────
|
|
133
|
+
# Parse config.yaml using yq.
|
|
134
|
+
_parse_config_yq() {
|
|
135
|
+
TRAIN_IMAGE=$(yq -r '.image // ""' "${CONFIG_FILE}")
|
|
136
|
+
TRAIN_SCRIPT=$(yq -r '.script // ""' "${CONFIG_FILE}")
|
|
137
|
+
TRAIN_INSTANCE_TYPE=$(yq -r '.instance_type // ""' "${CONFIG_FILE}")
|
|
138
|
+
TRAIN_INSTANCE_COUNT=$(yq -r '.instance_count // "1"' "${CONFIG_FILE}")
|
|
139
|
+
TRAIN_DATASET=$(yq -r '.dataset // ""' "${CONFIG_FILE}")
|
|
140
|
+
TRAIN_OUTPUT_PATH=$(yq -r '.output_path // ""' "${CONFIG_FILE}")
|
|
141
|
+
TRAIN_MAX_RUNTIME=$(yq -r '.max_runtime_seconds // "86400"' "${CONFIG_FILE}")
|
|
142
|
+
TRAIN_VOLUME_SIZE=$(yq -r '.volume_size_gb // "50"' "${CONFIG_FILE}")
|
|
143
|
+
TRAIN_ENABLE_SPOT=$(yq -r '.enable_spot // "false"' "${CONFIG_FILE}")
|
|
144
|
+
TRAIN_MAX_WAIT=$(yq -r '.max_wait_seconds // "172800"' "${CONFIG_FILE}")
|
|
145
|
+
TRAIN_CHECKPOINT_PATH=$(yq -r '.checkpoint_path // ""' "${CONFIG_FILE}")
|
|
146
|
+
|
|
147
|
+
# Hyperparameters: convert map to JSON string
|
|
148
|
+
local hp_raw
|
|
149
|
+
hp_raw=$(yq -r '.hyperparameters // {}' -o=json "${CONFIG_FILE}")
|
|
150
|
+
if [ "${hp_raw}" = "{}" ] || [ "${hp_raw}" = "null" ]; then
|
|
151
|
+
TRAIN_HYPERPARAMS="{}"
|
|
152
|
+
else
|
|
153
|
+
TRAIN_HYPERPARAMS="${hp_raw}"
|
|
154
|
+
fi
|
|
155
|
+
|
|
156
|
+
# Metric definitions: convert list to JSON string
|
|
157
|
+
local md_raw
|
|
158
|
+
md_raw=$(yq -r '.metric_definitions // []' -o=json "${CONFIG_FILE}")
|
|
159
|
+
if [ "${md_raw}" = "[]" ] || [ "${md_raw}" = "null" ]; then
|
|
160
|
+
TRAIN_METRIC_DEFINITIONS="[]"
|
|
161
|
+
else
|
|
162
|
+
TRAIN_METRIC_DEFINITIONS="${md_raw}"
|
|
163
|
+
fi
|
|
164
|
+
|
|
165
|
+
# Environment: convert map to JSON string
|
|
166
|
+
local env_raw
|
|
167
|
+
env_raw=$(yq -r '.environment // {}' -o=json "${CONFIG_FILE}")
|
|
168
|
+
if [ "${env_raw}" = "{}" ] || [ "${env_raw}" = "null" ]; then
|
|
169
|
+
TRAIN_ENVIRONMENT="{}"
|
|
170
|
+
else
|
|
171
|
+
TRAIN_ENVIRONMENT="${env_raw}"
|
|
172
|
+
fi
|
|
173
|
+
|
|
174
|
+
# Tags: convert map to JSON string
|
|
175
|
+
local tags_raw
|
|
176
|
+
tags_raw=$(yq -r '.tags // {}' -o=json "${CONFIG_FILE}")
|
|
177
|
+
if [ "${tags_raw}" = "{}" ] || [ "${tags_raw}" = "null" ]; then
|
|
178
|
+
TRAIN_TAGS="{}"
|
|
179
|
+
else
|
|
180
|
+
TRAIN_TAGS="${tags_raw}"
|
|
181
|
+
fi
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
# ── _parse_config_python() ───────────────────────────────────────────────────
|
|
185
|
+
# Parse config.yaml using python3 as fallback when yq is not available.
|
|
186
|
+
_parse_config_python() {
|
|
187
|
+
local parse_output
|
|
188
|
+
parse_output=$(python3 -c "
|
|
189
|
+
import yaml, json, sys
|
|
190
|
+
|
|
191
|
+
with open('${CONFIG_FILE}', 'r') as f:
|
|
192
|
+
cfg = yaml.safe_load(f) or {}
|
|
193
|
+
|
|
194
|
+
def s(val, default=''):
|
|
195
|
+
if val is None:
|
|
196
|
+
return default
|
|
197
|
+
if isinstance(val, bool):
|
|
198
|
+
return 'true' if val else 'false'
|
|
199
|
+
return str(val)
|
|
200
|
+
|
|
201
|
+
print(s(cfg.get('image'), ''))
|
|
202
|
+
print(s(cfg.get('script'), ''))
|
|
203
|
+
print(s(cfg.get('instance_type'), ''))
|
|
204
|
+
print(s(cfg.get('instance_count'), '1'))
|
|
205
|
+
print(s(cfg.get('dataset'), ''))
|
|
206
|
+
print(s(cfg.get('output_path'), ''))
|
|
207
|
+
print(s(cfg.get('max_runtime_seconds'), '86400'))
|
|
208
|
+
print(s(cfg.get('volume_size_gb'), '50'))
|
|
209
|
+
print(s(cfg.get('enable_spot'), 'false'))
|
|
210
|
+
print(s(cfg.get('max_wait_seconds'), '172800'))
|
|
211
|
+
print(s(cfg.get('checkpoint_path'), ''))
|
|
212
|
+
print(json.dumps(cfg.get('hyperparameters') or {}))
|
|
213
|
+
print(json.dumps(cfg.get('metric_definitions') or []))
|
|
214
|
+
print(json.dumps(cfg.get('environment') or {}))
|
|
215
|
+
print(json.dumps(cfg.get('tags') or {}))
|
|
216
|
+
" 2>&1)
|
|
217
|
+
|
|
218
|
+
if [ $? -ne 0 ]; then
|
|
219
|
+
echo "❌ Failed to parse ${CONFIG_FILE}"
|
|
220
|
+
echo " ${parse_output}"
|
|
221
|
+
echo ""
|
|
222
|
+
echo " Ensure the file is valid YAML syntax."
|
|
223
|
+
exit 1
|
|
224
|
+
fi
|
|
225
|
+
|
|
226
|
+
# Read each line into the corresponding variable
|
|
227
|
+
local i=0
|
|
228
|
+
while IFS= read -r line; do
|
|
229
|
+
case $i in
|
|
230
|
+
0) TRAIN_IMAGE="${line}" ;;
|
|
231
|
+
1) TRAIN_SCRIPT="${line}" ;;
|
|
232
|
+
2) TRAIN_INSTANCE_TYPE="${line}" ;;
|
|
233
|
+
3) TRAIN_INSTANCE_COUNT="${line}" ;;
|
|
234
|
+
4) TRAIN_DATASET="${line}" ;;
|
|
235
|
+
5) TRAIN_OUTPUT_PATH="${line}" ;;
|
|
236
|
+
6) TRAIN_MAX_RUNTIME="${line}" ;;
|
|
237
|
+
7) TRAIN_VOLUME_SIZE="${line}" ;;
|
|
238
|
+
8) TRAIN_ENABLE_SPOT="${line}" ;;
|
|
239
|
+
9) TRAIN_MAX_WAIT="${line}" ;;
|
|
240
|
+
10) TRAIN_CHECKPOINT_PATH="${line}" ;;
|
|
241
|
+
11) TRAIN_HYPERPARAMS="${line}" ;;
|
|
242
|
+
12) TRAIN_METRIC_DEFINITIONS="${line}" ;;
|
|
243
|
+
13) TRAIN_ENVIRONMENT="${line}" ;;
|
|
244
|
+
14) TRAIN_TAGS="${line}" ;;
|
|
245
|
+
esac
|
|
246
|
+
i=$((i + 1))
|
|
247
|
+
done <<< "${parse_output}"
|
|
248
|
+
|
|
249
|
+
# Apply defaults for any empty optional fields
|
|
250
|
+
TRAIN_INSTANCE_COUNT="${TRAIN_INSTANCE_COUNT:-1}"
|
|
251
|
+
TRAIN_MAX_RUNTIME="${TRAIN_MAX_RUNTIME:-86400}"
|
|
252
|
+
TRAIN_VOLUME_SIZE="${TRAIN_VOLUME_SIZE:-50}"
|
|
253
|
+
TRAIN_ENABLE_SPOT="${TRAIN_ENABLE_SPOT:-false}"
|
|
254
|
+
TRAIN_MAX_WAIT="${TRAIN_MAX_WAIT:-172800}"
|
|
255
|
+
local empty_obj='{}'
|
|
256
|
+
local empty_arr='[]'
|
|
257
|
+
TRAIN_HYPERPARAMS="${TRAIN_HYPERPARAMS:-$empty_obj}"
|
|
258
|
+
TRAIN_METRIC_DEFINITIONS="${TRAIN_METRIC_DEFINITIONS:-$empty_arr}"
|
|
259
|
+
TRAIN_ENVIRONMENT="${TRAIN_ENVIRONMENT:-$empty_obj}"
|
|
260
|
+
TRAIN_TAGS="${TRAIN_TAGS:-$empty_obj}"
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
# ── _validate_config() ────────────────────────────────────────────────────────
|
|
264
|
+
# Check that all required fields are present and valid.
|
|
265
|
+
_validate_config() {
|
|
266
|
+
local has_error=false
|
|
267
|
+
|
|
268
|
+
if [ -z "${TRAIN_IMAGE}" ]; then
|
|
269
|
+
echo "❌ Missing required field: image"
|
|
270
|
+
echo " The container image URI is required in do/training/config.yaml"
|
|
271
|
+
echo ""
|
|
272
|
+
echo " Expected format: image: \"123456789012.dkr.ecr.us-east-1.amazonaws.com/my-training:latest\""
|
|
273
|
+
echo ""
|
|
274
|
+
has_error=true
|
|
275
|
+
fi
|
|
276
|
+
|
|
277
|
+
if [ -z "${TRAIN_SCRIPT}" ]; then
|
|
278
|
+
echo "❌ Missing required field: script"
|
|
279
|
+
echo " The training script S3 path is required in do/training/config.yaml"
|
|
280
|
+
echo ""
|
|
281
|
+
echo " Expected format: script: \"s3://my-bucket/scripts/train.py\""
|
|
282
|
+
echo ""
|
|
283
|
+
has_error=true
|
|
284
|
+
fi
|
|
285
|
+
|
|
286
|
+
if [ -z "${TRAIN_INSTANCE_TYPE}" ]; then
|
|
287
|
+
echo "❌ Missing required field: instance_type"
|
|
288
|
+
echo " The SageMaker instance type is required in do/training/config.yaml"
|
|
289
|
+
echo ""
|
|
290
|
+
echo " Expected format: instance_type: \"ml.g5.xlarge\""
|
|
291
|
+
echo ""
|
|
292
|
+
has_error=true
|
|
293
|
+
fi
|
|
294
|
+
|
|
295
|
+
if [ -z "${TRAIN_DATASET}" ]; then
|
|
296
|
+
echo "❌ Missing required field: dataset"
|
|
297
|
+
echo " The S3 dataset path is required in do/training/config.yaml"
|
|
298
|
+
echo ""
|
|
299
|
+
echo " Expected format: dataset: \"s3://my-bucket/data/train/\""
|
|
300
|
+
echo ""
|
|
301
|
+
has_error=true
|
|
302
|
+
fi
|
|
303
|
+
|
|
304
|
+
if [ -z "${TRAIN_OUTPUT_PATH}" ]; then
|
|
305
|
+
echo "❌ Missing required field: output_path"
|
|
306
|
+
echo " The S3 output path is required in do/training/config.yaml"
|
|
307
|
+
echo ""
|
|
308
|
+
echo " Expected format: output_path: \"s3://my-bucket/output/\""
|
|
309
|
+
echo ""
|
|
310
|
+
has_error=true
|
|
311
|
+
fi
|
|
312
|
+
|
|
313
|
+
# Spot training requires a checkpoint path for resumption
|
|
314
|
+
if [ "${TRAIN_ENABLE_SPOT}" = "true" ] && [ -z "${TRAIN_CHECKPOINT_PATH}" ]; then
|
|
315
|
+
echo "❌ Checkpoint path required for spot training"
|
|
316
|
+
echo " When enable_spot is true, a checkpoint S3 path must be specified"
|
|
317
|
+
echo " so training can resume after spot interruptions."
|
|
318
|
+
echo ""
|
|
319
|
+
echo " Add to do/training/config.yaml:"
|
|
320
|
+
echo " checkpoint_path: \"s3://my-bucket/checkpoints/\""
|
|
321
|
+
echo ""
|
|
322
|
+
has_error=true
|
|
323
|
+
fi
|
|
324
|
+
|
|
325
|
+
if [ "${has_error}" = true ]; then
|
|
326
|
+
exit 1
|
|
327
|
+
fi
|
|
328
|
+
}
|
|
329
|
+
|
|
330
|
+
# ── _check_idempotency() ─────────────────────────────────────────────────────
|
|
331
|
+
# Check TRAIN_JOB_NAME in config, query status if exists.
|
|
332
|
+
# If --force is set, skip check entirely and proceed to new job submission.
|
|
333
|
+
# If an existing job is found, handle based on its current status:
|
|
334
|
+
# InProgress → poll until completion
|
|
335
|
+
# Completed → display results and exit 0
|
|
336
|
+
# Failed → display failure reason, suggest --force, exit 2
|
|
337
|
+
# Stopped → display stopped message, suggest --force, exit 2
|
|
338
|
+
_check_idempotency() {
|
|
339
|
+
# If --force is set, skip idempotency check entirely
|
|
340
|
+
if [ "${ARG_FORCE}" = true ]; then
|
|
341
|
+
return 0
|
|
342
|
+
fi
|
|
343
|
+
|
|
344
|
+
# Check if TRAIN_JOB_NAME is set and non-empty (sourced from do/config)
|
|
345
|
+
if [ -z "${TRAIN_JOB_NAME:-}" ]; then
|
|
346
|
+
return 0
|
|
347
|
+
fi
|
|
348
|
+
|
|
349
|
+
echo "🔍 Found existing training job: ${TRAIN_JOB_NAME}"
|
|
350
|
+
echo " Checking status..."
|
|
351
|
+
echo ""
|
|
352
|
+
|
|
353
|
+
# Call DescribeTrainingJob to get current status
|
|
354
|
+
local describe_output
|
|
355
|
+
local describe_exit_code
|
|
356
|
+
describe_output=$(aws sagemaker describe-training-job \
|
|
357
|
+
--training-job-name "${TRAIN_JOB_NAME}" 2>&1) || describe_exit_code=$?
|
|
358
|
+
describe_exit_code=${describe_exit_code:-0}
|
|
359
|
+
|
|
360
|
+
if [ ${describe_exit_code} -ne 0 ]; then
|
|
361
|
+
# If describe fails (e.g., job was deleted), proceed to new job
|
|
362
|
+
echo "⚠️ Could not describe existing job: ${TRAIN_JOB_NAME}"
|
|
363
|
+
echo " ${describe_output}"
|
|
364
|
+
echo " Proceeding to create a new job."
|
|
365
|
+
echo ""
|
|
366
|
+
return 0
|
|
367
|
+
fi
|
|
368
|
+
|
|
369
|
+
# Extract status from the JSON response using python3
|
|
370
|
+
local job_status
|
|
371
|
+
job_status=$(echo "${describe_output}" | python3 -c "
|
|
372
|
+
import sys, json
|
|
373
|
+
resp = json.load(sys.stdin)
|
|
374
|
+
print(resp.get('TrainingJobStatus', 'Unknown'))
|
|
375
|
+
" 2>/dev/null) || job_status="Unknown"
|
|
376
|
+
|
|
377
|
+
case "${job_status}" in
|
|
378
|
+
InProgress)
|
|
379
|
+
echo "⏳ Training job is still running: ${TRAIN_JOB_NAME}"
|
|
380
|
+
echo " Resuming polling..."
|
|
381
|
+
echo ""
|
|
382
|
+
_poll_job
|
|
383
|
+
_handle_completion
|
|
384
|
+
exit 0
|
|
385
|
+
;;
|
|
386
|
+
Completed)
|
|
387
|
+
echo "✅ Training job already completed: ${TRAIN_JOB_NAME}"
|
|
388
|
+
echo ""
|
|
389
|
+
# Pass the describe output to _handle_completion via a temp file
|
|
390
|
+
local describe_file="/tmp/train-describe-${TRAIN_JOB_NAME}.json"
|
|
391
|
+
echo "${describe_output}" > "${describe_file}"
|
|
392
|
+
_handle_completion
|
|
393
|
+
exit 0
|
|
394
|
+
;;
|
|
395
|
+
Failed)
|
|
396
|
+
local failure_reason
|
|
397
|
+
failure_reason=$(echo "${describe_output}" | python3 -c "
|
|
398
|
+
import sys, json
|
|
399
|
+
resp = json.load(sys.stdin)
|
|
400
|
+
print(resp.get('FailureReason', 'No failure reason provided'))
|
|
401
|
+
" 2>/dev/null) || failure_reason="No failure reason provided"
|
|
402
|
+
|
|
403
|
+
echo "❌ Previous training job failed: ${TRAIN_JOB_NAME}"
|
|
404
|
+
echo " Reason: ${failure_reason}"
|
|
405
|
+
echo ""
|
|
406
|
+
echo " To submit a new job, re-run with --force:"
|
|
407
|
+
echo " ./do/train --force"
|
|
408
|
+
exit 2
|
|
409
|
+
;;
|
|
410
|
+
Stopped)
|
|
411
|
+
echo "⏹️ Previous training job was stopped: ${TRAIN_JOB_NAME}"
|
|
412
|
+
echo ""
|
|
413
|
+
echo " To submit a new job, re-run with --force:"
|
|
414
|
+
echo " ./do/train --force"
|
|
415
|
+
exit 2
|
|
416
|
+
;;
|
|
417
|
+
*)
|
|
418
|
+
echo "⚠️ Unexpected job status: ${job_status} for ${TRAIN_JOB_NAME}"
|
|
419
|
+
echo " To submit a new job, re-run with --force:"
|
|
420
|
+
echo " ./do/train --force"
|
|
421
|
+
exit 2
|
|
422
|
+
;;
|
|
423
|
+
esac
|
|
424
|
+
}
|
|
425
|
+
|
|
426
|
+
# ── _build_job_request() ──────────────────────────────────────────────────────
|
|
427
|
+
# Construct the CreateTrainingJob JSON request body.
|
|
428
|
+
# Sets JOB_NAME and JOB_REQUEST_FILE for use by _submit_job.
|
|
429
|
+
_build_job_request() {
|
|
430
|
+
# Generate job name with timestamp
|
|
431
|
+
local timestamp
|
|
432
|
+
timestamp=$(date +%Y%m%d-%H%M%S)
|
|
433
|
+
JOB_NAME="${PROJECT_NAME}-train-${timestamp}"
|
|
434
|
+
|
|
435
|
+
# Construct the JSON request file using python3
|
|
436
|
+
JOB_REQUEST_FILE="/tmp/train-request-${JOB_NAME}.json"
|
|
437
|
+
|
|
438
|
+
python3 "${SCRIPT_DIR}/.train_build_request.py" \
|
|
439
|
+
--job-name "${JOB_NAME}" \
|
|
440
|
+
--role-arn "${ROLE_ARN}" \
|
|
441
|
+
--image "${TRAIN_IMAGE}" \
|
|
442
|
+
--instance-type "${TRAIN_INSTANCE_TYPE}" \
|
|
443
|
+
--instance-count "${TRAIN_INSTANCE_COUNT}" \
|
|
444
|
+
--volume-size "${TRAIN_VOLUME_SIZE}" \
|
|
445
|
+
--dataset "${TRAIN_DATASET}" \
|
|
446
|
+
--output-path "${TRAIN_OUTPUT_PATH}" \
|
|
447
|
+
--max-runtime "${TRAIN_MAX_RUNTIME}" \
|
|
448
|
+
--hyperparams "${TRAIN_HYPERPARAMS}" \
|
|
449
|
+
--enable-spot "${TRAIN_ENABLE_SPOT}" \
|
|
450
|
+
--max-wait "${TRAIN_MAX_WAIT}" \
|
|
451
|
+
--checkpoint-path "${TRAIN_CHECKPOINT_PATH}" \
|
|
452
|
+
--metric-definitions "${TRAIN_METRIC_DEFINITIONS}" \
|
|
453
|
+
--environment "${TRAIN_ENVIRONMENT}" \
|
|
454
|
+
--tags "${TRAIN_TAGS}" \
|
|
455
|
+
--output-file "${JOB_REQUEST_FILE}"
|
|
456
|
+
|
|
457
|
+
if [ $? -ne 0 ]; then
|
|
458
|
+
echo "❌ Failed to construct CreateTrainingJob request"
|
|
459
|
+
exit 1
|
|
460
|
+
fi
|
|
461
|
+
|
|
462
|
+
echo "📋 Training Job: ${JOB_NAME}"
|
|
463
|
+
echo " Image: ${TRAIN_IMAGE}"
|
|
464
|
+
echo " Instance: ${TRAIN_INSTANCE_TYPE} x ${TRAIN_INSTANCE_COUNT}"
|
|
465
|
+
echo " Dataset: ${TRAIN_DATASET}"
|
|
466
|
+
echo " Output: ${TRAIN_OUTPUT_PATH}"
|
|
467
|
+
if [ "${TRAIN_ENABLE_SPOT}" = "true" ]; then
|
|
468
|
+
echo " Spot: enabled (max wait ${TRAIN_MAX_WAIT}s)"
|
|
469
|
+
fi
|
|
470
|
+
}
|
|
471
|
+
|
|
472
|
+
# ── _submit_job() ─────────────────────────────────────────────────────────────
|
|
473
|
+
# Call aws sagemaker create-training-job with the constructed JSON.
|
|
474
|
+
# Handles --dry-run, AccessDenied detection, and config persistence.
|
|
475
|
+
_submit_job() {
|
|
476
|
+
# Handle --dry-run: print the request JSON and exit without submitting
|
|
477
|
+
if [ "${ARG_DRY_RUN}" = true ]; then
|
|
478
|
+
echo ""
|
|
479
|
+
echo "🔍 Dry run — CreateTrainingJob request:"
|
|
480
|
+
echo ""
|
|
481
|
+
cat "${JOB_REQUEST_FILE}"
|
|
482
|
+
echo ""
|
|
483
|
+
rm -f "${JOB_REQUEST_FILE}"
|
|
484
|
+
exit 0
|
|
485
|
+
fi
|
|
486
|
+
|
|
487
|
+
echo ""
|
|
488
|
+
echo "🚀 Submitting training job..."
|
|
489
|
+
|
|
490
|
+
# Submit the job via AWS CLI
|
|
491
|
+
local submit_output
|
|
492
|
+
local submit_exit_code
|
|
493
|
+
submit_output=$(aws sagemaker create-training-job \
|
|
494
|
+
--cli-input-json "file://${JOB_REQUEST_FILE}" 2>&1) || submit_exit_code=$?
|
|
495
|
+
submit_exit_code=${submit_exit_code:-0}
|
|
496
|
+
|
|
497
|
+
# Clean up the temporary request file
|
|
498
|
+
rm -f "${JOB_REQUEST_FILE}"
|
|
499
|
+
|
|
500
|
+
if [ ${submit_exit_code} -eq 0 ]; then
|
|
501
|
+
# Success — persist job name to do/config
|
|
502
|
+
_update_config_var "TRAIN_JOB_NAME" "${JOB_NAME}"
|
|
503
|
+
echo " ✅ Job submitted successfully: ${JOB_NAME}"
|
|
504
|
+
echo ""
|
|
505
|
+
else
|
|
506
|
+
# Failure — detect error type and provide remediation
|
|
507
|
+
if echo "${submit_output}" | grep -q "AccessDeniedException"; then
|
|
508
|
+
# Extract the permission or action from the error message
|
|
509
|
+
local missing_permission
|
|
510
|
+
missing_permission=$(echo "${submit_output}" | grep -oP '(?<=performing: )[^ ]+' || echo "")
|
|
511
|
+
if [ -z "${missing_permission}" ]; then
|
|
512
|
+
missing_permission=$(echo "${submit_output}" | grep -oP '(?<=action: )[^ ]+' || echo "")
|
|
513
|
+
fi
|
|
514
|
+
if [ -z "${missing_permission}" ]; then
|
|
515
|
+
missing_permission="sagemaker:CreateTrainingJob"
|
|
516
|
+
fi
|
|
517
|
+
|
|
518
|
+
echo "❌ Access denied: ${missing_permission}"
|
|
519
|
+
echo " ${submit_output}"
|
|
520
|
+
echo ""
|
|
521
|
+
echo " Remediation:"
|
|
522
|
+
echo " Ensure your IAM role or user has the '${missing_permission}' permission."
|
|
523
|
+
echo " If using the bootstrap stack, re-run ./do/bootstrap to update permissions."
|
|
524
|
+
echo " Otherwise, attach a policy granting '${missing_permission}' to your role."
|
|
525
|
+
exit 1
|
|
526
|
+
else
|
|
527
|
+
echo "❌ Failed to submit training job"
|
|
528
|
+
echo " ${submit_output}"
|
|
529
|
+
echo ""
|
|
530
|
+
echo " Check the error above and verify your configuration in do/training/config.yaml."
|
|
531
|
+
exit 1
|
|
532
|
+
fi
|
|
533
|
+
fi
|
|
534
|
+
}
|
|
535
|
+
|
|
536
|
+
# ── _poll_job() ───────────────────────────────────────────────────────────────
|
|
537
|
+
# Poll DescribeTrainingJob every POLL_INTERVAL seconds until terminal state.
|
|
538
|
+
# Displays: job status, secondary status, elapsed time, and training metrics.
|
|
539
|
+
# On Completed: breaks loop and returns (caller handles completion).
|
|
540
|
+
# On Failed: displays FailureReason and exits 2.
|
|
541
|
+
# On Stopped: displays stopped message and exits 2.
|
|
542
|
+
# On spot interruption: explains auto-resume from checkpoint.
|
|
543
|
+
_poll_job() {
|
|
544
|
+
local job_name="${JOB_NAME:-$TRAIN_JOB_NAME}"
|
|
545
|
+
|
|
546
|
+
echo "⏳ Polling training job: ${job_name}"
|
|
547
|
+
echo " (Ctrl+C to stop polling — job continues in background)"
|
|
548
|
+
echo ""
|
|
549
|
+
|
|
550
|
+
while true; do
|
|
551
|
+
# Call DescribeTrainingJob
|
|
552
|
+
local describe_output
|
|
553
|
+
local describe_exit_code
|
|
554
|
+
describe_output=$(aws sagemaker describe-training-job \
|
|
555
|
+
--training-job-name "${job_name}" 2>&1) || describe_exit_code=$?
|
|
556
|
+
describe_exit_code=${describe_exit_code:-0}
|
|
557
|
+
|
|
558
|
+
if [ ${describe_exit_code} -ne 0 ]; then
|
|
559
|
+
echo "⚠️ Failed to describe job (will retry): ${describe_output}"
|
|
560
|
+
sleep "${POLL_INTERVAL}"
|
|
561
|
+
continue
|
|
562
|
+
fi
|
|
563
|
+
|
|
564
|
+
# Parse the response using python3 helper
|
|
565
|
+
local poll_result
|
|
566
|
+
poll_result=$(echo "${describe_output}" | python3 "${SCRIPT_DIR}/.train_poll_parser.py" 2>&1)
|
|
567
|
+
local parse_exit_code=$?
|
|
568
|
+
|
|
569
|
+
if [ ${parse_exit_code} -ne 0 ]; then
|
|
570
|
+
echo "⚠️ Failed to parse job status (will retry): ${poll_result}"
|
|
571
|
+
sleep "${POLL_INTERVAL}"
|
|
572
|
+
continue
|
|
573
|
+
fi
|
|
574
|
+
|
|
575
|
+
# The parser outputs structured lines:
|
|
576
|
+
# STATUS=<status>
|
|
577
|
+
# SECONDARY=<secondary_status>
|
|
578
|
+
# ELAPSED=<elapsed_string>
|
|
579
|
+
# METRICS=<metrics_string>
|
|
580
|
+
# FAILURE_REASON=<reason>
|
|
581
|
+
# DISPLAY=<formatted display text>
|
|
582
|
+
local job_status=""
|
|
583
|
+
local secondary_status=""
|
|
584
|
+
local display_text=""
|
|
585
|
+
local failure_reason=""
|
|
586
|
+
|
|
587
|
+
job_status=$(echo "${poll_result}" | grep '^STATUS=' | cut -d= -f2-)
|
|
588
|
+
secondary_status=$(echo "${poll_result}" | grep '^SECONDARY=' | cut -d= -f2-)
|
|
589
|
+
failure_reason=$(echo "${poll_result}" | grep '^FAILURE_REASON=' | cut -d= -f2-)
|
|
590
|
+
display_text=$(echo "${poll_result}" | grep '^DISPLAY=' | cut -d= -f2-)
|
|
591
|
+
|
|
592
|
+
# Print the formatted status line
|
|
593
|
+
echo "${display_text}"
|
|
594
|
+
|
|
595
|
+
# Handle terminal states
|
|
596
|
+
case "${job_status}" in
|
|
597
|
+
Completed)
|
|
598
|
+
echo ""
|
|
599
|
+
echo "✅ Training job completed: ${job_name}"
|
|
600
|
+
break
|
|
601
|
+
;;
|
|
602
|
+
Failed)
|
|
603
|
+
echo ""
|
|
604
|
+
echo "❌ Training job failed: ${job_name}"
|
|
605
|
+
if [ -n "${failure_reason}" ]; then
|
|
606
|
+
echo " Reason: ${failure_reason}"
|
|
607
|
+
fi
|
|
608
|
+
echo ""
|
|
609
|
+
echo " To investigate: check CloudWatch logs for /aws/sagemaker/TrainingJobs/${job_name}"
|
|
610
|
+
echo " To retry: ./do/train --force"
|
|
611
|
+
exit 2
|
|
612
|
+
;;
|
|
613
|
+
Stopped)
|
|
614
|
+
echo ""
|
|
615
|
+
echo "⏹️ Training job was stopped: ${job_name}"
|
|
616
|
+
echo ""
|
|
617
|
+
echo " To submit a new job: ./do/train --force"
|
|
618
|
+
exit 2
|
|
619
|
+
;;
|
|
620
|
+
esac
|
|
621
|
+
|
|
622
|
+
# Handle spot interruption (job still InProgress but interrupted)
|
|
623
|
+
if echo "${secondary_status}" | grep -qi "interrupted"; then
|
|
624
|
+
echo ""
|
|
625
|
+
echo " ℹ️ Spot instance interrupted. The job will automatically resume"
|
|
626
|
+
echo " from the last checkpoint. Continuing to poll..."
|
|
627
|
+
echo ""
|
|
628
|
+
fi
|
|
629
|
+
|
|
630
|
+
# Wait before next poll
|
|
631
|
+
sleep "${POLL_INTERVAL}"
|
|
632
|
+
done
|
|
633
|
+
}
|
|
634
|
+
|
|
635
|
+
# ── _handle_completion() ──────────────────────────────────────────────────────
|
|
636
|
+
# Store output paths and invoke feedback loop.
|
|
637
|
+
# Extracts model artifacts path, detects output type, and prints next steps.
|
|
638
|
+
_handle_completion() {
|
|
639
|
+
local job_name="${JOB_NAME:-$TRAIN_JOB_NAME}"
|
|
640
|
+
|
|
641
|
+
# Get the full DescribeTrainingJob response
|
|
642
|
+
local describe_output
|
|
643
|
+
local describe_exit_code
|
|
644
|
+
describe_output=$(aws sagemaker describe-training-job \
|
|
645
|
+
--training-job-name "${job_name}" 2>&1) || describe_exit_code=$?
|
|
646
|
+
describe_exit_code=${describe_exit_code:-0}
|
|
647
|
+
|
|
648
|
+
if [ ${describe_exit_code} -ne 0 ]; then
|
|
649
|
+
echo "⚠️ Failed to describe completed job: ${job_name}"
|
|
650
|
+
echo " ${describe_output}"
|
|
651
|
+
return 1
|
|
652
|
+
fi
|
|
653
|
+
|
|
654
|
+
# Extract ModelArtifacts.S3ModelArtifacts from the response
|
|
655
|
+
local output_path
|
|
656
|
+
output_path=$(echo "${describe_output}" | python3 -c "
|
|
657
|
+
import sys, json
|
|
658
|
+
resp = json.load(sys.stdin)
|
|
659
|
+
artifacts = resp.get('ModelArtifacts', {})
|
|
660
|
+
print(artifacts.get('S3ModelArtifacts', ''))
|
|
661
|
+
" 2>/dev/null)
|
|
662
|
+
|
|
663
|
+
if [ -z "${output_path}" ]; then
|
|
664
|
+
echo "⚠️ No model artifacts found in job response."
|
|
665
|
+
echo " The job may not have produced output artifacts."
|
|
666
|
+
return 1
|
|
667
|
+
fi
|
|
668
|
+
|
|
669
|
+
# Write TRAIN_OUTPUT_PATH to do/config
|
|
670
|
+
_update_config_var "TRAIN_OUTPUT_PATH" "${output_path}"
|
|
671
|
+
|
|
672
|
+
# Detect output type: check for adapter_config.json in output path
|
|
673
|
+
local output_type="full-model"
|
|
674
|
+
if aws s3 ls "${output_path}/adapter_config.json" &>/dev/null; then
|
|
675
|
+
output_type="adapter"
|
|
676
|
+
fi
|
|
677
|
+
|
|
678
|
+
# Source feedback.sh and call print_completion_feedback
|
|
679
|
+
source "${SCRIPT_DIR}/lib/feedback.sh"
|
|
680
|
+
print_completion_feedback "${output_path}" "${output_type}" "${job_name}"
|
|
681
|
+
|
|
682
|
+
# If spot training was enabled, display cost savings
|
|
683
|
+
if [ "${TRAIN_ENABLE_SPOT:-false}" = "true" ]; then
|
|
684
|
+
local billable_seconds training_seconds savings_pct
|
|
685
|
+
billable_seconds=$(echo "${describe_output}" | python3 -c "
|
|
686
|
+
import sys, json
|
|
687
|
+
resp = json.load(sys.stdin)
|
|
688
|
+
print(resp.get('BillableTimeInSeconds', 0))
|
|
689
|
+
" 2>/dev/null)
|
|
690
|
+
training_seconds=$(echo "${describe_output}" | python3 -c "
|
|
691
|
+
import sys, json
|
|
692
|
+
resp = json.load(sys.stdin)
|
|
693
|
+
print(resp.get('TrainingTimeInSeconds', 0))
|
|
694
|
+
" 2>/dev/null)
|
|
695
|
+
|
|
696
|
+
if [ "${training_seconds:-0}" -gt 0 ] && [ "${billable_seconds:-0}" -gt 0 ]; then
|
|
697
|
+
savings_pct=$(python3 -c "
|
|
698
|
+
billable = ${billable_seconds}
|
|
699
|
+
training = ${training_seconds}
|
|
700
|
+
if training > 0:
|
|
701
|
+
savings = ((training - billable) / training) * 100
|
|
702
|
+
print(f'{savings:.0f}')
|
|
703
|
+
else:
|
|
704
|
+
print('0')
|
|
705
|
+
")
|
|
706
|
+
echo " 💰 Spot training savings:"
|
|
707
|
+
echo " Training time: ${training_seconds}s"
|
|
708
|
+
echo " Billed time: ${billable_seconds}s"
|
|
709
|
+
echo " Estimated savings: ~${savings_pct}%"
|
|
710
|
+
echo ""
|
|
711
|
+
fi
|
|
712
|
+
fi
|
|
713
|
+
}
|
|
714
|
+
|
|
715
|
+
# ── _update_config_var() ──────────────────────────────────────────────────────
|
|
716
|
+
# Write or update a variable in do/config.
|
|
717
|
+
# Usage: _update_config_var VAR_NAME "value"
|
|
718
|
+
_update_config_var() {
|
|
719
|
+
local var_name="$1"
|
|
720
|
+
local var_value="$2"
|
|
721
|
+
local config_file="${SCRIPT_DIR}/config"
|
|
722
|
+
|
|
723
|
+
if grep -q "^export ${var_name}=" "${config_file}" 2>/dev/null; then
|
|
724
|
+
sed -i.bak "s|^export ${var_name}=.*|export ${var_name}=\"${var_value}\"|" "${config_file}"
|
|
725
|
+
rm -f "${config_file}.bak"
|
|
726
|
+
else
|
|
727
|
+
echo "export ${var_name}=\"${var_value}\"" >> "${config_file}"
|
|
728
|
+
fi
|
|
729
|
+
}
|
|
730
|
+
|
|
731
|
+
# ── Main ──────────────────────────────────────────────────────────────────────
|
|
732
|
+
_parse_args "$@"
|
|
733
|
+
|
|
734
|
+
if [ "${ARG_HELP}" = true ]; then
|
|
735
|
+
_show_help
|
|
736
|
+
fi
|
|
737
|
+
|
|
738
|
+
if [ "${ARG_STATUS}" = true ]; then
|
|
739
|
+
# Show status of current tracked job without submitting
|
|
740
|
+
if [ -z "${TRAIN_JOB_NAME:-}" ]; then
|
|
741
|
+
echo "📊 No training job tracked."
|
|
742
|
+
echo " Run ./do/train to submit a training job."
|
|
743
|
+
exit 0
|
|
744
|
+
fi
|
|
745
|
+
|
|
746
|
+
echo "📊 Training Job Status"
|
|
747
|
+
echo " Job: ${TRAIN_JOB_NAME}"
|
|
748
|
+
|
|
749
|
+
# Call DescribeTrainingJob and parse the response
|
|
750
|
+
STATUS_JSON=$(aws sagemaker describe-training-job \
|
|
751
|
+
--training-job-name "${TRAIN_JOB_NAME}" \
|
|
752
|
+
--region "${AWS_REGION}" 2>&1) || {
|
|
753
|
+
echo ""
|
|
754
|
+
echo "❌ Failed to describe training job: ${TRAIN_JOB_NAME}"
|
|
755
|
+
echo " ${STATUS_JSON}"
|
|
756
|
+
echo ""
|
|
757
|
+
echo " The job may have been deleted or the name is incorrect."
|
|
758
|
+
echo " Run ./do/train --force to start a new job."
|
|
759
|
+
exit 1
|
|
760
|
+
}
|
|
761
|
+
|
|
762
|
+
# Parse and display the status using the helper script
|
|
763
|
+
echo "${STATUS_JSON}" | python3 "${SCRIPT_DIR}/.train_status_parser.py"
|
|
764
|
+
exit 0
|
|
765
|
+
fi
|
|
766
|
+
|
|
767
|
+
# Parse and validate configuration
|
|
768
|
+
_parse_config
|
|
769
|
+
_validate_config
|
|
770
|
+
|
|
771
|
+
# Check idempotency (existing job handling)
|
|
772
|
+
_check_idempotency
|
|
773
|
+
|
|
774
|
+
# Build and submit the job
|
|
775
|
+
_build_job_request
|
|
776
|
+
_submit_job
|
|
777
|
+
|
|
778
|
+
# Poll for completion (unless --no-wait)
|
|
779
|
+
if [ "${ARG_NO_WAIT}" = true ]; then
|
|
780
|
+
echo " --no-wait specified. Job submitted, exiting without polling."
|
|
781
|
+
echo " Re-run ./do/train --status to check progress."
|
|
782
|
+
exit 0
|
|
783
|
+
fi
|
|
784
|
+
|
|
785
|
+
_poll_job
|
|
786
|
+
_handle_completion
|