@aws/ml-container-creator 0.7.1 → 0.9.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. package/LICENSE-THIRD-PARTY +50760 -16218
  2. package/bin/cli.js +1 -1
  3. package/infra/ci-harness/buildspec.yml +4 -0
  4. package/package.json +3 -1
  5. package/servers/lib/catalogs/instances.json +52 -1275
  6. package/servers/lib/catalogs/model-servers.json +80 -0
  7. package/servers/lib/catalogs/models.json +0 -132
  8. package/servers/lib/catalogs/popular-diffusors.json +1 -110
  9. package/servers/model-picker/index.js +27 -16
  10. package/src/app.js +113 -23
  11. package/src/lib/cli-handler.js +1 -1
  12. package/src/lib/config-manager.js +39 -2
  13. package/src/lib/cross-cutting-checker.js +146 -33
  14. package/src/lib/deployment-config-resolver.js +10 -4
  15. package/src/lib/e2e-bootstrap.js +227 -0
  16. package/src/lib/e2e-catalog-validator.js +103 -0
  17. package/src/lib/e2e-quota-validator.js +135 -0
  18. package/src/lib/mcp-client.js +16 -1
  19. package/src/lib/mcp-command-handler.js +10 -2
  20. package/src/lib/prompt-runner.js +306 -24
  21. package/src/lib/prompts.js +9 -3
  22. package/src/lib/template-manager.js +10 -4
  23. package/src/lib/train-config-parser.js +136 -0
  24. package/src/lib/train-config-persistence.js +143 -0
  25. package/src/lib/train-config-validator.js +112 -0
  26. package/src/lib/train-feedback.js +46 -0
  27. package/src/lib/train-idempotency.js +97 -0
  28. package/src/lib/train-request-builder.js +120 -0
  29. package/src/lib/tune-catalog-validator.js +5 -5
  30. package/templates/code/serve +2 -2
  31. package/templates/code/serving.properties +2 -2
  32. package/templates/diffusors/serve +3 -3
  33. package/templates/do/.train_build_request.py +141 -0
  34. package/templates/do/.train_poll_parser.py +135 -0
  35. package/templates/do/.train_status_parser.py +187 -0
  36. package/templates/do/.tune_helper.py +2 -2
  37. package/templates/do/lib/feedback.sh +41 -0
  38. package/templates/do/register +8 -2
  39. package/templates/do/test +5 -5
  40. package/templates/do/train +786 -0
  41. package/templates/do/training/config.yaml +140 -0
  42. package/templates/do/training/train.py +463 -0
  43. package/templates/do/tune +2 -2
  44. package/templates/marketplace/config +118 -0
  45. package/templates/marketplace/deploy +890 -0
  46. package/templates/marketplace/test +453 -0
