mlxsmith 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. mlxsmith/__init__.py +2 -0
  2. mlxsmith/accel/__init__.py +10 -0
  3. mlxsmith/accel/base.py +17 -0
  4. mlxsmith/accel/none.py +13 -0
  5. mlxsmith/accel/zmlx_backend.py +42 -0
  6. mlxsmith/adapters.py +46 -0
  7. mlxsmith/api/__init__.py +48 -0
  8. mlxsmith/api/handlers.py +1217 -0
  9. mlxsmith/api/schemas.py +436 -0
  10. mlxsmith/auth.py +88 -0
  11. mlxsmith/bench.py +102 -0
  12. mlxsmith/cli.py +950 -0
  13. mlxsmith/config.py +543 -0
  14. mlxsmith/config_models.py +261 -0
  15. mlxsmith/data.py +493 -0
  16. mlxsmith/envs/__init__.py +33 -0
  17. mlxsmith/envs/system.py +388 -0
  18. mlxsmith/envs/token_env.py +191 -0
  19. mlxsmith/eval.py +112 -0
  20. mlxsmith/infer.py +140 -0
  21. mlxsmith/llm/__init__.py +16 -0
  22. mlxsmith/llm/backend.py +126 -0
  23. mlxsmith/llm/interface.py +212 -0
  24. mlxsmith/llm/mlx_lm_backend.py +509 -0
  25. mlxsmith/llm/mock_backend.py +228 -0
  26. mlxsmith/llm/registry.py +12 -0
  27. mlxsmith/models.py +257 -0
  28. mlxsmith/orchestrator/__init__.py +25 -0
  29. mlxsmith/orchestrator/daemon.py +454 -0
  30. mlxsmith/orchestrator/inference_worker.py +496 -0
  31. mlxsmith/orchestrator/queue.py +355 -0
  32. mlxsmith/orchestrator/trainer_worker.py +437 -0
  33. mlxsmith/rlm/__init__.py +8 -0
  34. mlxsmith/rlm/corpus.py +74 -0
  35. mlxsmith/rlm/gating.py +90 -0
  36. mlxsmith/rlm/generate.py +249 -0
  37. mlxsmith/rlm/history.py +12 -0
  38. mlxsmith/rlm/inference.py +150 -0
  39. mlxsmith/rlm/loop.py +1297 -0
  40. mlxsmith/rlm/mutate.py +82 -0
  41. mlxsmith/rlm/trainer.py +73 -0
  42. mlxsmith/rlm/weights.py +263 -0
  43. mlxsmith/runs.py +44 -0
  44. mlxsmith/sdk/__init__.py +392 -0
  45. mlxsmith/sdk/future.py +486 -0
  46. mlxsmith/sdk/losses.py +262 -0
  47. mlxsmith/sdk/sampling_client.py +729 -0
  48. mlxsmith/sdk/training_client.py +676 -0
  49. mlxsmith/server.py +376 -0
  50. mlxsmith/train/__init__.py +0 -0
  51. mlxsmith/train/distill.py +279 -0
  52. mlxsmith/train/lora.py +280 -0
  53. mlxsmith/train/pref.py +180 -0
  54. mlxsmith/train/rft.py +458 -0
  55. mlxsmith/train/sft.py +151 -0
  56. mlxsmith/util.py +174 -0
  57. mlxsmith/verifiers/__init__.py +3 -0
  58. mlxsmith/verifiers/compose.py +109 -0
  59. mlxsmith/verifiers/docker_verifier.py +111 -0
  60. mlxsmith/verifiers/jsonschema.py +54 -0
  61. mlxsmith/verifiers/pytest_verifier.py +82 -0
  62. mlxsmith/verifiers/regex.py +15 -0
  63. mlxsmith/verifiers/types.py +10 -0
  64. mlxsmith-0.1.0.dist-info/METADATA +163 -0
  65. mlxsmith-0.1.0.dist-info/RECORD +69 -0
  66. mlxsmith-0.1.0.dist-info/WHEEL +5 -0
  67. mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
  68. mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
  69. mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
