@aws/ml-container-creator 1.0.3 → 1.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. package/README.md +1 -1
  2. package/infra/ci-harness/lib/ci-harness-stack.ts +43 -0
  3. package/package.json +2 -2
  4. package/servers/base-image-picker/index.js +65 -18
  5. package/servers/instance-sizer/index.js +32 -0
  6. package/servers/lib/catalogs/fleet-drivers.json +38 -0
  7. package/servers/lib/catalogs/model-arch-support.json +51 -0
  8. package/servers/lib/catalogs/model-servers.json +2842 -1730
  9. package/servers/lib/schemas/image-catalog.schema.json +12 -0
  10. package/src/app.js +6 -4
  11. package/src/lib/generated/cli-options.js +1 -1
  12. package/src/lib/generated/parameter-matrix.js +1 -1
  13. package/src/lib/generated/validation-rules.js +1 -1
  14. package/src/lib/mcp-query-runner.js +110 -3
  15. package/src/lib/prompt-runner.js +66 -22
  16. package/src/lib/template-variable-resolver.js +8 -0
  17. package/src/lib/train-config-builder.js +339 -0
  18. package/templates/do/.benchmark_writer.py +3 -0
  19. package/templates/do/.eval_helper.py +409 -0
  20. package/templates/do/.register_helper.py +185 -11
  21. package/templates/do/.train_build_request.py +102 -113
  22. package/templates/do/.train_helper.py +433 -0
  23. package/templates/do/__pycache__/.register_helper.cpython-312.pyc +0 -0
  24. package/templates/do/adapter +157 -0
  25. package/templates/do/benchmark +60 -3
  26. package/templates/do/deploy.d/managed-inference.ejs +83 -0
  27. package/templates/do/evaluate +272 -0
  28. package/templates/do/lib/resolve-instance.sh +155 -0
  29. package/templates/do/register +5 -0
  30. package/templates/do/test +1 -0
  31. package/templates/do/train +879 -126
  32. package/templates/do/training/config.yaml +83 -11
  33. package/templates/do/training/dpo/accelerate_config.yaml +24 -0
  34. package/templates/do/training/dpo/defaults.yaml +26 -0
  35. package/templates/do/training/dpo/prompts.json +8 -0
  36. package/templates/do/training/dpo/train.py +363 -0
  37. package/templates/do/training/sft/accelerate_config.yaml +22 -0
  38. package/templates/do/training/sft/defaults.yaml +18 -0
  39. package/templates/do/training/sft/prompts.json +7 -0
  40. package/templates/do/training/sft/train.py +310 -0
  41. package/templates/do/tune +11 -2
  42. package/templates/do/.train_poll_parser.py +0 -135
  43. package/templates/do/.train_status_parser.py +0 -187
  44. /package/templates/do/training/{train.py → custom/train.py} +0 -0
@@ -5,20 +5,33 @@
5
5
  # This file is read by do/train to construct a CreateTrainingJob request.
6
6
  # Required fields are marked with [REQUIRED]. All other fields are optional.
7
7
 
8
+ # ─── Training Technique ───────────────────────────────────────────────────────
9
+ # Which training technique to use. Each technique has its own training script
10
+ # in do/training/<technique>/train.py with appropriate defaults.
11
+ #
12
+ # Available techniques:
13
+ # custom — Your own training script (default, user-provided logic)
14
+ # sft — Supervised Fine-Tuning using TRL SFTTrainer + PEFT LoRA
15
+ # dpo — Direct Preference Optimization (future, see Spec G2)
16
+ #
17
+ # Override with CLI: ./do/train --technique sft
18
+ technique: "custom"
19
+
8
20
  # ─── Container Image [REQUIRED] ──────────────────────────────────────────────
9
21
  # ECR image URI for the training container.
10
22
  # This defaults to your project's own container image (built via do/build + do/push).
11
23
  # You can also use a SageMaker pre-built training image or any custom ECR image.
12
24
  #
13
25
  # Example: 123456789012.dkr.ecr.us-east-1.amazonaws.com/ml-container-creator:latest
