@aws/ml-container-creator 1.0.2 → 1.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. package/README.md +1 -1
  2. package/bin/cli.js +1 -1
  3. package/config/tune-catalog.json +303 -1
  4. package/infra/ci-harness/lib/ci-harness-stack.ts +43 -0
  5. package/package.json +3 -2
  6. package/servers/base-image-picker/index.js +65 -18
  7. package/servers/instance-sizer/index.js +32 -0
  8. package/servers/lib/catalogs/fleet-drivers.json +38 -0
  9. package/servers/lib/catalogs/model-arch-support.json +51 -0
  10. package/servers/lib/catalogs/model-servers.json +2842 -1516
  11. package/servers/lib/schemas/image-catalog.schema.json +12 -0
  12. package/src/app.js +6 -4
  13. package/src/lib/bootstrap-command-handler.js +12 -2
  14. package/src/lib/bootstrap-profile-manager.js +16 -0
  15. package/src/lib/cross-cutting-checker.js +6 -1
  16. package/src/lib/generated/cli-options.js +1 -1
  17. package/src/lib/generated/parameter-matrix.js +1 -1
  18. package/src/lib/generated/validation-rules.js +1 -1
  19. package/src/lib/mcp-query-runner.js +110 -3
  20. package/src/lib/prompt-runner.js +66 -22
  21. package/src/lib/template-variable-resolver.js +8 -0
  22. package/src/lib/train-config-builder.js +339 -0
  23. package/templates/do/.benchmark_writer.py +3 -0
  24. package/templates/do/.eval_helper.py +409 -0
  25. package/templates/do/.register_helper.py +185 -11
  26. package/templates/do/.train_build_request.py +102 -113
  27. package/templates/do/.train_helper.py +433 -0
  28. package/templates/do/__pycache__/.register_helper.cpython-312.pyc +0 -0
  29. package/templates/do/adapter +157 -0
  30. package/templates/do/benchmark +60 -3
  31. package/templates/do/deploy.d/managed-inference.ejs +83 -0
  32. package/templates/do/evaluate +272 -0
  33. package/templates/do/lib/resolve-instance.sh +155 -0
  34. package/templates/do/register +5 -0
  35. package/templates/do/test +1 -0
  36. package/templates/do/train +879 -126
  37. package/templates/do/training/config.yaml +83 -11
  38. package/templates/do/training/dpo/accelerate_config.yaml +24 -0
  39. package/templates/do/training/dpo/defaults.yaml +26 -0
  40. package/templates/do/training/dpo/prompts.json +8 -0
  41. package/templates/do/training/dpo/train.py +363 -0
  42. package/templates/do/training/sft/accelerate_config.yaml +22 -0
  43. package/templates/do/training/sft/defaults.yaml +18 -0
  44. package/templates/do/training/sft/prompts.json +7 -0
  45. package/templates/do/training/sft/train.py +310 -0
  46. package/templates/do/tune +11 -2
  47. package/templates/do/.train_poll_parser.py +0 -135
  48. package/templates/do/.train_status_parser.py +0 -187
  49. /package/templates/do/training/{train.py → custom/train.py} +0 -0
@@ -16,6 +16,8 @@ set -o pipefail
16
16
  # ── Source project configuration ──────────────────────────────────────────────
17
17
  SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
18
18
  source "${SCRIPT_DIR}/config"
19
+ source "${SCRIPT_DIR}/lib/profile.sh"
20
+ source "${SCRIPT_DIR}/lib/resolve-instance.sh"
19
21
 
20
22
  # ── Constants ─────────────────────────────────────────────────────────────────
21
23
  CONFIG_FILE="${SCRIPT_DIR}/training/config.yaml"
@@ -27,11 +29,30 @@ ARG_STATUS=false
27
29
  ARG_DRY_RUN=false
28
30
  ARG_NO_WAIT=false
29
31
  ARG_HELP=false
32
+ ARG_TECHNIQUE=""
33
+ ARG_DATASET=""
34
+ ARG_LIST_DATASETS=false
35
+ ARG_NO_REGISTER=false
36
+ ARG_INTERACTIVE=false
37
+ ARG_LEARNING_RATE=""
38
+ ARG_EPOCHS=""
39
+ ARG_BATCH_SIZE=""
40
+ ARG_LORA_R=""
41
+ ARG_BETA=""
42
+ ARG_RESUME=false
43
+ ARG_RESUME_JOB=""
30
44
 
31
45
  # ── Job Variables (set by _build_job_request) ─────────────────────────────────
32
46
  JOB_NAME=""
33
47
  JOB_REQUEST_FILE=""
34
48
 
49
+ # ── Technique Variables (set by _resolve_technique) ───────────────────────────
50
+ TECHNIQUE=""
51
+ TRAIN_SCRIPT_PATH=""
52
+
53
+ # ── Dataset Variables (set by _resolve_dataset) ───────────────────────────────
54
+ RESOLVED_DATASET_S3_URI=""
55
+
35
56
  # ── Training Config Variables (set by _parse_config) ──────────────────────────
36
57
  TRAIN_IMAGE=""
37
58
  TRAIN_SCRIPT=""
@@ -70,6 +91,46 @@ _parse_args() {
70
91
  --status) ARG_STATUS=true; shift ;;
71
92
  --dry-run) ARG_DRY_RUN=true; shift ;;
72
93
  --no-wait) ARG_NO_WAIT=true; shift ;;
94
+ --technique)
95
+ if [ -z "${2:-}" ]; then
96
+ echo "❌ --technique requires a value"
97
+ echo " Available: $(_list_techniques)"
98
+ exit 1
99
+ fi
100
+ ARG_TECHNIQUE="$2"; shift 2 ;;
101
+ --dataset)
102
+ if [ -z "${2:-}" ]; then
103
+ echo "❌ --dataset requires a value (s3://..., hf://..., or registry name)"
104
+ exit 1
105
+ fi
106
+ ARG_DATASET="$2"; shift 2 ;;
107
+ --list-datasets) ARG_LIST_DATASETS=true; shift ;;
108
+ --no-register) ARG_NO_REGISTER=true; shift ;;
109
+ --interactive|-i) ARG_INTERACTIVE=true; shift ;;
110
+ --learning-rate)
111
+ if [ -z "${2:-}" ]; then echo "❌ --learning-rate requires a value"; exit 1; fi
112
+ ARG_LEARNING_RATE="$2"; shift 2 ;;
113
+ --epochs)
114
+ if [ -z "${2:-}" ]; then echo "❌ --epochs requires a value"; exit 1; fi
115
+ ARG_EPOCHS="$2"; shift 2 ;;
116
+ --batch-size)
117
+ if [ -z "${2:-}" ]; then echo "❌ --batch-size requires a value"; exit 1; fi
118
+ ARG_BATCH_SIZE="$2"; shift 2 ;;
119
+ --lora-r)
120
+ if [ -z "${2:-}" ]; then echo "❌ --lora-r requires a value"; exit 1; fi
121
+ ARG_LORA_R="$2"; shift 2 ;;
122
+ --beta)
123
+ if [ -z "${2:-}" ]; then echo "❌ --beta requires a value"; exit 1; fi
124
+ ARG_BETA="$2"; shift 2 ;;
125
+ --resume)
126
+ ARG_RESUME=true
127
+ # Optional: next arg is a job name (not another flag)
128
+ if [ -n "${2:-}" ] && [[ "${2}" != -* ]]; then
129
+ ARG_RESUME_JOB="$2"; shift 2
130
+ else
131
+ shift
132
+ fi
133
+ ;;
73
134
  --help|-h) ARG_HELP=true; shift ;;
74
135
  *)
75
136
  echo "❌ Unknown option: $1"