mlxsmith/util.py ADDED
@@ -0,0 +1,174 @@
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ import json
5
+ import os
6
+ import platform
7
+ import shutil
8
+ import subprocess
9
+ import sys
10
+ import time
11
+ from dataclasses import dataclass
12
+ from pathlib import Path
13
+ from typing import Optional
14
+
15
+ from rich.console import Console
16
+
17
+ console = Console()
18
+
19
+ def sha1_text(s: str) -> str:
20
+ return hashlib.sha1(s.encode("utf-8")).hexdigest()
21
+
22
+ def ensure_dir(p: Path) -> Path:
23
+ p.mkdir(parents=True, exist_ok=True)
24
+ return p
25
+
26
+ def write_jsonl(path: Path, rows):
27
+ ensure_dir(path.parent)
28
+ with path.open("a", encoding="utf-8") as f:
29
+ for r in rows:
30
+ f.write(json.dumps(r, ensure_ascii=False) + "\n")
31
+
32
+ def now_ts() -> str:
33
+ return time.strftime("%Y%m%d_%H%M%S")
34
+
35
+ def run_cmd(cmd: list[str], cwd: Optional[Path] = None, timeout_s: Optional[int] = None) -> subprocess.CompletedProcess:
36
+ return subprocess.run(cmd, cwd=str(cwd) if cwd else None, timeout=timeout_s, capture_output=True, text=True)
37
+
38
+ @dataclass
39
+ class SystemInfo:
40
+ python: str
41
+ python_arch: str
42
+ platform: str
43
+ macos_version: Optional[str]
44
+ machine: str
45
+ cpu_count: int
46
+ has_metal: Optional[bool]
47
+ has_mlx: bool
48
+ mlx_version: Optional[str]
49
+ has_zmlx: bool
50
+
51
+ def detect_system() -> SystemInfo:
52
+ has_mlx = False
53
+ mlx_version = None
54
+ try:
55
+ import mlx # type: ignore
56
+ has_mlx = True
57
+ mlx_version = getattr(mlx, "__version__", None)
58
+ except Exception:
59
+ pass
60
+
61
+ import importlib.util
62
+
63
+ has_zmlx = importlib.util.find_spec("zmlx") is not None
64
+
65
+ # Metal detection (best-effort): on macOS we assume Metal is present; for CI, this is not reliable.
66
+ has_metal = None
67
+ if sys.platform == "darwin":
68
+ has_metal = True
69
+
70
+ macos_version = None
71
+ if sys.platform == "darwin":
72
+ macos_version = platform.mac_ver()[0] or None
73
+
74
+ py_arch = platform.architecture()[0] or "unknown"
75
+
76
+ return SystemInfo(
77
+ python=sys.version.split()[0],
78
+ python_arch=py_arch,
79
+ platform=platform.platform(),
80
+ macos_version=macos_version,
81
+ machine=platform.machine(),
82
+ cpu_count=os.cpu_count() or 0,
83
+ has_metal=has_metal,
84
+ has_mlx=has_mlx,
85
+ mlx_version=mlx_version,
86
+ has_zmlx=has_zmlx,
87
+ )
88
+
89
+ def require(cond: bool, msg: str):
90
+ if not cond:
91
+ raise RuntimeError(msg)
92
+
93
+ def copytree(src: Path, dst: Path):
94
+ if dst.exists():
95
+ shutil.rmtree(dst)
96
+ shutil.copytree(src, dst)
97
+
98
+
99
+ def tree_map(fn, tree):
100
+ if tree is None:
101
+ return None
102
+ if isinstance(tree, dict):
103
+ return {k: tree_map(fn, v) for k, v in tree.items()}
104
+ if isinstance(tree, (list, tuple)):
105
+ return type(tree)(tree_map(fn, v) for v in tree)
106
+ return fn(tree)
107
+
108
+
109
+ def tree_add(a, b):
110
+ if a is None:
111
+ return b
112
+ if b is None:
113
+ return a
114
+ if isinstance(a, dict) and isinstance(b, dict):
115
+ return {k: tree_add(a.get(k), b.get(k)) for k in a.keys() | b.keys()}
116
+ if isinstance(a, (list, tuple)) and isinstance(b, (list, tuple)):
117
+ return type(a)(tree_add(x, y) for x, y in zip(a, b))
118
+ return a + b
119
+
120
+
121
+ def tree_scale(tree, scale: float):
122
+ return tree_map(lambda x: x * scale, tree)
123
+
124
+
125
+ def tree_leaves(tree) -> list:
126
+ leaves = []
127
+ if tree is None:
128
+ return leaves
129
+ if isinstance(tree, dict):
130
+ for v in tree.values():
131
+ leaves.extend(tree_leaves(v))
132
+ elif isinstance(tree, (list, tuple)):
133
+ for v in tree:
134
+ leaves.extend(tree_leaves(v))
135
+ else:
136
+ leaves.append(tree)
137
+ return leaves
138
+
139
+
140
+ def clip_grad_norm(grads, max_norm: float):
141
+ """Clip gradients by global L2 norm. Returns clipped grads."""
142
+ import mlx.core as mx
143
+
144
+ leaves = tree_leaves(grads)
145
+ if not leaves:
146
+ return grads
147
+ total_norm_sq = mx.array(0.0)
148
+ for g in leaves:
149
+ total_norm_sq = total_norm_sq + (g * g).sum()
150
+ total_norm = mx.sqrt(total_norm_sq)
151
+ clip_coef = mx.minimum(mx.array(max_norm) / mx.maximum(total_norm, mx.array(1e-8)), mx.array(1.0))
152
+ return tree_map(lambda g: g * clip_coef, grads)
153
+
154
+
155
+ def latency_summary_ms(samples: list[float]) -> dict[str, float]:
156
+ if not samples:
157
+ return {}
158
+ items = sorted(samples)
159
+ n = len(items)
160
+ mean = sum(items) / n
161
+
162
+ def _pct(p: float) -> float:
163
+ if n == 1:
164
+ return items[0]
165
+ idx = int((p / 100.0) * (n - 1))
166
+ return items[max(0, min(idx, n - 1))]
167
+
168
+ return {
169
+ "mean": mean,
170
+ "p50": _pct(50),
171
+ "p90": _pct(90),
172
+ "p99": _pct(99),
173
+ "max": items[-1],
174
+ }
@@ -0,0 +1,3 @@
1
+ from .types import VerifyResult
2
+
3
+ __all__ = ["VerifyResult"]
@@ -0,0 +1,109 @@
1
+ from __future__ import annotations
2
+
3
+ import importlib.util
4
+ import time
5
+ from pathlib import Path
6
+ from typing import Any, Dict, List, Optional
7
+
8
+ from .types import VerifyResult
9
+
10
+
11
+ def _load_verifier(path: str):
12
+ verifier_path = Path(path)
13
+ spec = importlib.util.spec_from_file_location(verifier_path.stem, verifier_path)
14
+ if spec is None or spec.loader is None:
15
+ raise RuntimeError(f"Could not load verifier: {verifier_path}")
16
+ module = importlib.util.module_from_spec(spec)
17
+ spec.loader.exec_module(module) # type: ignore
18
+ verify_fn = getattr(module, "verify", None)
19
+ if not callable(verify_fn):
20
+ raise RuntimeError(f"Verifier must define verify(...): {verifier_path}")
21
+ return verify_fn
22
+
23
+
24
+ def verify(
25
+ prompt: str,
26
+ completion: str,
27
+ workdir: str,
28
+ *,
29
+ verifiers: List[Any],
30
+ mode: str = "all",
31
+ weights: Optional[List[float]] = None,
32
+ per_verifier_kwargs: Optional[Dict[str, Dict[str, Any]]] = None,
33
+ **kwargs,
34
+ ) -> VerifyResult:
35
+ """Compose multiple verifiers with AND/OR/weighted reward aggregation.
36
+
37
+ verifiers: list of paths or dicts {path, kwargs}.
38
+ mode: all | any | weighted
39
+ weights: optional weights for weighted reward aggregation.
40
+ per_verifier_kwargs: optional mapping of path -> kwargs.
41
+ """
42
+ results = []
43
+ latencies: Dict[str, float] = {}
44
+ per_kwargs = per_verifier_kwargs or {}
45
+
46
+ for idx, entry in enumerate(verifiers):
47
+ if isinstance(entry, dict):
48
+ path = entry.get("path")
49
+ extra = entry.get("kwargs") or {}
50
+ else:
51
+ path = entry
52
+ extra = {}
53
+
54
+ if not path:
55
+ continue
56
+ verify_fn = _load_verifier(str(path))
57
+ merged = dict(kwargs)
58
+ merged.update(per_kwargs.get(str(path), {}))
59
+ merged.update(extra)
60
+
61
+ t0 = time.time()
62
+ res = verify_fn(prompt, completion, workdir, **merged)
63
+ latencies[str(path)] = (time.time() - t0) * 1000.0
64
+ results.append((str(path), res))
65
+
66
+ if not results:
67
+ return VerifyResult(
68
+ reward=0.0,
69
+ passed=False,
70
+ info={"error": "no_verifiers"},
71
+ artifacts_dir=workdir,
72
+ )
73
+
74
+ mode = (mode or "all").lower()
75
+ passes = [bool(getattr(r, "passed", False)) for _p, r in results]
76
+ rewards = [float(getattr(r, "reward", 0.0)) for _p, r in results]
77
+
78
+ if mode == "any":
79
+ passed = any(passes)
80
+ reward = max(rewards) if rewards else 0.0
81
+ elif mode == "weighted":
82
+ if weights and len(weights) == len(rewards):
83
+ reward = sum(w * r for w, r in zip(weights, rewards))
84
+ else:
85
+ reward = sum(rewards) / max(1, len(rewards))
86
+ passed = reward > 0.0 and any(passes)
87
+ else:
88
+ passed = all(passes)
89
+ reward = sum(rewards) / max(1, len(rewards))
90
+
91
+ return VerifyResult(
92
+ reward=reward,
93
+ passed=passed,
94
+ info={
95
+ "mode": mode,
96
+ "verifiers": [
97
+ {
98
+ "path": path,
99
+ "passed": bool(getattr(res, "passed", False)),
100
+ "reward": float(getattr(res, "reward", 0.0)),
101
+ "info": getattr(res, "info", {}),
102
+ }
103
+ for path, res in results
104
+ ],
105
+ "verifier_latencies_ms": latencies,
106
+ },
107
+ artifacts_dir=workdir,
108
+ )
109
+
@@ -0,0 +1,111 @@
1
+ from __future__ import annotations
2
+
3
+ import shutil
4
+ import subprocess
5
+ from pathlib import Path
6
+
7
+ from .types import VerifyResult
8
+ from .pytest_verifier import verify as local_verify
9
+
10
+
11
+ def verify(
12
+ prompt: str,
13
+ completion: str,
14
+ workdir: str,
15
+ *,
16
+ tests_subdir: str = "tests",
17
+ timeout_s: int = 30,
18
+ reward_pass: float = 1.0,
19
+ reward_fail: float = 0.0,
20
+ image: str = "python:3.11-slim",
21
+ memory_mb: int = 512,
22
+ cpus: float = 1.0,
23
+ pids: int = 128,
24
+ use_local_fallback: bool = True,
25
+ ) -> VerifyResult:
26
+ """Run pytest inside a locked-down Docker container.
27
+
28
+ Falls back to local pytest verifier if Docker is unavailable and use_local_fallback is True.
29
+ """
30
+ if shutil.which("docker") is None:
31
+ if use_local_fallback:
32
+ return local_verify(
33
+ prompt,
34
+ completion,
35
+ workdir,
36
+ tests_subdir=tests_subdir,
37
+ timeout_s=timeout_s,
38
+ reward_pass=reward_pass,
39
+ reward_fail=reward_fail,
40
+ )
41
+ return VerifyResult(
42
+ reward=reward_fail,
43
+ passed=False,
44
+ info={"error": "docker_not_found"},
45
+ artifacts_dir=workdir,
46
+ )
47
+
48
+ wd = Path(workdir)
49
+ wd.mkdir(parents=True, exist_ok=True)
50
+ if not any(wd.glob("*.py")):
51
+ (wd / "main.py").write_text(completion, encoding="utf-8")
52
+
53
+ tests_path = wd / tests_subdir
54
+ if not tests_path.exists():
55
+ return VerifyResult(
56
+ reward=reward_fail,
57
+ passed=False,
58
+ info={"error": f"Missing tests folder: {tests_subdir}"},
59
+ artifacts_dir=str(wd),
60
+ )
61
+
62
+ cmd = [
63
+ "docker",
64
+ "run",
65
+ "--rm",
66
+ "--network",
67
+ "none",
68
+ "--read-only",
69
+ "--pids-limit",
70
+ str(int(pids)),
71
+ "--memory",
72
+ f"{int(memory_mb)}m",
73
+ "--cpus",
74
+ str(float(cpus)),
75
+ "--tmpfs",
76
+ "/tmp:rw,noexec,nosuid,nodev",
77
+ "-v",
78
+ f"{wd}:/workspace:rw",
79
+ "-w",
80
+ "/workspace",
81
+ image,
82
+ "pytest",
83
+ "-q",
84
+ ]
85
+
86
+ try:
87
+ proc = subprocess.run(
88
+ cmd,
89
+ capture_output=True,
90
+ text=True,
91
+ timeout=timeout_s,
92
+ )
93
+ passed = proc.returncode == 0
94
+ return VerifyResult(
95
+ reward=reward_pass if passed else reward_fail,
96
+ passed=passed,
97
+ info={
98
+ "returncode": proc.returncode,
99
+ "stdout": proc.stdout[-4000:],
100
+ "stderr": proc.stderr[-4000:],
101
+ "docker_image": image,
102
+ },
103
+ artifacts_dir=str(wd),
104
+ )
105
+ except subprocess.TimeoutExpired:
106
+ return VerifyResult(
107
+ reward=reward_fail,
108
+ passed=False,
109
+ info={"error": "timeout", "timeout_s": timeout_s},
110
+ artifacts_dir=str(wd),
111
+ )
@@ -0,0 +1,54 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ from jsonschema import validate, ValidationError
8
+
9
+ from .types import VerifyResult
10
+
11
+
12
+ def _load_schema(schema: Any) -> Any:
13
+ if isinstance(schema, str):
14
+ p = Path(schema)
15
+ if p.exists():
16
+ return json.loads(p.read_text(encoding="utf-8"))
17
+ return schema
18
+
19
+
20
+ def verify(
21
+ prompt: str,
22
+ completion: str,
23
+ workdir: str,
24
+ *,
25
+ schema: Any,
26
+ reward_pass: float = 1.0,
27
+ reward_fail: float = 0.0,
28
+ ) -> VerifyResult:
29
+ try:
30
+ data = json.loads(completion)
31
+ except json.JSONDecodeError as e:
32
+ return VerifyResult(
33
+ reward=reward_fail,
34
+ passed=False,
35
+ info={"error": "invalid_json", "detail": str(e)},
36
+ artifacts_dir=None,
37
+ )
38
+
39
+ schema_obj = _load_schema(schema)
40
+ try:
41
+ validate(instance=data, schema=schema_obj)
42
+ return VerifyResult(
43
+ reward=reward_pass,
44
+ passed=True,
45
+ info={"schema": schema_obj},
46
+ artifacts_dir=None,
47
+ )
48
+ except ValidationError as e:
49
+ return VerifyResult(
50
+ reward=reward_fail,
51
+ passed=False,
52
+ info={"error": "schema_validation", "detail": str(e)},
53
+ artifacts_dir=None,
54
+ )
@@ -0,0 +1,82 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import subprocess
5
+ from pathlib import Path
6
+ from .types import VerifyResult
7
+
8
+ def _sandbox_env(base_env: dict | None = None, *, workdir: Path) -> dict:
9
+ env = dict(os.environ)
10
+ if base_env:
11
+ env.update(base_env)
12
+ env.setdefault("PYTHONHASHSEED", "0")
13
+ env["HOME"] = str(workdir)
14
+ env["TMPDIR"] = str(workdir)
15
+ env["PYTHONPATH"] = str(workdir)
16
+ env.pop("PYTHONSTARTUP", None)
17
+ env.pop("VIRTUAL_ENV", None)
18
+ return env
19
+
20
+
21
+ def verify(
22
+ prompt: str,
23
+ completion: str,
24
+ workdir: str,
25
+ *,
26
+ tests_subdir: str = "tests",
27
+ timeout_s: int = 30,
28
+ reward_pass: float = 1.0,
29
+ reward_fail: float = 0.0,
30
+ ) -> VerifyResult:
31
+ """Run pytest in a sandbox directory.
32
+
33
+ Expected usage: environment builder writes files into `workdir` and this verifier runs tests there.
34
+ For v0, we simply run `pytest -q` and treat exit code 0 as pass.
35
+
36
+ IMPORTANT: This runs locally. For stronger isolation, replace with a sandbox runner.
37
+ """
38
+ wd = Path(workdir)
39
+ wd.mkdir(parents=True, exist_ok=True)
40
+
41
+ # If user is doing code tasks, they can write completion into a file convention, e.g., main.py
42
+ # Here we create a default file if none exist:
43
+ if not any(wd.glob("*.py")):
44
+ (wd / "main.py").write_text(completion, encoding="utf-8")
45
+
46
+ # Ensure tests folder exists; if not, fail deterministically.
47
+ tests_path = wd / tests_subdir
48
+ if not tests_path.exists():
49
+ return VerifyResult(
50
+ reward=reward_fail,
51
+ passed=False,
52
+ info={"error": f"Missing tests folder: {tests_subdir}"},
53
+ artifacts_dir=str(wd),
54
+ )
55
+
56
+ try:
57
+ proc = subprocess.run(
58
+ ["pytest", "-q"],
59
+ cwd=str(wd),
60
+ capture_output=True,
61
+ text=True,
62
+ timeout=timeout_s,
63
+ env=_sandbox_env(workdir=wd),
64
+ )
65
+ passed = proc.returncode == 0
66
+ return VerifyResult(
67
+ reward=reward_pass if passed else reward_fail,
68
+ passed=passed,
69
+ info={
70
+ "returncode": proc.returncode,
71
+ "stdout": proc.stdout[-4000:],
72
+ "stderr": proc.stderr[-4000:],
73
+ },
74
+ artifacts_dir=str(wd),
75
+ )
76
+ except subprocess.TimeoutExpired:
77
+ return VerifyResult(
78
+ reward=reward_fail,
79
+ passed=False,
80
+ info={"error": "timeout", "timeout_s": timeout_s},
81
+ artifacts_dir=str(wd),
82
+ )
@@ -0,0 +1,15 @@
1
+ from __future__ import annotations
2
+
3
+ import re
4
+
5
+ from .types import VerifyResult
6
+
7
+ def verify(prompt: str, completion: str, workdir: str, *, pattern: str, flags: int = 0, reward_pass: float = 1.0, reward_fail: float = 0.0) -> VerifyResult:
8
+ m = re.search(pattern, completion, flags=flags)
9
+ passed = m is not None
10
+ return VerifyResult(
11
+ reward=reward_pass if passed else reward_fail,
12
+ passed=passed,
13
+ info={"pattern": pattern, "match": m.group(0) if m else None},
14
+ artifacts_dir=None,
15
+ )
@@ -0,0 +1,10 @@
1
+ from __future__ import annotations
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, Optional
4
+
5
+ @dataclass
6
+ class VerifyResult:
7
+ reward: float
8
+ passed: bool
9
+ info: Dict[str, Any]
10
+ artifacts_dir: Optional[str] = None