mlxsmith 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. mlxsmith/__init__.py +2 -0
  2. mlxsmith/accel/__init__.py +10 -0
  3. mlxsmith/accel/base.py +17 -0
  4. mlxsmith/accel/none.py +13 -0
  5. mlxsmith/accel/zmlx_backend.py +42 -0
  6. mlxsmith/adapters.py +46 -0
  7. mlxsmith/api/__init__.py +48 -0
  8. mlxsmith/api/handlers.py +1217 -0
  9. mlxsmith/api/schemas.py +436 -0
  10. mlxsmith/auth.py +88 -0
  11. mlxsmith/bench.py +102 -0
  12. mlxsmith/cli.py +950 -0
  13. mlxsmith/config.py +543 -0
  14. mlxsmith/config_models.py +261 -0
  15. mlxsmith/data.py +493 -0
  16. mlxsmith/envs/__init__.py +33 -0
  17. mlxsmith/envs/system.py +388 -0
  18. mlxsmith/envs/token_env.py +191 -0
  19. mlxsmith/eval.py +112 -0
  20. mlxsmith/infer.py +140 -0
  21. mlxsmith/llm/__init__.py +16 -0
  22. mlxsmith/llm/backend.py +126 -0
  23. mlxsmith/llm/interface.py +212 -0
  24. mlxsmith/llm/mlx_lm_backend.py +509 -0
  25. mlxsmith/llm/mock_backend.py +228 -0
  26. mlxsmith/llm/registry.py +12 -0
  27. mlxsmith/models.py +257 -0
  28. mlxsmith/orchestrator/__init__.py +25 -0
  29. mlxsmith/orchestrator/daemon.py +454 -0
  30. mlxsmith/orchestrator/inference_worker.py +496 -0
  31. mlxsmith/orchestrator/queue.py +355 -0
  32. mlxsmith/orchestrator/trainer_worker.py +437 -0
  33. mlxsmith/rlm/__init__.py +8 -0
  34. mlxsmith/rlm/corpus.py +74 -0
  35. mlxsmith/rlm/gating.py +90 -0
  36. mlxsmith/rlm/generate.py +249 -0
  37. mlxsmith/rlm/history.py +12 -0
  38. mlxsmith/rlm/inference.py +150 -0
  39. mlxsmith/rlm/loop.py +1297 -0
  40. mlxsmith/rlm/mutate.py +82 -0
  41. mlxsmith/rlm/trainer.py +73 -0
  42. mlxsmith/rlm/weights.py +263 -0
  43. mlxsmith/runs.py +44 -0
  44. mlxsmith/sdk/__init__.py +392 -0
  45. mlxsmith/sdk/future.py +486 -0
  46. mlxsmith/sdk/losses.py +262 -0
  47. mlxsmith/sdk/sampling_client.py +729 -0
  48. mlxsmith/sdk/training_client.py +676 -0
  49. mlxsmith/server.py +376 -0
  50. mlxsmith/train/__init__.py +0 -0
  51. mlxsmith/train/distill.py +279 -0
  52. mlxsmith/train/lora.py +280 -0
  53. mlxsmith/train/pref.py +180 -0
  54. mlxsmith/train/rft.py +458 -0
  55. mlxsmith/train/sft.py +151 -0
  56. mlxsmith/util.py +174 -0
  57. mlxsmith/verifiers/__init__.py +3 -0
  58. mlxsmith/verifiers/compose.py +109 -0
  59. mlxsmith/verifiers/docker_verifier.py +111 -0
  60. mlxsmith/verifiers/jsonschema.py +54 -0
  61. mlxsmith/verifiers/pytest_verifier.py +82 -0
  62. mlxsmith/verifiers/regex.py +15 -0
  63. mlxsmith/verifiers/types.py +10 -0
  64. mlxsmith-0.1.0.dist-info/METADATA +163 -0
  65. mlxsmith-0.1.0.dist-info/RECORD +69 -0
  66. mlxsmith-0.1.0.dist-info/WHEEL +5 -0
  67. mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
  68. mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
  69. mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,249 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import re
