mlxsmith 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. mlxsmith/__init__.py +2 -0
  2. mlxsmith/accel/__init__.py +10 -0
  3. mlxsmith/accel/base.py +17 -0
  4. mlxsmith/accel/none.py +13 -0
  5. mlxsmith/accel/zmlx_backend.py +42 -0
  6. mlxsmith/adapters.py +46 -0
  7. mlxsmith/api/__init__.py +48 -0
  8. mlxsmith/api/handlers.py +1217 -0
  9. mlxsmith/api/schemas.py +436 -0
  10. mlxsmith/auth.py +88 -0
  11. mlxsmith/bench.py +102 -0
  12. mlxsmith/cli.py +950 -0
  13. mlxsmith/config.py +543 -0
  14. mlxsmith/config_models.py +261 -0
  15. mlxsmith/data.py +493 -0
  16. mlxsmith/envs/__init__.py +33 -0
  17. mlxsmith/envs/system.py +388 -0
  18. mlxsmith/envs/token_env.py +191 -0
  19. mlxsmith/eval.py +112 -0
  20. mlxsmith/infer.py +140 -0
  21. mlxsmith/llm/__init__.py +16 -0
  22. mlxsmith/llm/backend.py +126 -0
  23. mlxsmith/llm/interface.py +212 -0
  24. mlxsmith/llm/mlx_lm_backend.py +509 -0
  25. mlxsmith/llm/mock_backend.py +228 -0
  26. mlxsmith/llm/registry.py +12 -0
  27. mlxsmith/models.py +257 -0
  28. mlxsmith/orchestrator/__init__.py +25 -0
  29. mlxsmith/orchestrator/daemon.py +454 -0
  30. mlxsmith/orchestrator/inference_worker.py +496 -0
  31. mlxsmith/orchestrator/queue.py +355 -0
  32. mlxsmith/orchestrator/trainer_worker.py +437 -0
  33. mlxsmith/rlm/__init__.py +8 -0
  34. mlxsmith/rlm/corpus.py +74 -0
  35. mlxsmith/rlm/gating.py +90 -0
  36. mlxsmith/rlm/generate.py +249 -0
  37. mlxsmith/rlm/history.py +12 -0
  38. mlxsmith/rlm/inference.py +150 -0
  39. mlxsmith/rlm/loop.py +1297 -0
  40. mlxsmith/rlm/mutate.py +82 -0
  41. mlxsmith/rlm/trainer.py +73 -0
  42. mlxsmith/rlm/weights.py +263 -0
  43. mlxsmith/runs.py +44 -0
  44. mlxsmith/sdk/__init__.py +392 -0
  45. mlxsmith/sdk/future.py +486 -0
  46. mlxsmith/sdk/losses.py +262 -0
  47. mlxsmith/sdk/sampling_client.py +729 -0
  48. mlxsmith/sdk/training_client.py +676 -0
  49. mlxsmith/server.py +376 -0
  50. mlxsmith/train/__init__.py +0 -0
  51. mlxsmith/train/distill.py +279 -0
  52. mlxsmith/train/lora.py +280 -0
  53. mlxsmith/train/pref.py +180 -0
  54. mlxsmith/train/rft.py +458 -0
  55. mlxsmith/train/sft.py +151 -0
  56. mlxsmith/util.py +174 -0
  57. mlxsmith/verifiers/__init__.py +3 -0
  58. mlxsmith/verifiers/compose.py +109 -0
  59. mlxsmith/verifiers/docker_verifier.py +111 -0
  60. mlxsmith/verifiers/jsonschema.py +54 -0
  61. mlxsmith/verifiers/pytest_verifier.py +82 -0
  62. mlxsmith/verifiers/regex.py +15 -0
  63. mlxsmith/verifiers/types.py +10 -0
  64. mlxsmith-0.1.0.dist-info/METADATA +163 -0
  65. mlxsmith-0.1.0.dist-info/RECORD +69 -0
  66. mlxsmith-0.1.0.dist-info/WHEEL +5 -0
  67. mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
  68. mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
  69. mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
