mlxsmith 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. mlxsmith/__init__.py +2 -0
  2. mlxsmith/accel/__init__.py +10 -0
  3. mlxsmith/accel/base.py +17 -0
  4. mlxsmith/accel/none.py +13 -0
  5. mlxsmith/accel/zmlx_backend.py +42 -0
  6. mlxsmith/adapters.py +46 -0
  7. mlxsmith/api/__init__.py +48 -0
  8. mlxsmith/api/handlers.py +1217 -0
  9. mlxsmith/api/schemas.py +436 -0
  10. mlxsmith/auth.py +88 -0
  11. mlxsmith/bench.py +102 -0
  12. mlxsmith/cli.py +950 -0
  13. mlxsmith/config.py +543 -0
  14. mlxsmith/config_models.py +261 -0
  15. mlxsmith/data.py +493 -0
  16. mlxsmith/envs/__init__.py +33 -0
  17. mlxsmith/envs/system.py +388 -0
  18. mlxsmith/envs/token_env.py +191 -0
  19. mlxsmith/eval.py +112 -0
  20. mlxsmith/infer.py +140 -0
  21. mlxsmith/llm/__init__.py +16 -0
  22. mlxsmith/llm/backend.py +126 -0
  23. mlxsmith/llm/interface.py +212 -0
  24. mlxsmith/llm/mlx_lm_backend.py +509 -0
  25. mlxsmith/llm/mock_backend.py +228 -0
  26. mlxsmith/llm/registry.py +12 -0
  27. mlxsmith/models.py +257 -0
  28. mlxsmith/orchestrator/__init__.py +25 -0
  29. mlxsmith/orchestrator/daemon.py +454 -0
  30. mlxsmith/orchestrator/inference_worker.py +496 -0
  31. mlxsmith/orchestrator/queue.py +355 -0
  32. mlxsmith/orchestrator/trainer_worker.py +437 -0
  33. mlxsmith/rlm/__init__.py +8 -0
  34. mlxsmith/rlm/corpus.py +74 -0
  35. mlxsmith/rlm/gating.py +90 -0
  36. mlxsmith/rlm/generate.py +249 -0
  37. mlxsmith/rlm/history.py +12 -0
  38. mlxsmith/rlm/inference.py +150 -0
  39. mlxsmith/rlm/loop.py +1297 -0
  40. mlxsmith/rlm/mutate.py +82 -0
  41. mlxsmith/rlm/trainer.py +73 -0
  42. mlxsmith/rlm/weights.py +263 -0
  43. mlxsmith/runs.py +44 -0
  44. mlxsmith/sdk/__init__.py +392 -0
  45. mlxsmith/sdk/future.py +486 -0
  46. mlxsmith/sdk/losses.py +262 -0
  47. mlxsmith/sdk/sampling_client.py +729 -0
  48. mlxsmith/sdk/training_client.py +676 -0
  49. mlxsmith/server.py +376 -0
  50. mlxsmith/train/__init__.py +0 -0
  51. mlxsmith/train/distill.py +279 -0
  52. mlxsmith/train/lora.py +280 -0
  53. mlxsmith/train/pref.py +180 -0
  54. mlxsmith/train/rft.py +458 -0
  55. mlxsmith/train/sft.py +151 -0
  56. mlxsmith/util.py +174 -0
  57. mlxsmith/verifiers/__init__.py +3 -0
  58. mlxsmith/verifiers/compose.py +109 -0
  59. mlxsmith/verifiers/docker_verifier.py +111 -0
  60. mlxsmith/verifiers/jsonschema.py +54 -0
  61. mlxsmith/verifiers/pytest_verifier.py +82 -0
  62. mlxsmith/verifiers/regex.py +15 -0
  63. mlxsmith/verifiers/types.py +10 -0
  64. mlxsmith-0.1.0.dist-info/METADATA +163 -0
  65. mlxsmith-0.1.0.dist-info/RECORD +69 -0
  66. mlxsmith-0.1.0.dist-info/WHEEL +5 -0
  67. mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
  68. mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
  69. mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,388 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import re
