mlxsmith 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. mlxsmith/__init__.py +2 -0
  2. mlxsmith/accel/__init__.py +10 -0
  3. mlxsmith/accel/base.py +17 -0
  4. mlxsmith/accel/none.py +13 -0
  5. mlxsmith/accel/zmlx_backend.py +42 -0
  6. mlxsmith/adapters.py +46 -0
  7. mlxsmith/api/__init__.py +48 -0
  8. mlxsmith/api/handlers.py +1217 -0
  9. mlxsmith/api/schemas.py +436 -0
  10. mlxsmith/auth.py +88 -0
  11. mlxsmith/bench.py +102 -0
  12. mlxsmith/cli.py +950 -0
  13. mlxsmith/config.py +543 -0
  14. mlxsmith/config_models.py +261 -0
  15. mlxsmith/data.py +493 -0
  16. mlxsmith/envs/__init__.py +33 -0
  17. mlxsmith/envs/system.py +388 -0
  18. mlxsmith/envs/token_env.py +191 -0
  19. mlxsmith/eval.py +112 -0
  20. mlxsmith/infer.py +140 -0
  21. mlxsmith/llm/__init__.py +16 -0
  22. mlxsmith/llm/backend.py +126 -0
  23. mlxsmith/llm/interface.py +212 -0
  24. mlxsmith/llm/mlx_lm_backend.py +509 -0
  25. mlxsmith/llm/mock_backend.py +228 -0
  26. mlxsmith/llm/registry.py +12 -0
  27. mlxsmith/models.py +257 -0
  28. mlxsmith/orchestrator/__init__.py +25 -0
  29. mlxsmith/orchestrator/daemon.py +454 -0
  30. mlxsmith/orchestrator/inference_worker.py +496 -0
  31. mlxsmith/orchestrator/queue.py +355 -0
  32. mlxsmith/orchestrator/trainer_worker.py +437 -0
  33. mlxsmith/rlm/__init__.py +8 -0
  34. mlxsmith/rlm/corpus.py +74 -0
  35. mlxsmith/rlm/gating.py +90 -0
  36. mlxsmith/rlm/generate.py +249 -0
  37. mlxsmith/rlm/history.py +12 -0
  38. mlxsmith/rlm/inference.py +150 -0
  39. mlxsmith/rlm/loop.py +1297 -0
  40. mlxsmith/rlm/mutate.py +82 -0
  41. mlxsmith/rlm/trainer.py +73 -0
  42. mlxsmith/rlm/weights.py +263 -0
  43. mlxsmith/runs.py +44 -0
  44. mlxsmith/sdk/__init__.py +392 -0
  45. mlxsmith/sdk/future.py +486 -0
  46. mlxsmith/sdk/losses.py +262 -0
  47. mlxsmith/sdk/sampling_client.py +729 -0
  48. mlxsmith/sdk/training_client.py +676 -0
  49. mlxsmith/server.py +376 -0
  50. mlxsmith/train/__init__.py +0 -0
  51. mlxsmith/train/distill.py +279 -0
  52. mlxsmith/train/lora.py +280 -0
  53. mlxsmith/train/pref.py +180 -0
  54. mlxsmith/train/rft.py +458 -0
  55. mlxsmith/train/sft.py +151 -0
  56. mlxsmith/util.py +174 -0
  57. mlxsmith/verifiers/__init__.py +3 -0
  58. mlxsmith/verifiers/compose.py +109 -0
  59. mlxsmith/verifiers/docker_verifier.py +111 -0
  60. mlxsmith/verifiers/jsonschema.py +54 -0
  61. mlxsmith/verifiers/pytest_verifier.py +82 -0
  62. mlxsmith/verifiers/regex.py +15 -0
  63. mlxsmith/verifiers/types.py +10 -0
  64. mlxsmith-0.1.0.dist-info/METADATA +163 -0
  65. mlxsmith-0.1.0.dist-info/RECORD +69 -0
  66. mlxsmith-0.1.0.dist-info/WHEEL +5 -0
  67. mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
  68. mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
  69. mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,509 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ from typing import Sequence, Any, List, Dict, Optional
