mlx-forge 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlx_forge/__init__.py +456 -0
- mlx_forge/_version.py +1 -0
- mlx_forge/adapters/__init__.py +0 -0
- mlx_forge/adapters/lora.py +287 -0
- mlx_forge/adapters/targeting.py +162 -0
- mlx_forge/cli/__init__.py +0 -0
- mlx_forge/cli/data_cmd.py +220 -0
- mlx_forge/cli/generate_cmd.py +99 -0
- mlx_forge/cli/main.py +254 -0
- mlx_forge/cli/prepare_cmd.py +23 -0
- mlx_forge/cli/studio_cmd.py +21 -0
- mlx_forge/cli/train_cmd.py +11 -0
- mlx_forge/config.py +176 -0
- mlx_forge/data/__init__.py +0 -0
- mlx_forge/data/backend.py +157 -0
- mlx_forge/data/batching.py +238 -0
- mlx_forge/data/catalog.py +370 -0
- mlx_forge/data/converter.py +204 -0
- mlx_forge/data/formats.py +191 -0
- mlx_forge/data/mixing.py +61 -0
- mlx_forge/data/packing.py +93 -0
- mlx_forge/data/preprocessing.py +198 -0
- mlx_forge/data/registry.py +221 -0
- mlx_forge/data/validate.py +227 -0
- mlx_forge/inference/__init__.py +1 -0
- mlx_forge/inference/cache.py +122 -0
- mlx_forge/inference/engine.py +232 -0
- mlx_forge/inference/sampling.py +88 -0
- mlx_forge/logging/__init__.py +0 -0
- mlx_forge/logging/metrics.py +52 -0
- mlx_forge/losses/__init__.py +6 -0
- mlx_forge/losses/dpo.py +98 -0
- mlx_forge/losses/sft.py +77 -0
- mlx_forge/manifest.py +164 -0
- mlx_forge/models/__init__.py +0 -0
- mlx_forge/models/_base/__init__.py +12 -0
- mlx_forge/models/_base/activations.py +25 -0
- mlx_forge/models/_base/args.py +38 -0
- mlx_forge/models/_base/attention.py +101 -0
- mlx_forge/models/_base/rope.py +276 -0
- mlx_forge/models/architectures/__init__.py +4 -0
- mlx_forge/models/architectures/gemma.py +352 -0
- mlx_forge/models/architectures/llama.py +236 -0
- mlx_forge/models/architectures/phi3.py +261 -0
- mlx_forge/models/architectures/phi4.py +221 -0
- mlx_forge/models/architectures/qwen2.py +221 -0
- mlx_forge/models/architectures/qwen3.py +231 -0
- mlx_forge/models/architectures/qwen3_5.py +752 -0
- mlx_forge/models/loader.py +143 -0
- mlx_forge/models/memory.py +376 -0
- mlx_forge/models/quantize.py +39 -0
- mlx_forge/models/registry.py +108 -0
- mlx_forge/models/resolve.py +205 -0
- mlx_forge/recipes/__init__.py +5 -0
- mlx_forge/recipes/auto_config.py +104 -0
- mlx_forge/recipes/built_in/chat_sft.yaml +42 -0
- mlx_forge/recipes/built_in/instruction_sft.yaml +42 -0
- mlx_forge/recipes/built_in/preference_dpo.yaml +46 -0
- mlx_forge/recipes/built_in/writing_style.yaml +42 -0
- mlx_forge/recipes/registry.py +90 -0
- mlx_forge/studio/__init__.py +4 -0
- mlx_forge/studio/api/__init__.py +1 -0
- mlx_forge/studio/api/config_schema.py +152 -0
- mlx_forge/studio/api/data_library.py +76 -0
- mlx_forge/studio/api/datasets.py +44 -0
- mlx_forge/studio/api/inference.py +62 -0
- mlx_forge/studio/api/memory.py +58 -0
- mlx_forge/studio/api/models.py +46 -0
- mlx_forge/studio/api/queue.py +63 -0
- mlx_forge/studio/api/recipes.py +67 -0
- mlx_forge/studio/api/runs.py +73 -0
- mlx_forge/studio/api/training.py +48 -0
- mlx_forge/studio/frontend/assets/index-DfE9wCUu.js +46 -0
- mlx_forge/studio/frontend/assets/index-DoKRRrtV.css +1 -0
- mlx_forge/studio/frontend/index.html +14 -0
- mlx_forge/studio/server.py +210 -0
- mlx_forge/studio/services/__init__.py +1 -0
- mlx_forge/studio/services/data_library_service.py +46 -0
- mlx_forge/studio/services/dataset_service.py +73 -0
- mlx_forge/studio/services/memory_service.py +71 -0
- mlx_forge/studio/services/metrics_watcher.py +56 -0
- mlx_forge/studio/services/model_library_service.py +77 -0
- mlx_forge/studio/services/model_service.py +107 -0
- mlx_forge/studio/services/queue_service.py +178 -0
- mlx_forge/studio/services/recipe_service.py +47 -0
- mlx_forge/studio/services/run_service.py +242 -0
- mlx_forge/studio/services/training_service.py +113 -0
- mlx_forge/trainer/__init__.py +0 -0
- mlx_forge/trainer/callbacks.py +150 -0
- mlx_forge/trainer/checkpoint.py +187 -0
- mlx_forge/trainer/dpo_trainer.py +118 -0
- mlx_forge/trainer/optimizer.py +123 -0
- mlx_forge/trainer/state.py +20 -0
- mlx_forge/trainer/trainer.py +319 -0
- mlx_forge-0.2.0.dist-info/METADATA +246 -0
- mlx_forge-0.2.0.dist-info/RECORD +100 -0
- mlx_forge-0.2.0.dist-info/WHEEL +5 -0
- mlx_forge-0.2.0.dist-info/entry_points.txt +2 -0
- mlx_forge-0.2.0.dist-info/licenses/LICENSE +21 -0
- mlx_forge-0.2.0.dist-info/top_level.txt +1 -0
mlx_forge/__init__.py
ADDED
|
@@ -0,0 +1,456 @@
|
|
|
1
|
+
"""MLX Forge — LoRA SFT training framework for MLX on Apple Silicon."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
from transformers import AutoTokenizer
|
|
7
|
+
|
|
8
|
+
from mlx_forge._version import __version__ as __version__
|
|
9
|
+
from mlx_forge.data.formats import detect_format, validate_samples
|
|
10
|
+
from mlx_forge.data.preprocessing import tokenize_dataset
|
|
11
|
+
from mlx_forge.inference.engine import GenerationResult
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def prepare(
|
|
15
|
+
data_path: str,
|
|
16
|
+
model: str,
|
|
17
|
+
output: str | None = None,
|
|
18
|
+
*,
|
|
19
|
+
name: str | None = None,
|
|
20
|
+
trust_remote_code: bool = False,
|
|
21
|
+
max_seq_length: int = 2048,
|
|
22
|
+
mask_prompt: bool = True,
|
|
23
|
+
revision: str | None = None,
|
|
24
|
+
) -> dict:
|
|
25
|
+
"""Pre-tokenize a dataset and save as Arrow dataset for memory-mapped access.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
data_path: Path to JSONL data file
|
|
29
|
+
model: HuggingFace model ID or local path (for tokenizer)
|
|
30
|
+
output: Ignored (kept for CLI compat). Storage is now in ~/.mlxforge/datasets/
|
|
31
|
+
name: Dataset name for the registry. If omitted, derived from filename.
|
|
32
|
+
trust_remote_code: Trust remote code when loading tokenizer
|
|
33
|
+
max_seq_length: Maximum sequence length
|
|
34
|
+
mask_prompt: Mask prompt tokens from loss
|
|
35
|
+
revision: Optional HF revision/commit hash
|
|
36
|
+
|
|
37
|
+
Returns:
|
|
38
|
+
Dict of statistics (sample count, total tokens, etc.)
|
|
39
|
+
"""
|
|
40
|
+
from mlx_forge.data import backend
|
|
41
|
+
from mlx_forge.models.resolve import resolve_model
|
|
42
|
+
|
|
43
|
+
# Resolve model (HF repo ID -> local path)
|
|
44
|
+
print(f"Resolving model: {model}...")
|
|
45
|
+
resolved = resolve_model(
|
|
46
|
+
model,
|
|
47
|
+
revision=revision,
|
|
48
|
+
trust_remote_code=trust_remote_code,
|
|
49
|
+
)
|
|
50
|
+
print()
|
|
51
|
+
|
|
52
|
+
# Load tokenizer
|
|
53
|
+
print(f"Loading tokenizer from {resolved.local_path}...")
|
|
54
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
55
|
+
resolved.local_path,
|
|
56
|
+
trust_remote_code=trust_remote_code,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
# Read JSONL
|
|
60
|
+
print(f"Reading {data_path}...")
|
|
61
|
+
data_path_obj = Path(data_path)
|
|
62
|
+
if not data_path_obj.exists():
|
|
63
|
+
raise FileNotFoundError(f"Data file not found: {data_path}")
|
|
64
|
+
|
|
65
|
+
with open(data_path_obj) as f:
|
|
66
|
+
samples = [json.loads(line) for line in f if line.strip()]
|
|
67
|
+
|
|
68
|
+
if not samples:
|
|
69
|
+
raise ValueError(f"No samples found in {data_path}")
|
|
70
|
+
|
|
71
|
+
# Detect format
|
|
72
|
+
fmt = detect_format(samples)
|
|
73
|
+
print(f"Detected format: {fmt}")
|
|
74
|
+
|
|
75
|
+
# Validate samples
|
|
76
|
+
print(f"Validating {len(samples)} samples...")
|
|
77
|
+
errors = validate_samples(samples, fmt)
|
|
78
|
+
if errors:
|
|
79
|
+
error_msg = "\n".join(errors[:10])
|
|
80
|
+
if len(errors) > 10:
|
|
81
|
+
error_msg += f"\n... and {len(errors) - 10} more errors"
|
|
82
|
+
raise ValueError(f"Validation failed:\n{error_msg}")
|
|
83
|
+
|
|
84
|
+
# Derive dataset name from filename if not provided
|
|
85
|
+
dataset_name = name or data_path_obj.stem
|
|
86
|
+
|
|
87
|
+
# Check if already processed
|
|
88
|
+
if backend.tokenized_exists(dataset_name, model):
|
|
89
|
+
print(f"Already processed: {dataset_name} for {model}")
|
|
90
|
+
path = backend.get_processed_path(dataset_name, model)
|
|
91
|
+
meta_path = path / "meta.json"
|
|
92
|
+
with open(meta_path) as f:
|
|
93
|
+
meta = json.load(f)
|
|
94
|
+
print(f" {meta['num_samples']} samples, {meta['total_tokens']} tokens")
|
|
95
|
+
return meta
|
|
96
|
+
|
|
97
|
+
# Tokenize
|
|
98
|
+
print(f"Tokenizing {len(samples)} samples...")
|
|
99
|
+
tokenized = tokenize_dataset(
|
|
100
|
+
samples,
|
|
101
|
+
tokenizer,
|
|
102
|
+
fmt,
|
|
103
|
+
mask_prompt=mask_prompt,
|
|
104
|
+
max_seq_length=max_seq_length,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
# Save via datasets backend
|
|
108
|
+
print("Saving to datasets backend...")
|
|
109
|
+
path = backend.save_tokenized(dataset_name, model, tokenized)
|
|
110
|
+
|
|
111
|
+
meta_path = path / "meta.json"
|
|
112
|
+
with open(meta_path) as f:
|
|
113
|
+
meta = json.load(f)
|
|
114
|
+
|
|
115
|
+
print(f" Preprocessed {meta['num_samples']} samples")
|
|
116
|
+
print(f" Total tokens: {meta['total_tokens']}")
|
|
117
|
+
print(f" Min/mean/max length: {meta['min_length']}/{meta['mean_length']:.1f}/{meta['max_length']}")
|
|
118
|
+
|
|
119
|
+
return meta
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def train(config, resume: str | None = None): # -> TrainState
|
|
123
|
+
"""Run LoRA SFT training from a config file or TrainingConfig object.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
config: Path to a YAML config file (str) or a TrainingConfig instance.
|
|
127
|
+
resume: Path to checkpoint directory to resume from.
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
Final TrainState after training completes.
|
|
131
|
+
"""
|
|
132
|
+
import yaml
|
|
133
|
+
|
|
134
|
+
from mlx_forge.adapters.lora import apply_lora
|
|
135
|
+
from mlx_forge.adapters.targeting import get_patterns, resolve_targets
|
|
136
|
+
from mlx_forge.config import TrainingConfig
|
|
137
|
+
from mlx_forge.data import backend
|
|
138
|
+
from mlx_forge.manifest import write_manifest
|
|
139
|
+
from mlx_forge.models.loader import load_model
|
|
140
|
+
from mlx_forge.models.resolve import resolve_model
|
|
141
|
+
from mlx_forge.trainer.callbacks import ConsoleCallback, MetricsLoggerCallback
|
|
142
|
+
from mlx_forge.trainer.trainer import Trainer
|
|
143
|
+
|
|
144
|
+
# Load config if it's a path
|
|
145
|
+
if isinstance(config, str):
|
|
146
|
+
config = TrainingConfig.from_yaml(config)
|
|
147
|
+
|
|
148
|
+
print("MLX Forge v0 — Training")
|
|
149
|
+
print(f"Model: {config.model.path}")
|
|
150
|
+
print(f"Adapter: {config.adapter.method} (rank={config.adapter.rank})")
|
|
151
|
+
print()
|
|
152
|
+
|
|
153
|
+
# Resolve model (HF repo ID -> local path)
|
|
154
|
+
print("Resolving model...")
|
|
155
|
+
resolved_model = resolve_model(
|
|
156
|
+
config.model.path,
|
|
157
|
+
revision=config.model.revision,
|
|
158
|
+
trust_remote_code=config.model.trust_remote_code,
|
|
159
|
+
)
|
|
160
|
+
print()
|
|
161
|
+
|
|
162
|
+
# Resolve tokenizer if separate path specified
|
|
163
|
+
if config.model.tokenizer_path:
|
|
164
|
+
print("Resolving tokenizer...")
|
|
165
|
+
resolved_tokenizer = resolve_model(
|
|
166
|
+
config.model.tokenizer_path,
|
|
167
|
+
trust_remote_code=config.model.trust_remote_code,
|
|
168
|
+
)
|
|
169
|
+
tokenizer_path = resolved_tokenizer.local_path
|
|
170
|
+
print()
|
|
171
|
+
else:
|
|
172
|
+
tokenizer_path = None
|
|
173
|
+
|
|
174
|
+
# Create run directory
|
|
175
|
+
from mlx_forge.trainer.checkpoint import CheckpointManager
|
|
176
|
+
manager = CheckpointManager(config)
|
|
177
|
+
run_dir = manager.run_dir
|
|
178
|
+
run_dir.mkdir(parents=True, exist_ok=True)
|
|
179
|
+
|
|
180
|
+
print(f"Run directory: {run_dir}")
|
|
181
|
+
print()
|
|
182
|
+
|
|
183
|
+
# Write config.yaml
|
|
184
|
+
(run_dir / "config.yaml").write_text(yaml.dump(config.model_dump(), default_flow_style=False))
|
|
185
|
+
|
|
186
|
+
# Load model and tokenizer
|
|
187
|
+
print("Loading model and tokenizer...")
|
|
188
|
+
model, tokenizer = load_model(
|
|
189
|
+
resolved_model.local_path,
|
|
190
|
+
tokenizer_path=tokenizer_path,
|
|
191
|
+
trust_remote_code=config.model.trust_remote_code,
|
|
192
|
+
)
|
|
193
|
+
print(f"Model loaded: {type(model).__name__}")
|
|
194
|
+
print()
|
|
195
|
+
|
|
196
|
+
# Quantize model if configured (QLoRA: quantize THEN apply LoRA)
|
|
197
|
+
if config.model.quantization:
|
|
198
|
+
from mlx_forge.models.quantize import quantize_model
|
|
199
|
+
quantize_model(model, config.model.quantization)
|
|
200
|
+
print(f"Quantized to {config.model.quantization.bits}-bit "
|
|
201
|
+
f"(group_size={config.model.quantization.group_size})")
|
|
202
|
+
print()
|
|
203
|
+
|
|
204
|
+
# Apply LoRA adapters
|
|
205
|
+
print("Applying LoRA adapters...")
|
|
206
|
+
patterns = get_patterns(config.adapter)
|
|
207
|
+
targets = resolve_targets(model, patterns, config.adapter.num_layers)
|
|
208
|
+
print(f"Matched {len(targets)} modules")
|
|
209
|
+
|
|
210
|
+
apply_lora(model, targets, config.adapter)
|
|
211
|
+
|
|
212
|
+
# Count parameters
|
|
213
|
+
from mlx.utils import tree_flatten
|
|
214
|
+
trainable_params = sum(p.size for _, p in tree_flatten(model.trainable_parameters()))
|
|
215
|
+
total_params = sum(p.size for _, p in tree_flatten(model.parameters()))
|
|
216
|
+
print(f"Trainable parameters: {trainable_params:,} / {total_params:,} ({100 * trainable_params / total_params:.2f}%)")
|
|
217
|
+
print()
|
|
218
|
+
|
|
219
|
+
# Enable gradient checkpointing if configured
|
|
220
|
+
if config.training.gradient_checkpointing:
|
|
221
|
+
_enable_gradient_checkpointing(model)
|
|
222
|
+
print("Gradient checkpointing enabled")
|
|
223
|
+
print()
|
|
224
|
+
|
|
225
|
+
# Load or prepare training and validation data
|
|
226
|
+
tokenizer_for_data = config.model.tokenizer_path or config.model.path
|
|
227
|
+
|
|
228
|
+
def _load_or_prepare(data_path: str, label: str):
|
|
229
|
+
"""Load data from backend or run prepare if not cached."""
|
|
230
|
+
print(f"Loading {label} data...")
|
|
231
|
+
dataset_name = Path(data_path).stem
|
|
232
|
+
|
|
233
|
+
if backend.tokenized_exists(dataset_name, config.model.path):
|
|
234
|
+
print(f"Cache hit: {dataset_name}")
|
|
235
|
+
ds = backend.load_tokenized(dataset_name, config.model.path)
|
|
236
|
+
print(f" {len(ds)} samples (memory-mapped)")
|
|
237
|
+
return dataset_name, ds
|
|
238
|
+
else:
|
|
239
|
+
print(f"Cache miss for {data_path}. Running prepare...")
|
|
240
|
+
prepare(
|
|
241
|
+
data_path,
|
|
242
|
+
tokenizer_for_data,
|
|
243
|
+
name=dataset_name,
|
|
244
|
+
trust_remote_code=config.model.trust_remote_code,
|
|
245
|
+
max_seq_length=config.data.max_seq_length,
|
|
246
|
+
mask_prompt=config.data.mask_prompt,
|
|
247
|
+
)
|
|
248
|
+
ds = backend.load_tokenized(dataset_name, config.model.path)
|
|
249
|
+
return dataset_name, ds
|
|
250
|
+
|
|
251
|
+
# Multi-source mixing or single dataset
|
|
252
|
+
if config.data.sources:
|
|
253
|
+
from mlx_forge.data.mixing import MixedDatasetIterator
|
|
254
|
+
|
|
255
|
+
source_datasets = []
|
|
256
|
+
source_weights = []
|
|
257
|
+
for src in config.data.sources:
|
|
258
|
+
data_path = src.path or src.dataset
|
|
259
|
+
_, ds = _load_or_prepare(data_path, f"source ({data_path})")
|
|
260
|
+
source_datasets.append(ds)
|
|
261
|
+
source_weights.append(src.weight)
|
|
262
|
+
|
|
263
|
+
train_dataset = MixedDatasetIterator(
|
|
264
|
+
source_datasets, source_weights,
|
|
265
|
+
seed=config.training.seed,
|
|
266
|
+
)
|
|
267
|
+
train_name = "mixed"
|
|
268
|
+
train_fingerprint = "mixed"
|
|
269
|
+
print(f"Mixed dataset: {len(config.data.sources)} sources")
|
|
270
|
+
else:
|
|
271
|
+
train_name, train_dataset = _load_or_prepare(config.data.train, "training")
|
|
272
|
+
train_fingerprint = backend.compute_fingerprint(config.data.train, tokenizer)
|
|
273
|
+
|
|
274
|
+
_, val_dataset = _load_or_prepare(config.data.valid, "validation")
|
|
275
|
+
print()
|
|
276
|
+
|
|
277
|
+
# Write manifest
|
|
278
|
+
print("Writing manifest...")
|
|
279
|
+
write_manifest(
|
|
280
|
+
run_dir,
|
|
281
|
+
config.model_dump(),
|
|
282
|
+
train_fingerprint,
|
|
283
|
+
resolved_model.resolution_metadata,
|
|
284
|
+
)
|
|
285
|
+
print(f"Manifest written: {run_dir / 'manifest.json'}")
|
|
286
|
+
print()
|
|
287
|
+
|
|
288
|
+
# Create callbacks
|
|
289
|
+
callbacks = [
|
|
290
|
+
ConsoleCallback(num_iters=config.training.num_iters),
|
|
291
|
+
MetricsLoggerCallback(log_path=run_dir / "logs" / "metrics.jsonl"),
|
|
292
|
+
]
|
|
293
|
+
|
|
294
|
+
# Add WandB callback if configured
|
|
295
|
+
if hasattr(config.training, 'wandb_project') and config.training.wandb_project:
|
|
296
|
+
try:
|
|
297
|
+
from mlx_forge.trainer.callbacks import WandBCallback
|
|
298
|
+
callbacks.append(
|
|
299
|
+
WandBCallback(
|
|
300
|
+
project=config.training.wandb_project,
|
|
301
|
+
run_name=run_dir.name,
|
|
302
|
+
config=config.model_dump(),
|
|
303
|
+
)
|
|
304
|
+
)
|
|
305
|
+
print("WandB logging enabled")
|
|
306
|
+
except ImportError:
|
|
307
|
+
print("Warning: wandb not installed, skipping WandB logging")
|
|
308
|
+
|
|
309
|
+
# Create trainer (SFT or DPO based on training_type)
|
|
310
|
+
if config.training.training_type == "dpo":
|
|
311
|
+
from mlx_forge.trainer.dpo_trainer import DPOTrainer
|
|
312
|
+
trainer = DPOTrainer(
|
|
313
|
+
model=model,
|
|
314
|
+
config=config,
|
|
315
|
+
train_dataset=train_dataset,
|
|
316
|
+
val_dataset=val_dataset,
|
|
317
|
+
callbacks=callbacks,
|
|
318
|
+
checkpoint_manager=manager,
|
|
319
|
+
)
|
|
320
|
+
else:
|
|
321
|
+
trainer = Trainer(
|
|
322
|
+
model=model,
|
|
323
|
+
config=config,
|
|
324
|
+
train_dataset=train_dataset,
|
|
325
|
+
val_dataset=val_dataset,
|
|
326
|
+
callbacks=callbacks,
|
|
327
|
+
checkpoint_manager=manager,
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
# Handle resume from checkpoint
|
|
331
|
+
if resume:
|
|
332
|
+
resume_path = Path(resume).expanduser()
|
|
333
|
+
_validate_resume(resume_path, config)
|
|
334
|
+
restored_state = manager.load(resume_path, model, trainer.optimizer)
|
|
335
|
+
trainer.state = restored_state
|
|
336
|
+
print(f"Resumed from {resume_path} at step {restored_state.step}")
|
|
337
|
+
print()
|
|
338
|
+
|
|
339
|
+
# Run training
|
|
340
|
+
print("Starting training...")
|
|
341
|
+
print()
|
|
342
|
+
final_state = trainer.fit()
|
|
343
|
+
|
|
344
|
+
print()
|
|
345
|
+
print("Training complete!")
|
|
346
|
+
print(f"Final step: {final_state.step}")
|
|
347
|
+
print(f"Best validation loss: {final_state.best_val_loss:.4f}")
|
|
348
|
+
print(f"Total tokens trained: {final_state.trained_tokens:,}")
|
|
349
|
+
print(f"Checkpoints saved to: {run_dir / 'checkpoints'}")
|
|
350
|
+
|
|
351
|
+
return final_state
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
def _validate_resume(resume_path: Path, config) -> None:
|
|
355
|
+
"""Validate that a checkpoint directory is compatible with the current config."""
|
|
356
|
+
if not resume_path.exists():
|
|
357
|
+
raise FileNotFoundError(f"Checkpoint directory not found: {resume_path}")
|
|
358
|
+
|
|
359
|
+
required = ["adapters.safetensors", "optimizer.safetensors", "state.json"]
|
|
360
|
+
missing = [f for f in required if not (resume_path / f).exists()]
|
|
361
|
+
if missing:
|
|
362
|
+
raise FileNotFoundError(
|
|
363
|
+
f"Checkpoint missing {', '.join(missing)} in {resume_path}. "
|
|
364
|
+
f"Expected files: {', '.join(required)}"
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
state = json.loads((resume_path / "state.json").read_text())
|
|
368
|
+
if state.get("schema_version", 1) > 1:
|
|
369
|
+
raise ValueError(
|
|
370
|
+
f"Checkpoint schema version {state['schema_version']} is newer than "
|
|
371
|
+
f"supported version 1. Please upgrade MLX Forge."
|
|
372
|
+
)
|
|
373
|
+
if state["step"] >= config.training.num_iters:
|
|
374
|
+
raise ValueError(
|
|
375
|
+
f"Checkpoint is at step {state['step']} but training is configured "
|
|
376
|
+
f"for {config.training.num_iters} iterations. "
|
|
377
|
+
f"Increase 'num_iters' in your config to continue training."
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
def _enable_gradient_checkpointing(model) -> None:
|
|
382
|
+
"""Wrap each transformer layer's __call__ with mx.checkpoint."""
|
|
383
|
+
import mlx.core as mx
|
|
384
|
+
|
|
385
|
+
if hasattr(model, "model") and hasattr(model.model, "layers"):
|
|
386
|
+
for layer in model.model.layers:
|
|
387
|
+
layer.__call__ = mx.checkpoint(layer.__call__)
|
|
388
|
+
|
|
389
|
+
|
|
390
|
+
def generate(
|
|
391
|
+
model: str,
|
|
392
|
+
prompt: str | None = None,
|
|
393
|
+
messages: list[dict] | None = None,
|
|
394
|
+
*,
|
|
395
|
+
adapter: str | None = None,
|
|
396
|
+
temperature: float = 0.7,
|
|
397
|
+
top_p: float = 0.9,
|
|
398
|
+
max_tokens: int = 512,
|
|
399
|
+
repetition_penalty: float = 1.0,
|
|
400
|
+
trust_remote_code: bool = False,
|
|
401
|
+
seed: int | None = None,
|
|
402
|
+
stream: bool = False,
|
|
403
|
+
) -> GenerationResult:
|
|
404
|
+
"""Generate text from a model with optional LoRA adapter."""
|
|
405
|
+
from mlx_forge.inference.engine import (
|
|
406
|
+
generate as _generate,
|
|
407
|
+
)
|
|
408
|
+
from mlx_forge.inference.engine import (
|
|
409
|
+
generate_tokens,
|
|
410
|
+
load_for_inference,
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
loaded_model, tokenizer = load_for_inference(
|
|
414
|
+
model,
|
|
415
|
+
adapter_path=adapter,
|
|
416
|
+
trust_remote_code=trust_remote_code,
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
if stream:
|
|
420
|
+
if messages is not None:
|
|
421
|
+
prompt_tokens = tokenizer.apply_chat_template(
|
|
422
|
+
messages, add_generation_prompt=True
|
|
423
|
+
)
|
|
424
|
+
if isinstance(prompt_tokens, dict):
|
|
425
|
+
prompt_tokens = prompt_tokens["input_ids"]
|
|
426
|
+
elif prompt is not None:
|
|
427
|
+
prompt_tokens = tokenizer.encode(prompt)
|
|
428
|
+
else:
|
|
429
|
+
raise ValueError("Must provide either 'prompt' or 'messages'")
|
|
430
|
+
|
|
431
|
+
def _stream():
|
|
432
|
+
for token_id in generate_tokens(
|
|
433
|
+
loaded_model,
|
|
434
|
+
prompt_tokens,
|
|
435
|
+
tokenizer,
|
|
436
|
+
temperature=temperature,
|
|
437
|
+
top_p=top_p,
|
|
438
|
+
max_tokens=max_tokens,
|
|
439
|
+
repetition_penalty=repetition_penalty,
|
|
440
|
+
seed=seed,
|
|
441
|
+
):
|
|
442
|
+
yield tokenizer.decode([token_id])
|
|
443
|
+
|
|
444
|
+
return _stream()
|
|
445
|
+
|
|
446
|
+
return _generate(
|
|
447
|
+
loaded_model,
|
|
448
|
+
tokenizer,
|
|
449
|
+
prompt=prompt,
|
|
450
|
+
messages=messages,
|
|
451
|
+
temperature=temperature,
|
|
452
|
+
top_p=top_p,
|
|
453
|
+
max_tokens=max_tokens,
|
|
454
|
+
repetition_penalty=repetition_penalty,
|
|
455
|
+
seed=seed,
|
|
456
|
+
)
|
mlx_forge/_version.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.2.0"
|
|
File without changes
|