@aws/ml-container-creator 1.0.3 → 1.0.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +1 -1
- package/infra/ci-harness/lib/ci-harness-stack.ts +43 -0
- package/package.json +2 -2
- package/servers/base-image-picker/index.js +65 -18
- package/servers/instance-sizer/index.js +32 -0
- package/servers/lib/catalogs/fleet-drivers.json +38 -0
- package/servers/lib/catalogs/model-arch-support.json +51 -0
- package/servers/lib/catalogs/model-servers.json +2842 -1730
- package/servers/lib/schemas/image-catalog.schema.json +12 -0
- package/src/app.js +6 -4
- package/src/lib/generated/cli-options.js +1 -1
- package/src/lib/generated/parameter-matrix.js +1 -1
- package/src/lib/generated/validation-rules.js +1 -1
- package/src/lib/mcp-query-runner.js +110 -3
- package/src/lib/prompt-runner.js +66 -22
- package/src/lib/template-variable-resolver.js +8 -0
- package/src/lib/train-config-builder.js +339 -0
- package/templates/do/.benchmark_writer.py +3 -0
- package/templates/do/.eval_helper.py +409 -0
- package/templates/do/.register_helper.py +185 -11
- package/templates/do/.train_build_request.py +102 -113
- package/templates/do/.train_helper.py +433 -0
- package/templates/do/__pycache__/.register_helper.cpython-312.pyc +0 -0
- package/templates/do/adapter +157 -0
- package/templates/do/benchmark +60 -3
- package/templates/do/deploy.d/managed-inference.ejs +83 -0
- package/templates/do/evaluate +272 -0
- package/templates/do/lib/resolve-instance.sh +155 -0
- package/templates/do/register +5 -0
- package/templates/do/test +1 -0
- package/templates/do/train +879 -126
- package/templates/do/training/config.yaml +83 -11
- package/templates/do/training/dpo/accelerate_config.yaml +24 -0
- package/templates/do/training/dpo/defaults.yaml +26 -0
- package/templates/do/training/dpo/prompts.json +8 -0
- package/templates/do/training/dpo/train.py +363 -0
- package/templates/do/training/sft/accelerate_config.yaml +22 -0
- package/templates/do/training/sft/defaults.yaml +18 -0
- package/templates/do/training/sft/prompts.json +7 -0
- package/templates/do/training/sft/train.py +310 -0
- package/templates/do/tune +11 -2
- package/templates/do/.train_poll_parser.py +0 -135
- package/templates/do/.train_status_parser.py +0 -187
- /package/templates/do/training/{train.py → custom/train.py} +0 -0
|
@@ -0,0 +1,310 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
"""Unmanaged SFT fine-tuning using TRL SFTTrainer + PEFT LoRA.
|
|
6
|
+
|
|
7
|
+
This script performs LoRA-based supervised fine-tuning on a causal language
|
|
8
|
+
model. It is invoked via `accelerate launch` and works on both single-GPU
|
|
9
|
+
and multi-GPU instances without code changes.
|
|
10
|
+
|
|
11
|
+
Portable env-var contract (works on SageMaker AI and HyperPod EKS):
|
|
12
|
+
DATA_DIR / SM_CHANNEL_TRAINING -> training data path
|
|
13
|
+
OUTPUT_DIR / SM_MODEL_DIR -> model artifact output
|
|
14
|
+
CHECKPOINT_DIR / SM_CHECKPOINT_DIR -> checkpoint path for spot resume
|
|
15
|
+
HF_MODEL_ID / SM_HP_MODEL_ID -> base model HuggingFace ID
|
|
16
|
+
SM_HPS -> JSON blob of all hyperparameters
|
|
17
|
+
|
|
18
|
+
Output:
|
|
19
|
+
LoRA adapter saved to OUTPUT_DIR (adapter_model.safetensors + adapter_config.json)
|
|
20
|
+
Metrics logged to stdout in SageMaker-parseable format
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
import glob
|
|
24
|
+
import json
|
|
25
|
+
import logging
|
|
26
|
+
import os
|
|
27
|
+
import sys
|
|
28
|
+
|
|
29
|
+
# ── Logging ───────────────────────────────────────────────────────────────────
|
|
30
|
+
logging.basicConfig(
|
|
31
|
+
level=logging.INFO,
|
|
32
|
+
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
|
33
|
+
)
|
|
34
|
+
logger = logging.getLogger("sft-trainer")
|
|
35
|
+
|
|
36
|
+
# ── Portable Path Resolution ─────────────────────────────────────────────────
|
|
37
|
+
# Fallback chain: generic env var -> SageMaker env var -> default path
|
|
38
|
+
# This allows the same script to run on SageMaker, HyperPod, or locally.
|
|
39
|
+
|
|
40
|
+
DATA_DIR = (
|
|
41
|
+
os.environ.get("DATA_DIR")
|
|
42
|
+
or os.environ.get("SM_CHANNEL_TRAINING")
|
|
43
|
+
or "/opt/ml/input/data/training"
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
OUTPUT_DIR = (
|
|
47
|
+
os.environ.get("OUTPUT_DIR")
|
|
48
|
+
or os.environ.get("SM_MODEL_DIR")
|
|
49
|
+
or "/opt/ml/model"
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
CHECKPOINT_DIR = (
|
|
53
|
+
os.environ.get("CHECKPOINT_DIR")
|
|
54
|
+
or os.environ.get("SM_CHECKPOINT_DIR")
|
|
55
|
+
or "/opt/ml/checkpoints"
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
MODEL_ID = (
|
|
59
|
+
os.environ.get("HF_MODEL_ID")
|
|
60
|
+
or os.environ.get("SM_HP_MODEL_ID")
|
|
61
|
+
or ""
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# ── Hyperparameter Loading ────────────────────────────────────────────────────
|
|
66
|
+
|
|
67
|
+
def load_hyperparameters():
|
|
68
|
+
"""Load hyperparameters from SageMaker SM_HPS env var or individual SM_HP_* vars.
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
dict with typed hyperparameter values.
|
|
72
|
+
"""
|
|
73
|
+
defaults = {
|
|
74
|
+
"model_id": MODEL_ID,
|
|
75
|
+
"lora_r": 16,
|
|
76
|
+
"lora_alpha": 32,
|
|
77
|
+
"lora_dropout": 0.05,
|
|
78
|
+
"learning_rate": 2e-4,
|
|
79
|
+
"epochs": 3,
|
|
80
|
+
"batch_size": 4,
|
|
81
|
+
"max_seq_length": 2048,
|
|
82
|
+
"gradient_accumulation_steps": 4,
|
|
83
|
+
"warmup_ratio": 0.03,
|
|
84
|
+
"dataset_text_field": "text",
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
# Try SM_HPS (JSON blob of all hyperparameters)
|
|
88
|
+
sm_hps = os.environ.get("SM_HPS")
|
|
89
|
+
if sm_hps:
|
|
90
|
+
try:
|
|
91
|
+
raw = json.loads(sm_hps)
|
|
92
|
+
# SageMaker passes all values as strings — cast them
|
|
93
|
+
for key, default_val in defaults.items():
|
|
94
|
+
if key in raw:
|
|
95
|
+
defaults[key] = _cast(raw[key], type(default_val))
|
|
96
|
+
return defaults
|
|
97
|
+
except (json.JSONDecodeError, ValueError) as e:
|
|
98
|
+
logger.warning("Failed to parse SM_HPS: %s", e)
|
|
99
|
+
|
|
100
|
+
# Fallback: individual SM_HP_* env vars
|
|
101
|
+
for key, default_val in defaults.items():
|
|
102
|
+
env_key = f"SM_HP_{key.upper()}"
|
|
103
|
+
env_val = os.environ.get(env_key)
|
|
104
|
+
if env_val is not None:
|
|
105
|
+
defaults[key] = _cast(env_val, type(default_val))
|
|
106
|
+
|
|
107
|
+
return defaults
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _cast(value, target_type):
|
|
111
|
+
"""Cast a string value to the target type."""
|
|
112
|
+
if target_type == bool:
|
|
113
|
+
return str(value).lower() in ("true", "1", "yes")
|
|
114
|
+
if target_type == int:
|
|
115
|
+
return int(float(value))
|
|
116
|
+
if target_type == float:
|
|
117
|
+
return float(value)
|
|
118
|
+
return str(value)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
# ── Dataset Loading ───────────────────────────────────────────────────────────
|
|
122
|
+
|
|
123
|
+
def load_dataset(data_dir, text_field):
|
|
124
|
+
"""Load training dataset from data directory.
|
|
125
|
+
|
|
126
|
+
Supports .jsonl, .parquet, and .csv files.
|
|
127
|
+
|
|
128
|
+
Args:
|
|
129
|
+
data_dir: Path to directory containing training data files.
|
|
130
|
+
text_field: Name of the text column in the dataset.
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
A Hugging Face Dataset object.
|
|
134
|
+
"""
|
|
135
|
+
from datasets import load_dataset as hf_load_dataset
|
|
136
|
+
|
|
137
|
+
# Find data files
|
|
138
|
+
extensions = ["jsonl", "json", "parquet", "csv"]
|
|
139
|
+
data_files = []
|
|
140
|
+
for ext in extensions:
|
|
141
|
+
data_files.extend(glob.glob(os.path.join(data_dir, f"*.{ext}")))
|
|
142
|
+
data_files.extend(glob.glob(os.path.join(data_dir, f"**/*.{ext}"), recursive=True))
|
|
143
|
+
|
|
144
|
+
if not data_files:
|
|
145
|
+
logger.error("No data files found in %s (searched: %s)", data_dir, extensions)
|
|
146
|
+
sys.exit(1)
|
|
147
|
+
|
|
148
|
+
# Deduplicate and sort
|
|
149
|
+
data_files = sorted(set(data_files))
|
|
150
|
+
logger.info("Found %d data file(s) in %s", len(data_files), data_dir)
|
|
151
|
+
|
|
152
|
+
# Determine format from first file extension
|
|
153
|
+
first_ext = data_files[0].rsplit(".", 1)[-1].lower()
|
|
154
|
+
format_map = {"jsonl": "json", "json": "json", "parquet": "parquet", "csv": "csv"}
|
|
155
|
+
file_format = format_map.get(first_ext, "json")
|
|
156
|
+
|
|
157
|
+
dataset = hf_load_dataset(file_format, data_files=data_files, split="train")
|
|
158
|
+
logger.info("Loaded dataset: %d rows, columns: %s", len(dataset), dataset.column_names)
|
|
159
|
+
|
|
160
|
+
# Verify text field exists
|
|
161
|
+
if text_field not in dataset.column_names:
|
|
162
|
+
logger.error(
|
|
163
|
+
"Text field '%s' not found in dataset. Available columns: %s",
|
|
164
|
+
text_field, dataset.column_names,
|
|
165
|
+
)
|
|
166
|
+
sys.exit(1)
|
|
167
|
+
|
|
168
|
+
return dataset
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
# ── Main Training Function ────────────────────────────────────────────────────
|
|
172
|
+
|
|
173
|
+
def main():
|
|
174
|
+
"""Run SFT training with TRL SFTTrainer + PEFT LoRA."""
|
|
175
|
+
from accelerate import Accelerator
|
|
176
|
+
from peft import LoraConfig, TaskType
|
|
177
|
+
from transformers import (
|
|
178
|
+
AutoModelForCausalLM,
|
|
179
|
+
AutoTokenizer,
|
|
180
|
+
TrainingArguments,
|
|
181
|
+
)
|
|
182
|
+
from trl import SFTTrainer
|
|
183
|
+
|
|
184
|
+
# Initialize accelerator (handles distributed setup)
|
|
185
|
+
accelerator = Accelerator()
|
|
186
|
+
|
|
187
|
+
# Load hyperparameters
|
|
188
|
+
hparams = load_hyperparameters()
|
|
189
|
+
model_id = hparams["model_id"]
|
|
190
|
+
|
|
191
|
+
if not model_id:
|
|
192
|
+
logger.error("No model ID specified. Set HF_MODEL_ID env var or model_id hyperparameter.")
|
|
193
|
+
sys.exit(1)
|
|
194
|
+
|
|
195
|
+
if accelerator.is_main_process:
|
|
196
|
+
logger.info("=" * 60)
|
|
197
|
+
logger.info("SFT Training Configuration")
|
|
198
|
+
logger.info("=" * 60)
|
|
199
|
+
logger.info(" Model: %s", model_id)
|
|
200
|
+
logger.info(" Data dir: %s", DATA_DIR)
|
|
201
|
+
logger.info(" Output dir: %s", OUTPUT_DIR)
|
|
202
|
+
logger.info(" Checkpoint dir: %s", CHECKPOINT_DIR)
|
|
203
|
+
logger.info(" LoRA r: %d", hparams["lora_r"])
|
|
204
|
+
logger.info(" LoRA alpha: %d", hparams["lora_alpha"])
|
|
205
|
+
logger.info(" Learning rate: %s", hparams["learning_rate"])
|
|
206
|
+
logger.info(" Epochs: %d", hparams["epochs"])
|
|
207
|
+
logger.info(" Batch size: %d", hparams["batch_size"])
|
|
208
|
+
logger.info(" Max seq len: %d", hparams["max_seq_length"])
|
|
209
|
+
logger.info(" Text field: %s", hparams["dataset_text_field"])
|
|
210
|
+
logger.info("=" * 60)
|
|
211
|
+
|
|
212
|
+
# ── Load tokenizer and model ──────────────────────────────────────────────
|
|
213
|
+
if accelerator.is_main_process:
|
|
214
|
+
logger.info("Loading tokenizer and model: %s", model_id)
|
|
215
|
+
|
|
216
|
+
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
|
217
|
+
if tokenizer.pad_token is None:
|
|
218
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
219
|
+
|
|
220
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
221
|
+
model_id,
|
|
222
|
+
torch_dtype="auto",
|
|
223
|
+
trust_remote_code=True,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
# ── Configure LoRA ────────────────────────────────────────────────────────
|
|
227
|
+
lora_config = LoraConfig(
|
|
228
|
+
r=hparams["lora_r"],
|
|
229
|
+
lora_alpha=hparams["lora_alpha"],
|
|
230
|
+
lora_dropout=hparams["lora_dropout"],
|
|
231
|
+
target_modules="all-linear",
|
|
232
|
+
task_type=TaskType.CAUSAL_LM,
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
# ── Load dataset ──────────────────────────────────────────────────────────
|
|
236
|
+
dataset = load_dataset(DATA_DIR, hparams["dataset_text_field"])
|
|
237
|
+
|
|
238
|
+
# ── Training arguments ────────────────────────────────────────────────────
|
|
239
|
+
training_args = TrainingArguments(
|
|
240
|
+
output_dir=os.path.join(CHECKPOINT_DIR, "trainer-state"),
|
|
241
|
+
num_train_epochs=hparams["epochs"],
|
|
242
|
+
per_device_train_batch_size=hparams["batch_size"],
|
|
243
|
+
gradient_accumulation_steps=hparams["gradient_accumulation_steps"],
|
|
244
|
+
learning_rate=hparams["learning_rate"],
|
|
245
|
+
warmup_ratio=hparams["warmup_ratio"],
|
|
246
|
+
bf16=True,
|
|
247
|
+
logging_steps=10,
|
|
248
|
+
save_strategy="epoch",
|
|
249
|
+
save_total_limit=2,
|
|
250
|
+
report_to="none",
|
|
251
|
+
remove_unused_columns=False,
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
# ── Check for existing checkpoint (spot resume) ───────────────────────────
|
|
255
|
+
resume_from_checkpoint = None
|
|
256
|
+
trainer_state_dir = os.path.join(CHECKPOINT_DIR, "trainer-state")
|
|
257
|
+
if os.path.isdir(trainer_state_dir):
|
|
258
|
+
checkpoints = sorted(
|
|
259
|
+
glob.glob(os.path.join(trainer_state_dir, "checkpoint-*")),
|
|
260
|
+
key=lambda x: int(x.rsplit("-", 1)[-1]) if x.rsplit("-", 1)[-1].isdigit() else 0,
|
|
261
|
+
)
|
|
262
|
+
if checkpoints:
|
|
263
|
+
resume_from_checkpoint = checkpoints[-1]
|
|
264
|
+
if accelerator.is_main_process:
|
|
265
|
+
logger.info("Resuming from checkpoint: %s", resume_from_checkpoint)
|
|
266
|
+
|
|
267
|
+
# ── Initialize SFTTrainer ─────────────────────────────────────────────────
|
|
268
|
+
trainer = SFTTrainer(
|
|
269
|
+
model=model,
|
|
270
|
+
tokenizer=tokenizer,
|
|
271
|
+
train_dataset=dataset,
|
|
272
|
+
peft_config=lora_config,
|
|
273
|
+
args=training_args,
|
|
274
|
+
max_seq_length=hparams["max_seq_length"],
|
|
275
|
+
dataset_text_field=hparams["dataset_text_field"],
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
# ── Train ─────────────────────────────────────────────────────────────────
|
|
279
|
+
if accelerator.is_main_process:
|
|
280
|
+
logger.info("Starting SFT training...")
|
|
281
|
+
|
|
282
|
+
train_result = trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
|
283
|
+
|
|
284
|
+
# ── Save adapter (rank 0 only) ───────────────────────────────────────────
|
|
285
|
+
if accelerator.is_main_process:
|
|
286
|
+
logger.info("Saving LoRA adapter to: %s", OUTPUT_DIR)
|
|
287
|
+
trainer.save_model(OUTPUT_DIR)
|
|
288
|
+
tokenizer.save_pretrained(OUTPUT_DIR)
|
|
289
|
+
|
|
290
|
+
# ── Log final metrics ─────────────────────────────────────────────────
|
|
291
|
+
# Format: metric_name: value (SageMaker captures via regex in config.yaml)
|
|
292
|
+
metrics = train_result.metrics
|
|
293
|
+
print(f"train_loss: {metrics.get('train_loss', 0.0):.4f}")
|
|
294
|
+
print(f"train_runtime: {metrics.get('train_runtime', 0.0):.1f}")
|
|
295
|
+
print(f"train_samples_per_second: {metrics.get('train_samples_per_second', 0.0):.2f}")
|
|
296
|
+
print(f"epochs: {hparams['epochs']}")
|
|
297
|
+
|
|
298
|
+
logger.info("Training complete!")
|
|
299
|
+
logger.info(" Loss: %.4f", metrics.get("train_loss", 0.0))
|
|
300
|
+
logger.info(" Runtime: %.1fs", metrics.get("train_runtime", 0.0))
|
|
301
|
+
logger.info(" Samples/s: %.2f", metrics.get("train_samples_per_second", 0.0))
|
|
302
|
+
|
|
303
|
+
# Wait for all processes before exit
|
|
304
|
+
accelerator.wait_for_everyone()
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
# ── Entry Point ───────────────────────────────────────────────────────────────
|
|
308
|
+
|
|
309
|
+
if __name__ == "__main__":
|
|
310
|
+
main()
|
package/templates/do/tune
CHANGED
|
@@ -796,9 +796,16 @@ else:
|
|
|
796
796
|
_validate_dataset() {
|
|
797
797
|
local dataset="${ARG_DATASET}"
|
|
798
798
|
|
|
799
|
-
# ── Parse @v<
|
|
799
|
+
# ── Parse @v<version> suffix (AC-2.1, AC-2.3) ───────────────────────────
|
|
800
800
|
# Syntax: dataset-name@v2 → name="dataset-name", version ordinal=2
|
|
801
|
-
|
|
801
|
+
# dataset-name@v1.0.0 → name="dataset-name", version="1.0.0" (semver)
|
|
802
|
+
if [[ "${dataset}" =~ ^(.+)@v([0-9]+\.[0-9]+\.[0-9]+)$ ]]; then
|
|
803
|
+
# Semver syntax: @v1.0.0
|
|
804
|
+
ARG_DATASET_NAME="${BASH_REMATCH[1]}"
|
|
805
|
+
ARG_DATASET_VERSION="${BASH_REMATCH[2]}"
|
|
806
|
+
dataset="" # Clear so name-based resolution takes over
|
|
807
|
+
elif [[ "${dataset}" =~ ^(.+)@v([0-9]+)$ ]]; then
|
|
808
|
+
# Ordinal syntax: @v1, @v2
|
|
802
809
|
ARG_DATASET_NAME="${BASH_REMATCH[1]}"
|
|
803
810
|
ARG_DATASET_VERSION="${BASH_REMATCH[2]}"
|
|
804
811
|
dataset="" # Clear so name-based resolution takes over
|
|
@@ -824,6 +831,8 @@ _validate_dataset() {
|
|
|
824
831
|
echo " Resolved to: ${resolved_uri}"
|
|
825
832
|
dataset="${resolved_uri}"
|
|
826
833
|
ARG_DATASET="${resolved_uri}"
|
|
834
|
+
RESOLVED_DATASET_S3_URI="${resolved_uri}"
|
|
835
|
+
return 0
|
|
827
836
|
else
|
|
828
837
|
echo "❌ Dataset '${ARG_DATASET_NAME}' not found in registry"
|
|
829
838
|
echo " Register it first: ./do/register --dataset --dataset-name ${ARG_DATASET_NAME} --dataset-s3-uri s3://..."
|
|
@@ -1,135 +0,0 @@
|
|
|
1
|
-
#!/usr/bin/env python3
|
|
2
|
-
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
3
|
-
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
-
|
|
5
|
-
"""
|
|
6
|
-
Parse DescribeTrainingJob JSON for the polling loop in do/train.
|
|
7
|
-
|
|
8
|
-
Reads JSON from stdin and outputs structured key=value lines for bash consumption:
|
|
9
|
-
STATUS=<TrainingJobStatus>
|
|
10
|
-
SECONDARY=<SecondaryStatus>
|
|
11
|
-
FAILURE_REASON=<FailureReason or empty>
|
|
12
|
-
DISPLAY=<formatted single-line status display>
|
|
13
|
-
|
|
14
|
-
This keeps the bash poll loop simple while handling JSON parsing in Python.
|
|
15
|
-
"""
|
|
16
|
-
|
|
17
|
-
import json
|
|
18
|
-
import sys
|
|
19
|
-
from datetime import datetime, timezone
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
def format_duration(seconds):
|
|
23
|
-
"""Format seconds into a human-readable duration string."""
|
|
24
|
-
if seconds is None or seconds < 0:
|
|
25
|
-
return 'N/A'
|
|
26
|
-
hours = int(seconds // 3600)
|
|
27
|
-
minutes = int((seconds % 3600) // 60)
|
|
28
|
-
secs = int(seconds % 60)
|
|
29
|
-
if hours > 0:
|
|
30
|
-
return f'{hours}h {minutes}m {secs}s'
|
|
31
|
-
elif minutes > 0:
|
|
32
|
-
return f'{minutes}m {secs}s'
|
|
33
|
-
else:
|
|
34
|
-
return f'{secs}s'
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
def parse_iso_time(time_str):
|
|
38
|
-
"""Parse an ISO 8601 timestamp string to a datetime object."""
|
|
39
|
-
if not time_str:
|
|
40
|
-
return None
|
|
41
|
-
try:
|
|
42
|
-
time_str = time_str.replace('Z', '+00:00')
|
|
43
|
-
return datetime.fromisoformat(time_str)
|
|
44
|
-
except (ValueError, TypeError):
|
|
45
|
-
return None
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
def calculate_elapsed(start_time_str):
|
|
49
|
-
"""Calculate elapsed time from start to now."""
|
|
50
|
-
start = parse_iso_time(start_time_str)
|
|
51
|
-
if not start:
|
|
52
|
-
return None
|
|
53
|
-
now = datetime.now(timezone.utc)
|
|
54
|
-
elapsed = (now - start).total_seconds()
|
|
55
|
-
return max(0, elapsed)
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
def format_metrics(final_metrics):
|
|
59
|
-
"""Format FinalMetricDataList into a compact string."""
|
|
60
|
-
if not final_metrics:
|
|
61
|
-
return ''
|
|
62
|
-
parts = []
|
|
63
|
-
for metric in final_metrics:
|
|
64
|
-
name = metric.get('MetricName', 'unknown')
|
|
65
|
-
value = metric.get('Value', 0)
|
|
66
|
-
if isinstance(value, float):
|
|
67
|
-
if abs(value) < 0.001:
|
|
68
|
-
parts.append(f'{name}={value:.6f}')
|
|
69
|
-
elif abs(value) < 1:
|
|
70
|
-
parts.append(f'{name}={value:.4f}')
|
|
71
|
-
else:
|
|
72
|
-
parts.append(f'{name}={value:.2f}')
|
|
73
|
-
else:
|
|
74
|
-
parts.append(f'{name}={value}')
|
|
75
|
-
return ', '.join(parts)
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
# Status emoji mapping
|
|
79
|
-
STATUS_EMOJI = {
|
|
80
|
-
'InProgress': '🔄',
|
|
81
|
-
'Completed': '✅',
|
|
82
|
-
'Failed': '❌',
|
|
83
|
-
'Stopping': '⏸️',
|
|
84
|
-
'Stopped': '⏹️'
|
|
85
|
-
}
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
def main():
|
|
89
|
-
"""Parse DescribeTrainingJob JSON from stdin and output structured lines."""
|
|
90
|
-
try:
|
|
91
|
-
job_data = json.load(sys.stdin)
|
|
92
|
-
except json.JSONDecodeError as e:
|
|
93
|
-
print(f'Error parsing JSON: {e}', file=sys.stderr)
|
|
94
|
-
sys.exit(1)
|
|
95
|
-
|
|
96
|
-
status = job_data.get('TrainingJobStatus', 'Unknown')
|
|
97
|
-
secondary_status = job_data.get('SecondaryStatus', '')
|
|
98
|
-
failure_reason = job_data.get('FailureReason', '')
|
|
99
|
-
training_start = job_data.get('TrainingStartTime', '')
|
|
100
|
-
final_metrics = job_data.get('FinalMetricDataList', [])
|
|
101
|
-
|
|
102
|
-
# Calculate elapsed time
|
|
103
|
-
elapsed_str = ''
|
|
104
|
-
if training_start:
|
|
105
|
-
elapsed = calculate_elapsed(training_start)
|
|
106
|
-
if elapsed is not None:
|
|
107
|
-
elapsed_str = format_duration(elapsed)
|
|
108
|
-
|
|
109
|
-
# Format metrics
|
|
110
|
-
metrics_str = format_metrics(final_metrics)
|
|
111
|
-
|
|
112
|
-
# Build display line
|
|
113
|
-
emoji = STATUS_EMOJI.get(status, '❓')
|
|
114
|
-
display_parts = [f' {emoji} {status}']
|
|
115
|
-
|
|
116
|
-
if secondary_status:
|
|
117
|
-
display_parts.append(f'| {secondary_status}')
|
|
118
|
-
|
|
119
|
-
if elapsed_str:
|
|
120
|
-
display_parts.append(f'| elapsed: {elapsed_str}')
|
|
121
|
-
|
|
122
|
-
if metrics_str:
|
|
123
|
-
display_parts.append(f'| {metrics_str}')
|
|
124
|
-
|
|
125
|
-
display_line = ' '.join(display_parts)
|
|
126
|
-
|
|
127
|
-
# Output structured lines for bash
|
|
128
|
-
print(f'STATUS={status}')
|
|
129
|
-
print(f'SECONDARY={secondary_status}')
|
|
130
|
-
print(f'FAILURE_REASON={failure_reason}')
|
|
131
|
-
print(f'DISPLAY={display_line}')
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
if __name__ == '__main__':
|
|
135
|
-
main()
|
|
@@ -1,187 +0,0 @@
|
|
|
1
|
-
#!/usr/bin/env python3
|
|
2
|
-
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
3
|
-
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
-
|
|
5
|
-
"""
|
|
6
|
-
Parse DescribeTrainingJob JSON response and display formatted status.
|
|
7
|
-
|
|
8
|
-
This helper is called by do/train --status to parse the AWS CLI JSON output
|
|
9
|
-
from DescribeTrainingJob and display a user-friendly status summary.
|
|
10
|
-
"""
|
|
11
|
-
|
|
12
|
-
import json
|
|
13
|
-
import sys
|
|
14
|
-
import time
|
|
15
|
-
from datetime import datetime, timezone
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
# Status emoji mapping
|
|
19
|
-
STATUS_EMOJI = {
|
|
20
|
-
'InProgress': '🔄',
|
|
21
|
-
'Completed': '✅',
|
|
22
|
-
'Failed': '❌',
|
|
23
|
-
'Stopping': '⏸️',
|
|
24
|
-
'Stopped': '⏹️'
|
|
25
|
-
}
|
|
26
|
-
|
|
27
|
-
# Secondary status descriptions
|
|
28
|
-
SECONDARY_DESCRIPTIONS = {
|
|
29
|
-
'Starting': 'Preparing training instance',
|
|
30
|
-
'LaunchingMLInstances': 'Launching ML instances',
|
|
31
|
-
'PreparingTrainingStack': 'Preparing training stack',
|
|
32
|
-
'Downloading': 'Downloading training data',
|
|
33
|
-
'DownloadingTrainingImage': 'Downloading training image',
|
|
34
|
-
'Training': 'Training in progress',
|
|
35
|
-
'Uploading': 'Uploading model artifacts',
|
|
36
|
-
'Completed': 'Training completed',
|
|
37
|
-
'MaxRuntimeExceeded': 'Max runtime exceeded',
|
|
38
|
-
'Stopped': 'Training stopped',
|
|
39
|
-
'MaxWaitTimeExceeded': 'Max wait time exceeded (spot)',
|
|
40
|
-
'Interrupted': 'Spot instance interrupted'
|
|
41
|
-
}
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
def format_duration(seconds):
|
|
45
|
-
"""Format seconds into a human-readable duration string."""
|
|
46
|
-
if seconds is None or seconds < 0:
|
|
47
|
-
return 'N/A'
|
|
48
|
-
hours = int(seconds // 3600)
|
|
49
|
-
minutes = int((seconds % 3600) // 60)
|
|
50
|
-
secs = int(seconds % 60)
|
|
51
|
-
if hours > 0:
|
|
52
|
-
return f'{hours}h {minutes}m {secs}s'
|
|
53
|
-
elif minutes > 0:
|
|
54
|
-
return f'{minutes}m {secs}s'
|
|
55
|
-
else:
|
|
56
|
-
return f'{secs}s'
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
def parse_iso_time(time_str):
|
|
60
|
-
"""Parse an ISO 8601 timestamp string to a datetime object."""
|
|
61
|
-
if not time_str:
|
|
62
|
-
return None
|
|
63
|
-
try:
|
|
64
|
-
# Handle various AWS timestamp formats
|
|
65
|
-
# Remove trailing 'Z' and replace with +00:00 for fromisoformat
|
|
66
|
-
time_str = time_str.replace('Z', '+00:00')
|
|
67
|
-
return datetime.fromisoformat(time_str)
|
|
68
|
-
except (ValueError, TypeError):
|
|
69
|
-
return None
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
def calculate_elapsed(start_time_str):
|
|
73
|
-
"""Calculate elapsed time from start to now."""
|
|
74
|
-
start = parse_iso_time(start_time_str)
|
|
75
|
-
if not start:
|
|
76
|
-
return None
|
|
77
|
-
now = datetime.now(timezone.utc)
|
|
78
|
-
elapsed = (now - start).total_seconds()
|
|
79
|
-
return max(0, elapsed)
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
def display_status(job_data):
|
|
83
|
-
"""Display formatted training job status."""
|
|
84
|
-
job_name = job_data.get('TrainingJobName', 'Unknown')
|
|
85
|
-
status = job_data.get('TrainingJobStatus', 'Unknown')
|
|
86
|
-
secondary_status = job_data.get('SecondaryStatus', '')
|
|
87
|
-
failure_reason = job_data.get('FailureReason', '')
|
|
88
|
-
training_start = job_data.get('TrainingStartTime', '')
|
|
89
|
-
training_end = job_data.get('TrainingEndTime', '')
|
|
90
|
-
billable_seconds = job_data.get('BillableTimeInSeconds')
|
|
91
|
-
training_seconds = job_data.get('TrainingTimeInSeconds')
|
|
92
|
-
final_metrics = job_data.get('FinalMetricDataList', [])
|
|
93
|
-
output_path = job_data.get('OutputDataConfig', {}).get('S3OutputPath', '')
|
|
94
|
-
model_artifacts = job_data.get('ModelArtifacts', {}).get('S3ModelArtifacts', '')
|
|
95
|
-
instance_type = job_data.get('ResourceConfig', {}).get('InstanceType', '')
|
|
96
|
-
instance_count = job_data.get('ResourceConfig', {}).get('InstanceCount', 1)
|
|
97
|
-
spot_enabled = job_data.get('EnableManagedSpotTraining', False)
|
|
98
|
-
|
|
99
|
-
emoji = STATUS_EMOJI.get(status, '❓')
|
|
100
|
-
|
|
101
|
-
print(f'')
|
|
102
|
-
print(f' {emoji} Status: {status}')
|
|
103
|
-
|
|
104
|
-
# Secondary status with description
|
|
105
|
-
if secondary_status:
|
|
106
|
-
desc = SECONDARY_DESCRIPTIONS.get(secondary_status, '')
|
|
107
|
-
if desc:
|
|
108
|
-
print(f' 📍 Phase: {secondary_status} ({desc})')
|
|
109
|
-
else:
|
|
110
|
-
print(f' 📍 Phase: {secondary_status}')
|
|
111
|
-
|
|
112
|
-
# Elapsed time
|
|
113
|
-
if status == 'InProgress' and training_start:
|
|
114
|
-
elapsed = calculate_elapsed(training_start)
|
|
115
|
-
if elapsed is not None:
|
|
116
|
-
print(f' ⏱️ Elapsed: {format_duration(elapsed)}')
|
|
117
|
-
elif training_seconds is not None:
|
|
118
|
-
print(f' ⏱️ Training time: {format_duration(training_seconds)}')
|
|
119
|
-
|
|
120
|
-
# Instance info
|
|
121
|
-
if instance_type:
|
|
122
|
-
instance_info = f'{instance_type}'
|
|
123
|
-
if instance_count and instance_count > 1:
|
|
124
|
-
instance_info += f' x {instance_count}'
|
|
125
|
-
if spot_enabled:
|
|
126
|
-
instance_info += ' (spot)'
|
|
127
|
-
print(f' 🖥️ Instance: {instance_info}')
|
|
128
|
-
|
|
129
|
-
# Billable time and cost savings (for completed spot jobs)
|
|
130
|
-
if status == 'Completed' and spot_enabled and billable_seconds is not None and training_seconds is not None:
|
|
131
|
-
savings_seconds = training_seconds - billable_seconds
|
|
132
|
-
if training_seconds > 0:
|
|
133
|
-
savings_pct = (savings_seconds / training_seconds) * 100
|
|
134
|
-
print(f' 💰 Spot savings: {format_duration(savings_seconds)} saved ({savings_pct:.0f}% discount)')
|
|
135
|
-
print(f' Billable: {format_duration(billable_seconds)} / Total: {format_duration(training_seconds)}')
|
|
136
|
-
|
|
137
|
-
# Training metrics
|
|
138
|
-
if final_metrics:
|
|
139
|
-
print(f' 📈 Metrics:')
|
|
140
|
-
for metric in final_metrics:
|
|
141
|
-
name = metric.get('MetricName', 'unknown')
|
|
142
|
-
value = metric.get('Value', 0)
|
|
143
|
-
# Format value nicely
|
|
144
|
-
if isinstance(value, float):
|
|
145
|
-
if abs(value) < 0.001:
|
|
146
|
-
print(f' {name}: {value:.6f}')
|
|
147
|
-
elif abs(value) < 1:
|
|
148
|
-
print(f' {name}: {value:.4f}')
|
|
149
|
-
else:
|
|
150
|
-
print(f' {name}: {value:.2f}')
|
|
151
|
-
else:
|
|
152
|
-
print(f' {name}: {value}')
|
|
153
|
-
|
|
154
|
-
# Output artifacts (for completed jobs)
|
|
155
|
-
if status == 'Completed' and model_artifacts:
|
|
156
|
-
print(f' 📦 Artifacts: {model_artifacts}')
|
|
157
|
-
elif status == 'Completed' and output_path:
|
|
158
|
-
print(f' 📦 Output: {output_path}')
|
|
159
|
-
|
|
160
|
-
# Failure reason
|
|
161
|
-
if status == 'Failed' and failure_reason:
|
|
162
|
-
print(f' 💥 Reason: {failure_reason}')
|
|
163
|
-
print(f'')
|
|
164
|
-
print(f' To start a new job: ./do/train --force')
|
|
165
|
-
|
|
166
|
-
# Spot interruption guidance
|
|
167
|
-
if secondary_status == 'Interrupted':
|
|
168
|
-
print(f'')
|
|
169
|
-
print(f' ℹ️ Spot instance was interrupted. The job will automatically')
|
|
170
|
-
print(f' resume from the last checkpoint. Re-run ./do/train to poll.')
|
|
171
|
-
|
|
172
|
-
print(f'')
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
def main():
|
|
176
|
-
"""Main entry point — reads JSON from stdin."""
|
|
177
|
-
try:
|
|
178
|
-
job_data = json.load(sys.stdin)
|
|
179
|
-
except json.JSONDecodeError as e:
|
|
180
|
-
print(f'❌ Failed to parse DescribeTrainingJob response: {e}', file=sys.stderr)
|
|
181
|
-
sys.exit(1)
|
|
182
|
-
|
|
183
|
-
display_status(job_data)
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
if __name__ == '__main__':
|
|
187
|
-
main()
|
|
File without changes
|