5
+
6
+ from .backend import Generation, BackendNotAvailable
7
+ from ..train.lora import apply_adapter, apply_lora, LoRAConfig, save_adapter
8
+
9
+
10
+ class MlxLMBackend:
11
+ """Backend built on MLX + mlx-lm."""
12
+
13
+ name = "mlx-lm"
14
+
15
+ def __init__(self, *, lora_config: dict | None = None):
16
+ self.lora_config = lora_config or {}
17
+ self.model = None
18
+ self.tokenizer = None
19
+ self.nn = None
20
+ self.mx = None
21
+ self.optim = None
22
+ self._lora_applied = False
23
+ self._adapter_config: dict | None = None
24
+
25
+ def _require(self):
26
+ try:
27
+ import mlx.core as mx # type: ignore
28
+ import mlx.nn as nn # type: ignore
29
+ import mlx.optimizers as optim # type: ignore
30
+ except Exception as e: # pragma: no cover
31
+ raise BackendNotAvailable(
32
+ "MLX is not installed. Try: pip install -e '.[mlx,llm]'"
33
+ ) from e
34
+ self.mx = mx
35
+ self.nn = nn
36
+ self.optim = optim
37
+
38
+ try:
39
+ import mlx_lm # type: ignore
40
+ except Exception as e: # pragma: no cover
41
+ raise BackendNotAvailable(
42
+ "mlx-lm is not installed. Try: pip install -e '.[llm]'"
43
+ ) from e
44
+ return mlx_lm
45
+
46
+ def _call_with_supported_kwargs(self, fn, *args, **kwargs):
47
+ try:
48
+ sig = inspect.signature(fn)
49
+ supported = {}
50
+ for k, v in kwargs.items():
51
+ if k in sig.parameters:
52
+ supported[k] = v
53
+ return fn(*args, **supported)
54
+ except Exception:
55
+ return fn(*args, **kwargs)
56
+
57
+ def load(
58
+ self,
59
+ model_id_or_path: str,
60
+ *,
61
+ max_seq_len: int | None = None,
62
+ dtype: str | None = None,
63
+ tokenizer_config: dict | None = None,
64
+ model_config: dict | None = None,
65
+ adapter_path: str | None = None,
66
+ trust_remote_code: bool | None = None,
67
+ ) -> None:
68
+ mlx_lm = self._require()
69
+
70
+ if tokenizer_config is None:
71
+ tokenizer_config = {}
72
+ if trust_remote_code is not None:
73
+ tokenizer_config = dict(tokenizer_config)
74
+ tokenizer_config.setdefault("trust_remote_code", trust_remote_code)
75
+
76
+ if model_config is None:
77
+ model_config = {}
78
+ # Pass dtype/max_seq_len as hints when supported in model config.
79
+ if dtype is not None:
80
+ model_config = dict(model_config)
81
+ model_config.setdefault("dtype", dtype)
82
+ if max_seq_len is not None:
83
+ model_config = dict(model_config)
84
+ model_config.setdefault("max_seq_len", max_seq_len)
85
+
86
+ load_fn = getattr(mlx_lm, "load", None)
87
+ if callable(load_fn):
88
+ model, tokenizer = self._call_with_supported_kwargs(
89
+ load_fn,
90
+ model_id_or_path,
91
+ tokenizer_config=tokenizer_config,
92
+ model_config=model_config,
93
+ adapter_path=adapter_path,
94
+ )
95
+ else: # pragma: no cover
96
+ utils = getattr(mlx_lm, "utils", None)
97
+ if utils is None or not callable(getattr(utils, "load", None)):
98
+ raise BackendNotAvailable("Could not find mlx_lm.load(...) API")
99
+ model, tokenizer = self._call_with_supported_kwargs(
100
+ utils.load,
101
+ model_id_or_path,
102
+ tokenizer_config=tokenizer_config,
103
+ model_config=model_config,
104
+ adapter_path=adapter_path,
105
+ )
106
+
107
+ self.model = model
108
+ self.tokenizer = tokenizer
109
+ self._lora_applied = False
110
+ self._adapter_config = None
111
+
112
+ def apply_adapter(self, adapter_path: str) -> None:
113
+ if self.model is None:
114
+ raise RuntimeError("Backend not loaded")
115
+ self._adapter_config = apply_adapter(self.model, adapter_path)
116
+ self._lora_applied = True
117
+
118
+ def apply_lora_from_config(self, cfg: LoRAConfig) -> dict:
119
+ if self.model is None:
120
+ raise RuntimeError("Backend not loaded")
121
+ adapter_cfg = apply_lora(self.model, cfg)
122
+ self._lora_applied = True
123
+ self._adapter_config = adapter_cfg
124
+ return adapter_cfg
125
+
126
+ def encode(self, text: str) -> list[int]:
127
+ if self.tokenizer is None:
128
+ raise RuntimeError("Backend not loaded")
129
+ tok = self.tokenizer
130
+ if hasattr(tok, "encode"):
131
+ out = tok.encode(text)
132
+ if isinstance(out, dict) and "input_ids" in out:
133
+ return list(out["input_ids"])
134
+ if isinstance(out, (list, tuple)):
135
+ return list(out)
136
+ if hasattr(tok, "__call__"):
137
+ out = tok(text)
138
+ if isinstance(out, dict) and "input_ids" in out:
139
+ return list(out["input_ids"])
140
+ raise RuntimeError("Tokenizer does not support encode")
141
+
142
+ def decode(self, ids: Sequence[int]) -> str:
143
+ if self.tokenizer is None:
144
+ raise RuntimeError("Backend not loaded")
145
+ tok = self.tokenizer
146
+ if hasattr(tok, "decode"):
147
+ return tok.decode(list(ids))
148
+ raise RuntimeError("Tokenizer does not support decode")
149
+
150
+ def _forward_logits(self, ids: Sequence[int]):
151
+ assert self.mx is not None
152
+ if self.model is None:
153
+ raise RuntimeError("Backend not loaded")
154
+ mx = self.mx
155
+ x = mx.array([list(ids)], dtype=mx.int32)
156
+ return self.model(x)
157
+
158
+ def _response_logprobs(self, ids: Sequence[int], *, prompt_len: int) -> list[float]:
159
+ assert self.mx is not None
160
+ mx = self.mx
161
+ if not ids:
162
+ return []
163
+ logits = self._forward_logits(ids)
164
+ logits = logits[:, :-1, :]
165
+ labels = mx.array([list(ids)[1:]], dtype=mx.int32)
166
+ lse = mx.logsumexp(logits, axis=-1)
167
+ chosen = mx.take_along_axis(logits, labels[..., None], axis=-1).squeeze(-1)
168
+ logp = chosen - lse
169
+ start = max(0, prompt_len - 1)
170
+ if start >= int(getattr(logp, "size", len(ids) - 1)):
171
+ return []
172
+ values = logp[:, start:]
173
+ try:
174
+ flat = values.flatten().tolist()
175
+ except Exception:
176
+ try:
177
+ flat = [float(v) for v in values.reshape(-1)]
178
+ except Exception:
179
+ flat = [float(v) for v in values]
180
+ return [float(v) for v in flat]
181
+
182
+ def _extract_top_k_logprobs(
183
+ self,
184
+ logits: Any,
185
+ k: int,
186
+ sampled_ids: Optional[Sequence[int]] = None,
187
+ ) -> List[Dict[str, float]]:
188
+ """Extract top-k logprobs from logits.
189
+
190
+ Args:
191
+ logits: Logits array [batch, seq_len, vocab_size]
192
+ k: Number of top logprobs to extract
193
+ sampled_ids: Optional token IDs that were actually sampled
194
+
195
+ Returns:
196
+ List of dicts mapping token string to logprob for each position
197
+ """
198
+ assert self.mx is not None
199
+ mx = self.mx
200
+
201
+ # Get log softmax
202
+ log_probs = mx.log(mx.softmax(logits, axis=-1))
203
+
204
+ # Get top-k indices and values
205
+ # MLX doesn't have topk directly, so we use argsort
206
+ sorted_indices = mx.argsort(-log_probs, axis=-1)
207
+
208
+ results = []
209
+ batch_size, seq_len, vocab_size = log_probs.shape
210
+
211
+ # Limit k to vocab size
212
+ k = min(k, vocab_size)
213
+
214
+ for b in range(batch_size):
215
+ for t in range(seq_len):
216
+ # Get top-k for this position
217
+ top_k_indices = sorted_indices[b, t, :k]
218
+ top_k_logprobs = log_probs[b, t, top_k_indices]
219
+
220
+ # Build dict
221
+ token_logprobs = {}
222
+ for idx, logprob in zip(top_k_indices.tolist(), top_k_logprobs.tolist()):
223
+ # Try to decode token
224
+ try:
225
+ token_str = self.decode([idx])
226
+ # Escape special characters for JSON compatibility
227
+ token_str = token_str.replace('\n', '\\n').replace('\t', '\\t')
228
+ except Exception:
229
+ token_str = f"<token_{idx}>"
230
+ token_logprobs[token_str] = float(logprob)
231
+
232
+ results.append(token_logprobs)
233
+
234
+ return results
235
+
236
+ def generate(
237
+ self,
238
+ prompt: str,
239
+ *,
240
+ max_new_tokens: int = 256,
241
+ temperature: float = 0.8,
242
+ top_p: float = 1.0,
243
+ top_k: int | None = None,
244
+ seed: int | None = None,
245
+ ) -> Generation:
246
+ assert self.mx is not None
247
+ mx = self.mx
248
+ if seed is not None:
249
+ mx.random.seed(seed)
250
+
251
+ prompt_ids = self.encode(prompt)
252
+ ids = list(prompt_ids)
253
+ prompt_len = len(prompt_ids)
254
+
255
+ # Prefer mlx_lm sampler if available
256
+ sampler = None
257
+ try:
258
+ from mlx_lm.sample_utils import make_sampler # type: ignore
259
+
260
+ sampler = make_sampler(
261
+ temp=float(temperature),
262
+ top_p=float(top_p),
263
+ top_k=int(top_k or 0),
264
+ )
265
+ except Exception:
266
+ sampler = None
267
+
268
+ for _ in range(max_new_tokens):
269
+ logits = self._forward_logits(ids)
270
+ last = logits[:, -1, :]
271
+ if temperature <= 0:
272
+ next_id = int(mx.argmax(last, axis=-1).item())
273
+ elif sampler is not None:
274
+ next_id = int(sampler(last).item())
275
+ else:
276
+ probs = mx.softmax(last / float(temperature), axis=-1)
277
+ next_id = int(mx.random.categorical(mx.log(probs)).item())
278
+ ids.append(next_id)
279
+ text = self.decode(ids)
280
+ return Generation(text=text, token_ids=ids, prompt_len=prompt_len)
281
+
282
+ def generate_with_logprobs(
283
+ self,
284
+ prompt: str,
285
+ *,
286
+ max_new_tokens: int = 256,
287
+ temperature: float = 0.8,
288
+ top_p: float = 1.0,
289
+ top_k_sampling: int | None = None,
290
+ seed: int | None = None,
291
+ logprobs: int = 0, # Number of top logprobs to return per token
292
+ ) -> Generation:
293
+ """Generate with logprobs support including top-k logprobs.
294
+
295
+ Args:
296
+ prompt: Input prompt
297
+ max_new_tokens: Maximum tokens to generate
298
+ temperature: Sampling temperature
299
+ top_p: Nucleus sampling parameter
300
+ top_k_sampling: Top-k sampling parameter (named to avoid conflict with logprobs)
301
+ seed: Random seed
302
+ logprobs: Number of top logprobs to return per token (0 = only sampled token)
303
+
304
+ Returns:
305
+ Generation with logprobs and optionally top_k_logprobs
306
+ """
307
+ assert self.mx is not None
308
+ mx = self.mx
309
+ if seed is not None:
310
+ mx.random.seed(seed)
311
+
312
+ prompt_ids = self.encode(prompt)
313
+ ids = list(prompt_ids)
314
+ prompt_len = len(prompt_ids)
315
+
316
+ # Storage for per-token info
317
+ per_token_logprobs: list[float] = []
318
+ per_token_top_k: list[dict[str, float]] = []
319
+
320
+ # Prefer mlx_lm sampler if available
321
+ sampler = None
322
+ try:
323
+ from mlx_lm.sample_utils import make_sampler # type: ignore
324
+
325
+ sampler = make_sampler(
326
+ temp=float(temperature),
327
+ top_p=float(top_p),
328
+ top_k=int(top_k_sampling or 0),
329
+ )
330
+ except Exception:
331
+ sampler = None
332
+
333
+ for _ in range(max_new_tokens):
334
+ logits = self._forward_logits(ids)
335
+ last = logits[:, -1, :] # [batch=1, vocab_size]
336
+
337
+ # Get log probabilities for this position
338
+ log_probs = mx.log(mx.softmax(last, axis=-1))
339
+
340
+ if temperature <= 0:
341
+ next_id = int(mx.argmax(last, axis=-1).item())
342
+ elif sampler is not None:
343
+ next_id = int(sampler(last).item())
344
+ else:
345
+ probs = mx.softmax(last / float(temperature), axis=-1)
346
+ next_id = int(mx.random.categorical(mx.log(probs)).item())
347
+
348
+ # Get logprob of sampled token
349
+ sampled_logprob = float(log_probs[0, next_id].item())
350
+ per_token_logprobs.append(sampled_logprob)
351
+
352
+ # Get top-k logprobs if requested
353
+ if logprobs > 0:
354
+ top_k_logprobs = self._extract_top_k_logprobs(
355
+ last,
356
+ k=logprobs,
357
+ sampled_ids=[next_id]
358
+ )
359
+ if top_k_logprobs:
360
+ per_token_top_k.append(top_k_logprobs[0])
361
+
362
+ ids.append(next_id)
363
+
364
+ text = self.decode(ids)
365
+
366
+ return Generation(
367
+ text=text,
368
+ token_ids=ids,
369
+ prompt_len=prompt_len,
370
+ logprobs=per_token_logprobs,
371
+ top_k_logprobs=per_token_top_k if per_token_top_k else None,
372
+ )
373
+
374
+ def sft_loss(self, token_ids: Sequence[int], *, train_on_prompt: bool, prompt_len: int) -> Any:
375
+ assert self.mx is not None
376
+ mx = self.mx
377
+ ids = list(token_ids)
378
+ logits = self._forward_logits(ids)
379
+ logits = logits[:, :-1, :]
380
+ labels = mx.array([ids[1:]], dtype=mx.int32)
381
+
382
+ if not train_on_prompt:
383
+ mask = [0] * max(0, prompt_len - 1) + [1] * (len(ids) - prompt_len)
384
+ mask = mx.array([mask], dtype=mx.float32)
385
+ else:
386
+ mask = mx.ones(labels.shape, dtype=mx.float32)
387
+
388
+ lse = mx.logsumexp(logits, axis=-1)
389
+ chosen = mx.take_along_axis(logits, labels[..., None], axis=-1).squeeze(-1)
390
+ nll = (lse - chosen) * mask
391
+ denom = mx.maximum(mask.sum(), mx.array(1.0))
392
+ return nll.sum() / denom
393
+
394
+ def rl_loss(self, token_ids: Sequence[int], *, prompt_len: int, advantage: float) -> Any:
395
+ assert self.mx is not None
396
+ mx = self.mx
397
+ ids = list(token_ids)
398
+ logits = self._forward_logits(ids)
399
+ logits = logits[:, :-1, :]
400
+ labels = mx.array([ids[1:]], dtype=mx.int32)
401
+
402
+ lse = mx.logsumexp(logits, axis=-1)
403
+ chosen = mx.take_along_axis(logits, labels[..., None], axis=-1).squeeze(-1)
404
+ logp = chosen - lse
405
+
406
+ start = max(0, prompt_len - 1)
407
+ logp_resp = logp[:, start:]
408
+ return -float(advantage) * logp_resp.sum() / mx.maximum(mx.array(1.0), mx.array(logp_resp.size))
409
+
410
+ def sequence_logprob(self, token_ids: Sequence[int], *, prompt_len: int) -> Any:
411
+ assert self.mx is not None
412
+ mx = self.mx
413
+ ids = list(token_ids)
414
+ logits = self._forward_logits(ids)
415
+ logits = logits[:, :-1, :]
416
+ labels = mx.array([ids[1:]], dtype=mx.int32)
417
+ lse = mx.logsumexp(logits, axis=-1)
418
+ chosen = mx.take_along_axis(logits, labels[..., None], axis=-1).squeeze(-1)
419
+ logp = chosen - lse
420
+ start = max(0, prompt_len - 1)
421
+ return logp[:, start:].sum()
422
+
423
+ def token_logprobs(
424
+ self,
425
+ token_ids: Sequence[int],
426
+ *,
427
+ prompt_len: int,
428
+ top_k: int = 0,
429
+ include_prompt: bool = False,
430
+ ) -> tuple[list[float], list[dict[str, float]] | None]:
431
+ assert self.mx is not None
432
+ mx = self.mx
433
+ ids = list(token_ids)
434
+ if len(ids) < 2:
435
+ return [], [] if top_k > 0 else None
436
+
437
+ logits = self._forward_logits(ids)
438
+ logits = logits[:, :-1, :]
439
+ labels = mx.array([ids[1:]], dtype=mx.int32)
440
+ lse = mx.logsumexp(logits, axis=-1)
441
+ chosen = mx.take_along_axis(logits, labels[..., None], axis=-1).squeeze(-1)
442
+ logp = chosen - lse
443
+
444
+ start = 0 if include_prompt else max(0, prompt_len - 1)
445
+ values = logp[:, start:]
446
+ try:
447
+ flat = values.flatten().tolist()
448
+ except Exception:
449
+ try:
450
+ flat = [float(v) for v in values.reshape(-1)]
451
+ except Exception:
452
+ flat = [float(v) for v in values]
453
+ logprobs = [float(v) for v in flat]
454
+
455
+ if top_k <= 0:
456
+ return logprobs, None
457
+
458
+ top_k_all = self._extract_top_k_logprobs(logits, k=int(top_k))
459
+ top_k_list = top_k_all[start:] if top_k_all else []
460
+ return logprobs, top_k_list
461
+
462
+ def value_and_grad(self, loss_fn):
463
+ if self.nn is None or self.model is None:
464
+ return loss_fn(self.model), None
465
+ vag = getattr(self.nn, "value_and_grad", None)
466
+ if callable(vag):
467
+ return vag(self.model, loss_fn)(self.model)
468
+ return loss_fn(self.model), None
469
+
470
+ def optimizer_and_params(self, *, lr: float, weight_decay: float = 0.0) -> tuple[Any, Any]:
471
+ assert self.optim is not None
472
+ if self.model is None:
473
+ raise RuntimeError("Backend not loaded")
474
+
475
+ params = None
476
+ if hasattr(self.model, "trainable_parameters"):
477
+ params = self.model.trainable_parameters()
478
+ if params is None or not params:
479
+ # fallback: train LoRA params if injected
480
+ from ..train.lora import lora_parameters
481
+
482
+ params = lora_parameters(self.model)
483
+ if not params:
484
+ params = getattr(self.model, "parameters", lambda: self.model)()
485
+
486
+ opt = self.optim.AdamW(learning_rate=lr, weight_decay=weight_decay)
487
+ opt.init(params)
488
+ return opt, params
489
+
490
+ def apply_grads(self, optimizer: Any, grads: Any) -> None:
491
+ assert self.mx is not None
492
+ mx = self.mx
493
+ if self.model is None:
494
+ raise RuntimeError("Backend not loaded")
495
+ optimizer.update(self.model, grads)
496
+ try:
497
+ mx.eval(self.model.parameters(), optimizer.state)
498
+ except Exception: # pragma: no cover
499
+ pass
500
+
501
+ def save_adapter(self, out_dir: str, *, metadata: dict | None = None) -> None:
502
+ if self.model is None:
503
+ raise RuntimeError("Backend not loaded")
504
+ adapter_cfg = self._adapter_config or {
505
+ "fine_tune_type": "lora",
506
+ "num_layers": 0,
507
+ "lora_parameters": {},
508
+ }
509
+ save_adapter(self.model, out_dir, adapter_config=adapter_cfg, metadata=metadata)