@aws/ml-container-creator 0.8.0 → 0.9.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE-THIRD-PARTY +50760 -16218
- package/package.json +3 -1
- package/servers/lib/catalogs/instances.json +52 -1275
- package/servers/lib/catalogs/models.json +0 -132
- package/servers/lib/catalogs/popular-diffusors.json +1 -110
- package/src/app.js +24 -2
- package/src/lib/mcp-client.js +16 -1
- package/src/lib/mcp-command-handler.js +10 -2
- package/src/lib/prompt-runner.js +16 -2
- package/src/lib/train-config-parser.js +136 -0
- package/src/lib/train-config-persistence.js +143 -0
- package/src/lib/train-config-validator.js +112 -0
- package/src/lib/train-feedback.js +46 -0
- package/src/lib/train-idempotency.js +97 -0
- package/src/lib/train-request-builder.js +120 -0
- package/templates/do/.train_build_request.py +141 -0
- package/templates/do/.train_poll_parser.py +135 -0
- package/templates/do/.train_status_parser.py +187 -0
- package/templates/do/lib/feedback.sh +41 -0
- package/templates/do/train +786 -0
- package/templates/do/training/config.yaml +140 -0
- package/templates/do/training/train.py +463 -0
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
# do/training/config.yaml — Bespoke Training Job Configuration
|
|
2
|
+
# Edit this file to configure your custom SageMaker training job.
|
|
3
|
+
# Run with: ./do/train
|
|
4
|
+
#
|
|
5
|
+
# This file is read by do/train to construct a CreateTrainingJob request.
|
|
6
|
+
# Required fields are marked with [REQUIRED]. All other fields are optional.
|
|
7
|
+
|
|
8
|
+
# ─── Container Image [REQUIRED] ──────────────────────────────────────────────
|
|
9
|
+
# ECR image URI for the training container.
|
|
10
|
+
# This defaults to your project's own container image (built via do/build + do/push).
|
|
11
|
+
# You can also use a SageMaker pre-built training image or any custom ECR image.
|
|
12
|
+
#
|
|
13
|
+
# Example: 123456789012.dkr.ecr.us-east-1.amazonaws.com/ml-container-creator:latest
|
|
14
|
+
image: "<%= '${AWS_ACCOUNT_ID}.dkr.ecr.${AWS_REGION}.amazonaws.com/${ECR_REPOSITORY_NAME}:latest' %>"
|
|
15
|
+
|
|
16
|
+
# ─── Training Script [REQUIRED] ──────────────────────────────────────────────
|
|
17
|
+
# Path to your training script relative to the project root.
|
|
18
|
+
# SageMaker copies this into the container at /opt/ml/code/ and executes it.
|
|
19
|
+
#
|
|
20
|
+
# Example: do/training/train.py
|
|
21
|
+
script: "do/training/train.py"
|
|
22
|
+
|
|
23
|
+
# ─── Instance Configuration [REQUIRED] ───────────────────────────────────────
|
|
24
|
+
# SageMaker ML instance type for training.
|
|
25
|
+
# Use GPU instances (ml.g5.*, ml.p4d.*) for deep learning workloads.
|
|
26
|
+
# Use CPU instances (ml.m5.*) for traditional ML or preprocessing.
|
|
27
|
+
#
|
|
28
|
+
# Examples:
|
|
29
|
+
# ml.g5.xlarge — 1 GPU (24GB A10G), 4 vCPU, 16 GB RAM
|
|
30
|
+
# ml.g5.2xlarge — 1 GPU (24GB A10G), 8 vCPU, 32 GB RAM
|
|
31
|
+
# ml.p4d.24xlarge — 8 GPUs (40GB A100), 96 vCPU, 1152 GB RAM
|
|
32
|
+
# ml.m5.xlarge — CPU only, 4 vCPU, 16 GB RAM
|
|
33
|
+
instance_type: "ml.g5.xlarge"
|
|
34
|
+
|
|
35
|
+
# Number of training instances. Set > 1 for distributed training.
|
|
36
|
+
# NOTE: When instance_count > 1, your training script must be distribution-aware
|
|
37
|
+
# (e.g., using PyTorch DDP, Horovod, or similar). SageMaker handles inter-node
|
|
38
|
+
# communication setup automatically.
|
|
39
|
+
#
|
|
40
|
+
# Example: 2
|
|
41
|
+
instance_count: 1
|
|
42
|
+
|
|
43
|
+
# ─── Data Configuration [REQUIRED] ───────────────────────────────────────────
|
|
44
|
+
# S3 URI for the input training dataset.
|
|
45
|
+
# SageMaker downloads this to /opt/ml/input/data/training/ in the container.
|
|
46
|
+
#
|
|
47
|
+
# Examples:
|
|
48
|
+
# s3://my-bucket/datasets/train/
|
|
49
|
+
# s3://my-bucket/datasets/train.jsonl
|
|
50
|
+
dataset: ""
|
|
51
|
+
|
|
52
|
+
# S3 URI for training output artifacts. [REQUIRED]
|
|
53
|
+
# SageMaker uploads the contents of /opt/ml/model/ to this location after training.
|
|
54
|
+
#
|
|
55
|
+
# Example: s3://my-bucket/training-output/
|
|
56
|
+
output_path: ""
|
|
57
|
+
|
|
58
|
+
# ─── Hyperparameters (optional) ──────────────────────────────────────────────
|
|
59
|
+
# Key-value pairs passed to your training script as command-line arguments.
|
|
60
|
+
# All values must be strings. SageMaker passes these as string arguments to
|
|
61
|
+
# the training script entry point.
|
|
62
|
+
#
|
|
63
|
+
# Example:
|
|
64
|
+
# hyperparameters:
|
|
65
|
+
# epochs: "10"
|
|
66
|
+
# batch_size: "32"
|
|
67
|
+
# learning_rate: "5e-5"
|
|
68
|
+
# warmup_steps: "100"
|
|
69
|
+
# weight_decay: "0.01"
|
|
70
|
+
hyperparameters:
|
|
71
|
+
epochs: "10"
|
|
72
|
+
batch_size: "32"
|
|
73
|
+
learning_rate: "5e-5"
|
|
74
|
+
|
|
75
|
+
# ─── Resource Limits (optional) ──────────────────────────────────────────────
|
|
76
|
+
# Maximum training duration in seconds. Job is stopped if it exceeds this limit.
|
|
77
|
+
# Default: 86400 (24 hours)
|
|
78
|
+
#
|
|
79
|
+
# Example: 43200 (12 hours)
|
|
80
|
+
max_runtime_seconds: 86400
|
|
81
|
+
|
|
82
|
+
# EBS volume size in GB attached to each training instance.
|
|
83
|
+
# Must be large enough to hold your dataset, model, and intermediate artifacts.
|
|
84
|
+
# Default: 50
|
|
85
|
+
#
|
|
86
|
+
# Example: 100
|
|
87
|
+
volume_size_gb: 50
|
|
88
|
+
|
|
89
|
+
# ─── Managed Spot Training (optional) ────────────────────────────────────────
|
|
90
|
+
# Use EC2 Spot Instances for training at up to 90% cost savings.
|
|
91
|
+
# When enabled, SageMaker handles interruptions and resumes from checkpoints.
|
|
92
|
+
# Your training script should save checkpoints to /opt/ml/checkpoints/ periodically.
|
|
93
|
+
#
|
|
94
|
+
# IMPORTANT: When enable_spot is true, you must also set max_wait_seconds.
|
|
95
|
+
# The max_wait_seconds value must be >= max_runtime_seconds.
|
|
96
|
+
enable_spot: false
|
|
97
|
+
|
|
98
|
+
# Maximum total time (seconds) to wait for spot capacity + training completion.
|
|
99
|
+
# Only used when enable_spot is true. Must be >= max_runtime_seconds.
|
|
100
|
+
# Default: 172800 (48 hours)
|
|
101
|
+
#
|
|
102
|
+
# Example: 172800
|
|
103
|
+
max_wait_seconds: 172800
|
|
104
|
+
|
|
105
|
+
# ─── Metric Definitions (optional) ───────────────────────────────────────────
|
|
106
|
+
# Custom regex patterns to extract training metrics from container logs.
|
|
107
|
+
# SageMaker uses these to publish metrics to CloudWatch for monitoring.
|
|
108
|
+
# Each entry needs a name and a regex with exactly one capture group.
|
|
109
|
+
#
|
|
110
|
+
# Example:
|
|
111
|
+
# metric_definitions:
|
|
112
|
+
# - name: "train:loss"
|
|
113
|
+
# regex: "loss: ([0-9\\.]+)"
|
|
114
|
+
# - name: "train:epoch"
|
|
115
|
+
# regex: "epoch: ([0-9\\.]+)"
|
|
116
|
+
# - name: "eval:accuracy"
|
|
117
|
+
# regex: "eval_accuracy: ([0-9\\.]+)"
|
|
118
|
+
metric_definitions: []
|
|
119
|
+
|
|
120
|
+
# ─── Environment Variables (optional) ────────────────────────────────────────
|
|
121
|
+
# Additional environment variables set in the training container.
|
|
122
|
+
# Use for configuration that doesn't fit as hyperparameters.
|
|
123
|
+
#
|
|
124
|
+
# Example:
|
|
125
|
+
# environment:
|
|
126
|
+
# NCCL_DEBUG: "INFO"
|
|
127
|
+
# CUDA_VISIBLE_DEVICES: "0,1"
|
|
128
|
+
# TRANSFORMERS_CACHE: "/opt/ml/input/data/cache"
|
|
129
|
+
environment: {}
|
|
130
|
+
|
|
131
|
+
# ─── Tags (optional) ─────────────────────────────────────────────────────────
|
|
132
|
+
# AWS resource tags applied to the training job.
|
|
133
|
+
# Useful for cost allocation, team tracking, and resource organization.
|
|
134
|
+
#
|
|
135
|
+
# Example:
|
|
136
|
+
# tags:
|
|
137
|
+
# team: "ml-platform"
|
|
138
|
+
# project: "fine-tuning-experiment-1"
|
|
139
|
+
# environment: "development"
|
|
140
|
+
tags: {}
|
|
@@ -0,0 +1,463 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
"""SageMaker Training Script — Placeholder / Skeleton
|
|
6
|
+
|
|
7
|
+
This file demonstrates the contract that SageMaker expects from a custom
|
|
8
|
+
training script. Replace the placeholder logic with your own model training
|
|
9
|
+
code while keeping the I/O paths and conventions intact.
|
|
10
|
+
|
|
11
|
+
SageMaker copies this script into the container at /opt/ml/code/ and invokes
|
|
12
|
+
it as the entry point. The container filesystem is laid out as follows:
|
|
13
|
+
|
|
14
|
+
/opt/ml/
|
|
15
|
+
├── input/
|
|
16
|
+
│ ├── config/ # Training job configuration (JSON files)
|
|
17
|
+
│ │ ├── hyperparameters.json
|
|
18
|
+
│ │ ├── resourceconfig.json
|
|
19
|
+
│ │ └── inputdataconfig.json
|
|
20
|
+
│ └── data/ # Input data channels
|
|
21
|
+
│ └── training/ # Default channel name (configurable)
|
|
22
|
+
│ └── ... # Your training data files
|
|
23
|
+
├── model/ # Write final model artifacts here
|
|
24
|
+
│ └── ... # Everything here is packaged as model.tar.gz
|
|
25
|
+
├── checkpoints/ # Save/restore checkpoints here (spot training)
|
|
26
|
+
│ └── ... # Persisted to S3 between interruptions
|
|
27
|
+
└── output/
|
|
28
|
+
└── failure # Write failure reason here on error
|
|
29
|
+
|
|
30
|
+
Key conventions:
|
|
31
|
+
- Hyperparameters are passed as string key-value pairs (always strings!)
|
|
32
|
+
- Training data is downloaded to /opt/ml/input/data/<channel_name>/
|
|
33
|
+
- Final model artifacts MUST be written to /opt/ml/model/
|
|
34
|
+
- Checkpoints in /opt/ml/checkpoints/ survive spot interruptions
|
|
35
|
+
- Stdout/stderr are captured to CloudWatch Logs automatically
|
|
36
|
+
- Exit code 0 = success, non-zero = failure
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
import argparse
|
|
40
|
+
import json
|
|
41
|
+
import logging
|
|
42
|
+
import os
|
|
43
|
+
import sys
|
|
44
|
+
|
|
45
|
+
# ── Logging setup ─────────────────────────────────────────────────────────────
|
|
46
|
+
# SageMaker captures stdout/stderr to CloudWatch Logs automatically.
|
|
47
|
+
# Use structured logging for easier debugging in production.
|
|
48
|
+
|
|
49
|
+
logging.basicConfig(
|
|
50
|
+
level=logging.INFO,
|
|
51
|
+
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
|
52
|
+
)
|
|
53
|
+
logger = logging.getLogger(__name__)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
# ── SageMaker environment paths ──────────────────────────────────────────────
|
|
57
|
+
# These paths are fixed by the SageMaker training container contract.
|
|
58
|
+
|
|
59
|
+
INPUT_DATA_DIR = os.environ.get("SM_CHANNEL_TRAINING", "/opt/ml/input/data/training")
|
|
60
|
+
MODEL_DIR = os.environ.get("SM_MODEL_DIR", "/opt/ml/model")
|
|
61
|
+
CHECKPOINT_DIR = "/opt/ml/checkpoints"
|
|
62
|
+
OUTPUT_DIR = "/opt/ml/output"
|
|
63
|
+
HYPERPARAMS_FILE = "/opt/ml/input/config/hyperparameters.json"
|
|
64
|
+
RESOURCE_CONFIG_FILE = "/opt/ml/input/config/resourceconfig.json"
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
# ── Hyperparameter loading ────────────────────────────────────────────────────
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def load_hyperparameters():
|
|
71
|
+
"""Load hyperparameters from SageMaker's config file.
|
|
72
|
+
|
|
73
|
+
SageMaker passes all hyperparameters as STRING values in a JSON file.
|
|
74
|
+
You must cast them to the appropriate types in your code.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
dict: Hyperparameters with string values.
|
|
78
|
+
|
|
79
|
+
Example hyperparameters.json:
|
|
80
|
+
{
|
|
81
|
+
"epochs": "10",
|
|
82
|
+
"batch_size": "32",
|
|
83
|
+
"learning_rate": "0.001"
|
|
84
|
+
}
|
|
85
|
+
"""
|
|
86
|
+
if os.path.exists(HYPERPARAMS_FILE):
|
|
87
|
+
with open(HYPERPARAMS_FILE, "r") as f:
|
|
88
|
+
params = json.load(f)
|
|
89
|
+
logger.info("Loaded hyperparameters: %s", json.dumps(params, indent=2))
|
|
90
|
+
return params
|
|
91
|
+
|
|
92
|
+
logger.warning("No hyperparameters file found at %s", HYPERPARAMS_FILE)
|
|
93
|
+
return {}
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
# ── Data loading ──────────────────────────────────────────────────────────────
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def load_training_data(data_dir):
|
|
100
|
+
"""Load training data from the input channel directory.
|
|
101
|
+
|
|
102
|
+
SageMaker downloads your dataset from S3 to this directory before
|
|
103
|
+
training starts. The directory structure mirrors your S3 prefix.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
data_dir: Path to the input data channel (e.g., /opt/ml/input/data/training).
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
Your training data in whatever format your model expects.
|
|
110
|
+
|
|
111
|
+
Example directory contents:
|
|
112
|
+
/opt/ml/input/data/training/
|
|
113
|
+
├── train.csv
|
|
114
|
+
├── train.jsonl
|
|
115
|
+
└── data_part_001.parquet
|
|
116
|
+
"""
|
|
117
|
+
logger.info("Loading training data from: %s", data_dir)
|
|
118
|
+
|
|
119
|
+
# List available files
|
|
120
|
+
if os.path.isdir(data_dir):
|
|
121
|
+
files = os.listdir(data_dir)
|
|
122
|
+
logger.info("Found %d file(s): %s", len(files), files)
|
|
123
|
+
else:
|
|
124
|
+
logger.error("Data directory does not exist: %s", data_dir)
|
|
125
|
+
sys.exit(1)
|
|
126
|
+
|
|
127
|
+
# ─── Replace this with your actual data loading logic ───
|
|
128
|
+
# Examples:
|
|
129
|
+
#
|
|
130
|
+
# For CSV:
|
|
131
|
+
# import pandas as pd
|
|
132
|
+
# df = pd.read_csv(os.path.join(data_dir, "train.csv"))
|
|
133
|
+
#
|
|
134
|
+
# For JSONL:
|
|
135
|
+
# records = []
|
|
136
|
+
# with open(os.path.join(data_dir, "train.jsonl")) as f:
|
|
137
|
+
# for line in f:
|
|
138
|
+
# records.append(json.loads(line))
|
|
139
|
+
#
|
|
140
|
+
# For PyTorch datasets:
|
|
141
|
+
# from torch.utils.data import DataLoader
|
|
142
|
+
# dataset = MyDataset(data_dir)
|
|
143
|
+
# dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
|
|
144
|
+
#
|
|
145
|
+
# For Hugging Face datasets:
|
|
146
|
+
# from datasets import load_from_disk
|
|
147
|
+
# dataset = load_from_disk(data_dir)
|
|
148
|
+
|
|
149
|
+
return None # Replace with your loaded data
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
# ── Checkpoint management ─────────────────────────────────────────────────────
|
|
153
|
+
# Checkpoints are CRITICAL for managed spot training. When SageMaker interrupts
|
|
154
|
+
# a spot instance, it saves /opt/ml/checkpoints/ to S3. When the job resumes,
|
|
155
|
+
# it restores the checkpoint directory before re-running your script.
|
|
156
|
+
#
|
|
157
|
+
# Best practices:
|
|
158
|
+
# - Save checkpoints periodically (e.g., every N epochs or steps)
|
|
159
|
+
# - Include enough state to fully resume: model weights, optimizer state, epoch
|
|
160
|
+
# - On startup, check if a checkpoint exists and resume from it
|
|
161
|
+
# - Use atomic writes (write to temp file, then rename) to avoid corruption
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def save_checkpoint(model_state, optimizer_state, epoch, step):
|
|
165
|
+
"""Save a training checkpoint for spot training resumption.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
model_state: Model weights/parameters to save.
|
|
169
|
+
optimizer_state: Optimizer state for resuming training.
|
|
170
|
+
epoch: Current epoch number.
|
|
171
|
+
step: Current global step number.
|
|
172
|
+
|
|
173
|
+
The checkpoint directory (/opt/ml/checkpoints/) is automatically synced
|
|
174
|
+
to S3 by SageMaker. On spot interruption and restart, the contents are
|
|
175
|
+
restored before your script runs again.
|
|
176
|
+
"""
|
|
177
|
+
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
|
|
178
|
+
|
|
179
|
+
checkpoint = {
|
|
180
|
+
"epoch": epoch,
|
|
181
|
+
"step": step,
|
|
182
|
+
# Add your model and optimizer state here:
|
|
183
|
+
# "model_state_dict": model_state,
|
|
184
|
+
# "optimizer_state_dict": optimizer_state,
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
checkpoint_path = os.path.join(CHECKPOINT_DIR, "checkpoint_latest.json")
|
|
188
|
+
|
|
189
|
+
# Atomic write: write to temp file first, then rename
|
|
190
|
+
tmp_path = checkpoint_path + ".tmp"
|
|
191
|
+
with open(tmp_path, "w") as f:
|
|
192
|
+
json.dump(checkpoint, f)
|
|
193
|
+
os.replace(tmp_path, checkpoint_path)
|
|
194
|
+
|
|
195
|
+
logger.info("Saved checkpoint at epoch %d, step %d", epoch, step)
|
|
196
|
+
|
|
197
|
+
# ─── For PyTorch, you would typically do: ───
|
|
198
|
+
# import torch
|
|
199
|
+
# torch.save({
|
|
200
|
+
# "epoch": epoch,
|
|
201
|
+
# "step": step,
|
|
202
|
+
# "model_state_dict": model.state_dict(),
|
|
203
|
+
# "optimizer_state_dict": optimizer.state_dict(),
|
|
204
|
+
# "loss": current_loss,
|
|
205
|
+
# }, os.path.join(CHECKPOINT_DIR, "checkpoint_latest.pt"))
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def load_checkpoint():
|
|
209
|
+
"""Restore training state from a checkpoint if one exists.
|
|
210
|
+
|
|
211
|
+
Returns:
|
|
212
|
+
dict or None: Checkpoint data if found, None otherwise.
|
|
213
|
+
|
|
214
|
+
Call this at the start of training to resume from where you left off
|
|
215
|
+
after a spot interruption.
|
|
216
|
+
"""
|
|
217
|
+
checkpoint_path = os.path.join(CHECKPOINT_DIR, "checkpoint_latest.json")
|
|
218
|
+
|
|
219
|
+
if os.path.exists(checkpoint_path):
|
|
220
|
+
with open(checkpoint_path, "r") as f:
|
|
221
|
+
checkpoint = json.load(f)
|
|
222
|
+
logger.info(
|
|
223
|
+
"Restored checkpoint: epoch %d, step %d",
|
|
224
|
+
checkpoint.get("epoch", 0),
|
|
225
|
+
checkpoint.get("step", 0),
|
|
226
|
+
)
|
|
227
|
+
return checkpoint
|
|
228
|
+
|
|
229
|
+
logger.info("No checkpoint found — starting from scratch")
|
|
230
|
+
return None
|
|
231
|
+
|
|
232
|
+
# ─── For PyTorch, you would typically do: ───
|
|
233
|
+
# import torch
|
|
234
|
+
# ckpt_path = os.path.join(CHECKPOINT_DIR, "checkpoint_latest.pt")
|
|
235
|
+
# if os.path.exists(ckpt_path):
|
|
236
|
+
# checkpoint = torch.load(ckpt_path)
|
|
237
|
+
# model.load_state_dict(checkpoint["model_state_dict"])
|
|
238
|
+
# optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
|
239
|
+
# start_epoch = checkpoint["epoch"]
|
|
240
|
+
# return checkpoint
|
|
241
|
+
# return None
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
# ── Model saving ──────────────────────────────────────────────────────────────
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def save_model(model, model_dir):
|
|
248
|
+
"""Save final model artifacts to the output directory.
|
|
249
|
+
|
|
250
|
+
IMPORTANT: Everything written to /opt/ml/model/ is packaged into a
|
|
251
|
+
model.tar.gz and uploaded to S3 at the output path you configured.
|
|
252
|
+
This is what gets deployed to a SageMaker endpoint.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
model: Your trained model object.
|
|
256
|
+
model_dir: Path to write artifacts (default: /opt/ml/model/).
|
|
257
|
+
|
|
258
|
+
Common patterns:
|
|
259
|
+
- sklearn: joblib.dump(model, os.path.join(model_dir, "model.pkl"))
|
|
260
|
+
- PyTorch: torch.save(model.state_dict(), os.path.join(model_dir, "model.pt"))
|
|
261
|
+
- Hugging Face: model.save_pretrained(model_dir)
|
|
262
|
+
- XGBoost: model.save_model(os.path.join(model_dir, "model.json"))
|
|
263
|
+
"""
|
|
264
|
+
os.makedirs(model_dir, exist_ok=True)
|
|
265
|
+
|
|
266
|
+
logger.info("Saving model artifacts to: %s", model_dir)
|
|
267
|
+
|
|
268
|
+
# ─── Replace with your model saving logic ───
|
|
269
|
+
# Examples:
|
|
270
|
+
#
|
|
271
|
+
# For scikit-learn:
|
|
272
|
+
# import joblib
|
|
273
|
+
# joblib.dump(model, os.path.join(model_dir, "model.pkl"))
|
|
274
|
+
#
|
|
275
|
+
# For PyTorch:
|
|
276
|
+
# import torch
|
|
277
|
+
# torch.save(model.state_dict(), os.path.join(model_dir, "model.pt"))
|
|
278
|
+
# # Also save model config/architecture for inference
|
|
279
|
+
# with open(os.path.join(model_dir, "config.json"), "w") as f:
|
|
280
|
+
# json.dump(model_config, f)
|
|
281
|
+
#
|
|
282
|
+
# For Hugging Face transformers:
|
|
283
|
+
# model.save_pretrained(model_dir)
|
|
284
|
+
# tokenizer.save_pretrained(model_dir)
|
|
285
|
+
#
|
|
286
|
+
# For LoRA adapters (PEFT):
|
|
287
|
+
# model.save_pretrained(model_dir)
|
|
288
|
+
# # This creates adapter_config.json + adapter weights
|
|
289
|
+
# # The feedback loop will detect this and suggest ./do/adapter add
|
|
290
|
+
|
|
291
|
+
# Placeholder: save a marker file
|
|
292
|
+
with open(os.path.join(model_dir, "model_info.json"), "w") as f:
|
|
293
|
+
json.dump({"status": "placeholder", "message": "Replace with real model"}, f)
|
|
294
|
+
|
|
295
|
+
logger.info("Model artifacts saved successfully")
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
# ── Training loop ─────────────────────────────────────────────────────────────
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def train(hyperparams, data_dir, model_dir, checkpoint_dir):
|
|
302
|
+
"""Main training loop.
|
|
303
|
+
|
|
304
|
+
This is where your actual training logic goes. The structure below
|
|
305
|
+
shows the recommended pattern for SageMaker compatibility:
|
|
306
|
+
|
|
307
|
+
1. Load hyperparameters (cast strings to proper types)
|
|
308
|
+
2. Load training data from the input channel
|
|
309
|
+
3. Check for existing checkpoint (for spot training resumption)
|
|
310
|
+
4. Run training loop with periodic checkpoint saves
|
|
311
|
+
5. Save final model to the output directory
|
|
312
|
+
|
|
313
|
+
Args:
|
|
314
|
+
hyperparams: Dict of hyperparameters (all values are strings).
|
|
315
|
+
data_dir: Path to input training data.
|
|
316
|
+
model_dir: Path to write final model artifacts.
|
|
317
|
+
checkpoint_dir: Path for checkpoint save/restore.
|
|
318
|
+
"""
|
|
319
|
+
# ── Step 1: Parse hyperparameters (cast from strings) ──
|
|
320
|
+
epochs = int(hyperparams.get("epochs", "10"))
|
|
321
|
+
batch_size = int(hyperparams.get("batch_size", "32"))
|
|
322
|
+
learning_rate = float(hyperparams.get("learning_rate", "0.001"))
|
|
323
|
+
checkpoint_frequency = int(hyperparams.get("checkpoint_frequency", "1"))
|
|
324
|
+
|
|
325
|
+
logger.info(
|
|
326
|
+
"Training config: epochs=%d, batch_size=%d, lr=%f",
|
|
327
|
+
epochs, batch_size, learning_rate,
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
# ── Step 2: Load training data ──
|
|
331
|
+
training_data = load_training_data(data_dir)
|
|
332
|
+
|
|
333
|
+
# ── Step 3: Check for existing checkpoint (spot resumption) ──
|
|
334
|
+
checkpoint = load_checkpoint()
|
|
335
|
+
start_epoch = 0
|
|
336
|
+
if checkpoint:
|
|
337
|
+
start_epoch = checkpoint.get("epoch", 0)
|
|
338
|
+
# Restore model and optimizer state from checkpoint here
|
|
339
|
+
logger.info("Resuming training from epoch %d", start_epoch)
|
|
340
|
+
|
|
341
|
+
# ── Step 4: Training loop ──
|
|
342
|
+
model = None # Replace with your model initialization
|
|
343
|
+
|
|
344
|
+
for epoch in range(start_epoch, epochs):
|
|
345
|
+
logger.info("Epoch %d/%d", epoch + 1, epochs)
|
|
346
|
+
|
|
347
|
+
# ─── Replace with your actual training step ───
|
|
348
|
+
# Example (PyTorch):
|
|
349
|
+
# model.train()
|
|
350
|
+
# for batch_idx, (data, target) in enumerate(dataloader):
|
|
351
|
+
# optimizer.zero_grad()
|
|
352
|
+
# output = model(data)
|
|
353
|
+
# loss = criterion(output, target)
|
|
354
|
+
# loss.backward()
|
|
355
|
+
# optimizer.step()
|
|
356
|
+
#
|
|
357
|
+
# if batch_idx % log_interval == 0:
|
|
358
|
+
# logger.info(" Step %d, Loss: %.4f", batch_idx, loss.item())
|
|
359
|
+
|
|
360
|
+
# Print metrics in a format SageMaker can parse for CloudWatch
|
|
361
|
+
# Use the regex pattern defined in config.yaml metric_definitions
|
|
362
|
+
train_loss = 0.0 # Replace with actual loss
|
|
363
|
+
print(f"loss: {train_loss:.4f}")
|
|
364
|
+
print(f"epoch: {epoch + 1}")
|
|
365
|
+
|
|
366
|
+
# ── Save checkpoint periodically ──
|
|
367
|
+
if (epoch + 1) % checkpoint_frequency == 0:
|
|
368
|
+
save_checkpoint(
|
|
369
|
+
model_state=None, # Replace with model.state_dict()
|
|
370
|
+
optimizer_state=None, # Replace with optimizer.state_dict()
|
|
371
|
+
epoch=epoch + 1,
|
|
372
|
+
step=(epoch + 1) * batch_size,
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
# ── Step 5: Save final model ──
|
|
376
|
+
save_model(model, model_dir)
|
|
377
|
+
|
|
378
|
+
logger.info("Training complete!")
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
# ── Distributed training helpers ──────────────────────────────────────────────
|
|
382
|
+
# When instance_count > 1 in config.yaml, SageMaker launches multiple instances
|
|
383
|
+
# and sets up inter-node communication. Your script must be distribution-aware.
|
|
384
|
+
#
|
|
385
|
+
# SageMaker provides these environment variables for distributed training:
|
|
386
|
+
# SM_HOSTS - JSON list of all host names
|
|
387
|
+
# SM_CURRENT_HOST - This instance's host name
|
|
388
|
+
# SM_NUM_GPUS - Number of GPUs on this instance
|
|
389
|
+
#
|
|
390
|
+
# Example (PyTorch DDP):
|
|
391
|
+
# import torch.distributed as dist
|
|
392
|
+
# dist.init_process_group(backend="nccl")
|
|
393
|
+
# local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
|
394
|
+
# model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
def get_distributed_info():
|
|
398
|
+
"""Get distributed training configuration from SageMaker environment.
|
|
399
|
+
|
|
400
|
+
Returns:
|
|
401
|
+
dict: Distributed training info including hosts, current host, and GPU count.
|
|
402
|
+
"""
|
|
403
|
+
hosts = json.loads(os.environ.get("SM_HOSTS", '["localhost"]'))
|
|
404
|
+
current_host = os.environ.get("SM_CURRENT_HOST", "localhost")
|
|
405
|
+
num_gpus = int(os.environ.get("SM_NUM_GPUS", "0"))
|
|
406
|
+
|
|
407
|
+
return {
|
|
408
|
+
"hosts": hosts,
|
|
409
|
+
"current_host": current_host,
|
|
410
|
+
"num_gpus": num_gpus,
|
|
411
|
+
"num_hosts": len(hosts),
|
|
412
|
+
"is_leader": current_host == hosts[0],
|
|
413
|
+
}
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
# ── Entry point ───────────────────────────────────────────────────────────────
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
if __name__ == "__main__":
|
|
420
|
+
# SageMaker also passes hyperparameters as command-line arguments.
|
|
421
|
+
# You can use either the JSON file or argparse — both work.
|
|
422
|
+
parser = argparse.ArgumentParser(description="SageMaker Training Script")
|
|
423
|
+
|
|
424
|
+
# SageMaker standard arguments
|
|
425
|
+
parser.add_argument("--model-dir", type=str, default=MODEL_DIR,
|
|
426
|
+
help="Directory to save model artifacts")
|
|
427
|
+
parser.add_argument("--data-dir", type=str, default=INPUT_DATA_DIR,
|
|
428
|
+
help="Directory containing training data")
|
|
429
|
+
|
|
430
|
+
# Add your custom hyperparameters as CLI args if preferred:
|
|
431
|
+
# parser.add_argument("--epochs", type=int, default=10)
|
|
432
|
+
# parser.add_argument("--batch-size", type=int, default=32)
|
|
433
|
+
# parser.add_argument("--learning-rate", type=float, default=0.001)
|
|
434
|
+
|
|
435
|
+
args, _ = parser.parse_known_args()
|
|
436
|
+
|
|
437
|
+
# Load hyperparameters from SageMaker config file
|
|
438
|
+
hyperparams = load_hyperparameters()
|
|
439
|
+
|
|
440
|
+
# Log distributed training info
|
|
441
|
+
dist_info = get_distributed_info()
|
|
442
|
+
if dist_info["num_hosts"] > 1:
|
|
443
|
+
logger.info("Distributed training: %d hosts, %d GPUs per host",
|
|
444
|
+
dist_info["num_hosts"], dist_info["num_gpus"])
|
|
445
|
+
logger.info("Current host: %s (leader: %s)",
|
|
446
|
+
dist_info["current_host"], dist_info["is_leader"])
|
|
447
|
+
|
|
448
|
+
# Run training
|
|
449
|
+
try:
|
|
450
|
+
train(
|
|
451
|
+
hyperparams=hyperparams,
|
|
452
|
+
data_dir=args.data_dir,
|
|
453
|
+
model_dir=args.model_dir,
|
|
454
|
+
checkpoint_dir=CHECKPOINT_DIR,
|
|
455
|
+
)
|
|
456
|
+
except Exception as e:
|
|
457
|
+
# Write failure reason to /opt/ml/output/failure
|
|
458
|
+
failure_path = os.path.join(OUTPUT_DIR, "failure")
|
|
459
|
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|
460
|
+
with open(failure_path, "w") as f:
|
|
461
|
+
f.write(str(e))
|
|
462
|
+
logger.error("Training failed: %s", e, exc_info=True)
|
|
463
|
+
sys.exit(1)
|