@aws/ml-container-creator 1.0.2 → 1.0.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +1 -1
- package/bin/cli.js +1 -1
- package/config/tune-catalog.json +303 -1
- package/infra/ci-harness/lib/ci-harness-stack.ts +43 -0
- package/package.json +3 -2
- package/servers/base-image-picker/index.js +65 -18
- package/servers/instance-sizer/index.js +32 -0
- package/servers/lib/catalogs/fleet-drivers.json +38 -0
- package/servers/lib/catalogs/model-arch-support.json +51 -0
- package/servers/lib/catalogs/model-servers.json +2842 -1516
- package/servers/lib/schemas/image-catalog.schema.json +12 -0
- package/src/app.js +6 -4
- package/src/lib/bootstrap-command-handler.js +12 -2
- package/src/lib/bootstrap-profile-manager.js +16 -0
- package/src/lib/cross-cutting-checker.js +6 -1
- package/src/lib/generated/cli-options.js +1 -1
- package/src/lib/generated/parameter-matrix.js +1 -1
- package/src/lib/generated/validation-rules.js +1 -1
- package/src/lib/mcp-query-runner.js +110 -3
- package/src/lib/prompt-runner.js +66 -22
- package/src/lib/template-variable-resolver.js +8 -0
- package/src/lib/train-config-builder.js +339 -0
- package/templates/do/.benchmark_writer.py +3 -0
- package/templates/do/.eval_helper.py +409 -0
- package/templates/do/.register_helper.py +185 -11
- package/templates/do/.train_build_request.py +102 -113
- package/templates/do/.train_helper.py +433 -0
- package/templates/do/__pycache__/.register_helper.cpython-312.pyc +0 -0
- package/templates/do/adapter +157 -0
- package/templates/do/benchmark +60 -3
- package/templates/do/deploy.d/managed-inference.ejs +83 -0
- package/templates/do/evaluate +272 -0
- package/templates/do/lib/resolve-instance.sh +155 -0
- package/templates/do/register +5 -0
- package/templates/do/test +1 -0
- package/templates/do/train +879 -126
- package/templates/do/training/config.yaml +83 -11
- package/templates/do/training/dpo/accelerate_config.yaml +24 -0
- package/templates/do/training/dpo/defaults.yaml +26 -0
- package/templates/do/training/dpo/prompts.json +8 -0
- package/templates/do/training/dpo/train.py +363 -0
- package/templates/do/training/sft/accelerate_config.yaml +22 -0
- package/templates/do/training/sft/defaults.yaml +18 -0
- package/templates/do/training/sft/prompts.json +7 -0
- package/templates/do/training/sft/train.py +310 -0
- package/templates/do/tune +11 -2
- package/templates/do/.train_poll_parser.py +0 -135
- package/templates/do/.train_status_parser.py +0 -187
- /package/templates/do/training/{train.py → custom/train.py} +0 -0
|
@@ -5,20 +5,33 @@
|
|
|
5
5
|
# This file is read by do/train to construct a CreateTrainingJob request.
|
|
6
6
|
# Required fields are marked with [REQUIRED]. All other fields are optional.
|
|
7
7
|
|
|
8
|
+
# ─── Training Technique ───────────────────────────────────────────────────────
|
|
9
|
+
# Which training technique to use. Each technique has its own training script
|
|
10
|
+
# in do/training/<technique>/train.py with appropriate defaults.
|
|
11
|
+
#
|
|
12
|
+
# Available techniques:
|
|
13
|
+
# custom — Your own training script (default, user-provided logic)
|
|
14
|
+
# sft — Supervised Fine-Tuning using TRL SFTTrainer + PEFT LoRA
|
|
15
|
+
# dpo — Direct Preference Optimization (future, see Spec G2)
|
|
16
|
+
#
|
|
17
|
+
# Override with CLI: ./do/train --technique sft
|
|
18
|
+
technique: "custom"
|
|
19
|
+
|
|
8
20
|
# ─── Container Image [REQUIRED] ──────────────────────────────────────────────
|
|
9
21
|
# ECR image URI for the training container.
|
|
10
22
|
# This defaults to your project's own container image (built via do/build + do/push).
|
|
11
23
|
# You can also use a SageMaker pre-built training image or any custom ECR image.
|
|
12
24
|
#
|
|
13
25
|
# Example: 123456789012.dkr.ecr.us-east-1.amazonaws.com/ml-container-creator:latest
|
|
14
|
-
image: "
|
|
26
|
+
image: "${_PROFILE_accountId}.dkr.ecr.<%= awsRegion %>.amazonaws.com/${_PROFILE_ecrRepositoryName:-ml-container-creator}:<%= projectName %>-latest"
|
|
15
27
|
|
|
16
|
-
# ─── Training Script [REQUIRED]
|
|
28
|
+
# ─── Training Script [REQUIRED for custom technique] ─────────────────────────
|
|
17
29
|
# Path to your training script relative to the project root.
|
|
18
|
-
#
|
|
30
|
+
# For technique-based training (sft, dpo), the script is auto-selected from
|
|
31
|
+
# do/training/<technique>/train.py — this field is only used for 'custom'.
|
|
19
32
|
#
|
|
20
|
-
# Example: do/training/train.py
|
|
21
|
-
script: "do/training/train.py"
|
|
33
|
+
# Example: do/training/custom/train.py
|
|
34
|
+
script: "do/training/custom/train.py"
|
|
22
35
|
|
|
23
36
|
# ─── Instance Configuration [REQUIRED] ───────────────────────────────────────
|
|
24
37
|
# SageMaker ML instance type for training.
|
|
@@ -37,9 +50,39 @@ instance_type: "ml.g5.xlarge"
|
|
|
37
50
|
# (e.g., using PyTorch DDP, Horovod, or similar). SageMaker handles inter-node
|
|
38
51
|
# communication setup automatically.
|
|
39
52
|
#
|
|
53
|
+
# For multi-node training (instance_count > 1):
|
|
54
|
+
# - Use EFA-capable instances for best performance: p4d.24xlarge, p5.48xlarge, g5.48xlarge
|
|
55
|
+
# - SageMaker auto-configures NCCL/EFA for inter-node communication
|
|
56
|
+
# - Accelerate auto-detects WORLD_SIZE > local_gpu_count and enables multi-node FSDP
|
|
57
|
+
# - No changes needed to training scripts or accelerate_config.yaml
|
|
58
|
+
#
|
|
40
59
|
# Example: 2
|
|
41
60
|
instance_count: 1
|
|
42
61
|
|
|
62
|
+
# Networking backend for distributed training (multi-node).
|
|
63
|
+
# auto = detect EFA presence and use NCCL if available, gloo otherwise
|
|
64
|
+
# nccl = force NCCL (requires EFA-capable instances for multi-node)
|
|
65
|
+
# gloo = force gloo (CPU-based, works on any instance but slower)
|
|
66
|
+
networking_backend: "auto"
|
|
67
|
+
|
|
68
|
+
# ─── Launcher (optional) ─────────────────────────────────────────────────────
|
|
69
|
+
# How the training script is invoked. Default: accelerate (recommended for TRL).
|
|
70
|
+
#
|
|
71
|
+
# Options:
|
|
72
|
+
# accelerate — `accelerate launch --config_file <config> train.py` (default)
|
|
73
|
+
# Handles single-GPU, multi-GPU (FSDP), and multi-node automatically.
|
|
74
|
+
# torchrun — `torchrun --nproc_per_node <gpus> train.py`
|
|
75
|
+
# Use for scripts that manage distributed setup internally.
|
|
76
|
+
# deepspeed — `deepspeed --num_gpus <gpus> train.py`
|
|
77
|
+
# Use for ZeRO-3 offloading with >70B models.
|
|
78
|
+
# custom — Direct `python train.py` invocation (no distributed wrapper).
|
|
79
|
+
# Use only for scripts that handle everything internally.
|
|
80
|
+
#
|
|
81
|
+
# Note: SFT and DPO techniques include accelerate_config.yaml and are designed
|
|
82
|
+
# for the `accelerate` launcher. Using a different launcher with these techniques
|
|
83
|
+
# may require modifying the training script.
|
|
84
|
+
launcher: "accelerate"
|
|
85
|
+
|
|
43
86
|
# ─── Data Configuration [REQUIRED] ───────────────────────────────────────────
|
|
44
87
|
# S3 URI for the input training dataset.
|
|
45
88
|
# SageMaker downloads this to /opt/ml/input/data/training/ in the container.
|
|
@@ -60,13 +103,27 @@ output_path: ""
|
|
|
60
103
|
# All values must be strings. SageMaker passes these as string arguments to
|
|
61
104
|
# the training script entry point.
|
|
62
105
|
#
|
|
63
|
-
#
|
|
106
|
+
# For 'custom' technique: define your own hyperparameters.
|
|
107
|
+
# For 'sft' technique: the following are recognized:
|
|
108
|
+
# lora_r, lora_alpha, lora_dropout, learning_rate, epochs,
|
|
109
|
+
# batch_size, max_seq_length, gradient_accumulation_steps,
|
|
110
|
+
# warmup_ratio, dataset_text_field, model_id
|
|
111
|
+
#
|
|
112
|
+
# Example (custom):
|
|
64
113
|
# hyperparameters:
|
|
65
114
|
# epochs: "10"
|
|
66
115
|
# batch_size: "32"
|
|
67
116
|
# learning_rate: "5e-5"
|
|
68
|
-
#
|
|
69
|
-
#
|
|
117
|
+
#
|
|
118
|
+
# Example (sft):
|
|
119
|
+
# hyperparameters:
|
|
120
|
+
# lora_r: "16"
|
|
121
|
+
# lora_alpha: "32"
|
|
122
|
+
# learning_rate: "2e-4"
|
|
123
|
+
# epochs: "3"
|
|
124
|
+
# batch_size: "4"
|
|
125
|
+
# max_seq_length: "2048"
|
|
126
|
+
# dataset_text_field: "text"
|
|
70
127
|
hyperparameters:
|
|
71
128
|
epochs: "10"
|
|
72
129
|
batch_size: "32"
|
|
@@ -102,19 +159,34 @@ enable_spot: false
|
|
|
102
159
|
# Example: 172800
|
|
103
160
|
max_wait_seconds: 172800
|
|
104
161
|
|
|
162
|
+
# S3 URI for checkpoint storage during spot training.
|
|
163
|
+
# SageMaker syncs /opt/ml/checkpoints/ to this location between interruptions.
|
|
164
|
+
# If empty, auto-derived from output_path: <output_path>/checkpoints/
|
|
165
|
+
# Only used when enable_spot is true.
|
|
166
|
+
#
|
|
167
|
+
# Example: s3://my-bucket/checkpoints/
|
|
168
|
+
checkpoint_s3_uri: ""
|
|
169
|
+
|
|
105
170
|
# ─── Metric Definitions (optional) ───────────────────────────────────────────
|
|
106
171
|
# Custom regex patterns to extract training metrics from container logs.
|
|
107
172
|
# SageMaker uses these to publish metrics to CloudWatch for monitoring.
|
|
108
173
|
# Each entry needs a name and a regex with exactly one capture group.
|
|
109
174
|
#
|
|
110
|
-
# Example:
|
|
175
|
+
# Example (custom):
|
|
111
176
|
# metric_definitions:
|
|
112
177
|
# - name: "train:loss"
|
|
113
178
|
# regex: "loss: ([0-9\\.]+)"
|
|
114
179
|
# - name: "train:epoch"
|
|
115
180
|
# regex: "epoch: ([0-9\\.]+)"
|
|
116
|
-
#
|
|
117
|
-
#
|
|
181
|
+
#
|
|
182
|
+
# Example (sft — matches output from training/sft/train.py):
|
|
183
|
+
# metric_definitions:
|
|
184
|
+
# - name: "train:loss"
|
|
185
|
+
# regex: "train_loss: ([0-9\\.]+)"
|
|
186
|
+
# - name: "train:samples_per_second"
|
|
187
|
+
# regex: "train_samples_per_second: ([0-9\\.]+)"
|
|
188
|
+
# - name: "train:epoch"
|
|
189
|
+
# regex: "epochs: ([0-9\\.]+)"
|
|
118
190
|
metric_definitions: []
|
|
119
191
|
|
|
120
192
|
# ─── Environment Variables (optional) ────────────────────────────────────────
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# Accelerate configuration for DPO training.
|
|
2
|
+
# FSDP-ready: works on both single-GPU and multi-GPU automatically.
|
|
3
|
+
#
|
|
4
|
+
# NOTE: DPO requires a reference model (frozen copy of base model), which
|
|
5
|
+
# doubles memory usage. For models >7B, consider:
|
|
6
|
+
# - reference_free: true (in hyperparameters) — skips reference model
|
|
7
|
+
# - Using a larger instance (ml.g5.12xlarge for 4× A10G = 96GB total)
|
|
8
|
+
# - FSDP offloading: set fsdp_offload_params: true below
|
|
9
|
+
#
|
|
10
|
+
# Single-GPU (ml.g5.xlarge): accelerate detects 1 GPU, runs single-process.
|
|
11
|
+
# Multi-GPU (ml.g5.12xlarge): accelerate detects 4 GPUs, runs FSDP.
|
|
12
|
+
|
|
13
|
+
compute_environment: LOCAL_MACHINE
|
|
14
|
+
distributed_type: FSDP
|
|
15
|
+
fsdp_config:
|
|
16
|
+
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
|
17
|
+
fsdp_backward_prefetch: BACKWARD_PRE
|
|
18
|
+
fsdp_offload_params: false
|
|
19
|
+
fsdp_sharding_strategy: FULL_SHARD
|
|
20
|
+
fsdp_state_dict_type: SHARDED_STATE_DICT
|
|
21
|
+
fsdp_sync_module_states: true
|
|
22
|
+
fsdp_use_orig_params: true
|
|
23
|
+
mixed_precision: bf16
|
|
24
|
+
num_processes: auto
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# Default hyperparameters for DPO (Direct Preference Optimization) training.
|
|
2
|
+
#
|
|
3
|
+
# These are the lowest-priority defaults. Override via:
|
|
4
|
+
# 1. CLI flags (--beta, --learning-rate, etc.) — highest priority
|
|
5
|
+
# 2. training/config.yaml hyperparameters section
|
|
6
|
+
# 3. Interactive mode answers (written to config.yaml)
|
|
7
|
+
# 4. These defaults — lowest priority
|
|
8
|
+
#
|
|
9
|
+
# NOTE: DPO creates a frozen reference model internally, doubling memory.
|
|
10
|
+
# For models >7B on single GPU, set reference_free: true or use multi-GPU.
|
|
11
|
+
|
|
12
|
+
beta: 0.1
|
|
13
|
+
lora_r: 16
|
|
14
|
+
lora_alpha: 32
|
|
15
|
+
lora_dropout: 0.05
|
|
16
|
+
learning_rate: 5e-7
|
|
17
|
+
epochs: 1
|
|
18
|
+
batch_size: 2
|
|
19
|
+
max_length: 1024
|
|
20
|
+
max_prompt_length: 512
|
|
21
|
+
gradient_accumulation_steps: 4
|
|
22
|
+
warmup_ratio: 0.03
|
|
23
|
+
chosen_field: chosen
|
|
24
|
+
rejected_field: rejected
|
|
25
|
+
prompt_field: prompt
|
|
26
|
+
reference_free: false
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
{
|
|
2
|
+
"section_title": "DPO-specific settings",
|
|
3
|
+
"prompts": [
|
|
4
|
+
{"name": "beta", "type": "input", "message": "Beta (KL penalty coefficient)?", "default": "0.1", "validate": "float"},
|
|
5
|
+
{"name": "chosen_field", "type": "input", "message": "Chosen response column name?", "default": "chosen", "validate": "string"},
|
|
6
|
+
{"name": "rejected_field", "type": "input", "message": "Rejected response column name?", "default": "rejected", "validate": "string"}
|
|
7
|
+
]
|
|
8
|
+
}
|
|
@@ -0,0 +1,363 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
"""Unmanaged DPO fine-tuning using TRL DPOTrainer + PEFT LoRA.
|
|
6
|
+
|
|
7
|
+
Direct Preference Optimization trains a model to prefer chosen responses
|
|
8
|
+
over rejected responses without explicit reward modeling. This script uses
|
|
9
|
+
TRL's DPOTrainer with PEFT LoRA for parameter-efficient training.
|
|
10
|
+
|
|
11
|
+
Portable env-var contract (works on SageMaker AI and HyperPod EKS):
|
|
12
|
+
DATA_DIR / SM_CHANNEL_TRAINING -> training data path
|
|
13
|
+
OUTPUT_DIR / SM_MODEL_DIR -> model artifact output
|
|
14
|
+
CHECKPOINT_DIR / SM_CHECKPOINT_DIR -> checkpoint path for spot resume
|
|
15
|
+
HF_MODEL_ID / SM_HP_MODEL_ID -> base model HuggingFace ID
|
|
16
|
+
SM_HPS -> JSON blob of all hyperparameters
|
|
17
|
+
|
|
18
|
+
Dataset format:
|
|
19
|
+
JSONL with fields: prompt, chosen, rejected
|
|
20
|
+
Example: {"prompt": "Explain X", "chosen": "Good answer", "rejected": "Bad answer"}
|
|
21
|
+
|
|
22
|
+
Output:
|
|
23
|
+
LoRA adapter saved to OUTPUT_DIR (adapter_model.safetensors + adapter_config.json)
|
|
24
|
+
Metrics logged to stdout in SageMaker-parseable format
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
import glob
|
|
28
|
+
import json
|
|
29
|
+
import logging
|
|
30
|
+
import os
|
|
31
|
+
import sys
|
|
32
|
+
|
|
33
|
+
# ── Logging ───────────────────────────────────────────────────────────────────
|
|
34
|
+
logging.basicConfig(
|
|
35
|
+
level=logging.INFO,
|
|
36
|
+
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
|
37
|
+
)
|
|
38
|
+
logger = logging.getLogger("dpo-trainer")
|
|
39
|
+
|
|
40
|
+
# ── Portable Path Resolution ─────────────────────────────────────────────────
|
|
41
|
+
# Fallback chain: generic env var -> SageMaker env var -> default path
|
|
42
|
+
|
|
43
|
+
DATA_DIR = (
|
|
44
|
+
os.environ.get("DATA_DIR")
|
|
45
|
+
or os.environ.get("SM_CHANNEL_TRAINING")
|
|
46
|
+
or "/opt/ml/input/data/training"
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
OUTPUT_DIR = (
|
|
50
|
+
os.environ.get("OUTPUT_DIR")
|
|
51
|
+
or os.environ.get("SM_MODEL_DIR")
|
|
52
|
+
or "/opt/ml/model"
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
CHECKPOINT_DIR = (
|
|
56
|
+
os.environ.get("CHECKPOINT_DIR")
|
|
57
|
+
or os.environ.get("SM_CHECKPOINT_DIR")
|
|
58
|
+
or "/opt/ml/checkpoints"
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
MODEL_ID = (
|
|
62
|
+
os.environ.get("HF_MODEL_ID")
|
|
63
|
+
or os.environ.get("SM_HP_MODEL_ID")
|
|
64
|
+
or ""
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
# ── Hyperparameter Loading ────────────────────────────────────────────────────
|
|
69
|
+
|
|
70
|
+
def load_hyperparameters():
|
|
71
|
+
"""Load hyperparameters from SageMaker SM_HPS env var or individual SM_HP_* vars.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
dict with typed hyperparameter values.
|
|
75
|
+
"""
|
|
76
|
+
defaults = {
|
|
77
|
+
"model_id": MODEL_ID,
|
|
78
|
+
"beta": 0.1,
|
|
79
|
+
"lora_r": 16,
|
|
80
|
+
"lora_alpha": 32,
|
|
81
|
+
"lora_dropout": 0.05,
|
|
82
|
+
"learning_rate": 5e-7,
|
|
83
|
+
"epochs": 1,
|
|
84
|
+
"batch_size": 2,
|
|
85
|
+
"max_length": 1024,
|
|
86
|
+
"max_prompt_length": 512,
|
|
87
|
+
"gradient_accumulation_steps": 4,
|
|
88
|
+
"warmup_ratio": 0.03,
|
|
89
|
+
"chosen_field": "chosen",
|
|
90
|
+
"rejected_field": "rejected",
|
|
91
|
+
"prompt_field": "prompt",
|
|
92
|
+
"reference_free": False,
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
# Try SM_HPS (JSON blob of all hyperparameters)
|
|
96
|
+
sm_hps = os.environ.get("SM_HPS")
|
|
97
|
+
if sm_hps:
|
|
98
|
+
try:
|
|
99
|
+
raw = json.loads(sm_hps)
|
|
100
|
+
for key, default_val in defaults.items():
|
|
101
|
+
if key in raw:
|
|
102
|
+
defaults[key] = _cast(raw[key], type(default_val))
|
|
103
|
+
return defaults
|
|
104
|
+
except (json.JSONDecodeError, ValueError) as e:
|
|
105
|
+
logger.warning("Failed to parse SM_HPS: %s", e)
|
|
106
|
+
|
|
107
|
+
# Fallback: individual SM_HP_* env vars
|
|
108
|
+
for key, default_val in defaults.items():
|
|
109
|
+
env_key = f"SM_HP_{key.upper()}"
|
|
110
|
+
env_val = os.environ.get(env_key)
|
|
111
|
+
if env_val is not None:
|
|
112
|
+
defaults[key] = _cast(env_val, type(default_val))
|
|
113
|
+
|
|
114
|
+
return defaults
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _cast(value, target_type):
|
|
118
|
+
"""Cast a string value to the target type."""
|
|
119
|
+
if target_type == bool:
|
|
120
|
+
return str(value).lower() in ("true", "1", "yes")
|
|
121
|
+
if target_type == int:
|
|
122
|
+
return int(float(value))
|
|
123
|
+
if target_type == float:
|
|
124
|
+
return float(value)
|
|
125
|
+
return str(value)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
# ── Dataset Loading ───────────────────────────────────────────────────────────
|
|
129
|
+
|
|
130
|
+
def load_preference_dataset(data_dir, chosen_field, rejected_field, prompt_field):
|
|
131
|
+
"""Load DPO preference dataset from data directory.
|
|
132
|
+
|
|
133
|
+
Expects JSONL with at minimum `chosen` and `rejected` fields.
|
|
134
|
+
Optionally includes a `prompt` field for conditional DPO.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
data_dir: Path to directory containing training data files.
|
|
138
|
+
chosen_field: Column name for preferred responses.
|
|
139
|
+
rejected_field: Column name for dispreferred responses.
|
|
140
|
+
prompt_field: Column name for prompts (optional).
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
A Hugging Face Dataset object.
|
|
144
|
+
"""
|
|
145
|
+
from datasets import load_dataset as hf_load_dataset
|
|
146
|
+
|
|
147
|
+
# Find data files
|
|
148
|
+
extensions = ["jsonl", "json", "parquet", "csv"]
|
|
149
|
+
data_files = []
|
|
150
|
+
for ext in extensions:
|
|
151
|
+
data_files.extend(glob.glob(os.path.join(data_dir, f"*.{ext}")))
|
|
152
|
+
data_files.extend(glob.glob(os.path.join(data_dir, f"**/*.{ext}"), recursive=True))
|
|
153
|
+
|
|
154
|
+
if not data_files:
|
|
155
|
+
logger.error("No data files found in %s (searched: %s)", data_dir, extensions)
|
|
156
|
+
sys.exit(1)
|
|
157
|
+
|
|
158
|
+
# Deduplicate and sort
|
|
159
|
+
data_files = sorted(set(data_files))
|
|
160
|
+
logger.info("Found %d data file(s) in %s", len(data_files), data_dir)
|
|
161
|
+
|
|
162
|
+
# Determine format from first file extension
|
|
163
|
+
first_ext = data_files[0].rsplit(".", 1)[-1].lower()
|
|
164
|
+
format_map = {"jsonl": "json", "json": "json", "parquet": "parquet", "csv": "csv"}
|
|
165
|
+
file_format = format_map.get(first_ext, "json")
|
|
166
|
+
|
|
167
|
+
dataset = hf_load_dataset(file_format, data_files=data_files, split="train")
|
|
168
|
+
logger.info("Loaded dataset: %d rows, columns: %s", len(dataset), dataset.column_names)
|
|
169
|
+
|
|
170
|
+
# Verify required fields exist
|
|
171
|
+
missing = []
|
|
172
|
+
if chosen_field not in dataset.column_names:
|
|
173
|
+
missing.append(f"chosen_field='{chosen_field}'")
|
|
174
|
+
if rejected_field not in dataset.column_names:
|
|
175
|
+
missing.append(f"rejected_field='{rejected_field}'")
|
|
176
|
+
|
|
177
|
+
if missing:
|
|
178
|
+
logger.error(
|
|
179
|
+
"Required fields not found in dataset: %s. Available columns: %s",
|
|
180
|
+
", ".join(missing), dataset.column_names,
|
|
181
|
+
)
|
|
182
|
+
sys.exit(1)
|
|
183
|
+
|
|
184
|
+
# Check for optional prompt field
|
|
185
|
+
has_prompt = prompt_field in dataset.column_names
|
|
186
|
+
if has_prompt:
|
|
187
|
+
logger.info("Prompt field '%s' found — using conditional DPO", prompt_field)
|
|
188
|
+
else:
|
|
189
|
+
logger.info("No prompt field '%s' — using unconditional DPO", prompt_field)
|
|
190
|
+
|
|
191
|
+
return dataset
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
# ── Main Training Function ────────────────────────────────────────────────────
|
|
195
|
+
|
|
196
|
+
def main():
|
|
197
|
+
"""Run DPO training with TRL DPOTrainer + PEFT LoRA."""
|
|
198
|
+
from accelerate import Accelerator
|
|
199
|
+
from peft import LoraConfig, TaskType
|
|
200
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
201
|
+
from trl import DPOConfig, DPOTrainer
|
|
202
|
+
|
|
203
|
+
# Initialize accelerator (handles distributed setup)
|
|
204
|
+
accelerator = Accelerator()
|
|
205
|
+
|
|
206
|
+
# Load hyperparameters
|
|
207
|
+
hparams = load_hyperparameters()
|
|
208
|
+
model_id = hparams["model_id"]
|
|
209
|
+
|
|
210
|
+
if not model_id:
|
|
211
|
+
logger.error("No model ID specified. Set HF_MODEL_ID env var or model_id hyperparameter.")
|
|
212
|
+
sys.exit(1)
|
|
213
|
+
|
|
214
|
+
if accelerator.is_main_process:
|
|
215
|
+
logger.info("=" * 60)
|
|
216
|
+
logger.info("DPO Training Configuration")
|
|
217
|
+
logger.info("=" * 60)
|
|
218
|
+
logger.info(" Model: %s", model_id)
|
|
219
|
+
logger.info(" Data dir: %s", DATA_DIR)
|
|
220
|
+
logger.info(" Output dir: %s", OUTPUT_DIR)
|
|
221
|
+
logger.info(" Checkpoint dir: %s", CHECKPOINT_DIR)
|
|
222
|
+
logger.info(" Beta: %s", hparams["beta"])
|
|
223
|
+
logger.info(" LoRA r: %d", hparams["lora_r"])
|
|
224
|
+
logger.info(" LoRA alpha: %d", hparams["lora_alpha"])
|
|
225
|
+
logger.info(" Learning rate: %s", hparams["learning_rate"])
|
|
226
|
+
logger.info(" Epochs: %d", hparams["epochs"])
|
|
227
|
+
logger.info(" Batch size: %d", hparams["batch_size"])
|
|
228
|
+
logger.info(" Max length: %d", hparams["max_length"])
|
|
229
|
+
logger.info(" Chosen field: %s", hparams["chosen_field"])
|
|
230
|
+
logger.info(" Rejected field: %s", hparams["rejected_field"])
|
|
231
|
+
logger.info(" Prompt field: %s", hparams["prompt_field"])
|
|
232
|
+
logger.info(" Reference free: %s", hparams["reference_free"])
|
|
233
|
+
logger.info("=" * 60)
|
|
234
|
+
|
|
235
|
+
# ── Load tokenizer and model ──────────────────────────────────────────────
|
|
236
|
+
if accelerator.is_main_process:
|
|
237
|
+
logger.info("Loading tokenizer and model: %s", model_id)
|
|
238
|
+
|
|
239
|
+
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
|
240
|
+
if tokenizer.pad_token is None:
|
|
241
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
242
|
+
|
|
243
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
244
|
+
model_id,
|
|
245
|
+
torch_dtype="auto",
|
|
246
|
+
trust_remote_code=True,
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
# ── Configure LoRA ────────────────────────────────────────────────────────
|
|
250
|
+
lora_config = LoraConfig(
|
|
251
|
+
r=hparams["lora_r"],
|
|
252
|
+
lora_alpha=hparams["lora_alpha"],
|
|
253
|
+
lora_dropout=hparams["lora_dropout"],
|
|
254
|
+
target_modules="all-linear",
|
|
255
|
+
task_type=TaskType.CAUSAL_LM,
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
# ── Load preference dataset ───────────────────────────────────────────────
|
|
259
|
+
dataset = load_preference_dataset(
|
|
260
|
+
DATA_DIR,
|
|
261
|
+
hparams["chosen_field"],
|
|
262
|
+
hparams["rejected_field"],
|
|
263
|
+
hparams["prompt_field"],
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
# ── DPO training configuration ───────────────────────────────────────────
|
|
267
|
+
training_args = DPOConfig(
|
|
268
|
+
output_dir=os.path.join(CHECKPOINT_DIR, "trainer-state"),
|
|
269
|
+
num_train_epochs=hparams["epochs"],
|
|
270
|
+
per_device_train_batch_size=hparams["batch_size"],
|
|
271
|
+
gradient_accumulation_steps=hparams["gradient_accumulation_steps"],
|
|
272
|
+
learning_rate=hparams["learning_rate"],
|
|
273
|
+
warmup_ratio=hparams["warmup_ratio"],
|
|
274
|
+
beta=hparams["beta"],
|
|
275
|
+
max_length=hparams["max_length"],
|
|
276
|
+
max_prompt_length=hparams["max_prompt_length"],
|
|
277
|
+
bf16=True,
|
|
278
|
+
logging_steps=10,
|
|
279
|
+
save_strategy="epoch",
|
|
280
|
+
save_total_limit=2,
|
|
281
|
+
report_to="none",
|
|
282
|
+
remove_unused_columns=False,
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
# ── Check for existing checkpoint (spot resume) ───────────────────────────
|
|
286
|
+
resume_from_checkpoint = None
|
|
287
|
+
trainer_state_dir = os.path.join(CHECKPOINT_DIR, "trainer-state")
|
|
288
|
+
if os.path.isdir(trainer_state_dir):
|
|
289
|
+
checkpoints = sorted(
|
|
290
|
+
glob.glob(os.path.join(trainer_state_dir, "checkpoint-*")),
|
|
291
|
+
key=lambda x: int(x.rsplit("-", 1)[-1]) if x.rsplit("-", 1)[-1].isdigit() else 0,
|
|
292
|
+
)
|
|
293
|
+
if checkpoints:
|
|
294
|
+
resume_from_checkpoint = checkpoints[-1]
|
|
295
|
+
if accelerator.is_main_process:
|
|
296
|
+
logger.info("Resuming from checkpoint: %s", resume_from_checkpoint)
|
|
297
|
+
|
|
298
|
+
# ── Initialize DPOTrainer ─────────────────────────────────────────────────
|
|
299
|
+
# DPOTrainer creates the reference model internally (frozen copy of base).
|
|
300
|
+
# With reference_free=True, it skips the reference model to save memory.
|
|
301
|
+
trainer_kwargs = {
|
|
302
|
+
"model": model,
|
|
303
|
+
"args": training_args,
|
|
304
|
+
"train_dataset": dataset,
|
|
305
|
+
"processing_class": tokenizer,
|
|
306
|
+
"peft_config": lora_config,
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
# Handle reference-free mode (skips frozen reference model to save memory)
|
|
310
|
+
if hparams["reference_free"]:
|
|
311
|
+
trainer_kwargs["ref_model"] = None
|
|
312
|
+
if accelerator.is_main_process:
|
|
313
|
+
logger.info("Reference-free mode: skipping reference model (saves ~50%% memory)")
|
|
314
|
+
|
|
315
|
+
trainer = DPOTrainer(**trainer_kwargs)
|
|
316
|
+
|
|
317
|
+
# ── Train ─────────────────────────────────────────────────────────────────
|
|
318
|
+
if accelerator.is_main_process:
|
|
319
|
+
logger.info("Starting DPO training...")
|
|
320
|
+
|
|
321
|
+
train_result = trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
|
322
|
+
|
|
323
|
+
# ── Save adapter (rank 0 only) ───────────────────────────────────────────
|
|
324
|
+
if accelerator.is_main_process:
|
|
325
|
+
logger.info("Saving LoRA adapter to: %s", OUTPUT_DIR)
|
|
326
|
+
trainer.save_model(OUTPUT_DIR)
|
|
327
|
+
tokenizer.save_pretrained(OUTPUT_DIR)
|
|
328
|
+
|
|
329
|
+
# ── Log final metrics ─────────────────────────────────────────────────
|
|
330
|
+
# DPO-specific metrics: rewards/chosen, rewards/rejected, rewards/margins
|
|
331
|
+
metrics = train_result.metrics
|
|
332
|
+
print(f"train_loss: {metrics.get('train_loss', 0.0):.4f}")
|
|
333
|
+
print(f"train_runtime: {metrics.get('train_runtime', 0.0):.1f}")
|
|
334
|
+
print(f"train_samples_per_second: {metrics.get('train_samples_per_second', 0.0):.2f}")
|
|
335
|
+
|
|
336
|
+
# DPO reward metrics (logged by DPOTrainer during training)
|
|
337
|
+
rewards_chosen = metrics.get("rewards/chosen", metrics.get("train_rewards/chosen", None))
|
|
338
|
+
rewards_rejected = metrics.get("rewards/rejected", metrics.get("train_rewards/rejected", None))
|
|
339
|
+
if rewards_chosen is not None:
|
|
340
|
+
print(f"rewards_chosen: {rewards_chosen:.4f}")
|
|
341
|
+
if rewards_rejected is not None:
|
|
342
|
+
print(f"rewards_rejected: {rewards_rejected:.4f}")
|
|
343
|
+
if rewards_chosen is not None and rewards_rejected is not None:
|
|
344
|
+
margin = rewards_chosen - rewards_rejected
|
|
345
|
+
print(f"rewards_margin: {margin:.4f}")
|
|
346
|
+
|
|
347
|
+
print(f"epochs: {hparams['epochs']}")
|
|
348
|
+
|
|
349
|
+
logger.info("Training complete!")
|
|
350
|
+
logger.info(" Loss: %.4f", metrics.get("train_loss", 0.0))
|
|
351
|
+
logger.info(" Runtime: %.1fs", metrics.get("train_runtime", 0.0))
|
|
352
|
+
if rewards_chosen is not None:
|
|
353
|
+
logger.info(" Reward margin: %.4f (chosen=%.4f, rejected=%.4f)",
|
|
354
|
+
rewards_chosen - rewards_rejected, rewards_chosen, rewards_rejected)
|
|
355
|
+
|
|
356
|
+
# Wait for all processes before exit
|
|
357
|
+
accelerator.wait_for_everyone()
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
# ── Entry Point ───────────────────────────────────────────────────────────────
|
|
361
|
+
|
|
362
|
+
if __name__ == "__main__":
|
|
363
|
+
main()
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# Accelerate configuration for SFT training.
|
|
2
|
+
# FSDP-ready: works on both single-GPU and multi-GPU automatically.
|
|
3
|
+
#
|
|
4
|
+
# - Single-GPU (ml.g5.xlarge): accelerate detects 1 GPU, runs single-process.
|
|
5
|
+
# FSDP is effectively a no-op with 1 rank.
|
|
6
|
+
# - Multi-GPU (ml.g5.12xlarge, 4× A10G): accelerate detects 4 GPUs,
|
|
7
|
+
# runs FSDP with full sharding across all devices.
|
|
8
|
+
#
|
|
9
|
+
# No code or config changes needed between environments.
|
|
10
|
+
|
|
11
|
+
compute_environment: LOCAL_MACHINE
|
|
12
|
+
distributed_type: FSDP
|
|
13
|
+
fsdp_config:
|
|
14
|
+
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
|
15
|
+
fsdp_backward_prefetch: BACKWARD_PRE
|
|
16
|
+
fsdp_offload_params: false
|
|
17
|
+
fsdp_sharding_strategy: FULL_SHARD
|
|
18
|
+
fsdp_state_dict_type: SHARDED_STATE_DICT
|
|
19
|
+
fsdp_sync_module_states: true
|
|
20
|
+
fsdp_use_orig_params: true
|
|
21
|
+
mixed_precision: bf16
|
|
22
|
+
num_processes: auto
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# Default hyperparameters for SFT (Supervised Fine-Tuning) training.
|
|
2
|
+
#
|
|
3
|
+
# These are the lowest-priority defaults. Override via:
|
|
4
|
+
# 1. CLI flags (--learning-rate, --epochs, etc.) — highest priority
|
|
5
|
+
# 2. training/config.yaml hyperparameters section
|
|
6
|
+
# 3. Interactive mode answers (written to config.yaml)
|
|
7
|
+
# 4. These defaults — lowest priority
|
|
8
|
+
|
|
9
|
+
lora_r: 16
|
|
10
|
+
lora_alpha: 32
|
|
11
|
+
lora_dropout: 0.05
|
|
12
|
+
learning_rate: 2e-4
|
|
13
|
+
epochs: 3
|
|
14
|
+
batch_size: 4
|
|
15
|
+
max_seq_length: 2048
|
|
16
|
+
gradient_accumulation_steps: 4
|
|
17
|
+
warmup_ratio: 0.03
|
|
18
|
+
dataset_text_field: text
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
{
|
|
2
|
+
"section_title": "SFT-specific settings",
|
|
3
|
+
"prompts": [
|
|
4
|
+
{"name": "dataset_text_field", "type": "input", "message": "Text/instruction column name?", "default": "text", "validate": "string"},
|
|
5
|
+
{"name": "max_seq_length", "type": "input", "message": "Max sequence length?", "default": "2048", "validate": "int"}
|
|
6
|
+
]
|
|
7
|
+
}
|