mlxsmith 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. mlxsmith/__init__.py +2 -0
  2. mlxsmith/accel/__init__.py +10 -0
  3. mlxsmith/accel/base.py +17 -0
  4. mlxsmith/accel/none.py +13 -0
  5. mlxsmith/accel/zmlx_backend.py +42 -0
  6. mlxsmith/adapters.py +46 -0
  7. mlxsmith/api/__init__.py +48 -0
  8. mlxsmith/api/handlers.py +1217 -0
  9. mlxsmith/api/schemas.py +436 -0
  10. mlxsmith/auth.py +88 -0
  11. mlxsmith/bench.py +102 -0
  12. mlxsmith/cli.py +950 -0
  13. mlxsmith/config.py +543 -0
  14. mlxsmith/config_models.py +261 -0
  15. mlxsmith/data.py +493 -0
  16. mlxsmith/envs/__init__.py +33 -0
  17. mlxsmith/envs/system.py +388 -0
  18. mlxsmith/envs/token_env.py +191 -0
  19. mlxsmith/eval.py +112 -0
  20. mlxsmith/infer.py +140 -0
  21. mlxsmith/llm/__init__.py +16 -0
  22. mlxsmith/llm/backend.py +126 -0
  23. mlxsmith/llm/interface.py +212 -0
  24. mlxsmith/llm/mlx_lm_backend.py +509 -0
  25. mlxsmith/llm/mock_backend.py +228 -0
  26. mlxsmith/llm/registry.py +12 -0
  27. mlxsmith/models.py +257 -0
  28. mlxsmith/orchestrator/__init__.py +25 -0
  29. mlxsmith/orchestrator/daemon.py +454 -0
  30. mlxsmith/orchestrator/inference_worker.py +496 -0
  31. mlxsmith/orchestrator/queue.py +355 -0
  32. mlxsmith/orchestrator/trainer_worker.py +437 -0
  33. mlxsmith/rlm/__init__.py +8 -0
  34. mlxsmith/rlm/corpus.py +74 -0
  35. mlxsmith/rlm/gating.py +90 -0
  36. mlxsmith/rlm/generate.py +249 -0
  37. mlxsmith/rlm/history.py +12 -0
  38. mlxsmith/rlm/inference.py +150 -0
  39. mlxsmith/rlm/loop.py +1297 -0
  40. mlxsmith/rlm/mutate.py +82 -0
  41. mlxsmith/rlm/trainer.py +73 -0
  42. mlxsmith/rlm/weights.py +263 -0
  43. mlxsmith/runs.py +44 -0
  44. mlxsmith/sdk/__init__.py +392 -0
  45. mlxsmith/sdk/future.py +486 -0
  46. mlxsmith/sdk/losses.py +262 -0
  47. mlxsmith/sdk/sampling_client.py +729 -0
  48. mlxsmith/sdk/training_client.py +676 -0
  49. mlxsmith/server.py +376 -0
  50. mlxsmith/train/__init__.py +0 -0
  51. mlxsmith/train/distill.py +279 -0
  52. mlxsmith/train/lora.py +280 -0
  53. mlxsmith/train/pref.py +180 -0
  54. mlxsmith/train/rft.py +458 -0
  55. mlxsmith/train/sft.py +151 -0
  56. mlxsmith/util.py +174 -0
  57. mlxsmith/verifiers/__init__.py +3 -0
  58. mlxsmith/verifiers/compose.py +109 -0
  59. mlxsmith/verifiers/docker_verifier.py +111 -0
  60. mlxsmith/verifiers/jsonschema.py +54 -0
  61. mlxsmith/verifiers/pytest_verifier.py +82 -0
  62. mlxsmith/verifiers/regex.py +15 -0
  63. mlxsmith/verifiers/types.py +10 -0
  64. mlxsmith-0.1.0.dist-info/METADATA +163 -0
  65. mlxsmith-0.1.0.dist-info/RECORD +69 -0
  66. mlxsmith-0.1.0.dist-info/WHEEL +5 -0
  67. mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
  68. mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
  69. mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