@@ -83,6 +144,7 @@ _parse_args() {
83
144
  # ── _show_help() ──────────────────────────────────────────────────────────────
84
145
  _show_help() {
85
146
  echo "Usage: ./do/train [OPTIONS]"
147
+ echo " ./do/train --technique sft"
86
148
  echo " ./do/train --status"
87
149
  echo " ./do/train --help"
88
150
  echo ""
@@ -92,20 +154,372 @@ _show_help() {
92
154
  echo "Configuration is read from do/training/config.yaml"
93
155
  echo ""
94
156
  echo "Options:"
95
- echo " --force Create a new job even if a previous job exists"
96
- echo " --status Show current job status without submitting"
97
- echo " --dry-run Validate inputs and show the CreateTrainingJob request without submitting"
98
- echo " --no-wait Submit job and exit without polling for completion"
99
- echo " --help, -h Show this help message"
157
+ echo " --technique <name> Select training technique (overrides config.yaml)"
158
+ echo " --dataset <source> Dataset: s3://..., hf://org/name, or registry name"
159
+ echo " --interactive, -i Guided interactive configuration builder"
160
+ echo " --learning-rate <v> Override learning rate"
161
+ echo " --epochs <n> Override number of epochs"
162
+ echo " --batch-size <n> Override per-device batch size"
163
+ echo " --lora-r <n> Override LoRA rank"
164
+ echo " --beta <v> Override DPO beta (KL penalty)"
165
+ echo " --resume [job-name] Resume from previous job's checkpoint"
166
+ echo " --list-datasets Show registered datasets"
167
+ echo " --no-register Skip auto-registration after completion"
168
+ echo " --force Create a new job even if a previous job exists"
169
+ echo " --status Show current job status without submitting"
170
+ echo " --dry-run Validate inputs and show the request without submitting"
171
+ echo " --no-wait Submit job and exit without polling for completion"
172
+ echo " --help, -h Show this help message"
173
+ echo ""
174
+ echo "Available techniques:"
175
+ echo " $(_list_techniques)"
100
176
  echo ""
101
177
  echo "Examples:"
102
- echo " ./do/train # Submit a training job"
103
- echo " ./do/train --status # Check status of current job"
104
- echo " ./do/train --dry-run # Validate and preview request"
105
- echo " ./do/train --force # Force a new job after failure"
178
+ echo " ./do/train # Submit using config.yaml technique"
179
+ echo " ./do/train --technique sft # Submit SFT training job"
180
+ echo " ./do/train --status # Check status of current job"
181
+ echo " ./do/train --dry-run # Validate and preview request"
182
+ echo " ./do/train --force # Force a new job after failure"
106
183
  exit 0
107
184
  }
108
185
 
186
+ # ── _list_techniques() ────────────────────────────────────────────────────────
187
+ # Scan training/<technique>/train.py directories to list available techniques.
188
+ _list_techniques() {
189
+ local training_dir="${SCRIPT_DIR}/training"
190
+ local techniques=""
191
+ for dir in "${training_dir}"/*/; do
192
+ if [ -f "${dir}train.py" ]; then
193
+ local name
194
+ name=$(basename "${dir}")
195
+ if [ -n "${techniques}" ]; then
196
+ techniques="${techniques}, ${name}"
197
+ else
198
+ techniques="${name}"
199
+ fi
200
+ fi
201
+ done
202
+ echo "${techniques:-none found}"
203
+ }
204
+
205
+ # ── _resolve_technique() ─────────────────────────────────────────────────────
206
+ # Determine technique from CLI flag or config.yaml, validate directory exists.
207
+ # Sets TECHNIQUE and TRAIN_SCRIPT_PATH.
208
+ _resolve_technique() {
209
+ local training_dir="${SCRIPT_DIR}/training"
210
+
211
+ # Priority: CLI flag > config.yaml > default (custom)
212
+ if [ -n "${ARG_TECHNIQUE}" ]; then
213
+ TECHNIQUE="${ARG_TECHNIQUE}"
214
+ elif command -v yq &>/dev/null; then
215
+ TECHNIQUE=$(yq -r '.technique // "custom"' "${CONFIG_FILE}")
216
+ elif command -v python3 &>/dev/null; then
217
+ TECHNIQUE=$(python3 -c "
218
+ import yaml
219
+ with open('${CONFIG_FILE}', 'r') as f:
220
+ cfg = yaml.safe_load(f) or {}
221
+ print(cfg.get('technique', 'custom'))
222
+ " 2>/dev/null || echo "custom")
223
+ else
224
+ TECHNIQUE="custom"
225
+ fi
226
+
227
+ # Fallback if empty
228
+ TECHNIQUE="${TECHNIQUE:-custom}"
229
+
230
+ # Validate technique directory exists
231
+ TRAIN_SCRIPT_PATH="${training_dir}/${TECHNIQUE}/train.py"
232
+ if [ ! -f "${TRAIN_SCRIPT_PATH}" ]; then
233
+ echo "❌ Training technique '${TECHNIQUE}' not found"
234
+ echo " Expected: ${TRAIN_SCRIPT_PATH}"
235
+ echo ""
236
+ echo " Available techniques: $(_list_techniques)"
237
+ echo ""
238
+ echo " To add a new technique, create: do/training/${TECHNIQUE}/train.py"
239
+ exit 1
240
+ fi
241
+
242
+ echo "📋 Technique: ${TECHNIQUE}"
243
+ echo " Script: ${TRAIN_SCRIPT_PATH}"
244
+ }
245
+
246
+ # ── _load_technique_defaults() ────────────────────────────────────────────────
247
+ # Load technique-specific default hyperparameters from training/<technique>/defaults.yaml.
248
+ # Returns values via TECHNIQUE_DEFAULTS associative-style extraction (yq or python).
249
+ # Sets TECHNIQUE_DEFAULTS_JSON (JSON string of defaults).
250
+ _load_technique_defaults() {
251
+ local defaults_file="${SCRIPT_DIR}/training/${TECHNIQUE}/defaults.yaml"
252
+ TECHNIQUE_DEFAULTS_JSON="{}"
253
+
254
+ if [ ! -f "${defaults_file}" ]; then
255
+ return 0
256
+ fi
257
+
258
+ if command -v yq &>/dev/null; then
259
+ TECHNIQUE_DEFAULTS_JSON=$(yq -o=json '.' "${defaults_file}" 2>/dev/null) || TECHNIQUE_DEFAULTS_JSON="{}"
260
+ elif command -v python3 &>/dev/null; then
261
+ TECHNIQUE_DEFAULTS_JSON=$(python3 -c "
262
+ import yaml, json
263
+ with open('${defaults_file}', 'r') as f:
264
+ data = yaml.safe_load(f) or {}
265
+ # Remove comment-only entries
266
+ data = {k: v for k, v in data.items() if v is not None}
267
+ print(json.dumps(data))
268
+ " 2>/dev/null) || TECHNIQUE_DEFAULTS_JSON="{}"
269
+ fi
270
+ }
271
+
272
+ # ── _merge_hyperparameters() ──────────────────────────────────────────────────
273
+ # Merge hyperparameters with precedence: CLI flags > config.yaml > technique defaults.
274
+ # Modifies TRAIN_HYPERPARAMS (JSON string) in place.
275
+ _merge_hyperparameters() {
276
+ # Merge: technique defaults as base, config.yaml on top, CLI flags on top of that
277
+ TRAIN_HYPERPARAMS=$(python3 -c "
278
+ import json, sys
279
+
280
+ defaults_json = '''${TECHNIQUE_DEFAULTS_JSON}'''
281
+ config_json = '''${TRAIN_HYPERPARAMS}'''
282
+
283
+ defaults = json.loads(defaults_json) if defaults_json and defaults_json != '{}' else {}
284
+ config_hp = json.loads(config_json) if config_json and config_json != '{}' else {}
285
+
286
+ # Layer 1: technique defaults as base
287
+ merged = {str(k): str(v) for k, v in defaults.items()}
288
+
289
+ # Layer 2: config.yaml hyperparameters override defaults
290
+ merged.update({str(k): str(v) for k, v in config_hp.items()})
291
+
292
+ # Layer 3: CLI flags override everything (highest priority)
293
+ cli_overrides = {}
294
+ if '${ARG_LEARNING_RATE}':
295
+ cli_overrides['learning_rate'] = '${ARG_LEARNING_RATE}'
296
+ if '${ARG_EPOCHS}':
297
+ cli_overrides['epochs'] = '${ARG_EPOCHS}'
298
+ if '${ARG_BATCH_SIZE}':
299
+ cli_overrides['batch_size'] = '${ARG_BATCH_SIZE}'
300
+ if '${ARG_LORA_R}':
301
+ cli_overrides['lora_r'] = '${ARG_LORA_R}'
302
+ if '${ARG_BETA}':
303
+ cli_overrides['beta'] = '${ARG_BETA}'
304
+
305
+ merged.update(cli_overrides)
306
+ print(json.dumps(merged))
307
+ " 2>/dev/null) || true
308
+ }
309
+
310
+ # ── _list_datasets_cmd() ─────────────────────────────────────────────────────
311
+ # Show registered datasets (delegates to .register_helper.py).
312
+ _list_datasets_cmd() {
313
+ echo "📋 Registered Datasets"
314
+ echo ""
315
+ if [ -f "${SCRIPT_DIR}/.register_helper.py" ]; then
316
+ python3 "${SCRIPT_DIR}/.register_helper.py" list-datasets \
317
+ --region "${AWS_REGION:-us-east-1}" 2>/dev/null || {
318
+ echo " No datasets registered yet."
319
+ echo " Register: ./do/register dataset <name> --s3-uri <uri> --technique <sft|dpo>"
320
+ }
321
+ else
322
+ echo " Register helper not available."
323
+ fi
324
+ exit 0
325
+ }
326
+
327
+ # ── _resolve_dataset() ────────────────────────────────────────────────────────
328
+ # Resolve --dataset flag to an S3 URI. Handles:
329
+ # hf://org/name[/split] → stage via .tune_helper.py stage-hf
330
+ # s3://bucket/path/ → use directly
331
+ # name[@v<N>] → resolve from registry via .register_helper.py
332
+ # Sets RESOLVED_DATASET_S3_URI on success.
333
+ _resolve_dataset() {
334
+ local dataset="${ARG_DATASET}"
335
+
336
+ # If no --dataset provided, check config.yaml
337
+ if [ -z "${dataset}" ]; then
338
+ dataset="${TRAIN_DATASET}"
339
+ fi
340
+
341
+ # Still empty — dataset is optional for custom technique, required for sft/dpo
342
+ if [ -z "${dataset}" ]; then
343
+ if [ "${TECHNIQUE}" = "custom" ]; then
344
+ return 0 # Custom technique doesn't require managed dataset resolution
345
+ fi
346
+ echo "❌ --dataset is required for technique '${TECHNIQUE}'"
347
+ echo " Provide: s3://bucket/path.jsonl, hf://org/name, or a registered name"
348
+ echo " Run ./do/train --list-datasets to see registered datasets."
349
+ exit 1
350
+ fi
351
+
352
+ # ── Parse @v<N> version suffix ────────────────────────────────────────────
353
+ local dataset_name="" dataset_version=""
354
+ if [[ "${dataset}" =~ ^(.+)@v([0-9]+)$ ]]; then
355
+ dataset_name="${BASH_REMATCH[1]}"
356
+ dataset_version="${BASH_REMATCH[2]}"
357
+ dataset=""
358
+ fi
359
+
360
+ # ── Determine dataset type and resolve ────────────────────────────────────
361
+ if [ -n "${dataset_name}" ]; then
362
+ # Registry name resolution
363
+ _resolve_dataset_from_registry "${dataset_name}" "${dataset_version}"
364
+
365
+ elif [[ "${dataset}" == s3://* ]]; then
366
+ # S3 direct — use as-is
367
+ echo "📂 Dataset: ${dataset} (S3 direct)"
368
+ RESOLVED_DATASET_S3_URI="${dataset}"
369
+
370
+ elif [[ "${dataset}" == hf://* ]]; then
371
+ # HuggingFace — stage via helper
372
+ _stage_hf_dataset "${dataset}"
373
+
374
+ else
375
+ # Treat as a registry name
376
+ _resolve_dataset_from_registry "${dataset}" ""
377
+ fi
378
+
379
+ # Persist resolved URI
380
+ if [ -n "${RESOLVED_DATASET_S3_URI}" ]; then
381
+ local technique_upper
382
+ technique_upper=$(echo "${TECHNIQUE}" | tr '[:lower:]' '[:upper:]')
383
+ _update_config_var "TRAIN_DATASET_S3_URI_${technique_upper}" "${RESOLVED_DATASET_S3_URI}"
384
+ fi
385
+ }
386
+
387
+ # ── _resolve_dataset_from_registry() ─────────────────────────────────────────
388
+ _resolve_dataset_from_registry() {
389
+ local name="$1"
390
+ local version="$2"
391
+
392
+ echo "🔍 Resolving dataset '${name}' from registry..."
393
+ local resolve_args=("--name" "${name}")
394
+ if [ -n "${version}" ]; then
395
+ resolve_args+=("--version" "${version}")
396
+ echo " Version: v${version}"
397
+ fi
398
+
399
+ local resolve_result
400
+ resolve_result=$(python3 "${SCRIPT_DIR}/.register_helper.py" resolve-dataset \
401
+ "${resolve_args[@]}" 2>/dev/null) || resolve_result=""
402
+
403
+ if [ -n "${resolve_result}" ]; then
404
+ local resolved_uri
405
+ resolved_uri=$(echo "${resolve_result}" | grep -E '^\{' | tail -1 | \
406
+ python3 -c "import sys,json; print(json.load(sys.stdin).get('s3_uri',''))" 2>/dev/null) || resolved_uri=""
407
+ if [ -n "${resolved_uri}" ]; then
408
+ echo " Resolved to: ${resolved_uri}"
409
+ RESOLVED_DATASET_S3_URI="${resolved_uri}"
410
+ return 0
411
+ fi
412
+ fi
413
+
414
+ echo "❌ Dataset '${name}' not found in registry"
415
+ echo " Register it: ./do/register dataset ${name} --s3-uri <uri> --technique ${TECHNIQUE}"
416
+ echo " Or provide directly: --dataset s3://... or --dataset hf://..."
417
+ exit 1
418
+ }
419
+
420
+ # ── _stage_hf_dataset() ──────────────────────────────────────────────────────
421
+ # Stage a HuggingFace dataset to S3 via .tune_helper.py stage-hf.
422
+ _stage_hf_dataset() {
423
+ local dataset="$1"
424
+ local hf_path="${dataset#hf://}"
425
+ local hf_file=""
426
+
427
+ # Extract ?file= parameter
428
+ if [[ "${hf_path}" == *"?file="* ]]; then
429
+ hf_file="${hf_path#*?file=}"
430
+ hf_path="${hf_path%%\?file=*}"
431
+ fi
432
+
433
+ local hf_org hf_name hf_split
434
+ hf_org=$(echo "${hf_path}" | cut -d'/' -f1)
435
+ hf_name=$(echo "${hf_path}" | cut -d'/' -f2)
436
+ hf_split=$(echo "${hf_path}" | cut -d'/' -f3-)
437
+
438
+ if [ -z "${hf_org}" ] || [ -z "${hf_name}" ]; then
439
+ echo "❌ Invalid HF dataset reference: ${dataset}"
440
+ echo " Expected format: hf://org/name or hf://org/name/split"
441
+ exit 1
442
+ fi
443
+
444
+ # Determine which helper script to use for staging
445
+ local helper_script=""
446
+ if [ -f "${SCRIPT_DIR}/.tune_helper.py" ]; then
447
+ helper_script="${SCRIPT_DIR}/.tune_helper.py"
448
+ else
449
+ echo "❌ Dataset staging helper not available (.tune_helper.py)"
450
+ echo " Stage the dataset manually to S3 and use: --dataset s3://..."
451
+ exit 1
452
+ fi
453
+
454
+ echo "📦 Staging HuggingFace dataset: ${hf_org}/${hf_name}"
455
+ if [ -n "${hf_split}" ]; then
456
+ echo " Split: ${hf_split}"
457
+ fi
458
+
459
+ # Resolve output bucket from profile
460
+ local output_bucket="${S3_BUCKET:-}"
461
+ if [ -z "${output_bucket}" ]; then
462
+ output_bucket=$(python3 -c "
463
+ import json, os
464
+ config_path = os.path.expanduser('~/.ml-container-creator/config.json')
465
+ if os.path.exists(config_path):
466
+ with open(config_path) as f:
467
+ cfg = json.load(f)
468
+ print(cfg.get('s3Bucket', ''))
469
+ " 2>/dev/null) || output_bucket=""
470
+ fi
471
+
472
+ if [ -z "${output_bucket}" ]; then
473
+ echo "❌ No S3 bucket configured for dataset staging"
474
+ echo " Run ./do/bootstrap to configure, or set S3_BUCKET in do/config"
475
+ exit 1
476
+ fi
477
+
478
+ # Build stage-hf arguments
479
+ local stage_args=(
480
+ --hf-org "${hf_org}"
481
+ --hf-name "${hf_name}"
482
+ --output-bucket "${output_bucket}"
483
+ --project-name "${PROJECT_NAME}"
484
+ --region "${AWS_REGION}"
485
+ --technique "${TECHNIQUE}"
486
+ )
487
+ if [ -n "${hf_split}" ]; then
488
+ stage_args+=(--hf-split "${hf_split}")
489
+ fi
490
+ if [ -n "${HF_TOKEN_ARN:-}" ]; then
491
+ stage_args+=(--hf-secret-name "${HF_TOKEN_ARN}")
492
+ fi
493
+ if [ -n "${hf_file}" ]; then
494
+ stage_args+=(--hf-file "${hf_file}")
495
+ fi
496
+
497
+ local stage_result
498
+ stage_result=$(python3 "${helper_script}" stage-hf "${stage_args[@]}") || {
499
+ echo "❌ Failed to stage HF dataset"
500
+ exit 1
501
+ }
502
+
503
+ # Check for error in response
504
+ local has_error
505
+ has_error=$(echo "${stage_result}" | python3 -c "import sys,json; d=json.load(sys.stdin); print('yes' if 'error' in d else 'no')" 2>/dev/null) || has_error="yes"
506
+
507
+ if [ "${has_error}" = "yes" ]; then
508
+ local error_msg
509
+ error_msg=$(echo "${stage_result}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('error','Unknown error'))" 2>/dev/null) || error_msg="Unknown error"
510
+ echo "❌ ${error_msg}"
511
+ exit 1
512
+ fi
513
+
514
+ RESOLVED_DATASET_S3_URI=$(echo "${stage_result}" | python3 -c "import sys,json; print(json.load(sys.stdin)['s3_uri'])" 2>/dev/null)
515
+ local row_count
516
+ row_count=$(echo "${stage_result}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('num_records',0))" 2>/dev/null) || row_count="0"
517
+
518
+ echo " ✅ Staged to: ${RESOLVED_DATASET_S3_URI}"
519
+ echo " Records: ${row_count}"
520
+ echo ""
521
+ }
522
+
109
523
  # ── _parse_config() ───────────────────────────────────────────────────────────
110
524
  # Read and parse do/training/config.yaml into bash variables.
111
525
  # Uses yq if available, falls back to python3 YAML parsing.
@@ -135,6 +549,11 @@ _parse_config_yq() {
135
549
  TRAIN_IMAGE=$(yq -r '.image // ""' "${CONFIG_FILE}")
136
550
  TRAIN_SCRIPT=$(yq -r '.script // ""' "${CONFIG_FILE}")
137
551
  TRAIN_INSTANCE_TYPE=$(yq -r '.instance_type // ""' "${CONFIG_FILE}")
552
+
553
+ # Resolve shell variables in image URI (backward compat with old-style config)
554
+ if echo "${TRAIN_IMAGE}" | grep -q '^\${\|^\${'; then
555
+ TRAIN_IMAGE=$(eval echo "${TRAIN_IMAGE}")
556
+ fi
138
557
  TRAIN_INSTANCE_COUNT=$(yq -r '.instance_count // "1"' "${CONFIG_FILE}")
139
558
  TRAIN_DATASET=$(yq -r '.dataset // ""' "${CONFIG_FILE}")
140
559
  TRAIN_OUTPUT_PATH=$(yq -r '.output_path // ""' "${CONFIG_FILE}")
@@ -302,12 +721,22 @@ _validate_config() {
302
721
  fi
303
722
 
304
723
  if [ -z "${TRAIN_OUTPUT_PATH}" ]; then
305
- echo "❌ Missing required field: output_path"
306
- echo " The S3 output path is required in do/training/config.yaml"
307
- echo ""
308
- echo " Expected format: output_path: \"s3://my-bucket/output/\""
309
- echo ""
310
- has_error=true
724
+ # Auto-resolve from profile: use benchmarkS3Bucket or construct from project name
725
+ if [ -n "${_PROFILE_benchmarkS3Bucket:-}" ]; then
726
+ TRAIN_OUTPUT_PATH="s3://${_PROFILE_benchmarkS3Bucket}/${PROJECT_NAME}/training-output/"
727
+ elif [ -n "${BENCHMARK_S3_OUTPUT_PATH:-}" ]; then
728
+ # Derive from benchmark output path (replace /benchmarks/ with /training-output/)
729
+ TRAIN_OUTPUT_PATH="${BENCHMARK_S3_OUTPUT_PATH%/benchmarks/*}/training-output/${PROJECT_NAME}/"
730
+ fi
731
+ if [ -z "${TRAIN_OUTPUT_PATH}" ]; then
732
+ echo "❌ Missing required field: output_path"
733
+ echo " The S3 output path is required in do/training/config.yaml"
734
+ echo " Or run 'ml-container-creator bootstrap' to configure an S3 bucket."
735
+ echo ""
736
+ echo " Expected format: output_path: \"s3://my-bucket/output/\""
737
+ echo ""
738
+ has_error=true
739
+ fi
311
740
  fi
312
741
 
313
742
  # Spot training requires a checkpoint path for resumption
@@ -350,29 +779,33 @@ _check_idempotency() {
350
779
  echo " Checking status..."
351
780
  echo ""
352
781
 
353
- # Call DescribeTrainingJob to get current status
354
- local describe_output
355
- local describe_exit_code
356
- describe_output=$(aws sagemaker describe-training-job \
357
- --training-job-name "${TRAIN_JOB_NAME}" 2>&1) || describe_exit_code=$?
358
- describe_exit_code=${describe_exit_code:-0}
782
+ # Query status via SDK v3 helper
783
+ local status_json
784
+ status_json=$(python3 "${SCRIPT_DIR}/.train_helper.py" status \
785
+ --job-name "${TRAIN_JOB_NAME}" \
786
+ --region "${AWS_REGION}" 2>/dev/null | grep -E '^\{' | tail -1) || status_json=""
359
787
 
360
- if [ ${describe_exit_code} -ne 0 ]; then
361
- # If describe fails (e.g., job was deleted), proceed to new job
788
+ if [ -z "${status_json}" ]; then
789
+ # If status query fails, proceed to new job
790
+ echo "⚠️ Could not query existing job: ${TRAIN_JOB_NAME}"
791
+ echo " Proceeding to create a new job."
792
+ echo ""
793
+ return 0
794
+ fi
795
+
796
+ # Check for error in response
797
+ local has_error
798
+ has_error=$(echo "${status_json}" | python3 -c "import sys,json; d=json.load(sys.stdin); print('yes' if d.get('error') else 'no')" 2>/dev/null) || has_error="no"
799
+
800
+ if [ "${has_error}" = "yes" ]; then
362
801
  echo "⚠️ Could not describe existing job: ${TRAIN_JOB_NAME}"
363
- echo " ${describe_output}"
364
802
  echo " Proceeding to create a new job."
365
803
  echo ""
366
804
  return 0
367
805
  fi
368
806
 
369
- # Extract status from the JSON response using python3
370
807
  local job_status
371
- job_status=$(echo "${describe_output}" | python3 -c "
372
- import sys, json
373
- resp = json.load(sys.stdin)
374
- print(resp.get('TrainingJobStatus', 'Unknown'))
375
- " 2>/dev/null) || job_status="Unknown"
808
+ job_status=$(echo "${status_json}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('status','Unknown'))" 2>/dev/null) || job_status="Unknown"
376
809
 
377
810
  case "${job_status}" in
378
811
  InProgress)
@@ -386,19 +819,12 @@ print(resp.get('TrainingJobStatus', 'Unknown'))
386
819
  Completed)
387
820
  echo "✅ Training job already completed: ${TRAIN_JOB_NAME}"
388
821
  echo ""
389
- # Pass the describe output to _handle_completion via a temp file
390
- local describe_file="/tmp/train-describe-${TRAIN_JOB_NAME}.json"
391
- echo "${describe_output}" > "${describe_file}"
392
822
  _handle_completion
393
823
  exit 0
394
824
  ;;
395
825
  Failed)
396
826
  local failure_reason
397
- failure_reason=$(echo "${describe_output}" | python3 -c "
398
- import sys, json
399
- resp = json.load(sys.stdin)
400
- print(resp.get('FailureReason', 'No failure reason provided'))
401
- " 2>/dev/null) || failure_reason="No failure reason provided"
827
+ failure_reason=$(echo "${status_json}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('failure_reason','') or 'No failure reason provided')" 2>/dev/null) || failure_reason="No failure reason provided"
402
828
 
403
829
  echo "❌ Previous training job failed: ${TRAIN_JOB_NAME}"
404
830
  echo " Reason: ${failure_reason}"
@@ -469,9 +895,177 @@ _build_job_request() {
469
895
  fi
470
896
  }
471
897
 
898
+ <% if (deploymentTarget === 'hyperpod-eks') { %>
472
899
  # ── _submit_job() ─────────────────────────────────────────────────────────────
473
- # Call aws sagemaker create-training-job with the constructed JSON.
474
- # Handles --dry-run, AccessDenied detection, and config persistence.
900
+ # HyperPod EKS: Generate and apply a K8s Job manifest for single-pod training.
901
+ # Multi-node (parallelism > 1) requires PyTorchJob CRD — see Epic 8.
902
+ _submit_job() {
903
+ local manifest_file="${SCRIPT_DIR}/../.train-job.yaml"
904
+
905
+ # Handle --dry-run: generate manifest and print without applying
906
+ if [ "${ARG_DRY_RUN}" = true ]; then
907
+ _generate_k8s_manifest "${manifest_file}"
908
+ echo ""
909
+ echo "🔍 Dry run — K8s Job manifest:"
910
+ echo ""
911
+ cat "${manifest_file}"
912
+ echo ""
913
+ rm -f "${manifest_file}"
914
+ exit 0
915
+ fi
916
+
917
+ echo ""
918
+ echo "🚀 Submitting training job to HyperPod EKS..."
919
+
920
+ _generate_k8s_manifest "${manifest_file}"
921
+
922
+ # Apply the manifest
923
+ kubectl apply -f "${manifest_file}" 2>&1 || {
924
+ echo "❌ Failed to apply K8s Job manifest"
925
+ echo " Check: kubectl access, namespace exists, PVC available"
926
+ exit 1
927
+ }
928
+
929
+ # Persist job name to do/config
930
+ _update_config_var "TRAIN_JOB_NAME" "${JOB_NAME}"
931
+ echo " ✅ Job submitted: ${JOB_NAME}"
932
+ echo " Namespace: ${HYPERPOD_NAMESPACE:-default}"
933
+ echo ""
934
+ }
935
+
936
+ # ── _generate_k8s_manifest() ─────────────────────────────────────────────────
937
+ # Generate a K8s batch/v1 Job YAML for single-node multi-GPU training.
938
+ _generate_k8s_manifest() {
939
+ local manifest_file="$1"
940
+ local gpu_count="${TRAIN_INSTANCE_COUNT:-1}"
941
+
942
+ # Detect EFA availability on cluster nodes
943
+ local efa_count
944
+ efa_count=$(kubectl get nodes -o jsonpath='{.items[0].status.allocatable.vpc\.amazonaws\.com/efa}' 2>/dev/null || echo "0")
945
+
946
+ # Build EFA resource request if available
947
+ local efa_resource=""
948
+ local nccl_env=""
949
+ if [ "${efa_count}" != "0" ] && [ "${efa_count}" != "" ]; then
950
+ efa_resource=" vpc.amazonaws.com/efa: ${efa_count}"
951
+ nccl_env=" - name: NCCL_SOCKET_IFNAME
952
+ value: \"eth0\"
953
+ - name: FI_PROVIDER
954
+ value: \"efa\"
955
+ - name: FI_EFA_USE_DEVICE_RDMA
956
+ value: \"1\""
957
+ fi
958
+
959
+ cat > "${manifest_file}" <<MANIFEST
960
+ # Generated by do/train for HyperPod EKS (single-node)
961
+ # Multi-node distributed training requires PyTorchJob CRD — see Epic 8.
962
+ apiVersion: batch/v1
963
+ kind: Job
964
+ metadata:
965
+ name: ${JOB_NAME}
966
+ namespace: ${HYPERPOD_NAMESPACE:-default}
967
+ labels:
968
+ app: ml-container-creator
969
+ project: ${PROJECT_NAME}
970
+ technique: ${TECHNIQUE:-custom}
971
+ spec:
972
+ parallelism: 1
973
+ completions: 1
974
+ backoffLimit: 0
975
+ template:
976
+ metadata:
977
+ labels:
978
+ app: ml-container-creator
979
+ job-name: ${JOB_NAME}
980
+ spec:
981
+ containers:
982
+ - name: training
983
+ image: ${TRAIN_IMAGE}
984
+ command: ["accelerate", "launch", "--config_file", "/workspace/training/${TECHNIQUE:-custom}/accelerate_config.yaml", "/workspace/training/${TECHNIQUE:-custom}/train.py"]
985
+ env:
986
+ - name: DATA_DIR
987
+ value: "/data/training"
988
+ - name: OUTPUT_DIR
989
+ value: "/output/model"
990
+ - name: CHECKPOINT_DIR
991
+ value: "/output/checkpoints"
992
+ - name: HF_MODEL_ID
993
+ value: "${HF_MODEL_ID:-}"
994
+ - name: NUM_GPUS
995
+ value: "${gpu_count}"
996
+ - name: TRAIN_TECHNIQUE
997
+ value: "${TECHNIQUE:-custom}"
998
+ ${nccl_env}
999
+ resources:
1000
+ limits:
1001
+ nvidia.com/gpu: ${gpu_count}
1002
+ ${efa_resource}
1003
+ volumeMounts:
1004
+ - name: data
1005
+ mountPath: /data
1006
+ - name: output
1007
+ mountPath: /output
1008
+ - name: workspace
1009
+ mountPath: /workspace
1010
+ volumes:
1011
+ - name: data
1012
+ persistentVolumeClaim:
1013
+ claimName: ${DATA_PVC:-training-data}
1014
+ - name: output
1015
+ persistentVolumeClaim:
1016
+ claimName: ${OUTPUT_PVC:-training-output}
1017
+ - name: workspace
1018
+ configMap:
1019
+ name: ${PROJECT_NAME}-training-scripts
1020
+ restartPolicy: Never
1021
+ MANIFEST
1022
+ }
1023
+
1024
+ # ── _poll_job() ───────────────────────────────────────────────────────────────
1025
+ # HyperPod EKS: Poll K8s Job status via kubectl.
1026
+ _poll_job() {
1027
+ local job_name="${JOB_NAME:-$TRAIN_JOB_NAME}"
1028
+ local namespace="${HYPERPOD_NAMESPACE:-default}"
1029
+
1030
+ echo "⏳ Polling K8s Job: ${job_name} (namespace: ${namespace})"
1031
+ echo " (Ctrl+C to stop polling — job continues in cluster)"
1032
+ echo ""
1033
+
1034
+ while true; do
1035
+ local status
1036
+ status=$(kubectl get job "${job_name}" -n "${namespace}" -o jsonpath='{.status.conditions[0].type}' 2>/dev/null) || status=""
1037
+
1038
+ local active
1039
+ active=$(kubectl get job "${job_name}" -n "${namespace}" -o jsonpath='{.status.active}' 2>/dev/null) || active="0"
1040
+
1041
+ local succeeded
1042
+ succeeded=$(kubectl get job "${job_name}" -n "${namespace}" -o jsonpath='{.status.succeeded}' 2>/dev/null) || succeeded="0"
1043
+
1044
+ local failed
1045
+ failed=$(kubectl get job "${job_name}" -n "${namespace}" -o jsonpath='{.status.failed}' 2>/dev/null) || failed="0"
1046
+
1047
+ if [ "${succeeded}" = "1" ]; then
1048
+ echo " ✅ Completed"
1049
+ echo ""
1050
+ echo "✅ Training job completed: ${job_name}"
1051
+ break
1052
+ elif [ "${failed}" = "1" ] || [ "${status}" = "Failed" ]; then
1053
+ echo " ❌ Failed"
1054
+ echo ""
1055
+ echo "❌ Training job failed: ${job_name}"
1056
+ echo " Logs: kubectl logs job/${job_name} -n ${namespace}"
1057
+ exit 2
1058
+ else
1059
+ echo " 🔄 Running (active: ${active:-0})"
1060
+ fi
1061
+
1062
+ sleep "${POLL_INTERVAL}"
1063
+ done
1064
+ }
1065
+ <% } else { %>
1066
+ # ── _submit_job() ─────────────────────────────────────────────────────────────
1067
+ # Submit training job via .train_helper.py (SDK v3).
1068
+ # Handles --dry-run, error detection, and config persistence.
475
1069
  _submit_job() {
476
1070
  # Handle --dry-run: print the request JSON and exit without submitting
477
1071
  if [ "${ARG_DRY_RUN}" = true ]; then
@@ -487,59 +1081,44 @@ _submit_job() {
487
1081
  echo ""
488
1082
  echo "🚀 Submitting training job..."
489
1083
 
490
- # Submit the job via AWS CLI
1084
+ # Submit via SDK v3 helper
491
1085
  local submit_output
492
- local submit_exit_code
493
- submit_output=$(aws sagemaker create-training-job \
494
- --cli-input-json "file://${JOB_REQUEST_FILE}" 2>&1) || submit_exit_code=$?
495
- submit_exit_code=${submit_exit_code:-0}
1086
+ submit_output=$(python3 "${SCRIPT_DIR}/.train_helper.py" submit \
1087
+ --config "${JOB_REQUEST_FILE}" \
1088
+ --region "${AWS_REGION}" 2>&1 | grep -E '^\{' | tail -1) || submit_output=""
496
1089
 
497
1090
  # Clean up the temporary request file
498
1091
  rm -f "${JOB_REQUEST_FILE}"
499
1092
 
500
- if [ ${submit_exit_code} -eq 0 ]; then
501
- # Success persist job name to do/config
502
- _update_config_var "TRAIN_JOB_NAME" "${JOB_NAME}"
503
- echo " ✅ Job submitted successfully: ${JOB_NAME}"
504
- echo ""
505
- else
506
- # Failure — detect error type and provide remediation
507
- if echo "${submit_output}" | grep -q "AccessDeniedException"; then
508
- # Extract the permission or action from the error message
509
- local missing_permission
510
- missing_permission=$(echo "${submit_output}" | grep -oP '(?<=performing: )[^ ]+' || echo "")
511
- if [ -z "${missing_permission}" ]; then
512
- missing_permission=$(echo "${submit_output}" | grep -oP '(?<=action: )[^ ]+' || echo "")
513
- fi
514
- if [ -z "${missing_permission}" ]; then
515
- missing_permission="sagemaker:CreateTrainingJob"
516
- fi
1093
+ if [ -z "${submit_output}" ]; then
1094
+ echo "❌ Failed to submit training job (no response from helper)"
1095
+ echo " Ensure sagemaker SDK v3 is installed: pip install 'sagemaker>=3.0'"
1096
+ exit 1
1097
+ fi
517
1098
 
518
- echo "❌ Access denied: ${missing_permission}"
519
- echo " ${submit_output}"
520
- echo ""
521
- echo " Remediation:"
522
- echo " Ensure your IAM role or user has the '${missing_permission}' permission."
523
- echo " If using the bootstrap stack, re-run ./do/bootstrap to update permissions."
524
- echo " Otherwise, attach a policy granting '${missing_permission}' to your role."
525
- exit 1
526
- else
527
- echo "❌ Failed to submit training job"
528
- echo " ${submit_output}"
529
- echo ""
530
- echo " Check the error above and verify your configuration in do/training/config.yaml."
531
- exit 1
532
- fi
1099
+ # Check for error in response
1100
+ local has_error
1101
+ has_error=$(echo "${submit_output}" | python3 -c "import sys,json; d=json.load(sys.stdin); print('yes' if d.get('error') else 'no')" 2>/dev/null) || has_error="yes"
1102
+
1103
+ if [ "${has_error}" = "yes" ]; then
1104
+ local error_msg
1105
+ error_msg=$(echo "${submit_output}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('message','Unknown error'))" 2>/dev/null) || error_msg="Unknown error"
1106
+ echo "❌ ${error_msg}"
1107
+ exit 1
533
1108
  fi
1109
+
1110
+ # Success — persist job name to do/config
1111
+ _update_config_var "TRAIN_JOB_NAME" "${JOB_NAME}"
1112
+ echo " ✅ Job submitted successfully: ${JOB_NAME}"
1113
+ echo ""
534
1114
  }
535
1115
 
536
1116
  # ── _poll_job() ───────────────────────────────────────────────────────────────
537
- # Poll DescribeTrainingJob every POLL_INTERVAL seconds until terminal state.
538
- # Displays: job status, secondary status, elapsed time, and training metrics.
1117
+ # Poll training job status via .train_helper.py every POLL_INTERVAL seconds.
1118
+ # Displays: job status, secondary status, elapsed time.
539
1119
  # On Completed: breaks loop and returns (caller handles completion).
540
1120
  # On Failed: displays FailureReason and exits 2.
541
1121
  # On Stopped: displays stopped message and exits 2.
542
- # On spot interruption: explains auto-resume from checkpoint.
543
1122
  _poll_job() {
544
1123
  local job_name="${JOB_NAME:-$TRAIN_JOB_NAME}"
545
1124
 
@@ -548,49 +1127,39 @@ _poll_job() {
548
1127
  echo ""
549
1128
 
550
1129
  while true; do
551
- # Call DescribeTrainingJob
552
- local describe_output
553
- local describe_exit_code
554
- describe_output=$(aws sagemaker describe-training-job \
555
- --training-job-name "${job_name}" 2>&1) || describe_exit_code=$?
556
- describe_exit_code=${describe_exit_code:-0}
557
-
558
- if [ ${describe_exit_code} -ne 0 ]; then
559
- echo "⚠️ Failed to describe job (will retry): ${describe_output}"
1130
+ # Query status via SDK v3 helper
1131
+ local status_json
1132
+ status_json=$(python3 "${SCRIPT_DIR}/.train_helper.py" status \
1133
+ --job-name "${job_name}" \
1134
+ --region "${AWS_REGION}" 2>/dev/null | grep -E '^\{' | tail -1)
1135
+
1136
+ if [ -z "${status_json}" ]; then
1137
+ echo "⚠️ Failed to query job status (will retry)"
560
1138
  sleep "${POLL_INTERVAL}"
561
1139
  continue
562
1140
  fi
563
1141
 
564
- # Parse the response using python3 helper
565
- local poll_result
566
- poll_result=$(echo "${describe_output}" | python3 "${SCRIPT_DIR}/.train_poll_parser.py" 2>&1)
567
- local parse_exit_code=$?
1142
+ # Check for error
1143
+ local has_error
1144
+ has_error=$(echo "${status_json}" | python3 -c "import sys,json; d=json.load(sys.stdin); print('yes' if d.get('error') else 'no')" 2>/dev/null) || has_error="no"
568
1145
 
569
- if [ ${parse_exit_code} -ne 0 ]; then
570
- echo "⚠️ Failed to parse job status (will retry): ${poll_result}"
1146
+ if [ "${has_error}" = "yes" ]; then
1147
+ echo "⚠️ Status query error (will retry)"
571
1148
  sleep "${POLL_INTERVAL}"
572
1149
  continue
573
1150
  fi
574
1151
 
575
- # The parser outputs structured lines:
576
- # STATUS=<status>
577
- # SECONDARY=<secondary_status>
578
- # ELAPSED=<elapsed_string>
579
- # METRICS=<metrics_string>
580
- # FAILURE_REASON=<reason>
581
- # DISPLAY=<formatted display text>
582
- local job_status=""
583
- local secondary_status=""
584
- local display_text=""
585
- local failure_reason=""
586
-
587
- job_status=$(echo "${poll_result}" | grep '^STATUS=' | cut -d= -f2-)
588
- secondary_status=$(echo "${poll_result}" | grep '^SECONDARY=' | cut -d= -f2-)
589
- failure_reason=$(echo "${poll_result}" | grep '^FAILURE_REASON=' | cut -d= -f2-)
590
- display_text=$(echo "${poll_result}" | grep '^DISPLAY=' | cut -d= -f2-)
1152
+ # Extract fields
1153
+ local job_status display_text failure_reason secondary_status
1154
+ job_status=$(echo "${status_json}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('status','Unknown'))" 2>/dev/null) || job_status="Unknown"
1155
+ display_text=$(echo "${status_json}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('display',''))" 2>/dev/null) || display_text=""
1156
+ failure_reason=$(echo "${status_json}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('failure_reason','') or '')" 2>/dev/null) || failure_reason=""
1157
+ secondary_status=$(echo "${status_json}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('secondary_status','') or '')" 2>/dev/null) || secondary_status=""
591
1158
 
592
1159
  # Print the formatted status line
593
- echo "${display_text}"
1160
+ if [ -n "${display_text}" ]; then
1161
+ echo "${display_text}"
1162
+ fi
594
1163
 
595
1164
  # Handle terminal states
596
1165
  case "${job_status}" in
@@ -619,7 +1188,7 @@ _poll_job() {
619
1188
  ;;
620
1189
  esac
621
1190
 
622
- # Handle spot interruption (job still InProgress but interrupted)
1191
+ # Handle spot interruption
623
1192
  if echo "${secondary_status}" | grep -qi "interrupted"; then
624
1193
  echo ""
625
1194
  echo " ℹ️ Spot instance interrupted. The job will automatically resume"
@@ -631,9 +1200,10 @@ _poll_job() {
631
1200
  sleep "${POLL_INTERVAL}"
632
1201
  done
633
1202
  }
1203
+ <% } %>
634
1204
 
635
1205
  # ── _handle_completion() ──────────────────────────────────────────────────────
636
- # Store output paths and invoke feedback loop.
1206
+ # Store output paths, write TRAIN_* lifecycle vars, and invoke feedback loop.
637
1207
  # Extracts model artifacts path, detects output type, and prints next steps.
638
1208
  _handle_completion() {
639
1209
  local job_name="${JOB_NAME:-$TRAIN_JOB_NAME}"
@@ -666,8 +1236,15 @@ print(artifacts.get('S3ModelArtifacts', ''))
666
1236
  return 1
667
1237
  fi
668
1238
 
669
- # Write TRAIN_OUTPUT_PATH to do/config
670
- _update_config_var "TRAIN_OUTPUT_PATH" "${output_path}"
1239
+ # ── Write TRAIN_* lifecycle variables to do/config ────────────────────────
1240
+ _update_config_var "TRAIN_OUTPUT_PATH_LATEST" "${output_path}"
1241
+ _update_config_var "TRAIN_JOB_NAME" "${job_name}"
1242
+
1243
+ # Write technique-specific adapter path
1244
+ local technique_upper
1245
+ technique_upper=$(echo "${TECHNIQUE:-custom}" | tr '[:lower:]' '[:upper:]')
1246
+ _update_config_var "TRAIN_ADAPTER_PATH_${technique_upper}" "${output_path}"
1247
+ _update_config_var "TRAIN_TECHNIQUE" "${TECHNIQUE:-custom}"
671
1248
 
672
1249
  # Detect output type: check for adapter_config.json in output path
673
1250
  local output_type="full-model"
@@ -679,6 +1256,23 @@ print(artifacts.get('S3ModelArtifacts', ''))
679
1256
  source "${SCRIPT_DIR}/lib/feedback.sh"
680
1257
  print_completion_feedback "${output_path}" "${output_type}" "${job_name}"
681
1258
 
1259
+ # ── Auto-register (unless --no-register) ─────────────────────────────────
1260
+ if [ "${ARG_NO_REGISTER}" != true ] && [ -f "${SCRIPT_DIR}/register" ]; then
1261
+ if [ -n "${RESOLVED_DATASET_S3_URI}" ]; then
1262
+ echo "📝 Auto-registering dataset for technique '${TECHNIQUE}'..."
1263
+ "${SCRIPT_DIR}/register" dataset --from-train "${TECHNIQUE}" 2>/dev/null || {
1264
+ echo " ⚠️ Auto-registration skipped (non-fatal)"
1265
+ }
1266
+ fi
1267
+ fi
1268
+
1269
+ # ── Print next steps ─────────────────────────────────────────────────────
1270
+ if [ "${output_type}" = "adapter" ]; then
1271
+ echo ""
1272
+ echo " Next: ./do/adapter --from-train ${TECHNIQUE}"
1273
+ echo " to stage the adapter for deployment."
1274
+ fi
1275
+
682
1276
  # If spot training was enabled, display cost savings
683
1277
  if [ "${TRAIN_ENABLE_SPOT:-false}" = "true" ]; then
684
1278
  local billable_seconds training_seconds savings_pct
@@ -735,6 +1329,12 @@ if [ "${ARG_HELP}" = true ]; then
735
1329
  _show_help
736
1330
  fi
737
1331
 
1332
+ # Handle --list-datasets
1333
+ if [ "${ARG_LIST_DATASETS}" = true ]; then
1334
+ source "${SCRIPT_DIR}/config"
1335
+ _list_datasets_cmd
1336
+ fi
1337
+
738
1338
  if [ "${ARG_STATUS}" = true ]; then
739
1339
  # Show status of current tracked job without submitting
740
1340
  if [ -z "${TRAIN_JOB_NAME:-}" ]; then
@@ -746,28 +1346,181 @@ if [ "${ARG_STATUS}" = true ]; then
746
1346
  echo "📊 Training Job Status"
747
1347
  echo " Job: ${TRAIN_JOB_NAME}"
748
1348
 
749
- # Call DescribeTrainingJob and parse the response
750
- STATUS_JSON=$(aws sagemaker describe-training-job \
751
- --training-job-name "${TRAIN_JOB_NAME}" \
752
- --region "${AWS_REGION}" 2>&1) || {
1349
+ # Query status via SDK v3 helper
1350
+ STATUS_RESULT=$(python3 "${SCRIPT_DIR}/.train_helper.py" status \
1351
+ --job-name "${TRAIN_JOB_NAME}" \
1352
+ --region "${AWS_REGION}" 2>/dev/null | grep -E '^\{' | tail -1) || {
753
1353
  echo ""
754
- echo "❌ Failed to describe training job: ${TRAIN_JOB_NAME}"
755
- echo " ${STATUS_JSON}"
756
- echo ""
757
- echo " The job may have been deleted or the name is incorrect."
1354
+ echo "❌ Failed to query training job status: ${TRAIN_JOB_NAME}"
1355
+ echo " Ensure sagemaker SDK v3 is installed: pip install 'sagemaker>=3.0'"
758
1356
  echo " Run ./do/train --force to start a new job."
759
1357
  exit 1
760
1358
  }
761
1359
 
762
- # Parse and display the status using the helper script
763
- echo "${STATUS_JSON}" | python3 "${SCRIPT_DIR}/.train_status_parser.py"
1360
+ if [ -z "${STATUS_RESULT}" ]; then
1361
+ echo ""
1362
+ echo "❌ No response from status helper for: ${TRAIN_JOB_NAME}"
1363
+ echo " Run ./do/train --force to start a new job."
1364
+ exit 1
1365
+ fi
1366
+
1367
+ # Display the status info
1368
+ DISPLAY_TEXT=$(echo "${STATUS_RESULT}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('display',''))" 2>/dev/null) || DISPLAY_TEXT=""
1369
+ if [ -n "${DISPLAY_TEXT}" ]; then
1370
+ echo "${DISPLAY_TEXT}"
1371
+ fi
1372
+
1373
+ # Show additional details for completed jobs
1374
+ STATUS_VAL=$(echo "${STATUS_RESULT}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('status','Unknown'))" 2>/dev/null) || STATUS_VAL="Unknown"
1375
+ if [ "${STATUS_VAL}" = "Completed" ]; then
1376
+ ARTIFACTS_VAL=$(echo "${STATUS_RESULT}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('model_artifacts','') or '')" 2>/dev/null) || ARTIFACTS_VAL=""
1377
+ if [ -n "${ARTIFACTS_VAL}" ]; then
1378
+ echo " 📦 Artifacts: ${ARTIFACTS_VAL}"
1379
+ fi
1380
+ fi
1381
+ echo ""
764
1382
  exit 0
765
1383
  fi
766
1384
 
1385
+ # ── Handle --interactive mode ─────────────────────────────────────────────────
1386
+ if [ "${ARG_INTERACTIVE}" = true ]; then
1387
+ TRAINING_DIR="${SCRIPT_DIR}/training"
1388
+ _RESULT_FILE=$(mktemp "${TMPDIR:-/tmp}/mlcc-train-interactive.XXXXXX")
1389
+
1390
+ # Invoke the interactive builder via Node.js
1391
+ # Runs interactively (TTY prompts rendered to terminal), writes JSON result to temp file
1392
+ INTERACTIVE_RESULT=""
1393
+ if command -v node &>/dev/null; then
1394
+ _NPM_GLOBAL=$(npm root -g 2>/dev/null || echo /usr/local/lib/node_modules)
1395
+ _BUILDER_PATH="${_NPM_GLOBAL}/@aws/ml-container-creator/src/lib/train-config-builder.js"
1396
+
1397
+ # Fallback paths
1398
+ _project_root=$(cd "${SCRIPT_DIR}/.." && pwd)
1399
+ if [ ! -f "${_BUILDER_PATH}" ]; then
1400
+ _BUILDER_PATH="${_project_root}/node_modules/@aws/ml-container-creator/src/lib/train-config-builder.js"
1401
+ fi
1402
+ if [ ! -f "${_BUILDER_PATH}" ]; then
1403
+ _BUILDER_PATH="${_project_root}/src/lib/train-config-builder.js"
1404
+ fi
1405
+
1406
+ if [ -f "${_BUILDER_PATH}" ]; then
1407
+ # Run interactively (no stdout capture — prompts render to TTY)
1408
+ node --input-type=module -e "
1409
+ import { run } from '${_BUILDER_PATH}';
1410
+ const result = await run({ configFile: '${CONFIG_FILE}', trainingDir: '${TRAINING_DIR}' });
1411
+ import { writeFileSync } from 'node:fs';
1412
+ writeFileSync('${_RESULT_FILE}', JSON.stringify(result));
1413
+ " || true
1414
+ INTERACTIVE_RESULT=$(cat "${_RESULT_FILE}" 2>/dev/null || echo "")
1415
+ fi
1416
+ fi
1417
+
1418
+ rm -f "${_RESULT_FILE}"
1419
+
1420
+ if [ -z "${INTERACTIVE_RESULT}" ]; then
1421
+ echo "❌ Interactive mode requires Node.js and @aws/ml-container-creator installed."
1422
+ echo " Install: npm install -g @aws/ml-container-creator"
1423
+ echo ""
1424
+ echo " Alternatively, edit do/training/config.yaml directly."
1425
+ exit 1
1426
+ fi
1427
+
1428
+ # Check if user wants to run now
1429
+ RUN_NOW=$(echo "${INTERACTIVE_RESULT}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('run_now', False))" 2>/dev/null) || RUN_NOW="False"
1430
+
1431
+ if [ "${RUN_NOW}" != "True" ]; then
1432
+ echo " Run ./do/train to submit the configured job."
1433
+ exit 0
1434
+ fi
1435
+
1436
+ # Continue to normal submission flow (config was already written by builder)
1437
+ echo "🚀 Proceeding to job submission..."
1438
+ echo ""
1439
+ fi
1440
+
767
1441
  # Parse and validate configuration
768
1442
  _parse_config
1443
+
1444
+ # Resolve shell variables in image URI (backward compat with old-style config.yaml
1445
+ # that has ${AWS_ACCOUNT_ID}... instead of resolved values)
1446
+ if echo "${TRAIN_IMAGE}" | grep -q '${'; then
1447
+ TRAIN_IMAGE=$(eval echo "${TRAIN_IMAGE}" 2>/dev/null || echo "${TRAIN_IMAGE}")
1448
+ fi
1449
+
769
1450
  _validate_config
770
1451
 
1452
+ # Resolve technique (CLI flag > config.yaml > default)
1453
+ _resolve_technique
1454
+
1455
+ # Persist technique to do/config
1456
+ _update_config_var "TRAIN_TECHNIQUE" "${TECHNIQUE}"
1457
+
1458
+ # Load technique defaults and merge with config hyperparameters
1459
+ _load_technique_defaults
1460
+ _merge_hyperparameters
1461
+
1462
+ # Warn about multi-node on non-EFA instances
1463
+ if [ "${TRAIN_INSTANCE_COUNT:-1}" -gt 1 ]; then
1464
+ echo "📡 Multi-node training: ${TRAIN_INSTANCE_COUNT} instances"
1465
+ # Check for EFA-capable instance types
1466
+ case "${TRAIN_INSTANCE_TYPE}" in
1467
+ ml.p4d*|ml.p5*|ml.g5.48xlarge|ml.g6e.48xlarge|ml.trn1*)
1468
+ echo " ✅ EFA-capable instance: ${TRAIN_INSTANCE_TYPE}"
1469
+ ;;
1470
+ *)
1471
+ echo " ⚠️ Instance type '${TRAIN_INSTANCE_TYPE}' may not support EFA."
1472
+ echo " Multi-node training will use TCP networking (slower)."
1473
+ echo " For best performance, use: p4d.24xlarge, p5.48xlarge, or g5.48xlarge"
1474
+ ;;
1475
+ esac
1476
+ echo ""
1477
+ fi
1478
+
1479
+ # Resolve dataset (if --dataset provided or config.yaml has dataset)
1480
+ _resolve_dataset
1481
+
1482
+ # Update TRAIN_DATASET with resolved S3 URI for the job request
1483
+ if [ -n "${RESOLVED_DATASET_S3_URI}" ]; then
1484
+ TRAIN_DATASET="${RESOLVED_DATASET_S3_URI}"
1485
+ fi
1486
+
1487
+ # Handle --resume: resolve checkpoint path from previous job
1488
+ if [ "${ARG_RESUME}" = true ]; then
1489
+ RESUME_JOB="${ARG_RESUME_JOB:-${TRAIN_JOB_NAME:-}}"
1490
+ if [ -z "${RESUME_JOB}" ]; then
1491
+ echo "❌ --resume requires a previous job name."
1492
+ echo " Provide it: ./do/train --resume <job-name>"
1493
+ echo " Or run a training job first (TRAIN_JOB_NAME will be set in do/config)."
1494
+ exit 1
1495
+ fi
1496
+
1497
+ echo "🔄 Resuming from job: ${RESUME_JOB}"
1498
+
1499
+ # Resolve checkpoint path via helper
1500
+ CHECKPOINT_RESOLVE=$(python3 "${SCRIPT_DIR}/.train_helper.py" resolve \
1501
+ --job-name "${RESUME_JOB}" \
1502
+ --checkpoints \
1503
+ --region "${AWS_REGION}" 2>/dev/null | grep -E '^\{' | tail -1) || CHECKPOINT_RESOLVE=""
1504
+
1505
+ if [ -n "${CHECKPOINT_RESOLVE}" ]; then
1506
+ RESUME_CHECKPOINT_PATH=$(echo "${CHECKPOINT_RESOLVE}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('checkpoint_path',''))" 2>/dev/null) || RESUME_CHECKPOINT_PATH=""
1507
+ if [ -n "${RESUME_CHECKPOINT_PATH}" ]; then
1508
+ echo " Checkpoint: ${RESUME_CHECKPOINT_PATH}"
1509
+ TRAIN_CHECKPOINT_PATH="${RESUME_CHECKPOINT_PATH}"
1510
+ else
1511
+ echo " ⚠️ No checkpoint path found for job ${RESUME_JOB}"
1512
+ echo " Training will start from scratch."
1513
+ fi
1514
+ else
1515
+ echo " ⚠️ Could not resolve checkpoints for job ${RESUME_JOB}"
1516
+ echo " Training will start from scratch."
1517
+ fi
1518
+ echo ""
1519
+
1520
+ # Force new job creation when resuming
1521
+ ARG_FORCE=true
1522
+ fi
1523
+
771
1524
  # Check idempotency (existing job handling)
772
1525
  _check_idempotency
773
1526