5
+ from dataclasses import dataclass
6
+ from typing import Iterable, List, Optional, Sequence, Set
7
+
8
+ from ..util import sha1_text
9
+
10
+
11
+ @dataclass
12
+ class GeneratedTask:
13
+ id: str
14
+ prompt: str
15
+ tests: str
16
+ description: Optional[str] = None
17
+
18
+
19
+ _FALLBACK_TASKS = [
20
+ {
21
+ "id": "task_factorial",
22
+ "description": "Write a function solve(n: int) -> int that returns n! for n >= 0.",
23
+ "signature": "def solve(n: int) -> int",
24
+ "tests": [
25
+ {"input": [0], "expected": 1},
26
+ {"input": [5], "expected": 120},
27
+ {"input": [7], "expected": 5040},
28
+ ],
29
+ },
30
+ {
31
+ "id": "task_reverse",
32
+ "description": "Write a function solve(s: str) -> str that returns the reverse of s.",
33
+ "signature": "def solve(s: str) -> str",
34
+ "tests": [
35
+ {"input": ["abc"], "expected": "cba"},
36
+ {"input": [""], "expected": ""},
37
+ {"input": ["racecar"], "expected": "racecar"},
38
+ ],
39
+ },
40
+ ]
41
+
42
+ _DEFAULT_BLOCKLIST = [
43
+ r"\bsubprocess\b",
44
+ r"\bos\.system\b",
45
+ r"\bshutil\.rmtree\b",
46
+ r"\brm\s+-rf\b",
47
+ r"\brequests\b",
48
+ r"\burllib\b",
49
+ r"\bsocket\b",
50
+ r"\bhttp[s]?://",
51
+ r"\bpip\s+install\b",
52
+ r"\bapt-get\b",
53
+ r"\bbrew\s+install\b",
54
+ ]
55
+
56
+
57
+ def extract_json_objects(text: str) -> List[dict]:
58
+ # Try fenced json blocks first.
59
+ fenced = re.findall(r"```json\s*(.*?)```", text, flags=re.DOTALL | re.IGNORECASE)
60
+ chunks = fenced if fenced else [text]
61
+
62
+ items: List[dict] = []
63
+ for chunk in chunks:
64
+ buf = ""
65
+ depth = 0
66
+ for ch in chunk:
67
+ if ch == "{":
68
+ depth += 1
69
+ if depth > 0:
70
+ buf += ch
71
+ if ch == "}" and depth > 0:
72
+ depth -= 1
73
+ if depth == 0:
74
+ try:
75
+ items.append(json.loads(buf))
76
+ except Exception:
77
+ pass
78
+ buf = ""
79
+ return items
80
+
81
+
82
+ def _tests_from_cases(cases: Iterable[dict]) -> str:
83
+ lines = ["from main import solve", "", ""]
84
+ for idx, case in enumerate(cases):
85
+ args = case.get("input", [])
86
+ expected = case.get("expected")
87
+ if not isinstance(args, (list, tuple)):
88
+ args = [args]
89
+ lines.append(f"def test_case_{idx}():")
90
+ args_expr = ", ".join(repr(a) for a in args)
91
+ lines.append(f" assert solve({args_expr}) == {repr(expected)}")
92
+ lines.append("")
93
+ return "\n".join(lines).strip() + "\n"
94
+
95
+
96
+ def task_to_prompt(task: dict, *, require_recursion: bool) -> str:
97
+ desc = task.get("description") or task.get("prompt") or ""
98
+ sig = task.get("signature") or "def solve(...):"
99
+ suffix = "\nUse recursion." if require_recursion else ""
100
+ return f"{desc}\n{sig}\nReturn only Python code.{suffix}".strip()
101
+
102
+
103
+ def task_to_tests(task: dict) -> str:
104
+ tests = task.get("tests")
105
+ if isinstance(tests, str):
106
+ return tests
107
+ if isinstance(tests, list):
108
+ # List of pre-formatted test strings (e.g. from Qwen3-style generation)
109
+ if tests and isinstance(tests[0], str):
110
+ return "\n\n".join(tests).strip() + "\n"
111
+ # List of structured {input, expected} dicts
112
+ return _tests_from_cases(tests)
113
+ # fallback: trivial test that always fails to avoid false positives
114
+ return "def test_placeholder():\n assert False\n"
115
+
116
+
117
+ def _token_set(text: str) -> Set[str]:
118
+ return set(re.findall(r"[a-z0-9_]+", text.lower()))
119
+
120
+
121
+ def _jaccard(a: Set[str], b: Set[str]) -> float:
122
+ if not a and not b:
123
+ return 1.0
124
+ denom = len(a | b)
125
+ if denom == 0:
126
+ return 0.0
127
+ return len(a & b) / denom
128
+
129
+
130
+ def filter_tasks(
131
+ tasks: Sequence[GeneratedTask],
132
+ *,
133
+ existing_prompts: Optional[Sequence[str]] = None,
134
+ similarity_threshold: float = 0.85,
135
+ min_desc_len: int = 10,
136
+ min_asserts: int = 2,
137
+ max_prompt_len: int = 2000,
138
+ min_tests_len: int = 20,
139
+ max_tests_len: int = 8000,
140
+ blocked_patterns: Optional[Sequence[str]] = None,
141
+ ) -> List[GeneratedTask]:
142
+ """Filter tasks by basic quality + similarity + dedup."""
143
+ existing_prompts = existing_prompts or []
144
+ existing_tokens = [_token_set(p) for p in existing_prompts if p]
145
+ seen_hashes: Set[str] = set()
146
+ filtered: List[GeneratedTask] = []
147
+ patterns = list(blocked_patterns) if blocked_patterns else _DEFAULT_BLOCKLIST
148
+
149
+ for task in tasks:
150
+ prompt = task.prompt or ""
151
+ if len(prompt) > max_prompt_len:
152
+ continue
153
+ desc = task.description or prompt
154
+ if len(desc.strip()) < min_desc_len:
155
+ continue
156
+ tests = task.tests or ""
157
+ if len(tests.strip()) < min_tests_len or len(tests) > max_tests_len:
158
+ continue
159
+ if patterns:
160
+ blocked = False
161
+ for pattern in patterns:
162
+ if re.search(pattern, prompt, flags=re.IGNORECASE) or re.search(pattern, tests, flags=re.IGNORECASE):
163
+ blocked = True
164
+ break
165
+ if blocked:
166
+ continue
167
+ asserts = sum(1 for line in task.tests.splitlines() if line.strip().startswith("assert"))
168
+ if asserts < min_asserts:
169
+ continue
170
+
171
+ key = sha1_text(prompt + "\n" + task.tests)
172
+ if key in seen_hashes:
173
+ continue
174
+
175
+ tokens = _token_set(prompt)
176
+ too_similar = False
177
+ for ex in existing_tokens:
178
+ if _jaccard(tokens, ex) >= similarity_threshold:
179
+ too_similar = True
180
+ break
181
+ if too_similar:
182
+ continue
183
+
184
+ seen_hashes.add(key)
185
+ filtered.append(task)
186
+
187
+ return filtered
188
+
189
+
190
+ def generate_tasks(
191
+ llm,
192
+ *,
193
+ tasks_per_iter: int,
194
+ temperature: float,
195
+ max_new_tokens: int,
196
+ top_p: float,
197
+ top_k: Optional[int],
198
+ require_recursion: bool,
199
+ task_domains: Iterable[str],
200
+ ) -> List[GeneratedTask]:
201
+ if tasks_per_iter <= 0:
202
+ return []
203
+
204
+ domain_list = ", ".join(task_domains) if task_domains else "general"
205
+ prompt = (
206
+ "Generate JSON objects for coding tasks. Each JSON must include: "
207
+ "id, description, signature, tests (array of {input, expected}). "
208
+ f"Domain focus: {domain_list}. Return only JSON objects."
209
+ )
210
+
211
+ gen = llm.generate(
212
+ prompt,
213
+ max_new_tokens=max_new_tokens,
214
+ temperature=temperature,
215
+ top_p=top_p,
216
+ top_k=top_k,
217
+ )
218
+ items = extract_json_objects(gen.text)
219
+ tasks: List[GeneratedTask] = []
220
+ for item in items:
221
+ tid = item.get("id") or sha1_text(json.dumps(item, sort_keys=True))[:12]
222
+ task_prompt = task_to_prompt(item, require_recursion=require_recursion)
223
+ tests = task_to_tests(item)
224
+ tasks.append(
225
+ GeneratedTask(
226
+ id=str(tid),
227
+ prompt=task_prompt,
228
+ tests=tests,
229
+ description=item.get("description"),
230
+ )
231
+ )
232
+ if len(tasks) >= tasks_per_iter:
233
+ break
234
+
235
+ if not tasks:
236
+ for item in _FALLBACK_TASKS:
237
+ tid = item.get("id") or sha1_text(json.dumps(item, sort_keys=True))[:12]
238
+ tasks.append(
239
+ GeneratedTask(
240
+ id=str(tid),
241
+ prompt=task_to_prompt(item, require_recursion=require_recursion),
242
+ tests=task_to_tests(item),
243
+ description=item.get("description"),
244
+ )
245
+ )
246
+ if len(tasks) >= tasks_per_iter:
247
+ break
248
+
249
+ return tasks
@@ -0,0 +1,12 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from pathlib import Path
5
+
6
+ from ..util import ensure_dir
7
+
8
+
9
+ def append_history(path: Path, record: dict) -> None:
10
+ ensure_dir(path.parent)
11
+ with path.open("a", encoding="utf-8") as f:
12
+ f.write(json.dumps(record, ensure_ascii=False) + "\n")
@@ -0,0 +1,150 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Iterable, List, Optional
7
+
8
+ from ..config import ProjectConfig
9
+ from ..util import ensure_dir, now_ts
10
+ from ..verifiers.docker_verifier import verify as docker_verify
11
+ from ..verifiers.pytest_verifier import verify as pytest_verify
12
+ from .generate import GeneratedTask, generate_tasks, filter_tasks
13
+ from .mutate import mutate_tasks
14
+
15
+
16
+ @dataclass
17
+ class Rollout:
18
+ task_id: str
19
+ prompt: str
20
+ completion: str
21
+ token_ids: list[int]
22
+ prompt_len: int
23
+ logprobs: Optional[list[float]]
24
+ passed: bool
25
+ reward: float
26
+ verifier_latency_ms: float
27
+ weight_adapter: Optional[str]
28
+
29
+
30
+ def build_tasks(
31
+ llm,
32
+ cfg: ProjectConfig,
33
+ *,
34
+ require_recursion: bool,
35
+ tasks_per_iter: int,
36
+ mutations_per_task: int,
37
+ max_total: int,
38
+ existing_prompts: Optional[list[str]] = None,
39
+ ) -> List[GeneratedTask]:
40
+ tasks = generate_tasks(
41
+ llm,
42
+ tasks_per_iter=tasks_per_iter,
43
+ temperature=float(cfg.rft.temperature),
44
+ max_new_tokens=int(cfg.rlm.task_gen_max_new_tokens),
45
+ top_p=float(cfg.infer.top_p),
46
+ top_k=cfg.infer.top_k,
47
+ require_recursion=require_recursion,
48
+ task_domains=list(cfg.rlm.task_domains or []),
49
+ )
50
+ if bool(cfg.rlm.use_task_mutation) and int(mutations_per_task) > 0:
51
+ tasks = mutate_tasks(
52
+ llm,
53
+ tasks,
54
+ mutations_per_task=mutations_per_task,
55
+ max_total=max_total,
56
+ temperature=float(cfg.rft.temperature),
57
+ max_new_tokens=int(cfg.rlm.task_gen_max_new_tokens),
58
+ top_p=float(cfg.infer.top_p),
59
+ top_k=cfg.infer.top_k,
60
+ require_recursion=require_recursion,
61
+ )
62
+ filtered = filter_tasks(
63
+ tasks,
64
+ existing_prompts=existing_prompts,
65
+ similarity_threshold=float(getattr(cfg.rlm, "similarity_threshold", 0.85)),
66
+ min_desc_len=int(getattr(cfg.rlm, "min_task_desc_len", 10)),
67
+ min_asserts=int(getattr(cfg.rlm, "min_task_asserts", 2)),
68
+ max_prompt_len=int(getattr(cfg.rlm, "max_task_prompt_len", 2000)),
69
+ min_tests_len=int(getattr(cfg.rlm, "min_task_tests_len", 20)),
70
+ max_tests_len=int(getattr(cfg.rlm, "max_task_tests_len", 8000)),
71
+ blocked_patterns=getattr(cfg.rlm, "blocked_task_patterns", None),
72
+ )
73
+ return filtered or tasks
74
+
75
+
76
+ def _write_task_tests(task: GeneratedTask, workdir: Path) -> None:
77
+ tests_dir = ensure_dir(workdir / "tests")
78
+ (tests_dir / "test_task.py").write_text(task.tests, encoding="utf-8")
79
+
80
+
81
+ def collect_rollouts(
82
+ llm,
83
+ tasks: Iterable[GeneratedTask],
84
+ cfg: ProjectConfig,
85
+ *,
86
+ artifacts_dir: Path,
87
+ verifier_backend: str,
88
+ weight_adapter: Optional[str],
89
+ ) -> tuple[List[Rollout], list[dict]]:
90
+ rollouts: List[Rollout] = []
91
+ passed_samples: list[dict] = []
92
+
93
+ for task in tasks:
94
+ for k in range(int(cfg.rlm.rollouts_per_task)):
95
+ gen = llm.generate_with_logprobs(
96
+ task.prompt,
97
+ max_new_tokens=int(cfg.rft.max_new_tokens),
98
+ temperature=float(cfg.rft.temperature),
99
+ seed=int(time.time() * 1000) % (2**31 - 1),
100
+ )
101
+ completion = gen.text[len(task.prompt) :] if gen.text.startswith(task.prompt) else gen.text
102
+ wdir = ensure_dir(artifacts_dir / task.id / f"rollout_{k:02d}")
103
+ (wdir / "main.py").write_text(completion, encoding="utf-8")
104
+ _write_task_tests(task, wdir)
105
+
106
+ t0 = time.time()
107
+ if verifier_backend == "docker":
108
+ res = docker_verify(
109
+ task.prompt,
110
+ completion,
111
+ str(wdir),
112
+ timeout_s=int(cfg.rlm.verifier_timeout_s),
113
+ image=cfg.rlm.docker_image,
114
+ memory_mb=int(cfg.rlm.docker_memory_mb),
115
+ cpus=float(cfg.rlm.docker_cpus),
116
+ pids=int(cfg.rlm.docker_pids),
117
+ )
118
+ else:
119
+ res = pytest_verify(task.prompt, completion, str(wdir), timeout_s=int(cfg.rlm.verifier_timeout_s))
120
+ latency_ms = (time.time() - t0) * 1000.0
121
+
122
+ passed = bool(getattr(res, "passed", False))
123
+ reward = float(getattr(res, "reward", 0.0))
124
+ rollouts.append(
125
+ Rollout(
126
+ task_id=task.id,
127
+ prompt=task.prompt,
128
+ completion=completion,
129
+ token_ids=list(gen.token_ids),
130
+ prompt_len=int(gen.prompt_len),
131
+ logprobs=list(gen.logprobs) if gen.logprobs is not None else None,
132
+ passed=passed,
133
+ reward=reward,
134
+ verifier_latency_ms=latency_ms,
135
+ weight_adapter=weight_adapter,
136
+ )
137
+ )
138
+
139
+ if passed:
140
+ passed_samples.append(
141
+ {
142
+ "id": task.id,
143
+ "prompt": task.prompt,
144
+ "response": completion,
145
+ "reward": reward,
146
+ "ts": now_ts(),
147
+ }
148
+ )
149
+
150
+ return rollouts, passed_samples