@aws/ml-container-creator 1.0.3 → 1.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. package/README.md +1 -1
  2. package/infra/ci-harness/lib/ci-harness-stack.ts +43 -0
  3. package/package.json +2 -2
  4. package/servers/base-image-picker/index.js +65 -18
  5. package/servers/instance-sizer/index.js +32 -0
  6. package/servers/lib/catalogs/fleet-drivers.json +38 -0
  7. package/servers/lib/catalogs/model-arch-support.json +51 -0
  8. package/servers/lib/catalogs/model-servers.json +2842 -1730
  9. package/servers/lib/schemas/image-catalog.schema.json +12 -0
  10. package/src/app.js +6 -4
  11. package/src/lib/generated/cli-options.js +1 -1
  12. package/src/lib/generated/parameter-matrix.js +1 -1
  13. package/src/lib/generated/validation-rules.js +1 -1
  14. package/src/lib/mcp-query-runner.js +110 -3
  15. package/src/lib/prompt-runner.js +66 -22
  16. package/src/lib/template-variable-resolver.js +8 -0
  17. package/src/lib/train-config-builder.js +339 -0
  18. package/templates/do/.benchmark_writer.py +3 -0
  19. package/templates/do/.eval_helper.py +409 -0
  20. package/templates/do/.register_helper.py +185 -11
  21. package/templates/do/.train_build_request.py +102 -113
  22. package/templates/do/.train_helper.py +433 -0
  23. package/templates/do/__pycache__/.register_helper.cpython-312.pyc +0 -0
  24. package/templates/do/adapter +157 -0
  25. package/templates/do/benchmark +60 -3
  26. package/templates/do/deploy.d/managed-inference.ejs +83 -0
  27. package/templates/do/evaluate +272 -0
  28. package/templates/do/lib/resolve-instance.sh +155 -0
  29. package/templates/do/register +5 -0
  30. package/templates/do/test +1 -0
  31. package/templates/do/train +879 -126
  32. package/templates/do/training/config.yaml +83 -11
  33. package/templates/do/training/dpo/accelerate_config.yaml +24 -0
  34. package/templates/do/training/dpo/defaults.yaml +26 -0
  35. package/templates/do/training/dpo/prompts.json +8 -0
  36. package/templates/do/training/dpo/train.py +363 -0
  37. package/templates/do/training/sft/accelerate_config.yaml +22 -0
  38. package/templates/do/training/sft/defaults.yaml +18 -0
  39. package/templates/do/training/sft/prompts.json +7 -0
  40. package/templates/do/training/sft/train.py +310 -0
  41. package/templates/do/tune +11 -2
  42. package/templates/do/.train_poll_parser.py +0 -135
  43. package/templates/do/.train_status_parser.py +0 -187
  44. /package/templates/do/training/{train.py → custom/train.py} +0 -0
