mlx-forge 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (100) hide show
  1. mlx_forge/__init__.py +456 -0
  2. mlx_forge/_version.py +1 -0
  3. mlx_forge/adapters/__init__.py +0 -0
  4. mlx_forge/adapters/lora.py +287 -0
  5. mlx_forge/adapters/targeting.py +162 -0
  6. mlx_forge/cli/__init__.py +0 -0
  7. mlx_forge/cli/data_cmd.py +220 -0
  8. mlx_forge/cli/generate_cmd.py +99 -0
  9. mlx_forge/cli/main.py +254 -0
  10. mlx_forge/cli/prepare_cmd.py +23 -0
  11. mlx_forge/cli/studio_cmd.py +21 -0
  12. mlx_forge/cli/train_cmd.py +11 -0
  13. mlx_forge/config.py +176 -0
  14. mlx_forge/data/__init__.py +0 -0
  15. mlx_forge/data/backend.py +157 -0
  16. mlx_forge/data/batching.py +238 -0
  17. mlx_forge/data/catalog.py +370 -0
  18. mlx_forge/data/converter.py +204 -0
  19. mlx_forge/data/formats.py +191 -0
  20. mlx_forge/data/mixing.py +61 -0
  21. mlx_forge/data/packing.py +93 -0
  22. mlx_forge/data/preprocessing.py +198 -0
  23. mlx_forge/data/registry.py +221 -0
  24. mlx_forge/data/validate.py +227 -0
  25. mlx_forge/inference/__init__.py +1 -0
  26. mlx_forge/inference/cache.py +122 -0
  27. mlx_forge/inference/engine.py +232 -0
  28. mlx_forge/inference/sampling.py +88 -0
  29. mlx_forge/logging/__init__.py +0 -0
  30. mlx_forge/logging/metrics.py +52 -0
  31. mlx_forge/losses/__init__.py +6 -0
  32. mlx_forge/losses/dpo.py +98 -0
  33. mlx_forge/losses/sft.py +77 -0
  34. mlx_forge/manifest.py +164 -0
  35. mlx_forge/models/__init__.py +0 -0
  36. mlx_forge/models/_base/__init__.py +12 -0
  37. mlx_forge/models/_base/activations.py +25 -0
  38. mlx_forge/models/_base/args.py +38 -0
  39. mlx_forge/models/_base/attention.py +101 -0
  40. mlx_forge/models/_base/rope.py +276 -0
  41. mlx_forge/models/architectures/__init__.py +4 -0
  42. mlx_forge/models/architectures/gemma.py +352 -0
  43. mlx_forge/models/architectures/llama.py +236 -0
  44. mlx_forge/models/architectures/phi3.py +261 -0
  45. mlx_forge/models/architectures/phi4.py +221 -0
  46. mlx_forge/models/architectures/qwen2.py +221 -0
  47. mlx_forge/models/architectures/qwen3.py +231 -0
  48. mlx_forge/models/architectures/qwen3_5.py +752 -0
  49. mlx_forge/models/loader.py +143 -0
  50. mlx_forge/models/memory.py +376 -0
  51. mlx_forge/models/quantize.py +39 -0
  52. mlx_forge/models/registry.py +108 -0
  53. mlx_forge/models/resolve.py +205 -0
  54. mlx_forge/recipes/__init__.py +5 -0
  55. mlx_forge/recipes/auto_config.py +104 -0
  56. mlx_forge/recipes/built_in/chat_sft.yaml +42 -0
  57. mlx_forge/recipes/built_in/instruction_sft.yaml +42 -0
  58. mlx_forge/recipes/built_in/preference_dpo.yaml +46 -0
  59. mlx_forge/recipes/built_in/writing_style.yaml +42 -0
  60. mlx_forge/recipes/registry.py +90 -0
  61. mlx_forge/studio/__init__.py +4 -0
  62. mlx_forge/studio/api/__init__.py +1 -0
  63. mlx_forge/studio/api/config_schema.py +152 -0
  64. mlx_forge/studio/api/data_library.py +76 -0
  65. mlx_forge/studio/api/datasets.py +44 -0
  66. mlx_forge/studio/api/inference.py +62 -0
  67. mlx_forge/studio/api/memory.py +58 -0
  68. mlx_forge/studio/api/models.py +46 -0
  69. mlx_forge/studio/api/queue.py +63 -0
  70. mlx_forge/studio/api/recipes.py +67 -0
  71. mlx_forge/studio/api/runs.py +73 -0
  72. mlx_forge/studio/api/training.py +48 -0
  73. mlx_forge/studio/frontend/assets/index-DfE9wCUu.js +46 -0
  74. mlx_forge/studio/frontend/assets/index-DoKRRrtV.css +1 -0
  75. mlx_forge/studio/frontend/index.html +14 -0
  76. mlx_forge/studio/server.py +210 -0
  77. mlx_forge/studio/services/__init__.py +1 -0
  78. mlx_forge/studio/services/data_library_service.py +46 -0
  79. mlx_forge/studio/services/dataset_service.py +73 -0
  80. mlx_forge/studio/services/memory_service.py +71 -0
  81. mlx_forge/studio/services/metrics_watcher.py +56 -0
  82. mlx_forge/studio/services/model_library_service.py +77 -0
  83. mlx_forge/studio/services/model_service.py +107 -0
  84. mlx_forge/studio/services/queue_service.py +178 -0
  85. mlx_forge/studio/services/recipe_service.py +47 -0
  86. mlx_forge/studio/services/run_service.py +242 -0
  87. mlx_forge/studio/services/training_service.py +113 -0
  88. mlx_forge/trainer/__init__.py +0 -0
  89. mlx_forge/trainer/callbacks.py +150 -0
  90. mlx_forge/trainer/checkpoint.py +187 -0
  91. mlx_forge/trainer/dpo_trainer.py +118 -0
  92. mlx_forge/trainer/optimizer.py +123 -0
  93. mlx_forge/trainer/state.py +20 -0
  94. mlx_forge/trainer/trainer.py +319 -0
  95. mlx_forge-0.2.0.dist-info/METADATA +246 -0
  96. mlx_forge-0.2.0.dist-info/RECORD +100 -0
  97. mlx_forge-0.2.0.dist-info/WHEEL +5 -0
  98. mlx_forge-0.2.0.dist-info/entry_points.txt +2 -0
  99. mlx_forge-0.2.0.dist-info/licenses/LICENSE +21 -0
  100. mlx_forge-0.2.0.dist-info/top_level.txt +1 -0