@@ -0,0 +1,140 @@
1
+ # do/training/config.yaml — Bespoke Training Job Configuration
2
+ # Edit this file to configure your custom SageMaker training job.
3
+ # Run with: ./do/train
4
+ #
5
+ # This file is read by do/train to construct a CreateTrainingJob request.
6
+ # Required fields are marked with [REQUIRED]. All other fields are optional.
7
+
8
+ # ─── Container Image [REQUIRED] ──────────────────────────────────────────────
9
+ # ECR image URI for the training container.
10
+ # This defaults to your project's own container image (built via do/build + do/push).
11
+ # You can also use a SageMaker pre-built training image or any custom ECR image.
12
+ #
13
+ # Example: 123456789012.dkr.ecr.us-east-1.amazonaws.com/ml-container-creator:latest
14
+ image: "<%= '${AWS_ACCOUNT_ID}.dkr.ecr.${AWS_REGION}.amazonaws.com/${ECR_REPOSITORY_NAME}:latest' %>"
15
+
16
+ # ─── Training Script [REQUIRED] ──────────────────────────────────────────────
17
+ # Path to your training script relative to the project root.
18
+ # SageMaker copies this into the container at /opt/ml/code/ and executes it.
19
+ #
20
+ # Example: do/training/train.py
21
+ script: "do/training/train.py"
22
+
23
+ # ─── Instance Configuration [REQUIRED] ───────────────────────────────────────
24
+ # SageMaker ML instance type for training.
25
+ # Use GPU instances (ml.g5.*, ml.p4d.*) for deep learning workloads.
26
+ # Use CPU instances (ml.m5.*) for traditional ML or preprocessing.
27
+ #
28
+ # Examples:
29
+ # ml.g5.xlarge — 1 GPU (24GB A10G), 4 vCPU, 16 GB RAM
30
+ # ml.g5.2xlarge — 1 GPU (24GB A10G), 8 vCPU, 32 GB RAM
31
+ # ml.p4d.24xlarge — 8 GPUs (40GB A100), 96 vCPU, 1152 GB RAM
32
+ # ml.m5.xlarge — CPU only, 4 vCPU, 16 GB RAM
33
+ instance_type: "ml.g5.xlarge"
34
+
35
+ # Number of training instances. Set > 1 for distributed training.
36
+ # NOTE: When instance_count > 1, your training script must be distribution-aware
37
+ # (e.g., using PyTorch DDP, Horovod, or similar). SageMaker handles inter-node
38
+ # communication setup automatically.
39
+ #
40
+ # Example: 2
41
+ instance_count: 1
42
+
43
+ # ─── Data Configuration [REQUIRED] ───────────────────────────────────────────
44
+ # S3 URI for the input training dataset.
45
+ # SageMaker downloads this to /opt/ml/input/data/training/ in the container.
46
+ #
47
+ # Examples:
48
+ # s3://my-bucket/datasets/train/
49
+ # s3://my-bucket/datasets/train.jsonl
50
+ dataset: ""
51
+
52
+ # S3 URI for training output artifacts. [REQUIRED]
53
+ # SageMaker uploads the contents of /opt/ml/model/ to this location after training.
54
+ #
55
+ # Example: s3://my-bucket/training-output/
56
+ output_path: ""
57
+
58
+ # ─── Hyperparameters (optional) ──────────────────────────────────────────────
59
+ # Key-value pairs passed to your training script as command-line arguments.
60
+ # All values must be strings. SageMaker passes these as string arguments to
61
+ # the training script entry point.
62
+ #
63
+ # Example:
64
+ # hyperparameters:
65
+ # epochs: "10"
66
+ # batch_size: "32"
67
+ # learning_rate: "5e-5"
68
+ # warmup_steps: "100"
69
+ # weight_decay: "0.01"
70
+ hyperparameters:
71
+ epochs: "10"
72
+ batch_size: "32"
73
+ learning_rate: "5e-5"
74
+
75
+ # ─── Resource Limits (optional) ──────────────────────────────────────────────
76
+ # Maximum training duration in seconds. Job is stopped if it exceeds this limit.
77
+ # Default: 86400 (24 hours)
78
+ #
79
+ # Example: 43200 (12 hours)
80
+ max_runtime_seconds: 86400
81
+
82
+ # EBS volume size in GB attached to each training instance.
83
+ # Must be large enough to hold your dataset, model, and intermediate artifacts.
84
+ # Default: 50
85
+ #
86
+ # Example: 100
87
+ volume_size_gb: 50
88
+
89
+ # ─── Managed Spot Training (optional) ────────────────────────────────────────
90
+ # Use EC2 Spot Instances for training at up to 90% cost savings.
91
+ # When enabled, SageMaker handles interruptions and resumes from checkpoints.
92
+ # Your training script should save checkpoints to /opt/ml/checkpoints/ periodically.
93
+ #
94
+ # IMPORTANT: When enable_spot is true, you must also set max_wait_seconds.
95
+ # The max_wait_seconds value must be >= max_runtime_seconds.
96
+ enable_spot: false
97
+
98
+ # Maximum total time (seconds) to wait for spot capacity + training completion.
99
+ # Only used when enable_spot is true. Must be >= max_runtime_seconds.
100
+ # Default: 172800 (48 hours)
101
+ #
102
+ # Example: 172800
103
+ max_wait_seconds: 172800
104
+
105
+ # ─── Metric Definitions (optional) ───────────────────────────────────────────
106
+ # Custom regex patterns to extract training metrics from container logs.
107
+ # SageMaker uses these to publish metrics to CloudWatch for monitoring.
108
+ # Each entry needs a name and a regex with exactly one capture group.
109
+ #
110
+ # Example:
111
+ # metric_definitions:
112
+ # - name: "train:loss"
113
+ # regex: "loss: ([0-9\\.]+)"
114
+ # - name: "train:epoch"
115
+ # regex: "epoch: ([0-9\\.]+)"
116
+ # - name: "eval:accuracy"
117
+ # regex: "eval_accuracy: ([0-9\\.]+)"
118
+ metric_definitions: []
119
+
120
+ # ─── Environment Variables (optional) ────────────────────────────────────────
121
+ # Additional environment variables set in the training container.
122
+ # Use for configuration that doesn't fit as hyperparameters.
123
+ #
124
+ # Example:
125
+ # environment:
126
+ # NCCL_DEBUG: "INFO"
127
+ # CUDA_VISIBLE_DEVICES: "0,1"
128
+ # TRANSFORMERS_CACHE: "/opt/ml/input/data/cache"
129
+ environment: {}
130
+
131
+ # ─── Tags (optional) ─────────────────────────────────────────────────────────
132
+ # AWS resource tags applied to the training job.
133
+ # Useful for cost allocation, team tracking, and resource organization.
134
+ #
135
+ # Example:
136
+ # tags:
137
+ # team: "ml-platform"
138
+ # project: "fine-tuning-experiment-1"
139
+ # environment: "development"
140
+ tags: {}
@@ -0,0 +1,463 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ """SageMaker Training Script — Placeholder / Skeleton
6
+
7
+ This file demonstrates the contract that SageMaker expects from a custom
8
+ training script. Replace the placeholder logic with your own model training
9
+ code while keeping the I/O paths and conventions intact.
10
+
11
+ SageMaker copies this script into the container at /opt/ml/code/ and invokes
12
+ it as the entry point. The container filesystem is laid out as follows:
13
+
14
+ /opt/ml/
15
+ ├── input/
16
+ │ ├── config/ # Training job configuration (JSON files)
17
+ │ │ ├── hyperparameters.json
18
+ │ │ ├── resourceconfig.json
19
+ │ │ └── inputdataconfig.json
20
+ │ └── data/ # Input data channels
21
+ │ └── training/ # Default channel name (configurable)
22
+ │ └── ... # Your training data files
23
+ ├── model/ # Write final model artifacts here
24
+ │ └── ... # Everything here is packaged as model.tar.gz
25
+ ├── checkpoints/ # Save/restore checkpoints here (spot training)
26
+ │ └── ... # Persisted to S3 between interruptions
27
+ └── output/
28
+ └── failure # Write failure reason here on error
29
+
30
+ Key conventions:
31
+ - Hyperparameters are passed as string key-value pairs (always strings!)
32
+ - Training data is downloaded to /opt/ml/input/data/<channel_name>/
33
+ - Final model artifacts MUST be written to /opt/ml/model/
34
+ - Checkpoints in /opt/ml/checkpoints/ survive spot interruptions
35
+ - Stdout/stderr are captured to CloudWatch Logs automatically
36
+ - Exit code 0 = success, non-zero = failure
37
+ """
38
+
39
+ import argparse
40
+ import json
41
+ import logging
42
+ import os
43
+ import sys
44
+
45
+ # ── Logging setup ─────────────────────────────────────────────────────────────
46
+ # SageMaker captures stdout/stderr to CloudWatch Logs automatically.
47
+ # Use structured logging for easier debugging in production.
48
+
49
+ logging.basicConfig(
50
+ level=logging.INFO,
51
+ format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
52
+ )
53
+ logger = logging.getLogger(__name__)
54
+
55
+
56
+ # ── SageMaker environment paths ──────────────────────────────────────────────
57
+ # These paths are fixed by the SageMaker training container contract.
58
+
59
+ INPUT_DATA_DIR = os.environ.get("SM_CHANNEL_TRAINING", "/opt/ml/input/data/training")
60
+ MODEL_DIR = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
61
+ CHECKPOINT_DIR = "/opt/ml/checkpoints"
62
+ OUTPUT_DIR = "/opt/ml/output"
63
+ HYPERPARAMS_FILE = "/opt/ml/input/config/hyperparameters.json"
64
+ RESOURCE_CONFIG_FILE = "/opt/ml/input/config/resourceconfig.json"
65
+
66
+
67
+ # ── Hyperparameter loading ────────────────────────────────────────────────────
68
+
69
+
70
+ def load_hyperparameters():
71
+ """Load hyperparameters from SageMaker's config file.
72
+
73
+ SageMaker passes all hyperparameters as STRING values in a JSON file.
74
+ You must cast them to the appropriate types in your code.
75
+
76
+ Returns:
77
+ dict: Hyperparameters with string values.
78
+
79
+ Example hyperparameters.json:
80
+ {
81
+ "epochs": "10",
82
+ "batch_size": "32",
83
+ "learning_rate": "0.001"
84
+ }
85
+ """
86
+ if os.path.exists(HYPERPARAMS_FILE):
87
+ with open(HYPERPARAMS_FILE, "r") as f:
88
+ params = json.load(f)
89
+ logger.info("Loaded hyperparameters: %s", json.dumps(params, indent=2))
90
+ return params
91
+
92
+ logger.warning("No hyperparameters file found at %s", HYPERPARAMS_FILE)
93
+ return {}
94
+
95
+
96
+ # ── Data loading ──────────────────────────────────────────────────────────────
97
+
98
+
99
+ def load_training_data(data_dir):
100
+ """Load training data from the input channel directory.
101
+
102
+ SageMaker downloads your dataset from S3 to this directory before
103
+ training starts. The directory structure mirrors your S3 prefix.
104
+
105
+ Args:
106
+ data_dir: Path to the input data channel (e.g., /opt/ml/input/data/training).
107
+
108
+ Returns:
109
+ Your training data in whatever format your model expects.
110
+
111
+ Example directory contents:
112
+ /opt/ml/input/data/training/
113
+ ├── train.csv
114
+ ├── train.jsonl
115
+ └── data_part_001.parquet
116
+ """
117
+ logger.info("Loading training data from: %s", data_dir)
118
+
119
+ # List available files
120
+ if os.path.isdir(data_dir):
121
+ files = os.listdir(data_dir)
122
+ logger.info("Found %d file(s): %s", len(files), files)
123
+ else:
124
+ logger.error("Data directory does not exist: %s", data_dir)
125
+ sys.exit(1)
126
+
127
+ # ─── Replace this with your actual data loading logic ───
128
+ # Examples:
129
+ #
130
+ # For CSV:
131
+ # import pandas as pd
132
+ # df = pd.read_csv(os.path.join(data_dir, "train.csv"))
133
+ #
134
+ # For JSONL:
135
+ # records = []
136
+ # with open(os.path.join(data_dir, "train.jsonl")) as f:
137
+ # for line in f:
138
+ # records.append(json.loads(line))
139
+ #
140
+ # For PyTorch datasets:
141
+ # from torch.utils.data import DataLoader
142
+ # dataset = MyDataset(data_dir)
143
+ # dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
144
+ #
145
+ # For Hugging Face datasets:
146
+ # from datasets import load_from_disk
147
+ # dataset = load_from_disk(data_dir)
148
+
149
+ return None # Replace with your loaded data
150
+
151
+
152
+ # ── Checkpoint management ─────────────────────────────────────────────────────
153
+ # Checkpoints are CRITICAL for managed spot training. When SageMaker interrupts
154
+ # a spot instance, it saves /opt/ml/checkpoints/ to S3. When the job resumes,
155
+ # it restores the checkpoint directory before re-running your script.
156
+ #
157
+ # Best practices:
158
+ # - Save checkpoints periodically (e.g., every N epochs or steps)
159
+ # - Include enough state to fully resume: model weights, optimizer state, epoch
160
+ # - On startup, check if a checkpoint exists and resume from it
161
+ # - Use atomic writes (write to temp file, then rename) to avoid corruption
162
+
163
+
164
+ def save_checkpoint(model_state, optimizer_state, epoch, step):
165
+ """Save a training checkpoint for spot training resumption.
166
+
167
+ Args:
168
+ model_state: Model weights/parameters to save.
169
+ optimizer_state: Optimizer state for resuming training.
170
+ epoch: Current epoch number.
171
+ step: Current global step number.
172
+
173
+ The checkpoint directory (/opt/ml/checkpoints/) is automatically synced
174
+ to S3 by SageMaker. On spot interruption and restart, the contents are
175
+ restored before your script runs again.
176
+ """
177
+ os.makedirs(CHECKPOINT_DIR, exist_ok=True)
178
+
179
+ checkpoint = {
180
+ "epoch": epoch,
181
+ "step": step,
182
+ # Add your model and optimizer state here:
183
+ # "model_state_dict": model_state,
184
+ # "optimizer_state_dict": optimizer_state,
185
+ }
186
+
187
+ checkpoint_path = os.path.join(CHECKPOINT_DIR, "checkpoint_latest.json")
188
+
189
+ # Atomic write: write to temp file first, then rename
190
+ tmp_path = checkpoint_path + ".tmp"
191
+ with open(tmp_path, "w") as f:
192
+ json.dump(checkpoint, f)
193
+ os.replace(tmp_path, checkpoint_path)
194
+
195
+ logger.info("Saved checkpoint at epoch %d, step %d", epoch, step)
196
+
197
+ # ─── For PyTorch, you would typically do: ───
198
+ # import torch
199
+ # torch.save({
200
+ # "epoch": epoch,
201
+ # "step": step,
202
+ # "model_state_dict": model.state_dict(),
203
+ # "optimizer_state_dict": optimizer.state_dict(),
204
+ # "loss": current_loss,
205
+ # }, os.path.join(CHECKPOINT_DIR, "checkpoint_latest.pt"))
206
+
207
+
208
+ def load_checkpoint():
209
+ """Restore training state from a checkpoint if one exists.
210
+
211
+ Returns:
212
+ dict or None: Checkpoint data if found, None otherwise.
213
+
214
+ Call this at the start of training to resume from where you left off
215
+ after a spot interruption.
216
+ """
217
+ checkpoint_path = os.path.join(CHECKPOINT_DIR, "checkpoint_latest.json")
218
+
219
+ if os.path.exists(checkpoint_path):
220
+ with open(checkpoint_path, "r") as f:
221
+ checkpoint = json.load(f)
222
+ logger.info(
223
+ "Restored checkpoint: epoch %d, step %d",
224
+ checkpoint.get("epoch", 0),
225
+ checkpoint.get("step", 0),
226
+ )
227
+ return checkpoint
228
+
229
+ logger.info("No checkpoint found — starting from scratch")
230
+ return None
231
+
232
+ # ─── For PyTorch, you would typically do: ───
233
+ # import torch
234
+ # ckpt_path = os.path.join(CHECKPOINT_DIR, "checkpoint_latest.pt")
235
+ # if os.path.exists(ckpt_path):
236
+ # checkpoint = torch.load(ckpt_path)
237
+ # model.load_state_dict(checkpoint["model_state_dict"])
238
+ # optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
239
+ # start_epoch = checkpoint["epoch"]
240
+ # return checkpoint
241
+ # return None
242
+
243
+
244
+ # ── Model saving ──────────────────────────────────────────────────────────────
245
+
246
+
247
+ def save_model(model, model_dir):
248
+ """Save final model artifacts to the output directory.
249
+
250
+ IMPORTANT: Everything written to /opt/ml/model/ is packaged into a
251
+ model.tar.gz and uploaded to S3 at the output path you configured.
252
+ This is what gets deployed to a SageMaker endpoint.
253
+
254
+ Args:
255
+ model: Your trained model object.
256
+ model_dir: Path to write artifacts (default: /opt/ml/model/).
257
+
258
+ Common patterns:
259
+ - sklearn: joblib.dump(model, os.path.join(model_dir, "model.pkl"))
260
+ - PyTorch: torch.save(model.state_dict(), os.path.join(model_dir, "model.pt"))
261
+ - Hugging Face: model.save_pretrained(model_dir)
262
+ - XGBoost: model.save_model(os.path.join(model_dir, "model.json"))
263
+ """
264
+ os.makedirs(model_dir, exist_ok=True)
265
+
266
+ logger.info("Saving model artifacts to: %s", model_dir)
267
+
268
+ # ─── Replace with your model saving logic ───
269
+ # Examples:
270
+ #
271
+ # For scikit-learn:
272
+ # import joblib
273
+ # joblib.dump(model, os.path.join(model_dir, "model.pkl"))
274
+ #
275
+ # For PyTorch:
276
+ # import torch
277
+ # torch.save(model.state_dict(), os.path.join(model_dir, "model.pt"))
278
+ # # Also save model config/architecture for inference
279
+ # with open(os.path.join(model_dir, "config.json"), "w") as f:
280
+ # json.dump(model_config, f)
281
+ #
282
+ # For Hugging Face transformers:
283
+ # model.save_pretrained(model_dir)
284
+ # tokenizer.save_pretrained(model_dir)
285
+ #
286
+ # For LoRA adapters (PEFT):
287
+ # model.save_pretrained(model_dir)
288
+ # # This creates adapter_config.json + adapter weights
289
+ # # The feedback loop will detect this and suggest ./do/adapter add
290
+
291
+ # Placeholder: save a marker file
292
+ with open(os.path.join(model_dir, "model_info.json"), "w") as f:
293
+ json.dump({"status": "placeholder", "message": "Replace with real model"}, f)
294
+
295
+ logger.info("Model artifacts saved successfully")
296
+
297
+
298
+ # ── Training loop ─────────────────────────────────────────────────────────────
299
+
300
+
301
+ def train(hyperparams, data_dir, model_dir, checkpoint_dir):
302
+ """Main training loop.
303
+
304
+ This is where your actual training logic goes. The structure below
305
+ shows the recommended pattern for SageMaker compatibility:
306
+
307
+ 1. Load hyperparameters (cast strings to proper types)
308
+ 2. Load training data from the input channel
309
+ 3. Check for existing checkpoint (for spot training resumption)
310
+ 4. Run training loop with periodic checkpoint saves
311
+ 5. Save final model to the output directory
312
+
313
+ Args:
314
+ hyperparams: Dict of hyperparameters (all values are strings).
315
+ data_dir: Path to input training data.
316
+ model_dir: Path to write final model artifacts.
317
+ checkpoint_dir: Path for checkpoint save/restore.
318
+ """
319
+ # ── Step 1: Parse hyperparameters (cast from strings) ──
320
+ epochs = int(hyperparams.get("epochs", "10"))
321
+ batch_size = int(hyperparams.get("batch_size", "32"))
322
+ learning_rate = float(hyperparams.get("learning_rate", "0.001"))
323
+ checkpoint_frequency = int(hyperparams.get("checkpoint_frequency", "1"))
324
+
325
+ logger.info(
326
+ "Training config: epochs=%d, batch_size=%d, lr=%f",
327
+ epochs, batch_size, learning_rate,
328
+ )
329
+
330
+ # ── Step 2: Load training data ──
331
+ training_data = load_training_data(data_dir)
332
+
333
+ # ── Step 3: Check for existing checkpoint (spot resumption) ──
334
+ checkpoint = load_checkpoint()
335
+ start_epoch = 0
336
+ if checkpoint:
337
+ start_epoch = checkpoint.get("epoch", 0)
338
+ # Restore model and optimizer state from checkpoint here
339
+ logger.info("Resuming training from epoch %d", start_epoch)
340
+
341
+ # ── Step 4: Training loop ──
342
+ model = None # Replace with your model initialization
343
+
344
+ for epoch in range(start_epoch, epochs):
345
+ logger.info("Epoch %d/%d", epoch + 1, epochs)
346
+
347
+ # ─── Replace with your actual training step ───
348
+ # Example (PyTorch):
349
+ # model.train()
350
+ # for batch_idx, (data, target) in enumerate(dataloader):
351
+ # optimizer.zero_grad()
352
+ # output = model(data)
353
+ # loss = criterion(output, target)
354
+ # loss.backward()
355
+ # optimizer.step()
356
+ #
357
+ # if batch_idx % log_interval == 0:
358
+ # logger.info(" Step %d, Loss: %.4f", batch_idx, loss.item())
359
+
360
+ # Print metrics in a format SageMaker can parse for CloudWatch
361
+ # Use the regex pattern defined in config.yaml metric_definitions
362
+ train_loss = 0.0 # Replace with actual loss
363
+ print(f"loss: {train_loss:.4f}")
364
+ print(f"epoch: {epoch + 1}")
365
+
366
+ # ── Save checkpoint periodically ──
367
+ if (epoch + 1) % checkpoint_frequency == 0:
368
+ save_checkpoint(
369
+ model_state=None, # Replace with model.state_dict()
370
+ optimizer_state=None, # Replace with optimizer.state_dict()
371
+ epoch=epoch + 1,
372
+ step=(epoch + 1) * batch_size,
373
+ )
374
+
375
+ # ── Step 5: Save final model ──
376
+ save_model(model, model_dir)
377
+
378
+ logger.info("Training complete!")
379
+
380
+
381
+ # ── Distributed training helpers ──────────────────────────────────────────────
382
+ # When instance_count > 1 in config.yaml, SageMaker launches multiple instances
383
+ # and sets up inter-node communication. Your script must be distribution-aware.
384
+ #
385
+ # SageMaker provides these environment variables for distributed training:
386
+ # SM_HOSTS - JSON list of all host names
387
+ # SM_CURRENT_HOST - This instance's host name
388
+ # SM_NUM_GPUS - Number of GPUs on this instance
389
+ #
390
+ # Example (PyTorch DDP):
391
+ # import torch.distributed as dist
392
+ # dist.init_process_group(backend="nccl")
393
+ # local_rank = int(os.environ.get("LOCAL_RANK", 0))
394
+ # model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
395
+
396
+
397
+ def get_distributed_info():
398
+ """Get distributed training configuration from SageMaker environment.
399
+
400
+ Returns:
401
+ dict: Distributed training info including hosts, current host, and GPU count.
402
+ """
403
+ hosts = json.loads(os.environ.get("SM_HOSTS", '["localhost"]'))
404
+ current_host = os.environ.get("SM_CURRENT_HOST", "localhost")
405
+ num_gpus = int(os.environ.get("SM_NUM_GPUS", "0"))
406
+
407
+ return {
408
+ "hosts": hosts,
409
+ "current_host": current_host,
410
+ "num_gpus": num_gpus,
411
+ "num_hosts": len(hosts),
412
+ "is_leader": current_host == hosts[0],
413
+ }
414
+
415
+
416
+ # ── Entry point ───────────────────────────────────────────────────────────────
417
+
418
+
419
+ if __name__ == "__main__":
420
+ # SageMaker also passes hyperparameters as command-line arguments.
421
+ # You can use either the JSON file or argparse — both work.
422
+ parser = argparse.ArgumentParser(description="SageMaker Training Script")
423
+
424
+ # SageMaker standard arguments
425
+ parser.add_argument("--model-dir", type=str, default=MODEL_DIR,
426
+ help="Directory to save model artifacts")
427
+ parser.add_argument("--data-dir", type=str, default=INPUT_DATA_DIR,
428
+ help="Directory containing training data")
429
+
430
+ # Add your custom hyperparameters as CLI args if preferred:
431
+ # parser.add_argument("--epochs", type=int, default=10)
432
+ # parser.add_argument("--batch-size", type=int, default=32)
433
+ # parser.add_argument("--learning-rate", type=float, default=0.001)
434
+
435
+ args, _ = parser.parse_known_args()
436
+
437
+ # Load hyperparameters from SageMaker config file
438
+ hyperparams = load_hyperparameters()
439
+
440
+ # Log distributed training info
441
+ dist_info = get_distributed_info()
442
+ if dist_info["num_hosts"] > 1:
443
+ logger.info("Distributed training: %d hosts, %d GPUs per host",
444
+ dist_info["num_hosts"], dist_info["num_gpus"])
445
+ logger.info("Current host: %s (leader: %s)",
446
+ dist_info["current_host"], dist_info["is_leader"])
447
+
448
+ # Run training
449
+ try:
450
+ train(
451
+ hyperparams=hyperparams,
452
+ data_dir=args.data_dir,
453
+ model_dir=args.model_dir,
454
+ checkpoint_dir=CHECKPOINT_DIR,
455
+ )
456
+ except Exception as e:
457
+ # Write failure reason to /opt/ml/output/failure
458
+ failure_path = os.path.join(OUTPUT_DIR, "failure")
459
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
460
+ with open(failure_path, "w") as f:
461
+ f.write(str(e))
462
+ logger.error("Training failed: %s", e, exc_info=True)
463
+ sys.exit(1)
package/templates/do/tune CHANGED
@@ -67,7 +67,7 @@ _parse_args() {
67
67
  ARG_TRAINING_TYPE="$2"; shift 2 ;;
68
68
  --model)
69
69
  if [ -z "${2:-}" ]; then
70
- echo "❌ --model requires a JumpStart model ID"
70
+ echo "❌ --model requires a model ID"
71
71
  exit 1
72
72
  fi
73
73
  ARG_MODEL="$2"; shift 2 ;;
@@ -287,7 +287,7 @@ for family in sorted(families.keys()):
287
287
  for entry in entries:
288
288
  techniques = list(entry.get('techniques', {}).keys())
289
289
  print(f' • {entry[\"displayName\"]}')
290
- print(f' ID: {entry[\"jumpStartModelId\"]}')
290
+ print(f' ID: {entry[\"modelId\"]}')
291
291
  for t in techniques:
292
292
  tc = entry['techniques'][t]
293
293
  types = ', '.join(tc.get('trainingTypes', []))