14
- image: "<%= '${AWS_ACCOUNT_ID}.dkr.ecr.${AWS_REGION}.amazonaws.com/${ECR_REPOSITORY_NAME}:latest' %>"
26
+ image: "${_PROFILE_accountId}.dkr.ecr.<%= awsRegion %>.amazonaws.com/${_PROFILE_ecrRepositoryName:-ml-container-creator}:<%= projectName %>-latest"
15
27
 
16
- # ─── Training Script [REQUIRED] ──────────────────────────────────────────────
28
+ # ─── Training Script [REQUIRED for custom technique] ─────────────────────────
17
29
  # Path to your training script relative to the project root.
18
- # SageMaker copies this into the container at /opt/ml/code/ and executes it.
30
+ # For technique-based training (sft, dpo), the script is auto-selected from
31
+ # do/training/<technique>/train.py — this field is only used for 'custom'.
19
32
  #
20
- # Example: do/training/train.py
21
- script: "do/training/train.py"
33
+ # Example: do/training/custom/train.py
34
+ script: "do/training/custom/train.py"
22
35
 
23
36
  # ─── Instance Configuration [REQUIRED] ───────────────────────────────────────
24
37
  # SageMaker ML instance type for training.
@@ -37,9 +50,39 @@ instance_type: "ml.g5.xlarge"
37
50
  # (e.g., using PyTorch DDP, Horovod, or similar). SageMaker handles inter-node
38
51
  # communication setup automatically.
39
52
  #
53
+ # For multi-node training (instance_count > 1):
54
+ # - Use EFA-capable instances for best performance: p4d.24xlarge, p5.48xlarge, g5.48xlarge
55
+ # - SageMaker auto-configures NCCL/EFA for inter-node communication
56
+ # - Accelerate auto-detects WORLD_SIZE > local_gpu_count and enables multi-node FSDP
57
+ # - No changes needed to training scripts or accelerate_config.yaml
58
+ #
40
59
  # Example: 2
41
60
  instance_count: 1
42
61
 
62
+ # Networking backend for distributed training (multi-node).
63
+ # auto = detect EFA presence and use NCCL if available, gloo otherwise
64
+ # nccl = force NCCL (requires EFA-capable instances for multi-node)
65
+ # gloo = force gloo (CPU-based, works on any instance but slower)
66
+ networking_backend: "auto"
67
+
68
+ # ─── Launcher (optional) ─────────────────────────────────────────────────────
69
+ # How the training script is invoked. Default: accelerate (recommended for TRL).
70
+ #
71
+ # Options:
72
+ # accelerate — `accelerate launch --config_file <config> train.py` (default)
73
+ # Handles single-GPU, multi-GPU (FSDP), and multi-node automatically.
74
+ # torchrun — `torchrun --nproc_per_node <gpus> train.py`
75
+ # Use for scripts that manage distributed setup internally.
76
+ # deepspeed — `deepspeed --num_gpus <gpus> train.py`
77
+ # Use for ZeRO-3 offloading with >70B models.
78
+ # custom — Direct `python train.py` invocation (no distributed wrapper).
79
+ # Use only for scripts that handle everything internally.
80
+ #
81
+ # Note: SFT and DPO techniques include accelerate_config.yaml and are designed
82
+ # for the `accelerate` launcher. Using a different launcher with these techniques
83
+ # may require modifying the training script.
84
+ launcher: "accelerate"
85
+
43
86
  # ─── Data Configuration [REQUIRED] ───────────────────────────────────────────
44
87
  # S3 URI for the input training dataset.
45
88
  # SageMaker downloads this to /opt/ml/input/data/training/ in the container.
@@ -60,13 +103,27 @@ output_path: ""
60
103
  # All values must be strings. SageMaker passes these as string arguments to
61
104
  # the training script entry point.
62
105
  #
63
- # Example:
106
+ # For 'custom' technique: define your own hyperparameters.
107
+ # For 'sft' technique: the following are recognized:
108
+ # lora_r, lora_alpha, lora_dropout, learning_rate, epochs,
109
+ # batch_size, max_seq_length, gradient_accumulation_steps,
110
+ # warmup_ratio, dataset_text_field, model_id
111
+ #
112
+ # Example (custom):
64
113
  # hyperparameters:
