mlxsmith 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlxsmith/__init__.py +2 -0
- mlxsmith/accel/__init__.py +10 -0
- mlxsmith/accel/base.py +17 -0
- mlxsmith/accel/none.py +13 -0
- mlxsmith/accel/zmlx_backend.py +42 -0
- mlxsmith/adapters.py +46 -0
- mlxsmith/api/__init__.py +48 -0
- mlxsmith/api/handlers.py +1217 -0
- mlxsmith/api/schemas.py +436 -0
- mlxsmith/auth.py +88 -0
- mlxsmith/bench.py +102 -0
- mlxsmith/cli.py +950 -0
- mlxsmith/config.py +543 -0
- mlxsmith/config_models.py +261 -0
- mlxsmith/data.py +493 -0
- mlxsmith/envs/__init__.py +33 -0
- mlxsmith/envs/system.py +388 -0
- mlxsmith/envs/token_env.py +191 -0
- mlxsmith/eval.py +112 -0
- mlxsmith/infer.py +140 -0
- mlxsmith/llm/__init__.py +16 -0
- mlxsmith/llm/backend.py +126 -0
- mlxsmith/llm/interface.py +212 -0
- mlxsmith/llm/mlx_lm_backend.py +509 -0
- mlxsmith/llm/mock_backend.py +228 -0
- mlxsmith/llm/registry.py +12 -0
- mlxsmith/models.py +257 -0
- mlxsmith/orchestrator/__init__.py +25 -0
- mlxsmith/orchestrator/daemon.py +454 -0
- mlxsmith/orchestrator/inference_worker.py +496 -0
- mlxsmith/orchestrator/queue.py +355 -0
- mlxsmith/orchestrator/trainer_worker.py +437 -0
- mlxsmith/rlm/__init__.py +8 -0
- mlxsmith/rlm/corpus.py +74 -0
- mlxsmith/rlm/gating.py +90 -0
- mlxsmith/rlm/generate.py +249 -0
- mlxsmith/rlm/history.py +12 -0
- mlxsmith/rlm/inference.py +150 -0
- mlxsmith/rlm/loop.py +1297 -0
- mlxsmith/rlm/mutate.py +82 -0
- mlxsmith/rlm/trainer.py +73 -0
- mlxsmith/rlm/weights.py +263 -0
- mlxsmith/runs.py +44 -0
- mlxsmith/sdk/__init__.py +392 -0
- mlxsmith/sdk/future.py +486 -0
- mlxsmith/sdk/losses.py +262 -0
- mlxsmith/sdk/sampling_client.py +729 -0
- mlxsmith/sdk/training_client.py +676 -0
- mlxsmith/server.py +376 -0
- mlxsmith/train/__init__.py +0 -0
- mlxsmith/train/distill.py +279 -0
- mlxsmith/train/lora.py +280 -0
- mlxsmith/train/pref.py +180 -0
- mlxsmith/train/rft.py +458 -0
- mlxsmith/train/sft.py +151 -0
- mlxsmith/util.py +174 -0
- mlxsmith/verifiers/__init__.py +3 -0
- mlxsmith/verifiers/compose.py +109 -0
- mlxsmith/verifiers/docker_verifier.py +111 -0
- mlxsmith/verifiers/jsonschema.py +54 -0
- mlxsmith/verifiers/pytest_verifier.py +82 -0
- mlxsmith/verifiers/regex.py +15 -0
- mlxsmith/verifiers/types.py +10 -0
- mlxsmith-0.1.0.dist-info/METADATA +163 -0
- mlxsmith-0.1.0.dist-info/RECORD +69 -0
- mlxsmith-0.1.0.dist-info/WHEEL +5 -0
- mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
- mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
- mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
mlxsmith/envs/system.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import re
|
|
5
|
+
import shutil
|
|
6
|
+
import tarfile
|
|
7
|
+
import tempfile
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Optional
|
|
11
|
+
|
|
12
|
+
import yaml
|
|
13
|
+
|
|
14
|
+
from ..util import ensure_dir, now_ts, copytree
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class EnvManifest:
|
|
19
|
+
name: str
|
|
20
|
+
version: str
|
|
21
|
+
description: Optional[str] = None
|
|
22
|
+
verifier: Optional[str] = None
|
|
23
|
+
tasks: Optional[list] = None
|
|
24
|
+
token_env: Optional[object] = None
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class EnvRef:
|
|
29
|
+
name: str
|
|
30
|
+
version: Optional[str] = None
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _envs_root(project_root: Path) -> Path:
|
|
34
|
+
return project_root / "envs"
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _registry_path(project_root: Path) -> Path:
|
|
38
|
+
return _envs_root(project_root) / "registry.json"
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def load_manifest(env_path: Path) -> EnvManifest:
|
|
42
|
+
data = yaml.safe_load(env_path.read_text(encoding="utf-8")) or {}
|
|
43
|
+
return EnvManifest(
|
|
44
|
+
name=str(data.get("name") or env_path.parent.name),
|
|
45
|
+
version=str(data.get("version") or "0.1.0"),
|
|
46
|
+
description=data.get("description"),
|
|
47
|
+
verifier=data.get("verifier"),
|
|
48
|
+
tasks=data.get("tasks"),
|
|
49
|
+
token_env=data.get("token_env"),
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _normalize_package_module(name: str) -> str:
|
|
54
|
+
cleaned = re.sub(r"[^a-zA-Z0-9_]+", "_", name.replace("-", "_"))
|
|
55
|
+
cleaned = cleaned.strip("_") or "env"
|
|
56
|
+
if cleaned[0].isdigit():
|
|
57
|
+
cleaned = f"env_{cleaned}"
|
|
58
|
+
return cleaned.lower()
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _env_scaffold_pyproject(name: str, version: str, description: str) -> str:
|
|
62
|
+
return f"""[build-system]
|
|
63
|
+
requires = ["setuptools>=68", "wheel"]
|
|
64
|
+
build-backend = "setuptools.build_meta"
|
|
65
|
+
|
|
66
|
+
[project]
|
|
67
|
+
name = "{name}"
|
|
68
|
+
version = "{version}"
|
|
69
|
+
description = "{description}"
|
|
70
|
+
readme = "README.md"
|
|
71
|
+
requires-python = ">=3.10"
|
|
72
|
+
|
|
73
|
+
[tool.setuptools]
|
|
74
|
+
package-dir = {{"" = "src"}}
|
|
75
|
+
|
|
76
|
+
[tool.setuptools.packages.find]
|
|
77
|
+
where = ["src"]
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _env_scaffold_readme(name: str) -> str:
|
|
82
|
+
return f"""# {name}
|
|
83
|
+
|
|
84
|
+
Local MLXSmith environment package.
|
|
85
|
+
|
|
86
|
+
## Files
|
|
87
|
+
- `env.yaml`: task manifest consumed by mlxsmith.
|
|
88
|
+
- `pyproject.toml`: Python package metadata for Hub publishing.
|
|
89
|
+
|
|
90
|
+
## Usage
|
|
91
|
+
```bash
|
|
92
|
+
mlxsmith env package {name}
|
|
93
|
+
mlxsmith env publish envs/packages/{name}-0.1.0.tar.gz
|
|
94
|
+
```
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def _env_scaffold_module() -> str:
|
|
99
|
+
return """from pathlib import Path
|
|
100
|
+
|
|
101
|
+
ENV_MANIFEST = Path(__file__).resolve().parents[2] / "env.yaml"
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def load_environment() -> Path:
|
|
105
|
+
return ENV_MANIFEST
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _env_scaffold_init() -> str:
|
|
110
|
+
return """from .environment import ENV_MANIFEST, load_environment
|
|
111
|
+
|
|
112
|
+
__all__ = ["ENV_MANIFEST", "load_environment"]
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def resolve_env_path(project_root: Path, env_ref: str) -> Path:
|
|
117
|
+
ref = Path(env_ref)
|
|
118
|
+
if ref.exists():
|
|
119
|
+
if ref.is_dir():
|
|
120
|
+
candidate = ref / "env.yaml"
|
|
121
|
+
return candidate if candidate.exists() else ref
|
|
122
|
+
return ref
|
|
123
|
+
candidate = _envs_root(project_root) / env_ref / "env.yaml"
|
|
124
|
+
return candidate
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def init_env(project_root: Path, name: str) -> Path:
|
|
128
|
+
env_root = _envs_root(project_root) / name
|
|
129
|
+
ensure_dir(env_root)
|
|
130
|
+
manifest = {
|
|
131
|
+
"name": name,
|
|
132
|
+
"version": "0.1.0",
|
|
133
|
+
"description": "Sample environment",
|
|
134
|
+
"verifier": "verifiers/regex.py",
|
|
135
|
+
"tasks": [
|
|
136
|
+
{
|
|
137
|
+
"id": "add",
|
|
138
|
+
"prompt": "Write a Python function add(a, b) that returns the sum.",
|
|
139
|
+
"tests": "from main import add\\n\\n\\n"
|
|
140
|
+
"def test_add():\\n"
|
|
141
|
+
" assert add(2, 3) == 5\\n",
|
|
142
|
+
}
|
|
143
|
+
],
|
|
144
|
+
}
|
|
145
|
+
(env_root / "env.yaml").write_text(yaml.safe_dump(manifest, sort_keys=False), encoding="utf-8")
|
|
146
|
+
|
|
147
|
+
pkg_module = _normalize_package_module(name)
|
|
148
|
+
(env_root / "pyproject.toml").write_text(
|
|
149
|
+
_env_scaffold_pyproject(name, manifest["version"], manifest["description"]),
|
|
150
|
+
encoding="utf-8",
|
|
151
|
+
)
|
|
152
|
+
(env_root / "README.md").write_text(_env_scaffold_readme(name), encoding="utf-8")
|
|
153
|
+
pkg_dir = ensure_dir(env_root / "src" / pkg_module)
|
|
154
|
+
(pkg_dir / "environment.py").write_text(_env_scaffold_module(), encoding="utf-8")
|
|
155
|
+
(pkg_dir / "__init__.py").write_text(_env_scaffold_init(), encoding="utf-8")
|
|
156
|
+
return env_root
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def _load_registry(project_root: Path) -> dict:
|
|
160
|
+
path = _registry_path(project_root)
|
|
161
|
+
if not path.exists():
|
|
162
|
+
return {"packages": [], "updated_at": None}
|
|
163
|
+
return json.loads(path.read_text(encoding="utf-8"))
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _save_registry(project_root: Path, data: dict) -> None:
|
|
167
|
+
data["updated_at"] = now_ts()
|
|
168
|
+
path = _registry_path(project_root)
|
|
169
|
+
ensure_dir(path.parent)
|
|
170
|
+
path.write_text(json.dumps(data, indent=2), encoding="utf-8")
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def _parse_env_ref(env_ref: str, version: Optional[str] = None) -> EnvRef:
|
|
174
|
+
name = env_ref.strip()
|
|
175
|
+
parsed_version = None
|
|
176
|
+
if "==" in name:
|
|
177
|
+
name, parsed_version = name.split("==", 1)
|
|
178
|
+
elif "@" in name:
|
|
179
|
+
name, parsed_version = name.rsplit("@", 1)
|
|
180
|
+
name = name.strip()
|
|
181
|
+
if parsed_version is not None:
|
|
182
|
+
parsed_version = parsed_version.strip()
|
|
183
|
+
if parsed_version in {"", "latest"}:
|
|
184
|
+
parsed_version = None
|
|
185
|
+
if version and parsed_version and version != parsed_version:
|
|
186
|
+
raise RuntimeError(f"Conflicting versions: {parsed_version} vs {version}")
|
|
187
|
+
return EnvRef(name=name, version=version or parsed_version)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def _version_key(version: str) -> tuple:
|
|
191
|
+
if not version:
|
|
192
|
+
return tuple()
|
|
193
|
+
base, _, _build = version.partition("+")
|
|
194
|
+
main, _, pre = base.partition("-")
|
|
195
|
+
parts = []
|
|
196
|
+
for part in main.split("."):
|
|
197
|
+
if part.startswith("v") and part[1:].isdigit():
|
|
198
|
+
part = part[1:]
|
|
199
|
+
if part.isdigit():
|
|
200
|
+
parts.append((0, int(part)))
|
|
201
|
+
else:
|
|
202
|
+
parts.append((1, part))
|
|
203
|
+
if pre:
|
|
204
|
+
parts.append((1, pre))
|
|
205
|
+
else:
|
|
206
|
+
parts.append((2, ""))
|
|
207
|
+
return tuple(parts)
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def _select_registry_package(packages: list[dict], name: str, version: Optional[str]) -> dict:
|
|
211
|
+
matches = [p for p in packages if p.get("name") == name]
|
|
212
|
+
if not matches:
|
|
213
|
+
raise RuntimeError(f"Env not found in registry: {name}")
|
|
214
|
+
if version:
|
|
215
|
+
exact = [p for p in matches if p.get("version") == version]
|
|
216
|
+
if not exact:
|
|
217
|
+
available = sorted({p.get("version", "") for p in matches})
|
|
218
|
+
raise RuntimeError(f"Env {name} has no version {version}. Available: {', '.join(available)}")
|
|
219
|
+
return exact[0]
|
|
220
|
+
return sorted(matches, key=lambda p: _version_key(p.get("version", "")))[-1]
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def _load_manifest_from_package(package_path: Path) -> EnvManifest:
|
|
224
|
+
with tarfile.open(package_path, "r:gz") as tf:
|
|
225
|
+
for m in tf.getmembers():
|
|
226
|
+
if m.name.endswith("env.yaml"):
|
|
227
|
+
f = tf.extractfile(m)
|
|
228
|
+
if f is None:
|
|
229
|
+
break
|
|
230
|
+
data = yaml.safe_load(f.read().decode("utf-8")) or {}
|
|
231
|
+
return EnvManifest(
|
|
232
|
+
name=str(data.get("name") or Path(m.name).parent.name),
|
|
233
|
+
version=str(data.get("version") or "0.1.0"),
|
|
234
|
+
description=data.get("description"),
|
|
235
|
+
verifier=data.get("verifier"),
|
|
236
|
+
tasks=data.get("tasks"),
|
|
237
|
+
token_env=data.get("token_env"),
|
|
238
|
+
)
|
|
239
|
+
raise RuntimeError("Package missing env.yaml")
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def list_registry_packages(project_root: Path, name: Optional[str] = None, all_versions: bool = False) -> list[dict]:
|
|
243
|
+
registry = _load_registry(project_root)
|
|
244
|
+
packages = registry.get("packages", [])
|
|
245
|
+
if name:
|
|
246
|
+
packages = [p for p in packages if p.get("name") == name]
|
|
247
|
+
if all_versions or not packages:
|
|
248
|
+
return sorted(packages, key=lambda p: (p.get("name", ""), _version_key(p.get("version", ""))))
|
|
249
|
+
|
|
250
|
+
latest = {}
|
|
251
|
+
for pkg in packages:
|
|
252
|
+
pkg_name = pkg.get("name")
|
|
253
|
+
if not pkg_name:
|
|
254
|
+
continue
|
|
255
|
+
prev = latest.get(pkg_name)
|
|
256
|
+
if prev is None or _version_key(pkg.get("version", "")) > _version_key(prev.get("version", "")):
|
|
257
|
+
latest[pkg_name] = pkg
|
|
258
|
+
return sorted(latest.values(), key=lambda p: p.get("name", ""))
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def registry_info(project_root: Path, env_ref: str, version: Optional[str] = None) -> tuple[dict, EnvManifest]:
|
|
262
|
+
registry = _load_registry(project_root)
|
|
263
|
+
ref = _parse_env_ref(env_ref, version=version)
|
|
264
|
+
pkg = _select_registry_package(registry.get("packages", []), ref.name, ref.version)
|
|
265
|
+
pkg_path = Path(pkg["path"])
|
|
266
|
+
if not pkg_path.exists():
|
|
267
|
+
raise RuntimeError(f"Registry package missing: {pkg_path}")
|
|
268
|
+
manifest = _load_manifest_from_package(pkg_path)
|
|
269
|
+
return pkg, manifest
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def install_env(
|
|
273
|
+
project_root: Path,
|
|
274
|
+
source: str,
|
|
275
|
+
version: Optional[str] = None,
|
|
276
|
+
) -> Path:
|
|
277
|
+
src = Path(source)
|
|
278
|
+
ensure_dir(_envs_root(project_root))
|
|
279
|
+
if src.exists():
|
|
280
|
+
if version:
|
|
281
|
+
raise RuntimeError("Version pinning only applies to registry installs.")
|
|
282
|
+
if src.is_dir():
|
|
283
|
+
manifest_path = src / "env.yaml"
|
|
284
|
+
if not manifest_path.exists():
|
|
285
|
+
raise RuntimeError(f"Missing env.yaml in {src}")
|
|
286
|
+
manifest = load_manifest(manifest_path)
|
|
287
|
+
dest = _envs_root(project_root) / manifest.name
|
|
288
|
+
copytree(src, dest)
|
|
289
|
+
return dest
|
|
290
|
+
if src.suffixes[-2:] == [".tar", ".gz"]:
|
|
291
|
+
with tarfile.open(src, "r:gz") as tf:
|
|
292
|
+
members = tf.getmembers()
|
|
293
|
+
top = members[0].name.split("/")[0] if members else ""
|
|
294
|
+
tf.extractall(_envs_root(project_root))
|
|
295
|
+
env_path = _envs_root(project_root) / top / "env.yaml"
|
|
296
|
+
if env_path.exists():
|
|
297
|
+
return env_path.parent
|
|
298
|
+
raise RuntimeError(f"Invalid package: {src}")
|
|
299
|
+
raise RuntimeError(f"Unsupported env source: {src}")
|
|
300
|
+
|
|
301
|
+
# treat as registry name
|
|
302
|
+
registry = _load_registry(project_root)
|
|
303
|
+
ref = _parse_env_ref(source, version=version)
|
|
304
|
+
pkg = _select_registry_package(registry.get("packages", []), ref.name, ref.version)
|
|
305
|
+
pkg_path = Path(pkg["path"])
|
|
306
|
+
return install_env(project_root, str(pkg_path))
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def package_env(project_root: Path, env_name: str, out_path: Optional[str] = None) -> Path:
|
|
310
|
+
env_dir = _envs_root(project_root) / env_name
|
|
311
|
+
manifest_path = env_dir / "env.yaml"
|
|
312
|
+
if not manifest_path.exists():
|
|
313
|
+
raise RuntimeError(f"Missing env.yaml in {env_dir}")
|
|
314
|
+
manifest = load_manifest(manifest_path)
|
|
315
|
+
|
|
316
|
+
out_dir = Path(out_path) if out_path else _envs_root(project_root) / "packages"
|
|
317
|
+
ensure_dir(out_dir)
|
|
318
|
+
tar_path = out_dir / f"{manifest.name}-{manifest.version}.tar.gz"
|
|
319
|
+
with tarfile.open(tar_path, "w:gz") as tf:
|
|
320
|
+
tf.add(env_dir, arcname=env_dir.name)
|
|
321
|
+
return tar_path
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
def publish_env(project_root: Path, package_path: str) -> Path:
|
|
325
|
+
pkg = Path(package_path)
|
|
326
|
+
if not pkg.exists():
|
|
327
|
+
raise RuntimeError(f"Missing package: {pkg}")
|
|
328
|
+
|
|
329
|
+
manifest = _load_manifest_from_package(pkg)
|
|
330
|
+
|
|
331
|
+
registry_dir = _envs_root(project_root) / "registry"
|
|
332
|
+
ensure_dir(registry_dir)
|
|
333
|
+
dest = registry_dir / f"{manifest.name}-{manifest.version}.tar.gz"
|
|
334
|
+
dest.write_bytes(pkg.read_bytes())
|
|
335
|
+
|
|
336
|
+
registry = _load_registry(project_root)
|
|
337
|
+
registry.setdefault("packages", [])
|
|
338
|
+
registry["packages"] = [
|
|
339
|
+
p
|
|
340
|
+
for p in registry["packages"]
|
|
341
|
+
if not (p.get("name") == manifest.name and p.get("version") == manifest.version)
|
|
342
|
+
]
|
|
343
|
+
registry["packages"].append(
|
|
344
|
+
{
|
|
345
|
+
"name": manifest.name,
|
|
346
|
+
"version": manifest.version,
|
|
347
|
+
"description": manifest.description,
|
|
348
|
+
"verifier": manifest.verifier,
|
|
349
|
+
"path": str(dest),
|
|
350
|
+
}
|
|
351
|
+
)
|
|
352
|
+
_save_registry(project_root, registry)
|
|
353
|
+
return dest
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def pull_env(
|
|
357
|
+
project_root: Path,
|
|
358
|
+
env_ref: str,
|
|
359
|
+
out_dir: Optional[str] = None,
|
|
360
|
+
version: Optional[str] = None,
|
|
361
|
+
force: bool = False,
|
|
362
|
+
) -> Path:
|
|
363
|
+
registry = _load_registry(project_root)
|
|
364
|
+
ref = _parse_env_ref(env_ref, version=version)
|
|
365
|
+
pkg = _select_registry_package(registry.get("packages", []), ref.name, ref.version)
|
|
366
|
+
pkg_path = Path(pkg["path"])
|
|
367
|
+
if not pkg_path.exists():
|
|
368
|
+
raise RuntimeError(f"Registry package missing: {pkg_path}")
|
|
369
|
+
|
|
370
|
+
dest = Path(out_dir) if out_dir else Path.cwd() / ref.name
|
|
371
|
+
if dest.exists():
|
|
372
|
+
if not force:
|
|
373
|
+
raise RuntimeError(f"Destination exists: {dest}")
|
|
374
|
+
shutil.rmtree(dest)
|
|
375
|
+
|
|
376
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
377
|
+
tmp_root = Path(tmpdir)
|
|
378
|
+
with tarfile.open(pkg_path, "r:gz") as tf:
|
|
379
|
+
members = tf.getmembers()
|
|
380
|
+
if not members:
|
|
381
|
+
raise RuntimeError(f"Empty package: {pkg_path}")
|
|
382
|
+
top = members[0].name.split("/")[0]
|
|
383
|
+
tf.extractall(tmp_root)
|
|
384
|
+
src = tmp_root / top
|
|
385
|
+
if not src.exists():
|
|
386
|
+
raise RuntimeError(f"Invalid package layout: {pkg_path}")
|
|
387
|
+
shutil.copytree(src, dest)
|
|
388
|
+
return dest
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import importlib
|
|
4
|
+
import importlib.util
|
|
5
|
+
import inspect
|
|
6
|
+
import time
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any, Callable, Optional, Protocol
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class TokenEnvStep:
|
|
14
|
+
observation: list[int]
|
|
15
|
+
reward: float
|
|
16
|
+
done: bool
|
|
17
|
+
info: dict[str, Any] = field(default_factory=dict)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class TokenEnv(Protocol):
|
|
21
|
+
def initial_observation(self) -> list[int] | TokenEnvStep:
|
|
22
|
+
...
|
|
23
|
+
|
|
24
|
+
def step(self, action: int) -> TokenEnvStep:
|
|
25
|
+
...
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class TokenEnvSpec:
|
|
30
|
+
factory: Callable[..., TokenEnv]
|
|
31
|
+
kwargs: dict[str, Any] = field(default_factory=dict)
|
|
32
|
+
kind: str = "custom"
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _filter_kwargs(fn: Callable[..., Any], kwargs: dict[str, Any]) -> dict[str, Any]:
|
|
36
|
+
try:
|
|
37
|
+
sig = inspect.signature(fn)
|
|
38
|
+
except (TypeError, ValueError):
|
|
39
|
+
return kwargs
|
|
40
|
+
for param in sig.parameters.values():
|
|
41
|
+
if param.kind == param.VAR_KEYWORD:
|
|
42
|
+
return kwargs
|
|
43
|
+
return {k: v for k, v in kwargs.items() if k in sig.parameters}
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _load_from_path(path: Path):
|
|
47
|
+
spec = importlib.util.spec_from_file_location(path.stem, path)
|
|
48
|
+
if spec is None or spec.loader is None:
|
|
49
|
+
raise RuntimeError(f"Could not load token env module: {path}")
|
|
50
|
+
module = importlib.util.module_from_spec(spec)
|
|
51
|
+
spec.loader.exec_module(module) # type: ignore
|
|
52
|
+
return module
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _resolve_factory(module, class_name: Optional[str]) -> Callable[..., TokenEnv]:
|
|
56
|
+
if class_name:
|
|
57
|
+
factory = getattr(module, class_name, None)
|
|
58
|
+
if factory is None:
|
|
59
|
+
raise RuntimeError(f"Token env factory not found: {class_name}")
|
|
60
|
+
if not callable(factory):
|
|
61
|
+
raise RuntimeError(f"Token env factory not callable: {class_name}")
|
|
62
|
+
return factory
|
|
63
|
+
|
|
64
|
+
for fallback in ("Env", "TokenEnv", "make_env", "load_env"):
|
|
65
|
+
factory = getattr(module, fallback, None)
|
|
66
|
+
if callable(factory):
|
|
67
|
+
return factory
|
|
68
|
+
|
|
69
|
+
raise RuntimeError("Token env factory not found (expected class or make_env/load_env).")
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _parse_token_env_spec(project_root: Path, token_env: Any) -> TokenEnvSpec:
|
|
73
|
+
if isinstance(token_env, str):
|
|
74
|
+
if token_env in {"tasks", "task_shim"}:
|
|
75
|
+
return TokenEnvSpec(factory=StringTaskTokenEnv, kind="tasks")
|
|
76
|
+
path_part, _, class_part = token_env.partition(":")
|
|
77
|
+
class_name = class_part or None
|
|
78
|
+
if Path(path_part).suffix == ".py" or Path(path_part).exists():
|
|
79
|
+
path = Path(path_part)
|
|
80
|
+
if not path.is_absolute():
|
|
81
|
+
path = project_root / path
|
|
82
|
+
module = _load_from_path(path)
|
|
83
|
+
else:
|
|
84
|
+
module = importlib.import_module(path_part)
|
|
85
|
+
factory = _resolve_factory(module, class_name)
|
|
86
|
+
return TokenEnvSpec(factory=factory, kind="custom")
|
|
87
|
+
|
|
88
|
+
if isinstance(token_env, dict):
|
|
89
|
+
if token_env.get("type") in {"tasks", "task_shim"}:
|
|
90
|
+
return TokenEnvSpec(factory=StringTaskTokenEnv, kind="tasks")
|
|
91
|
+
class_name = token_env.get("class") or token_env.get("cls")
|
|
92
|
+
kwargs = token_env.get("kwargs") or {}
|
|
93
|
+
if "path" in token_env:
|
|
94
|
+
path = Path(token_env["path"])
|
|
95
|
+
if not path.is_absolute():
|
|
96
|
+
path = project_root / path
|
|
97
|
+
module = _load_from_path(path)
|
|
98
|
+
elif "module" in token_env:
|
|
99
|
+
module = importlib.import_module(str(token_env["module"]))
|
|
100
|
+
else:
|
|
101
|
+
raise RuntimeError("token_env requires 'path' or 'module'")
|
|
102
|
+
factory = _resolve_factory(module, class_name)
|
|
103
|
+
return TokenEnvSpec(factory=factory, kwargs=dict(kwargs), kind="custom")
|
|
104
|
+
|
|
105
|
+
raise RuntimeError("token_env must be a string or mapping")
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def load_token_env_spec(project_root: Path, env_data: dict) -> Optional[TokenEnvSpec]:
|
|
109
|
+
token_env = env_data.get("token_env")
|
|
110
|
+
if not token_env:
|
|
111
|
+
return None
|
|
112
|
+
return _parse_token_env_spec(project_root, token_env)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def create_token_env(spec: TokenEnvSpec, **kwargs) -> TokenEnv:
|
|
116
|
+
params = dict(spec.kwargs)
|
|
117
|
+
params.update(kwargs)
|
|
118
|
+
params = _filter_kwargs(spec.factory, params)
|
|
119
|
+
return spec.factory(**params)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class StringTaskTokenEnv:
|
|
123
|
+
def __init__(
|
|
124
|
+
self,
|
|
125
|
+
*,
|
|
126
|
+
prompt: str,
|
|
127
|
+
tests: str,
|
|
128
|
+
verifier_fn: Callable[..., Any],
|
|
129
|
+
workdir: Path,
|
|
130
|
+
max_steps: int,
|
|
131
|
+
encode: Callable[[str], list[int]],
|
|
132
|
+
decode: Callable[[list[int]], str],
|
|
133
|
+
verifier_kwargs: Optional[dict[str, Any]] = None,
|
|
134
|
+
eos_token_id: Optional[int] = None,
|
|
135
|
+
):
|
|
136
|
+
self.prompt = prompt
|
|
137
|
+
self.tests = tests
|
|
138
|
+
self.verifier_fn = verifier_fn
|
|
139
|
+
self.workdir = Path(workdir)
|
|
140
|
+
self.max_steps = max_steps
|
|
141
|
+
self.encode = encode
|
|
142
|
+
self.decode = decode
|
|
143
|
+
self.verifier_kwargs = verifier_kwargs or {}
|
|
144
|
+
self.eos_token_id = eos_token_id
|
|
145
|
+
self._prompt_ids: list[int] = []
|
|
146
|
+
self._generated: list[int] = []
|
|
147
|
+
self._steps = 0
|
|
148
|
+
|
|
149
|
+
def initial_observation(self) -> list[int]:
|
|
150
|
+
self._prompt_ids = list(self.encode(self.prompt))
|
|
151
|
+
self._generated = []
|
|
152
|
+
self._steps = 0
|
|
153
|
+
tests_dir = self.workdir / "tests"
|
|
154
|
+
tests_dir.mkdir(parents=True, exist_ok=True)
|
|
155
|
+
(tests_dir / "test_task.py").write_text(self.tests or "", encoding="utf-8")
|
|
156
|
+
return list(self._prompt_ids)
|
|
157
|
+
|
|
158
|
+
def step(self, action: int) -> TokenEnvStep:
|
|
159
|
+
if self.eos_token_id is not None and action == self.eos_token_id:
|
|
160
|
+
done = True
|
|
161
|
+
else:
|
|
162
|
+
done = False
|
|
163
|
+
self._generated.append(int(action))
|
|
164
|
+
self._steps += 1
|
|
165
|
+
|
|
166
|
+
if self._steps >= self.max_steps:
|
|
167
|
+
done = True
|
|
168
|
+
|
|
169
|
+
reward = 0.0
|
|
170
|
+
info: dict[str, Any] = {}
|
|
171
|
+
if done:
|
|
172
|
+
completion_ids = list(self._generated)
|
|
173
|
+
if self.eos_token_id is not None and completion_ids and completion_ids[-1] == self.eos_token_id:
|
|
174
|
+
completion_ids = completion_ids[:-1]
|
|
175
|
+
completion = self.decode(completion_ids)
|
|
176
|
+
(self.workdir / "main.py").write_text(completion, encoding="utf-8")
|
|
177
|
+
t0 = time.time()
|
|
178
|
+
res = self.verifier_fn(self.prompt, completion, str(self.workdir), **self.verifier_kwargs)
|
|
179
|
+
latency_ms = (time.time() - t0) * 1000.0
|
|
180
|
+
reward = float(getattr(res, "reward", 0.0))
|
|
181
|
+
info = dict(getattr(res, "info", {}) or {})
|
|
182
|
+
info["passed"] = bool(getattr(res, "passed", False))
|
|
183
|
+
info["verifier_latency_ms"] = latency_ms
|
|
184
|
+
|
|
185
|
+
observation = list(self._prompt_ids) + list(self._generated)
|
|
186
|
+
return TokenEnvStep(
|
|
187
|
+
observation=observation,
|
|
188
|
+
reward=reward,
|
|
189
|
+
done=done,
|
|
190
|
+
info=info,
|
|
191
|
+
)
|
mlxsmith/eval.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import time
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
from rich.console import Console
|
|
8
|
+
|
|
9
|
+
from .util import ensure_dir, now_ts
|
|
10
|
+
from .config import ProjectConfig, load_config
|
|
11
|
+
from .models import resolve_model_spec
|
|
12
|
+
from .llm.registry import get_llm_backend
|
|
13
|
+
|
|
14
|
+
console = Console()
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _load_verifier(verifier_path: Path):
|
|
18
|
+
import importlib.util
|
|
19
|
+
|
|
20
|
+
spec = importlib.util.spec_from_file_location(verifier_path.stem, verifier_path)
|
|
21
|
+
if spec is None or spec.loader is None:
|
|
22
|
+
raise RuntimeError(f"Could not load verifier: {verifier_path}")
|
|
23
|
+
module = importlib.util.module_from_spec(spec)
|
|
24
|
+
spec.loader.exec_module(module) # type: ignore
|
|
25
|
+
verify_fn = getattr(module, "verify", None)
|
|
26
|
+
if not callable(verify_fn):
|
|
27
|
+
raise RuntimeError(f"Verifier must define verify(...): {verifier_path}")
|
|
28
|
+
return verify_fn
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def run_eval(project_root: Path, suite_path: Path, model_path: Path) -> Path:
|
|
32
|
+
import yaml
|
|
33
|
+
|
|
34
|
+
suite = yaml.safe_load(suite_path.read_text(encoding="utf-8")) or {}
|
|
35
|
+
out_dir = ensure_dir(project_root / "eval" / "last")
|
|
36
|
+
out_path = out_dir / "results.json"
|
|
37
|
+
|
|
38
|
+
cfg_path = project_root / "mlxsmith.yaml"
|
|
39
|
+
if cfg_path.exists():
|
|
40
|
+
cfg = load_config(cfg_path)
|
|
41
|
+
else:
|
|
42
|
+
cfg = ProjectConfig()
|
|
43
|
+
if suite.get("config"):
|
|
44
|
+
merged = cfg.model_dump()
|
|
45
|
+
merged.update(suite.get("config") or {})
|
|
46
|
+
cfg = ProjectConfig.model_validate(merged)
|
|
47
|
+
|
|
48
|
+
llm = get_llm_backend(cfg.model.backend)
|
|
49
|
+
base_model, adapter_path, _meta = resolve_model_spec(project_root, str(model_path), cfg)
|
|
50
|
+
llm.load(
|
|
51
|
+
base_model,
|
|
52
|
+
max_seq_len=cfg.model.max_seq_len,
|
|
53
|
+
dtype=cfg.model.dtype,
|
|
54
|
+
trust_remote_code=cfg.model.trust_remote_code,
|
|
55
|
+
)
|
|
56
|
+
if adapter_path:
|
|
57
|
+
llm.apply_adapter(str(adapter_path))
|
|
58
|
+
|
|
59
|
+
tasks = suite.get("tasks") or []
|
|
60
|
+
if not tasks:
|
|
61
|
+
results = {
|
|
62
|
+
"model": str(model_path),
|
|
63
|
+
"suite": suite.get("name", suite_path.name),
|
|
64
|
+
"error": "no tasks",
|
|
65
|
+
}
|
|
66
|
+
out_path.write_text(json.dumps(results, indent=2), encoding="utf-8")
|
|
67
|
+
return out_path
|
|
68
|
+
|
|
69
|
+
summaries = []
|
|
70
|
+
for task in tasks:
|
|
71
|
+
prompt = task.get("prompt", "")
|
|
72
|
+
k = int(task.get("k", 1))
|
|
73
|
+
verifier_path = task.get("verifier")
|
|
74
|
+
verify_fn = None
|
|
75
|
+
if verifier_path:
|
|
76
|
+
verify_fn = _load_verifier(project_root / verifier_path)
|
|
77
|
+
passes = 0
|
|
78
|
+
responses = []
|
|
79
|
+
t0 = time.time()
|
|
80
|
+
for i in range(k):
|
|
81
|
+
gen = llm.generate(
|
|
82
|
+
prompt,
|
|
83
|
+
max_new_tokens=int(task.get("max_new_tokens", 256)),
|
|
84
|
+
temperature=float(task.get("temperature", 0.7)),
|
|
85
|
+
top_p=float(task.get("top_p", 1.0)),
|
|
86
|
+
seed=int(task.get("seed", 0)) if task.get("seed") is not None else None,
|
|
87
|
+
)
|
|
88
|
+
completion = gen.text[len(prompt) :] if gen.text.startswith(prompt) else gen.text
|
|
89
|
+
responses.append(completion)
|
|
90
|
+
if verify_fn:
|
|
91
|
+
res = verify_fn(prompt, completion, str(out_dir), **(task.get("verifier_kwargs") or {}))
|
|
92
|
+
if bool(getattr(res, "passed", False)):
|
|
93
|
+
passes += 1
|
|
94
|
+
elapsed = max(time.time() - t0, 1e-6)
|
|
95
|
+
summaries.append(
|
|
96
|
+
{
|
|
97
|
+
"task_id": task.get("id") or prompt[:32],
|
|
98
|
+
"k": k,
|
|
99
|
+
"pass@k": passes / max(1, k),
|
|
100
|
+
"latency_s": elapsed,
|
|
101
|
+
}
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
results = {
|
|
105
|
+
"model": str(model_path),
|
|
106
|
+
"suite": suite.get("name", suite_path.name),
|
|
107
|
+
"ts": now_ts(),
|
|
108
|
+
"summary": summaries,
|
|
109
|
+
}
|
|
110
|
+
out_path.write_text(json.dumps(results, indent=2), encoding="utf-8")
|
|
111
|
+
console.print(f"[green]Wrote[/green] {out_path}")
|
|
112
|
+
return out_path
|