@aws/ml-container-creator 0.5.0 → 0.6.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/bin/cli.js +9 -0
- package/config/bootstrap-stack.json +106 -9
- package/infra/ci-harness/package-lock.json +5 -1
- package/package.json +1 -1
- package/servers/instance-sizer/index.js +4 -4
- package/servers/instance-sizer/lib/model-resolver.js +1 -1
- package/servers/lib/catalogs/model-sizes.json +135 -90
- package/servers/lib/catalogs/models.json +483 -411
- package/src/app.js +29 -1
- package/src/lib/bootstrap-command-handler.js +71 -23
- package/src/lib/cli-handler.js +1 -1
- package/src/lib/config-manager.js +1 -1
- package/src/lib/mcp-client.js +3 -3
- package/src/lib/prompt-runner.js +5 -5
- package/src/lib/prompts.js +31 -5
- package/src/lib/tune-catalog-validator.js +143 -0
- package/src/lib/tune-config-state.js +116 -0
- package/src/lib/tune-dataset-validator.js +279 -0
- package/src/lib/tune-output-resolver.js +66 -0
- package/templates/do/.tune_helper.py +768 -0
- package/templates/do/adapter +128 -17
- package/templates/do/add-ic +155 -19
- package/templates/do/config +11 -4
- package/templates/do/tune +1143 -0
|
@@ -0,0 +1,1143 @@
|
|
|
1
|
+
#!/bin/bash
|
|
2
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
# do/tune — SageMaker AI Managed Model Customization
|
|
6
|
+
# Wraps SageMaker managed fine-tuning for supported foundation models.
|
|
7
|
+
# Supports SFT, DPO, RLAIF, and RLVR techniques.
|
|
8
|
+
|
|
9
|
+
set -e
|
|
10
|
+
set -u
|
|
11
|
+
set -o pipefail
|
|
12
|
+
|
|
13
|
+
# ── Source project configuration ──────────────────────────────────────────────
|
|
14
|
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
|
15
|
+
source "${SCRIPT_DIR}/config"
|
|
16
|
+
|
|
17
|
+
# ── Constants ─────────────────────────────────────────────────────────────────
|
|
18
|
+
CATALOG_FILE="${SCRIPT_DIR}/.tune_catalog.json"
|
|
19
|
+
HELPER_SCRIPT="${SCRIPT_DIR}/.tune_helper.py"
|
|
20
|
+
POLL_INTERVAL=60
|
|
21
|
+
|
|
22
|
+
# ── CLI Variables (set by _parse_args) ────────────────────────────────────────
|
|
23
|
+
ARG_TECHNIQUE=""
|
|
24
|
+
ARG_DATASET=""
|
|
25
|
+
ARG_TRAINING_TYPE="lora"
|
|
26
|
+
ARG_MODEL=""
|
|
27
|
+
ARG_EPOCHS=""
|
|
28
|
+
ARG_LEARNING_RATE=""
|
|
29
|
+
ARG_MAX_SEQ_LENGTH=""
|
|
30
|
+
ARG_LORA_RANK=""
|
|
31
|
+
ARG_LORA_ALPHA=""
|
|
32
|
+
ARG_BATCH_SIZE=""
|
|
33
|
+
ARG_REWARD_FUNCTION=""
|
|
34
|
+
ARG_REWARD_PROMPT=""
|
|
35
|
+
ARG_OUTPUT_BUCKET=""
|
|
36
|
+
ARG_ROLE=""
|
|
37
|
+
ARG_FORCE=false
|
|
38
|
+
ARG_NO_WAIT=false
|
|
39
|
+
ARG_STATUS=false
|
|
40
|
+
ARG_HELP=false
|
|
41
|
+
ARG_DRY_RUN=false
|
|
42
|
+
ARG_LIST_MODELS=false
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# ── _parse_args() ─────────────────────────────────────────────────────────────
|
|
46
|
+
# Parse all CLI flags into variables.
|
|
47
|
+
_parse_args() {
|
|
48
|
+
while [ $# -gt 0 ]; do
|
|
49
|
+
case "$1" in
|
|
50
|
+
--technique)
|
|
51
|
+
if [ -z "${2:-}" ]; then
|
|
52
|
+
echo "❌ --technique requires a value (sft, dpo, rlaif, rlvr)"
|
|
53
|
+
exit 1
|
|
54
|
+
fi
|
|
55
|
+
ARG_TECHNIQUE="$2"; shift 2 ;;
|
|
56
|
+
--dataset)
|
|
57
|
+
if [ -z "${2:-}" ]; then
|
|
58
|
+
echo "❌ --dataset requires a value (s3://... or hf://...)"
|
|
59
|
+
exit 1
|
|
60
|
+
fi
|
|
61
|
+
ARG_DATASET="$2"; shift 2 ;;
|
|
62
|
+
--training-type)
|
|
63
|
+
if [ -z "${2:-}" ]; then
|
|
64
|
+
echo "❌ --training-type requires a value (lora, full-rank)"
|
|
65
|
+
exit 1
|
|
66
|
+
fi
|
|
67
|
+
ARG_TRAINING_TYPE="$2"; shift 2 ;;
|
|
68
|
+
--model)
|
|
69
|
+
if [ -z "${2:-}" ]; then
|
|
70
|
+
echo "❌ --model requires a JumpStart model ID"
|
|
71
|
+
exit 1
|
|
72
|
+
fi
|
|
73
|
+
ARG_MODEL="$2"; shift 2 ;;
|
|
74
|
+
--epochs)
|
|
75
|
+
if [ -z "${2:-}" ]; then
|
|
76
|
+
echo "❌ --epochs requires an integer value"
|
|
77
|
+
exit 1
|
|
78
|
+
fi
|
|
79
|
+
ARG_EPOCHS="$2"; shift 2 ;;
|
|
80
|
+
--learning-rate)
|
|
81
|
+
if [ -z "${2:-}" ]; then
|
|
82
|
+
echo "❌ --learning-rate requires a float value"
|
|
83
|
+
exit 1
|
|
84
|
+
fi
|
|
85
|
+
ARG_LEARNING_RATE="$2"; shift 2 ;;
|
|
86
|
+
--max-seq-length)
|
|
87
|
+
if [ -z "${2:-}" ]; then
|
|
88
|
+
echo "❌ --max-seq-length requires an integer value"
|
|
89
|
+
exit 1
|
|
90
|
+
fi
|
|
91
|
+
ARG_MAX_SEQ_LENGTH="$2"; shift 2 ;;
|
|
92
|
+
--lora-rank)
|
|
93
|
+
if [ -z "${2:-}" ]; then
|
|
94
|
+
echo "❌ --lora-rank requires an integer value"
|
|
95
|
+
exit 1
|
|
96
|
+
fi
|
|
97
|
+
ARG_LORA_RANK="$2"; shift 2 ;;
|
|
98
|
+
--lora-alpha)
|
|
99
|
+
if [ -z "${2:-}" ]; then
|
|
100
|
+
echo "❌ --lora-alpha requires an integer value"
|
|
101
|
+
exit 1
|
|
102
|
+
fi
|
|
103
|
+
ARG_LORA_ALPHA="$2"; shift 2 ;;
|
|
104
|
+
--batch-size)
|
|
105
|
+
if [ -z "${2:-}" ]; then
|
|
106
|
+
echo "❌ --batch-size requires an integer value"
|
|
107
|
+
exit 1
|
|
108
|
+
fi
|
|
109
|
+
ARG_BATCH_SIZE="$2"; shift 2 ;;
|
|
110
|
+
--reward-function)
|
|
111
|
+
if [ -z "${2:-}" ]; then
|
|
112
|
+
echo "❌ --reward-function requires a Lambda ARN"
|
|
113
|
+
exit 1
|
|
114
|
+
fi
|
|
115
|
+
ARG_REWARD_FUNCTION="$2"; shift 2 ;;
|
|
116
|
+
--reward-prompt)
|
|
117
|
+
if [ -z "${2:-}" ]; then
|
|
118
|
+
echo "❌ --reward-prompt requires an S3 URI"
|
|
119
|
+
exit 1
|
|
120
|
+
fi
|
|
121
|
+
ARG_REWARD_PROMPT="$2"; shift 2 ;;
|
|
122
|
+
--output-bucket)
|
|
123
|
+
if [ -z "${2:-}" ]; then
|
|
124
|
+
echo "❌ --output-bucket requires a bucket name"
|
|
125
|
+
exit 1
|
|
126
|
+
fi
|
|
127
|
+
ARG_OUTPUT_BUCKET="$2"; shift 2 ;;
|
|
128
|
+
--role)
|
|
129
|
+
if [ -z "${2:-}" ]; then
|
|
130
|
+
echo "❌ --role requires an IAM role ARN"
|
|
131
|
+
exit 1
|
|
132
|
+
fi
|
|
133
|
+
ARG_ROLE="$2"; shift 2 ;;
|
|
134
|
+
--force) ARG_FORCE=true; shift ;;
|
|
135
|
+
--no-wait) ARG_NO_WAIT=true; shift ;;
|
|
136
|
+
--status) ARG_STATUS=true; shift ;;
|
|
137
|
+
--help|-h) ARG_HELP=true; shift ;;
|
|
138
|
+
--dry-run) ARG_DRY_RUN=true; shift ;;
|
|
139
|
+
--list-models) ARG_LIST_MODELS=true; shift ;;
|
|
140
|
+
*)
|
|
141
|
+
echo "❌ Unknown option: $1"
|
|
142
|
+
echo " Run ./do/tune --help for usage."
|
|
143
|
+
exit 1
|
|
144
|
+
;;
|
|
145
|
+
esac
|
|
146
|
+
done
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
# ── _show_help() ──────────────────────────────────────────────────────────────
|
|
151
|
+
_show_help() {
|
|
152
|
+
echo "Usage: ./do/tune --technique <technique> --dataset <source> [options]"
|
|
153
|
+
echo " ./do/tune --status"
|
|
154
|
+
echo " ./do/tune --list-models"
|
|
155
|
+
echo " ./do/tune --help"
|
|
156
|
+
echo ""
|
|
157
|
+
echo "SageMaker AI Managed Model Customization — fine-tune supported foundation"
|
|
158
|
+
echo "models using SFT, DPO, RLAIF, or RLVR without managing infrastructure."
|
|
159
|
+
echo ""
|
|
160
|
+
echo "Required:"
|
|
161
|
+
echo " --technique <t> Customization technique: sft, dpo, rlaif, rlvr"
|
|
162
|
+
echo " --dataset <source> Dataset: s3://bucket/path.jsonl or hf://org/name[/split]"
|
|
163
|
+
echo ""
|
|
164
|
+
echo "Training type:"
|
|
165
|
+
echo " --training-type <t> lora (default) or full-rank"
|
|
166
|
+
echo ""
|
|
167
|
+
echo "Hyperparameter overrides (optional):"
|
|
168
|
+
echo " --epochs <n> Number of training epochs"
|
|
169
|
+
echo " --learning-rate <f> Learning rate (e.g., 2e-4)"
|
|
170
|
+
echo " --max-seq-length <n> Maximum sequence length in tokens"
|
|
171
|
+
echo " --lora-rank <n> LoRA rank (e.g., 16, 32, 64)"
|
|
172
|
+
echo " --lora-alpha <n> LoRA alpha scaling factor"
|
|
173
|
+
echo " --batch-size <n> Global batch size"
|
|
174
|
+
echo ""
|
|
175
|
+
echo "Evaluator (RLVR/RLAIF only):"
|
|
176
|
+
echo " --reward-function <arn> Lambda ARN for reward function"
|
|
177
|
+
echo " --reward-prompt <uri> S3 URI for reward prompt file"
|
|
178
|
+
echo ""
|
|
179
|
+
echo "Overrides:"
|
|
180
|
+
echo " --model <id> Override model (defaults to MODEL_ID from do/config)"
|
|
181
|
+
echo " --output-bucket <b> Override output bucket (defaults to TUNE_S3_BUCKET)"
|
|
182
|
+
echo " --role <arn> Override execution role (defaults to ROLE_ARN)"
|
|
183
|
+
echo ""
|
|
184
|
+
echo "Job control:"
|
|
185
|
+
echo " --force Force new job even if one exists for this technique"
|
|
186
|
+
echo " --no-wait Submit and exit without polling for completion"
|
|
187
|
+
echo " --status Show status of all tracked tune jobs"
|
|
188
|
+
echo ""
|
|
189
|
+
echo "Informational:"
|
|
190
|
+
echo " --help, -h Show this help message"
|
|
191
|
+
echo " --dry-run Validate inputs and show what would be submitted"
|
|
192
|
+
echo " --list-models Print supported models, techniques, and training types"
|
|
193
|
+
echo ""
|
|
194
|
+
echo "Examples:"
|
|
195
|
+
echo " ./do/tune --technique sft --dataset s3://my-bucket/train.jsonl"
|
|
196
|
+
echo " ./do/tune --technique dpo --dataset hf://my-org/pref-data --learning-rate 1e-5"
|
|
197
|
+
echo " ./do/tune --technique sft --dataset s3://bucket/data.jsonl --training-type full-rank"
|
|
198
|
+
echo " ./do/tune --status"
|
|
199
|
+
echo " ./do/tune --technique sft --dataset s3://bucket/data.jsonl --dry-run"
|
|
200
|
+
exit 0
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
# ── _show_status() ────────────────────────────────────────────────────────────
|
|
204
|
+
# Display status of all tracked tune jobs from do/config.
|
|
205
|
+
_show_status() {
|
|
206
|
+
echo "📊 Tune Job Status"
|
|
207
|
+
echo ""
|
|
208
|
+
|
|
209
|
+
local found_any=false
|
|
210
|
+
for technique in sft dpo rlaif rlvr; do
|
|
211
|
+
local var_name="TUNE_JOB_NAME_$(echo "${technique}" | tr '[:lower:]' '[:upper:]')"
|
|
212
|
+
local job_name="${!var_name:-}"
|
|
213
|
+
|
|
214
|
+
if [ -n "${job_name}" ]; then
|
|
215
|
+
found_any=true
|
|
216
|
+
echo " ${technique^^}:"
|
|
217
|
+
echo " Job: ${job_name}"
|
|
218
|
+
|
|
219
|
+
# Query status via Python helper
|
|
220
|
+
local status_json
|
|
221
|
+
status_json=$(python3 "${HELPER_SCRIPT}" status \
|
|
222
|
+
--job-name "${job_name}" \
|
|
223
|
+
--region "${AWS_REGION}" 2>/dev/null) || status_json='{"status":"Unknown","error":"Failed to query"}'
|
|
224
|
+
|
|
225
|
+
local status
|
|
226
|
+
status=$(echo "${status_json}" | python3 -c "import sys,json; d=json.load(sys.stdin); print(d.get('status','Unknown'))" 2>/dev/null) || status="Unknown"
|
|
227
|
+
|
|
228
|
+
local elapsed
|
|
229
|
+
elapsed=$(echo "${status_json}" | python3 -c "import sys,json; d=json.load(sys.stdin); print(d.get('elapsed_seconds',0))" 2>/dev/null) || elapsed="0"
|
|
230
|
+
|
|
231
|
+
echo " Status: ${status}"
|
|
232
|
+
if [ "${elapsed}" != "0" ]; then
|
|
233
|
+
local mins=$((elapsed / 60))
|
|
234
|
+
local secs=$((elapsed % 60))
|
|
235
|
+
echo " Elapsed: ${mins}m ${secs}s"
|
|
236
|
+
fi
|
|
237
|
+
|
|
238
|
+
# Show output path if completed
|
|
239
|
+
local output_var="TUNE_ADAPTER_PATH_$(echo "${technique}" | tr '[:lower:]' '[:upper:]')"
|
|
240
|
+
local model_var="TUNE_MODEL_PATH_$(echo "${technique}" | tr '[:lower:]' '[:upper:]')"
|
|
241
|
+
if [ -n "${!output_var:-}" ]; then
|
|
242
|
+
echo " Output (adapter): ${!output_var}"
|
|
243
|
+
elif [ -n "${!model_var:-}" ]; then
|
|
244
|
+
echo " Output (model): ${!model_var}"
|
|
245
|
+
fi
|
|
246
|
+
echo ""
|
|
247
|
+
fi
|
|
248
|
+
done
|
|
249
|
+
|
|
250
|
+
if [ "${found_any}" = false ]; then
|
|
251
|
+
echo " No tune jobs tracked. Run ./do/tune --technique <t> --dataset <d> to start."
|
|
252
|
+
fi
|
|
253
|
+
|
|
254
|
+
exit 0
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
# ── _list_models() ────────────────────────────────────────────────────────────
|
|
258
|
+
# Print the Supported Model Catalog.
|
|
259
|
+
_list_models() {
|
|
260
|
+
if [ ! -f "${CATALOG_FILE}" ]; then
|
|
261
|
+
echo "❌ Catalog file not found: ${CATALOG_FILE}"
|
|
262
|
+
exit 1
|
|
263
|
+
fi
|
|
264
|
+
|
|
265
|
+
echo "📋 Supported Models for Managed Customization"
|
|
266
|
+
echo ""
|
|
267
|
+
|
|
268
|
+
python3 -c "
|
|
269
|
+
import json, sys
|
|
270
|
+
|
|
271
|
+
with open('${CATALOG_FILE}') as f:
|
|
272
|
+
catalog = json.load(f)
|
|
273
|
+
|
|
274
|
+
models = catalog.get('models', {})
|
|
275
|
+
# Group by family
|
|
276
|
+
families = {}
|
|
277
|
+
for model_id, entry in models.items():
|
|
278
|
+
family = entry.get('family', 'unknown')
|
|
279
|
+
if family not in families:
|
|
280
|
+
families[family] = []
|
|
281
|
+
families[family].append(entry)
|
|
282
|
+
|
|
283
|
+
for family in sorted(families.keys()):
|
|
284
|
+
entries = families[family]
|
|
285
|
+
provider = entries[0].get('provider', '')
|
|
286
|
+
print(f' {family} ({provider}):')
|
|
287
|
+
for entry in entries:
|
|
288
|
+
techniques = list(entry.get('techniques', {}).keys())
|
|
289
|
+
print(f' • {entry[\"displayName\"]}')
|
|
290
|
+
print(f' ID: {entry[\"jumpStartModelId\"]}')
|
|
291
|
+
for t in techniques:
|
|
292
|
+
tc = entry['techniques'][t]
|
|
293
|
+
types = ', '.join(tc.get('trainingTypes', []))
|
|
294
|
+
print(f' {t}: [{types}]')
|
|
295
|
+
print()
|
|
296
|
+
" 2>/dev/null || {
|
|
297
|
+
echo "❌ Failed to parse catalog. Ensure python3 is available."
|
|
298
|
+
exit 1
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
exit 0
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
# ── _update_config_var() ──────────────────────────────────────────────────────
|
|
306
|
+
# Write or update a variable in do/config.
|
|
307
|
+
# Usage: _update_config_var VAR_NAME "value"
|
|
308
|
+
_update_config_var() {
|
|
309
|
+
local var_name="$1"
|
|
310
|
+
local var_value="$2"
|
|
311
|
+
local config_file="${SCRIPT_DIR}/config"
|
|
312
|
+
|
|
313
|
+
if grep -q "^export ${var_name}=" "${config_file}" 2>/dev/null; then
|
|
314
|
+
sed -i.bak "s|^export ${var_name}=.*|export ${var_name}=\"${var_value}\"|" "${config_file}"
|
|
315
|
+
rm -f "${config_file}.bak"
|
|
316
|
+
else
|
|
317
|
+
echo "export ${var_name}=\"${var_value}\"" >> "${config_file}"
|
|
318
|
+
fi
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
# ── _validate_model() ─────────────────────────────────────────────────────────
|
|
322
|
+
# Read MODEL_ID from do/config (or --model override), check against catalog.
|
|
323
|
+
# Sets RESOLVED_MODEL_ID on success.
|
|
324
|
+
_validate_model() {
|
|
325
|
+
# Resolve model ID: --model override, MODEL_ID from config, or MODEL_NAME fallback
|
|
326
|
+
if [ -n "${ARG_MODEL}" ]; then
|
|
327
|
+
RESOLVED_MODEL_ID="${ARG_MODEL}"
|
|
328
|
+
elif [ -n "${MODEL_ID:-}" ]; then
|
|
329
|
+
RESOLVED_MODEL_ID="${MODEL_ID}"
|
|
330
|
+
elif [ -n "${MODEL_NAME:-}" ]; then
|
|
331
|
+
RESOLVED_MODEL_ID="${MODEL_NAME}"
|
|
332
|
+
else
|
|
333
|
+
echo "❌ No model configured"
|
|
334
|
+
echo " Set MODEL_ID in do/config or use --model <id>"
|
|
335
|
+
exit 1
|
|
336
|
+
fi
|
|
337
|
+
|
|
338
|
+
if [ ! -f "${CATALOG_FILE}" ]; then
|
|
339
|
+
echo "❌ Catalog file not found: ${CATALOG_FILE}"
|
|
340
|
+
echo " The tune catalog is required for model validation."
|
|
341
|
+
exit 1
|
|
342
|
+
fi
|
|
343
|
+
|
|
344
|
+
# Check if model is in catalog using python3 for JSON parsing
|
|
345
|
+
local result
|
|
346
|
+
result=$(python3 -c "
|
|
347
|
+
import json, sys
|
|
348
|
+
|
|
349
|
+
with open('${CATALOG_FILE}') as f:
|
|
350
|
+
catalog = json.load(f)
|
|
351
|
+
|
|
352
|
+
model_id = '${RESOLVED_MODEL_ID}'
|
|
353
|
+
models = catalog.get('models', {})
|
|
354
|
+
|
|
355
|
+
if model_id in models:
|
|
356
|
+
print('SUPPORTED')
|
|
357
|
+
else:
|
|
358
|
+
# Collect unique families
|
|
359
|
+
families = sorted(set(e.get('family', '') for e in models.values() if e.get('family')))
|
|
360
|
+
print('UNSUPPORTED|' + '|'.join(families))
|
|
361
|
+
" 2>/dev/null) || {
|
|
362
|
+
echo "❌ Failed to validate model against catalog"
|
|
363
|
+
echo " Ensure python3 is available."
|
|
364
|
+
exit 1
|
|
365
|
+
}
|
|
366
|
+
|
|
367
|
+
if [ "${result}" = "SUPPORTED" ]; then
|
|
368
|
+
return 0
|
|
369
|
+
fi
|
|
370
|
+
|
|
371
|
+
# Model not supported — extract families from result
|
|
372
|
+
local families
|
|
373
|
+
families=$(echo "${result}" | cut -d'|' -f2- | tr '|' ', ')
|
|
374
|
+
|
|
375
|
+
echo "❌ Model \"${RESOLVED_MODEL_ID}\" is not yet supported for managed serverless customization."
|
|
376
|
+
echo " Supported model families: ${families}"
|
|
377
|
+
echo ""
|
|
378
|
+
echo " Additional model support and custom training workflows are expected in future releases."
|
|
379
|
+
echo " For custom training workflows, see \`do/train\`."
|
|
380
|
+
exit 1
|
|
381
|
+
}
|
|
382
|
+
|
|
383
|
+
# ── _validate_technique() ─────────────────────────────────────────────────────
|
|
384
|
+
# Check that the technique is supported for the resolved model.
|
|
385
|
+
_validate_technique() {
|
|
386
|
+
local technique="${ARG_TECHNIQUE}"
|
|
387
|
+
|
|
388
|
+
# Validate technique value
|
|
389
|
+
case "${technique}" in
|
|
390
|
+
sft|dpo|rlaif|rlvr) ;;
|
|
391
|
+
*)
|
|
392
|
+
echo "❌ Invalid technique: ${technique}"
|
|
393
|
+
echo " Valid techniques: sft, dpo, rlaif, rlvr"
|
|
394
|
+
exit 1
|
|
395
|
+
;;
|
|
396
|
+
esac
|
|
397
|
+
|
|
398
|
+
# Check catalog for model+technique support
|
|
399
|
+
local result
|
|
400
|
+
result=$(python3 -c "
|
|
401
|
+
import json, sys
|
|
402
|
+
|
|
403
|
+
with open('${CATALOG_FILE}') as f:
|
|
404
|
+
catalog = json.load(f)
|
|
405
|
+
|
|
406
|
+
model_id = '${RESOLVED_MODEL_ID}'
|
|
407
|
+
technique = '${technique}'
|
|
408
|
+
entry = catalog['models'][model_id]
|
|
409
|
+
techniques = entry.get('techniques', {})
|
|
410
|
+
|
|
411
|
+
if technique in techniques:
|
|
412
|
+
print('SUPPORTED')
|
|
413
|
+
else:
|
|
414
|
+
supported = list(techniques.keys())
|
|
415
|
+
print('UNSUPPORTED|' + '|'.join(supported))
|
|
416
|
+
" 2>/dev/null) || {
|
|
417
|
+
echo "❌ Failed to validate technique against catalog"
|
|
418
|
+
exit 1
|
|
419
|
+
}
|
|
420
|
+
|
|
421
|
+
if [ "${result}" = "SUPPORTED" ]; then
|
|
422
|
+
return 0
|
|
423
|
+
fi
|
|
424
|
+
|
|
425
|
+
local supported
|
|
426
|
+
supported=$(echo "${result}" | cut -d'|' -f2- | tr '|' ', ')
|
|
427
|
+
|
|
428
|
+
echo "❌ Technique \"${technique}\" is not supported for model \"${RESOLVED_MODEL_ID}\"."
|
|
429
|
+
echo " Supported techniques: ${supported}"
|
|
430
|
+
exit 1
|
|
431
|
+
}
|
|
432
|
+
|
|
433
|
+
# ── _validate_training_type() ─────────────────────────────────────────────────
|
|
434
|
+
# Check that the training type is supported for the model+technique.
|
|
435
|
+
_validate_training_type() {
|
|
436
|
+
local technique="${ARG_TECHNIQUE}"
|
|
437
|
+
local training_type="${ARG_TRAINING_TYPE}"
|
|
438
|
+
|
|
439
|
+
# Validate training type value
|
|
440
|
+
case "${training_type}" in
|
|
441
|
+
lora|full-rank) ;;
|
|
442
|
+
*)
|
|
443
|
+
echo "❌ Invalid training type: ${training_type}"
|
|
444
|
+
echo " Valid training types: lora, full-rank"
|
|
445
|
+
exit 1
|
|
446
|
+
;;
|
|
447
|
+
esac
|
|
448
|
+
|
|
449
|
+
# Check catalog for model+technique+training_type support
|
|
450
|
+
local result
|
|
451
|
+
result=$(python3 -c "
|
|
452
|
+
import json, sys
|
|
453
|
+
|
|
454
|
+
with open('${CATALOG_FILE}') as f:
|
|
455
|
+
catalog = json.load(f)
|
|
456
|
+
|
|
457
|
+
model_id = '${RESOLVED_MODEL_ID}'
|
|
458
|
+
technique = '${technique}'
|
|
459
|
+
training_type = '${training_type}'
|
|
460
|
+
entry = catalog['models'][model_id]
|
|
461
|
+
technique_entry = entry['techniques'][technique]
|
|
462
|
+
training_types = technique_entry.get('trainingTypes', [])
|
|
463
|
+
|
|
464
|
+
if training_type in training_types:
|
|
465
|
+
print('SUPPORTED')
|
|
466
|
+
else:
|
|
467
|
+
print('UNSUPPORTED|' + '|'.join(training_types))
|
|
468
|
+
" 2>/dev/null) || {
|
|
469
|
+
echo "❌ Failed to validate training type against catalog"
|
|
470
|
+
exit 1
|
|
471
|
+
}
|
|
472
|
+
|
|
473
|
+
if [ "${result}" = "SUPPORTED" ]; then
|
|
474
|
+
return 0
|
|
475
|
+
fi
|
|
476
|
+
|
|
477
|
+
local supported
|
|
478
|
+
supported=$(echo "${result}" | cut -d'|' -f2- | tr '|' ', ')
|
|
479
|
+
|
|
480
|
+
echo "❌ Training type \"${training_type}\" is not supported for model \"${RESOLVED_MODEL_ID}\" with technique \"${technique}\"."
|
|
481
|
+
echo " Supported training types: ${supported}"
|
|
482
|
+
exit 1
|
|
483
|
+
}
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
# ── _validate_dataset() ───────────────────────────────────────────────────────
|
|
487
|
+
# Check S3 existence or delegate HF staging to Python helper.
|
|
488
|
+
# Sets RESOLVED_DATASET_S3_URI on success.
|
|
489
|
+
_validate_dataset() {
|
|
490
|
+
local dataset="${ARG_DATASET}"
|
|
491
|
+
|
|
492
|
+
if [ -z "${dataset}" ]; then
|
|
493
|
+
echo "❌ --dataset is required"
|
|
494
|
+
echo " Provide an S3 URI (s3://bucket/path.jsonl) or HF reference (hf://org/name)"
|
|
495
|
+
exit 1
|
|
496
|
+
fi
|
|
497
|
+
|
|
498
|
+
# Determine dataset type
|
|
499
|
+
if [[ "${dataset}" == s3://* ]]; then
|
|
500
|
+
# S3 dataset — verify existence
|
|
501
|
+
if ! aws s3 ls "${dataset}" --region "${AWS_REGION}" >/dev/null 2>&1; then
|
|
502
|
+
echo "❌ Dataset not found or not accessible: ${dataset}"
|
|
503
|
+
echo " Verify the S3 URI is correct and you have read permissions."
|
|
504
|
+
echo " Check: aws s3 ls ${dataset} --region ${AWS_REGION}"
|
|
505
|
+
exit 1
|
|
506
|
+
fi
|
|
507
|
+
RESOLVED_DATASET_S3_URI="${dataset}"
|
|
508
|
+
|
|
509
|
+
# Validate format by downloading first 10 lines
|
|
510
|
+
local schema_json
|
|
511
|
+
schema_json=$(_get_dataset_schema)
|
|
512
|
+
|
|
513
|
+
local sample_data
|
|
514
|
+
sample_data=$(aws s3 cp "${dataset}" - --region "${AWS_REGION}" 2>/dev/null | head -10)
|
|
515
|
+
|
|
516
|
+
if [ -n "${sample_data}" ]; then
|
|
517
|
+
local validate_result
|
|
518
|
+
validate_result=$(echo "${sample_data}" | python3 "${HELPER_SCRIPT}" validate \
|
|
519
|
+
--schema "${schema_json}" 2>/dev/null) || validate_result='{"valid":false,"error":"Validation failed"}'
|
|
520
|
+
|
|
521
|
+
local is_valid
|
|
522
|
+
is_valid=$(echo "${validate_result}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('valid', False))" 2>/dev/null) || is_valid="False"
|
|
523
|
+
|
|
524
|
+
if [ "${is_valid}" != "True" ]; then
|
|
525
|
+
local error_msg
|
|
526
|
+
error_msg=$(echo "${validate_result}" | python3 -c "import sys,json; d=json.load(sys.stdin); print(d.get('error','Unknown format error'))" 2>/dev/null) || error_msg="Unknown format error"
|
|
527
|
+
local malformed_line
|
|
528
|
+
malformed_line=$(echo "${validate_result}" | python3 -c "import sys,json; d=json.load(sys.stdin); print(d.get('malformed_line','') or '')" 2>/dev/null) || malformed_line=""
|
|
529
|
+
local expected_format
|
|
530
|
+
expected_format=$(echo "${validate_result}" | python3 -c "import sys,json; d=json.load(sys.stdin); print(d.get('expected_format','') or '')" 2>/dev/null) || expected_format=""
|
|
531
|
+
|
|
532
|
+
echo "❌ Dataset format validation failed"
|
|
533
|
+
echo " ${error_msg}"
|
|
534
|
+
if [ -n "${malformed_line}" ]; then
|
|
535
|
+
echo ""
|
|
536
|
+
echo " Malformed line: ${malformed_line}"
|
|
537
|
+
fi
|
|
538
|
+
if [ -n "${expected_format}" ]; then
|
|
539
|
+
echo ""
|
|
540
|
+
echo " ${expected_format}"
|
|
541
|
+
fi
|
|
542
|
+
exit 1
|
|
543
|
+
fi
|
|
544
|
+
fi
|
|
545
|
+
|
|
546
|
+
elif [[ "${dataset}" == hf://* ]]; then
|
|
547
|
+
# Hugging Face dataset — parse reference and stage to S3
|
|
548
|
+
local hf_path="${dataset#hf://}"
|
|
549
|
+
local hf_org hf_name hf_split
|
|
550
|
+
|
|
551
|
+
# Parse org/name/split
|
|
552
|
+
hf_org=$(echo "${hf_path}" | cut -d'/' -f1)
|
|
553
|
+
hf_name=$(echo "${hf_path}" | cut -d'/' -f2)
|
|
554
|
+
hf_split=$(echo "${hf_path}" | cut -d'/' -f3-)
|
|
555
|
+
|
|
556
|
+
if [ -z "${hf_org}" ] || [ -z "${hf_name}" ]; then
|
|
557
|
+
echo "❌ Invalid HF dataset reference: ${dataset}"
|
|
558
|
+
echo " Expected format: hf://org/name or hf://org/name/split"
|
|
559
|
+
exit 1
|
|
560
|
+
fi
|
|
561
|
+
|
|
562
|
+
local output_bucket
|
|
563
|
+
output_bucket=$(_resolve_output_bucket)
|
|
564
|
+
|
|
565
|
+
echo "📦 Staging Hugging Face dataset: ${hf_org}/${hf_name}"
|
|
566
|
+
if [ -n "${hf_split}" ]; then
|
|
567
|
+
echo " Split: ${hf_split}"
|
|
568
|
+
else
|
|
569
|
+
echo " Split: train (default)"
|
|
570
|
+
fi
|
|
571
|
+
|
|
572
|
+
# Build stage-hf arguments
|
|
573
|
+
local stage_args=(
|
|
574
|
+
--hf-org "${hf_org}"
|
|
575
|
+
--hf-name "${hf_name}"
|
|
576
|
+
--output-bucket "${output_bucket}"
|
|
577
|
+
--project-name "${PROJECT_NAME}"
|
|
578
|
+
--region "${AWS_REGION}"
|
|
579
|
+
)
|
|
580
|
+
if [ -n "${hf_split}" ]; then
|
|
581
|
+
stage_args+=(--hf-split "${hf_split}")
|
|
582
|
+
fi
|
|
583
|
+
if [ -n "${HF_TOKEN_ARN:-}" ]; then
|
|
584
|
+
stage_args+=(--hf-secret-name "${HF_TOKEN_ARN}")
|
|
585
|
+
fi
|
|
586
|
+
|
|
587
|
+
local stage_result
|
|
588
|
+
stage_result=$(python3 "${HELPER_SCRIPT}" stage-hf "${stage_args[@]}" 2>/dev/null) || {
|
|
589
|
+
local error_msg
|
|
590
|
+
error_msg=$(echo "${stage_result}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('error','Failed to stage dataset'))" 2>/dev/null) || error_msg="Failed to stage HF dataset"
|
|
591
|
+
echo "❌ ${error_msg}"
|
|
592
|
+
exit 1
|
|
593
|
+
}
|
|
594
|
+
|
|
595
|
+
# Check for error in response
|
|
596
|
+
local has_error
|
|
597
|
+
has_error=$(echo "${stage_result}" | python3 -c "import sys,json; d=json.load(sys.stdin); print('yes' if 'error' in d else 'no')" 2>/dev/null) || has_error="yes"
|
|
598
|
+
|
|
599
|
+
if [ "${has_error}" = "yes" ]; then
|
|
600
|
+
local error_msg
|
|
601
|
+
error_msg=$(echo "${stage_result}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('error','Unknown error'))" 2>/dev/null) || error_msg="Unknown error"
|
|
602
|
+
echo "❌ ${error_msg}"
|
|
603
|
+
exit 1
|
|
604
|
+
fi
|
|
605
|
+
|
|
606
|
+
RESOLVED_DATASET_S3_URI=$(echo "${stage_result}" | python3 -c "import sys,json; print(json.load(sys.stdin)['s3_uri'])" 2>/dev/null)
|
|
607
|
+
local num_records
|
|
608
|
+
num_records=$(echo "${stage_result}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('num_records',0))" 2>/dev/null) || num_records="0"
|
|
609
|
+
|
|
610
|
+
echo " ✅ Staged to: ${RESOLVED_DATASET_S3_URI}"
|
|
611
|
+
echo " Records: ${num_records}"
|
|
612
|
+
echo ""
|
|
613
|
+
|
|
614
|
+
else
|
|
615
|
+
echo "❌ Invalid dataset format: ${dataset}"
|
|
616
|
+
echo " Expected: s3://bucket/path.jsonl or hf://org/name[/split]"
|
|
617
|
+
exit 1
|
|
618
|
+
fi
|
|
619
|
+
}
|
|
620
|
+
|
|
621
|
+
# ── _get_dataset_schema() ─────────────────────────────────────────────────────
|
|
622
|
+
# Get the dataset schema JSON for the current model+technique from the catalog.
|
|
623
|
+
_get_dataset_schema() {
|
|
624
|
+
python3 -c "
|
|
625
|
+
import json, sys
|
|
626
|
+
|
|
627
|
+
with open('${CATALOG_FILE}') as f:
|
|
628
|
+
catalog = json.load(f)
|
|
629
|
+
|
|
630
|
+
model_id = '${RESOLVED_MODEL_ID}'
|
|
631
|
+
technique = '${ARG_TECHNIQUE}'
|
|
632
|
+
entry = catalog['models'][model_id]
|
|
633
|
+
schema = entry['techniques'][technique].get('datasetSchema', {})
|
|
634
|
+
print(json.dumps(schema))
|
|
635
|
+
" 2>/dev/null
|
|
636
|
+
}
|
|
637
|
+
|
|
638
|
+
# ── _resolve_output_bucket() ──────────────────────────────────────────────────
|
|
639
|
+
# Resolve the S3 output bucket from --output-bucket, TUNE_S3_BUCKET, or fallback.
|
|
640
|
+
_resolve_output_bucket() {
|
|
641
|
+
if [ -n "${ARG_OUTPUT_BUCKET}" ]; then
|
|
642
|
+
echo "${ARG_OUTPUT_BUCKET}"
|
|
643
|
+
elif [ -n "${TUNE_S3_BUCKET:-}" ]; then
|
|
644
|
+
echo "${TUNE_S3_BUCKET}"
|
|
645
|
+
elif [ -n "${ADAPTER_S3_BUCKET:-}" ]; then
|
|
646
|
+
echo "${ADAPTER_S3_BUCKET}"
|
|
647
|
+
else
|
|
648
|
+
echo "mlcc-tune-$(aws sts get-caller-identity --query Account --output text 2>/dev/null || echo 'UNKNOWN')-${AWS_REGION}"
|
|
649
|
+
fi
|
|
650
|
+
}
|
|
651
|
+
|
|
652
|
+
|
|
653
|
+
# ── _check_idempotency() ──────────────────────────────────────────────────────
|
|
654
|
+
# Check TUNE_JOB_NAME_<TECHNIQUE> in config, query status if exists.
|
|
655
|
+
# Returns 0 if a new job should be created, 1 if existing job was handled.
|
|
656
|
+
_check_idempotency() {
|
|
657
|
+
local technique_upper
|
|
658
|
+
technique_upper=$(echo "${ARG_TECHNIQUE}" | tr '[:lower:]' '[:upper:]')
|
|
659
|
+
local var_name="TUNE_JOB_NAME_${technique_upper}"
|
|
660
|
+
local existing_job="${!var_name:-}"
|
|
661
|
+
|
|
662
|
+
if [ -z "${existing_job}" ] || [ "${ARG_FORCE}" = true ]; then
|
|
663
|
+
return 0 # No existing job or --force: proceed with new job
|
|
664
|
+
fi
|
|
665
|
+
|
|
666
|
+
echo "🔍 Found existing ${ARG_TECHNIQUE^^} job: ${existing_job}"
|
|
667
|
+
|
|
668
|
+
# Query status via Python helper
|
|
669
|
+
local status_json
|
|
670
|
+
status_json=$(python3 "${HELPER_SCRIPT}" status \
|
|
671
|
+
--job-name "${existing_job}" \
|
|
672
|
+
--region "${AWS_REGION}" 2>/dev/null) || {
|
|
673
|
+
echo " ⚠️ Could not query job status — proceeding with new job"
|
|
674
|
+
return 0
|
|
675
|
+
}
|
|
676
|
+
|
|
677
|
+
# Check for error response
|
|
678
|
+
local has_error
|
|
679
|
+
has_error=$(echo "${status_json}" | python3 -c "import sys,json; d=json.load(sys.stdin); print('yes' if 'error' in d else 'no')" 2>/dev/null) || has_error="no"
|
|
680
|
+
|
|
681
|
+
if [ "${has_error}" = "yes" ]; then
|
|
682
|
+
echo " ⚠️ Could not query job status — proceeding with new job"
|
|
683
|
+
return 0
|
|
684
|
+
fi
|
|
685
|
+
|
|
686
|
+
local status
|
|
687
|
+
status=$(echo "${status_json}" | python3 -c "import sys,json; print(json.load(sys.stdin)['status'])" 2>/dev/null) || status="Unknown"
|
|
688
|
+
|
|
689
|
+
case "${status}" in
|
|
690
|
+
InProgress|Starting)
|
|
691
|
+
echo " Status: ${status}"
|
|
692
|
+
echo " (use --force to start a new job instead)"
|
|
693
|
+
echo ""
|
|
694
|
+
# Resume polling
|
|
695
|
+
_poll_job "${existing_job}"
|
|
696
|
+
_handle_completion "${existing_job}"
|
|
697
|
+
return 1
|
|
698
|
+
;;
|
|
699
|
+
Completed)
|
|
700
|
+
echo " Status: Completed"
|
|
701
|
+
echo " (use --force to start a new job)"
|
|
702
|
+
echo ""
|
|
703
|
+
_handle_completion "${existing_job}"
|
|
704
|
+
return 1
|
|
705
|
+
;;
|
|
706
|
+
Failed)
|
|
707
|
+
local failure_reason
|
|
708
|
+
failure_reason=$(echo "${status_json}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('failure_reason','') or 'Unknown')" 2>/dev/null) || failure_reason="Unknown"
|
|
709
|
+
echo " Status: Failed"
|
|
710
|
+
echo " Reason: ${failure_reason}"
|
|
711
|
+
echo ""
|
|
712
|
+
echo " Re-run with --force to start a new job."
|
|
713
|
+
exit 2
|
|
714
|
+
;;
|
|
715
|
+
Stopped)
|
|
716
|
+
echo " Status: Stopped"
|
|
717
|
+
echo " Re-run with --force to start a new job."
|
|
718
|
+
exit 2
|
|
719
|
+
;;
|
|
720
|
+
*)
|
|
721
|
+
echo " Status: ${status}"
|
|
722
|
+
echo " Proceeding with new job..."
|
|
723
|
+
return 0
|
|
724
|
+
;;
|
|
725
|
+
esac
|
|
726
|
+
}
|
|
727
|
+
|
|
728
|
+
# ── _submit_job() ─────────────────────────────────────────────────────────────
|
|
729
|
+
# Invoke Python helper with all args to submit the customization job.
|
|
730
|
+
# Sets JOB_NAME on success.
|
|
731
|
+
_submit_job() {
|
|
732
|
+
local output_bucket
|
|
733
|
+
output_bucket=$(_resolve_output_bucket)
|
|
734
|
+
|
|
735
|
+
local role_arn
|
|
736
|
+
if [ -n "${ARG_ROLE}" ]; then
|
|
737
|
+
role_arn="${ARG_ROLE}"
|
|
738
|
+
elif [ -n "${ROLE_ARN:-}" ]; then
|
|
739
|
+
role_arn="${ROLE_ARN}"
|
|
740
|
+
else
|
|
741
|
+
# Try to resolve from SageMaker execution role
|
|
742
|
+
role_arn="${SAGEMAKER_ROLE_ARN:-}"
|
|
743
|
+
if [ -z "${role_arn}" ]; then
|
|
744
|
+
echo "❌ No execution role configured"
|
|
745
|
+
echo " Set ROLE_ARN in do/config or use --role <arn>"
|
|
746
|
+
exit 1
|
|
747
|
+
fi
|
|
748
|
+
fi
|
|
749
|
+
|
|
750
|
+
# Generate unique job name
|
|
751
|
+
local timestamp
|
|
752
|
+
timestamp=$(date +%Y%m%d-%H%M%S)
|
|
753
|
+
JOB_NAME="${PROJECT_NAME}-tune-${ARG_TECHNIQUE}-${timestamp}"
|
|
754
|
+
|
|
755
|
+
echo "🚀 Submitting ${ARG_TECHNIQUE^^} customization job"
|
|
756
|
+
echo " Job name: ${JOB_NAME}"
|
|
757
|
+
echo " Model: ${RESOLVED_MODEL_ID}"
|
|
758
|
+
echo " Technique: ${ARG_TECHNIQUE}"
|
|
759
|
+
echo " Training type: ${ARG_TRAINING_TYPE}"
|
|
760
|
+
echo " Dataset: ${RESOLVED_DATASET_S3_URI}"
|
|
761
|
+
echo " Output bucket: ${output_bucket}"
|
|
762
|
+
echo ""
|
|
763
|
+
|
|
764
|
+
# Build submit arguments
|
|
765
|
+
local submit_args=(
|
|
766
|
+
--model-id "${RESOLVED_MODEL_ID}"
|
|
767
|
+
--technique "${ARG_TECHNIQUE}"
|
|
768
|
+
--training-type "${ARG_TRAINING_TYPE}"
|
|
769
|
+
--dataset-s3-uri "${RESOLVED_DATASET_S3_URI}"
|
|
770
|
+
--output-bucket "${output_bucket}"
|
|
771
|
+
--role-arn "${role_arn}"
|
|
772
|
+
--job-name "${JOB_NAME}"
|
|
773
|
+
--project-name "${PROJECT_NAME}"
|
|
774
|
+
)
|
|
775
|
+
|
|
776
|
+
# Add model package group
|
|
777
|
+
submit_args+=(--model-package-group "${PROJECT_NAME}-tune-models")
|
|
778
|
+
|
|
779
|
+
# Add optional hyperparameters
|
|
780
|
+
if [ -n "${ARG_EPOCHS}" ]; then
|
|
781
|
+
submit_args+=(--epochs "${ARG_EPOCHS}")
|
|
782
|
+
fi
|
|
783
|
+
if [ -n "${ARG_LEARNING_RATE}" ]; then
|
|
784
|
+
submit_args+=(--learning-rate "${ARG_LEARNING_RATE}")
|
|
785
|
+
fi
|
|
786
|
+
if [ -n "${ARG_MAX_SEQ_LENGTH}" ]; then
|
|
787
|
+
submit_args+=(--max-seq-length "${ARG_MAX_SEQ_LENGTH}")
|
|
788
|
+
fi
|
|
789
|
+
if [ -n "${ARG_LORA_RANK}" ]; then
|
|
790
|
+
submit_args+=(--lora-rank "${ARG_LORA_RANK}")
|
|
791
|
+
fi
|
|
792
|
+
if [ -n "${ARG_LORA_ALPHA}" ]; then
|
|
793
|
+
submit_args+=(--lora-alpha "${ARG_LORA_ALPHA}")
|
|
794
|
+
fi
|
|
795
|
+
if [ -n "${ARG_BATCH_SIZE}" ]; then
|
|
796
|
+
submit_args+=(--batch-size "${ARG_BATCH_SIZE}")
|
|
797
|
+
fi
|
|
798
|
+
if [ -n "${ARG_REWARD_FUNCTION}" ]; then
|
|
799
|
+
submit_args+=(--reward-function "${ARG_REWARD_FUNCTION}")
|
|
800
|
+
fi
|
|
801
|
+
if [ -n "${ARG_REWARD_PROMPT}" ]; then
|
|
802
|
+
submit_args+=(--reward-prompt "${ARG_REWARD_PROMPT}")
|
|
803
|
+
fi
|
|
804
|
+
|
|
805
|
+
# Invoke Python helper
|
|
806
|
+
local submit_result
|
|
807
|
+
submit_result=$(python3 "${HELPER_SCRIPT}" submit "${submit_args[@]}" 2>/dev/null) || {
|
|
808
|
+
echo "❌ Failed to submit customization job"
|
|
809
|
+
echo " Ensure the SageMaker Python SDK is installed: pip install 'sagemaker>=2.232.0'"
|
|
810
|
+
exit 1
|
|
811
|
+
}
|
|
812
|
+
|
|
813
|
+
# Check for error in response
|
|
814
|
+
local has_error
|
|
815
|
+
has_error=$(echo "${submit_result}" | python3 -c "import sys,json; d=json.load(sys.stdin); print('yes' if 'error' in d else 'no')" 2>/dev/null) || has_error="yes"
|
|
816
|
+
|
|
817
|
+
if [ "${has_error}" = "yes" ]; then
|
|
818
|
+
local error_msg
|
|
819
|
+
error_msg=$(echo "${submit_result}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('error','Unknown error'))" 2>/dev/null) || error_msg="Unknown error"
|
|
820
|
+
echo "❌ ${error_msg}"
|
|
821
|
+
exit 1
|
|
822
|
+
fi
|
|
823
|
+
|
|
824
|
+
# Extract job name from response (may differ from our generated name)
|
|
825
|
+
local returned_job_name
|
|
826
|
+
returned_job_name=$(echo "${submit_result}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('job_name',''))" 2>/dev/null) || returned_job_name=""
|
|
827
|
+
if [ -n "${returned_job_name}" ]; then
|
|
828
|
+
JOB_NAME="${returned_job_name}"
|
|
829
|
+
fi
|
|
830
|
+
|
|
831
|
+
# Display MLflow URL if available
|
|
832
|
+
local mlflow_url
|
|
833
|
+
mlflow_url=$(echo "${submit_result}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('mlflow_url','') or '')" 2>/dev/null) || mlflow_url=""
|
|
834
|
+
if [ -n "${mlflow_url}" ]; then
|
|
835
|
+
echo " 📈 MLflow tracking: ${mlflow_url}"
|
|
836
|
+
fi
|
|
837
|
+
|
|
838
|
+
echo "✅ Job submitted: ${JOB_NAME}"
|
|
839
|
+
echo ""
|
|
840
|
+
|
|
841
|
+
# Store state in do/config
|
|
842
|
+
local technique_upper
|
|
843
|
+
technique_upper=$(echo "${ARG_TECHNIQUE}" | tr '[:lower:]' '[:upper:]')
|
|
844
|
+
_update_config_var "TUNE_JOB_NAME_${technique_upper}" "${JOB_NAME}"
|
|
845
|
+
_update_config_var "TUNE_TECHNIQUE" "${ARG_TECHNIQUE}"
|
|
846
|
+
_update_config_var "TUNE_TRAINING_TYPE" "${ARG_TRAINING_TYPE}"
|
|
847
|
+
_update_config_var "TUNE_DATASET_PATH" "${ARG_DATASET}"
|
|
848
|
+
}
|
|
849
|
+
|
|
850
|
+
|
|
851
|
+
# ── _poll_job() ───────────────────────────────────────────────────────────────
|
|
852
|
+
# Poll every 60s, display status/elapsed/step, handle Ctrl+C gracefully.
|
|
853
|
+
# Exits cleanly on interrupt without stopping the remote job.
|
|
854
|
+
_poll_job() {
|
|
855
|
+
local job_name="${1:-${JOB_NAME}}"
|
|
856
|
+
|
|
857
|
+
echo "⏳ Polling job status every ${POLL_INTERVAL}s..."
|
|
858
|
+
echo " (Ctrl+C to exit — job continues in background)"
|
|
859
|
+
echo ""
|
|
860
|
+
|
|
861
|
+
# Trap SIGINT for graceful exit
|
|
862
|
+
trap '_handle_interrupt "${job_name}"' INT
|
|
863
|
+
|
|
864
|
+
while true; do
|
|
865
|
+
local status_json
|
|
866
|
+
status_json=$(python3 "${HELPER_SCRIPT}" status \
|
|
867
|
+
--job-name "${job_name}" \
|
|
868
|
+
--region "${AWS_REGION}" 2>/dev/null) || {
|
|
869
|
+
echo " ⚠️ Failed to query status (will retry)"
|
|
870
|
+
sleep "${POLL_INTERVAL}"
|
|
871
|
+
continue
|
|
872
|
+
}
|
|
873
|
+
|
|
874
|
+
local status
|
|
875
|
+
status=$(echo "${status_json}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('status','Unknown'))" 2>/dev/null) || status="Unknown"
|
|
876
|
+
|
|
877
|
+
local elapsed
|
|
878
|
+
elapsed=$(echo "${status_json}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('elapsed_seconds',0))" 2>/dev/null) || elapsed="0"
|
|
879
|
+
|
|
880
|
+
local mins=$((elapsed / 60))
|
|
881
|
+
local secs=$((elapsed % 60))
|
|
882
|
+
|
|
883
|
+
case "${status}" in
|
|
884
|
+
Completed)
|
|
885
|
+
echo " ✅ $(date +%H:%M:%S) Status: Completed (${mins}m ${secs}s)"
|
|
886
|
+
# Restore default signal handling
|
|
887
|
+
trap - INT
|
|
888
|
+
return 0
|
|
889
|
+
;;
|
|
890
|
+
Failed)
|
|
891
|
+
local failure_reason
|
|
892
|
+
failure_reason=$(echo "${status_json}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('failure_reason','') or 'Unknown')" 2>/dev/null) || failure_reason="Unknown"
|
|
893
|
+
echo " ❌ $(date +%H:%M:%S) Status: Failed (${mins}m ${secs}s)"
|
|
894
|
+
echo " Reason: ${failure_reason}"
|
|
895
|
+
trap - INT
|
|
896
|
+
exit 2
|
|
897
|
+
;;
|
|
898
|
+
Stopped)
|
|
899
|
+
echo " ⚠️ $(date +%H:%M:%S) Status: Stopped (${mins}m ${secs}s)"
|
|
900
|
+
trap - INT
|
|
901
|
+
exit 2
|
|
902
|
+
;;
|
|
903
|
+
*)
|
|
904
|
+
echo " $(date +%H:%M:%S) Status: ${status} (${mins}m ${secs}s)"
|
|
905
|
+
sleep "${POLL_INTERVAL}"
|
|
906
|
+
;;
|
|
907
|
+
esac
|
|
908
|
+
done
|
|
909
|
+
}
|
|
910
|
+
|
|
911
|
+
# ── _handle_interrupt() ───────────────────────────────────────────────────────
|
|
912
|
+
# Handle Ctrl+C during polling — exit cleanly without stopping the remote job.
|
|
913
|
+
_handle_interrupt() {
|
|
914
|
+
local job_name="${1:-}"
|
|
915
|
+
echo ""
|
|
916
|
+
echo ""
|
|
917
|
+
echo "⚠️ Interrupted — job continues running in background"
|
|
918
|
+
echo " Job: ${job_name}"
|
|
919
|
+
echo ""
|
|
920
|
+
echo " Resume monitoring: ./do/tune --technique ${ARG_TECHNIQUE} --dataset ${ARG_DATASET}"
|
|
921
|
+
echo " Check status: ./do/tune --status"
|
|
922
|
+
exit 130
|
|
923
|
+
}
|
|
924
|
+
|
|
925
|
+
# ── _handle_completion() ──────────────────────────────────────────────────────
|
|
926
|
+
# Store output paths, detect output type, print next-step commands.
|
|
927
|
+
_handle_completion() {
|
|
928
|
+
local job_name="${1:-${JOB_NAME}}"
|
|
929
|
+
local technique_upper
|
|
930
|
+
technique_upper=$(echo "${ARG_TECHNIQUE}" | tr '[:lower:]' '[:upper:]')
|
|
931
|
+
|
|
932
|
+
# Resolve artifact path via Python helper
|
|
933
|
+
local resolve_args=(
|
|
934
|
+
--job-name "${job_name}"
|
|
935
|
+
--region "${AWS_REGION}"
|
|
936
|
+
--training-type "${ARG_TRAINING_TYPE}"
|
|
937
|
+
)
|
|
938
|
+
resolve_args+=(--model-package-group "${PROJECT_NAME}-tune-models")
|
|
939
|
+
|
|
940
|
+
local resolve_result
|
|
941
|
+
resolve_result=$(python3 "${HELPER_SCRIPT}" resolve "${resolve_args[@]}" 2>/dev/null) || {
|
|
942
|
+
echo "⚠️ Could not resolve output artifacts"
|
|
943
|
+
echo " Check job output manually:"
|
|
944
|
+
echo " python3 ${HELPER_SCRIPT} resolve --job-name ${job_name} --region ${AWS_REGION} --training-type ${ARG_TRAINING_TYPE}"
|
|
945
|
+
return 0
|
|
946
|
+
}
|
|
947
|
+
|
|
948
|
+
# Check for error
|
|
949
|
+
local has_error
|
|
950
|
+
has_error=$(echo "${resolve_result}" | python3 -c "import sys,json; d=json.load(sys.stdin); print('yes' if 'error' in d else 'no')" 2>/dev/null) || has_error="yes"
|
|
951
|
+
|
|
952
|
+
if [ "${has_error}" = "yes" ]; then
|
|
953
|
+
local error_msg
|
|
954
|
+
error_msg=$(echo "${resolve_result}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('error',''))" 2>/dev/null) || error_msg=""
|
|
955
|
+
if [ -n "${error_msg}" ]; then
|
|
956
|
+
echo " ⚠️ ${error_msg}"
|
|
957
|
+
fi
|
|
958
|
+
return 0
|
|
959
|
+
fi
|
|
960
|
+
|
|
961
|
+
# Extract results
|
|
962
|
+
local artifact_path
|
|
963
|
+
artifact_path=$(echo "${resolve_result}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('artifact_path',''))" 2>/dev/null) || artifact_path=""
|
|
964
|
+
|
|
965
|
+
local output_type
|
|
966
|
+
output_type=$(echo "${resolve_result}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('output_type',''))" 2>/dev/null) || output_type=""
|
|
967
|
+
|
|
968
|
+
local model_package_arn
|
|
969
|
+
model_package_arn=$(echo "${resolve_result}" | python3 -c "import sys,json; print(json.load(sys.stdin).get('model_package_arn','') or '')" 2>/dev/null) || model_package_arn=""
|
|
970
|
+
|
|
971
|
+
# Display results
|
|
972
|
+
echo "🎉 Customization complete!"
|
|
973
|
+
echo ""
|
|
974
|
+
echo " Output type: ${output_type}"
|
|
975
|
+
echo " Artifact: ${artifact_path}"
|
|
976
|
+
if [ -n "${model_package_arn}" ]; then
|
|
977
|
+
echo " Model Package: ${model_package_arn}"
|
|
978
|
+
fi
|
|
979
|
+
echo ""
|
|
980
|
+
|
|
981
|
+
# Store output paths in config
|
|
982
|
+
if [ "${output_type}" = "adapter" ]; then
|
|
983
|
+
_update_config_var "TUNE_ADAPTER_PATH_${technique_upper}" "${artifact_path}"
|
|
984
|
+
else
|
|
985
|
+
_update_config_var "TUNE_MODEL_PATH_${technique_upper}" "${artifact_path}"
|
|
986
|
+
fi
|
|
987
|
+
_update_config_var "TUNE_OUTPUT_PATH_LATEST" "${artifact_path}"
|
|
988
|
+
_update_config_var "TUNE_OUTPUT_TYPE_LATEST" "${output_type}"
|
|
989
|
+
|
|
990
|
+
# Print next-step commands
|
|
991
|
+
echo "📋 Next steps:"
|
|
992
|
+
echo ""
|
|
993
|
+
if [ "${output_type}" = "adapter" ]; then
|
|
994
|
+
echo " Deploy as LoRA adapter:"
|
|
995
|
+
echo " ./do/adapter add tuned-${ARG_TECHNIQUE} --from-tune"
|
|
996
|
+
echo " ./do/adapter add tuned-${ARG_TECHNIQUE} --from-tune ${ARG_TECHNIQUE}"
|
|
997
|
+
echo " ./do/adapter add tuned-${ARG_TECHNIQUE} --weights ${artifact_path}"
|
|
998
|
+
else
|
|
999
|
+
echo " Deploy as new inference component:"
|
|
1000
|
+
echo " ./do/add-ic tuned-v1 --from-tune"
|
|
1001
|
+
echo " ./do/add-ic tuned-v1 --model-data ${artifact_path}"
|
|
1002
|
+
echo " ./do/deploy --force-ic --model-data ${artifact_path}"
|
|
1003
|
+
fi
|
|
1004
|
+
echo ""
|
|
1005
|
+
}
|
|
1006
|
+
|
|
1007
|
+
|
|
1008
|
+
# ── _dry_run() ────────────────────────────────────────────────────────────────
|
|
1009
|
+
# Validate all inputs and show what would be submitted without creating a job.
|
|
1010
|
+
_dry_run() {
|
|
1011
|
+
local output_bucket
|
|
1012
|
+
output_bucket=$(_resolve_output_bucket)
|
|
1013
|
+
|
|
1014
|
+
local role_arn
|
|
1015
|
+
if [ -n "${ARG_ROLE}" ]; then
|
|
1016
|
+
role_arn="${ARG_ROLE}"
|
|
1017
|
+
elif [ -n "${ROLE_ARN:-}" ]; then
|
|
1018
|
+
role_arn="${ROLE_ARN}"
|
|
1019
|
+
else
|
|
1020
|
+
role_arn="${SAGEMAKER_ROLE_ARN:-<not configured>}"
|
|
1021
|
+
fi
|
|
1022
|
+
|
|
1023
|
+
local timestamp
|
|
1024
|
+
timestamp=$(date +%Y%m%d-%H%M%S)
|
|
1025
|
+
local job_name="${PROJECT_NAME}-tune-${ARG_TECHNIQUE}-${timestamp}"
|
|
1026
|
+
|
|
1027
|
+
echo "🔍 Dry run — validation passed, would submit:"
|
|
1028
|
+
echo ""
|
|
1029
|
+
echo " Job name: ${job_name}"
|
|
1030
|
+
echo " Model: ${RESOLVED_MODEL_ID}"
|
|
1031
|
+
echo " Technique: ${ARG_TECHNIQUE}"
|
|
1032
|
+
echo " Training type: ${ARG_TRAINING_TYPE}"
|
|
1033
|
+
echo " Dataset: ${RESOLVED_DATASET_S3_URI}"
|
|
1034
|
+
echo " Output bucket: ${output_bucket}"
|
|
1035
|
+
echo " Role: ${role_arn}"
|
|
1036
|
+
echo " Package group: ${PROJECT_NAME}-tune-models"
|
|
1037
|
+
|
|
1038
|
+
if [ -n "${ARG_EPOCHS}" ]; then
|
|
1039
|
+
echo " Epochs: ${ARG_EPOCHS}"
|
|
1040
|
+
fi
|
|
1041
|
+
if [ -n "${ARG_LEARNING_RATE}" ]; then
|
|
1042
|
+
echo " Learning rate: ${ARG_LEARNING_RATE}"
|
|
1043
|
+
fi
|
|
1044
|
+
if [ -n "${ARG_MAX_SEQ_LENGTH}" ]; then
|
|
1045
|
+
echo " Max seq len: ${ARG_MAX_SEQ_LENGTH}"
|
|
1046
|
+
fi
|
|
1047
|
+
if [ -n "${ARG_LORA_RANK}" ]; then
|
|
1048
|
+
echo " LoRA rank: ${ARG_LORA_RANK}"
|
|
1049
|
+
fi
|
|
1050
|
+
if [ -n "${ARG_LORA_ALPHA}" ]; then
|
|
1051
|
+
echo " LoRA alpha: ${ARG_LORA_ALPHA}"
|
|
1052
|
+
fi
|
|
1053
|
+
if [ -n "${ARG_BATCH_SIZE}" ]; then
|
|
1054
|
+
echo " Batch size: ${ARG_BATCH_SIZE}"
|
|
1055
|
+
fi
|
|
1056
|
+
if [ -n "${ARG_REWARD_FUNCTION}" ]; then
|
|
1057
|
+
echo " Reward fn: ${ARG_REWARD_FUNCTION}"
|
|
1058
|
+
fi
|
|
1059
|
+
if [ -n "${ARG_REWARD_PROMPT}" ]; then
|
|
1060
|
+
echo " Reward prompt: ${ARG_REWARD_PROMPT}"
|
|
1061
|
+
fi
|
|
1062
|
+
|
|
1063
|
+
echo ""
|
|
1064
|
+
echo " ✅ All validations passed. Remove --dry-run to submit."
|
|
1065
|
+
exit 0
|
|
1066
|
+
}
|
|
1067
|
+
|
|
1068
|
+
# ══════════════════════════════════════════════════════════════════════════════
|
|
1069
|
+
# MAIN
|
|
1070
|
+
# ══════════════════════════════════════════════════════════════════════════════
|
|
1071
|
+
|
|
1072
|
+
_parse_args "$@"
|
|
1073
|
+
|
|
1074
|
+
# Handle informational flags first
|
|
1075
|
+
if [ "${ARG_HELP}" = true ]; then
|
|
1076
|
+
_show_help
|
|
1077
|
+
fi
|
|
1078
|
+
|
|
1079
|
+
if [ "${ARG_LIST_MODELS}" = true ]; then
|
|
1080
|
+
_list_models
|
|
1081
|
+
fi
|
|
1082
|
+
|
|
1083
|
+
if [ "${ARG_STATUS}" = true ]; then
|
|
1084
|
+
_show_status
|
|
1085
|
+
fi
|
|
1086
|
+
|
|
1087
|
+
# Validate required arguments for job submission
|
|
1088
|
+
if [ -z "${ARG_TECHNIQUE}" ]; then
|
|
1089
|
+
echo "❌ --technique is required"
|
|
1090
|
+
echo " Usage: ./do/tune --technique <sft|dpo|rlaif|rlvr> --dataset <source>"
|
|
1091
|
+
echo " Run ./do/tune --help for full usage."
|
|
1092
|
+
exit 1
|
|
1093
|
+
fi
|
|
1094
|
+
|
|
1095
|
+
if [ -z "${ARG_DATASET}" ]; then
|
|
1096
|
+
echo "❌ --dataset is required"
|
|
1097
|
+
echo " Usage: ./do/tune --technique ${ARG_TECHNIQUE} --dataset <s3://... or hf://...>"
|
|
1098
|
+
echo " Run ./do/tune --help for full usage."
|
|
1099
|
+
exit 1
|
|
1100
|
+
fi
|
|
1101
|
+
|
|
1102
|
+
# Check runtime support
|
|
1103
|
+
if [ "${TUNE_SUPPORTED:-}" = "false" ]; then
|
|
1104
|
+
echo "⚠️ Managed customization is not supported for the configured model."
|
|
1105
|
+
echo " Checking catalog for current support..."
|
|
1106
|
+
echo ""
|
|
1107
|
+
fi
|
|
1108
|
+
|
|
1109
|
+
# Validate Python availability
|
|
1110
|
+
if ! command -v python3 &>/dev/null; then
|
|
1111
|
+
echo "❌ python3 is required but not found"
|
|
1112
|
+
echo " Install Python 3 to use managed model customization."
|
|
1113
|
+
exit 1
|
|
1114
|
+
fi
|
|
1115
|
+
|
|
1116
|
+
# Run validations
|
|
1117
|
+
echo "🔧 SageMaker AI Managed Model Customization"
|
|
1118
|
+
echo ""
|
|
1119
|
+
|
|
1120
|
+
_validate_model
|
|
1121
|
+
_validate_technique
|
|
1122
|
+
_validate_training_type
|
|
1123
|
+
_validate_dataset
|
|
1124
|
+
|
|
1125
|
+
# Check idempotency (may exit early if existing job is handled)
|
|
1126
|
+
if _check_idempotency; then
|
|
1127
|
+
# No existing job or --force: proceed with submission
|
|
1128
|
+
|
|
1129
|
+
if [ "${ARG_DRY_RUN}" = true ]; then
|
|
1130
|
+
_dry_run
|
|
1131
|
+
fi
|
|
1132
|
+
|
|
1133
|
+
_submit_job
|
|
1134
|
+
|
|
1135
|
+
if [ "${ARG_NO_WAIT}" = true ]; then
|
|
1136
|
+
echo " --no-wait specified. Job running in background."
|
|
1137
|
+
echo " Check status: ./do/tune --status"
|
|
1138
|
+
exit 0
|
|
1139
|
+
fi
|
|
1140
|
+
|
|
1141
|
+
_poll_job
|
|
1142
|
+
_handle_completion
|
|
1143
|
+
fi
|