65
114
  # epochs: "10"
66
115
  # batch_size: "32"
67
116
  # learning_rate: "5e-5"
68
- # warmup_steps: "100"
69
- # weight_decay: "0.01"
117
+ #
118
+ # Example (sft):
119
+ # hyperparameters:
120
+ # lora_r: "16"
121
+ # lora_alpha: "32"
122
+ # learning_rate: "2e-4"
123
+ # epochs: "3"
124
+ # batch_size: "4"
125
+ # max_seq_length: "2048"
126
+ # dataset_text_field: "text"
70
127
  hyperparameters:
71
128
  epochs: "10"
72
129
  batch_size: "32"
@@ -102,19 +159,34 @@ enable_spot: false
102
159
  # Example: 172800
103
160
  max_wait_seconds: 172800
104
161
 
162
+ # S3 URI for checkpoint storage during spot training.
163
+ # SageMaker syncs /opt/ml/checkpoints/ to this location between interruptions.
164
+ # If empty, auto-derived from output_path: <output_path>/checkpoints/
165
+ # Only used when enable_spot is true.
166
+ #
167
+ # Example: s3://my-bucket/checkpoints/
168
+ checkpoint_s3_uri: ""
169
+
105
170
  # ─── Metric Definitions (optional) ───────────────────────────────────────────
106
171
  # Custom regex patterns to extract training metrics from container logs.
107
172
  # SageMaker uses these to publish metrics to CloudWatch for monitoring.
108
173
  # Each entry needs a name and a regex with exactly one capture group.
109
174
  #
110
- # Example:
175
+ # Example (custom):
111
176
  # metric_definitions:
112
177
  # - name: "train:loss"
113
178
  # regex: "loss: ([0-9\\.]+)"
114
179
  # - name: "train:epoch"
115
180
  # regex: "epoch: ([0-9\\.]+)"
116
- # - name: "eval:accuracy"
117
- # regex: "eval_accuracy: ([0-9\\.]+)"
181
+ #
182
+ # Example (sft — matches output from training/sft/train.py):
183
+ # metric_definitions:
184
+ # - name: "train:loss"
185
+ # regex: "train_loss: ([0-9\\.]+)"
186
+ # - name: "train:samples_per_second"
187
+ # regex: "train_samples_per_second: ([0-9\\.]+)"
188
+ # - name: "train:epoch"
189
+ # regex: "epochs: ([0-9\\.]+)"
118
190
  metric_definitions: []
119
191
 
120
192
  # ─── Environment Variables (optional) ────────────────────────────────────────
