@aws/ml-container-creator 0.5.0 → 0.6.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1143 @@
1
+ #!/bin/bash
2
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ # do/tune — SageMaker AI Managed Model Customization
6
+ # Wraps SageMaker managed fine-tuning for supported foundation models.
7
+ # Supports SFT, DPO, RLAIF, and RLVR techniques.
8
+
9
+ set -e
10
+ set -u
11
+ set -o pipefail
12
+
13
+ # ── Source project configuration ──────────────────────────────────────────────
14
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
15
+ source "${SCRIPT_DIR}/config"
16
+
17
+ # ── Constants ─────────────────────────────────────────────────────────────────
18
+ CATALOG_FILE="${SCRIPT_DIR}/.tune_catalog.json"
19
+ HELPER_SCRIPT="${SCRIPT_DIR}/.tune_helper.py"
20
+ POLL_INTERVAL=60
21
+
22
+ # ── CLI Variables (set by _parse_args) ────────────────────────────────────────
23
+ ARG_TECHNIQUE=""
24
+ ARG_DATASET=""
25
+ ARG_TRAINING_TYPE="lora"
26
+ ARG_MODEL=""
27
+ ARG_EPOCHS=""
28
+ ARG_LEARNING_RATE=""
29
+ ARG_MAX_SEQ_LENGTH=""
30
+ ARG_LORA_RANK=""
31
+ ARG_LORA_ALPHA=""
32
+ ARG_BATCH_SIZE=""
33
+ ARG_REWARD_FUNCTION=""
34
+ ARG_REWARD_PROMPT=""
35
+ ARG_OUTPUT_BUCKET=""
36
+ ARG_ROLE=""
37
+ ARG_FORCE=false
38
+ ARG_NO_WAIT=false
39
+ ARG_STATUS=false
40
+ ARG_HELP=false
41
+ ARG_DRY_RUN=false
42
+ ARG_LIST_MODELS=false
43
+
44
+
45
+ # ── _parse_args() ─────────────────────────────────────────────────────────────
46
+ # Parse all CLI flags into variables.
47
+ _parse_args() {
48
+ while [ $# -gt 0 ]; do
49
+ case "$1" in
50
+ --technique)
51
+ if [ -z "${2:-}" ]; then
52
+ echo "❌ --technique requires a value (sft, dpo, rlaif, rlvr)"
53
+ exit 1
54
+ fi
55
+ ARG_TECHNIQUE="$2"; shift 2 ;;
56
+ --dataset)
57
+ if [ -z "${2:-}" ]; then
58
+ echo "❌ --dataset requires a value (s3://... or hf://...)"
59
+ exit 1
60
+ fi
61
+ ARG_DATASET="$2"; shift 2 ;;
62
+ --training-type)
63
+ if [ -z "${2:-}" ]; then
64
+ echo "❌ --training-type requires a value (lora, full-rank)"
65
+ exit 1
66
+ fi
67
+ ARG_TRAINING_TYPE="$2"; shift 2 ;;
68
+ --model)
69
+ if [ -z "${2:-}" ]; then
70
+ echo "❌ --model requires a JumpStart model ID"
71
+ exit 1
72
+ fi
73
+ ARG_MODEL="$2"; shift 2 ;;
74
+ --epochs)
75
+ if [ -z "${2:-}" ]; then
76
+ echo "❌ --epochs requires an integer value"
77
+ exit 1
78
+ fi
79
+ ARG_EPOCHS="$2"; shift 2 ;;
80
+ --learning-rate)
81
+ if [ -z "${2:-}" ]; then
82
+ echo "❌ --learning-rate requires a float value"
83
+ exit 1
84
+ fi
85
+ ARG_LEARNING_RATE="$2"; shift 2 ;;
86
+ --max-seq-length)
87
+ if [ -z "${2:-}" ]; then
88
+ echo "❌ --max-seq-length requires an integer value"
89
+ exit 1
90
+ fi
91
+ ARG_MAX_SEQ_LENGTH="$2"; shift 2 ;;
92
+ --lora-rank)
93
+ if [ -z "${2:-}" ]; then
94
+ echo "❌ --lora-rank requires an integer value"
95
+ exit 1
96
+ fi
97
+ ARG_LORA_RANK="$2"; shift 2 ;;
98
+ --lora-alpha)
99
+ if [ -z "${2:-}" ]; then
100
+ echo "❌ --lora-alpha requires an integer value"
101
+ exit 1
102
+ fi
103
+ ARG_LORA_ALPHA="$2"; shift 2 ;;
104
+ --batch-size)
105
+ if [ -z "${2:-}" ]; then
106
+ echo "❌ --batch-size requires an integer value"
107
+ exit 1
108
+ fi
109
+ ARG_BATCH_SIZE="$2"; shift 2 ;;
110
+ --reward-function)
111
+ if [ -z "${2:-}" ]; then
112
+ echo "❌ --reward-function requires a Lambda ARN"
113
+ exit 1
114
+ fi
115
+ ARG_REWARD_FUNCTION="$2"; shift 2 ;;
116
+ --reward-prompt)
117
+ if [ -z "${2:-}" ]; then
118
+ echo "❌ --reward-prompt requires an S3 URI"
119
+ exit 1
120
+ fi
121
+ ARG_REWARD_PROMPT="$2"; shift 2 ;;
122
+ --output-bucket)
123
+ if [ -z "${2:-}" ]; then
124
+ echo "❌ --output-bucket requires a bucket name"
125
+ exit 1
126
+ fi
127
+ ARG_OUTPUT_BUCKET="$2"; shift 2 ;;
128
+ --role)
129
+ if [ -z "${2:-}" ]; then
130
+ echo "❌ --role requires an IAM role ARN"
131
+ exit 1
132
+ fi
133
+ ARG_ROLE="$2"; shift 2 ;;
134
+ --force) ARG_FORCE=true; shift ;;
135
+ --no-wait) ARG_NO_WAIT=true; shift ;;
136
+ --status) ARG_STATUS=true; shift ;;
137
+ --help|-h) ARG_HELP=true; shift ;;
138
+ --dry-run) ARG_DRY_RUN=true; shift ;;
139
+ --list-models) ARG_LIST_MODELS=true; shift ;;
140
+ *)
141
+ echo "❌ Unknown option: $1"
142
+ echo " Run ./do/tune --help for usage."
143
+ exit 1
144
+ ;;
145
+ esac
146
+ done
147
+ }
148
+
149
+
150
+ # ── _show_help() ──────────────────────────────────────────────────────────────
151
+ _show_help() {
152
+ echo "Usage: ./do/tune --technique <technique> --dataset <source> [options]"
153
+ echo " ./do/tune --status"
154
+ echo " ./do/tune --list-models"
155
+ echo " ./do/tune --help"
156
+ echo ""
157
+ echo "SageMaker AI Managed Model Customization — fine-tune supported foundation"
158
+ echo "models using SFT, DPO, RLAIF, or RLVR without managing infrastructure."
159
+ echo ""
160
+ echo "Required:"
161
+ echo " --technique <t> Customization technique: sft, dpo, rlaif, rlvr"
162
+ echo " --dataset <source> Dataset: s3://bucket/path.jsonl or hf://org/name[/split]"
163
+ echo ""
164
+ echo "Training type:"
165
+ echo " --training-type <t> lora (default) or full-rank"
166
+ echo ""
167
+ echo "Hyperparameter overrides (optional):"
168
+ echo " --epochs <n> Number of training epochs"
169
+ echo " --learning-rate <f> Learning rate (e.g., 2e-4)"
170
+ echo " --max-seq-length <n> Maximum sequence length in tokens"
171
+ echo " --lora-rank <n> LoRA rank (e.g., 16, 32, 64)"
172
+ echo " --lora-alpha <n> LoRA alpha scaling factor"
173
+ echo " --batch-size <n> Global batch size"
174
+ echo ""
175
+ echo "Evaluator (RLVR/RLAIF only):"
176
+ echo " --reward-function <arn> Lambda ARN for reward function"
177
+ echo " --reward-prompt <uri> S3 URI for reward prompt file"
178
+ echo ""
179
+ echo "Overrides:"
180
+ echo " --model <id> Override model (defaults to MODEL_ID from do/config)"
181
+ echo " --output-bucket <b> Override output bucket (defaults to TUNE_S3_BUCKET)"
182
+ echo " --role <arn> Override execution role (defaults to ROLE_ARN)"
183
+ echo ""
184
+ echo "Job control:"
185
+ echo " --force Force new job even if one exists for this technique"
186
+ echo " --no-wait Submit and exit without polling for completion"
187
+ echo " --status Show status of all tracked tune jobs"
188
+ echo ""
189
+ echo "Informational:"
190
+ echo " --help, -h Show this help message"
191
+ echo " --dry-run Validate inputs and show what would be submitted"
192
+ echo " --list-models Print supported models, techniques, and training types"
193
+ echo ""
194
+ echo "Examples:"
195
+ echo " ./do/tune --technique sft --dataset s3://my-bucket/train.jsonl"
196
+ echo " ./do/tune --technique dpo --dataset hf://my-org/pref-data --learning-rate 1e-5"
197
+ echo " ./do/tune --technique sft --dataset s3://bucket/data.jsonl --training-type full-rank"
198
+ echo " ./do/tune --status"
199
+ echo " ./do/tune --technique sft --dataset s3://bucket/data.jsonl --dry-run"
200
+ exit 0
201
+ }
202
+
203
+ # ── _show_status() ────────────────────────────────────────────────────────────
204
+ # Display status of all tracked tune jobs from do/config.
205
+ _show_status() {
206
+ echo "📊 Tune Job Status"
207
+ echo ""
208
+
209
+ local found_any=false
210
+ for technique in sft dpo rlaif rlvr; do
211
+ local var_name="TUNE_JOB_NAME_$(echo "${technique}" | tr '[:lower:]' '[:upper:]')"
212
+ local job_name="${!var_name:-}"
213
+
214
+ if [ -n "${job_name}" ]; then
215
+ found_any=true
216
+ echo " ${technique^^}:"
217
+ echo " Job: ${job_name}"
218
+
219
+ # Query status via Python helper
220
+ local status_json
221
+ status_json=$(python3 "${HELPER_SCRIPT}" status \
222
+ --job-name "${job_name}" \
223
+ --region "${AWS_REGION}" 2>/dev/null) || status_json='{"status":"Unknown","error":"Failed to query"}'
224
+
225
+ local status
226
+ status=$(echo "${status_json}" | python3 -c "import sys,json; d=json.load(sys.stdin); print(d.get('status','Unknown'))" 2>/dev/null) || status="Unknown"
227
+
228
+ local elapsed
229
+ elapsed=$(echo "${status_json}" | python3 -c "import sys,json; d=json.load(sys.stdin); print(d.get('elapsed_seconds',0))" 2>/dev/null) || elapsed="0"
230
+
231
+ echo " Status: ${status}"
232
+ if [ "${elapsed}" != "0" ]; then
233
+ local mins=$((elapsed / 60))
234
+ local secs=$((elapsed % 60))
235
+ echo " Elapsed: ${mins}m ${secs}s"
236
+ fi
237
+
238
+ # Show output path if completed
239
+ local output_var="TUNE_ADAPTER_PATH_$(echo "${technique}" | tr '[:lower:]' '[:upper:]')"
240
+ local model_var="TUNE_MODEL_PATH_$(echo "${technique}" | tr '[:lower:]' '[:upper:]')"
241
+ if [ -n "${!output_var:-}" ]; then
242
+ echo " Output (adapter): ${!output_var}"
243
+ elif [ -n "${!model_var:-}" ]; then
244
+ echo " Output (model): ${!model_var}"
245
+ fi
246
+ echo ""
247
+ fi
248
+ done
249
+
250
+ if [ "${found_any}" = false ]; then
251
+ echo " No tune jobs tracked. Run ./do/tune --technique <t> --dataset <d> to start."
252
+ fi
253
+
254
+ exit 0
255
+ }
256
+
257
+ # ── _list_models() ────────────────────────────────────────────────────────────
258
+ # Print the Supported Model Catalog.
259
+ _list_models() {
260
+ if [ ! -f "${CATALOG_FILE}" ]; then
261
+ echo "❌ Catalog file not found: ${CATALOG_FILE}"
262
+ exit 1
263
+ fi
264
+
265
+ echo "📋 Supported Models for Managed Customization"
266
+ echo ""
267
+
268
+ python3 -c "
269
+ import json, sys
270
+
271
+ with open('${CATALOG_FILE}') as f:
272
+ catalog = json.load(f)
273
+
274
+ models = catalog.get('models', {})
275
+ # Group by family
276
+ families = {}
277
+ for model_id, entry in models.items():
278
+ family = entry.get('family', 'unknown')
279
+ if family not in families:
280
+ families[family] = []
281
+ families[family].append(entry)
282
+
283
+ for family in sorted(families.keys()):
284
+ entries = families[family]
285
+ provider = entries[0].get('provider', '')
286
+ print(f' {family} ({provider}):')
287
+ for entry in entries:
288
+ techniques = list(entry.get('techniques', {}).keys())
289
+ print(f' • {entry[\"displayName\"]}')
290
+ print(f' ID: {entry[\"jumpStartModelId\"]}')
291
+ for t in techniques:
292
+ tc = entry['techniques'][t]
293
+ types = ', '.join(tc.get('trainingTypes', []))
294
+ print(f' {t}: [{types}]')
295
+ print()
296
+ " 2>/dev/null || {
297
+ echo "❌ Failed to parse catalog. Ensure python3 is available."
298
+ exit 1
299
+ }
300
+
301
+ exit 0
302
+ }
303
+
304
+
305
+ # ── _update_config_var() ──────────────────────────────────────────────────────
306
+ # Write or update a variable in do/config.
307
+ # Usage: _update_config_var VAR_NAME "value"
308
+ _update_config_var() {
309
+ local var_name="$1"
310
+ local var_value="$2"
311
+ local config_file="${SCRIPT_DIR}/config"
312
+
313
+ if grep -q "^export ${var_name}=" "${config_file}" 2>/dev/null; then
314
+ sed -i.bak "s|^export ${var_name}=.*|export ${var_name}=\"${var_value}\"|" "${config_file}"
315
+ rm -f "${config_file}.bak"
316
+ else
317
+ echo "export ${var_name}=\"${var_value}\"" >> "${config_file}"
318
+ fi
319
+ }
320
+
321
+ # ── _validate_model() ─────────────────────────────────────────────────────────
322
+ # Read MODEL_ID from do/config (or --model override), check against catalog.
323
+ # Sets RESOLVED_MODEL_ID on success.
324
+ _validate_model() {
325
+ # Resolve model ID: --model override, MODEL_ID from config, or MODEL_NAME fallback
326
+ if [ -n "${ARG_MODEL}" ]; then
327
+ RESOLVED_MODEL_ID="${ARG_MODEL}"
328
+ elif [ -n "${MODEL_ID:-}" ]; then
329
+ RESOLVED_MODEL_ID="${MODEL_ID}"
330
+ elif [ -n "${MODEL_NAME:-}" ]; then
331
+ RESOLVED_MODEL_ID="${MODEL_NAME}"
332
+ else
333
+ echo "❌ No model configured"
334
+ echo " Set MODEL_ID in do/config or use --model <id>"
335
+ exit 1
336
+ fi
337
+
338
+ if [ ! -f "${CATALOG_FILE}" ]; then
339
+ echo "❌ Catalog file not found: ${CATALOG_FILE}"
340
+ echo " The tune catalog is required for model validation."
341
+ exit 1
342
+ fi
343
+
344
+ # Check if model is in catalog using python3 for JSON parsing
345
+ local result
346
+ result=$(python3 -c "
347
+ import json, sys
348
+
349
+ with open('${CATALOG_FILE}') as f:
350
+ catalog = json.load(f)
351
+
352
+ model_id = '${RESOLVED_MODEL_ID}'
353
+ models = catalog.get('models', {})
354
+
355
+ if model_id in models:
356
+ print('SUPPORTED')
357
+ else:
358
+ # Collect unique families
359
+ families = sorted(set(e.get('family', '') for e in models.values() if e.get('family')))
360
+ print('UNSUPPORTED|' + '|'.join(families))
361
+ " 2>/dev/null) || {
362
+ echo "❌ Failed to validate model against catalog"
363
+ echo " Ensure python3 is available."
364
+ exit 1
365
+ }
366
+
367
+ if [ "${result}" = "SUPPORTED" ]; then
368
+ return 0
369
+ fi
370
+
371
+ # Model not supported — extract families from result
372
+ local families
373
+ families=$(echo "${result}" | cut -d'|' -f2- | tr '|' ', ')
374
+
375
+ echo "❌ Model \"${RESOLVED_MODEL_ID}\" is not yet supported for managed serverless customization."
376
+ echo " Supported model families: ${families}"
377
+ echo ""
378
+ echo " Additional model support and custom training workflows are expected in future releases."
379
+ echo " For custom training workflows, see \`do/train\`."
380
+ exit 1
381
+ }
382
+
383
+ # ── _validate_technique() ─────────────────────────────────────────────────────
384
+ # Check that the technique is supported for the resolved model.
385
+ _validate_technique() {
386
+ local technique="${ARG_TECHNIQUE}"
387
+
388
+ # Validate technique value
389
+ case "${technique}" in
390
+ sft|dpo|rlaif|rlvr) ;;
391
+ *)
392
+ echo "❌ Invalid technique: ${technique}"
393
+ echo " Valid techniques: sft, dpo, rlaif, rlvr"
394
+ exit 1
395
+ ;;
396
+ esac
397
+
398
+ # Check catalog for model+technique support
399
+ local result
400
+ result=$(python3 -c "
401
+ import json, sys
402
+
403
+ with open('${CATALOG_FILE}') as f:
404
+ catalog = json.load(f)
405
+
406
+ model_id = '${RESOLVED_MODEL_ID}'
407
+ technique = '${technique}'
408
+ entry = catalog['models'][model_id]
409
+ techniques = entry.get('techniques', {})
410
+
411
+ if technique in techniques:
412
+ print('SUPPORTED')
413
+ else:
414
+ supported = list(techniques.keys())
415
+ print('UNSUPPORTED|' + '|'.join(supported))
416
+ " 2>/dev/null) || {
417
+ echo "❌ Failed to validate technique against catalog"
418
+ exit 1
419
+ }
420
+
421
+ if [ "${result}" = "SUPPORTED" ]; then
422
+ return 0
423
+ fi
424
+
425
+ local supported
426
+ supported=$(echo "${result}" | cut -d'|' -f2- | tr '|' ', ')
427
+
428
+ echo "❌ Technique \"${technique}\" is not supported for model \"${RESOLVED_MODEL_ID}\"."
429
+ echo " Supported techniques: ${supported}"
430
+ exit 1
431
+ }
432
+
433
+ # ── _validate_training_type() ─────────────────────────────────────────────────
434
+ # Check that the training type is supported for the model+technique.
435
+ _validate_training_type() {
436
+ local technique="${ARG_TECHNIQUE}"
437
+ local training_type="${ARG_TRAINING_TYPE}"
438
+
439
+ # Validate training type value
440
+ case "${training_type}" in
441
+ lora|full-rank) ;;
442
+ *)
443
+ echo "❌ Invalid training type: ${training_type}"
444
+ echo " Valid training types: lora, full-rank"
445
+ exit 1
446
+ ;;
447
+ esac
448
+
449
+ # Check catalog for model+technique+training_type support
450
+ local result
451
+ result=$(python3 -c "
452
+ import json, sys
453
+
454
+ with open('${CATALOG_FILE}') as f:
455
+ catalog = json.load(f)
456
+
457
+ model_id = '${RESOLVED_MODEL_ID}'
458
+ technique = '${technique}'
459
+ training_type = '${training_type}'
460
+ entry = catalog['models'][model_id]
461
+ technique_entry = entry['techniques'][technique]
462
+ training_types = technique_entry.get('trainingTypes', [])
463
+
464
+ if training_type in training_types:
465
+ print('SUPPORTED')
466
+ else:
467
+ print('UNSUPPORTED|' + '|'.join(training_types))
468
+ " 2>/dev/null) || {
469
+ echo "❌ Failed to validate training type against catalog"
470
+ exit 1
471
+ }
472
+
473
+ if [ "${result}" = "SUPPORTED" ]; then
474
+ return 0
475
+ fi
476
+
477
+ local supported
478
+ supported=$(echo "${result}" | cut -d'|' -f2- | tr '|' ', ')
479
+
480
+ echo "❌ Training type \"${training_type}\" is not supported for model \"${RESOLVED_MODEL_ID}\" with technique \"${technique}\"."
481
+ echo " Supported training types: ${supported}"
482
+ exit 1
483
+ }
484
+
485
+
486
+ # ── _validate_dataset() ───────────────────────────────────────────────────────
487
+ # Check S3 existence or delegate HF staging to Python helper.
488
+ # Sets RESOLVED_DATASET_S3_URI on success.
489
+ _validate_dataset() {
490
+ local dataset="${ARG_DATASET}"
491
+
492
+ if [ -z "${dataset}" ]; then
493
+ echo "❌ --dataset is required"
494
+ echo " Provide an S3 URI (s3://bucket/path.jsonl) or HF reference (hf://org/name)"
495
+ exit 1
496
+ fi
497
+
498
+ # Determine dataset type
499
+ if [[ "${dataset}" == s3://* ]]; then
500
+ # S3 dataset — verify existence
501
+ if ! aws s3 ls "${dataset}" --region "${AWS_REGION}" >/dev/null 2>&1; then
502
+ echo "❌ Dataset not found or not accessible: ${dataset}"
503
+ echo " Verify the S3 URI is correct and you have read permissions."
504
+ echo " Check: aws s3 ls ${dataset} --region ${AWS_REGION}"
505
+ exit 1
506
+ fi
507
+ RESOLVED_DATASET_S3_URI="${dataset}"
508
+
509
+ # Validate format by downloading first 10 lines
510
+ local schema_json
511
+ schema_json=$(_get_dataset_schema)
512
+
513
+ local sample_data
514
+ sample_data=$(aws s3 cp "${dataset}" - --region "${AWS_REGION}" 2>/dev/null | head -10)
515
+
516
+ if [ -n "${sample_data}" ]; then
517
+ local validate_result
518
+ validate_result=$(echo "${sample_data}" | python3 "${HELPER_SCRIPT}" validate \
519
+ --schema "${schema_json}" 2>/dev/null) || validate_result='{"valid":false,"error":"Validation failed"}'
520
+
521
+ local is_valid
522
+ is_valid=$(echo "${validate_result}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('valid', False))" 2>/dev/null) || is_valid="False"
523
+
524
+ if [ "${is_valid}" != "True" ]; then
525
+ local error_msg
526
+ error_msg=$(echo "${validate_result}" | python3 -c "import sys,json; d=json.load(sys.stdin); print(d.get('error','Unknown format error'))" 2>/dev/null) || error_msg="Unknown format error"
527
+ local malformed_line
528
+ malformed_line=$(echo "${validate_result}" | python3 -c "import sys,json; d=json.load(sys.stdin); print(d.get('malformed_line','') or '')" 2>/dev/null) || malformed_line=""
529
+ local expected_format
530
+ expected_format=$(echo "${validate_result}" | python3 -c "import sys,json; d=json.load(sys.stdin); print(d.get('expected_format','') or '')" 2>/dev/null) || expected_format=""
531
+
532
+ echo "❌ Dataset format validation failed"
533
+ echo " ${error_msg}"
534
+ if [ -n "${malformed_line}" ]; then
535
+ echo ""
536
+ echo " Malformed line: ${malformed_line}"
537
+ fi
538
+ if [ -n "${expected_format}" ]; then
539
+ echo ""
540
+ echo " ${expected_format}"
541
+ fi
542
+ exit 1
543
+ fi
544
+ fi
545
+
546
+ elif [[ "${dataset}" == hf://* ]]; then
547
+ # Hugging Face dataset — parse reference and stage to S3
548
+ local hf_path="${dataset#hf://}"
549
+ local hf_org hf_name hf_split
550
+
551
+ # Parse org/name/split
552
+ hf_org=$(echo "${hf_path}" | cut -d'/' -f1)
553
+ hf_name=$(echo "${hf_path}" | cut -d'/' -f2)
554
+ hf_split=$(echo "${hf_path}" | cut -d'/' -f3-)
555
+
556
+ if [ -z "${hf_org}" ] || [ -z "${hf_name}" ]; then
557
+ echo "❌ Invalid HF dataset reference: ${dataset}"
558
+ echo " Expected format: hf://org/name or hf://org/name/split"
559
+ exit 1
560
+ fi
561
+
562
+ local output_bucket
563
+ output_bucket=$(_resolve_output_bucket)
564
+
565
+ echo "📦 Staging Hugging Face dataset: ${hf_org}/${hf_name}"
566
+ if [ -n "${hf_split}" ]; then
567
+ echo " Split: ${hf_split}"
568
+ else
569
+ echo " Split: train (default)"
570
+ fi
571
+
572
+ # Build stage-hf arguments
573
+ local stage_args=(
574
+ --hf-org "${hf_org}"
575
+ --hf-name "${hf_name}"
576
+ --output-bucket "${output_bucket}"
577
+ --project-name "${PROJECT_NAME}"
578
+ --region "${AWS_REGION}"
579
+ )
580
+ if [ -n "${hf_split}" ]; then
581
+ stage_args+=(--hf-split "${hf_split}")
582
+ fi
583
+ if [ -n "${HF_TOKEN_ARN:-}" ]; then
584
+ stage_args+=(--hf-secret-name "${HF_TOKEN_ARN}")
585
+ fi
586
+
587
+ local stage_result
588
+ stage_result=$(python3 "${HELPER_SCRIPT}" stage-hf "${stage_args[@]}" 2>/dev/null) || {
589
+ local error_msg
590
+ error_msg=$(echo "${stage_result}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('error','Failed to stage dataset'))" 2>/dev/null) || error_msg="Failed to stage HF dataset"
591
+ echo "❌ ${error_msg}"
592
+ exit 1
593
+ }
594
+
595
+ # Check for error in response
596
+ local has_error
597
+ has_error=$(echo "${stage_result}" | python3 -c "import sys,json; d=json.load(sys.stdin); print('yes' if 'error' in d else 'no')" 2>/dev/null) || has_error="yes"
598
+
599
+ if [ "${has_error}" = "yes" ]; then
600
+ local error_msg
601
+ error_msg=$(echo "${stage_result}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('error','Unknown error'))" 2>/dev/null) || error_msg="Unknown error"
602
+ echo "❌ ${error_msg}"
603
+ exit 1
604
+ fi
605
+
606
+ RESOLVED_DATASET_S3_URI=$(echo "${stage_result}" | python3 -c "import sys,json; print(json.load(sys.stdin)['s3_uri'])" 2>/dev/null)
607
+ local num_records
608
+ num_records=$(echo "${stage_result}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('num_records',0))" 2>/dev/null) || num_records="0"
609
+
610
+ echo " ✅ Staged to: ${RESOLVED_DATASET_S3_URI}"
611
+ echo " Records: ${num_records}"
612
+ echo ""
613
+
614
+ else
615
+ echo "❌ Invalid dataset format: ${dataset}"
616
+ echo " Expected: s3://bucket/path.jsonl or hf://org/name[/split]"
617
+ exit 1
618
+ fi
619
+ }
620
+
621
+ # ── _get_dataset_schema() ─────────────────────────────────────────────────────
622
+ # Get the dataset schema JSON for the current model+technique from the catalog.
623
+ _get_dataset_schema() {
624
+ python3 -c "
625
+ import json, sys
626
+
627
+ with open('${CATALOG_FILE}') as f:
628
+ catalog = json.load(f)
629
+
630
+ model_id = '${RESOLVED_MODEL_ID}'
631
+ technique = '${ARG_TECHNIQUE}'
632
+ entry = catalog['models'][model_id]
633
+ schema = entry['techniques'][technique].get('datasetSchema', {})
634
+ print(json.dumps(schema))
635
+ " 2>/dev/null
636
+ }
637
+
638
+ # ── _resolve_output_bucket() ──────────────────────────────────────────────────
639
+ # Resolve the S3 output bucket from --output-bucket, TUNE_S3_BUCKET, or fallback.
640
+ _resolve_output_bucket() {
641
+ if [ -n "${ARG_OUTPUT_BUCKET}" ]; then
642
+ echo "${ARG_OUTPUT_BUCKET}"
643
+ elif [ -n "${TUNE_S3_BUCKET:-}" ]; then
644
+ echo "${TUNE_S3_BUCKET}"
645
+ elif [ -n "${ADAPTER_S3_BUCKET:-}" ]; then
646
+ echo "${ADAPTER_S3_BUCKET}"
647
+ else
648
+ echo "mlcc-tune-$(aws sts get-caller-identity --query Account --output text 2>/dev/null || echo 'UNKNOWN')-${AWS_REGION}"
649
+ fi
650
+ }
651
+
652
+
653
+ # ── _check_idempotency() ──────────────────────────────────────────────────────
654
+ # Check TUNE_JOB_NAME_<TECHNIQUE> in config, query status if exists.
655
+ # Returns 0 if a new job should be created, 1 if existing job was handled.
656
+ _check_idempotency() {
657
+ local technique_upper
658
+ technique_upper=$(echo "${ARG_TECHNIQUE}" | tr '[:lower:]' '[:upper:]')
659
+ local var_name="TUNE_JOB_NAME_${technique_upper}"
660
+ local existing_job="${!var_name:-}"
661
+
662
+ if [ -z "${existing_job}" ] || [ "${ARG_FORCE}" = true ]; then
663
+ return 0 # No existing job or --force: proceed with new job
664
+ fi
665
+
666
+ echo "🔍 Found existing ${ARG_TECHNIQUE^^} job: ${existing_job}"
667
+
668
+ # Query status via Python helper
669
+ local status_json
670
+ status_json=$(python3 "${HELPER_SCRIPT}" status \
671
+ --job-name "${existing_job}" \
672
+ --region "${AWS_REGION}" 2>/dev/null) || {
673
+ echo " ⚠️ Could not query job status — proceeding with new job"
674
+ return 0
675
+ }
676
+
677
+ # Check for error response
678
+ local has_error
679
+ has_error=$(echo "${status_json}" | python3 -c "import sys,json; d=json.load(sys.stdin); print('yes' if 'error' in d else 'no')" 2>/dev/null) || has_error="no"
680
+
681
+ if [ "${has_error}" = "yes" ]; then
682
+ echo " ⚠️ Could not query job status — proceeding with new job"
683
+ return 0
684
+ fi
685
+
686
+ local status
687
+ status=$(echo "${status_json}" | python3 -c "import sys,json; print(json.load(sys.stdin)['status'])" 2>/dev/null) || status="Unknown"
688
+
689
+ case "${status}" in
690
+ InProgress|Starting)
691
+ echo " Status: ${status}"
692
+ echo " (use --force to start a new job instead)"
693
+ echo ""
694
+ # Resume polling
695
+ _poll_job "${existing_job}"
696
+ _handle_completion "${existing_job}"
697
+ return 1
698
+ ;;
699
+ Completed)
700
+ echo " Status: Completed"
701
+ echo " (use --force to start a new job)"
702
+ echo ""
703
+ _handle_completion "${existing_job}"
704
+ return 1
705
+ ;;
706
+ Failed)
707
+ local failure_reason
708
+ failure_reason=$(echo "${status_json}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('failure_reason','') or 'Unknown')" 2>/dev/null) || failure_reason="Unknown"
709
+ echo " Status: Failed"
710
+ echo " Reason: ${failure_reason}"
711
+ echo ""
712
+ echo " Re-run with --force to start a new job."
713
+ exit 2
714
+ ;;
715
+ Stopped)
716
+ echo " Status: Stopped"
717
+ echo " Re-run with --force to start a new job."
718
+ exit 2
719
+ ;;
720
+ *)
721
+ echo " Status: ${status}"
722
+ echo " Proceeding with new job..."
723
+ return 0
724
+ ;;
725
+ esac
726
+ }
727
+
728
+ # ── _submit_job() ─────────────────────────────────────────────────────────────
729
+ # Invoke Python helper with all args to submit the customization job.
730
+ # Sets JOB_NAME on success.
731
+ _submit_job() {
732
+ local output_bucket
733
+ output_bucket=$(_resolve_output_bucket)
734
+
735
+ local role_arn
736
+ if [ -n "${ARG_ROLE}" ]; then
737
+ role_arn="${ARG_ROLE}"
738
+ elif [ -n "${ROLE_ARN:-}" ]; then
739
+ role_arn="${ROLE_ARN}"
740
+ else
741
+ # Try to resolve from SageMaker execution role
742
+ role_arn="${SAGEMAKER_ROLE_ARN:-}"
743
+ if [ -z "${role_arn}" ]; then
744
+ echo "❌ No execution role configured"
745
+ echo " Set ROLE_ARN in do/config or use --role <arn>"
746
+ exit 1
747
+ fi
748
+ fi
749
+
750
+ # Generate unique job name
751
+ local timestamp
752
+ timestamp=$(date +%Y%m%d-%H%M%S)
753
+ JOB_NAME="${PROJECT_NAME}-tune-${ARG_TECHNIQUE}-${timestamp}"
754
+
755
+ echo "🚀 Submitting ${ARG_TECHNIQUE^^} customization job"
756
+ echo " Job name: ${JOB_NAME}"
757
+ echo " Model: ${RESOLVED_MODEL_ID}"
758
+ echo " Technique: ${ARG_TECHNIQUE}"
759
+ echo " Training type: ${ARG_TRAINING_TYPE}"
760
+ echo " Dataset: ${RESOLVED_DATASET_S3_URI}"
761
+ echo " Output bucket: ${output_bucket}"
762
+ echo ""
763
+
764
+ # Build submit arguments
765
+ local submit_args=(
766
+ --model-id "${RESOLVED_MODEL_ID}"
767
+ --technique "${ARG_TECHNIQUE}"
768
+ --training-type "${ARG_TRAINING_TYPE}"
769
+ --dataset-s3-uri "${RESOLVED_DATASET_S3_URI}"
770
+ --output-bucket "${output_bucket}"
771
+ --role-arn "${role_arn}"
772
+ --job-name "${JOB_NAME}"
773
+ --project-name "${PROJECT_NAME}"
774
+ )
775
+
776
+ # Add model package group
777
+ submit_args+=(--model-package-group "${PROJECT_NAME}-tune-models")
778
+
779
+ # Add optional hyperparameters
780
+ if [ -n "${ARG_EPOCHS}" ]; then
781
+ submit_args+=(--epochs "${ARG_EPOCHS}")
782
+ fi
783
+ if [ -n "${ARG_LEARNING_RATE}" ]; then
784
+ submit_args+=(--learning-rate "${ARG_LEARNING_RATE}")
785
+ fi
786
+ if [ -n "${ARG_MAX_SEQ_LENGTH}" ]; then
787
+ submit_args+=(--max-seq-length "${ARG_MAX_SEQ_LENGTH}")
788
+ fi
789
+ if [ -n "${ARG_LORA_RANK}" ]; then
790
+ submit_args+=(--lora-rank "${ARG_LORA_RANK}")
791
+ fi
792
+ if [ -n "${ARG_LORA_ALPHA}" ]; then
793
+ submit_args+=(--lora-alpha "${ARG_LORA_ALPHA}")
794
+ fi
795
+ if [ -n "${ARG_BATCH_SIZE}" ]; then
796
+ submit_args+=(--batch-size "${ARG_BATCH_SIZE}")
797
+ fi
798
+ if [ -n "${ARG_REWARD_FUNCTION}" ]; then
799
+ submit_args+=(--reward-function "${ARG_REWARD_FUNCTION}")
800
+ fi
801
+ if [ -n "${ARG_REWARD_PROMPT}" ]; then
802
+ submit_args+=(--reward-prompt "${ARG_REWARD_PROMPT}")
803
+ fi
804
+
805
+ # Invoke Python helper
806
+ local submit_result
807
+ submit_result=$(python3 "${HELPER_SCRIPT}" submit "${submit_args[@]}" 2>/dev/null) || {
808
+ echo "❌ Failed to submit customization job"
809
+ echo " Ensure the SageMaker Python SDK is installed: pip install 'sagemaker>=2.232.0'"
810
+ exit 1
811
+ }
812
+
813
+ # Check for error in response
814
+ local has_error
815
+ has_error=$(echo "${submit_result}" | python3 -c "import sys,json; d=json.load(sys.stdin); print('yes' if 'error' in d else 'no')" 2>/dev/null) || has_error="yes"
816
+
817
+ if [ "${has_error}" = "yes" ]; then
818
+ local error_msg
819
+ error_msg=$(echo "${submit_result}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('error','Unknown error'))" 2>/dev/null) || error_msg="Unknown error"
820
+ echo "❌ ${error_msg}"
821
+ exit 1
822
+ fi
823
+
824
+ # Extract job name from response (may differ from our generated name)
825
+ local returned_job_name
826
+ returned_job_name=$(echo "${submit_result}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('job_name',''))" 2>/dev/null) || returned_job_name=""
827
+ if [ -n "${returned_job_name}" ]; then
828
+ JOB_NAME="${returned_job_name}"
829
+ fi
830
+
831
+ # Display MLflow URL if available
832
+ local mlflow_url
833
+ mlflow_url=$(echo "${submit_result}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('mlflow_url','') or '')" 2>/dev/null) || mlflow_url=""
834
+ if [ -n "${mlflow_url}" ]; then
835
+ echo " 📈 MLflow tracking: ${mlflow_url}"
836
+ fi
837
+
838
+ echo "✅ Job submitted: ${JOB_NAME}"
839
+ echo ""
840
+
841
+ # Store state in do/config
842
+ local technique_upper
843
+ technique_upper=$(echo "${ARG_TECHNIQUE}" | tr '[:lower:]' '[:upper:]')
844
+ _update_config_var "TUNE_JOB_NAME_${technique_upper}" "${JOB_NAME}"
845
+ _update_config_var "TUNE_TECHNIQUE" "${ARG_TECHNIQUE}"
846
+ _update_config_var "TUNE_TRAINING_TYPE" "${ARG_TRAINING_TYPE}"
847
+ _update_config_var "TUNE_DATASET_PATH" "${ARG_DATASET}"
848
+ }
849
+
850
+
851
+ # ── _poll_job() ───────────────────────────────────────────────────────────────
852
+ # Poll every 60s, display status/elapsed/step, handle Ctrl+C gracefully.
853
+ # Exits cleanly on interrupt without stopping the remote job.
854
+ _poll_job() {
855
+ local job_name="${1:-${JOB_NAME}}"
856
+
857
+ echo "⏳ Polling job status every ${POLL_INTERVAL}s..."
858
+ echo " (Ctrl+C to exit — job continues in background)"
859
+ echo ""
860
+
861
+ # Trap SIGINT for graceful exit
862
+ trap '_handle_interrupt "${job_name}"' INT
863
+
864
+ while true; do
865
+ local status_json
866
+ status_json=$(python3 "${HELPER_SCRIPT}" status \
867
+ --job-name "${job_name}" \
868
+ --region "${AWS_REGION}" 2>/dev/null) || {
869
+ echo " ⚠️ Failed to query status (will retry)"
870
+ sleep "${POLL_INTERVAL}"
871
+ continue
872
+ }
873
+
874
+ local status
875
+ status=$(echo "${status_json}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('status','Unknown'))" 2>/dev/null) || status="Unknown"
876
+
877
+ local elapsed
878
+ elapsed=$(echo "${status_json}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('elapsed_seconds',0))" 2>/dev/null) || elapsed="0"
879
+
880
+ local mins=$((elapsed / 60))
881
+ local secs=$((elapsed % 60))
882
+
883
+ case "${status}" in
884
+ Completed)
885
+ echo " ✅ $(date +%H:%M:%S) Status: Completed (${mins}m ${secs}s)"
886
+ # Restore default signal handling
887
+ trap - INT
888
+ return 0
889
+ ;;
890
+ Failed)
891
+ local failure_reason
892
+ failure_reason=$(echo "${status_json}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('failure_reason','') or 'Unknown')" 2>/dev/null) || failure_reason="Unknown"
893
+ echo " ❌ $(date +%H:%M:%S) Status: Failed (${mins}m ${secs}s)"
894
+ echo " Reason: ${failure_reason}"
895
+ trap - INT
896
+ exit 2
897
+ ;;
898
+ Stopped)
899
+ echo " ⚠️ $(date +%H:%M:%S) Status: Stopped (${mins}m ${secs}s)"
900
+ trap - INT
901
+ exit 2
902
+ ;;
903
+ *)
904
+ echo " $(date +%H:%M:%S) Status: ${status} (${mins}m ${secs}s)"
905
+ sleep "${POLL_INTERVAL}"
906
+ ;;
907
+ esac
908
+ done
909
+ }
910
+
911
+ # ── _handle_interrupt() ───────────────────────────────────────────────────────
912
+ # Handle Ctrl+C during polling — exit cleanly without stopping the remote job.
913
+ _handle_interrupt() {
914
+ local job_name="${1:-}"
915
+ echo ""
916
+ echo ""
917
+ echo "⚠️ Interrupted — job continues running in background"
918
+ echo " Job: ${job_name}"
919
+ echo ""
920
+ echo " Resume monitoring: ./do/tune --technique ${ARG_TECHNIQUE} --dataset ${ARG_DATASET}"
921
+ echo " Check status: ./do/tune --status"
922
+ exit 130
923
+ }
924
+
925
+ # ── _handle_completion() ──────────────────────────────────────────────────────
926
+ # Store output paths, detect output type, print next-step commands.
927
+ _handle_completion() {
928
+ local job_name="${1:-${JOB_NAME}}"
929
+ local technique_upper
930
+ technique_upper=$(echo "${ARG_TECHNIQUE}" | tr '[:lower:]' '[:upper:]')
931
+
932
+ # Resolve artifact path via Python helper
933
+ local resolve_args=(
934
+ --job-name "${job_name}"
935
+ --region "${AWS_REGION}"
936
+ --training-type "${ARG_TRAINING_TYPE}"
937
+ )
938
+ resolve_args+=(--model-package-group "${PROJECT_NAME}-tune-models")
939
+
940
+ local resolve_result
941
+ resolve_result=$(python3 "${HELPER_SCRIPT}" resolve "${resolve_args[@]}" 2>/dev/null) || {
942
+ echo "⚠️ Could not resolve output artifacts"
943
+ echo " Check job output manually:"
944
+ echo " python3 ${HELPER_SCRIPT} resolve --job-name ${job_name} --region ${AWS_REGION} --training-type ${ARG_TRAINING_TYPE}"
945
+ return 0
946
+ }
947
+
948
+ # Check for error
949
+ local has_error
950
+ has_error=$(echo "${resolve_result}" | python3 -c "import sys,json; d=json.load(sys.stdin); print('yes' if 'error' in d else 'no')" 2>/dev/null) || has_error="yes"
951
+
952
+ if [ "${has_error}" = "yes" ]; then
953
+ local error_msg
954
+ error_msg=$(echo "${resolve_result}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('error',''))" 2>/dev/null) || error_msg=""
955
+ if [ -n "${error_msg}" ]; then
956
+ echo " ⚠️ ${error_msg}"
957
+ fi
958
+ return 0
959
+ fi
960
+
961
+ # Extract results
962
+ local artifact_path
963
+ artifact_path=$(echo "${resolve_result}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('artifact_path',''))" 2>/dev/null) || artifact_path=""
964
+
965
+ local output_type
966
+ output_type=$(echo "${resolve_result}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('output_type',''))" 2>/dev/null) || output_type=""
967
+
968
+ local model_package_arn
969
+ model_package_arn=$(echo "${resolve_result}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('model_package_arn','') or '')" 2>/dev/null) || model_package_arn=""
970
+
971
+ # Display results
972
+ echo "🎉 Customization complete!"
973
+ echo ""
974
+ echo " Output type: ${output_type}"
975
+ echo " Artifact: ${artifact_path}"
976
+ if [ -n "${model_package_arn}" ]; then
977
+ echo " Model Package: ${model_package_arn}"
978
+ fi
979
+ echo ""
980
+
981
+ # Store output paths in config
982
+ if [ "${output_type}" = "adapter" ]; then
983
+ _update_config_var "TUNE_ADAPTER_PATH_${technique_upper}" "${artifact_path}"
984
+ else
985
+ _update_config_var "TUNE_MODEL_PATH_${technique_upper}" "${artifact_path}"
986
+ fi
987
+ _update_config_var "TUNE_OUTPUT_PATH_LATEST" "${artifact_path}"
988
+ _update_config_var "TUNE_OUTPUT_TYPE_LATEST" "${output_type}"
989
+
990
+ # Print next-step commands
991
+ echo "📋 Next steps:"
992
+ echo ""
993
+ if [ "${output_type}" = "adapter" ]; then
994
+ echo " Deploy as LoRA adapter:"
995
+ echo " ./do/adapter add tuned-${ARG_TECHNIQUE} --from-tune"
996
+ echo " ./do/adapter add tuned-${ARG_TECHNIQUE} --from-tune ${ARG_TECHNIQUE}"
997
+ echo " ./do/adapter add tuned-${ARG_TECHNIQUE} --weights ${artifact_path}"
998
+ else
999
+ echo " Deploy as new inference component:"
1000
+ echo " ./do/add-ic tuned-v1 --from-tune"
1001
+ echo " ./do/add-ic tuned-v1 --model-data ${artifact_path}"
1002
+ echo " ./do/deploy --force-ic --model-data ${artifact_path}"
1003
+ fi
1004
+ echo ""
1005
+ }
1006
+
1007
+
1008
+ # ── _dry_run() ────────────────────────────────────────────────────────────────
1009
+ # Validate all inputs and show what would be submitted without creating a job.
1010
+ _dry_run() {
1011
+ local output_bucket
1012
+ output_bucket=$(_resolve_output_bucket)
1013
+
1014
+ local role_arn
1015
+ if [ -n "${ARG_ROLE}" ]; then
1016
+ role_arn="${ARG_ROLE}"
1017
+ elif [ -n "${ROLE_ARN:-}" ]; then
1018
+ role_arn="${ROLE_ARN}"
1019
+ else
1020
+ role_arn="${SAGEMAKER_ROLE_ARN:-<not configured>}"
1021
+ fi
1022
+
1023
+ local timestamp
1024
+ timestamp=$(date +%Y%m%d-%H%M%S)
1025
+ local job_name="${PROJECT_NAME}-tune-${ARG_TECHNIQUE}-${timestamp}"
1026
+
1027
+ echo "🔍 Dry run — validation passed, would submit:"
1028
+ echo ""
1029
+ echo " Job name: ${job_name}"
1030
+ echo " Model: ${RESOLVED_MODEL_ID}"
1031
+ echo " Technique: ${ARG_TECHNIQUE}"
1032
+ echo " Training type: ${ARG_TRAINING_TYPE}"
1033
+ echo " Dataset: ${RESOLVED_DATASET_S3_URI}"
1034
+ echo " Output bucket: ${output_bucket}"
1035
+ echo " Role: ${role_arn}"
1036
+ echo " Package group: ${PROJECT_NAME}-tune-models"
1037
+
1038
+ if [ -n "${ARG_EPOCHS}" ]; then
1039
+ echo " Epochs: ${ARG_EPOCHS}"
1040
+ fi
1041
+ if [ -n "${ARG_LEARNING_RATE}" ]; then
1042
+ echo " Learning rate: ${ARG_LEARNING_RATE}"
1043
+ fi
1044
+ if [ -n "${ARG_MAX_SEQ_LENGTH}" ]; then
1045
+ echo " Max seq len: ${ARG_MAX_SEQ_LENGTH}"
1046
+ fi
1047
+ if [ -n "${ARG_LORA_RANK}" ]; then
1048
+ echo " LoRA rank: ${ARG_LORA_RANK}"
1049
+ fi
1050
+ if [ -n "${ARG_LORA_ALPHA}" ]; then
1051
+ echo " LoRA alpha: ${ARG_LORA_ALPHA}"
1052
+ fi
1053
+ if [ -n "${ARG_BATCH_SIZE}" ]; then
1054
+ echo " Batch size: ${ARG_BATCH_SIZE}"
1055
+ fi
1056
+ if [ -n "${ARG_REWARD_FUNCTION}" ]; then
1057
+ echo " Reward fn: ${ARG_REWARD_FUNCTION}"
1058
+ fi
1059
+ if [ -n "${ARG_REWARD_PROMPT}" ]; then
1060
+ echo " Reward prompt: ${ARG_REWARD_PROMPT}"
1061
+ fi
1062
+
1063
+ echo ""
1064
+ echo " ✅ All validations passed. Remove --dry-run to submit."
1065
+ exit 0
1066
+ }
1067
+
1068
+ # ══════════════════════════════════════════════════════════════════════════════
1069
+ # MAIN
1070
+ # ══════════════════════════════════════════════════════════════════════════════
1071
+
1072
+ _parse_args "$@"
1073
+
1074
+ # Handle informational flags first
1075
+ if [ "${ARG_HELP}" = true ]; then
1076
+ _show_help
1077
+ fi
1078
+
1079
+ if [ "${ARG_LIST_MODELS}" = true ]; then
1080
+ _list_models
1081
+ fi
1082
+
1083
+ if [ "${ARG_STATUS}" = true ]; then
1084
+ _show_status
1085
+ fi
1086
+
1087
+ # Validate required arguments for job submission
1088
+ if [ -z "${ARG_TECHNIQUE}" ]; then
1089
+ echo "❌ --technique is required"
1090
+ echo " Usage: ./do/tune --technique <sft|dpo|rlaif|rlvr> --dataset <source>"
1091
+ echo " Run ./do/tune --help for full usage."
1092
+ exit 1
1093
+ fi
1094
+
1095
+ if [ -z "${ARG_DATASET}" ]; then
1096
+ echo "❌ --dataset is required"
1097
+ echo " Usage: ./do/tune --technique ${ARG_TECHNIQUE} --dataset <s3://... or hf://...>"
1098
+ echo " Run ./do/tune --help for full usage."
1099
+ exit 1
1100
+ fi
1101
+
1102
+ # Check runtime support
1103
+ if [ "${TUNE_SUPPORTED:-}" = "false" ]; then
1104
+ echo "⚠️ Managed customization is not supported for the configured model."
1105
+ echo " Checking catalog for current support..."
1106
+ echo ""
1107
+ fi
1108
+
1109
+ # Validate Python availability
1110
+ if ! command -v python3 &>/dev/null; then
1111
+ echo "❌ python3 is required but not found"
1112
+ echo " Install Python 3 to use managed model customization."
1113
+ exit 1
1114
+ fi
1115
+
1116
+ # Run validations
1117
+ echo "🔧 SageMaker AI Managed Model Customization"
1118
+ echo ""
1119
+
1120
+ _validate_model
1121
+ _validate_technique
1122
+ _validate_training_type
1123
+ _validate_dataset
1124
+
1125
+ # Check idempotency (may exit early if existing job is handled)
1126
+ if _check_idempotency; then
1127
+ # No existing job or --force: proceed with submission
1128
+
1129
+ if [ "${ARG_DRY_RUN}" = true ]; then
1130
+ _dry_run
1131
+ fi
1132
+
1133
+ _submit_job
1134
+
1135
+ if [ "${ARG_NO_WAIT}" = true ]; then
1136
+ echo " --no-wait specified. Job running in background."
1137
+ echo " Check status: ./do/tune --status"
1138
+ exit 0
1139
+ fi
1140
+
1141
+ _poll_job
1142
+ _handle_completion
1143
+ fi