@aws/ml-container-creator 0.8.0 → 0.9.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. package/LICENSE-THIRD-PARTY +50760 -16218
  2. package/bin/cli.js +31 -137
  3. package/package.json +7 -2
  4. package/servers/lib/catalogs/instances.json +52 -1275
  5. package/servers/lib/catalogs/models.json +0 -132
  6. package/servers/lib/catalogs/popular-diffusors.json +1 -110
  7. package/src/app.js +29 -2
  8. package/src/lib/config-manager.js +17 -0
  9. package/src/lib/generated/cli-options.js +467 -0
  10. package/src/lib/generated/validation-rules.js +202 -0
  11. package/src/lib/mcp-client.js +16 -1
  12. package/src/lib/mcp-command-handler.js +10 -2
  13. package/src/lib/prompt-runner.js +16 -2
  14. package/src/lib/train-config-parser.js +136 -0
  15. package/src/lib/train-config-persistence.js +143 -0
  16. package/src/lib/train-config-validator.js +112 -0
  17. package/src/lib/train-feedback.js +46 -0
  18. package/src/lib/train-idempotency.js +97 -0
  19. package/src/lib/train-request-builder.js +120 -0
  20. package/templates/code/serve +5 -134
  21. package/templates/code/serve.d/lmi.ejs +19 -0
  22. package/templates/code/serve.d/sglang.ejs +47 -0
  23. package/templates/code/serve.d/tensorrt-llm.ejs +53 -0
  24. package/templates/code/serve.d/vllm.ejs +48 -0
  25. package/templates/do/.train_build_request.py +141 -0
  26. package/templates/do/.train_poll_parser.py +135 -0
  27. package/templates/do/.train_status_parser.py +187 -0
  28. package/templates/do/clean +1 -1387
  29. package/templates/do/clean.d/async-inference.ejs +508 -0
  30. package/templates/do/clean.d/batch-transform.ejs +512 -0
  31. package/templates/do/clean.d/hyperpod-eks.ejs +481 -0
  32. package/templates/do/clean.d/managed-inference.ejs +1043 -0
  33. package/templates/do/deploy +1 -1766
  34. package/templates/do/deploy.d/async-inference.ejs +501 -0
  35. package/templates/do/deploy.d/batch-transform.ejs +529 -0
  36. package/templates/do/deploy.d/hyperpod-eks.ejs +339 -0
  37. package/templates/do/deploy.d/managed-inference.ejs +726 -0
  38. package/templates/do/lib/feedback.sh +41 -0
  39. package/templates/do/train +786 -0
  40. package/templates/do/training/config.yaml +140 -0
  41. package/templates/do/training/train.py +463 -0