@@ -0,0 +1,24 @@
1
+ # Accelerate configuration for DPO training.
2
+ # FSDP-ready: works on both single-GPU and multi-GPU automatically.
3
+ #
4
+ # NOTE: DPO requires a reference model (frozen copy of base model), which
5
+ # doubles memory usage. For models >7B, consider:
6
+ # - reference_free: true (in hyperparameters) — skips reference model
7
+ # - Using a larger instance (ml.g5.12xlarge for 4× A10G = 96GB total)
8
+ # - FSDP offloading: set fsdp_offload_params: true below
9
+ #
10
+ # Single-GPU (ml.g5.xlarge): accelerate detects 1 GPU, runs single-process.
11
+ # Multi-GPU (ml.g5.12xlarge): accelerate detects 4 GPUs, runs FSDP.
12
+
13
+ compute_environment: LOCAL_MACHINE
14
+ distributed_type: FSDP
15
+ fsdp_config:
16
+ fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
17
+ fsdp_backward_prefetch: BACKWARD_PRE
18
+ fsdp_offload_params: false
19
+ fsdp_sharding_strategy: FULL_SHARD
20
+ fsdp_state_dict_type: SHARDED_STATE_DICT
21
+ fsdp_sync_module_states: true
22
+ fsdp_use_orig_params: true
23
+ mixed_precision: bf16
24
+ num_processes: auto
@@ -0,0 +1,26 @@
1
+ # Default hyperparameters for DPO (Direct Preference Optimization) training.
2
+ #
3
+ # These are the lowest-priority defaults. Override via:
4
+ # 1. CLI flags (--beta, --learning-rate, etc.) — highest priority
5
+ # 2. training/config.yaml hyperparameters section
6
+ # 3. Interactive mode answers (written to config.yaml)
7
+ # 4. These defaults — lowest priority
8
+ #
9
+ # NOTE: DPO creates a frozen reference model internally, doubling memory.
10
+ # For models >7B on single GPU, set reference_free: true or use multi-GPU.
11
+
12
+ beta: 0.1
13
+ lora_r: 16
14
+ lora_alpha: 32
15
+ lora_dropout: 0.05
16
+ learning_rate: 5e-7
17
+ epochs: 1
18
+ batch_size: 2
19
+ max_length: 1024
20
+ max_prompt_length: 512
21
+ gradient_accumulation_steps: 4
22
+ warmup_ratio: 0.03
23
+ chosen_field: chosen
24
+ rejected_field: rejected
25
+ prompt_field: prompt
26
+ reference_free: false
@@ -0,0 +1,8 @@
1
+ {
2
+ "section_title": "DPO-specific settings",
3
+ "prompts": [
4
+ {"name": "beta", "type": "input", "message": "Beta (KL penalty coefficient)?", "default": "0.1", "validate": "float"},
5
+ {"name": "chosen_field", "type": "input", "message": "Chosen response column name?", "default": "chosen", "validate": "string"},
6
+ {"name": "rejected_field", "type": "input", "message": "Rejected response column name?", "default": "rejected", "validate": "string"}
7
+ ]
8
+ }
@@ -0,0 +1,363 @@
1
+ #!/usr/bin/env python3
2
+ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ """Unmanaged DPO fine-tuning using TRL DPOTrainer + PEFT LoRA.
6
+
7
+ Direct Preference Optimization trains a model to prefer chosen responses
8
+ over rejected responses without explicit reward modeling. This script uses
9
+ TRL's DPOTrainer with PEFT LoRA for parameter-efficient training.
10
+
11
+ Portable env-var contract (works on SageMaker AI and HyperPod EKS):
12
+ DATA_DIR / SM_CHANNEL_TRAINING -> training data path
13
+ OUTPUT_DIR / SM_MODEL_DIR -> model artifact output
14
+ CHECKPOINT_DIR / SM_CHECKPOINT_DIR -> checkpoint path for spot resume
15
+ HF_MODEL_ID / SM_HP_MODEL_ID -> base model HuggingFace ID
16
+ SM_HPS -> JSON blob of all hyperparameters
17
+
18
+ Dataset format:
19
+ JSONL with fields: prompt, chosen, rejected
20
+ Example: {"prompt": "Explain X", "chosen": "Good answer", "rejected": "Bad answer"}
21
+
22
+ Output:
23
+ LoRA adapter saved to OUTPUT_DIR (adapter_model.safetensors + adapter_config.json)
24
+ Metrics logged to stdout in SageMaker-parseable format
25
+ """
26
+
27
+ import glob
28
+ import json
29
+ import logging
30
+ import os
31
+ import sys
32
+
33
+ # ── Logging ───────────────────────────────────────────────────────────────────
34
+ logging.basicConfig(
35
+ level=logging.INFO,
36
+ format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
37
+ )
38
+ logger = logging.getLogger("dpo-trainer")
39
+
40
+ # ── Portable Path Resolution ─────────────────────────────────────────────────
41
+ # Fallback chain: generic env var -> SageMaker env var -> default path
42
+
43
+ DATA_DIR = (
44
+ os.environ.get("DATA_DIR")
45
+ or os.environ.get("SM_CHANNEL_TRAINING")
46
+ or "/opt/ml/input/data/training"
47
+ )
48
+
49
+ OUTPUT_DIR = (
50
+ os.environ.get("OUTPUT_DIR")
51
+ or os.environ.get("SM_MODEL_DIR")
52
+ or "/opt/ml/model"
53
+ )
54
+
55
+ CHECKPOINT_DIR = (
56
+ os.environ.get("CHECKPOINT_DIR")
57
+ or os.environ.get("SM_CHECKPOINT_DIR")
58
+ or "/opt/ml/checkpoints"
59
+ )
60
+
61
+ MODEL_ID = (
62
+ os.environ.get("HF_MODEL_ID")
63
+ or os.environ.get("SM_HP_MODEL_ID")
64
+ or ""
65
+ )
66
+
67
+
68
+ # ── Hyperparameter Loading ────────────────────────────────────────────────────
69
+
70
+ def load_hyperparameters():
71
+ """Load hyperparameters from SageMaker SM_HPS env var or individual SM_HP_* vars.
72
+
73
+ Returns:
74
+ dict with typed hyperparameter values.
75
+ """
76
+ defaults = {
77
+ "model_id": MODEL_ID,
78
+ "beta": 0.1,
79
+ "lora_r": 16,
80
+ "lora_alpha": 32,
81
+ "lora_dropout": 0.05,
82
+ "learning_rate": 5e-7,
83
+ "epochs": 1,
84
+ "batch_size": 2,
85
+ "max_length": 1024,
86
+ "max_prompt_length": 512,
87
+ "gradient_accumulation_steps": 4,
88
+ "warmup_ratio": 0.03,
89
+ "chosen_field": "chosen",
90
+ "rejected_field": "rejected",
91
+ "prompt_field": "prompt",
92
+ "reference_free": False,
93
+ }
94
+
95
+ # Try SM_HPS (JSON blob of all hyperparameters)
96
+ sm_hps = os.environ.get("SM_HPS")
97
+ if sm_hps:
98
+ try:
99
+ raw = json.loads(sm_hps)
100
+ for key, default_val in defaults.items():
101
+ if key in raw:
102
+ defaults[key] = _cast(raw[key], type(default_val))
103
+ return defaults
104
+ except (json.JSONDecodeError, ValueError) as e:
105
+ logger.warning("Failed to parse SM_HPS: %s", e)
106
+
107
+ # Fallback: individual SM_HP_* env vars
108
+ for key, default_val in defaults.items():
109
+ env_key = f"SM_HP_{key.upper()}"
110
+ env_val = os.environ.get(env_key)
111
+ if env_val is not None:
112
+ defaults[key] = _cast(env_val, type(default_val))
113
+
114
+ return defaults
115
+
116
+
117
+ def _cast(value, target_type):
118
+ """Cast a string value to the target type."""
119
+ if target_type == bool:
120
+ return str(value).lower() in ("true", "1", "yes")
121
+ if target_type == int:
122
+ return int(float(value))
123
+ if target_type == float:
124
+ return float(value)
125
+ return str(value)
126
+
127
+
128
+ # ── Dataset Loading ───────────────────────────────────────────────────────────
129
+
130
+ def load_preference_dataset(data_dir, chosen_field, rejected_field, prompt_field):
131
+ """Load DPO preference dataset from data directory.
132
+
133
+ Expects JSONL with at minimum `chosen` and `rejected` fields.
134
+ Optionally includes a `prompt` field for conditional DPO.
135
+
136
+ Args:
137
+ data_dir: Path to directory containing training data files.
138
+ chosen_field: Column name for preferred responses.
139
+ rejected_field: Column name for dispreferred responses.
140
+ prompt_field: Column name for prompts (optional).
141
+
142
+ Returns:
143
+ A Hugging Face Dataset object.
144
+ """
145
+ from datasets import load_dataset as hf_load_dataset
146
+
147
+ # Find data files
148
+ extensions = ["jsonl", "json", "parquet", "csv"]
149
+ data_files = []
150
+ for ext in extensions:
151
+ data_files.extend(glob.glob(os.path.join(data_dir, f"*.{ext}")))
152
+ data_files.extend(glob.glob(os.path.join(data_dir, f"**/*.{ext}"), recursive=True))
153
+
154
+ if not data_files:
155
+ logger.error("No data files found in %s (searched: %s)", data_dir, extensions)
156
+ sys.exit(1)
157
+
158
+ # Deduplicate and sort
159
+ data_files = sorted(set(data_files))
160
+ logger.info("Found %d data file(s) in %s", len(data_files), data_dir)
161
+
162
+ # Determine format from first file extension
163
+ first_ext = data_files[0].rsplit(".", 1)[-1].lower()
164
+ format_map = {"jsonl": "json", "json": "json", "parquet": "parquet", "csv": "csv"}
165
+ file_format = format_map.get(first_ext, "json")
166
+
167
+ dataset = hf_load_dataset(file_format, data_files=data_files, split="train")
168
+ logger.info("Loaded dataset: %d rows, columns: %s", len(dataset), dataset.column_names)
169
+
170
+ # Verify required fields exist
171
+ missing = []
172
+ if chosen_field not in dataset.column_names:
173
+ missing.append(f"chosen_field='{chosen_field}'")
174
+ if rejected_field not in dataset.column_names:
175
+ missing.append(f"rejected_field='{rejected_field}'")
176
+
177
+ if missing:
178
+ logger.error(
179
+ "Required fields not found in dataset: %s. Available columns: %s",
180
+ ", ".join(missing), dataset.column_names,
181
+ )
182
+ sys.exit(1)
183
+
184
+ # Check for optional prompt field
185
+ has_prompt = prompt_field in dataset.column_names
186
+ if has_prompt:
187
+ logger.info("Prompt field '%s' found — using conditional DPO", prompt_field)
188
+ else:
189
+ logger.info("No prompt field '%s' — using unconditional DPO", prompt_field)
190
+
191
+ return dataset
192
+
193
+
194
+ # ── Main Training Function ────────────────────────────────────────────────────
195
+
196
+ def main():
197
+ """Run DPO training with TRL DPOTrainer + PEFT LoRA."""
198
+ from accelerate import Accelerator
199
+ from peft import LoraConfig, TaskType
200
+ from transformers import AutoModelForCausalLM, AutoTokenizer
201
+ from trl import DPOConfig, DPOTrainer
202
+
203
+ # Initialize accelerator (handles distributed setup)
204
+ accelerator = Accelerator()
205
+
206
+ # Load hyperparameters
207
+ hparams = load_hyperparameters()
208
+ model_id = hparams["model_id"]
209
+
210
+ if not model_id:
211
+ logger.error("No model ID specified. Set HF_MODEL_ID env var or model_id hyperparameter.")
212
+ sys.exit(1)
213
+
214
+ if accelerator.is_main_process:
215
+ logger.info("=" * 60)
216
+ logger.info("DPO Training Configuration")
217
+ logger.info("=" * 60)
218
+ logger.info(" Model: %s", model_id)
219
+ logger.info(" Data dir: %s", DATA_DIR)
220
+ logger.info(" Output dir: %s", OUTPUT_DIR)
221
+ logger.info(" Checkpoint dir: %s", CHECKPOINT_DIR)
222
+ logger.info(" Beta: %s", hparams["beta"])
223
+ logger.info(" LoRA r: %d", hparams["lora_r"])
224
+ logger.info(" LoRA alpha: %d", hparams["lora_alpha"])
225
+ logger.info(" Learning rate: %s", hparams["learning_rate"])
226
+ logger.info(" Epochs: %d", hparams["epochs"])
227
+ logger.info(" Batch size: %d", hparams["batch_size"])
228
+ logger.info(" Max length: %d", hparams["max_length"])
229
+ logger.info(" Chosen field: %s", hparams["chosen_field"])
230
+ logger.info(" Rejected field: %s", hparams["rejected_field"])
231
+ logger.info(" Prompt field: %s", hparams["prompt_field"])
232
+ logger.info(" Reference free: %s", hparams["reference_free"])
233
+ logger.info("=" * 60)
234
+
235
+ # ── Load tokenizer and model ──────────────────────────────────────────────
236
+ if accelerator.is_main_process:
237
+ logger.info("Loading tokenizer and model: %s", model_id)
238
+
239
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
240
+ if tokenizer.pad_token is None:
241
+ tokenizer.pad_token = tokenizer.eos_token
242
+
243
+ model = AutoModelForCausalLM.from_pretrained(
244
+ model_id,
245
+ torch_dtype="auto",
246
+ trust_remote_code=True,
247
+ )
248
+
249
+ # ── Configure LoRA ────────────────────────────────────────────────────────
250
+ lora_config = LoraConfig(
251
+ r=hparams["lora_r"],
252
+ lora_alpha=hparams["lora_alpha"],
253
+ lora_dropout=hparams["lora_dropout"],
254
+ target_modules="all-linear",
255
+ task_type=TaskType.CAUSAL_LM,
256
+ )
257
+
258
+ # ── Load preference dataset ───────────────────────────────────────────────
259
+ dataset = load_preference_dataset(
260
+ DATA_DIR,
261
+ hparams["chosen_field"],
262
+ hparams["rejected_field"],
263
+ hparams["prompt_field"],
264
+ )
265
+
266
+ # ── DPO training configuration ───────────────────────────────────────────
267
+ training_args = DPOConfig(
268
+ output_dir=os.path.join(CHECKPOINT_DIR, "trainer-state"),
269
+ num_train_epochs=hparams["epochs"],
270
+ per_device_train_batch_size=hparams["batch_size"],
271
+ gradient_accumulation_steps=hparams["gradient_accumulation_steps"],
272
+ learning_rate=hparams["learning_rate"],
273
+ warmup_ratio=hparams["warmup_ratio"],
274
+ beta=hparams["beta"],
275
+ max_length=hparams["max_length"],
276
+ max_prompt_length=hparams["max_prompt_length"],
277
+ bf16=True,
278
+ logging_steps=10,
279
+ save_strategy="epoch",
280
+ save_total_limit=2,
281
+ report_to="none",
282
+ remove_unused_columns=False,
283
+ )
284
+
285
+ # ── Check for existing checkpoint (spot resume) ───────────────────────────
286
+ resume_from_checkpoint = None
287
+ trainer_state_dir = os.path.join(CHECKPOINT_DIR, "trainer-state")
288
+ if os.path.isdir(trainer_state_dir):
289
+ checkpoints = sorted(
290
+ glob.glob(os.path.join(trainer_state_dir, "checkpoint-*")),
291
+ key=lambda x: int(x.rsplit("-", 1)[-1]) if x.rsplit("-", 1)[-1].isdigit() else 0,
292
+ )
293
+ if checkpoints:
294
+ resume_from_checkpoint = checkpoints[-1]
295
+ if accelerator.is_main_process:
296
+ logger.info("Resuming from checkpoint: %s", resume_from_checkpoint)
297
+
298
+ # ── Initialize DPOTrainer ─────────────────────────────────────────────────
299
+ # DPOTrainer creates the reference model internally (frozen copy of base).
300
+ # With reference_free=True, it skips the reference model to save memory.
301
+ trainer_kwargs = {
302
+ "model": model,
303
+ "args": training_args,
304
+ "train_dataset": dataset,
305
+ "processing_class": tokenizer,
306
+ "peft_config": lora_config,
307
+ }
308
+
309
+ # Handle reference-free mode (skips frozen reference model to save memory)
310
+ if hparams["reference_free"]:
311
+ trainer_kwargs["ref_model"] = None
312
+ if accelerator.is_main_process:
313
+ logger.info("Reference-free mode: skipping reference model (saves ~50%% memory)")
314
+
315
+ trainer = DPOTrainer(**trainer_kwargs)
316
+
317
+ # ── Train ─────────────────────────────────────────────────────────────────
318
+ if accelerator.is_main_process:
319
+ logger.info("Starting DPO training...")
320
+
321
+ train_result = trainer.train(resume_from_checkpoint=resume_from_checkpoint)
322
+
323
+ # ── Save adapter (rank 0 only) ───────────────────────────────────────────
324
+ if accelerator.is_main_process:
325
+ logger.info("Saving LoRA adapter to: %s", OUTPUT_DIR)
326
+ trainer.save_model(OUTPUT_DIR)
327
+ tokenizer.save_pretrained(OUTPUT_DIR)
328
+
329
+ # ── Log final metrics ─────────────────────────────────────────────────
330
+ # DPO-specific metrics: rewards/chosen, rewards/rejected, rewards/margins
331
+ metrics = train_result.metrics
332
+ print(f"train_loss: {metrics.get('train_loss', 0.0):.4f}")
333
+ print(f"train_runtime: {metrics.get('train_runtime', 0.0):.1f}")
334
+ print(f"train_samples_per_second: {metrics.get('train_samples_per_second', 0.0):.2f}")
335
+
336
+ # DPO reward metrics (logged by DPOTrainer during training)
337
+ rewards_chosen = metrics.get("rewards/chosen", metrics.get("train_rewards/chosen", None))
338
+ rewards_rejected = metrics.get("rewards/rejected", metrics.get("train_rewards/rejected", None))
339
+ if rewards_chosen is not None:
340
+ print(f"rewards_chosen: {rewards_chosen:.4f}")
341
+ if rewards_rejected is not None:
342
+ print(f"rewards_rejected: {rewards_rejected:.4f}")
343
+ if rewards_chosen is not None and rewards_rejected is not None:
344
+ margin = rewards_chosen - rewards_rejected
345
+ print(f"rewards_margin: {margin:.4f}")
346
+
347
+ print(f"epochs: {hparams['epochs']}")
348
+
349
+ logger.info("Training complete!")
350
+ logger.info(" Loss: %.4f", metrics.get("train_loss", 0.0))
351
+ logger.info(" Runtime: %.1fs", metrics.get("train_runtime", 0.0))
352
+ if rewards_chosen is not None:
353
+ logger.info(" Reward margin: %.4f (chosen=%.4f, rejected=%.4f)",
354
+ rewards_chosen - rewards_rejected, rewards_chosen, rewards_rejected)
355
+
356
+ # Wait for all processes before exit
357
+ accelerator.wait_for_everyone()
358
+
359
+
360
+ # ── Entry Point ───────────────────────────────────────────────────────────────
361
+
362
+ if __name__ == "__main__":
363
+ main()
@@ -0,0 +1,22 @@
1
+ # Accelerate configuration for SFT training.
2
+ # FSDP-ready: works on both single-GPU and multi-GPU automatically.
3
+ #
4
+ # - Single-GPU (ml.g5.xlarge): accelerate detects 1 GPU, runs single-process.
5
+ # FSDP is effectively a no-op with 1 rank.
6
+ # - Multi-GPU (ml.g5.12xlarge, 4× A10G): accelerate detects 4 GPUs,
7
+ # runs FSDP with full sharding across all devices.
8
+ #
9
+ # No code or config changes needed between environments.
10
+
11
+ compute_environment: LOCAL_MACHINE
12
+ distributed_type: FSDP
13
+ fsdp_config:
14
+ fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
15
+ fsdp_backward_prefetch: BACKWARD_PRE
16
+ fsdp_offload_params: false
17
+ fsdp_sharding_strategy: FULL_SHARD
18
+ fsdp_state_dict_type: SHARDED_STATE_DICT
19
+ fsdp_sync_module_states: true
20
+ fsdp_use_orig_params: true
21
+ mixed_precision: bf16
22
+ num_processes: auto
@@ -0,0 +1,18 @@
1
+ # Default hyperparameters for SFT (Supervised Fine-Tuning) training.
2
+ #
3
+ # These are the lowest-priority defaults. Override via:
4
+ # 1. CLI flags (--learning-rate, --epochs, etc.) — highest priority
5
+ # 2. training/config.yaml hyperparameters section
6
+ # 3. Interactive mode answers (written to config.yaml)
7
+ # 4. These defaults — lowest priority
8
+
9
+ lora_r: 16
10
+ lora_alpha: 32
11
+ lora_dropout: 0.05
12
+ learning_rate: 2e-4
13
+ epochs: 3
14
+ batch_size: 4
15
+ max_seq_length: 2048
16
+ gradient_accumulation_steps: 4
17
+ warmup_ratio: 0.03
18
+ dataset_text_field: text
@@ -0,0 +1,7 @@
1
+ {
2
+ "section_title": "SFT-specific settings",
3
+ "prompts": [
4
+ {"name": "dataset_text_field", "type": "input", "message": "Text/instruction column name?", "default": "text", "validate": "string"},
5
+ {"name": "max_seq_length", "type": "input", "message": "Max sequence length?", "default": "2048", "validate": "int"}
6
+ ]
7
+ }