mlxsmith/rlm/mutate.py ADDED
@@ -0,0 +1,82 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from typing import Iterable, List, Optional
5
+
6
+ from ..util import sha1_text
7
+ from .generate import GeneratedTask, extract_json_objects, task_to_prompt, task_to_tests
8
+
9
+
10
+ def mutate_tasks(
11
+ llm,
12
+ tasks: Iterable[GeneratedTask],
13
+ *,
14
+ mutations_per_task: int,
15
+ max_total: Optional[int] = None,
16
+ temperature: float = 0.7,
17
+ max_new_tokens: int = 512,
18
+ top_p: float = 1.0,
19
+ top_k: Optional[int] = None,
20
+ require_recursion: bool = False,
21
+ ) -> List[GeneratedTask]:
22
+ tasks_list = list(tasks)
23
+ if mutations_per_task <= 0 or not tasks_list:
24
+ return tasks_list
25
+
26
+ mutated: List[GeneratedTask] = list(tasks_list)
27
+
28
+ for task in tasks_list:
29
+ for idx in range(mutations_per_task):
30
+ prompt = (
31
+ "Mutate the following coding task to increase diversity. "
32
+ "Return ONE JSON object with fields: id, description, signature, tests.\n\n"
33
+ f"ORIGINAL_ID: {task.id}\n"
34
+ f"ORIGINAL_PROMPT:\n{task.prompt}\n\n"
35
+ f"ORIGINAL_TESTS:\n{task.tests}\n"
36
+ )
37
+ gen = llm.generate(
38
+ prompt,
39
+ max_new_tokens=max_new_tokens,
40
+ temperature=temperature,
41
+ top_p=top_p,
42
+ top_k=top_k,
43
+ )
44
+ items = extract_json_objects(gen.text)
45
+ if not items:
46
+ continue
47
+
48
+ item = items[0]
49
+ tid = item.get("id") or f"{task.id}_m{idx}" or sha1_text(json.dumps(item, sort_keys=True))[:12]
50
+ task_prompt = task_to_prompt(item, require_recursion=require_recursion)
51
+ tests = task_to_tests(item)
52
+
53
+ if len(task_prompt) < 10 or not tests:
54
+ continue
55
+
56
+ mutated.append(
57
+ GeneratedTask(
58
+ id=str(tid),
59
+ prompt=task_prompt,
60
+ tests=tests,
61
+ description=item.get("description"),
62
+ )
63
+ )
64
+
65
+ if max_total is not None and len(mutated) >= max_total:
66
+ break
67
+ if max_total is not None and len(mutated) >= max_total:
68
+ break
69
+
70
+ # Deduplicate by id/prompt hash
71
+ seen = set()
72
+ deduped: List[GeneratedTask] = []
73
+ for t in mutated:
74
+ key = t.id or sha1_text(t.prompt)
75
+ if key in seen:
76
+ continue
77
+ seen.add(key)
78
+ deduped.append(t)
79
+ if max_total is not None and len(deduped) >= max_total:
80
+ break
81
+
82
+ return deduped
@@ -0,0 +1,73 @@
1
+ from __future__ import annotations
2
+
3
+ from collections import defaultdict
4
+ from typing import Iterable, Optional
5
+
6
+ from ..config import ProjectConfig
7
+ from ..util import now_ts, latency_summary_ms
8
+ from .inference import Rollout
9
+
10
+
11
+ def train_on_rollouts(
12
+ llm,
13
+ rollouts: Iterable[Rollout],
14
+ cfg: ProjectConfig,
15
+ *,
16
+ optimizer: object,
17
+ train_adapter: Optional[str] = None,
18
+ ref_llm: Optional[object] = None,
19
+ ) -> list[dict]:
20
+ grouped = defaultdict(list)
21
+ for r in rollouts:
22
+ grouped[r.task_id].append(r)
23
+
24
+ metrics_rows: list[dict] = []
25
+
26
+ for task_id, rows in grouped.items():
27
+ if not rows:
28
+ continue
29
+
30
+ mean_r = sum(r.reward for r in rows) / max(1, len(rows))
31
+ std_r = (sum((r.reward - mean_r) ** 2 for r in rows) / max(1, len(rows))) ** 0.5
32
+ advs = [r.reward - mean_r for r in rows]
33
+ if bool(cfg.rft.normalize_advantage) and std_r > 1e-6:
34
+ advs = [a / std_r for a in advs]
35
+
36
+ def loss_fn(_model):
37
+ loss = llm.mx.array(0.0) # type: ignore
38
+ for rollout, adv in zip(rows, advs):
39
+ logp = llm.sequence_logprob(rollout.token_ids, prompt_len=rollout.prompt_len)
40
+ if rollout.logprobs and rollout.weight_adapter and rollout.weight_adapter != train_adapter:
41
+ behavior_logp = llm.mx.array(sum(rollout.logprobs)) # type: ignore
42
+ ratio = llm.mx.exp(logp - behavior_logp) # type: ignore
43
+ pg = -ratio * llm.mx.array(float(adv)) # type: ignore
44
+ else:
45
+ pg = -llm.mx.array(float(adv)) * logp # type: ignore
46
+ if ref_llm is not None and cfg.rft.kl_coeff > 0:
47
+ ref_logp = ref_llm.sequence_logprob(rollout.token_ids, prompt_len=rollout.prompt_len)
48
+ pg = pg + llm.mx.array(cfg.rft.kl_coeff) * (logp - ref_logp) # type: ignore
49
+ loss = loss + pg
50
+ return loss / llm.mx.array(float(len(rows))) # type: ignore
51
+
52
+ lval, grads = llm.value_and_grad(loss_fn)
53
+ if grads is not None:
54
+ llm.apply_grads(optimizer, grads)
55
+
56
+ latency_summary = latency_summary_ms([float(r.verifier_latency_ms) for r in rows])
57
+ metrics = {
58
+ "ts": now_ts(),
59
+ "task_id": task_id,
60
+ "mean_reward": mean_r,
61
+ "std_reward": std_r,
62
+ "loss": float(lval.item()) if hasattr(lval, "item") else float(lval),
63
+ "verifier_latency_ms": latency_summary.get("mean", 0.0),
64
+ "verifier_latency_ms_mean": latency_summary.get("mean", 0.0),
65
+ "verifier_latency_ms_p50": latency_summary.get("p50", 0.0),
66
+ "verifier_latency_ms_p90": latency_summary.get("p90", 0.0),
67
+ "verifier_latency_ms_p99": latency_summary.get("p99", 0.0),
68
+ "verifier_latency_ms_max": latency_summary.get("max", 0.0),
69
+ "weight_adapter": rows[0].weight_adapter,
70
+ }
71
+ metrics_rows.append(metrics)
72
+
73
+ return metrics_rows
@@ -0,0 +1,263 @@
1
+ """Weight pointer system for tracking adapter weights across RLM iterations.
2
+
3
+ Extends to support IPC for multi-process orchestration with atomic updates
4
+ and hot-reload capabilities.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ import multiprocessing as mp
11
+ import time
12
+ from dataclasses import dataclass, asdict
13
+ from pathlib import Path
14
+ from typing import Optional, Callable
15
+
16
+ from ..util import ensure_dir, now_ts
17
+
18
+
19
+ @dataclass
20
+ class WeightPointer:
21
+ base_model: str
22
+ adapter_path: Optional[str]
23
+ iteration: int
24
+ updated_at: str
25
+ name: Optional[str] = None
26
+
27
+
28
+ @dataclass
29
+ class WeightPointerIPC:
30
+ """Extended WeightPointer with IPC support for multi-process orchestration.
31
+
32
+ Includes versioning and atomic update mechanisms for hot-reloading.
33
+ """
34
+ base_model: str
35
+ adapter_path: Optional[str]
36
+ iteration: int
37
+ updated_at: str
38
+ version: int = 0 # Monotonic version for ordering updates
39
+ checksum: Optional[str] = None # Optional checksum for integrity
40
+ name: Optional[str] = None
41
+
42
+ def to_dict(self) -> dict:
43
+ return {
44
+ "base_model": self.base_model,
45
+ "adapter_path": self.adapter_path,
46
+ "iteration": self.iteration,
47
+ "updated_at": self.updated_at,
48
+ "version": self.version,
49
+ "checksum": self.checksum,
50
+ "name": self.name,
51
+ }
52
+
53
+ @classmethod
54
+ def from_dict(cls, data: dict) -> "WeightPointerIPC":
55
+ return cls(
56
+ base_model=data["base_model"],
57
+ adapter_path=data.get("adapter_path"),
58
+ iteration=data.get("iteration", 0),
59
+ updated_at=data.get("updated_at", now_ts()),
60
+ version=data.get("version", 0),
61
+ checksum=data.get("checksum"),
62
+ name=data.get("name"),
63
+ )
64
+
65
+
66
+ class WeightPointerStore:
67
+ """Atomic weight pointer store for IPC between processes.
68
+
69
+ Uses file-based atomic updates with versioning to ensure
70
+ inference workers always see consistent state.
71
+ """
72
+
73
+ def __init__(self, weights_dir: Path):
74
+ self._weights_dir = Path(weights_dir)
75
+ self._lock = mp.Lock()
76
+
77
+ def get_path(self, name: str) -> Path:
78
+ """Get the storage path for a named pointer."""
79
+ return self._weights_dir / f"{name}.json"
80
+
81
+ def get_atomic_path(self, name: str) -> Path:
82
+ """Get the temporary path for atomic writes."""
83
+ return self._weights_dir / f".{name}.tmp"
84
+
85
+ def load(self, name: str, base_model: str) -> WeightPointerIPC:
86
+ """Load a weight pointer from storage."""
87
+ path = self.get_path(name)
88
+
89
+ with self._lock:
90
+ if not path.exists():
91
+ return WeightPointerIPC(
92
+ base_model=base_model,
93
+ adapter_path=None,
94
+ iteration=0,
95
+ updated_at=now_ts(),
96
+ version=0,
97
+ name=name,
98
+ )
99
+
100
+ try:
101
+ data = json.loads(path.read_text(encoding="utf-8"))
102
+ return WeightPointerIPC.from_dict(data)
103
+ except Exception:
104
+ return WeightPointerIPC(
105
+ base_model=base_model,
106
+ adapter_path=None,
107
+ iteration=0,
108
+ updated_at=now_ts(),
109
+ version=0,
110
+ name=name,
111
+ )
112
+
113
+ def save(self, pointer: WeightPointerIPC) -> None:
114
+ """Atomically save a weight pointer."""
115
+ path = self.get_path(pointer.name or "default")
116
+ tmp_path = self.get_atomic_path(pointer.name or "default")
117
+
118
+ ensure_dir(self._weights_dir)
119
+
120
+ with self._lock:
121
+ # Write to temp file
122
+ tmp_path.write_text(
123
+ json.dumps(pointer.to_dict(), indent=2),
124
+ encoding="utf-8",
125
+ )
126
+ # Atomic rename
127
+ tmp_path.rename(path)
128
+
129
+ def update(
130
+ self,
131
+ name: str,
132
+ adapter_path: Optional[str] = None,
133
+ iteration: Optional[int] = None,
134
+ checksum: Optional[str] = None,
135
+ ) -> WeightPointerIPC:
136
+ """Update a weight pointer atomically."""
137
+ current = self.load(name, "") # base_model will be preserved
138
+
139
+ new_pointer = WeightPointerIPC(
140
+ base_model=current.base_model,
141
+ adapter_path=adapter_path if adapter_path is not None else current.adapter_path,
142
+ iteration=iteration if iteration is not None else current.iteration,
143
+ updated_at=now_ts(),
144
+ version=current.version + 1,
145
+ checksum=checksum,
146
+ name=name,
147
+ )
148
+
149
+ self.save(new_pointer)
150
+ return new_pointer
151
+
152
+ def watch(
153
+ self,
154
+ name: str,
155
+ base_model: str,
156
+ callback: Callable[[WeightPointerIPC], None],
157
+ poll_interval: float = 1.0,
158
+ ) -> "WeightWatcher":
159
+ """Create a watcher that monitors for pointer changes."""
160
+ return WeightWatcher(self, name, base_model, callback, poll_interval)
161
+
162
+
163
+ class WeightWatcher:
164
+ """Watches a weight pointer for changes and triggers callbacks.
165
+
166
+ Used by inference workers to hot-reload weights when updates
167
+ are published by the trainer.
168
+ """
169
+
170
+ def __init__(
171
+ self,
172
+ store: WeightPointerStore,
173
+ name: str,
174
+ base_model: str,
175
+ callback: Callable[[WeightPointerIPC], None],
176
+ poll_interval: float = 1.0,
177
+ ):
178
+ self._store = store
179
+ self._name = name
180
+ self._base_model = base_model
181
+ self._callback = callback
182
+ self._poll_interval = poll_interval
183
+ self._last_version = -1
184
+ self._running = False
185
+ self._process: Optional[mp.Process] = None
186
+
187
+ def start(self) -> None:
188
+ """Start watching in a background process."""
189
+ self._running = True
190
+ self._process = mp.Process(target=self._watch_loop)
191
+ self._process.start()
192
+
193
+ def stop(self) -> None:
194
+ """Stop the watcher."""
195
+ self._running = False
196
+ if self._process:
197
+ self._process.join(timeout=5.0)
198
+ if self._process.is_alive():
199
+ self._process.terminate()
200
+ self._process = None
201
+
202
+ def _watch_loop(self) -> None:
203
+ """Internal watch loop running in separate process."""
204
+ while self._running:
205
+ try:
206
+ pointer = self._store.load(self._name, self._base_model)
207
+ if pointer.version > self._last_version:
208
+ self._last_version = pointer.version
209
+ self._callback(pointer)
210
+ except Exception:
211
+ pass # Continue watching despite errors
212
+
213
+ time.sleep(self._poll_interval)
214
+
215
+
216
+ def load_pointer(path: Path, *, base_model: str, name: Optional[str] = None) -> WeightPointer:
217
+ """Load a weight pointer from disk (backward compatible)."""
218
+ if not path.exists():
219
+ return WeightPointer(
220
+ base_model=base_model,
221
+ adapter_path=None,
222
+ iteration=0,
223
+ updated_at=now_ts(),
224
+ name=name,
225
+ )
226
+ data = json.loads(path.read_text(encoding="utf-8"))
227
+ return WeightPointer(
228
+ base_model=data.get("base_model") or base_model,
229
+ adapter_path=data.get("adapter_path"),
230
+ iteration=int(data.get("iteration", 0)),
231
+ updated_at=data.get("updated_at") or now_ts(),
232
+ name=data.get("name") or name,
233
+ )
234
+
235
+
236
+ def save_pointer(path: Path, pointer: WeightPointer) -> None:
237
+ """Save a weight pointer to disk (backward compatible)."""
238
+ ensure_dir(path.parent)
239
+ path.write_text(
240
+ json.dumps(
241
+ {
242
+ "base_model": pointer.base_model,
243
+ "adapter_path": pointer.adapter_path,
244
+ "iteration": pointer.iteration,
245
+ "updated_at": pointer.updated_at,
246
+ "name": pointer.name,
247
+ },
248
+ indent=2,
249
+ ),
250
+ encoding="utf-8",
251
+ )
252
+
253
+
254
+ def load_pointer_ipc(path: Path, base_model: str, name: str) -> WeightPointerIPC:
255
+ """Load an IPC-enabled weight pointer."""
256
+ store = WeightPointerStore(path.parent)
257
+ return store.load(name, base_model)
258
+
259
+
260
+ def save_pointer_ipc(path: Path, pointer: WeightPointerIPC) -> None:
261
+ """Save an IPC-enabled weight pointer."""
262
+ store = WeightPointerStore(path.parent)
263
+ store.save(pointer)
mlxsmith/runs.py ADDED
@@ -0,0 +1,44 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+
6
+ import yaml
7
+
8
+ from .util import ensure_dir
9
+
10
+ @dataclass
11
+ class RunPaths:
12
+ run_dir: Path
13
+ logs_dir: Path
14
+ checkpoints_dir: Path
15
+ adapter_dir: Path
16
+ artifacts_dir: Path
17
+ metrics_path: Path
18
+ config_snapshot_path: Path
19
+
20
+ def new_run(root: Path, kind: str) -> RunPaths:
21
+ runs_root = ensure_dir(root / "runs")
22
+ # monotonically increasing id by counting existing runs of same kind
23
+ existing = sorted([p for p in runs_root.glob(f"{kind}_*") if p.is_dir()])
24
+ next_idx = len(existing) + 1
25
+ run_name = f"{kind}_{next_idx:04d}"
26
+ run_dir = ensure_dir(runs_root / run_name)
27
+ logs_dir = ensure_dir(run_dir / "logs")
28
+ checkpoints_dir = ensure_dir(run_dir / "checkpoints")
29
+ adapter_dir = ensure_dir(run_dir / "adapter")
30
+ artifacts_dir = ensure_dir(run_dir / "artifacts")
31
+ metrics_path = run_dir / "metrics.jsonl"
32
+ config_snapshot_path = run_dir / "config.snapshot.yaml"
33
+ return RunPaths(
34
+ run_dir=run_dir,
35
+ logs_dir=logs_dir,
36
+ checkpoints_dir=checkpoints_dir,
37
+ adapter_dir=adapter_dir,
38
+ artifacts_dir=artifacts_dir,
39
+ metrics_path=metrics_path,
40
+ config_snapshot_path=config_snapshot_path,
41
+ )
42
+
43
+ def snapshot_config(cfg_dict: dict, path: Path):
44
+ path.write_text(yaml.safe_dump(cfg_dict, sort_keys=False), encoding="utf-8")