@@ -0,0 +1,786 @@
1
+ #!/bin/bash
2
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ # do/train — SageMaker Bespoke Training Job
6
+ # Wraps CreateTrainingJob for custom training where you provide your own
7
+ # training script, container, dataset, and hyperparameters.
8
+ # Configuration is read from do/training/config.yaml.
9
+ #
10
+ # Project: <%= projectName %>
11
+
12
+ set -e
13
+ set -u
14
+ set -o pipefail
15
+
16
+ # ── Source project configuration ──────────────────────────────────────────────
17
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
18
+ source "${SCRIPT_DIR}/config"
19
+
20
+ # ── Constants ─────────────────────────────────────────────────────────────────
21
+ CONFIG_FILE="${SCRIPT_DIR}/training/config.yaml"
22
+ POLL_INTERVAL=60
23
+
24
+ # ── CLI Variables (set by _parse_args) ────────────────────────────────────────
25
+ ARG_FORCE=false
26
+ ARG_STATUS=false
27
+ ARG_DRY_RUN=false
28
+ ARG_NO_WAIT=false
29
+ ARG_HELP=false
30
+
31
+ # ── Job Variables (set by _build_job_request) ─────────────────────────────────
32
+ JOB_NAME=""
33
+ JOB_REQUEST_FILE=""
34
+
35
+ # ── Training Config Variables (set by _parse_config) ──────────────────────────
36
+ TRAIN_IMAGE=""
37
+ TRAIN_SCRIPT=""
38
+ TRAIN_INSTANCE_TYPE=""
39
+ TRAIN_INSTANCE_COUNT=""
40
+ TRAIN_DATASET=""
41
+ TRAIN_OUTPUT_PATH=""
42
+ TRAIN_HYPERPARAMS=""
43
+ TRAIN_MAX_RUNTIME=""
44
+ TRAIN_VOLUME_SIZE=""
45
+ TRAIN_ENABLE_SPOT=""
46
+ TRAIN_MAX_WAIT=""
47
+ TRAIN_CHECKPOINT_PATH=""
48
+ TRAIN_METRIC_DEFINITIONS=""
49
+ TRAIN_ENVIRONMENT=""
50
+ TRAIN_TAGS=""
51
+
52
+ # ── SIGINT Trap ───────────────────────────────────────────────────────────────
53
+ trap '_handle_interrupt' INT
54
+
55
+ _handle_interrupt() {
56
+ echo ""
57
+ echo ""
58
+ echo "⚠️ Interrupted — training job continues in background."
59
+ echo " The job will keep running in SageMaker."
60
+ echo " Re-run ./do/train to resume polling, or ./do/train --status to check progress."
61
+ exit 130
62
+ }
63
+
64
+ # ── _parse_args() ─────────────────────────────────────────────────────────────
65
+ # Parse CLI flags into variables.
66
+ _parse_args() {
67
+ while [ $# -gt 0 ]; do
68
+ case "$1" in
69
+ --force) ARG_FORCE=true; shift ;;
70
+ --status) ARG_STATUS=true; shift ;;
71
+ --dry-run) ARG_DRY_RUN=true; shift ;;
72
+ --no-wait) ARG_NO_WAIT=true; shift ;;
73
+ --help|-h) ARG_HELP=true; shift ;;
74
+ *)
75
+ echo "❌ Unknown option: $1"
76
+ echo " Run ./do/train --help for usage."
77
+ exit 1
78
+ ;;
79
+ esac
80
+ done
81
+ }
82
+
83
+ # ── _show_help() ──────────────────────────────────────────────────────────────
84
+ _show_help() {
85
+ echo "Usage: ./do/train [OPTIONS]"
86
+ echo " ./do/train --status"
87
+ echo " ./do/train --help"
88
+ echo ""
89
+ echo "SageMaker Bespoke Training — submit custom training jobs using your own"
90
+ echo "training script, container, dataset, and hyperparameters."
91
+ echo ""
92
+ echo "Configuration is read from do/training/config.yaml"
93
+ echo ""
94
+ echo "Options:"
95
+ echo " --force Create a new job even if a previous job exists"
96
+ echo " --status Show current job status without submitting"
97
+ echo " --dry-run Validate inputs and show the CreateTrainingJob request without submitting"
98
+ echo " --no-wait Submit job and exit without polling for completion"
99
+ echo " --help, -h Show this help message"
100
+ echo ""
101
+ echo "Examples:"
102
+ echo " ./do/train # Submit a training job"
103
+ echo " ./do/train --status # Check status of current job"
104
+ echo " ./do/train --dry-run # Validate and preview request"
105
+ echo " ./do/train --force # Force a new job after failure"
106
+ exit 0
107
+ }
108
+
109
+ # ── _parse_config() ───────────────────────────────────────────────────────────
110
+ # Read and parse do/training/config.yaml into bash variables.
111
+ # Uses yq if available, falls back to python3 YAML parsing.
112
+ _parse_config() {
113
+ if [ ! -f "${CONFIG_FILE}" ]; then
114
+ echo "❌ Configuration file not found: ${CONFIG_FILE}"
115
+ echo " Expected at: do/training/config.yaml"
116
+ echo ""
117
+ echo " Create it with the required fields (image, script, instance_type, dataset, output_path)."
118
+ exit 1
119
+ fi
120
+
121
+ if command -v yq &>/dev/null; then
122
+ _parse_config_yq
123
+ elif command -v python3 &>/dev/null; then
124
+ _parse_config_python
125
+ else
126
+ echo "❌ Neither yq nor python3 found."
127
+ echo " Install yq (https://github.com/mikefarah/yq) or ensure python3 is available."
128
+ exit 1
129
+ fi
130
+ }
131
+
132
+ # ── _parse_config_yq() ───────────────────────────────────────────────────────
133
+ # Parse config.yaml using yq.
134
+ _parse_config_yq() {
135
+ TRAIN_IMAGE=$(yq -r '.image // ""' "${CONFIG_FILE}")
136
+ TRAIN_SCRIPT=$(yq -r '.script // ""' "${CONFIG_FILE}")
137
+ TRAIN_INSTANCE_TYPE=$(yq -r '.instance_type // ""' "${CONFIG_FILE}")
138
+ TRAIN_INSTANCE_COUNT=$(yq -r '.instance_count // "1"' "${CONFIG_FILE}")
139
+ TRAIN_DATASET=$(yq -r '.dataset // ""' "${CONFIG_FILE}")
140
+ TRAIN_OUTPUT_PATH=$(yq -r '.output_path // ""' "${CONFIG_FILE}")
141
+ TRAIN_MAX_RUNTIME=$(yq -r '.max_runtime_seconds // "86400"' "${CONFIG_FILE}")
142
+ TRAIN_VOLUME_SIZE=$(yq -r '.volume_size_gb // "50"' "${CONFIG_FILE}")
143
+ TRAIN_ENABLE_SPOT=$(yq -r '.enable_spot // "false"' "${CONFIG_FILE}")
144
+ TRAIN_MAX_WAIT=$(yq -r '.max_wait_seconds // "172800"' "${CONFIG_FILE}")
145
+ TRAIN_CHECKPOINT_PATH=$(yq -r '.checkpoint_path // ""' "${CONFIG_FILE}")
146
+
147
+ # Hyperparameters: convert map to JSON string
148
+ local hp_raw
149
+ hp_raw=$(yq -r '.hyperparameters // {}' -o=json "${CONFIG_FILE}")
150
+ if [ "${hp_raw}" = "{}" ] || [ "${hp_raw}" = "null" ]; then
151
+ TRAIN_HYPERPARAMS="{}"
152
+ else
153
+ TRAIN_HYPERPARAMS="${hp_raw}"
154
+ fi
155
+
156
+ # Metric definitions: convert list to JSON string
157
+ local md_raw
158
+ md_raw=$(yq -r '.metric_definitions // []' -o=json "${CONFIG_FILE}")
159
+ if [ "${md_raw}" = "[]" ] || [ "${md_raw}" = "null" ]; then
160
+ TRAIN_METRIC_DEFINITIONS="[]"
161
+ else
162
+ TRAIN_METRIC_DEFINITIONS="${md_raw}"
163
+ fi
164
+
165
+ # Environment: convert map to JSON string
166
+ local env_raw
167
+ env_raw=$(yq -r '.environment // {}' -o=json "${CONFIG_FILE}")
168
+ if [ "${env_raw}" = "{}" ] || [ "${env_raw}" = "null" ]; then
169
+ TRAIN_ENVIRONMENT="{}"
170
+ else
171
+ TRAIN_ENVIRONMENT="${env_raw}"
172
+ fi
173
+
174
+ # Tags: convert map to JSON string
175
+ local tags_raw
176
+ tags_raw=$(yq -r '.tags // {}' -o=json "${CONFIG_FILE}")
177
+ if [ "${tags_raw}" = "{}" ] || [ "${tags_raw}" = "null" ]; then
178
+ TRAIN_TAGS="{}"
179
+ else
180
+ TRAIN_TAGS="${tags_raw}"
181
+ fi
182
+ }
183
+
184
+ # ── _parse_config_python() ───────────────────────────────────────────────────
185
+ # Parse config.yaml using python3 as fallback when yq is not available.
186
+ _parse_config_python() {
187
+ local parse_output
188
+ parse_output=$(python3 -c "
189
+ import yaml, json, sys
190
+
191
+ with open('${CONFIG_FILE}', 'r') as f:
192
+ cfg = yaml.safe_load(f) or {}
193
+
194
+ def s(val, default=''):
195
+ if val is None:
196
+ return default
197
+ if isinstance(val, bool):
198
+ return 'true' if val else 'false'
199
+ return str(val)
200
+
201
+ print(s(cfg.get('image'), ''))
202
+ print(s(cfg.get('script'), ''))
203
+ print(s(cfg.get('instance_type'), ''))
204
+ print(s(cfg.get('instance_count'), '1'))
205
+ print(s(cfg.get('dataset'), ''))
206
+ print(s(cfg.get('output_path'), ''))
207
+ print(s(cfg.get('max_runtime_seconds'), '86400'))
208
+ print(s(cfg.get('volume_size_gb'), '50'))
209
+ print(s(cfg.get('enable_spot'), 'false'))
210
+ print(s(cfg.get('max_wait_seconds'), '172800'))
211
+ print(s(cfg.get('checkpoint_path'), ''))
212
+ print(json.dumps(cfg.get('hyperparameters') or {}))
213
+ print(json.dumps(cfg.get('metric_definitions') or []))
214
+ print(json.dumps(cfg.get('environment') or {}))
215
+ print(json.dumps(cfg.get('tags') or {}))
216
+ " 2>&1)
217
+
218
+ if [ $? -ne 0 ]; then
219
+ echo "❌ Failed to parse ${CONFIG_FILE}"
220
+ echo " ${parse_output}"
221
+ echo ""
222
+ echo " Ensure the file is valid YAML syntax."
223
+ exit 1
224
+ fi
225
+
226
+ # Read each line into the corresponding variable
227
+ local i=0
228
+ while IFS= read -r line; do
229
+ case $i in
230
+ 0) TRAIN_IMAGE="${line}" ;;
231
+ 1) TRAIN_SCRIPT="${line}" ;;
232
+ 2) TRAIN_INSTANCE_TYPE="${line}" ;;
233
+ 3) TRAIN_INSTANCE_COUNT="${line}" ;;
234
+ 4) TRAIN_DATASET="${line}" ;;
235
+ 5) TRAIN_OUTPUT_PATH="${line}" ;;
236
+ 6) TRAIN_MAX_RUNTIME="${line}" ;;
237
+ 7) TRAIN_VOLUME_SIZE="${line}" ;;
238
+ 8) TRAIN_ENABLE_SPOT="${line}" ;;
239
+ 9) TRAIN_MAX_WAIT="${line}" ;;
240
+ 10) TRAIN_CHECKPOINT_PATH="${line}" ;;
241
+ 11) TRAIN_HYPERPARAMS="${line}" ;;
242
+ 12) TRAIN_METRIC_DEFINITIONS="${line}" ;;
243
+ 13) TRAIN_ENVIRONMENT="${line}" ;;
244
+ 14) TRAIN_TAGS="${line}" ;;
245
+ esac
246
+ i=$((i + 1))
247
+ done <<< "${parse_output}"
248
+
249
+ # Apply defaults for any empty optional fields
250
+ TRAIN_INSTANCE_COUNT="${TRAIN_INSTANCE_COUNT:-1}"
251
+ TRAIN_MAX_RUNTIME="${TRAIN_MAX_RUNTIME:-86400}"
252
+ TRAIN_VOLUME_SIZE="${TRAIN_VOLUME_SIZE:-50}"
253
+ TRAIN_ENABLE_SPOT="${TRAIN_ENABLE_SPOT:-false}"
254
+ TRAIN_MAX_WAIT="${TRAIN_MAX_WAIT:-172800}"
255
+ local empty_obj='{}'
256
+ local empty_arr='[]'
257
+ TRAIN_HYPERPARAMS="${TRAIN_HYPERPARAMS:-$empty_obj}"
258
+ TRAIN_METRIC_DEFINITIONS="${TRAIN_METRIC_DEFINITIONS:-$empty_arr}"
259
+ TRAIN_ENVIRONMENT="${TRAIN_ENVIRONMENT:-$empty_obj}"
260
+ TRAIN_TAGS="${TRAIN_TAGS:-$empty_obj}"
261
+ }
262
+
263
+ # ── _validate_config() ────────────────────────────────────────────────────────
264
+ # Check that all required fields are present and valid.
265
+ _validate_config() {
266
+ local has_error=false
267
+
268
+ if [ -z "${TRAIN_IMAGE}" ]; then
269
+ echo "❌ Missing required field: image"
270
+ echo " The container image URI is required in do/training/config.yaml"
271
+ echo ""
272
+ echo " Expected format: image: \"123456789012.dkr.ecr.us-east-1.amazonaws.com/my-training:latest\""
273
+ echo ""
274
+ has_error=true
275
+ fi
276
+
277
+ if [ -z "${TRAIN_SCRIPT}" ]; then
278
+ echo "❌ Missing required field: script"
279
+ echo " The training script S3 path is required in do/training/config.yaml"
280
+ echo ""
281
+ echo " Expected format: script: \"s3://my-bucket/scripts/train.py\""
282
+ echo ""
283
+ has_error=true
284
+ fi
285
+
286
+ if [ -z "${TRAIN_INSTANCE_TYPE}" ]; then
287
+ echo "❌ Missing required field: instance_type"
288
+ echo " The SageMaker instance type is required in do/training/config.yaml"
289
+ echo ""
290
+ echo " Expected format: instance_type: \"ml.g5.xlarge\""
291
+ echo ""
292
+ has_error=true
293
+ fi
294
+
295
+ if [ -z "${TRAIN_DATASET}" ]; then
296
+ echo "❌ Missing required field: dataset"
297
+ echo " The S3 dataset path is required in do/training/config.yaml"
298
+ echo ""
299
+ echo " Expected format: dataset: \"s3://my-bucket/data/train/\""
300
+ echo ""
301
+ has_error=true
302
+ fi
303
+
304
+ if [ -z "${TRAIN_OUTPUT_PATH}" ]; then
305
+ echo "❌ Missing required field: output_path"
306
+ echo " The S3 output path is required in do/training/config.yaml"
307
+ echo ""
308
+ echo " Expected format: output_path: \"s3://my-bucket/output/\""
309
+ echo ""
310
+ has_error=true
311
+ fi
312
+
313
+ # Spot training requires a checkpoint path for resumption
314
+ if [ "${TRAIN_ENABLE_SPOT}" = "true" ] && [ -z "${TRAIN_CHECKPOINT_PATH}" ]; then
315
+ echo "❌ Checkpoint path required for spot training"
316
+ echo " When enable_spot is true, a checkpoint S3 path must be specified"
317
+ echo " so training can resume after spot interruptions."
318
+ echo ""
319
+ echo " Add to do/training/config.yaml:"
320
+ echo " checkpoint_path: \"s3://my-bucket/checkpoints/\""
321
+ echo ""
322
+ has_error=true
323
+ fi
324
+
325
+ if [ "${has_error}" = true ]; then
326
+ exit 1
327
+ fi
328
+ }
329
+
330
+ # ── _check_idempotency() ─────────────────────────────────────────────────────
331
+ # Check TRAIN_JOB_NAME in config, query status if exists.
332
+ # If --force is set, skip check entirely and proceed to new job submission.
333
+ # If an existing job is found, handle based on its current status:
334
+ # InProgress → poll until completion
335
+ # Completed → display results and exit 0
336
+ # Failed → display failure reason, suggest --force, exit 2
337
+ # Stopped → display stopped message, suggest --force, exit 2
338
+ _check_idempotency() {
339
+ # If --force is set, skip idempotency check entirely
340
+ if [ "${ARG_FORCE}" = true ]; then
341
+ return 0
342
+ fi
343
+
344
+ # Check if TRAIN_JOB_NAME is set and non-empty (sourced from do/config)
345
+ if [ -z "${TRAIN_JOB_NAME:-}" ]; then
346
+ return 0
347
+ fi
348
+
349
+ echo "🔍 Found existing training job: ${TRAIN_JOB_NAME}"
350
+ echo " Checking status..."
351
+ echo ""
352
+
353
+ # Call DescribeTrainingJob to get current status
354
+ local describe_output
355
+ local describe_exit_code
356
+ describe_output=$(aws sagemaker describe-training-job \
357
+ --training-job-name "${TRAIN_JOB_NAME}" 2>&1) || describe_exit_code=$?
358
+ describe_exit_code=${describe_exit_code:-0}
359
+
360
+ if [ ${describe_exit_code} -ne 0 ]; then
361
+ # If describe fails (e.g., job was deleted), proceed to new job
362
+ echo "⚠️ Could not describe existing job: ${TRAIN_JOB_NAME}"
363
+ echo " ${describe_output}"
364
+ echo " Proceeding to create a new job."
365
+ echo ""
366
+ return 0
367
+ fi
368
+
369
+ # Extract status from the JSON response using python3
370
+ local job_status
371
+ job_status=$(echo "${describe_output}" | python3 -c "
372
+ import sys, json
373
+ resp = json.load(sys.stdin)
374
+ print(resp.get('TrainingJobStatus', 'Unknown'))
375
+ " 2>/dev/null) || job_status="Unknown"
376
+
377
+ case "${job_status}" in
378
+ InProgress)
379
+ echo "⏳ Training job is still running: ${TRAIN_JOB_NAME}"
380
+ echo " Resuming polling..."
381
+ echo ""
382
+ _poll_job
383
+ _handle_completion
384
+ exit 0
385
+ ;;
386
+ Completed)
387
+ echo "✅ Training job already completed: ${TRAIN_JOB_NAME}"
388
+ echo ""
389
+ # Pass the describe output to _handle_completion via a temp file
390
+ local describe_file="/tmp/train-describe-${TRAIN_JOB_NAME}.json"
391
+ echo "${describe_output}" > "${describe_file}"
392
+ _handle_completion
393
+ exit 0
394
+ ;;
395
+ Failed)
396
+ local failure_reason
397
+ failure_reason=$(echo "${describe_output}" | python3 -c "
398
+ import sys, json
399
+ resp = json.load(sys.stdin)
400
+ print(resp.get('FailureReason', 'No failure reason provided'))
401
+ " 2>/dev/null) || failure_reason="No failure reason provided"
402
+
403
+ echo "❌ Previous training job failed: ${TRAIN_JOB_NAME}"
404
+ echo " Reason: ${failure_reason}"
405
+ echo ""
406
+ echo " To submit a new job, re-run with --force:"
407
+ echo " ./do/train --force"
408
+ exit 2
409
+ ;;
410
+ Stopped)
411
+ echo "⏹️ Previous training job was stopped: ${TRAIN_JOB_NAME}"
412
+ echo ""
413
+ echo " To submit a new job, re-run with --force:"
414
+ echo " ./do/train --force"
415
+ exit 2
416
+ ;;
417
+ *)
418
+ echo "⚠️ Unexpected job status: ${job_status} for ${TRAIN_JOB_NAME}"
419
+ echo " To submit a new job, re-run with --force:"
420
+ echo " ./do/train --force"
421
+ exit 2
422
+ ;;
423
+ esac
424
+ }
425
+
426
+ # ── _build_job_request() ──────────────────────────────────────────────────────
427
+ # Construct the CreateTrainingJob JSON request body.
428
+ # Sets JOB_NAME and JOB_REQUEST_FILE for use by _submit_job.
429
+ _build_job_request() {
430
+ # Generate job name with timestamp
431
+ local timestamp
432
+ timestamp=$(date +%Y%m%d-%H%M%S)
433
+ JOB_NAME="${PROJECT_NAME}-train-${timestamp}"
434
+
435
+ # Construct the JSON request file using python3
436
+ JOB_REQUEST_FILE="/tmp/train-request-${JOB_NAME}.json"
437
+
438
+ python3 "${SCRIPT_DIR}/.train_build_request.py" \
439
+ --job-name "${JOB_NAME}" \
440
+ --role-arn "${ROLE_ARN}" \
441
+ --image "${TRAIN_IMAGE}" \
442
+ --instance-type "${TRAIN_INSTANCE_TYPE}" \
443
+ --instance-count "${TRAIN_INSTANCE_COUNT}" \
444
+ --volume-size "${TRAIN_VOLUME_SIZE}" \
445
+ --dataset "${TRAIN_DATASET}" \
446
+ --output-path "${TRAIN_OUTPUT_PATH}" \
447
+ --max-runtime "${TRAIN_MAX_RUNTIME}" \
448
+ --hyperparams "${TRAIN_HYPERPARAMS}" \
449
+ --enable-spot "${TRAIN_ENABLE_SPOT}" \
450
+ --max-wait "${TRAIN_MAX_WAIT}" \
451
+ --checkpoint-path "${TRAIN_CHECKPOINT_PATH}" \
452
+ --metric-definitions "${TRAIN_METRIC_DEFINITIONS}" \
453
+ --environment "${TRAIN_ENVIRONMENT}" \
454
+ --tags "${TRAIN_TAGS}" \
455
+ --output-file "${JOB_REQUEST_FILE}"
456
+
457
+ if [ $? -ne 0 ]; then
458
+ echo "❌ Failed to construct CreateTrainingJob request"
459
+ exit 1
460
+ fi
461
+
462
+ echo "📋 Training Job: ${JOB_NAME}"
463
+ echo " Image: ${TRAIN_IMAGE}"
464
+ echo " Instance: ${TRAIN_INSTANCE_TYPE} x ${TRAIN_INSTANCE_COUNT}"
465
+ echo " Dataset: ${TRAIN_DATASET}"
466
+ echo " Output: ${TRAIN_OUTPUT_PATH}"
467
+ if [ "${TRAIN_ENABLE_SPOT}" = "true" ]; then
468
+ echo " Spot: enabled (max wait ${TRAIN_MAX_WAIT}s)"
469
+ fi
470
+ }
471
+
472
+ # ── _submit_job() ─────────────────────────────────────────────────────────────
473
+ # Call aws sagemaker create-training-job with the constructed JSON.
474
+ # Handles --dry-run, AccessDenied detection, and config persistence.
475
+ _submit_job() {
476
+ # Handle --dry-run: print the request JSON and exit without submitting
477
+ if [ "${ARG_DRY_RUN}" = true ]; then
478
+ echo ""
479
+ echo "🔍 Dry run — CreateTrainingJob request:"
480
+ echo ""
481
+ cat "${JOB_REQUEST_FILE}"
482
+ echo ""
483
+ rm -f "${JOB_REQUEST_FILE}"
484
+ exit 0
485
+ fi
486
+
487
+ echo ""
488
+ echo "🚀 Submitting training job..."
489
+
490
+ # Submit the job via AWS CLI
491
+ local submit_output
492
+ local submit_exit_code
493
+ submit_output=$(aws sagemaker create-training-job \
494
+ --cli-input-json "file://${JOB_REQUEST_FILE}" 2>&1) || submit_exit_code=$?
495
+ submit_exit_code=${submit_exit_code:-0}
496
+
497
+ # Clean up the temporary request file
498
+ rm -f "${JOB_REQUEST_FILE}"
499
+
500
+ if [ ${submit_exit_code} -eq 0 ]; then
501
+ # Success — persist job name to do/config
502
+ _update_config_var "TRAIN_JOB_NAME" "${JOB_NAME}"
503
+ echo " ✅ Job submitted successfully: ${JOB_NAME}"
504
+ echo ""
505
+ else
506
+ # Failure — detect error type and provide remediation
507
+ if echo "${submit_output}" | grep -q "AccessDeniedException"; then
508
+ # Extract the permission or action from the error message
509
+ local missing_permission
510
+ missing_permission=$(echo "${submit_output}" | grep -oP '(?<=performing: )[^ ]+' || echo "")
511
+ if [ -z "${missing_permission}" ]; then
512
+ missing_permission=$(echo "${submit_output}" | grep -oP '(?<=action: )[^ ]+' || echo "")
513
+ fi
514
+ if [ -z "${missing_permission}" ]; then
515
+ missing_permission="sagemaker:CreateTrainingJob"
516
+ fi
517
+
518
+ echo "❌ Access denied: ${missing_permission}"
519
+ echo " ${submit_output}"
520
+ echo ""
521
+ echo " Remediation:"
522
+ echo " Ensure your IAM role or user has the '${missing_permission}' permission."
523
+ echo " If using the bootstrap stack, re-run ./do/bootstrap to update permissions."
524
+ echo " Otherwise, attach a policy granting '${missing_permission}' to your role."
525
+ exit 1
526
+ else
527
+ echo "❌ Failed to submit training job"
528
+ echo " ${submit_output}"
529
+ echo ""
530
+ echo " Check the error above and verify your configuration in do/training/config.yaml."
531
+ exit 1
532
+ fi
533
+ fi
534
+ }
535
+
536
+ # ── _poll_job() ───────────────────────────────────────────────────────────────
537
+ # Poll DescribeTrainingJob every POLL_INTERVAL seconds until terminal state.
538
+ # Displays: job status, secondary status, elapsed time, and training metrics.
539
+ # On Completed: breaks loop and returns (caller handles completion).
540
+ # On Failed: displays FailureReason and exits 2.
541
+ # On Stopped: displays stopped message and exits 2.
542
+ # On spot interruption: explains auto-resume from checkpoint.
543
+ _poll_job() {
544
+ local job_name="${JOB_NAME:-$TRAIN_JOB_NAME}"
545
+
546
+ echo "⏳ Polling training job: ${job_name}"
547
+ echo " (Ctrl+C to stop polling — job continues in background)"
548
+ echo ""
549
+
550
+ while true; do
551
+ # Call DescribeTrainingJob
552
+ local describe_output
553
+ local describe_exit_code
554
+ describe_output=$(aws sagemaker describe-training-job \
555
+ --training-job-name "${job_name}" 2>&1) || describe_exit_code=$?
556
+ describe_exit_code=${describe_exit_code:-0}
557
+
558
+ if [ ${describe_exit_code} -ne 0 ]; then
559
+ echo "⚠️ Failed to describe job (will retry): ${describe_output}"
560
+ sleep "${POLL_INTERVAL}"
561
+ continue
562
+ fi
563
+
564
+ # Parse the response using python3 helper
565
+ local poll_result
566
+ poll_result=$(echo "${describe_output}" | python3 "${SCRIPT_DIR}/.train_poll_parser.py" 2>&1)
567
+ local parse_exit_code=$?
568
+
569
+ if [ ${parse_exit_code} -ne 0 ]; then
570
+ echo "⚠️ Failed to parse job status (will retry): ${poll_result}"
571
+ sleep "${POLL_INTERVAL}"
572
+ continue
573
+ fi
574
+
575
+ # The parser outputs structured lines:
576
+ # STATUS=<status>
577
+ # SECONDARY=<secondary_status>
578
+ # ELAPSED=<elapsed_string>
579
+ # METRICS=<metrics_string>
580
+ # FAILURE_REASON=<reason>
581
+ # DISPLAY=<formatted display text>
582
+ local job_status=""
583
+ local secondary_status=""
584
+ local display_text=""
585
+ local failure_reason=""
586
+
587
+ job_status=$(echo "${poll_result}" | grep '^STATUS=' | cut -d= -f2-)
588
+ secondary_status=$(echo "${poll_result}" | grep '^SECONDARY=' | cut -d= -f2-)
589
+ failure_reason=$(echo "${poll_result}" | grep '^FAILURE_REASON=' | cut -d= -f2-)
590
+ display_text=$(echo "${poll_result}" | grep '^DISPLAY=' | cut -d= -f2-)
591
+
592
+ # Print the formatted status line
593
+ echo "${display_text}"
594
+
595
+ # Handle terminal states
596
+ case "${job_status}" in
597
+ Completed)
598
+ echo ""
599
+ echo "✅ Training job completed: ${job_name}"
600
+ break
601
+ ;;
602
+ Failed)
603
+ echo ""
604
+ echo "❌ Training job failed: ${job_name}"
605
+ if [ -n "${failure_reason}" ]; then
606
+ echo " Reason: ${failure_reason}"
607
+ fi
608
+ echo ""
609
+ echo " To investigate: check CloudWatch logs for /aws/sagemaker/TrainingJobs/${job_name}"
610
+ echo " To retry: ./do/train --force"
611
+ exit 2
612
+ ;;
613
+ Stopped)
614
+ echo ""
615
+ echo "⏹️ Training job was stopped: ${job_name}"
616
+ echo ""
617
+ echo " To submit a new job: ./do/train --force"
618
+ exit 2
619
+ ;;
620
+ esac
621
+
622
+ # Handle spot interruption (job still InProgress but interrupted)
623
+ if echo "${secondary_status}" | grep -qi "interrupted"; then
624
+ echo ""
625
+ echo " ℹ️ Spot instance interrupted. The job will automatically resume"
626
+ echo " from the last checkpoint. Continuing to poll..."
627
+ echo ""
628
+ fi
629
+
630
+ # Wait before next poll
631
+ sleep "${POLL_INTERVAL}"
632
+ done
633
+ }
634
+
635
+ # ── _handle_completion() ──────────────────────────────────────────────────────
636
+ # Store output paths and invoke feedback loop.
637
+ # Extracts model artifacts path, detects output type, and prints next steps.
638
+ _handle_completion() {
639
+ local job_name="${JOB_NAME:-$TRAIN_JOB_NAME}"
640
+
641
+ # Get the full DescribeTrainingJob response
642
+ local describe_output
643
+ local describe_exit_code
644
+ describe_output=$(aws sagemaker describe-training-job \
645
+ --training-job-name "${job_name}" 2>&1) || describe_exit_code=$?
646
+ describe_exit_code=${describe_exit_code:-0}
647
+
648
+ if [ ${describe_exit_code} -ne 0 ]; then
649
+ echo "⚠️ Failed to describe completed job: ${job_name}"
650
+ echo " ${describe_output}"
651
+ return 1
652
+ fi
653
+
654
+ # Extract ModelArtifacts.S3ModelArtifacts from the response
655
+ local output_path
656
+ output_path=$(echo "${describe_output}" | python3 -c "
657
+ import sys, json
658
+ resp = json.load(sys.stdin)
659
+ artifacts = resp.get('ModelArtifacts', {})
660
+ print(artifacts.get('S3ModelArtifacts', ''))
661
+ " 2>/dev/null)
662
+
663
+ if [ -z "${output_path}" ]; then
664
+ echo "⚠️ No model artifacts found in job response."
665
+ echo " The job may not have produced output artifacts."
666
+ return 1
667
+ fi
668
+
669
+ # Write TRAIN_OUTPUT_PATH to do/config
670
+ _update_config_var "TRAIN_OUTPUT_PATH" "${output_path}"
671
+
672
+ # Detect output type: check for adapter_config.json in output path
673
+ local output_type="full-model"
674
+ if aws s3 ls "${output_path}/adapter_config.json" &>/dev/null; then
675
+ output_type="adapter"
676
+ fi
677
+
678
+ # Source feedback.sh and call print_completion_feedback
679
+ source "${SCRIPT_DIR}/lib/feedback.sh"
680
+ print_completion_feedback "${output_path}" "${output_type}" "${job_name}"
681
+
682
+ # If spot training was enabled, display cost savings
683
+ if [ "${TRAIN_ENABLE_SPOT:-false}" = "true" ]; then
684
+ local billable_seconds training_seconds savings_pct
685
+ billable_seconds=$(echo "${describe_output}" | python3 -c "
686
+ import sys, json
687
+ resp = json.load(sys.stdin)
688
+ print(resp.get('BillableTimeInSeconds', 0))
689
+ " 2>/dev/null)
690
+ training_seconds=$(echo "${describe_output}" | python3 -c "
691
+ import sys, json
692
+ resp = json.load(sys.stdin)
693
+ print(resp.get('TrainingTimeInSeconds', 0))
694
+ " 2>/dev/null)
695
+
696
+ if [ "${training_seconds:-0}" -gt 0 ] && [ "${billable_seconds:-0}" -gt 0 ]; then
697
+ savings_pct=$(python3 -c "
698
+ billable = ${billable_seconds}
699
+ training = ${training_seconds}
700
+ if training > 0:
701
+ savings = ((training - billable) / training) * 100
702
+ print(f'{savings:.0f}')
703
+ else:
704
+ print('0')
705
+ ")
706
+ echo " 💰 Spot training savings:"
707
+ echo " Training time: ${training_seconds}s"
708
+ echo " Billed time: ${billable_seconds}s"
709
+ echo " Estimated savings: ~${savings_pct}%"
710
+ echo ""
711
+ fi
712
+ fi
713
+ }
714
+
715
+ # ── _update_config_var() ──────────────────────────────────────────────────────
716
+ # Write or update a variable in do/config.
717
+ # Usage: _update_config_var VAR_NAME "value"
718
+ _update_config_var() {
719
+ local var_name="$1"
720
+ local var_value="$2"
721
+ local config_file="${SCRIPT_DIR}/config"
722
+
723
+ if grep -q "^export ${var_name}=" "${config_file}" 2>/dev/null; then
724
+ sed -i.bak "s|^export ${var_name}=.*|export ${var_name}=\"${var_value}\"|" "${config_file}"
725
+ rm -f "${config_file}.bak"
726
+ else
727
+ echo "export ${var_name}=\"${var_value}\"" >> "${config_file}"
728
+ fi
729
+ }
730
+
731
+ # ── Main ──────────────────────────────────────────────────────────────────────
732
+ _parse_args "$@"
733
+
734
+ if [ "${ARG_HELP}" = true ]; then
735
+ _show_help
736
+ fi
737
+
738
+ if [ "${ARG_STATUS}" = true ]; then
739
+ # Show status of current tracked job without submitting
740
+ if [ -z "${TRAIN_JOB_NAME:-}" ]; then
741
+ echo "📊 No training job tracked."
742
+ echo " Run ./do/train to submit a training job."
743
+ exit 0
744
+ fi
745
+
746
+ echo "📊 Training Job Status"
747
+ echo " Job: ${TRAIN_JOB_NAME}"
748
+
749
+ # Call DescribeTrainingJob and parse the response
750
+ STATUS_JSON=$(aws sagemaker describe-training-job \
751
+ --training-job-name "${TRAIN_JOB_NAME}" \
752
+ --region "${AWS_REGION}" 2>&1) || {
753
+ echo ""
754
+ echo "❌ Failed to describe training job: ${TRAIN_JOB_NAME}"
755
+ echo " ${STATUS_JSON}"
756
+ echo ""
757
+ echo " The job may have been deleted or the name is incorrect."
758
+ echo " Run ./do/train --force to start a new job."
759
+ exit 1
760
+ }
761
+
762
+ # Parse and display the status using the helper script
763
+ echo "${STATUS_JSON}" | python3 "${SCRIPT_DIR}/.train_status_parser.py"
764
+ exit 0
765
+ fi
766
+
767
+ # Parse and validate configuration
768
+ _parse_config
769
+ _validate_config
770
+
771
+ # Check idempotency (existing job handling)
772
+ _check_idempotency
773
+
774
+ # Build and submit the job
775
+ _build_job_request
776
+ _submit_job
777
+
778
+ # Poll for completion (unless --no-wait)
779
+ if [ "${ARG_NO_WAIT}" = true ]; then
780
+ echo " --no-wait specified. Job submitted, exiting without polling."
781
+ echo " Re-run ./do/train --status to check progress."
782
+ exit 0
783
+ fi
784
+
785
+ _poll_job
786
+ _handle_completion