@aws/ml-container-creator 1.0.2 → 1.0.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +1 -1
- package/bin/cli.js +1 -1
- package/config/tune-catalog.json +303 -1
- package/infra/ci-harness/lib/ci-harness-stack.ts +43 -0
- package/package.json +3 -2
- package/servers/base-image-picker/index.js +65 -18
- package/servers/instance-sizer/index.js +32 -0
- package/servers/lib/catalogs/fleet-drivers.json +38 -0
- package/servers/lib/catalogs/model-arch-support.json +51 -0
- package/servers/lib/catalogs/model-servers.json +2842 -1516
- package/servers/lib/schemas/image-catalog.schema.json +12 -0
- package/src/app.js +6 -4
- package/src/lib/bootstrap-command-handler.js +12 -2
- package/src/lib/bootstrap-profile-manager.js +16 -0
- package/src/lib/cross-cutting-checker.js +6 -1
- package/src/lib/generated/cli-options.js +1 -1
- package/src/lib/generated/parameter-matrix.js +1 -1
- package/src/lib/generated/validation-rules.js +1 -1
- package/src/lib/mcp-query-runner.js +110 -3
- package/src/lib/prompt-runner.js +66 -22
- package/src/lib/template-variable-resolver.js +8 -0
- package/src/lib/train-config-builder.js +339 -0
- package/templates/do/.benchmark_writer.py +3 -0
- package/templates/do/.eval_helper.py +409 -0
- package/templates/do/.register_helper.py +185 -11
- package/templates/do/.train_build_request.py +102 -113
- package/templates/do/.train_helper.py +433 -0
- package/templates/do/__pycache__/.register_helper.cpython-312.pyc +0 -0
- package/templates/do/adapter +157 -0
- package/templates/do/benchmark +60 -3
- package/templates/do/deploy.d/managed-inference.ejs +83 -0
- package/templates/do/evaluate +272 -0
- package/templates/do/lib/resolve-instance.sh +155 -0
- package/templates/do/register +5 -0
- package/templates/do/test +1 -0
- package/templates/do/train +879 -126
- package/templates/do/training/config.yaml +83 -11
- package/templates/do/training/dpo/accelerate_config.yaml +24 -0
- package/templates/do/training/dpo/defaults.yaml +26 -0
- package/templates/do/training/dpo/prompts.json +8 -0
- package/templates/do/training/dpo/train.py +363 -0
- package/templates/do/training/sft/accelerate_config.yaml +22 -0
- package/templates/do/training/sft/defaults.yaml +18 -0
- package/templates/do/training/sft/prompts.json +7 -0
- package/templates/do/training/sft/train.py +310 -0
- package/templates/do/tune +11 -2
- package/templates/do/.train_poll_parser.py +0 -135
- package/templates/do/.train_status_parser.py +0 -187
- /package/templates/do/training/{train.py → custom/train.py} +0 -0
package/templates/do/train
CHANGED
|
@@ -16,6 +16,8 @@ set -o pipefail
|
|
|
16
16
|
# ── Source project configuration ──────────────────────────────────────────────
|
|
17
17
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
|
18
18
|
source "${SCRIPT_DIR}/config"
|
|
19
|
+
source "${SCRIPT_DIR}/lib/profile.sh"
|
|
20
|
+
source "${SCRIPT_DIR}/lib/resolve-instance.sh"
|
|
19
21
|
|
|
20
22
|
# ── Constants ─────────────────────────────────────────────────────────────────
|
|
21
23
|
CONFIG_FILE="${SCRIPT_DIR}/training/config.yaml"
|
|
@@ -27,11 +29,30 @@ ARG_STATUS=false
|
|
|
27
29
|
ARG_DRY_RUN=false
|
|
28
30
|
ARG_NO_WAIT=false
|
|
29
31
|
ARG_HELP=false
|
|
32
|
+
ARG_TECHNIQUE=""
|
|
33
|
+
ARG_DATASET=""
|
|
34
|
+
ARG_LIST_DATASETS=false
|
|
35
|
+
ARG_NO_REGISTER=false
|
|
36
|
+
ARG_INTERACTIVE=false
|
|
37
|
+
ARG_LEARNING_RATE=""
|
|
38
|
+
ARG_EPOCHS=""
|
|
39
|
+
ARG_BATCH_SIZE=""
|
|
40
|
+
ARG_LORA_R=""
|
|
41
|
+
ARG_BETA=""
|
|
42
|
+
ARG_RESUME=false
|
|
43
|
+
ARG_RESUME_JOB=""
|
|
30
44
|
|
|
31
45
|
# ── Job Variables (set by _build_job_request) ─────────────────────────────────
|
|
32
46
|
JOB_NAME=""
|
|
33
47
|
JOB_REQUEST_FILE=""
|
|
34
48
|
|
|
49
|
+
# ── Technique Variables (set by _resolve_technique) ───────────────────────────
|
|
50
|
+
TECHNIQUE=""
|
|
51
|
+
TRAIN_SCRIPT_PATH=""
|
|
52
|
+
|
|
53
|
+
# ── Dataset Variables (set by _resolve_dataset) ───────────────────────────────
|
|
54
|
+
RESOLVED_DATASET_S3_URI=""
|
|
55
|
+
|
|
35
56
|
# ── Training Config Variables (set by _parse_config) ──────────────────────────
|
|
36
57
|
TRAIN_IMAGE=""
|
|
37
58
|
TRAIN_SCRIPT=""
|
|
@@ -70,6 +91,46 @@ _parse_args() {
|
|
|
70
91
|
--status) ARG_STATUS=true; shift ;;
|
|
71
92
|
--dry-run) ARG_DRY_RUN=true; shift ;;
|
|
72
93
|
--no-wait) ARG_NO_WAIT=true; shift ;;
|
|
94
|
+
--technique)
|
|
95
|
+
if [ -z "${2:-}" ]; then
|
|
96
|
+
echo "❌ --technique requires a value"
|
|
97
|
+
echo " Available: $(_list_techniques)"
|
|
98
|
+
exit 1
|
|
99
|
+
fi
|
|
100
|
+
ARG_TECHNIQUE="$2"; shift 2 ;;
|
|
101
|
+
--dataset)
|
|
102
|
+
if [ -z "${2:-}" ]; then
|
|
103
|
+
echo "❌ --dataset requires a value (s3://..., hf://..., or registry name)"
|
|
104
|
+
exit 1
|
|
105
|
+
fi
|
|
106
|
+
ARG_DATASET="$2"; shift 2 ;;
|
|
107
|
+
--list-datasets) ARG_LIST_DATASETS=true; shift ;;
|
|
108
|
+
--no-register) ARG_NO_REGISTER=true; shift ;;
|
|
109
|
+
--interactive|-i) ARG_INTERACTIVE=true; shift ;;
|
|
110
|
+
--learning-rate)
|
|
111
|
+
if [ -z "${2:-}" ]; then echo "❌ --learning-rate requires a value"; exit 1; fi
|
|
112
|
+
ARG_LEARNING_RATE="$2"; shift 2 ;;
|
|
113
|
+
--epochs)
|
|
114
|
+
if [ -z "${2:-}" ]; then echo "❌ --epochs requires a value"; exit 1; fi
|
|
115
|
+
ARG_EPOCHS="$2"; shift 2 ;;
|
|
116
|
+
--batch-size)
|
|
117
|
+
if [ -z "${2:-}" ]; then echo "❌ --batch-size requires a value"; exit 1; fi
|
|
118
|
+
ARG_BATCH_SIZE="$2"; shift 2 ;;
|
|
119
|
+
--lora-r)
|
|
120
|
+
if [ -z "${2:-}" ]; then echo "❌ --lora-r requires a value"; exit 1; fi
|
|
121
|
+
ARG_LORA_R="$2"; shift 2 ;;
|
|
122
|
+
--beta)
|
|
123
|
+
if [ -z "${2:-}" ]; then echo "❌ --beta requires a value"; exit 1; fi
|
|
124
|
+
ARG_BETA="$2"; shift 2 ;;
|
|
125
|
+
--resume)
|
|
126
|
+
ARG_RESUME=true
|
|
127
|
+
# Optional: next arg is a job name (not another flag)
|
|
128
|
+
if [ -n "${2:-}" ] && [[ "${2}" != -* ]]; then
|
|
129
|
+
ARG_RESUME_JOB="$2"; shift 2
|
|
130
|
+
else
|
|
131
|
+
shift
|
|
132
|
+
fi
|
|
133
|
+
;;
|
|
73
134
|
--help|-h) ARG_HELP=true; shift ;;
|
|
74
135
|
*)
|
|
75
136
|
echo "❌ Unknown option: $1"
|
|
@@ -83,6 +144,7 @@ _parse_args() {
|
|
|
83
144
|
# ── _show_help() ──────────────────────────────────────────────────────────────
|
|
84
145
|
_show_help() {
|
|
85
146
|
echo "Usage: ./do/train [OPTIONS]"
|
|
147
|
+
echo " ./do/train --technique sft"
|
|
86
148
|
echo " ./do/train --status"
|
|
87
149
|
echo " ./do/train --help"
|
|
88
150
|
echo ""
|
|
@@ -92,20 +154,372 @@ _show_help() {
|
|
|
92
154
|
echo "Configuration is read from do/training/config.yaml"
|
|
93
155
|
echo ""
|
|
94
156
|
echo "Options:"
|
|
95
|
-
echo " --
|
|
96
|
-
echo " --
|
|
97
|
-
echo " --
|
|
98
|
-
echo " --
|
|
99
|
-
echo " --
|
|
157
|
+
echo " --technique <name> Select training technique (overrides config.yaml)"
|
|
158
|
+
echo " --dataset <source> Dataset: s3://..., hf://org/name, or registry name"
|
|
159
|
+
echo " --interactive, -i Guided interactive configuration builder"
|
|
160
|
+
echo " --learning-rate <v> Override learning rate"
|
|
161
|
+
echo " --epochs <n> Override number of epochs"
|
|
162
|
+
echo " --batch-size <n> Override per-device batch size"
|
|
163
|
+
echo " --lora-r <n> Override LoRA rank"
|
|
164
|
+
echo " --beta <v> Override DPO beta (KL penalty)"
|
|
165
|
+
echo " --resume [job-name] Resume from previous job's checkpoint"
|
|
166
|
+
echo " --list-datasets Show registered datasets"
|
|
167
|
+
echo " --no-register Skip auto-registration after completion"
|
|
168
|
+
echo " --force Create a new job even if a previous job exists"
|
|
169
|
+
echo " --status Show current job status without submitting"
|
|
170
|
+
echo " --dry-run Validate inputs and show the request without submitting"
|
|
171
|
+
echo " --no-wait Submit job and exit without polling for completion"
|
|
172
|
+
echo " --help, -h Show this help message"
|
|
173
|
+
echo ""
|
|
174
|
+
echo "Available techniques:"
|
|
175
|
+
echo " $(_list_techniques)"
|
|
100
176
|
echo ""
|
|
101
177
|
echo "Examples:"
|
|
102
|
-
echo " ./do/train
|
|
103
|
-
echo " ./do/train --
|
|
104
|
-
echo " ./do/train --
|
|
105
|
-
echo " ./do/train --
|
|
178
|
+
echo " ./do/train # Submit using config.yaml technique"
|
|
179
|
+
echo " ./do/train --technique sft # Submit SFT training job"
|
|
180
|
+
echo " ./do/train --status # Check status of current job"
|
|
181
|
+
echo " ./do/train --dry-run # Validate and preview request"
|
|
182
|
+
echo " ./do/train --force # Force a new job after failure"
|
|
106
183
|
exit 0
|
|
107
184
|
}
|
|
108
185
|
|
|
186
|
+
# ── _list_techniques() ────────────────────────────────────────────────────────
|
|
187
|
+
# Scan training/<technique>/train.py directories to list available techniques.
|
|
188
|
+
_list_techniques() {
|
|
189
|
+
local training_dir="${SCRIPT_DIR}/training"
|
|
190
|
+
local techniques=""
|
|
191
|
+
for dir in "${training_dir}"/*/; do
|
|
192
|
+
if [ -f "${dir}train.py" ]; then
|
|
193
|
+
local name
|
|
194
|
+
name=$(basename "${dir}")
|
|
195
|
+
if [ -n "${techniques}" ]; then
|
|
196
|
+
techniques="${techniques}, ${name}"
|
|
197
|
+
else
|
|
198
|
+
techniques="${name}"
|
|
199
|
+
fi
|
|
200
|
+
fi
|
|
201
|
+
done
|
|
202
|
+
echo "${techniques:-none found}"
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
# ── _resolve_technique() ─────────────────────────────────────────────────────
|
|
206
|
+
# Determine technique from CLI flag or config.yaml, validate directory exists.
|
|
207
|
+
# Sets TECHNIQUE and TRAIN_SCRIPT_PATH.
|
|
208
|
+
_resolve_technique() {
|
|
209
|
+
local training_dir="${SCRIPT_DIR}/training"
|
|
210
|
+
|
|
211
|
+
# Priority: CLI flag > config.yaml > default (custom)
|
|
212
|
+
if [ -n "${ARG_TECHNIQUE}" ]; then
|
|
213
|
+
TECHNIQUE="${ARG_TECHNIQUE}"
|
|
214
|
+
elif command -v yq &>/dev/null; then
|
|
215
|
+
TECHNIQUE=$(yq -r '.technique // "custom"' "${CONFIG_FILE}")
|
|
216
|
+
elif command -v python3 &>/dev/null; then
|
|
217
|
+
TECHNIQUE=$(python3 -c "
|
|
218
|
+
import yaml
|
|
219
|
+
with open('${CONFIG_FILE}', 'r') as f:
|
|
220
|
+
cfg = yaml.safe_load(f) or {}
|
|
221
|
+
print(cfg.get('technique', 'custom'))
|
|
222
|
+
" 2>/dev/null || echo "custom")
|
|
223
|
+
else
|
|
224
|
+
TECHNIQUE="custom"
|
|
225
|
+
fi
|
|
226
|
+
|
|
227
|
+
# Fallback if empty
|
|
228
|
+
TECHNIQUE="${TECHNIQUE:-custom}"
|
|
229
|
+
|
|
230
|
+
# Validate technique directory exists
|
|
231
|
+
TRAIN_SCRIPT_PATH="${training_dir}/${TECHNIQUE}/train.py"
|
|
232
|
+
if [ ! -f "${TRAIN_SCRIPT_PATH}" ]; then
|
|
233
|
+
echo "❌ Training technique '${TECHNIQUE}' not found"
|
|
234
|
+
echo " Expected: ${TRAIN_SCRIPT_PATH}"
|
|
235
|
+
echo ""
|
|
236
|
+
echo " Available techniques: $(_list_techniques)"
|
|
237
|
+
echo ""
|
|
238
|
+
echo " To add a new technique, create: do/training/${TECHNIQUE}/train.py"
|
|
239
|
+
exit 1
|
|
240
|
+
fi
|
|
241
|
+
|
|
242
|
+
echo "📋 Technique: ${TECHNIQUE}"
|
|
243
|
+
echo " Script: ${TRAIN_SCRIPT_PATH}"
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
# ── _load_technique_defaults() ────────────────────────────────────────────────
|
|
247
|
+
# Load technique-specific default hyperparameters from training/<technique>/defaults.yaml.
|
|
248
|
+
# Returns values via TECHNIQUE_DEFAULTS associative-style extraction (yq or python).
|
|
249
|
+
# Sets TECHNIQUE_DEFAULTS_JSON (JSON string of defaults).
|
|
250
|
+
_load_technique_defaults() {
|
|
251
|
+
local defaults_file="${SCRIPT_DIR}/training/${TECHNIQUE}/defaults.yaml"
|
|
252
|
+
TECHNIQUE_DEFAULTS_JSON="{}"
|
|
253
|
+
|
|
254
|
+
if [ ! -f "${defaults_file}" ]; then
|
|
255
|
+
return 0
|
|
256
|
+
fi
|
|
257
|
+
|
|
258
|
+
if command -v yq &>/dev/null; then
|
|
259
|
+
TECHNIQUE_DEFAULTS_JSON=$(yq -o=json '.' "${defaults_file}" 2>/dev/null) || TECHNIQUE_DEFAULTS_JSON="{}"
|
|
260
|
+
elif command -v python3 &>/dev/null; then
|
|
261
|
+
TECHNIQUE_DEFAULTS_JSON=$(python3 -c "
|
|
262
|
+
import yaml, json
|
|
263
|
+
with open('${defaults_file}', 'r') as f:
|
|
264
|
+
data = yaml.safe_load(f) or {}
|
|
265
|
+
# Remove comment-only entries
|
|
266
|
+
data = {k: v for k, v in data.items() if v is not None}
|
|
267
|
+
print(json.dumps(data))
|
|
268
|
+
" 2>/dev/null) || TECHNIQUE_DEFAULTS_JSON="{}"
|
|
269
|
+
fi
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
# ── _merge_hyperparameters() ──────────────────────────────────────────────────
|
|
273
|
+
# Merge hyperparameters with precedence: CLI flags > config.yaml > technique defaults.
|
|
274
|
+
# Modifies TRAIN_HYPERPARAMS (JSON string) in place.
|
|
275
|
+
_merge_hyperparameters() {
|
|
276
|
+
# Merge: technique defaults as base, config.yaml on top, CLI flags on top of that
|
|
277
|
+
TRAIN_HYPERPARAMS=$(python3 -c "
|
|
278
|
+
import json, sys
|
|
279
|
+
|
|
280
|
+
defaults_json = '''${TECHNIQUE_DEFAULTS_JSON}'''
|
|
281
|
+
config_json = '''${TRAIN_HYPERPARAMS}'''
|
|
282
|
+
|
|
283
|
+
defaults = json.loads(defaults_json) if defaults_json and defaults_json != '{}' else {}
|
|
284
|
+
config_hp = json.loads(config_json) if config_json and config_json != '{}' else {}
|
|
285
|
+
|
|
286
|
+
# Layer 1: technique defaults as base
|
|
287
|
+
merged = {str(k): str(v) for k, v in defaults.items()}
|
|
288
|
+
|
|
289
|
+
# Layer 2: config.yaml hyperparameters override defaults
|
|
290
|
+
merged.update({str(k): str(v) for k, v in config_hp.items()})
|
|
291
|
+
|
|
292
|
+
# Layer 3: CLI flags override everything (highest priority)
|
|
293
|
+
cli_overrides = {}
|
|
294
|
+
if '${ARG_LEARNING_RATE}':
|
|
295
|
+
cli_overrides['learning_rate'] = '${ARG_LEARNING_RATE}'
|
|
296
|
+
if '${ARG_EPOCHS}':
|
|
297
|
+
cli_overrides['epochs'] = '${ARG_EPOCHS}'
|
|
298
|
+
if '${ARG_BATCH_SIZE}':
|
|
299
|
+
cli_overrides['batch_size'] = '${ARG_BATCH_SIZE}'
|
|
300
|
+
if '${ARG_LORA_R}':
|
|
301
|
+
cli_overrides['lora_r'] = '${ARG_LORA_R}'
|
|
302
|
+
if '${ARG_BETA}':
|
|
303
|
+
cli_overrides['beta'] = '${ARG_BETA}'
|
|
304
|
+
|
|
305
|
+
merged.update(cli_overrides)
|
|
306
|
+
print(json.dumps(merged))
|
|
307
|
+
" 2>/dev/null) || true
|
|
308
|
+
}
|
|
309
|
+
|
|
310
|
+
# ── _list_datasets_cmd() ─────────────────────────────────────────────────────
|
|
311
|
+
# Show registered datasets (delegates to .register_helper.py).
|
|
312
|
+
_list_datasets_cmd() {
|
|
313
|
+
echo "📋 Registered Datasets"
|
|
314
|
+
echo ""
|
|
315
|
+
if [ -f "${SCRIPT_DIR}/.register_helper.py" ]; then
|
|
316
|
+
python3 "${SCRIPT_DIR}/.register_helper.py" list-datasets \
|
|
317
|
+
--region "${AWS_REGION:-us-east-1}" 2>/dev/null || {
|
|
318
|
+
echo " No datasets registered yet."
|
|
319
|
+
echo " Register: ./do/register dataset <name> --s3-uri <uri> --technique <sft|dpo>"
|
|
320
|
+
}
|
|
321
|
+
else
|
|
322
|
+
echo " Register helper not available."
|
|
323
|
+
fi
|
|
324
|
+
exit 0
|
|
325
|
+
}
|
|
326
|
+
|
|
327
|
+
# ── _resolve_dataset() ────────────────────────────────────────────────────────
|
|
328
|
+
# Resolve --dataset flag to an S3 URI. Handles:
|
|
329
|
+
# hf://org/name[/split] → stage via .tune_helper.py stage-hf
|
|
330
|
+
# s3://bucket/path/ → use directly
|
|
331
|
+
# name[@v<N>] → resolve from registry via .register_helper.py
|
|
332
|
+
# Sets RESOLVED_DATASET_S3_URI on success.
|
|
333
|
+
_resolve_dataset() {
|
|
334
|
+
local dataset="${ARG_DATASET}"
|
|
335
|
+
|
|
336
|
+
# If no --dataset provided, check config.yaml
|
|
337
|
+
if [ -z "${dataset}" ]; then
|
|
338
|
+
dataset="${TRAIN_DATASET}"
|
|
339
|
+
fi
|
|
340
|
+
|
|
341
|
+
# Still empty — dataset is optional for custom technique, required for sft/dpo
|
|
342
|
+
if [ -z "${dataset}" ]; then
|
|
343
|
+
if [ "${TECHNIQUE}" = "custom" ]; then
|
|
344
|
+
return 0 # Custom technique doesn't require managed dataset resolution
|
|
345
|
+
fi
|
|
346
|
+
echo "❌ --dataset is required for technique '${TECHNIQUE}'"
|
|
347
|
+
echo " Provide: s3://bucket/path.jsonl, hf://org/name, or a registered name"
|
|
348
|
+
echo " Run ./do/train --list-datasets to see registered datasets."
|
|
349
|
+
exit 1
|
|
350
|
+
fi
|
|
351
|
+
|
|
352
|
+
# ── Parse @v<N> version suffix ────────────────────────────────────────────
|
|
353
|
+
local dataset_name="" dataset_version=""
|
|
354
|
+
if [[ "${dataset}" =~ ^(.+)@v([0-9]+)$ ]]; then
|
|
355
|
+
dataset_name="${BASH_REMATCH[1]}"
|
|
356
|
+
dataset_version="${BASH_REMATCH[2]}"
|
|
357
|
+
dataset=""
|
|
358
|
+
fi
|
|
359
|
+
|
|
360
|
+
# ── Determine dataset type and resolve ────────────────────────────────────
|
|
361
|
+
if [ -n "${dataset_name}" ]; then
|
|
362
|
+
# Registry name resolution
|
|
363
|
+
_resolve_dataset_from_registry "${dataset_name}" "${dataset_version}"
|
|
364
|
+
|
|
365
|
+
elif [[ "${dataset}" == s3://* ]]; then
|
|
366
|
+
# S3 direct — use as-is
|
|
367
|
+
echo "📂 Dataset: ${dataset} (S3 direct)"
|
|
368
|
+
RESOLVED_DATASET_S3_URI="${dataset}"
|
|
369
|
+
|
|
370
|
+
elif [[ "${dataset}" == hf://* ]]; then
|
|
371
|
+
# HuggingFace — stage via helper
|
|
372
|
+
_stage_hf_dataset "${dataset}"
|
|
373
|
+
|
|
374
|
+
else
|
|
375
|
+
# Treat as a registry name
|
|
376
|
+
_resolve_dataset_from_registry "${dataset}" ""
|
|
377
|
+
fi
|
|
378
|
+
|
|
379
|
+
# Persist resolved URI
|
|
380
|
+
if [ -n "${RESOLVED_DATASET_S3_URI}" ]; then
|
|
381
|
+
local technique_upper
|
|
382
|
+
technique_upper=$(echo "${TECHNIQUE}" | tr '[:lower:]' '[:upper:]')
|
|
383
|
+
_update_config_var "TRAIN_DATASET_S3_URI_${technique_upper}" "${RESOLVED_DATASET_S3_URI}"
|
|
384
|
+
fi
|
|
385
|
+
}
|
|
386
|
+
|
|
387
|
+
# ── _resolve_dataset_from_registry() ─────────────────────────────────────────
|
|
388
|
+
_resolve_dataset_from_registry() {
|
|
389
|
+
local name="$1"
|
|
390
|
+
local version="$2"
|
|
391
|
+
|
|
392
|
+
echo "🔍 Resolving dataset '${name}' from registry..."
|
|
393
|
+
local resolve_args=("--name" "${name}")
|
|
394
|
+
if [ -n "${version}" ]; then
|
|
395
|
+
resolve_args+=("--version" "${version}")
|
|
396
|
+
echo " Version: v${version}"
|
|
397
|
+
fi
|
|
398
|
+
|
|
399
|
+
local resolve_result
|
|
400
|
+
resolve_result=$(python3 "${SCRIPT_DIR}/.register_helper.py" resolve-dataset \
|
|
401
|
+
"${resolve_args[@]}" 2>/dev/null) || resolve_result=""
|
|
402
|
+
|
|
403
|
+
if [ -n "${resolve_result}" ]; then
|
|
404
|
+
local resolved_uri
|
|
405
|
+
resolved_uri=$(echo "${resolve_result}" | grep -E '^\{' | tail -1 | \
|
|
406
|
+
python3 -c "import sys,json; print(json.load(sys.stdin).get('s3_uri',''))" 2>/dev/null) || resolved_uri=""
|
|
407
|
+
if [ -n "${resolved_uri}" ]; then
|
|
408
|
+
echo " Resolved to: ${resolved_uri}"
|
|
409
|
+
RESOLVED_DATASET_S3_URI="${resolved_uri}"
|
|
410
|
+
return 0
|
|
411
|
+
fi
|
|
412
|
+
fi
|
|
413
|
+
|
|
414
|
+
echo "❌ Dataset '${name}' not found in registry"
|
|
415
|
+
echo " Register it: ./do/register dataset ${name} --s3-uri <uri> --technique ${TECHNIQUE}"
|
|
416
|
+
echo " Or provide directly: --dataset s3://... or --dataset hf://..."
|
|
417
|
+
exit 1
|
|
418
|
+
}
|
|
419
|
+
|
|
420
|
+
# ── _stage_hf_dataset() ──────────────────────────────────────────────────────
|
|
421
|
+
# Stage a HuggingFace dataset to S3 via .tune_helper.py stage-hf.
|
|
422
|
+
_stage_hf_dataset() {
|
|
423
|
+
local dataset="$1"
|
|
424
|
+
local hf_path="${dataset#hf://}"
|
|
425
|
+
local hf_file=""
|
|
426
|
+
|
|
427
|
+
# Extract ?file= parameter
|
|
428
|
+
if [[ "${hf_path}" == *"?file="* ]]; then
|
|
429
|
+
hf_file="${hf_path#*?file=}"
|
|
430
|
+
hf_path="${hf_path%%\?file=*}"
|
|
431
|
+
fi
|
|
432
|
+
|
|
433
|
+
local hf_org hf_name hf_split
|
|
434
|
+
hf_org=$(echo "${hf_path}" | cut -d'/' -f1)
|
|
435
|
+
hf_name=$(echo "${hf_path}" | cut -d'/' -f2)
|
|
436
|
+
hf_split=$(echo "${hf_path}" | cut -d'/' -f3-)
|
|
437
|
+
|
|
438
|
+
if [ -z "${hf_org}" ] || [ -z "${hf_name}" ]; then
|
|
439
|
+
echo "❌ Invalid HF dataset reference: ${dataset}"
|
|
440
|
+
echo " Expected format: hf://org/name or hf://org/name/split"
|
|
441
|
+
exit 1
|
|
442
|
+
fi
|
|
443
|
+
|
|
444
|
+
# Determine which helper script to use for staging
|
|
445
|
+
local helper_script=""
|
|
446
|
+
if [ -f "${SCRIPT_DIR}/.tune_helper.py" ]; then
|
|
447
|
+
helper_script="${SCRIPT_DIR}/.tune_helper.py"
|
|
448
|
+
else
|
|
449
|
+
echo "❌ Dataset staging helper not available (.tune_helper.py)"
|
|
450
|
+
echo " Stage the dataset manually to S3 and use: --dataset s3://..."
|
|
451
|
+
exit 1
|
|
452
|
+
fi
|
|
453
|
+
|
|
454
|
+
echo "📦 Staging HuggingFace dataset: ${hf_org}/${hf_name}"
|
|
455
|
+
if [ -n "${hf_split}" ]; then
|
|
456
|
+
echo " Split: ${hf_split}"
|
|
457
|
+
fi
|
|
458
|
+
|
|
459
|
+
# Resolve output bucket from profile
|
|
460
|
+
local output_bucket="${S3_BUCKET:-}"
|
|
461
|
+
if [ -z "${output_bucket}" ]; then
|
|
462
|
+
output_bucket=$(python3 -c "
|
|
463
|
+
import json, os
|
|
464
|
+
config_path = os.path.expanduser('~/.ml-container-creator/config.json')
|
|
465
|
+
if os.path.exists(config_path):
|
|
466
|
+
with open(config_path) as f:
|
|
467
|
+
cfg = json.load(f)
|
|
468
|
+
print(cfg.get('s3Bucket', ''))
|
|
469
|
+
" 2>/dev/null) || output_bucket=""
|
|
470
|
+
fi
|
|
471
|
+
|
|
472
|
+
if [ -z "${output_bucket}" ]; then
|
|
473
|
+
echo "❌ No S3 bucket configured for dataset staging"
|
|
474
|
+
echo " Run ./do/bootstrap to configure, or set S3_BUCKET in do/config"
|
|
475
|
+
exit 1
|
|
476
|
+
fi
|
|
477
|
+
|
|
478
|
+
# Build stage-hf arguments
|
|
479
|
+
local stage_args=(
|
|
480
|
+
--hf-org "${hf_org}"
|
|
481
|
+
--hf-name "${hf_name}"
|
|
482
|
+
--output-bucket "${output_bucket}"
|
|
483
|
+
--project-name "${PROJECT_NAME}"
|
|
484
|
+
--region "${AWS_REGION}"
|
|
485
|
+
--technique "${TECHNIQUE}"
|
|
486
|
+
)
|
|
487
|
+
if [ -n "${hf_split}" ]; then
|
|
488
|
+
stage_args+=(--hf-split "${hf_split}")
|
|
489
|
+
fi
|
|
490
|
+
if [ -n "${HF_TOKEN_ARN:-}" ]; then
|
|
491
|
+
stage_args+=(--hf-secret-name "${HF_TOKEN_ARN}")
|
|
492
|
+
fi
|
|
493
|
+
if [ -n "${hf_file}" ]; then
|
|
494
|
+
stage_args+=(--hf-file "${hf_file}")
|
|
495
|
+
fi
|
|
496
|
+
|
|
497
|
+
local stage_result
|
|
498
|
+
stage_result=$(python3 "${helper_script}" stage-hf "${stage_args[@]}") || {
|
|
499
|
+
echo "❌ Failed to stage HF dataset"
|
|
500
|
+
exit 1
|
|
501
|
+
}
|
|
502
|
+
|
|
503
|
+
# Check for error in response
|
|
504
|
+
local has_error
|
|
505
|
+
has_error=$(echo "${stage_result}" | python3 -c "import sys,json; d=json.load(sys.stdin); print('yes' if 'error' in d else 'no')" 2>/dev/null) || has_error="yes"
|
|
506
|
+
|
|
507
|
+
if [ "${has_error}" = "yes" ]; then
|
|
508
|
+
local error_msg
|
|
509
|
+
error_msg=$(echo "${stage_result}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('error','Unknown error'))" 2>/dev/null) || error_msg="Unknown error"
|
|
510
|
+
echo "❌ ${error_msg}"
|
|
511
|
+
exit 1
|
|
512
|
+
fi
|
|
513
|
+
|
|
514
|
+
RESOLVED_DATASET_S3_URI=$(echo "${stage_result}" | python3 -c "import sys,json; print(json.load(sys.stdin)['s3_uri'])" 2>/dev/null)
|
|
515
|
+
local row_count
|
|
516
|
+
row_count=$(echo "${stage_result}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('num_records',0))" 2>/dev/null) || row_count="0"
|
|
517
|
+
|
|
518
|
+
echo " ✅ Staged to: ${RESOLVED_DATASET_S3_URI}"
|
|
519
|
+
echo " Records: ${row_count}"
|
|
520
|
+
echo ""
|
|
521
|
+
}
|
|
522
|
+
|
|
109
523
|
# ── _parse_config() ───────────────────────────────────────────────────────────
|
|
110
524
|
# Read and parse do/training/config.yaml into bash variables.
|
|
111
525
|
# Uses yq if available, falls back to python3 YAML parsing.
|
|
@@ -135,6 +549,11 @@ _parse_config_yq() {
|
|
|
135
549
|
TRAIN_IMAGE=$(yq -r '.image // ""' "${CONFIG_FILE}")
|
|
136
550
|
TRAIN_SCRIPT=$(yq -r '.script // ""' "${CONFIG_FILE}")
|
|
137
551
|
TRAIN_INSTANCE_TYPE=$(yq -r '.instance_type // ""' "${CONFIG_FILE}")
|
|
552
|
+
|
|
553
|
+
# Resolve shell variables in image URI (backward compat with old-style config)
|
|
554
|
+
if echo "${TRAIN_IMAGE}" | grep -q '^\${\|^\${'; then
|
|
555
|
+
TRAIN_IMAGE=$(eval echo "${TRAIN_IMAGE}")
|
|
556
|
+
fi
|
|
138
557
|
TRAIN_INSTANCE_COUNT=$(yq -r '.instance_count // "1"' "${CONFIG_FILE}")
|
|
139
558
|
TRAIN_DATASET=$(yq -r '.dataset // ""' "${CONFIG_FILE}")
|
|
140
559
|
TRAIN_OUTPUT_PATH=$(yq -r '.output_path // ""' "${CONFIG_FILE}")
|
|
@@ -302,12 +721,22 @@ _validate_config() {
|
|
|
302
721
|
fi
|
|
303
722
|
|
|
304
723
|
if [ -z "${TRAIN_OUTPUT_PATH}" ]; then
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
724
|
+
# Auto-resolve from profile: use benchmarkS3Bucket or construct from project name
|
|
725
|
+
if [ -n "${_PROFILE_benchmarkS3Bucket:-}" ]; then
|
|
726
|
+
TRAIN_OUTPUT_PATH="s3://${_PROFILE_benchmarkS3Bucket}/${PROJECT_NAME}/training-output/"
|
|
727
|
+
elif [ -n "${BENCHMARK_S3_OUTPUT_PATH:-}" ]; then
|
|
728
|
+
# Derive from benchmark output path (replace /benchmarks/ with /training-output/)
|
|
729
|
+
TRAIN_OUTPUT_PATH="${BENCHMARK_S3_OUTPUT_PATH%/benchmarks/*}/training-output/${PROJECT_NAME}/"
|
|
730
|
+
fi
|
|
731
|
+
if [ -z "${TRAIN_OUTPUT_PATH}" ]; then
|
|
732
|
+
echo "❌ Missing required field: output_path"
|
|
733
|
+
echo " The S3 output path is required in do/training/config.yaml"
|
|
734
|
+
echo " Or run 'ml-container-creator bootstrap' to configure an S3 bucket."
|
|
735
|
+
echo ""
|
|
736
|
+
echo " Expected format: output_path: \"s3://my-bucket/output/\""
|
|
737
|
+
echo ""
|
|
738
|
+
has_error=true
|
|
739
|
+
fi
|
|
311
740
|
fi
|
|
312
741
|
|
|
313
742
|
# Spot training requires a checkpoint path for resumption
|
|
@@ -350,29 +779,33 @@ _check_idempotency() {
|
|
|
350
779
|
echo " Checking status..."
|
|
351
780
|
echo ""
|
|
352
781
|
|
|
353
|
-
#
|
|
354
|
-
local
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
--
|
|
358
|
-
describe_exit_code=${describe_exit_code:-0}
|
|
782
|
+
# Query status via SDK v3 helper
|
|
783
|
+
local status_json
|
|
784
|
+
status_json=$(python3 "${SCRIPT_DIR}/.train_helper.py" status \
|
|
785
|
+
--job-name "${TRAIN_JOB_NAME}" \
|
|
786
|
+
--region "${AWS_REGION}" 2>/dev/null | grep -E '^\{' | tail -1) || status_json=""
|
|
359
787
|
|
|
360
|
-
if [ ${
|
|
361
|
-
# If
|
|
788
|
+
if [ -z "${status_json}" ]; then
|
|
789
|
+
# If status query fails, proceed to new job
|
|
790
|
+
echo "⚠️ Could not query existing job: ${TRAIN_JOB_NAME}"
|
|
791
|
+
echo " Proceeding to create a new job."
|
|
792
|
+
echo ""
|
|
793
|
+
return 0
|
|
794
|
+
fi
|
|
795
|
+
|
|
796
|
+
# Check for error in response
|
|
797
|
+
local has_error
|
|
798
|
+
has_error=$(echo "${status_json}" | python3 -c "import sys,json; d=json.load(sys.stdin); print('yes' if d.get('error') else 'no')" 2>/dev/null) || has_error="no"
|
|
799
|
+
|
|
800
|
+
if [ "${has_error}" = "yes" ]; then
|
|
362
801
|
echo "⚠️ Could not describe existing job: ${TRAIN_JOB_NAME}"
|
|
363
|
-
echo " ${describe_output}"
|
|
364
802
|
echo " Proceeding to create a new job."
|
|
365
803
|
echo ""
|
|
366
804
|
return 0
|
|
367
805
|
fi
|
|
368
806
|
|
|
369
|
-
# Extract status from the JSON response using python3
|
|
370
807
|
local job_status
|
|
371
|
-
job_status=$(echo "${
|
|
372
|
-
import sys, json
|
|
373
|
-
resp = json.load(sys.stdin)
|
|
374
|
-
print(resp.get('TrainingJobStatus', 'Unknown'))
|
|
375
|
-
" 2>/dev/null) || job_status="Unknown"
|
|
808
|
+
job_status=$(echo "${status_json}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('status','Unknown'))" 2>/dev/null) || job_status="Unknown"
|
|
376
809
|
|
|
377
810
|
case "${job_status}" in
|
|
378
811
|
InProgress)
|
|
@@ -386,19 +819,12 @@ print(resp.get('TrainingJobStatus', 'Unknown'))
|
|
|
386
819
|
Completed)
|
|
387
820
|
echo "✅ Training job already completed: ${TRAIN_JOB_NAME}"
|
|
388
821
|
echo ""
|
|
389
|
-
# Pass the describe output to _handle_completion via a temp file
|
|
390
|
-
local describe_file="/tmp/train-describe-${TRAIN_JOB_NAME}.json"
|
|
391
|
-
echo "${describe_output}" > "${describe_file}"
|
|
392
822
|
_handle_completion
|
|
393
823
|
exit 0
|
|
394
824
|
;;
|
|
395
825
|
Failed)
|
|
396
826
|
local failure_reason
|
|
397
|
-
failure_reason=$(echo "${
|
|
398
|
-
import sys, json
|
|
399
|
-
resp = json.load(sys.stdin)
|
|
400
|
-
print(resp.get('FailureReason', 'No failure reason provided'))
|
|
401
|
-
" 2>/dev/null) || failure_reason="No failure reason provided"
|
|
827
|
+
failure_reason=$(echo "${status_json}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('failure_reason','') or 'No failure reason provided')" 2>/dev/null) || failure_reason="No failure reason provided"
|
|
402
828
|
|
|
403
829
|
echo "❌ Previous training job failed: ${TRAIN_JOB_NAME}"
|
|
404
830
|
echo " Reason: ${failure_reason}"
|
|
@@ -469,9 +895,177 @@ _build_job_request() {
|
|
|
469
895
|
fi
|
|
470
896
|
}
|
|
471
897
|
|
|
898
|
+
<% if (deploymentTarget === 'hyperpod-eks') { %>
|
|
472
899
|
# ── _submit_job() ─────────────────────────────────────────────────────────────
|
|
473
|
-
#
|
|
474
|
-
#
|
|
900
|
+
# HyperPod EKS: Generate and apply a K8s Job manifest for single-pod training.
|
|
901
|
+
# Multi-node (parallelism > 1) requires PyTorchJob CRD — see Epic 8.
|
|
902
|
+
_submit_job() {
|
|
903
|
+
local manifest_file="${SCRIPT_DIR}/../.train-job.yaml"
|
|
904
|
+
|
|
905
|
+
# Handle --dry-run: generate manifest and print without applying
|
|
906
|
+
if [ "${ARG_DRY_RUN}" = true ]; then
|
|
907
|
+
_generate_k8s_manifest "${manifest_file}"
|
|
908
|
+
echo ""
|
|
909
|
+
echo "🔍 Dry run — K8s Job manifest:"
|
|
910
|
+
echo ""
|
|
911
|
+
cat "${manifest_file}"
|
|
912
|
+
echo ""
|
|
913
|
+
rm -f "${manifest_file}"
|
|
914
|
+
exit 0
|
|
915
|
+
fi
|
|
916
|
+
|
|
917
|
+
echo ""
|
|
918
|
+
echo "🚀 Submitting training job to HyperPod EKS..."
|
|
919
|
+
|
|
920
|
+
_generate_k8s_manifest "${manifest_file}"
|
|
921
|
+
|
|
922
|
+
# Apply the manifest
|
|
923
|
+
kubectl apply -f "${manifest_file}" 2>&1 || {
|
|
924
|
+
echo "❌ Failed to apply K8s Job manifest"
|
|
925
|
+
echo " Check: kubectl access, namespace exists, PVC available"
|
|
926
|
+
exit 1
|
|
927
|
+
}
|
|
928
|
+
|
|
929
|
+
# Persist job name to do/config
|
|
930
|
+
_update_config_var "TRAIN_JOB_NAME" "${JOB_NAME}"
|
|
931
|
+
echo " ✅ Job submitted: ${JOB_NAME}"
|
|
932
|
+
echo " Namespace: ${HYPERPOD_NAMESPACE:-default}"
|
|
933
|
+
echo ""
|
|
934
|
+
}
|
|
935
|
+
|
|
936
|
+
# ── _generate_k8s_manifest() ─────────────────────────────────────────────────
|
|
937
|
+
# Generate a K8s batch/v1 Job YAML for single-node multi-GPU training.
|
|
938
|
+
_generate_k8s_manifest() {
|
|
939
|
+
local manifest_file="$1"
|
|
940
|
+
local gpu_count="${TRAIN_INSTANCE_COUNT:-1}"
|
|
941
|
+
|
|
942
|
+
# Detect EFA availability on cluster nodes
|
|
943
|
+
local efa_count
|
|
944
|
+
efa_count=$(kubectl get nodes -o jsonpath='{.items[0].status.allocatable.vpc\.amazonaws\.com/efa}' 2>/dev/null || echo "0")
|
|
945
|
+
|
|
946
|
+
# Build EFA resource request if available
|
|
947
|
+
local efa_resource=""
|
|
948
|
+
local nccl_env=""
|
|
949
|
+
if [ "${efa_count}" != "0" ] && [ "${efa_count}" != "" ]; then
|
|
950
|
+
efa_resource=" vpc.amazonaws.com/efa: ${efa_count}"
|
|
951
|
+
nccl_env=" - name: NCCL_SOCKET_IFNAME
|
|
952
|
+
value: \"eth0\"
|
|
953
|
+
- name: FI_PROVIDER
|
|
954
|
+
value: \"efa\"
|
|
955
|
+
- name: FI_EFA_USE_DEVICE_RDMA
|
|
956
|
+
value: \"1\""
|
|
957
|
+
fi
|
|
958
|
+
|
|
959
|
+
cat > "${manifest_file}" <<MANIFEST
|
|
960
|
+
# Generated by do/train for HyperPod EKS (single-node)
|
|
961
|
+
# Multi-node distributed training requires PyTorchJob CRD — see Epic 8.
|
|
962
|
+
apiVersion: batch/v1
|
|
963
|
+
kind: Job
|
|
964
|
+
metadata:
|
|
965
|
+
name: ${JOB_NAME}
|
|
966
|
+
namespace: ${HYPERPOD_NAMESPACE:-default}
|
|
967
|
+
labels:
|
|
968
|
+
app: ml-container-creator
|
|
969
|
+
project: ${PROJECT_NAME}
|
|
970
|
+
technique: ${TECHNIQUE:-custom}
|
|
971
|
+
spec:
|
|
972
|
+
parallelism: 1
|
|
973
|
+
completions: 1
|
|
974
|
+
backoffLimit: 0
|
|
975
|
+
template:
|
|
976
|
+
metadata:
|
|
977
|
+
labels:
|
|
978
|
+
app: ml-container-creator
|
|
979
|
+
job-name: ${JOB_NAME}
|
|
980
|
+
spec:
|
|
981
|
+
containers:
|
|
982
|
+
- name: training
|
|
983
|
+
image: ${TRAIN_IMAGE}
|
|
984
|
+
command: ["accelerate", "launch", "--config_file", "/workspace/training/${TECHNIQUE:-custom}/accelerate_config.yaml", "/workspace/training/${TECHNIQUE:-custom}/train.py"]
|
|
985
|
+
env:
|
|
986
|
+
- name: DATA_DIR
|
|
987
|
+
value: "/data/training"
|
|
988
|
+
- name: OUTPUT_DIR
|
|
989
|
+
value: "/output/model"
|
|
990
|
+
- name: CHECKPOINT_DIR
|
|
991
|
+
value: "/output/checkpoints"
|
|
992
|
+
- name: HF_MODEL_ID
|
|
993
|
+
value: "${HF_MODEL_ID:-}"
|
|
994
|
+
- name: NUM_GPUS
|
|
995
|
+
value: "${gpu_count}"
|
|
996
|
+
- name: TRAIN_TECHNIQUE
|
|
997
|
+
value: "${TECHNIQUE:-custom}"
|
|
998
|
+
${nccl_env}
|
|
999
|
+
resources:
|
|
1000
|
+
limits:
|
|
1001
|
+
nvidia.com/gpu: ${gpu_count}
|
|
1002
|
+
${efa_resource}
|
|
1003
|
+
volumeMounts:
|
|
1004
|
+
- name: data
|
|
1005
|
+
mountPath: /data
|
|
1006
|
+
- name: output
|
|
1007
|
+
mountPath: /output
|
|
1008
|
+
- name: workspace
|
|
1009
|
+
mountPath: /workspace
|
|
1010
|
+
volumes:
|
|
1011
|
+
- name: data
|
|
1012
|
+
persistentVolumeClaim:
|
|
1013
|
+
claimName: ${DATA_PVC:-training-data}
|
|
1014
|
+
- name: output
|
|
1015
|
+
persistentVolumeClaim:
|
|
1016
|
+
claimName: ${OUTPUT_PVC:-training-output}
|
|
1017
|
+
- name: workspace
|
|
1018
|
+
configMap:
|
|
1019
|
+
name: ${PROJECT_NAME}-training-scripts
|
|
1020
|
+
restartPolicy: Never
|
|
1021
|
+
MANIFEST
|
|
1022
|
+
}
|
|
1023
|
+
|
|
1024
|
+
# ── _poll_job() ───────────────────────────────────────────────────────────────
|
|
1025
|
+
# HyperPod EKS: Poll K8s Job status via kubectl.
|
|
1026
|
+
_poll_job() {
|
|
1027
|
+
local job_name="${JOB_NAME:-$TRAIN_JOB_NAME}"
|
|
1028
|
+
local namespace="${HYPERPOD_NAMESPACE:-default}"
|
|
1029
|
+
|
|
1030
|
+
echo "⏳ Polling K8s Job: ${job_name} (namespace: ${namespace})"
|
|
1031
|
+
echo " (Ctrl+C to stop polling — job continues in cluster)"
|
|
1032
|
+
echo ""
|
|
1033
|
+
|
|
1034
|
+
while true; do
|
|
1035
|
+
local status
|
|
1036
|
+
status=$(kubectl get job "${job_name}" -n "${namespace}" -o jsonpath='{.status.conditions[0].type}' 2>/dev/null) || status=""
|
|
1037
|
+
|
|
1038
|
+
local active
|
|
1039
|
+
active=$(kubectl get job "${job_name}" -n "${namespace}" -o jsonpath='{.status.active}' 2>/dev/null) || active="0"
|
|
1040
|
+
|
|
1041
|
+
local succeeded
|
|
1042
|
+
succeeded=$(kubectl get job "${job_name}" -n "${namespace}" -o jsonpath='{.status.succeeded}' 2>/dev/null) || succeeded="0"
|
|
1043
|
+
|
|
1044
|
+
local failed
|
|
1045
|
+
failed=$(kubectl get job "${job_name}" -n "${namespace}" -o jsonpath='{.status.failed}' 2>/dev/null) || failed="0"
|
|
1046
|
+
|
|
1047
|
+
if [ "${succeeded}" = "1" ]; then
|
|
1048
|
+
echo " ✅ Completed"
|
|
1049
|
+
echo ""
|
|
1050
|
+
echo "✅ Training job completed: ${job_name}"
|
|
1051
|
+
break
|
|
1052
|
+
elif [ "${failed}" = "1" ] || [ "${status}" = "Failed" ]; then
|
|
1053
|
+
echo " ❌ Failed"
|
|
1054
|
+
echo ""
|
|
1055
|
+
echo "❌ Training job failed: ${job_name}"
|
|
1056
|
+
echo " Logs: kubectl logs job/${job_name} -n ${namespace}"
|
|
1057
|
+
exit 2
|
|
1058
|
+
else
|
|
1059
|
+
echo " 🔄 Running (active: ${active:-0})"
|
|
1060
|
+
fi
|
|
1061
|
+
|
|
1062
|
+
sleep "${POLL_INTERVAL}"
|
|
1063
|
+
done
|
|
1064
|
+
}
|
|
1065
|
+
<% } else { %>
|
|
1066
|
+
# ── _submit_job() ─────────────────────────────────────────────────────────────
|
|
1067
|
+
# Submit training job via .train_helper.py (SDK v3).
|
|
1068
|
+
# Handles --dry-run, error detection, and config persistence.
|
|
475
1069
|
_submit_job() {
|
|
476
1070
|
# Handle --dry-run: print the request JSON and exit without submitting
|
|
477
1071
|
if [ "${ARG_DRY_RUN}" = true ]; then
|
|
@@ -487,59 +1081,44 @@ _submit_job() {
|
|
|
487
1081
|
echo ""
|
|
488
1082
|
echo "🚀 Submitting training job..."
|
|
489
1083
|
|
|
490
|
-
# Submit
|
|
1084
|
+
# Submit via SDK v3 helper
|
|
491
1085
|
local submit_output
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
--
|
|
495
|
-
submit_exit_code=${submit_exit_code:-0}
|
|
1086
|
+
submit_output=$(python3 "${SCRIPT_DIR}/.train_helper.py" submit \
|
|
1087
|
+
--config "${JOB_REQUEST_FILE}" \
|
|
1088
|
+
--region "${AWS_REGION}" 2>&1 | grep -E '^\{' | tail -1) || submit_output=""
|
|
496
1089
|
|
|
497
1090
|
# Clean up the temporary request file
|
|
498
1091
|
rm -f "${JOB_REQUEST_FILE}"
|
|
499
1092
|
|
|
500
|
-
if [ ${
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
else
|
|
506
|
-
# Failure — detect error type and provide remediation
|
|
507
|
-
if echo "${submit_output}" | grep -q "AccessDeniedException"; then
|
|
508
|
-
# Extract the permission or action from the error message
|
|
509
|
-
local missing_permission
|
|
510
|
-
missing_permission=$(echo "${submit_output}" | grep -oP '(?<=performing: )[^ ]+' || echo "")
|
|
511
|
-
if [ -z "${missing_permission}" ]; then
|
|
512
|
-
missing_permission=$(echo "${submit_output}" | grep -oP '(?<=action: )[^ ]+' || echo "")
|
|
513
|
-
fi
|
|
514
|
-
if [ -z "${missing_permission}" ]; then
|
|
515
|
-
missing_permission="sagemaker:CreateTrainingJob"
|
|
516
|
-
fi
|
|
1093
|
+
if [ -z "${submit_output}" ]; then
|
|
1094
|
+
echo "❌ Failed to submit training job (no response from helper)"
|
|
1095
|
+
echo " Ensure sagemaker SDK v3 is installed: pip install 'sagemaker>=3.0'"
|
|
1096
|
+
exit 1
|
|
1097
|
+
fi
|
|
517
1098
|
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
echo "❌ Failed to submit training job"
|
|
528
|
-
echo " ${submit_output}"
|
|
529
|
-
echo ""
|
|
530
|
-
echo " Check the error above and verify your configuration in do/training/config.yaml."
|
|
531
|
-
exit 1
|
|
532
|
-
fi
|
|
1099
|
+
# Check for error in response
|
|
1100
|
+
local has_error
|
|
1101
|
+
has_error=$(echo "${submit_output}" | python3 -c "import sys,json; d=json.load(sys.stdin); print('yes' if d.get('error') else 'no')" 2>/dev/null) || has_error="yes"
|
|
1102
|
+
|
|
1103
|
+
if [ "${has_error}" = "yes" ]; then
|
|
1104
|
+
local error_msg
|
|
1105
|
+
error_msg=$(echo "${submit_output}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('message','Unknown error'))" 2>/dev/null) || error_msg="Unknown error"
|
|
1106
|
+
echo "❌ ${error_msg}"
|
|
1107
|
+
exit 1
|
|
533
1108
|
fi
|
|
1109
|
+
|
|
1110
|
+
# Success — persist job name to do/config
|
|
1111
|
+
_update_config_var "TRAIN_JOB_NAME" "${JOB_NAME}"
|
|
1112
|
+
echo " ✅ Job submitted successfully: ${JOB_NAME}"
|
|
1113
|
+
echo ""
|
|
534
1114
|
}
|
|
535
1115
|
|
|
536
1116
|
# ── _poll_job() ───────────────────────────────────────────────────────────────
|
|
537
|
-
# Poll
|
|
538
|
-
# Displays: job status, secondary status, elapsed time
|
|
1117
|
+
# Poll training job status via .train_helper.py every POLL_INTERVAL seconds.
|
|
1118
|
+
# Displays: job status, secondary status, elapsed time.
|
|
539
1119
|
# On Completed: breaks loop and returns (caller handles completion).
|
|
540
1120
|
# On Failed: displays FailureReason and exits 2.
|
|
541
1121
|
# On Stopped: displays stopped message and exits 2.
|
|
542
|
-
# On spot interruption: explains auto-resume from checkpoint.
|
|
543
1122
|
_poll_job() {
|
|
544
1123
|
local job_name="${JOB_NAME:-$TRAIN_JOB_NAME}"
|
|
545
1124
|
|
|
@@ -548,49 +1127,39 @@ _poll_job() {
|
|
|
548
1127
|
echo ""
|
|
549
1128
|
|
|
550
1129
|
while true; do
|
|
551
|
-
#
|
|
552
|
-
local
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
--
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
echo "⚠️ Failed to describe job (will retry): ${describe_output}"
|
|
1130
|
+
# Query status via SDK v3 helper
|
|
1131
|
+
local status_json
|
|
1132
|
+
status_json=$(python3 "${SCRIPT_DIR}/.train_helper.py" status \
|
|
1133
|
+
--job-name "${job_name}" \
|
|
1134
|
+
--region "${AWS_REGION}" 2>/dev/null | grep -E '^\{' | tail -1)
|
|
1135
|
+
|
|
1136
|
+
if [ -z "${status_json}" ]; then
|
|
1137
|
+
echo "⚠️ Failed to query job status (will retry)"
|
|
560
1138
|
sleep "${POLL_INTERVAL}"
|
|
561
1139
|
continue
|
|
562
1140
|
fi
|
|
563
1141
|
|
|
564
|
-
#
|
|
565
|
-
local
|
|
566
|
-
|
|
567
|
-
local parse_exit_code=$?
|
|
1142
|
+
# Check for error
|
|
1143
|
+
local has_error
|
|
1144
|
+
has_error=$(echo "${status_json}" | python3 -c "import sys,json; d=json.load(sys.stdin); print('yes' if d.get('error') else 'no')" 2>/dev/null) || has_error="no"
|
|
568
1145
|
|
|
569
|
-
if [ ${
|
|
570
|
-
echo "⚠️
|
|
1146
|
+
if [ "${has_error}" = "yes" ]; then
|
|
1147
|
+
echo "⚠️ Status query error (will retry)"
|
|
571
1148
|
sleep "${POLL_INTERVAL}"
|
|
572
1149
|
continue
|
|
573
1150
|
fi
|
|
574
1151
|
|
|
575
|
-
#
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
# DISPLAY=<formatted display text>
|
|
582
|
-
local job_status=""
|
|
583
|
-
local secondary_status=""
|
|
584
|
-
local display_text=""
|
|
585
|
-
local failure_reason=""
|
|
586
|
-
|
|
587
|
-
job_status=$(echo "${poll_result}" | grep '^STATUS=' | cut -d= -f2-)
|
|
588
|
-
secondary_status=$(echo "${poll_result}" | grep '^SECONDARY=' | cut -d= -f2-)
|
|
589
|
-
failure_reason=$(echo "${poll_result}" | grep '^FAILURE_REASON=' | cut -d= -f2-)
|
|
590
|
-
display_text=$(echo "${poll_result}" | grep '^DISPLAY=' | cut -d= -f2-)
|
|
1152
|
+
# Extract fields
|
|
1153
|
+
local job_status display_text failure_reason secondary_status
|
|
1154
|
+
job_status=$(echo "${status_json}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('status','Unknown'))" 2>/dev/null) || job_status="Unknown"
|
|
1155
|
+
display_text=$(echo "${status_json}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('display',''))" 2>/dev/null) || display_text=""
|
|
1156
|
+
failure_reason=$(echo "${status_json}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('failure_reason','') or '')" 2>/dev/null) || failure_reason=""
|
|
1157
|
+
secondary_status=$(echo "${status_json}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('secondary_status','') or '')" 2>/dev/null) || secondary_status=""
|
|
591
1158
|
|
|
592
1159
|
# Print the formatted status line
|
|
593
|
-
|
|
1160
|
+
if [ -n "${display_text}" ]; then
|
|
1161
|
+
echo "${display_text}"
|
|
1162
|
+
fi
|
|
594
1163
|
|
|
595
1164
|
# Handle terminal states
|
|
596
1165
|
case "${job_status}" in
|
|
@@ -619,7 +1188,7 @@ _poll_job() {
|
|
|
619
1188
|
;;
|
|
620
1189
|
esac
|
|
621
1190
|
|
|
622
|
-
# Handle spot interruption
|
|
1191
|
+
# Handle spot interruption
|
|
623
1192
|
if echo "${secondary_status}" | grep -qi "interrupted"; then
|
|
624
1193
|
echo ""
|
|
625
1194
|
echo " ℹ️ Spot instance interrupted. The job will automatically resume"
|
|
@@ -631,9 +1200,10 @@ _poll_job() {
|
|
|
631
1200
|
sleep "${POLL_INTERVAL}"
|
|
632
1201
|
done
|
|
633
1202
|
}
|
|
1203
|
+
<% } %>
|
|
634
1204
|
|
|
635
1205
|
# ── _handle_completion() ──────────────────────────────────────────────────────
|
|
636
|
-
# Store output paths and invoke feedback loop.
|
|
1206
|
+
# Store output paths, write TRAIN_* lifecycle vars, and invoke feedback loop.
|
|
637
1207
|
# Extracts model artifacts path, detects output type, and prints next steps.
|
|
638
1208
|
_handle_completion() {
|
|
639
1209
|
local job_name="${JOB_NAME:-$TRAIN_JOB_NAME}"
|
|
@@ -666,8 +1236,15 @@ print(artifacts.get('S3ModelArtifacts', ''))
|
|
|
666
1236
|
return 1
|
|
667
1237
|
fi
|
|
668
1238
|
|
|
669
|
-
# Write
|
|
670
|
-
_update_config_var "
|
|
1239
|
+
# ── Write TRAIN_* lifecycle variables to do/config ────────────────────────
|
|
1240
|
+
_update_config_var "TRAIN_OUTPUT_PATH_LATEST" "${output_path}"
|
|
1241
|
+
_update_config_var "TRAIN_JOB_NAME" "${job_name}"
|
|
1242
|
+
|
|
1243
|
+
# Write technique-specific adapter path
|
|
1244
|
+
local technique_upper
|
|
1245
|
+
technique_upper=$(echo "${TECHNIQUE:-custom}" | tr '[:lower:]' '[:upper:]')
|
|
1246
|
+
_update_config_var "TRAIN_ADAPTER_PATH_${technique_upper}" "${output_path}"
|
|
1247
|
+
_update_config_var "TRAIN_TECHNIQUE" "${TECHNIQUE:-custom}"
|
|
671
1248
|
|
|
672
1249
|
# Detect output type: check for adapter_config.json in output path
|
|
673
1250
|
local output_type="full-model"
|
|
@@ -679,6 +1256,23 @@ print(artifacts.get('S3ModelArtifacts', ''))
|
|
|
679
1256
|
source "${SCRIPT_DIR}/lib/feedback.sh"
|
|
680
1257
|
print_completion_feedback "${output_path}" "${output_type}" "${job_name}"
|
|
681
1258
|
|
|
1259
|
+
# ── Auto-register (unless --no-register) ─────────────────────────────────
|
|
1260
|
+
if [ "${ARG_NO_REGISTER}" != true ] && [ -f "${SCRIPT_DIR}/register" ]; then
|
|
1261
|
+
if [ -n "${RESOLVED_DATASET_S3_URI}" ]; then
|
|
1262
|
+
echo "📝 Auto-registering dataset for technique '${TECHNIQUE}'..."
|
|
1263
|
+
"${SCRIPT_DIR}/register" dataset --from-train "${TECHNIQUE}" 2>/dev/null || {
|
|
1264
|
+
echo " ⚠️ Auto-registration skipped (non-fatal)"
|
|
1265
|
+
}
|
|
1266
|
+
fi
|
|
1267
|
+
fi
|
|
1268
|
+
|
|
1269
|
+
# ── Print next steps ─────────────────────────────────────────────────────
|
|
1270
|
+
if [ "${output_type}" = "adapter" ]; then
|
|
1271
|
+
echo ""
|
|
1272
|
+
echo " Next: ./do/adapter --from-train ${TECHNIQUE}"
|
|
1273
|
+
echo " to stage the adapter for deployment."
|
|
1274
|
+
fi
|
|
1275
|
+
|
|
682
1276
|
# If spot training was enabled, display cost savings
|
|
683
1277
|
if [ "${TRAIN_ENABLE_SPOT:-false}" = "true" ]; then
|
|
684
1278
|
local billable_seconds training_seconds savings_pct
|
|
@@ -735,6 +1329,12 @@ if [ "${ARG_HELP}" = true ]; then
|
|
|
735
1329
|
_show_help
|
|
736
1330
|
fi
|
|
737
1331
|
|
|
1332
|
+
# Handle --list-datasets
|
|
1333
|
+
if [ "${ARG_LIST_DATASETS}" = true ]; then
|
|
1334
|
+
source "${SCRIPT_DIR}/config"
|
|
1335
|
+
_list_datasets_cmd
|
|
1336
|
+
fi
|
|
1337
|
+
|
|
738
1338
|
if [ "${ARG_STATUS}" = true ]; then
|
|
739
1339
|
# Show status of current tracked job without submitting
|
|
740
1340
|
if [ -z "${TRAIN_JOB_NAME:-}" ]; then
|
|
@@ -746,28 +1346,181 @@ if [ "${ARG_STATUS}" = true ]; then
|
|
|
746
1346
|
echo "📊 Training Job Status"
|
|
747
1347
|
echo " Job: ${TRAIN_JOB_NAME}"
|
|
748
1348
|
|
|
749
|
-
#
|
|
750
|
-
|
|
751
|
-
--
|
|
752
|
-
--region "${AWS_REGION}" 2
|
|
1349
|
+
# Query status via SDK v3 helper
|
|
1350
|
+
STATUS_RESULT=$(python3 "${SCRIPT_DIR}/.train_helper.py" status \
|
|
1351
|
+
--job-name "${TRAIN_JOB_NAME}" \
|
|
1352
|
+
--region "${AWS_REGION}" 2>/dev/null | grep -E '^\{' | tail -1) || {
|
|
753
1353
|
echo ""
|
|
754
|
-
echo "❌ Failed to
|
|
755
|
-
echo "
|
|
756
|
-
echo ""
|
|
757
|
-
echo " The job may have been deleted or the name is incorrect."
|
|
1354
|
+
echo "❌ Failed to query training job status: ${TRAIN_JOB_NAME}"
|
|
1355
|
+
echo " Ensure sagemaker SDK v3 is installed: pip install 'sagemaker>=3.0'"
|
|
758
1356
|
echo " Run ./do/train --force to start a new job."
|
|
759
1357
|
exit 1
|
|
760
1358
|
}
|
|
761
1359
|
|
|
762
|
-
|
|
763
|
-
|
|
1360
|
+
if [ -z "${STATUS_RESULT}" ]; then
|
|
1361
|
+
echo ""
|
|
1362
|
+
echo "❌ No response from status helper for: ${TRAIN_JOB_NAME}"
|
|
1363
|
+
echo " Run ./do/train --force to start a new job."
|
|
1364
|
+
exit 1
|
|
1365
|
+
fi
|
|
1366
|
+
|
|
1367
|
+
# Display the status info
|
|
1368
|
+
DISPLAY_TEXT=$(echo "${STATUS_RESULT}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('display',''))" 2>/dev/null) || DISPLAY_TEXT=""
|
|
1369
|
+
if [ -n "${DISPLAY_TEXT}" ]; then
|
|
1370
|
+
echo "${DISPLAY_TEXT}"
|
|
1371
|
+
fi
|
|
1372
|
+
|
|
1373
|
+
# Show additional details for completed jobs
|
|
1374
|
+
STATUS_VAL=$(echo "${STATUS_RESULT}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('status','Unknown'))" 2>/dev/null) || STATUS_VAL="Unknown"
|
|
1375
|
+
if [ "${STATUS_VAL}" = "Completed" ]; then
|
|
1376
|
+
ARTIFACTS_VAL=$(echo "${STATUS_RESULT}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('model_artifacts','') or '')" 2>/dev/null) || ARTIFACTS_VAL=""
|
|
1377
|
+
if [ -n "${ARTIFACTS_VAL}" ]; then
|
|
1378
|
+
echo " 📦 Artifacts: ${ARTIFACTS_VAL}"
|
|
1379
|
+
fi
|
|
1380
|
+
fi
|
|
1381
|
+
echo ""
|
|
764
1382
|
exit 0
|
|
765
1383
|
fi
|
|
766
1384
|
|
|
1385
|
+
# ── Handle --interactive mode ─────────────────────────────────────────────────
|
|
1386
|
+
if [ "${ARG_INTERACTIVE}" = true ]; then
|
|
1387
|
+
TRAINING_DIR="${SCRIPT_DIR}/training"
|
|
1388
|
+
_RESULT_FILE=$(mktemp "${TMPDIR:-/tmp}/mlcc-train-interactive.XXXXXX")
|
|
1389
|
+
|
|
1390
|
+
# Invoke the interactive builder via Node.js
|
|
1391
|
+
# Runs interactively (TTY prompts rendered to terminal), writes JSON result to temp file
|
|
1392
|
+
INTERACTIVE_RESULT=""
|
|
1393
|
+
if command -v node &>/dev/null; then
|
|
1394
|
+
_NPM_GLOBAL=$(npm root -g 2>/dev/null || echo /usr/local/lib/node_modules)
|
|
1395
|
+
_BUILDER_PATH="${_NPM_GLOBAL}/@aws/ml-container-creator/src/lib/train-config-builder.js"
|
|
1396
|
+
|
|
1397
|
+
# Fallback paths
|
|
1398
|
+
_project_root=$(cd "${SCRIPT_DIR}/.." && pwd)
|
|
1399
|
+
if [ ! -f "${_BUILDER_PATH}" ]; then
|
|
1400
|
+
_BUILDER_PATH="${_project_root}/node_modules/@aws/ml-container-creator/src/lib/train-config-builder.js"
|
|
1401
|
+
fi
|
|
1402
|
+
if [ ! -f "${_BUILDER_PATH}" ]; then
|
|
1403
|
+
_BUILDER_PATH="${_project_root}/src/lib/train-config-builder.js"
|
|
1404
|
+
fi
|
|
1405
|
+
|
|
1406
|
+
if [ -f "${_BUILDER_PATH}" ]; then
|
|
1407
|
+
# Run interactively (no stdout capture — prompts render to TTY)
|
|
1408
|
+
node --input-type=module -e "
|
|
1409
|
+
import { run } from '${_BUILDER_PATH}';
|
|
1410
|
+
const result = await run({ configFile: '${CONFIG_FILE}', trainingDir: '${TRAINING_DIR}' });
|
|
1411
|
+
import { writeFileSync } from 'node:fs';
|
|
1412
|
+
writeFileSync('${_RESULT_FILE}', JSON.stringify(result));
|
|
1413
|
+
" || true
|
|
1414
|
+
INTERACTIVE_RESULT=$(cat "${_RESULT_FILE}" 2>/dev/null || echo "")
|
|
1415
|
+
fi
|
|
1416
|
+
fi
|
|
1417
|
+
|
|
1418
|
+
rm -f "${_RESULT_FILE}"
|
|
1419
|
+
|
|
1420
|
+
if [ -z "${INTERACTIVE_RESULT}" ]; then
|
|
1421
|
+
echo "❌ Interactive mode requires Node.js and @aws/ml-container-creator installed."
|
|
1422
|
+
echo " Install: npm install -g @aws/ml-container-creator"
|
|
1423
|
+
echo ""
|
|
1424
|
+
echo " Alternatively, edit do/training/config.yaml directly."
|
|
1425
|
+
exit 1
|
|
1426
|
+
fi
|
|
1427
|
+
|
|
1428
|
+
# Check if user wants to run now
|
|
1429
|
+
RUN_NOW=$(echo "${INTERACTIVE_RESULT}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('run_now', False))" 2>/dev/null) || RUN_NOW="False"
|
|
1430
|
+
|
|
1431
|
+
if [ "${RUN_NOW}" != "True" ]; then
|
|
1432
|
+
echo " Run ./do/train to submit the configured job."
|
|
1433
|
+
exit 0
|
|
1434
|
+
fi
|
|
1435
|
+
|
|
1436
|
+
# Continue to normal submission flow (config was already written by builder)
|
|
1437
|
+
echo "🚀 Proceeding to job submission..."
|
|
1438
|
+
echo ""
|
|
1439
|
+
fi
|
|
1440
|
+
|
|
767
1441
|
# Parse and validate configuration
|
|
768
1442
|
_parse_config
|
|
1443
|
+
|
|
1444
|
+
# Resolve shell variables in image URI (backward compat with old-style config.yaml
|
|
1445
|
+
# that has ${AWS_ACCOUNT_ID}... instead of resolved values)
|
|
1446
|
+
if echo "${TRAIN_IMAGE}" | grep -q '${'; then
|
|
1447
|
+
TRAIN_IMAGE=$(eval echo "${TRAIN_IMAGE}" 2>/dev/null || echo "${TRAIN_IMAGE}")
|
|
1448
|
+
fi
|
|
1449
|
+
|
|
769
1450
|
_validate_config
|
|
770
1451
|
|
|
1452
|
+
# Resolve technique (CLI flag > config.yaml > default)
|
|
1453
|
+
_resolve_technique
|
|
1454
|
+
|
|
1455
|
+
# Persist technique to do/config
|
|
1456
|
+
_update_config_var "TRAIN_TECHNIQUE" "${TECHNIQUE}"
|
|
1457
|
+
|
|
1458
|
+
# Load technique defaults and merge with config hyperparameters
|
|
1459
|
+
_load_technique_defaults
|
|
1460
|
+
_merge_hyperparameters
|
|
1461
|
+
|
|
1462
|
+
# Warn about multi-node on non-EFA instances
|
|
1463
|
+
if [ "${TRAIN_INSTANCE_COUNT:-1}" -gt 1 ]; then
|
|
1464
|
+
echo "📡 Multi-node training: ${TRAIN_INSTANCE_COUNT} instances"
|
|
1465
|
+
# Check for EFA-capable instance types
|
|
1466
|
+
case "${TRAIN_INSTANCE_TYPE}" in
|
|
1467
|
+
ml.p4d*|ml.p5*|ml.g5.48xlarge|ml.g6e.48xlarge|ml.trn1*)
|
|
1468
|
+
echo " ✅ EFA-capable instance: ${TRAIN_INSTANCE_TYPE}"
|
|
1469
|
+
;;
|
|
1470
|
+
*)
|
|
1471
|
+
echo " ⚠️ Instance type '${TRAIN_INSTANCE_TYPE}' may not support EFA."
|
|
1472
|
+
echo " Multi-node training will use TCP networking (slower)."
|
|
1473
|
+
echo " For best performance, use: p4d.24xlarge, p5.48xlarge, or g5.48xlarge"
|
|
1474
|
+
;;
|
|
1475
|
+
esac
|
|
1476
|
+
echo ""
|
|
1477
|
+
fi
|
|
1478
|
+
|
|
1479
|
+
# Resolve dataset (if --dataset provided or config.yaml has dataset)
|
|
1480
|
+
_resolve_dataset
|
|
1481
|
+
|
|
1482
|
+
# Update TRAIN_DATASET with resolved S3 URI for the job request
|
|
1483
|
+
if [ -n "${RESOLVED_DATASET_S3_URI}" ]; then
|
|
1484
|
+
TRAIN_DATASET="${RESOLVED_DATASET_S3_URI}"
|
|
1485
|
+
fi
|
|
1486
|
+
|
|
1487
|
+
# Handle --resume: resolve checkpoint path from previous job
|
|
1488
|
+
if [ "${ARG_RESUME}" = true ]; then
|
|
1489
|
+
RESUME_JOB="${ARG_RESUME_JOB:-${TRAIN_JOB_NAME:-}}"
|
|
1490
|
+
if [ -z "${RESUME_JOB}" ]; then
|
|
1491
|
+
echo "❌ --resume requires a previous job name."
|
|
1492
|
+
echo " Provide it: ./do/train --resume <job-name>"
|
|
1493
|
+
echo " Or run a training job first (TRAIN_JOB_NAME will be set in do/config)."
|
|
1494
|
+
exit 1
|
|
1495
|
+
fi
|
|
1496
|
+
|
|
1497
|
+
echo "🔄 Resuming from job: ${RESUME_JOB}"
|
|
1498
|
+
|
|
1499
|
+
# Resolve checkpoint path via helper
|
|
1500
|
+
CHECKPOINT_RESOLVE=$(python3 "${SCRIPT_DIR}/.train_helper.py" resolve \
|
|
1501
|
+
--job-name "${RESUME_JOB}" \
|
|
1502
|
+
--checkpoints \
|
|
1503
|
+
--region "${AWS_REGION}" 2>/dev/null | grep -E '^\{' | tail -1) || CHECKPOINT_RESOLVE=""
|
|
1504
|
+
|
|
1505
|
+
if [ -n "${CHECKPOINT_RESOLVE}" ]; then
|
|
1506
|
+
RESUME_CHECKPOINT_PATH=$(echo "${CHECKPOINT_RESOLVE}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('checkpoint_path',''))" 2>/dev/null) || RESUME_CHECKPOINT_PATH=""
|
|
1507
|
+
if [ -n "${RESUME_CHECKPOINT_PATH}" ]; then
|
|
1508
|
+
echo " Checkpoint: ${RESUME_CHECKPOINT_PATH}"
|
|
1509
|
+
TRAIN_CHECKPOINT_PATH="${RESUME_CHECKPOINT_PATH}"
|
|
1510
|
+
else
|
|
1511
|
+
echo " ⚠️ No checkpoint path found for job ${RESUME_JOB}"
|
|
1512
|
+
echo " Training will start from scratch."
|
|
1513
|
+
fi
|
|
1514
|
+
else
|
|
1515
|
+
echo " ⚠️ Could not resolve checkpoints for job ${RESUME_JOB}"
|
|
1516
|
+
echo " Training will start from scratch."
|
|
1517
|
+
fi
|
|
1518
|
+
echo ""
|
|
1519
|
+
|
|
1520
|
+
# Force new job creation when resuming
|
|
1521
|
+
ARG_FORCE=true
|
|
1522
|
+
fi
|
|
1523
|
+
|
|
771
1524
|
# Check idempotency (existing job handling)
|
|
772
1525
|
_check_idempotency
|
|
773
1526
|
|