@@ -0,0 +1,310 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ """Unmanaged SFT fine-tuning using TRL SFTTrainer + PEFT LoRA.
6
+
7
+ This script performs LoRA-based supervised fine-tuning on a causal language
8
+ model. It is invoked via `accelerate launch` and works on both single-GPU
9
+ and multi-GPU instances without code changes.
10
+
11
+ Portable env-var contract (works on SageMaker AI and HyperPod EKS):
12
+ DATA_DIR / SM_CHANNEL_TRAINING -> training data path
13
+ OUTPUT_DIR / SM_MODEL_DIR -> model artifact output
14
+ CHECKPOINT_DIR / SM_CHECKPOINT_DIR -> checkpoint path for spot resume
15
+ HF_MODEL_ID / SM_HP_MODEL_ID -> base model HuggingFace ID
16
+ SM_HPS -> JSON blob of all hyperparameters
17
+
18
+ Output:
19
+ LoRA adapter saved to OUTPUT_DIR (adapter_model.safetensors + adapter_config.json)
20
+ Metrics logged to stdout in SageMaker-parseable format
21
+ """
22
+
23
+ import glob
24
+ import json
25
+ import logging
26
+ import os
27
+ import sys
28
+
29
+ # ── Logging ───────────────────────────────────────────────────────────────────
30
+ logging.basicConfig(
31
+ level=logging.INFO,
32
+ format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
33
+ )
34
+ logger = logging.getLogger("sft-trainer")
35
+
36
+ # ── Portable Path Resolution ─────────────────────────────────────────────────
37
+ # Fallback chain: generic env var -> SageMaker env var -> default path
38
+ # This allows the same script to run on SageMaker, HyperPod, or locally.
39
+
40
+ DATA_DIR = (
41
+ os.environ.get("DATA_DIR")
42
+ or os.environ.get("SM_CHANNEL_TRAINING")
43
+ or "/opt/ml/input/data/training"
44
+ )
45
+
46
+ OUTPUT_DIR = (
47
+ os.environ.get("OUTPUT_DIR")
48
+ or os.environ.get("SM_MODEL_DIR")
49
+ or "/opt/ml/model"
50
+ )
51
+
52
+ CHECKPOINT_DIR = (
53
+ os.environ.get("CHECKPOINT_DIR")
54
+ or os.environ.get("SM_CHECKPOINT_DIR")
55
+ or "/opt/ml/checkpoints"
56
+ )
57
+
58
+ MODEL_ID = (
59
+ os.environ.get("HF_MODEL_ID")
60
+ or os.environ.get("SM_HP_MODEL_ID")
61
+ or ""
62
+ )
63
+
64
+
65
+ # ── Hyperparameter Loading ────────────────────────────────────────────────────
66
+
67
+ def load_hyperparameters():
68
+ """Load hyperparameters from SageMaker SM_HPS env var or individual SM_HP_* vars.
69
+
70
+ Returns:
71
+ dict with typed hyperparameter values.
72
+ """
73
+ defaults = {
74
+ "model_id": MODEL_ID,
75
+ "lora_r": 16,
76
+ "lora_alpha": 32,
77
+ "lora_dropout": 0.05,
78
+ "learning_rate": 2e-4,
79
+ "epochs": 3,
80
+ "batch_size": 4,
81
+ "max_seq_length": 2048,
82
+ "gradient_accumulation_steps": 4,
83
+ "warmup_ratio": 0.03,
84
+ "dataset_text_field": "text",
85
+ }
86
+
87
+ # Try SM_HPS (JSON blob of all hyperparameters)
88
+ sm_hps = os.environ.get("SM_HPS")
89
+ if sm_hps:
90
+ try:
91
+ raw = json.loads(sm_hps)
92
+ # SageMaker passes all values as strings — cast them
93
+ for key, default_val in defaults.items():
94
+ if key in raw:
95
+ defaults[key] = _cast(raw[key], type(default_val))
96
+ return defaults
97
+ except (json.JSONDecodeError, ValueError) as e:
98
+ logger.warning("Failed to parse SM_HPS: %s", e)
99
+
100
+ # Fallback: individual SM_HP_* env vars
101
+ for key, default_val in defaults.items():
102
+ env_key = f"SM_HP_{key.upper()}"
103
+ env_val = os.environ.get(env_key)
104
+ if env_val is not None:
105
+ defaults[key] = _cast(env_val, type(default_val))
106
+
107
+ return defaults
108
+
109
+
110
+ def _cast(value, target_type):
111
+ """Cast a string value to the target type."""
112
+ if target_type == bool:
113
+ return str(value).lower() in ("true", "1", "yes")
114
+ if target_type == int:
115
+ return int(float(value))
116
+ if target_type == float:
117
+ return float(value)
118
+ return str(value)
119
+
120
+
121
+ # ── Dataset Loading ───────────────────────────────────────────────────────────
122
+
123
+ def load_dataset(data_dir, text_field):
124
+ """Load training dataset from data directory.
125
+
126
+ Supports .jsonl, .parquet, and .csv files.
127
+
128
+ Args:
129
+ data_dir: Path to directory containing training data files.
130
+ text_field: Name of the text column in the dataset.
131
+
132
+ Returns:
133
+ A Hugging Face Dataset object.
134
+ """
135
+ from datasets import load_dataset as hf_load_dataset
136
+
137
+ # Find data files
138
+ extensions = ["jsonl", "json", "parquet", "csv"]
139
+ data_files = []
140
+ for ext in extensions:
141
+ data_files.extend(glob.glob(os.path.join(data_dir, f"*.{ext}")))
142
+ data_files.extend(glob.glob(os.path.join(data_dir, f"**/*.{ext}"), recursive=True))
143
+
144
+ if not data_files:
145
+ logger.error("No data files found in %s (searched: %s)", data_dir, extensions)
146
+ sys.exit(1)
147
+
148
+ # Deduplicate and sort
149
+ data_files = sorted(set(data_files))
150
+ logger.info("Found %d data file(s) in %s", len(data_files), data_dir)
151
+
152
+ # Determine format from first file extension
153
+ first_ext = data_files[0].rsplit(".", 1)[-1].lower()
154
+ format_map = {"jsonl": "json", "json": "json", "parquet": "parquet", "csv": "csv"}
155
+ file_format = format_map.get(first_ext, "json")
156
+
157
+ dataset = hf_load_dataset(file_format, data_files=data_files, split="train")
158
+ logger.info("Loaded dataset: %d rows, columns: %s", len(dataset), dataset.column_names)
159
+
160
+ # Verify text field exists
161
+ if text_field not in dataset.column_names:
162
+ logger.error(
163
+ "Text field '%s' not found in dataset. Available columns: %s",
164
+ text_field, dataset.column_names,
165
+ )
166
+ sys.exit(1)
167
+
168
+ return dataset
169
+
170
+
171
+ # ── Main Training Function ────────────────────────────────────────────────────
172
+
173
+ def main():
174
+ """Run SFT training with TRL SFTTrainer + PEFT LoRA."""
175
+ from accelerate import Accelerator
176
+ from peft import LoraConfig, TaskType
177
+ from transformers import (
178
+ AutoModelForCausalLM,
179
+ AutoTokenizer,
180
+ TrainingArguments,
181
+ )
182
+ from trl import SFTTrainer
183
+
184
+ # Initialize accelerator (handles distributed setup)
185
+ accelerator = Accelerator()
186
+
187
+ # Load hyperparameters
188
+ hparams = load_hyperparameters()
189
+ model_id = hparams["model_id"]
190
+
191
+ if not model_id:
192
+ logger.error("No model ID specified. Set HF_MODEL_ID env var or model_id hyperparameter.")
193
+ sys.exit(1)
194
+
195
+ if accelerator.is_main_process:
196
+ logger.info("=" * 60)
197
+ logger.info("SFT Training Configuration")
198
+ logger.info("=" * 60)
199
+ logger.info(" Model: %s", model_id)
200
+ logger.info(" Data dir: %s", DATA_DIR)
201
+ logger.info(" Output dir: %s", OUTPUT_DIR)
202
+ logger.info(" Checkpoint dir: %s", CHECKPOINT_DIR)
203
+ logger.info(" LoRA r: %d", hparams["lora_r"])
204
+ logger.info(" LoRA alpha: %d", hparams["lora_alpha"])
205
+ logger.info(" Learning rate: %s", hparams["learning_rate"])
206
+ logger.info(" Epochs: %d", hparams["epochs"])
207
+ logger.info(" Batch size: %d", hparams["batch_size"])
208
+ logger.info(" Max seq len: %d", hparams["max_seq_length"])
209
+ logger.info(" Text field: %s", hparams["dataset_text_field"])
210
+ logger.info("=" * 60)
211
+
212
+ # ── Load tokenizer and model ──────────────────────────────────────────────
213
+ if accelerator.is_main_process:
214
+ logger.info("Loading tokenizer and model: %s", model_id)
215
+
216
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
217
+ if tokenizer.pad_token is None:
218
+ tokenizer.pad_token = tokenizer.eos_token
219
+
220
+ model = AutoModelForCausalLM.from_pretrained(
221
+ model_id,
222
+ torch_dtype="auto",
223
+ trust_remote_code=True,
224
+ )
225
+
226
+ # ── Configure LoRA ────────────────────────────────────────────────────────
227
+ lora_config = LoraConfig(
228
+ r=hparams["lora_r"],
229
+ lora_alpha=hparams["lora_alpha"],
230
+ lora_dropout=hparams["lora_dropout"],
231
+ target_modules="all-linear",
232
+ task_type=TaskType.CAUSAL_LM,
233
+ )
234
+
235
+ # ── Load dataset ──────────────────────────────────────────────────────────
236
+ dataset = load_dataset(DATA_DIR, hparams["dataset_text_field"])
237
+
238
+ # ── Training arguments ────────────────────────────────────────────────────
239
+ training_args = TrainingArguments(
240
+ output_dir=os.path.join(CHECKPOINT_DIR, "trainer-state"),
241
+ num_train_epochs=hparams["epochs"],
242
+ per_device_train_batch_size=hparams["batch_size"],
243
+ gradient_accumulation_steps=hparams["gradient_accumulation_steps"],
244
+ learning_rate=hparams["learning_rate"],
245
+ warmup_ratio=hparams["warmup_ratio"],
246
+ bf16=True,
247
+ logging_steps=10,
248
+ save_strategy="epoch",
249
+ save_total_limit=2,
250
+ report_to="none",
251
+ remove_unused_columns=False,
252
+ )
253
+
254
+ # ── Check for existing checkpoint (spot resume) ───────────────────────────
255
+ resume_from_checkpoint = None
256
+ trainer_state_dir = os.path.join(CHECKPOINT_DIR, "trainer-state")
257
+ if os.path.isdir(trainer_state_dir):
258
+ checkpoints = sorted(
259
+ glob.glob(os.path.join(trainer_state_dir, "checkpoint-*")),
260
+ key=lambda x: int(x.rsplit("-", 1)[-1]) if x.rsplit("-", 1)[-1].isdigit() else 0,
261
+ )
262
+ if checkpoints:
263
+ resume_from_checkpoint = checkpoints[-1]
264
+ if accelerator.is_main_process:
265
+ logger.info("Resuming from checkpoint: %s", resume_from_checkpoint)
266
+
267
+ # ── Initialize SFTTrainer ─────────────────────────────────────────────────
268
+ trainer = SFTTrainer(
269
+ model=model,
270
+ tokenizer=tokenizer,
271
+ train_dataset=dataset,
272
+ peft_config=lora_config,
273
+ args=training_args,
274
+ max_seq_length=hparams["max_seq_length"],
275
+ dataset_text_field=hparams["dataset_text_field"],
276
+ )
277
+
278
+ # ── Train ─────────────────────────────────────────────────────────────────
279
+ if accelerator.is_main_process:
280
+ logger.info("Starting SFT training...")
281
+
282
+ train_result = trainer.train(resume_from_checkpoint=resume_from_checkpoint)
283
+
284
+ # ── Save adapter (rank 0 only) ───────────────────────────────────────────
285
+ if accelerator.is_main_process:
286
+ logger.info("Saving LoRA adapter to: %s", OUTPUT_DIR)
287
+ trainer.save_model(OUTPUT_DIR)
288
+ tokenizer.save_pretrained(OUTPUT_DIR)
289
+
290
+ # ── Log final metrics ─────────────────────────────────────────────────
291
+ # Format: metric_name: value (SageMaker captures via regex in config.yaml)
292
+ metrics = train_result.metrics
293
+ print(f"train_loss: {metrics.get('train_loss', 0.0):.4f}")
294
+ print(f"train_runtime: {metrics.get('train_runtime', 0.0):.1f}")
295
+ print(f"train_samples_per_second: {metrics.get('train_samples_per_second', 0.0):.2f}")
296
+ print(f"epochs: {hparams['epochs']}")
297
+
298
+ logger.info("Training complete!")
299
+ logger.info(" Loss: %.4f", metrics.get("train_loss", 0.0))
300
+ logger.info(" Runtime: %.1fs", metrics.get("train_runtime", 0.0))
301
+ logger.info(" Samples/s: %.2f", metrics.get("train_samples_per_second", 0.0))
302
+
303
+ # Wait for all processes before exit
304
+ accelerator.wait_for_everyone()
305
+
306
+
307
+ # ── Entry Point ───────────────────────────────────────────────────────────────
308
+
309
+ if __name__ == "__main__":
310
+ main()
package/templates/do/tune CHANGED
@@ -796,9 +796,16 @@ else:
796
796
  _validate_dataset() {
797
797
  local dataset="${ARG_DATASET}"
798
798
 
799
- # ── Parse @v<N> version suffix (AC-2.1, AC-2.3) ──────────────────────────
799
+ # ── Parse @v<version> suffix (AC-2.1, AC-2.3) ───────────────────────────
800
800
  # Syntax: dataset-name@v2 → name="dataset-name", version ordinal=2
801
- if [[ "${dataset}" =~ ^(.+)@v([0-9]+)$ ]]; then
801
+ # dataset-name@v1.0.0 name="dataset-name", version="1.0.0" (semver)
802
+ if [[ "${dataset}" =~ ^(.+)@v([0-9]+\.[0-9]+\.[0-9]+)$ ]]; then
803
+ # Semver syntax: @v1.0.0
804
+ ARG_DATASET_NAME="${BASH_REMATCH[1]}"
805
+ ARG_DATASET_VERSION="${BASH_REMATCH[2]}"
806
+ dataset="" # Clear so name-based resolution takes over
807
+ elif [[ "${dataset}" =~ ^(.+)@v([0-9]+)$ ]]; then
808
+ # Ordinal syntax: @v1, @v2
802
809
  ARG_DATASET_NAME="${BASH_REMATCH[1]}"
803
810
  ARG_DATASET_VERSION="${BASH_REMATCH[2]}"
804
811
  dataset="" # Clear so name-based resolution takes over
@@ -824,6 +831,8 @@ _validate_dataset() {
824
831
  echo " Resolved to: ${resolved_uri}"
825
832
  dataset="${resolved_uri}"
826
833
  ARG_DATASET="${resolved_uri}"
834
+ RESOLVED_DATASET_S3_URI="${resolved_uri}"
835
+ return 0
827
836
  else
828
837
  echo "❌ Dataset '${ARG_DATASET_NAME}' not found in registry"
829
838
  echo " Register it first: ./do/register --dataset --dataset-name ${ARG_DATASET_NAME} --dataset-s3-uri s3://..."
@@ -1,135 +0,0 @@
1
- #!/usr/bin/env python3
2
- # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
- # SPDX-License-Identifier: Apache-2.0
4
-
5
- """
6
- Parse DescribeTrainingJob JSON for the polling loop in do/train.
7
-
8
- Reads JSON from stdin and outputs structured key=value lines for bash consumption:
9
- STATUS=<TrainingJobStatus>
10
- SECONDARY=<SecondaryStatus>
11
- FAILURE_REASON=<FailureReason or empty>
12
- DISPLAY=<formatted single-line status display>
13
-
14
- This keeps the bash poll loop simple while handling JSON parsing in Python.
15
- """
16
-
17
- import json
18
- import sys
19
- from datetime import datetime, timezone
20
-
21
-
22
- def format_duration(seconds):
23
- """Format seconds into a human-readable duration string."""
24
- if seconds is None or seconds < 0:
25
- return 'N/A'
26
- hours = int(seconds // 3600)
27
- minutes = int((seconds % 3600) // 60)
28
- secs = int(seconds % 60)
29
- if hours > 0:
30
- return f'{hours}h {minutes}m {secs}s'
31
- elif minutes > 0:
32
- return f'{minutes}m {secs}s'
33
- else:
34
- return f'{secs}s'
35
-
36
-
37
- def parse_iso_time(time_str):
38
- """Parse an ISO 8601 timestamp string to a datetime object."""
39
- if not time_str:
40
- return None
41
- try:
42
- time_str = time_str.replace('Z', '+00:00')
43
- return datetime.fromisoformat(time_str)
44
- except (ValueError, TypeError):
45
- return None
46
-
47
-
48
- def calculate_elapsed(start_time_str):
49
- """Calculate elapsed time from start to now."""
50
- start = parse_iso_time(start_time_str)
51
- if not start:
52
- return None
53
- now = datetime.now(timezone.utc)
54
- elapsed = (now - start).total_seconds()
55
- return max(0, elapsed)
56
-
57
-
58
- def format_metrics(final_metrics):
59
- """Format FinalMetricDataList into a compact string."""
60
- if not final_metrics:
61
- return ''
62
- parts = []
63
- for metric in final_metrics:
64
- name = metric.get('MetricName', 'unknown')
65
- value = metric.get('Value', 0)
66
- if isinstance(value, float):
67
- if abs(value) < 0.001:
68
- parts.append(f'{name}={value:.6f}')
69
- elif abs(value) < 1:
70
- parts.append(f'{name}={value:.4f}')
71
- else:
72
- parts.append(f'{name}={value:.2f}')
73
- else:
74
- parts.append(f'{name}={value}')
75
- return ', '.join(parts)
76
-
77
-
78
- # Status emoji mapping
79
- STATUS_EMOJI = {
80
- 'InProgress': '🔄',
81
- 'Completed': '✅',
82
- 'Failed': '❌',
83
- 'Stopping': '⏸️',
84
- 'Stopped': '⏹️'
85
- }
86
-
87
-
88
- def main():
89
- """Parse DescribeTrainingJob JSON from stdin and output structured lines."""
90
- try:
91
- job_data = json.load(sys.stdin)
92
- except json.JSONDecodeError as e:
93
- print(f'Error parsing JSON: {e}', file=sys.stderr)
94
- sys.exit(1)
95
-
96
- status = job_data.get('TrainingJobStatus', 'Unknown')
97
- secondary_status = job_data.get('SecondaryStatus', '')
98
- failure_reason = job_data.get('FailureReason', '')
99
- training_start = job_data.get('TrainingStartTime', '')
100
- final_metrics = job_data.get('FinalMetricDataList', [])
101
-
102
- # Calculate elapsed time
103
- elapsed_str = ''
104
- if training_start:
105
- elapsed = calculate_elapsed(training_start)
106
- if elapsed is not None:
107
- elapsed_str = format_duration(elapsed)
108
-
109
- # Format metrics
110
- metrics_str = format_metrics(final_metrics)
111
-
112
- # Build display line
113
- emoji = STATUS_EMOJI.get(status, '❓')
114
- display_parts = [f' {emoji} {status}']
115
-
116
- if secondary_status:
117
- display_parts.append(f'| {secondary_status}')
118
-
119
- if elapsed_str:
120
- display_parts.append(f'| elapsed: {elapsed_str}')
121
-
122
- if metrics_str:
123
- display_parts.append(f'| {metrics_str}')
124
-
125
- display_line = ' '.join(display_parts)
126
-
127
- # Output structured lines for bash
128
- print(f'STATUS={status}')
129
- print(f'SECONDARY={secondary_status}')
130
- print(f'FAILURE_REASON={failure_reason}')
131
- print(f'DISPLAY={display_line}')
132
-
133
-
134
- if __name__ == '__main__':
135
- main()
@@ -1,187 +0,0 @@
1
- #!/usr/bin/env python3
2
- # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
- # SPDX-License-Identifier: Apache-2.0
4
-
5
- """
6
- Parse DescribeTrainingJob JSON response and display formatted status.
7
-
8
- This helper is called by do/train --status to parse the AWS CLI JSON output
9
- from DescribeTrainingJob and display a user-friendly status summary.
10
- """
11
-
12
- import json
13
- import sys
14
- import time
15
- from datetime import datetime, timezone
16
-
17
-
18
- # Status emoji mapping
19
- STATUS_EMOJI = {
20
- 'InProgress': '🔄',
21
- 'Completed': '✅',
22
- 'Failed': '❌',
23
- 'Stopping': '⏸️',
24
- 'Stopped': '⏹️'
25
- }
26
-
27
- # Secondary status descriptions
28
- SECONDARY_DESCRIPTIONS = {
29
- 'Starting': 'Preparing training instance',
30
- 'LaunchingMLInstances': 'Launching ML instances',
31
- 'PreparingTrainingStack': 'Preparing training stack',
32
- 'Downloading': 'Downloading training data',
33
- 'DownloadingTrainingImage': 'Downloading training image',
34
- 'Training': 'Training in progress',
35
- 'Uploading': 'Uploading model artifacts',
36
- 'Completed': 'Training completed',
37
- 'MaxRuntimeExceeded': 'Max runtime exceeded',
38
- 'Stopped': 'Training stopped',
39
- 'MaxWaitTimeExceeded': 'Max wait time exceeded (spot)',
40
- 'Interrupted': 'Spot instance interrupted'
41
- }
42
-
43
-
44
- def format_duration(seconds):
45
- """Format seconds into a human-readable duration string."""
46
- if seconds is None or seconds < 0:
47
- return 'N/A'
48
- hours = int(seconds // 3600)
49
- minutes = int((seconds % 3600) // 60)
50
- secs = int(seconds % 60)
51
- if hours > 0:
52
- return f'{hours}h {minutes}m {secs}s'
53
- elif minutes > 0:
54
- return f'{minutes}m {secs}s'
55
- else:
56
- return f'{secs}s'
57
-
58
-
59
- def parse_iso_time(time_str):
60
- """Parse an ISO 8601 timestamp string to a datetime object."""
61
- if not time_str:
62
- return None
63
- try:
64
- # Handle various AWS timestamp formats
65
- # Remove trailing 'Z' and replace with +00:00 for fromisoformat
66
- time_str = time_str.replace('Z', '+00:00')
67
- return datetime.fromisoformat(time_str)
68
- except (ValueError, TypeError):
69
- return None
70
-
71
-
72
- def calculate_elapsed(start_time_str):
73
- """Calculate elapsed time from start to now."""
74
- start = parse_iso_time(start_time_str)
75
- if not start:
76
- return None
77
- now = datetime.now(timezone.utc)
78
- elapsed = (now - start).total_seconds()
79
- return max(0, elapsed)
80
-
81
-
82
- def display_status(job_data):
83
- """Display formatted training job status."""
84
- job_name = job_data.get('TrainingJobName', 'Unknown')
85
- status = job_data.get('TrainingJobStatus', 'Unknown')
86
- secondary_status = job_data.get('SecondaryStatus', '')
87
- failure_reason = job_data.get('FailureReason', '')
88
- training_start = job_data.get('TrainingStartTime', '')
89
- training_end = job_data.get('TrainingEndTime', '')
90
- billable_seconds = job_data.get('BillableTimeInSeconds')
91
- training_seconds = job_data.get('TrainingTimeInSeconds')
92
- final_metrics = job_data.get('FinalMetricDataList', [])
93
- output_path = job_data.get('OutputDataConfig', {}).get('S3OutputPath', '')
94
- model_artifacts = job_data.get('ModelArtifacts', {}).get('S3ModelArtifacts', '')
95
- instance_type = job_data.get('ResourceConfig', {}).get('InstanceType', '')
96
- instance_count = job_data.get('ResourceConfig', {}).get('InstanceCount', 1)
97
- spot_enabled = job_data.get('EnableManagedSpotTraining', False)
98
-
99
- emoji = STATUS_EMOJI.get(status, '❓')
100
-
101
- print(f'')
102
- print(f' {emoji} Status: {status}')
103
-
104
- # Secondary status with description
105
- if secondary_status:
106
- desc = SECONDARY_DESCRIPTIONS.get(secondary_status, '')
107
- if desc:
108
- print(f' 📍 Phase: {secondary_status} ({desc})')
109
- else:
110
- print(f' 📍 Phase: {secondary_status}')
111
-
112
- # Elapsed time
113
- if status == 'InProgress' and training_start:
114
- elapsed = calculate_elapsed(training_start)
115
- if elapsed is not None:
116
- print(f' ⏱️ Elapsed: {format_duration(elapsed)}')
117
- elif training_seconds is not None:
118
- print(f' ⏱️ Training time: {format_duration(training_seconds)}')
119
-
120
- # Instance info
121
- if instance_type:
122
- instance_info = f'{instance_type}'
123
- if instance_count and instance_count > 1:
124
- instance_info += f' x {instance_count}'
125
- if spot_enabled:
126
- instance_info += ' (spot)'
127
- print(f' 🖥️ Instance: {instance_info}')
128
-
129
- # Billable time and cost savings (for completed spot jobs)
130
- if status == 'Completed' and spot_enabled and billable_seconds is not None and training_seconds is not None:
131
- savings_seconds = training_seconds - billable_seconds
132
- if training_seconds > 0:
133
- savings_pct = (savings_seconds / training_seconds) * 100
134
- print(f' 💰 Spot savings: {format_duration(savings_seconds)} saved ({savings_pct:.0f}% discount)')
135
- print(f' Billable: {format_duration(billable_seconds)} / Total: {format_duration(training_seconds)}')
136
-
137
- # Training metrics
138
- if final_metrics:
139
- print(f' 📈 Metrics:')
140
- for metric in final_metrics:
141
- name = metric.get('MetricName', 'unknown')
142
- value = metric.get('Value', 0)
143
- # Format value nicely
144
- if isinstance(value, float):
145
- if abs(value) < 0.001:
146
- print(f' {name}: {value:.6f}')
147
- elif abs(value) < 1:
148
- print(f' {name}: {value:.4f}')
149
- else:
150
- print(f' {name}: {value:.2f}')
151
- else:
152
- print(f' {name}: {value}')
153
-
154
- # Output artifacts (for completed jobs)
155
- if status == 'Completed' and model_artifacts:
156
- print(f' 📦 Artifacts: {model_artifacts}')
157
- elif status == 'Completed' and output_path:
158
- print(f' 📦 Output: {output_path}')
159
-
160
- # Failure reason
161
- if status == 'Failed' and failure_reason:
162
- print(f' 💥 Reason: {failure_reason}')
163
- print(f'')
164
- print(f' To start a new job: ./do/train --force')
165
-
166
- # Spot interruption guidance
167
- if secondary_status == 'Interrupted':
168
- print(f'')
169
- print(f' ℹ️ Spot instance was interrupted. The job will automatically')
170
- print(f' resume from the last checkpoint. Re-run ./do/train to poll.')
171
-
172
- print(f'')
173
-
174
-
175
- def main():
176
- """Main entry point — reads JSON from stdin."""
177
- try:
178
- job_data = json.load(sys.stdin)
179
- except json.JSONDecodeError as e:
180
- print(f'❌ Failed to parse DescribeTrainingJob response: {e}', file=sys.stderr)
181
- sys.exit(1)
182
-
183
- display_status(job_data)
184
-
185
-
186
- if __name__ == '__main__':
187
- main()