mlx_forge/__init__.py ADDED
@@ -0,0 +1,456 @@
1
+ """MLX Forge — LoRA SFT training framework for MLX on Apple Silicon."""
2
+
3
+ import json
4
+ from pathlib import Path
5
+
6
+ from transformers import AutoTokenizer
7
+
8
+ from mlx_forge._version import __version__ as __version__
9
+ from mlx_forge.data.formats import detect_format, validate_samples
10
+ from mlx_forge.data.preprocessing import tokenize_dataset
11
+ from mlx_forge.inference.engine import GenerationResult
12
+
13
+
14
+ def prepare(
15
+ data_path: str,
16
+ model: str,
17
+ output: str | None = None,
18
+ *,
19
+ name: str | None = None,
20
+ trust_remote_code: bool = False,
21
+ max_seq_length: int = 2048,
22
+ mask_prompt: bool = True,
23
+ revision: str | None = None,
24
+ ) -> dict:
25
+ """Pre-tokenize a dataset and save as Arrow dataset for memory-mapped access.
26
+
27
+ Args:
28
+ data_path: Path to JSONL data file
29
+ model: HuggingFace model ID or local path (for tokenizer)
30
+ output: Ignored (kept for CLI compat). Storage is now in ~/.mlxforge/datasets/
31
+ name: Dataset name for the registry. If omitted, derived from filename.
32
+ trust_remote_code: Trust remote code when loading tokenizer
33
+ max_seq_length: Maximum sequence length
34
+ mask_prompt: Mask prompt tokens from loss
35
+ revision: Optional HF revision/commit hash
36
+
37
+ Returns:
38
+ Dict of statistics (sample count, total tokens, etc.)
39
+ """
40
+ from mlx_forge.data import backend
41
+ from mlx_forge.models.resolve import resolve_model
42
+
43
+ # Resolve model (HF repo ID -> local path)
44
+ print(f"Resolving model: {model}...")
45
+ resolved = resolve_model(
46
+ model,
47
+ revision=revision,
48
+ trust_remote_code=trust_remote_code,
49
+ )
50
+ print()
51
+
52
+ # Load tokenizer
53
+ print(f"Loading tokenizer from {resolved.local_path}...")
54
+ tokenizer = AutoTokenizer.from_pretrained(
55
+ resolved.local_path,
56
+ trust_remote_code=trust_remote_code,
57
+ )
58
+
59
+ # Read JSONL
60
+ print(f"Reading {data_path}...")
61
+ data_path_obj = Path(data_path)
62
+ if not data_path_obj.exists():
63
+ raise FileNotFoundError(f"Data file not found: {data_path}")
64
+
65
+ with open(data_path_obj) as f:
66
+ samples = [json.loads(line) for line in f if line.strip()]
67
+
68
+ if not samples:
69
+ raise ValueError(f"No samples found in {data_path}")
70
+
71
+ # Detect format
72
+ fmt = detect_format(samples)
73
+ print(f"Detected format: {fmt}")
74
+
75
+ # Validate samples
76
+ print(f"Validating {len(samples)} samples...")
77
+ errors = validate_samples(samples, fmt)
78
+ if errors:
79
+ error_msg = "\n".join(errors[:10])
80
+ if len(errors) > 10:
81
+ error_msg += f"\n... and {len(errors) - 10} more errors"
82
+ raise ValueError(f"Validation failed:\n{error_msg}")
83
+
84
+ # Derive dataset name from filename if not provided
85
+ dataset_name = name or data_path_obj.stem
86
+
87
+ # Check if already processed
88
+ if backend.tokenized_exists(dataset_name, model):
89
+ print(f"Already processed: {dataset_name} for {model}")
90
+ path = backend.get_processed_path(dataset_name, model)
91
+ meta_path = path / "meta.json"
92
+ with open(meta_path) as f:
93
+ meta = json.load(f)
94
+ print(f" {meta['num_samples']} samples, {meta['total_tokens']} tokens")
95
+ return meta
96
+
97
+ # Tokenize
98
+ print(f"Tokenizing {len(samples)} samples...")
99
+ tokenized = tokenize_dataset(
100
+ samples,
101
+ tokenizer,
102
+ fmt,
103
+ mask_prompt=mask_prompt,
104
+ max_seq_length=max_seq_length,
105
+ )
106
+
107
+ # Save via datasets backend
108
+ print("Saving to datasets backend...")
109
+ path = backend.save_tokenized(dataset_name, model, tokenized)
110
+
111
+ meta_path = path / "meta.json"
112
+ with open(meta_path) as f:
113
+ meta = json.load(f)
114
+
115
+ print(f" Preprocessed {meta['num_samples']} samples")
116
+ print(f" Total tokens: {meta['total_tokens']}")
117
+ print(f" Min/mean/max length: {meta['min_length']}/{meta['mean_length']:.1f}/{meta['max_length']}")
118
+
119
+ return meta
120
+
121
+
122
+ def train(config, resume: str | None = None): # -> TrainState
123
+ """Run LoRA SFT training from a config file or TrainingConfig object.
124
+
125
+ Args:
126
+ config: Path to a YAML config file (str) or a TrainingConfig instance.
127
+ resume: Path to checkpoint directory to resume from.
128
+
129
+ Returns:
130
+ Final TrainState after training completes.
131
+ """
132
+ import yaml
133
+
134
+ from mlx_forge.adapters.lora import apply_lora
135
+ from mlx_forge.adapters.targeting import get_patterns, resolve_targets
136
+ from mlx_forge.config import TrainingConfig
137
+ from mlx_forge.data import backend
138
+ from mlx_forge.manifest import write_manifest
139
+ from mlx_forge.models.loader import load_model
140
+ from mlx_forge.models.resolve import resolve_model
141
+ from mlx_forge.trainer.callbacks import ConsoleCallback, MetricsLoggerCallback
142
+ from mlx_forge.trainer.trainer import Trainer
143
+
144
+ # Load config if it's a path
145
+ if isinstance(config, str):
146
+ config = TrainingConfig.from_yaml(config)
147
+
148
+ print("MLX Forge v0 — Training")
149
+ print(f"Model: {config.model.path}")
150
+ print(f"Adapter: {config.adapter.method} (rank={config.adapter.rank})")
151
+ print()
152
+
153
+ # Resolve model (HF repo ID -> local path)
154
+ print("Resolving model...")
155
+ resolved_model = resolve_model(
156
+ config.model.path,
157
+ revision=config.model.revision,
158
+ trust_remote_code=config.model.trust_remote_code,
159
+ )
160
+ print()
161
+
162
+ # Resolve tokenizer if separate path specified
163
+ if config.model.tokenizer_path:
164
+ print("Resolving tokenizer...")
165
+ resolved_tokenizer = resolve_model(
166
+ config.model.tokenizer_path,
167
+ trust_remote_code=config.model.trust_remote_code,
168
+ )
169
+ tokenizer_path = resolved_tokenizer.local_path
170
+ print()
171
+ else:
172
+ tokenizer_path = None
173
+
174
+ # Create run directory
175
+ from mlx_forge.trainer.checkpoint import CheckpointManager
176
+ manager = CheckpointManager(config)
177
+ run_dir = manager.run_dir
178
+ run_dir.mkdir(parents=True, exist_ok=True)
179
+
180
+ print(f"Run directory: {run_dir}")
181
+ print()
182
+
183
+ # Write config.yaml
184
+ (run_dir / "config.yaml").write_text(yaml.dump(config.model_dump(), default_flow_style=False))
185
+
186
+ # Load model and tokenizer
187
+ print("Loading model and tokenizer...")
188
+ model, tokenizer = load_model(
189
+ resolved_model.local_path,
190
+ tokenizer_path=tokenizer_path,
191
+ trust_remote_code=config.model.trust_remote_code,
192
+ )
193
+ print(f"Model loaded: {type(model).__name__}")
194
+ print()
195
+
196
+ # Quantize model if configured (QLoRA: quantize THEN apply LoRA)
197
+ if config.model.quantization:
198
+ from mlx_forge.models.quantize import quantize_model
199
+ quantize_model(model, config.model.quantization)
200
+ print(f"Quantized to {config.model.quantization.bits}-bit "
201
+ f"(group_size={config.model.quantization.group_size})")
202
+ print()
203
+
204
+ # Apply LoRA adapters
205
+ print("Applying LoRA adapters...")
206
+ patterns = get_patterns(config.adapter)
207
+ targets = resolve_targets(model, patterns, config.adapter.num_layers)
208
+ print(f"Matched {len(targets)} modules")
209
+
210
+ apply_lora(model, targets, config.adapter)
211
+
212
+ # Count parameters
213
+ from mlx.utils import tree_flatten
214
+ trainable_params = sum(p.size for _, p in tree_flatten(model.trainable_parameters()))
215
+ total_params = sum(p.size for _, p in tree_flatten(model.parameters()))
216
+ print(f"Trainable parameters: {trainable_params:,} / {total_params:,} ({100 * trainable_params / total_params:.2f}%)")
217
+ print()
218
+
219
+ # Enable gradient checkpointing if configured
220
+ if config.training.gradient_checkpointing:
221
+ _enable_gradient_checkpointing(model)
222
+ print("Gradient checkpointing enabled")
223
+ print()
224
+
225
+ # Load or prepare training and validation data
226
+ tokenizer_for_data = config.model.tokenizer_path or config.model.path
227
+
228
+ def _load_or_prepare(data_path: str, label: str):
229
+ """Load data from backend or run prepare if not cached."""
230
+ print(f"Loading {label} data...")
231
+ dataset_name = Path(data_path).stem
232
+
233
+ if backend.tokenized_exists(dataset_name, config.model.path):
234
+ print(f"Cache hit: {dataset_name}")
235
+ ds = backend.load_tokenized(dataset_name, config.model.path)
236
+ print(f" {len(ds)} samples (memory-mapped)")
237
+ return dataset_name, ds
238
+ else:
239
+ print(f"Cache miss for {data_path}. Running prepare...")
240
+ prepare(
241
+ data_path,
242
+ tokenizer_for_data,
243
+ name=dataset_name,
244
+ trust_remote_code=config.model.trust_remote_code,
245
+ max_seq_length=config.data.max_seq_length,
246
+ mask_prompt=config.data.mask_prompt,
247
+ )
248
+ ds = backend.load_tokenized(dataset_name, config.model.path)
249
+ return dataset_name, ds
250
+
251
+ # Multi-source mixing or single dataset
252
+ if config.data.sources:
253
+ from mlx_forge.data.mixing import MixedDatasetIterator
254
+
255
+ source_datasets = []
256
+ source_weights = []
257
+ for src in config.data.sources:
258
+ data_path = src.path or src.dataset
259
+ _, ds = _load_or_prepare(data_path, f"source ({data_path})")
260
+ source_datasets.append(ds)
261
+ source_weights.append(src.weight)
262
+
263
+ train_dataset = MixedDatasetIterator(
264
+ source_datasets, source_weights,
265
+ seed=config.training.seed,
266
+ )
267
+ train_name = "mixed"
268
+ train_fingerprint = "mixed"
269
+ print(f"Mixed dataset: {len(config.data.sources)} sources")
270
+ else:
271
+ train_name, train_dataset = _load_or_prepare(config.data.train, "training")
272
+ train_fingerprint = backend.compute_fingerprint(config.data.train, tokenizer)
273
+
274
+ _, val_dataset = _load_or_prepare(config.data.valid, "validation")
275
+ print()
276
+
277
+ # Write manifest
278
+ print("Writing manifest...")
279
+ write_manifest(
280
+ run_dir,
281
+ config.model_dump(),
282
+ train_fingerprint,
283
+ resolved_model.resolution_metadata,
284
+ )
285
+ print(f"Manifest written: {run_dir / 'manifest.json'}")
286
+ print()
287
+
288
+ # Create callbacks
289
+ callbacks = [
290
+ ConsoleCallback(num_iters=config.training.num_iters),
291
+ MetricsLoggerCallback(log_path=run_dir / "logs" / "metrics.jsonl"),
292
+ ]
293
+
294
+ # Add WandB callback if configured
295
+ if hasattr(config.training, 'wandb_project') and config.training.wandb_project:
296
+ try:
297
+ from mlx_forge.trainer.callbacks import WandBCallback
298
+ callbacks.append(
299
+ WandBCallback(
300
+ project=config.training.wandb_project,
301
+ run_name=run_dir.name,
302
+ config=config.model_dump(),
303
+ )
304
+ )
305
+ print("WandB logging enabled")
306
+ except ImportError:
307
+ print("Warning: wandb not installed, skipping WandB logging")
308
+
309
+ # Create trainer (SFT or DPO based on training_type)
310
+ if config.training.training_type == "dpo":
311
+ from mlx_forge.trainer.dpo_trainer import DPOTrainer
312
+ trainer = DPOTrainer(
313
+ model=model,
314
+ config=config,
315
+ train_dataset=train_dataset,
316
+ val_dataset=val_dataset,
317
+ callbacks=callbacks,
318
+ checkpoint_manager=manager,
319
+ )
320
+ else:
321
+ trainer = Trainer(
322
+ model=model,
323
+ config=config,
324
+ train_dataset=train_dataset,
325
+ val_dataset=val_dataset,
326
+ callbacks=callbacks,
327
+ checkpoint_manager=manager,
328
+ )
329
+
330
+ # Handle resume from checkpoint
331
+ if resume:
332
+ resume_path = Path(resume).expanduser()
333
+ _validate_resume(resume_path, config)
334
+ restored_state = manager.load(resume_path, model, trainer.optimizer)
335
+ trainer.state = restored_state
336
+ print(f"Resumed from {resume_path} at step {restored_state.step}")
337
+ print()
338
+
339
+ # Run training
340
+ print("Starting training...")
341
+ print()
342
+ final_state = trainer.fit()
343
+
344
+ print()
345
+ print("Training complete!")
346
+ print(f"Final step: {final_state.step}")
347
+ print(f"Best validation loss: {final_state.best_val_loss:.4f}")
348
+ print(f"Total tokens trained: {final_state.trained_tokens:,}")
349
+ print(f"Checkpoints saved to: {run_dir / 'checkpoints'}")
350
+
351
+ return final_state
352
+
353
+
354
+ def _validate_resume(resume_path: Path, config) -> None:
355
+ """Validate that a checkpoint directory is compatible with the current config."""
356
+ if not resume_path.exists():
357
+ raise FileNotFoundError(f"Checkpoint directory not found: {resume_path}")
358
+
359
+ required = ["adapters.safetensors", "optimizer.safetensors", "state.json"]
360
+ missing = [f for f in required if not (resume_path / f).exists()]
361
+ if missing:
362
+ raise FileNotFoundError(
363
+ f"Checkpoint missing {', '.join(missing)} in {resume_path}. "
364
+ f"Expected files: {', '.join(required)}"
365
+ )
366
+
367
+ state = json.loads((resume_path / "state.json").read_text())
368
+ if state.get("schema_version", 1) > 1:
369
+ raise ValueError(
370
+ f"Checkpoint schema version {state['schema_version']} is newer than "
371
+ f"supported version 1. Please upgrade MLX Forge."
372
+ )
373
+ if state["step"] >= config.training.num_iters:
374
+ raise ValueError(
375
+ f"Checkpoint is at step {state['step']} but training is configured "
376
+ f"for {config.training.num_iters} iterations. "
377
+ f"Increase 'num_iters' in your config to continue training."
378
+ )
379
+
380
+
381
+ def _enable_gradient_checkpointing(model) -> None:
382
+ """Wrap each transformer layer's __call__ with mx.checkpoint."""
383
+ import mlx.core as mx
384
+
385
+ if hasattr(model, "model") and hasattr(model.model, "layers"):
386
+ for layer in model.model.layers:
387
+ layer.__call__ = mx.checkpoint(layer.__call__)
388
+
389
+
390
+ def generate(
391
+ model: str,
392
+ prompt: str | None = None,
393
+ messages: list[dict] | None = None,
394
+ *,
395
+ adapter: str | None = None,
396
+ temperature: float = 0.7,
397
+ top_p: float = 0.9,
398
+ max_tokens: int = 512,
399
+ repetition_penalty: float = 1.0,
400
+ trust_remote_code: bool = False,
401
+ seed: int | None = None,
402
+ stream: bool = False,
403
+ ) -> GenerationResult:
404
+ """Generate text from a model with optional LoRA adapter."""
405
+ from mlx_forge.inference.engine import (
406
+ generate as _generate,
407
+ )
408
+ from mlx_forge.inference.engine import (
409
+ generate_tokens,
410
+ load_for_inference,
411
+ )
412
+
413
+ loaded_model, tokenizer = load_for_inference(
414
+ model,
415
+ adapter_path=adapter,
416
+ trust_remote_code=trust_remote_code,
417
+ )
418
+
419
+ if stream:
420
+ if messages is not None:
421
+ prompt_tokens = tokenizer.apply_chat_template(
422
+ messages, add_generation_prompt=True
423
+ )
424
+ if isinstance(prompt_tokens, dict):
425
+ prompt_tokens = prompt_tokens["input_ids"]
426
+ elif prompt is not None:
427
+ prompt_tokens = tokenizer.encode(prompt)
428
+ else:
429
+ raise ValueError("Must provide either 'prompt' or 'messages'")
430
+
431
+ def _stream():
432
+ for token_id in generate_tokens(
433
+ loaded_model,
434
+ prompt_tokens,
435
+ tokenizer,
436
+ temperature=temperature,
437
+ top_p=top_p,
438
+ max_tokens=max_tokens,
439
+ repetition_penalty=repetition_penalty,
440
+ seed=seed,
441
+ ):
442
+ yield tokenizer.decode([token_id])
443
+
444
+ return _stream()
445
+
446
+ return _generate(
447
+ loaded_model,
448
+ tokenizer,
449
+ prompt=prompt,
450
+ messages=messages,
451
+ temperature=temperature,
452
+ top_p=top_p,
453
+ max_tokens=max_tokens,
454
+ repetition_penalty=repetition_penalty,
455
+ seed=seed,
456
+ )
mlx_forge/_version.py ADDED
@@ -0,0 +1 @@
1
+ __version__ = "0.2.0"
File without changes