mlxsmith/train/lora.py ADDED
@@ -0,0 +1,280 @@
1
+ """LoRA adapter utilities.
2
+
3
+ This module prefers MLX-LM's LoRA utilities and adapter format when available.
4
+ Fallback implementations are provided for environments without MLX/MLX-LM.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ from dataclasses import dataclass
11
+ from pathlib import Path
12
+ from typing import Any
13
+
14
+
15
+ def _require_mlx():
16
+ import mlx.core as mx # type: ignore
17
+ import mlx.nn as nn # type: ignore
18
+ from mlx.utils import tree_flatten # type: ignore
19
+
20
+ return mx, nn, tree_flatten
21
+
22
+
23
+ def _try_mlx_lm_utils():
24
+ try:
25
+ from mlx_lm.tuner import utils as tuner_utils # type: ignore
26
+ from mlx_lm.utils import save_config as mlx_save_config # type: ignore
27
+ except Exception:
28
+ return None, None
29
+ return tuner_utils, mlx_save_config
30
+
31
+
32
+ @dataclass
33
+ class LoRAConfig:
34
+ r: int = 16
35
+ alpha: int = 32
36
+ dropout: float = 0.0
37
+ target_modules: list[str] | None = None
38
+ num_layers: int = 0 # 0 => all layers
39
+ scale: float | None = None
40
+ fine_tune_type: str = "lora" # lora | dora | full
41
+
42
+
43
+ class LoRALinear:
44
+ """Minimal fallback LoRA wrapper for mlx.nn.Linear."""
45
+
46
+ def __init__(self, base_linear: Any, *, r: int, alpha: int, dropout: float = 0.0):
47
+ mx, nn, _ = _require_mlx()
48
+ self.nn = nn
49
+ self.mx = mx
50
+
51
+ self.base = base_linear
52
+ self.r = int(r)
53
+ self.alpha = float(alpha)
54
+ self.scale = float(alpha) / float(r) if r > 0 else 0.0
55
+ self.dropout = nn.Dropout(p=float(dropout)) if dropout and dropout > 0 else None
56
+
57
+ w = getattr(base_linear, "weight")
58
+ out_dim, in_dim = int(w.shape[0]), int(w.shape[1])
59
+
60
+ self.A = mx.random.normal((self.r, in_dim)) * 0.01
61
+ self.B = mx.zeros((out_dim, self.r))
62
+
63
+ def __call__(self, x):
64
+ mx = self.mx
65
+ base_w = mx.stop_gradient(self.base.weight)
66
+ y = x @ base_w.T
67
+ if getattr(self.base, "bias", None) is not None:
68
+ y = y + mx.stop_gradient(self.base.bias)
69
+ z = x
70
+ if self.dropout is not None:
71
+ z = self.dropout(z)
72
+ z = (z @ self.A.T) @ self.B.T
73
+ return y + z * self.scale
74
+
75
+
76
+ def _iter_named_modules(root: Any, prefix: str = ""):
77
+ for name in dir(root):
78
+ if name.startswith("_"):
79
+ continue
80
+ try:
81
+ obj = getattr(root, name)
82
+ except Exception:
83
+ continue
84
+ if callable(getattr(obj, "__call__", None)) and hasattr(obj, "__dict__"):
85
+ full = f"{prefix}{name}" if not prefix else f"{prefix}.{name}"
86
+ yield full, root, name, obj
87
+
88
+
89
+ def inject_lora(
90
+ model: Any,
91
+ *,
92
+ r: int = 16,
93
+ alpha: int = 32,
94
+ dropout: float = 0.0,
95
+ target_modules: list[str] | None = None,
96
+ ) -> int:
97
+ _mx, _nn, _ = _require_mlx()
98
+ targets = target_modules or ["q_proj", "k_proj", "v_proj", "o_proj"]
99
+ wrapped = 0
100
+
101
+ for full, parent, attr, obj in list(_iter_named_modules(model)):
102
+ if not hasattr(obj, "weight"):
103
+ continue
104
+ if not any(attr.endswith(t) or full.endswith(t) for t in targets):
105
+ continue
106
+ cls_name = obj.__class__.__name__.lower()
107
+ if "linear" not in cls_name:
108
+ continue
109
+ try:
110
+ setattr(parent, attr, LoRALinear(obj, r=r, alpha=alpha, dropout=dropout))
111
+ wrapped += 1
112
+ except Exception:
113
+ continue
114
+
115
+ return wrapped
116
+
117
+
118
+ def lora_parameters(model: Any) -> dict[str, Any]:
119
+ params: dict[str, Any] = {}
120
+ for full, _parent, _attr, obj in list(_iter_named_modules(model)):
121
+ if obj.__class__.__name__ == "LoRALinear" or (hasattr(obj, "A") and hasattr(obj, "B")):
122
+ try:
123
+ params[f"{full}.A"] = obj.A
124
+ params[f"{full}.B"] = obj.B
125
+ except Exception:
126
+ pass
127
+ return params
128
+
129
+
130
+ def _scale_for_config(cfg: LoRAConfig) -> float:
131
+ if cfg.scale is not None:
132
+ return float(cfg.scale)
133
+ if cfg.r == 0:
134
+ return float(cfg.alpha)
135
+ return float(cfg.alpha) / float(cfg.r)
136
+
137
+
138
+ def _keys_for_target_modules(model: Any, target_modules: list[str]) -> set[str]:
139
+ keys: set[str] = set()
140
+ # Prefer model.named_modules if present (MLX-LM models support this)
141
+ if hasattr(model, "named_modules"):
142
+ for name, _mod in model.named_modules():
143
+ if any(name.endswith(t) for t in target_modules):
144
+ keys.add(name)
145
+ return keys
146
+
147
+ # Fallback: walk attributes
148
+ for full, _parent, _attr, obj in list(_iter_named_modules(model)):
149
+ if any(full.endswith(t) for t in target_modules):
150
+ keys.add(full)
151
+ return keys
152
+
153
+
154
+ def apply_lora(model: Any, cfg: LoRAConfig) -> dict:
155
+ """Apply LoRA layers (prefer MLX-LM utilities) and return adapter config."""
156
+ tuner_utils, _save_cfg = _try_mlx_lm_utils()
157
+ scale = _scale_for_config(cfg)
158
+ keys = None
159
+ if cfg.target_modules:
160
+ keys = sorted(_keys_for_target_modules(model, cfg.target_modules))
161
+
162
+ if tuner_utils is not None and hasattr(tuner_utils, "linear_to_lora_layers"):
163
+ # MLX-LM format
164
+ config = {
165
+ "rank": int(cfg.r),
166
+ "scale": float(scale),
167
+ "dropout": float(cfg.dropout),
168
+ }
169
+ if keys:
170
+ config["keys"] = keys
171
+ num_layers = int(cfg.num_layers)
172
+ tuner_utils.linear_to_lora_layers(
173
+ model,
174
+ num_layers,
175
+ config,
176
+ use_dora=(cfg.fine_tune_type == "dora"),
177
+ )
178
+ return {
179
+ "fine_tune_type": cfg.fine_tune_type,
180
+ "num_layers": num_layers,
181
+ "lora_parameters": config,
182
+ }
183
+
184
+ # Fallback to local LoRA injection
185
+ inject_lora(
186
+ model,
187
+ r=cfg.r,
188
+ alpha=cfg.alpha,
189
+ dropout=cfg.dropout,
190
+ target_modules=cfg.target_modules,
191
+ )
192
+ return {
193
+ "fine_tune_type": "lora",
194
+ "num_layers": 0,
195
+ "lora_parameters": {
196
+ "rank": int(cfg.r),
197
+ "scale": float(scale),
198
+ "dropout": float(cfg.dropout),
199
+ },
200
+ }
201
+
202
+
203
+ def save_adapter(model: Any, out_dir: str | Path, *, adapter_config: dict, metadata: dict | None = None) -> None:
204
+ out = Path(out_dir)
205
+ out.mkdir(parents=True, exist_ok=True)
206
+
207
+ tuner_utils, mlx_save_config = _try_mlx_lm_utils()
208
+ try:
209
+ mx, _nn, tree_flatten = _require_mlx()
210
+ except Exception:
211
+ mx = None
212
+ tree_flatten = None
213
+
214
+ if mx is not None and hasattr(model, "trainable_parameters") and tree_flatten is not None:
215
+ adapter_weights = dict(tree_flatten(model.trainable_parameters()))
216
+ try:
217
+ mx.save_safetensors(str(out / "adapters.safetensors"), adapter_weights)
218
+ except Exception:
219
+ # fallback to numpy
220
+ import numpy as np
221
+
222
+ arrays = {k: mx.array(v).to_numpy() for k, v in adapter_weights.items()}
223
+ np.savez(out / "lora.npz", **arrays)
224
+ else:
225
+ # fallback to local LoRA params
226
+ params = lora_parameters(model)
227
+ if params:
228
+ import numpy as np
229
+
230
+ mx_local, _nn_local, _ = _require_mlx()
231
+ arrays = {k: mx_local.array(v).to_numpy() for k, v in params.items()}
232
+ np.savez(out / "lora.npz", **arrays)
233
+
234
+ config_path = out / "adapter_config.json"
235
+ if mlx_save_config is not None:
236
+ cfg = dict(adapter_config)
237
+ mlx_save_config(cfg, config_path)
238
+ else:
239
+ config_path.write_text(json.dumps(adapter_config, indent=2), encoding="utf-8")
240
+
241
+ if metadata is not None:
242
+ (out / "adapter_metadata.json").write_text(
243
+ json.dumps(metadata, indent=2), encoding="utf-8"
244
+ )
245
+
246
+
247
+ def load_adapter_config(adapter_dir: str | Path) -> dict | None:
248
+ path = Path(adapter_dir) / "adapter_config.json"
249
+ if not path.exists():
250
+ return None
251
+ return json.loads(path.read_text(encoding="utf-8"))
252
+
253
+
254
+ def apply_adapter(model: Any, adapter_dir: str | Path) -> dict | None:
255
+ adapter_dir = Path(adapter_dir)
256
+ adapter_cfg = load_adapter_config(adapter_dir)
257
+ tuner_utils, _ = _try_mlx_lm_utils()
258
+ if adapter_cfg is not None and tuner_utils is not None and hasattr(tuner_utils, "load_adapters"):
259
+ tuner_utils.load_adapters(model, str(adapter_dir))
260
+ return adapter_cfg
261
+
262
+ # Fallback: load local lora.npz into LoRALinear wrappers
263
+ lora_file = adapter_dir / "lora.npz"
264
+ if not lora_file.exists():
265
+ return adapter_cfg
266
+
267
+ import numpy as np
268
+
269
+ weights = dict(np.load(lora_file))
270
+ mx, _nn, _ = _require_mlx()
271
+ # apply
272
+ for full, _parent, _attr, obj in list(_iter_named_modules(model)):
273
+ if hasattr(obj, "A") and hasattr(obj, "B"):
274
+ key_a = f"{full}.A"
275
+ key_b = f"{full}.B"
276
+ if key_a in weights:
277
+ obj.A = mx.array(weights[key_a])
278
+ if key_b in weights:
279
+ obj.B = mx.array(weights[key_b])
280
+ return adapter_cfg
mlxsmith/train/pref.py ADDED
@@ -0,0 +1,180 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import random
5
+ from pathlib import Path
6
+
7
+ from rich.console import Console
8
+
9
+ from ..accel import get_backend
10
+ from ..config import ProjectConfig
11
+ from ..models import resolve_model_spec
12
+ from ..runs import RunPaths, new_run, snapshot_config
13
+ from ..util import write_jsonl, now_ts, tree_add, tree_scale
14
+ from ..llm.registry import get_llm_backend
15
+ from ..llm.backend import BackendNotAvailable
16
+ from .lora import LoRAConfig
17
+
18
+ console = Console()
19
+
20
+
21
+ def _load_pref_rows(path: Path) -> list[dict]:
22
+ return [json.loads(line) for line in path.read_text(encoding="utf-8").splitlines() if line.strip()]
23
+
24
+
25
+ def run_pref(project_root: Path, cfg: ProjectConfig, data_dir: Path, base_model_path: Path, accel: str) -> RunPaths:
26
+ run = new_run(project_root, "pref")
27
+ snapshot_config(cfg.model_dump(), run.config_snapshot_path)
28
+
29
+ backend = get_backend(accel)
30
+ backend.patch()
31
+ console.print(f"[bold]PREF[/bold] run: {run.run_dir.name} algo={cfg.pref.algo} accel={backend.name}")
32
+
33
+ prefs_path = data_dir / "train.jsonl"
34
+ if not prefs_path.exists():
35
+ raise RuntimeError("Preference data missing. Expect data/prefs/train.jsonl with {prompt, chosen, rejected}")
36
+ rows = _load_pref_rows(prefs_path)
37
+
38
+ llm = get_llm_backend(cfg.model.backend)
39
+ base_model, adapter_path, adapter_meta = resolve_model_spec(project_root, str(base_model_path), cfg)
40
+
41
+ try:
42
+ llm.load(
43
+ base_model,
44
+ max_seq_len=cfg.model.max_seq_len,
45
+ dtype=cfg.model.dtype,
46
+ trust_remote_code=cfg.model.trust_remote_code,
47
+ )
48
+ if adapter_path:
49
+ llm.apply_adapter(str(adapter_path))
50
+ else:
51
+ lora_cfg = LoRAConfig(
52
+ r=cfg.lora.r,
53
+ alpha=cfg.lora.alpha,
54
+ dropout=cfg.lora.dropout,
55
+ target_modules=list(cfg.lora.target_modules or []),
56
+ num_layers=cfg.lora.num_layers,
57
+ scale=cfg.lora.scale,
58
+ fine_tune_type=cfg.lora.fine_tune_type,
59
+ )
60
+ llm.apply_lora_from_config(lora_cfg)
61
+ except BackendNotAvailable as e:
62
+ console.print(f"[yellow]MLX backend unavailable[/yellow]: {e}")
63
+ (run.adapter_dir / "ADAPTER.txt").write_text(
64
+ f"Backend unavailable in this environment.\nbase={base_model}\naccel={backend.name}\n",
65
+ encoding="utf-8",
66
+ )
67
+ return run
68
+
69
+ ref_llm = None
70
+ if cfg.pref.reference_model:
71
+ ref_llm = get_llm_backend(cfg.model.backend)
72
+ try:
73
+ ref_llm.load(
74
+ cfg.pref.reference_model,
75
+ max_seq_len=cfg.model.max_seq_len,
76
+ dtype=cfg.model.dtype,
77
+ trust_remote_code=cfg.model.trust_remote_code,
78
+ )
79
+ except BackendNotAvailable:
80
+ ref_llm = None
81
+
82
+ opt, _params = llm.optimizer_and_params(lr=cfg.train.lr, weight_decay=cfg.train.weight_decay)
83
+
84
+ beta = float(cfg.pref.beta)
85
+ kl_coeff = float(cfg.pref.kl_coeff)
86
+ rng = random.Random(cfg.train.seed)
87
+ total = int(cfg.train.iters)
88
+ grad_accum = max(1, int(cfg.train.grad_accum))
89
+ train_on_prompt = bool(getattr(cfg.train, "train_on_prompt", False))
90
+
91
+ accum_grads = None
92
+ for step in range(1, total + 1):
93
+ row = rng.choice(rows)
94
+ prompt = row.get("prompt") or ""
95
+ chosen = row.get("chosen") or row.get("accepted") or ""
96
+ rejected = row.get("rejected") or row.get("rejected_response") or ""
97
+ if not (prompt and chosen and rejected):
98
+ continue
99
+
100
+ prompt_ids = llm.encode(prompt)
101
+ chosen_ids = llm.encode(prompt + chosen)
102
+ rejected_ids = llm.encode(prompt + rejected)
103
+ p_len_c = len(prompt_ids)
104
+ p_len_r = len(prompt_ids)
105
+ max_len = int(cfg.model.max_seq_len)
106
+ if max_len:
107
+ if len(chosen_ids) > max_len:
108
+ overflow = len(chosen_ids) - max_len
109
+ chosen_ids = chosen_ids[overflow:]
110
+ p_len_c = max(0, p_len_c - overflow)
111
+ if len(rejected_ids) > max_len:
112
+ overflow = len(rejected_ids) - max_len
113
+ rejected_ids = rejected_ids[overflow:]
114
+ p_len_r = max(0, p_len_r - overflow)
115
+
116
+ def loss_fn(_model):
117
+ logp_c = llm.sequence_logprob(chosen_ids, prompt_len=p_len_c)
118
+ logp_r = llm.sequence_logprob(rejected_ids, prompt_len=p_len_r)
119
+ ref_diff = 0.0
120
+ if ref_llm is not None:
121
+ ref_logp_c = ref_llm.sequence_logprob(chosen_ids, prompt_len=p_len_c)
122
+ ref_logp_r = ref_llm.sequence_logprob(rejected_ids, prompt_len=p_len_r)
123
+ ref_diff = ref_logp_c - ref_logp_r
124
+ diff = (logp_c - logp_r) - ref_diff
125
+
126
+ if cfg.pref.algo == "orpo":
127
+ # ORPO loss = NLL(chosen) - beta * log(sigmoid(diff))
128
+ nll = llm.sft_loss(chosen_ids, train_on_prompt=train_on_prompt, prompt_len=p_len_c)
129
+ or_loss = -beta * llm.mx.log(llm.mx.sigmoid(diff)) # type: ignore
130
+ loss = nll + or_loss
131
+ else:
132
+ # DPO loss
133
+ scaled = llm.mx.array(beta) * diff # type: ignore
134
+ loss = llm.mx.log1p(llm.mx.exp(-scaled)) # type: ignore
135
+
136
+ if ref_llm is not None and kl_coeff > 0:
137
+ # Simple KL penalty on chosen responses
138
+ kl = (logp_c - ref_logp_c) if ref_llm is not None else 0.0
139
+ loss = loss + llm.mx.array(kl_coeff) * kl # type: ignore
140
+ return loss
141
+
142
+ lval, grads = llm.value_and_grad(loss_fn)
143
+ if grads is not None:
144
+ accum_grads = tree_add(accum_grads, grads)
145
+
146
+ if step % grad_accum == 0:
147
+ if accum_grads is not None:
148
+ llm.apply_grads(opt, tree_scale(accum_grads, 1.0 / grad_accum))
149
+ accum_grads = None
150
+
151
+ if step % cfg.train.log_every == 0 or step == 1 or step == total:
152
+ write_jsonl(
153
+ run.metrics_path,
154
+ [
155
+ {
156
+ "ts": now_ts(),
157
+ "step": step,
158
+ "kind": "pref",
159
+ "algo": cfg.pref.algo,
160
+ "beta": beta,
161
+ "kl_coeff": kl_coeff,
162
+ "loss": float(lval.item()) if hasattr(lval, "item") else float(lval),
163
+ "accel": backend.name,
164
+ }
165
+ ],
166
+ )
167
+
168
+ if step % cfg.train.save_every == 0 or step == total:
169
+ llm.save_adapter(
170
+ str(run.adapter_dir),
171
+ metadata={
172
+ "base_model": base_model,
173
+ "source_adapter": str(adapter_path) if adapter_path else None,
174
+ "run": run.run_dir.name,
175
+ "kind": "pref",
176
+ },
177
+ )
178
+
179
+ console.print(f"[green]Saved adapter[/green] {run.adapter_dir}")
180
+ return run