5
+ import shutil
6
+ import tarfile
7
+ import tempfile
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+ from typing import Optional
11
+
12
+ import yaml
13
+
14
+ from ..util import ensure_dir, now_ts, copytree
15
+
16
+
17
+ @dataclass
18
+ class EnvManifest:
19
+ name: str
20
+ version: str
21
+ description: Optional[str] = None
22
+ verifier: Optional[str] = None
23
+ tasks: Optional[list] = None
24
+ token_env: Optional[object] = None
25
+
26
+
27
+ @dataclass
28
+ class EnvRef:
29
+ name: str
30
+ version: Optional[str] = None
31
+
32
+
33
+ def _envs_root(project_root: Path) -> Path:
34
+ return project_root / "envs"
35
+
36
+
37
+ def _registry_path(project_root: Path) -> Path:
38
+ return _envs_root(project_root) / "registry.json"
39
+
40
+
41
+ def load_manifest(env_path: Path) -> EnvManifest:
42
+ data = yaml.safe_load(env_path.read_text(encoding="utf-8")) or {}
43
+ return EnvManifest(
44
+ name=str(data.get("name") or env_path.parent.name),
45
+ version=str(data.get("version") or "0.1.0"),
46
+ description=data.get("description"),
47
+ verifier=data.get("verifier"),
48
+ tasks=data.get("tasks"),
49
+ token_env=data.get("token_env"),
50
+ )
51
+
52
+
53
+ def _normalize_package_module(name: str) -> str:
54
+ cleaned = re.sub(r"[^a-zA-Z0-9_]+", "_", name.replace("-", "_"))
55
+ cleaned = cleaned.strip("_") or "env"
56
+ if cleaned[0].isdigit():
57
+ cleaned = f"env_{cleaned}"
58
+ return cleaned.lower()
59
+
60
+
61
+ def _env_scaffold_pyproject(name: str, version: str, description: str) -> str:
62
+ return f"""[build-system]
63
+ requires = ["setuptools>=68", "wheel"]
64
+ build-backend = "setuptools.build_meta"
65
+
66
+ [project]
67
+ name = "{name}"
68
+ version = "{version}"
69
+ description = "{description}"
70
+ readme = "README.md"
71
+ requires-python = ">=3.10"
72
+
73
+ [tool.setuptools]
74
+ package-dir = {{"" = "src"}}
75
+
76
+ [tool.setuptools.packages.find]
77
+ where = ["src"]
78
+ """
79
+
80
+
81
+ def _env_scaffold_readme(name: str) -> str:
82
+ return f"""# {name}
83
+
84
+ Local MLXSmith environment package.
85
+
86
+ ## Files
87
+ - `env.yaml`: task manifest consumed by mlxsmith.
88
+ - `pyproject.toml`: Python package metadata for Hub publishing.
89
+
90
+ ## Usage
91
+ ```bash
92
+ mlxsmith env package {name}
93
+ mlxsmith env publish envs/packages/{name}-0.1.0.tar.gz
94
+ ```
95
+ """
96
+
97
+
98
+ def _env_scaffold_module() -> str:
99
+ return """from pathlib import Path
100
+
101
+ ENV_MANIFEST = Path(__file__).resolve().parents[2] / "env.yaml"
102
+
103
+
104
+ def load_environment() -> Path:
105
+ return ENV_MANIFEST
106
+ """
107
+
108
+
109
+ def _env_scaffold_init() -> str:
110
+ return """from .environment import ENV_MANIFEST, load_environment
111
+
112
+ __all__ = ["ENV_MANIFEST", "load_environment"]
113
+ """
114
+
115
+
116
+ def resolve_env_path(project_root: Path, env_ref: str) -> Path:
117
+ ref = Path(env_ref)
118
+ if ref.exists():
119
+ if ref.is_dir():
120
+ candidate = ref / "env.yaml"
121
+ return candidate if candidate.exists() else ref
122
+ return ref
123
+ candidate = _envs_root(project_root) / env_ref / "env.yaml"
124
+ return candidate
125
+
126
+
127
+ def init_env(project_root: Path, name: str) -> Path:
128
+ env_root = _envs_root(project_root) / name
129
+ ensure_dir(env_root)
130
+ manifest = {
131
+ "name": name,
132
+ "version": "0.1.0",
133
+ "description": "Sample environment",
134
+ "verifier": "verifiers/regex.py",
135
+ "tasks": [
136
+ {
137
+ "id": "add",
138
+ "prompt": "Write a Python function add(a, b) that returns the sum.",
139
+ "tests": "from main import add\\n\\n\\n"
140
+ "def test_add():\\n"
141
+ " assert add(2, 3) == 5\\n",
142
+ }
143
+ ],
144
+ }
145
+ (env_root / "env.yaml").write_text(yaml.safe_dump(manifest, sort_keys=False), encoding="utf-8")
146
+
147
+ pkg_module = _normalize_package_module(name)
148
+ (env_root / "pyproject.toml").write_text(
149
+ _env_scaffold_pyproject(name, manifest["version"], manifest["description"]),
150
+ encoding="utf-8",
151
+ )
152
+ (env_root / "README.md").write_text(_env_scaffold_readme(name), encoding="utf-8")
153
+ pkg_dir = ensure_dir(env_root / "src" / pkg_module)
154
+ (pkg_dir / "environment.py").write_text(_env_scaffold_module(), encoding="utf-8")
155
+ (pkg_dir / "__init__.py").write_text(_env_scaffold_init(), encoding="utf-8")
156
+ return env_root
157
+
158
+
159
+ def _load_registry(project_root: Path) -> dict:
160
+ path = _registry_path(project_root)
161
+ if not path.exists():
162
+ return {"packages": [], "updated_at": None}
163
+ return json.loads(path.read_text(encoding="utf-8"))
164
+
165
+
166
+ def _save_registry(project_root: Path, data: dict) -> None:
167
+ data["updated_at"] = now_ts()
168
+ path = _registry_path(project_root)
169
+ ensure_dir(path.parent)
170
+ path.write_text(json.dumps(data, indent=2), encoding="utf-8")
171
+
172
+
173
+ def _parse_env_ref(env_ref: str, version: Optional[str] = None) -> EnvRef:
174
+ name = env_ref.strip()
175
+ parsed_version = None
176
+ if "==" in name:
177
+ name, parsed_version = name.split("==", 1)
178
+ elif "@" in name:
179
+ name, parsed_version = name.rsplit("@", 1)
180
+ name = name.strip()
181
+ if parsed_version is not None:
182
+ parsed_version = parsed_version.strip()
183
+ if parsed_version in {"", "latest"}:
184
+ parsed_version = None
185
+ if version and parsed_version and version != parsed_version:
186
+ raise RuntimeError(f"Conflicting versions: {parsed_version} vs {version}")
187
+ return EnvRef(name=name, version=version or parsed_version)
188
+
189
+
190
+ def _version_key(version: str) -> tuple:
191
+ if not version:
192
+ return tuple()
193
+ base, _, _build = version.partition("+")
194
+ main, _, pre = base.partition("-")
195
+ parts = []
196
+ for part in main.split("."):
197
+ if part.startswith("v") and part[1:].isdigit():
198
+ part = part[1:]
199
+ if part.isdigit():
200
+ parts.append((0, int(part)))
201
+ else:
202
+ parts.append((1, part))
203
+ if pre:
204
+ parts.append((1, pre))
205
+ else:
206
+ parts.append((2, ""))
207
+ return tuple(parts)
208
+
209
+
210
+ def _select_registry_package(packages: list[dict], name: str, version: Optional[str]) -> dict:
211
+ matches = [p for p in packages if p.get("name") == name]
212
+ if not matches:
213
+ raise RuntimeError(f"Env not found in registry: {name}")
214
+ if version:
215
+ exact = [p for p in matches if p.get("version") == version]
216
+ if not exact:
217
+ available = sorted({p.get("version", "") for p in matches})
218
+ raise RuntimeError(f"Env {name} has no version {version}. Available: {', '.join(available)}")
219
+ return exact[0]
220
+ return sorted(matches, key=lambda p: _version_key(p.get("version", "")))[-1]
221
+
222
+
223
+ def _load_manifest_from_package(package_path: Path) -> EnvManifest:
224
+ with tarfile.open(package_path, "r:gz") as tf:
225
+ for m in tf.getmembers():
226
+ if m.name.endswith("env.yaml"):
227
+ f = tf.extractfile(m)
228
+ if f is None:
229
+ break
230
+ data = yaml.safe_load(f.read().decode("utf-8")) or {}
231
+ return EnvManifest(
232
+ name=str(data.get("name") or Path(m.name).parent.name),
233
+ version=str(data.get("version") or "0.1.0"),
234
+ description=data.get("description"),
235
+ verifier=data.get("verifier"),
236
+ tasks=data.get("tasks"),
237
+ token_env=data.get("token_env"),
238
+ )
239
+ raise RuntimeError("Package missing env.yaml")
240
+
241
+
242
+ def list_registry_packages(project_root: Path, name: Optional[str] = None, all_versions: bool = False) -> list[dict]:
243
+ registry = _load_registry(project_root)
244
+ packages = registry.get("packages", [])
245
+ if name:
246
+ packages = [p for p in packages if p.get("name") == name]
247
+ if all_versions or not packages:
248
+ return sorted(packages, key=lambda p: (p.get("name", ""), _version_key(p.get("version", ""))))
249
+
250
+ latest = {}
251
+ for pkg in packages:
252
+ pkg_name = pkg.get("name")
253
+ if not pkg_name:
254
+ continue
255
+ prev = latest.get(pkg_name)
256
+ if prev is None or _version_key(pkg.get("version", "")) > _version_key(prev.get("version", "")):
257
+ latest[pkg_name] = pkg
258
+ return sorted(latest.values(), key=lambda p: p.get("name", ""))
259
+
260
+
261
+ def registry_info(project_root: Path, env_ref: str, version: Optional[str] = None) -> tuple[dict, EnvManifest]:
262
+ registry = _load_registry(project_root)
263
+ ref = _parse_env_ref(env_ref, version=version)
264
+ pkg = _select_registry_package(registry.get("packages", []), ref.name, ref.version)
265
+ pkg_path = Path(pkg["path"])
266
+ if not pkg_path.exists():
267
+ raise RuntimeError(f"Registry package missing: {pkg_path}")
268
+ manifest = _load_manifest_from_package(pkg_path)
269
+ return pkg, manifest
270
+
271
+
272
+ def install_env(
273
+ project_root: Path,
274
+ source: str,
275
+ version: Optional[str] = None,
276
+ ) -> Path:
277
+ src = Path(source)
278
+ ensure_dir(_envs_root(project_root))
279
+ if src.exists():
280
+ if version:
281
+ raise RuntimeError("Version pinning only applies to registry installs.")
282
+ if src.is_dir():
283
+ manifest_path = src / "env.yaml"
284
+ if not manifest_path.exists():
285
+ raise RuntimeError(f"Missing env.yaml in {src}")
286
+ manifest = load_manifest(manifest_path)
287
+ dest = _envs_root(project_root) / manifest.name
288
+ copytree(src, dest)
289
+ return dest
290
+ if src.suffixes[-2:] == [".tar", ".gz"]:
291
+ with tarfile.open(src, "r:gz") as tf:
292
+ members = tf.getmembers()
293
+ top = members[0].name.split("/")[0] if members else ""
294
+ tf.extractall(_envs_root(project_root))
295
+ env_path = _envs_root(project_root) / top / "env.yaml"
296
+ if env_path.exists():
297
+ return env_path.parent
298
+ raise RuntimeError(f"Invalid package: {src}")
299
+ raise RuntimeError(f"Unsupported env source: {src}")
300
+
301
+ # treat as registry name
302
+ registry = _load_registry(project_root)
303
+ ref = _parse_env_ref(source, version=version)
304
+ pkg = _select_registry_package(registry.get("packages", []), ref.name, ref.version)
305
+ pkg_path = Path(pkg["path"])
306
+ return install_env(project_root, str(pkg_path))
307
+
308
+
309
+ def package_env(project_root: Path, env_name: str, out_path: Optional[str] = None) -> Path:
310
+ env_dir = _envs_root(project_root) / env_name
311
+ manifest_path = env_dir / "env.yaml"
312
+ if not manifest_path.exists():
313
+ raise RuntimeError(f"Missing env.yaml in {env_dir}")
314
+ manifest = load_manifest(manifest_path)
315
+
316
+ out_dir = Path(out_path) if out_path else _envs_root(project_root) / "packages"
317
+ ensure_dir(out_dir)
318
+ tar_path = out_dir / f"{manifest.name}-{manifest.version}.tar.gz"
319
+ with tarfile.open(tar_path, "w:gz") as tf:
320
+ tf.add(env_dir, arcname=env_dir.name)
321
+ return tar_path
322
+
323
+
324
+ def publish_env(project_root: Path, package_path: str) -> Path:
325
+ pkg = Path(package_path)
326
+ if not pkg.exists():
327
+ raise RuntimeError(f"Missing package: {pkg}")
328
+
329
+ manifest = _load_manifest_from_package(pkg)
330
+
331
+ registry_dir = _envs_root(project_root) / "registry"
332
+ ensure_dir(registry_dir)
333
+ dest = registry_dir / f"{manifest.name}-{manifest.version}.tar.gz"
334
+ dest.write_bytes(pkg.read_bytes())
335
+
336
+ registry = _load_registry(project_root)
337
+ registry.setdefault("packages", [])
338
+ registry["packages"] = [
339
+ p
340
+ for p in registry["packages"]
341
+ if not (p.get("name") == manifest.name and p.get("version") == manifest.version)
342
+ ]
343
+ registry["packages"].append(
344
+ {
345
+ "name": manifest.name,
346
+ "version": manifest.version,
347
+ "description": manifest.description,
348
+ "verifier": manifest.verifier,
349
+ "path": str(dest),
350
+ }
351
+ )
352
+ _save_registry(project_root, registry)
353
+ return dest
354
+
355
+
356
+ def pull_env(
357
+ project_root: Path,
358
+ env_ref: str,
359
+ out_dir: Optional[str] = None,
360
+ version: Optional[str] = None,
361
+ force: bool = False,
362
+ ) -> Path:
363
+ registry = _load_registry(project_root)
364
+ ref = _parse_env_ref(env_ref, version=version)
365
+ pkg = _select_registry_package(registry.get("packages", []), ref.name, ref.version)
366
+ pkg_path = Path(pkg["path"])
367
+ if not pkg_path.exists():
368
+ raise RuntimeError(f"Registry package missing: {pkg_path}")
369
+
370
+ dest = Path(out_dir) if out_dir else Path.cwd() / ref.name
371
+ if dest.exists():
372
+ if not force:
373
+ raise RuntimeError(f"Destination exists: {dest}")
374
+ shutil.rmtree(dest)
375
+
376
+ with tempfile.TemporaryDirectory() as tmpdir:
377
+ tmp_root = Path(tmpdir)
378
+ with tarfile.open(pkg_path, "r:gz") as tf:
379
+ members = tf.getmembers()
380
+ if not members:
381
+ raise RuntimeError(f"Empty package: {pkg_path}")
382
+ top = members[0].name.split("/")[0]
383
+ tf.extractall(tmp_root)
384
+ src = tmp_root / top
385
+ if not src.exists():
386
+ raise RuntimeError(f"Invalid package layout: {pkg_path}")
387
+ shutil.copytree(src, dest)
388
+ return dest
@@ -0,0 +1,191 @@
1
+ from __future__ import annotations
2
+
3
+ import importlib
4
+ import importlib.util
5
+ import inspect
6
+ import time
7
+ from dataclasses import dataclass, field
8
+ from pathlib import Path
9
+ from typing import Any, Callable, Optional, Protocol
10
+
11
+
12
+ @dataclass
13
+ class TokenEnvStep:
14
+ observation: list[int]
15
+ reward: float
16
+ done: bool
17
+ info: dict[str, Any] = field(default_factory=dict)
18
+
19
+
20
+ class TokenEnv(Protocol):
21
+ def initial_observation(self) -> list[int] | TokenEnvStep:
22
+ ...
23
+
24
+ def step(self, action: int) -> TokenEnvStep:
25
+ ...
26
+
27
+
28
+ @dataclass
29
+ class TokenEnvSpec:
30
+ factory: Callable[..., TokenEnv]
31
+ kwargs: dict[str, Any] = field(default_factory=dict)
32
+ kind: str = "custom"
33
+
34
+
35
+ def _filter_kwargs(fn: Callable[..., Any], kwargs: dict[str, Any]) -> dict[str, Any]:
36
+ try:
37
+ sig = inspect.signature(fn)
38
+ except (TypeError, ValueError):
39
+ return kwargs
40
+ for param in sig.parameters.values():
41
+ if param.kind == param.VAR_KEYWORD:
42
+ return kwargs
43
+ return {k: v for k, v in kwargs.items() if k in sig.parameters}
44
+
45
+
46
+ def _load_from_path(path: Path):
47
+ spec = importlib.util.spec_from_file_location(path.stem, path)
48
+ if spec is None or spec.loader is None:
49
+ raise RuntimeError(f"Could not load token env module: {path}")
50
+ module = importlib.util.module_from_spec(spec)
51
+ spec.loader.exec_module(module) # type: ignore
52
+ return module
53
+
54
+
55
+ def _resolve_factory(module, class_name: Optional[str]) -> Callable[..., TokenEnv]:
56
+ if class_name:
57
+ factory = getattr(module, class_name, None)
58
+ if factory is None:
59
+ raise RuntimeError(f"Token env factory not found: {class_name}")
60
+ if not callable(factory):
61
+ raise RuntimeError(f"Token env factory not callable: {class_name}")
62
+ return factory
63
+
64
+ for fallback in ("Env", "TokenEnv", "make_env", "load_env"):
65
+ factory = getattr(module, fallback, None)
66
+ if callable(factory):
67
+ return factory
68
+
69
+ raise RuntimeError("Token env factory not found (expected class or make_env/load_env).")
70
+
71
+
72
+ def _parse_token_env_spec(project_root: Path, token_env: Any) -> TokenEnvSpec:
73
+ if isinstance(token_env, str):
74
+ if token_env in {"tasks", "task_shim"}:
75
+ return TokenEnvSpec(factory=StringTaskTokenEnv, kind="tasks")
76
+ path_part, _, class_part = token_env.partition(":")
77
+ class_name = class_part or None
78
+ if Path(path_part).suffix == ".py" or Path(path_part).exists():
79
+ path = Path(path_part)
80
+ if not path.is_absolute():
81
+ path = project_root / path
82
+ module = _load_from_path(path)
83
+ else:
84
+ module = importlib.import_module(path_part)
85
+ factory = _resolve_factory(module, class_name)
86
+ return TokenEnvSpec(factory=factory, kind="custom")
87
+
88
+ if isinstance(token_env, dict):
89
+ if token_env.get("type") in {"tasks", "task_shim"}:
90
+ return TokenEnvSpec(factory=StringTaskTokenEnv, kind="tasks")
91
+ class_name = token_env.get("class") or token_env.get("cls")
92
+ kwargs = token_env.get("kwargs") or {}
93
+ if "path" in token_env:
94
+ path = Path(token_env["path"])
95
+ if not path.is_absolute():
96
+ path = project_root / path
97
+ module = _load_from_path(path)
98
+ elif "module" in token_env:
99
+ module = importlib.import_module(str(token_env["module"]))
100
+ else:
101
+ raise RuntimeError("token_env requires 'path' or 'module'")
102
+ factory = _resolve_factory(module, class_name)
103
+ return TokenEnvSpec(factory=factory, kwargs=dict(kwargs), kind="custom")
104
+
105
+ raise RuntimeError("token_env must be a string or mapping")
106
+
107
+
108
+ def load_token_env_spec(project_root: Path, env_data: dict) -> Optional[TokenEnvSpec]:
109
+ token_env = env_data.get("token_env")
110
+ if not token_env:
111
+ return None
112
+ return _parse_token_env_spec(project_root, token_env)
113
+
114
+
115
+ def create_token_env(spec: TokenEnvSpec, **kwargs) -> TokenEnv:
116
+ params = dict(spec.kwargs)
117
+ params.update(kwargs)
118
+ params = _filter_kwargs(spec.factory, params)
119
+ return spec.factory(**params)
120
+
121
+
122
+ class StringTaskTokenEnv:
123
+ def __init__(
124
+ self,
125
+ *,
126
+ prompt: str,
127
+ tests: str,
128
+ verifier_fn: Callable[..., Any],
129
+ workdir: Path,
130
+ max_steps: int,
131
+ encode: Callable[[str], list[int]],
132
+ decode: Callable[[list[int]], str],
133
+ verifier_kwargs: Optional[dict[str, Any]] = None,
134
+ eos_token_id: Optional[int] = None,
135
+ ):
136
+ self.prompt = prompt
137
+ self.tests = tests
138
+ self.verifier_fn = verifier_fn
139
+ self.workdir = Path(workdir)
140
+ self.max_steps = max_steps
141
+ self.encode = encode
142
+ self.decode = decode
143
+ self.verifier_kwargs = verifier_kwargs or {}
144
+ self.eos_token_id = eos_token_id
145
+ self._prompt_ids: list[int] = []
146
+ self._generated: list[int] = []
147
+ self._steps = 0
148
+
149
+ def initial_observation(self) -> list[int]:
150
+ self._prompt_ids = list(self.encode(self.prompt))
151
+ self._generated = []
152
+ self._steps = 0
153
+ tests_dir = self.workdir / "tests"
154
+ tests_dir.mkdir(parents=True, exist_ok=True)
155
+ (tests_dir / "test_task.py").write_text(self.tests or "", encoding="utf-8")
156
+ return list(self._prompt_ids)
157
+
158
+ def step(self, action: int) -> TokenEnvStep:
159
+ if self.eos_token_id is not None and action == self.eos_token_id:
160
+ done = True
161
+ else:
162
+ done = False
163
+ self._generated.append(int(action))
164
+ self._steps += 1
165
+
166
+ if self._steps >= self.max_steps:
167
+ done = True
168
+
169
+ reward = 0.0
170
+ info: dict[str, Any] = {}
171
+ if done:
172
+ completion_ids = list(self._generated)
173
+ if self.eos_token_id is not None and completion_ids and completion_ids[-1] == self.eos_token_id:
174
+ completion_ids = completion_ids[:-1]
175
+ completion = self.decode(completion_ids)
176
+ (self.workdir / "main.py").write_text(completion, encoding="utf-8")
177
+ t0 = time.time()
178
+ res = self.verifier_fn(self.prompt, completion, str(self.workdir), **self.verifier_kwargs)
179
+ latency_ms = (time.time() - t0) * 1000.0
180
+ reward = float(getattr(res, "reward", 0.0))
181
+ info = dict(getattr(res, "info", {}) or {})
182
+ info["passed"] = bool(getattr(res, "passed", False))
183
+ info["verifier_latency_ms"] = latency_ms
184
+
185
+ observation = list(self._prompt_ids) + list(self._generated)
186
+ return TokenEnvStep(
187
+ observation=observation,
188
+ reward=reward,
189
+ done=done,
190
+ info=info,
191
+ )
mlxsmith/eval.py ADDED
@@ -0,0 +1,112 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import time
5
+ from pathlib import Path
6
+
7
+ from rich.console import Console
8
+
9
+ from .util import ensure_dir, now_ts
10
+ from .config import ProjectConfig, load_config
11
+ from .models import resolve_model_spec
12
+ from .llm.registry import get_llm_backend
13
+
14
+ console = Console()
15
+
16
+
17
+ def _load_verifier(verifier_path: Path):
18
+ import importlib.util
19
+
20
+ spec = importlib.util.spec_from_file_location(verifier_path.stem, verifier_path)
21
+ if spec is None or spec.loader is None:
22
+ raise RuntimeError(f"Could not load verifier: {verifier_path}")
23
+ module = importlib.util.module_from_spec(spec)
24
+ spec.loader.exec_module(module) # type: ignore
25
+ verify_fn = getattr(module, "verify", None)
26
+ if not callable(verify_fn):
27
+ raise RuntimeError(f"Verifier must define verify(...): {verifier_path}")
28
+ return verify_fn
29
+
30
+
31
+ def run_eval(project_root: Path, suite_path: Path, model_path: Path) -> Path:
32
+ import yaml
33
+
34
+ suite = yaml.safe_load(suite_path.read_text(encoding="utf-8")) or {}
35
+ out_dir = ensure_dir(project_root / "eval" / "last")
36
+ out_path = out_dir / "results.json"
37
+
38
+ cfg_path = project_root / "mlxsmith.yaml"
39
+ if cfg_path.exists():
40
+ cfg = load_config(cfg_path)
41
+ else:
42
+ cfg = ProjectConfig()
43
+ if suite.get("config"):
44
+ merged = cfg.model_dump()
45
+ merged.update(suite.get("config") or {})
46
+ cfg = ProjectConfig.model_validate(merged)
47
+
48
+ llm = get_llm_backend(cfg.model.backend)
49
+ base_model, adapter_path, _meta = resolve_model_spec(project_root, str(model_path), cfg)
50
+ llm.load(
51
+ base_model,
52
+ max_seq_len=cfg.model.max_seq_len,
53
+ dtype=cfg.model.dtype,
54
+ trust_remote_code=cfg.model.trust_remote_code,
55
+ )
56
+ if adapter_path:
57
+ llm.apply_adapter(str(adapter_path))
58
+
59
+ tasks = suite.get("tasks") or []
60
+ if not tasks:
61
+ results = {
62
+ "model": str(model_path),
63
+ "suite": suite.get("name", suite_path.name),
64
+ "error": "no tasks",
65
+ }
66
+ out_path.write_text(json.dumps(results, indent=2), encoding="utf-8")
67
+ return out_path
68
+
69
+ summaries = []
70
+ for task in tasks:
71
+ prompt = task.get("prompt", "")
72
+ k = int(task.get("k", 1))
73
+ verifier_path = task.get("verifier")
74
+ verify_fn = None
75
+ if verifier_path:
76
+ verify_fn = _load_verifier(project_root / verifier_path)
77
+ passes = 0
78
+ responses = []
79
+ t0 = time.time()
80
+ for i in range(k):
81
+ gen = llm.generate(
82
+ prompt,
83
+ max_new_tokens=int(task.get("max_new_tokens", 256)),
84
+ temperature=float(task.get("temperature", 0.7)),
85
+ top_p=float(task.get("top_p", 1.0)),
86
+ seed=int(task.get("seed", 0)) if task.get("seed") is not None else None,
87
+ )
88
+ completion = gen.text[len(prompt) :] if gen.text.startswith(prompt) else gen.text
89
+ responses.append(completion)
90
+ if verify_fn:
91
+ res = verify_fn(prompt, completion, str(out_dir), **(task.get("verifier_kwargs") or {}))
92
+ if bool(getattr(res, "passed", False)):
93
+ passes += 1
94
+ elapsed = max(time.time() - t0, 1e-6)
95
+ summaries.append(
96
+ {
97
+ "task_id": task.get("id") or prompt[:32],
98
+ "k": k,
99
+ "pass@k": passes / max(1, k),
100
+ "latency_s": elapsed,
101
+ }
102
+ )
103
+
104
+ results = {
105
+ "model": str(model_path),
106
+ "suite": suite.get("name", suite_path.name),
107
+ "ts": now_ts(),
108
+ "summary": summaries,
109
+ }
110
+ out_path.write_text(json.dumps(results, indent=2), encoding="utf-8")
111
+ console.print(f"[green]Wrote[/green] {out_path}")
112
+ return out_path