mlxsmith 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlxsmith/__init__.py +2 -0
- mlxsmith/accel/__init__.py +10 -0
- mlxsmith/accel/base.py +17 -0
- mlxsmith/accel/none.py +13 -0
- mlxsmith/accel/zmlx_backend.py +42 -0
- mlxsmith/adapters.py +46 -0
- mlxsmith/api/__init__.py +48 -0
- mlxsmith/api/handlers.py +1217 -0
- mlxsmith/api/schemas.py +436 -0
- mlxsmith/auth.py +88 -0
- mlxsmith/bench.py +102 -0
- mlxsmith/cli.py +950 -0
- mlxsmith/config.py +543 -0
- mlxsmith/config_models.py +261 -0
- mlxsmith/data.py +493 -0
- mlxsmith/envs/__init__.py +33 -0
- mlxsmith/envs/system.py +388 -0
- mlxsmith/envs/token_env.py +191 -0
- mlxsmith/eval.py +112 -0
- mlxsmith/infer.py +140 -0
- mlxsmith/llm/__init__.py +16 -0
- mlxsmith/llm/backend.py +126 -0
- mlxsmith/llm/interface.py +212 -0
- mlxsmith/llm/mlx_lm_backend.py +509 -0
- mlxsmith/llm/mock_backend.py +228 -0
- mlxsmith/llm/registry.py +12 -0
- mlxsmith/models.py +257 -0
- mlxsmith/orchestrator/__init__.py +25 -0
- mlxsmith/orchestrator/daemon.py +454 -0
- mlxsmith/orchestrator/inference_worker.py +496 -0
- mlxsmith/orchestrator/queue.py +355 -0
- mlxsmith/orchestrator/trainer_worker.py +437 -0
- mlxsmith/rlm/__init__.py +8 -0
- mlxsmith/rlm/corpus.py +74 -0
- mlxsmith/rlm/gating.py +90 -0
- mlxsmith/rlm/generate.py +249 -0
- mlxsmith/rlm/history.py +12 -0
- mlxsmith/rlm/inference.py +150 -0
- mlxsmith/rlm/loop.py +1297 -0
- mlxsmith/rlm/mutate.py +82 -0
- mlxsmith/rlm/trainer.py +73 -0
- mlxsmith/rlm/weights.py +263 -0
- mlxsmith/runs.py +44 -0
- mlxsmith/sdk/__init__.py +392 -0
- mlxsmith/sdk/future.py +486 -0
- mlxsmith/sdk/losses.py +262 -0
- mlxsmith/sdk/sampling_client.py +729 -0
- mlxsmith/sdk/training_client.py +676 -0
- mlxsmith/server.py +376 -0
- mlxsmith/train/__init__.py +0 -0
- mlxsmith/train/distill.py +279 -0
- mlxsmith/train/lora.py +280 -0
- mlxsmith/train/pref.py +180 -0
- mlxsmith/train/rft.py +458 -0
- mlxsmith/train/sft.py +151 -0
- mlxsmith/util.py +174 -0
- mlxsmith/verifiers/__init__.py +3 -0
- mlxsmith/verifiers/compose.py +109 -0
- mlxsmith/verifiers/docker_verifier.py +111 -0
- mlxsmith/verifiers/jsonschema.py +54 -0
- mlxsmith/verifiers/pytest_verifier.py +82 -0
- mlxsmith/verifiers/regex.py +15 -0
- mlxsmith/verifiers/types.py +10 -0
- mlxsmith-0.1.0.dist-info/METADATA +163 -0
- mlxsmith-0.1.0.dist-info/RECORD +69 -0
- mlxsmith-0.1.0.dist-info/WHEEL +5 -0
- mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
- mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
- mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
mlxsmith/util.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import hashlib
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
import platform
|
|
7
|
+
import shutil
|
|
8
|
+
import subprocess
|
|
9
|
+
import sys
|
|
10
|
+
import time
|
|
11
|
+
from dataclasses import dataclass
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Optional
|
|
14
|
+
|
|
15
|
+
from rich.console import Console
|
|
16
|
+
|
|
17
|
+
console = Console()
|
|
18
|
+
|
|
19
|
+
def sha1_text(s: str) -> str:
|
|
20
|
+
return hashlib.sha1(s.encode("utf-8")).hexdigest()
|
|
21
|
+
|
|
22
|
+
def ensure_dir(p: Path) -> Path:
|
|
23
|
+
p.mkdir(parents=True, exist_ok=True)
|
|
24
|
+
return p
|
|
25
|
+
|
|
26
|
+
def write_jsonl(path: Path, rows):
|
|
27
|
+
ensure_dir(path.parent)
|
|
28
|
+
with path.open("a", encoding="utf-8") as f:
|
|
29
|
+
for r in rows:
|
|
30
|
+
f.write(json.dumps(r, ensure_ascii=False) + "\n")
|
|
31
|
+
|
|
32
|
+
def now_ts() -> str:
|
|
33
|
+
return time.strftime("%Y%m%d_%H%M%S")
|
|
34
|
+
|
|
35
|
+
def run_cmd(cmd: list[str], cwd: Optional[Path] = None, timeout_s: Optional[int] = None) -> subprocess.CompletedProcess:
|
|
36
|
+
return subprocess.run(cmd, cwd=str(cwd) if cwd else None, timeout=timeout_s, capture_output=True, text=True)
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class SystemInfo:
|
|
40
|
+
python: str
|
|
41
|
+
python_arch: str
|
|
42
|
+
platform: str
|
|
43
|
+
macos_version: Optional[str]
|
|
44
|
+
machine: str
|
|
45
|
+
cpu_count: int
|
|
46
|
+
has_metal: Optional[bool]
|
|
47
|
+
has_mlx: bool
|
|
48
|
+
mlx_version: Optional[str]
|
|
49
|
+
has_zmlx: bool
|
|
50
|
+
|
|
51
|
+
def detect_system() -> SystemInfo:
|
|
52
|
+
has_mlx = False
|
|
53
|
+
mlx_version = None
|
|
54
|
+
try:
|
|
55
|
+
import mlx # type: ignore
|
|
56
|
+
has_mlx = True
|
|
57
|
+
mlx_version = getattr(mlx, "__version__", None)
|
|
58
|
+
except Exception:
|
|
59
|
+
pass
|
|
60
|
+
|
|
61
|
+
import importlib.util
|
|
62
|
+
|
|
63
|
+
has_zmlx = importlib.util.find_spec("zmlx") is not None
|
|
64
|
+
|
|
65
|
+
# Metal detection (best-effort): on macOS we assume Metal is present; for CI, this is not reliable.
|
|
66
|
+
has_metal = None
|
|
67
|
+
if sys.platform == "darwin":
|
|
68
|
+
has_metal = True
|
|
69
|
+
|
|
70
|
+
macos_version = None
|
|
71
|
+
if sys.platform == "darwin":
|
|
72
|
+
macos_version = platform.mac_ver()[0] or None
|
|
73
|
+
|
|
74
|
+
py_arch = platform.architecture()[0] or "unknown"
|
|
75
|
+
|
|
76
|
+
return SystemInfo(
|
|
77
|
+
python=sys.version.split()[0],
|
|
78
|
+
python_arch=py_arch,
|
|
79
|
+
platform=platform.platform(),
|
|
80
|
+
macos_version=macos_version,
|
|
81
|
+
machine=platform.machine(),
|
|
82
|
+
cpu_count=os.cpu_count() or 0,
|
|
83
|
+
has_metal=has_metal,
|
|
84
|
+
has_mlx=has_mlx,
|
|
85
|
+
mlx_version=mlx_version,
|
|
86
|
+
has_zmlx=has_zmlx,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
def require(cond: bool, msg: str):
|
|
90
|
+
if not cond:
|
|
91
|
+
raise RuntimeError(msg)
|
|
92
|
+
|
|
93
|
+
def copytree(src: Path, dst: Path):
|
|
94
|
+
if dst.exists():
|
|
95
|
+
shutil.rmtree(dst)
|
|
96
|
+
shutil.copytree(src, dst)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def tree_map(fn, tree):
|
|
100
|
+
if tree is None:
|
|
101
|
+
return None
|
|
102
|
+
if isinstance(tree, dict):
|
|
103
|
+
return {k: tree_map(fn, v) for k, v in tree.items()}
|
|
104
|
+
if isinstance(tree, (list, tuple)):
|
|
105
|
+
return type(tree)(tree_map(fn, v) for v in tree)
|
|
106
|
+
return fn(tree)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def tree_add(a, b):
|
|
110
|
+
if a is None:
|
|
111
|
+
return b
|
|
112
|
+
if b is None:
|
|
113
|
+
return a
|
|
114
|
+
if isinstance(a, dict) and isinstance(b, dict):
|
|
115
|
+
return {k: tree_add(a.get(k), b.get(k)) for k in a.keys() | b.keys()}
|
|
116
|
+
if isinstance(a, (list, tuple)) and isinstance(b, (list, tuple)):
|
|
117
|
+
return type(a)(tree_add(x, y) for x, y in zip(a, b))
|
|
118
|
+
return a + b
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def tree_scale(tree, scale: float):
|
|
122
|
+
return tree_map(lambda x: x * scale, tree)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def tree_leaves(tree) -> list:
|
|
126
|
+
leaves = []
|
|
127
|
+
if tree is None:
|
|
128
|
+
return leaves
|
|
129
|
+
if isinstance(tree, dict):
|
|
130
|
+
for v in tree.values():
|
|
131
|
+
leaves.extend(tree_leaves(v))
|
|
132
|
+
elif isinstance(tree, (list, tuple)):
|
|
133
|
+
for v in tree:
|
|
134
|
+
leaves.extend(tree_leaves(v))
|
|
135
|
+
else:
|
|
136
|
+
leaves.append(tree)
|
|
137
|
+
return leaves
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def clip_grad_norm(grads, max_norm: float):
|
|
141
|
+
"""Clip gradients by global L2 norm. Returns clipped grads."""
|
|
142
|
+
import mlx.core as mx
|
|
143
|
+
|
|
144
|
+
leaves = tree_leaves(grads)
|
|
145
|
+
if not leaves:
|
|
146
|
+
return grads
|
|
147
|
+
total_norm_sq = mx.array(0.0)
|
|
148
|
+
for g in leaves:
|
|
149
|
+
total_norm_sq = total_norm_sq + (g * g).sum()
|
|
150
|
+
total_norm = mx.sqrt(total_norm_sq)
|
|
151
|
+
clip_coef = mx.minimum(mx.array(max_norm) / mx.maximum(total_norm, mx.array(1e-8)), mx.array(1.0))
|
|
152
|
+
return tree_map(lambda g: g * clip_coef, grads)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def latency_summary_ms(samples: list[float]) -> dict[str, float]:
|
|
156
|
+
if not samples:
|
|
157
|
+
return {}
|
|
158
|
+
items = sorted(samples)
|
|
159
|
+
n = len(items)
|
|
160
|
+
mean = sum(items) / n
|
|
161
|
+
|
|
162
|
+
def _pct(p: float) -> float:
|
|
163
|
+
if n == 1:
|
|
164
|
+
return items[0]
|
|
165
|
+
idx = int((p / 100.0) * (n - 1))
|
|
166
|
+
return items[max(0, min(idx, n - 1))]
|
|
167
|
+
|
|
168
|
+
return {
|
|
169
|
+
"mean": mean,
|
|
170
|
+
"p50": _pct(50),
|
|
171
|
+
"p90": _pct(90),
|
|
172
|
+
"p99": _pct(99),
|
|
173
|
+
"max": items[-1],
|
|
174
|
+
}
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import importlib.util
|
|
4
|
+
import time
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Dict, List, Optional
|
|
7
|
+
|
|
8
|
+
from .types import VerifyResult
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _load_verifier(path: str):
|
|
12
|
+
verifier_path = Path(path)
|
|
13
|
+
spec = importlib.util.spec_from_file_location(verifier_path.stem, verifier_path)
|
|
14
|
+
if spec is None or spec.loader is None:
|
|
15
|
+
raise RuntimeError(f"Could not load verifier: {verifier_path}")
|
|
16
|
+
module = importlib.util.module_from_spec(spec)
|
|
17
|
+
spec.loader.exec_module(module) # type: ignore
|
|
18
|
+
verify_fn = getattr(module, "verify", None)
|
|
19
|
+
if not callable(verify_fn):
|
|
20
|
+
raise RuntimeError(f"Verifier must define verify(...): {verifier_path}")
|
|
21
|
+
return verify_fn
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def verify(
|
|
25
|
+
prompt: str,
|
|
26
|
+
completion: str,
|
|
27
|
+
workdir: str,
|
|
28
|
+
*,
|
|
29
|
+
verifiers: List[Any],
|
|
30
|
+
mode: str = "all",
|
|
31
|
+
weights: Optional[List[float]] = None,
|
|
32
|
+
per_verifier_kwargs: Optional[Dict[str, Dict[str, Any]]] = None,
|
|
33
|
+
**kwargs,
|
|
34
|
+
) -> VerifyResult:
|
|
35
|
+
"""Compose multiple verifiers with AND/OR/weighted reward aggregation.
|
|
36
|
+
|
|
37
|
+
verifiers: list of paths or dicts {path, kwargs}.
|
|
38
|
+
mode: all | any | weighted
|
|
39
|
+
weights: optional weights for weighted reward aggregation.
|
|
40
|
+
per_verifier_kwargs: optional mapping of path -> kwargs.
|
|
41
|
+
"""
|
|
42
|
+
results = []
|
|
43
|
+
latencies: Dict[str, float] = {}
|
|
44
|
+
per_kwargs = per_verifier_kwargs or {}
|
|
45
|
+
|
|
46
|
+
for idx, entry in enumerate(verifiers):
|
|
47
|
+
if isinstance(entry, dict):
|
|
48
|
+
path = entry.get("path")
|
|
49
|
+
extra = entry.get("kwargs") or {}
|
|
50
|
+
else:
|
|
51
|
+
path = entry
|
|
52
|
+
extra = {}
|
|
53
|
+
|
|
54
|
+
if not path:
|
|
55
|
+
continue
|
|
56
|
+
verify_fn = _load_verifier(str(path))
|
|
57
|
+
merged = dict(kwargs)
|
|
58
|
+
merged.update(per_kwargs.get(str(path), {}))
|
|
59
|
+
merged.update(extra)
|
|
60
|
+
|
|
61
|
+
t0 = time.time()
|
|
62
|
+
res = verify_fn(prompt, completion, workdir, **merged)
|
|
63
|
+
latencies[str(path)] = (time.time() - t0) * 1000.0
|
|
64
|
+
results.append((str(path), res))
|
|
65
|
+
|
|
66
|
+
if not results:
|
|
67
|
+
return VerifyResult(
|
|
68
|
+
reward=0.0,
|
|
69
|
+
passed=False,
|
|
70
|
+
info={"error": "no_verifiers"},
|
|
71
|
+
artifacts_dir=workdir,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
mode = (mode or "all").lower()
|
|
75
|
+
passes = [bool(getattr(r, "passed", False)) for _p, r in results]
|
|
76
|
+
rewards = [float(getattr(r, "reward", 0.0)) for _p, r in results]
|
|
77
|
+
|
|
78
|
+
if mode == "any":
|
|
79
|
+
passed = any(passes)
|
|
80
|
+
reward = max(rewards) if rewards else 0.0
|
|
81
|
+
elif mode == "weighted":
|
|
82
|
+
if weights and len(weights) == len(rewards):
|
|
83
|
+
reward = sum(w * r for w, r in zip(weights, rewards))
|
|
84
|
+
else:
|
|
85
|
+
reward = sum(rewards) / max(1, len(rewards))
|
|
86
|
+
passed = reward > 0.0 and any(passes)
|
|
87
|
+
else:
|
|
88
|
+
passed = all(passes)
|
|
89
|
+
reward = sum(rewards) / max(1, len(rewards))
|
|
90
|
+
|
|
91
|
+
return VerifyResult(
|
|
92
|
+
reward=reward,
|
|
93
|
+
passed=passed,
|
|
94
|
+
info={
|
|
95
|
+
"mode": mode,
|
|
96
|
+
"verifiers": [
|
|
97
|
+
{
|
|
98
|
+
"path": path,
|
|
99
|
+
"passed": bool(getattr(res, "passed", False)),
|
|
100
|
+
"reward": float(getattr(res, "reward", 0.0)),
|
|
101
|
+
"info": getattr(res, "info", {}),
|
|
102
|
+
}
|
|
103
|
+
for path, res in results
|
|
104
|
+
],
|
|
105
|
+
"verifier_latencies_ms": latencies,
|
|
106
|
+
},
|
|
107
|
+
artifacts_dir=workdir,
|
|
108
|
+
)
|
|
109
|
+
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import shutil
|
|
4
|
+
import subprocess
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
from .types import VerifyResult
|
|
8
|
+
from .pytest_verifier import verify as local_verify
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def verify(
|
|
12
|
+
prompt: str,
|
|
13
|
+
completion: str,
|
|
14
|
+
workdir: str,
|
|
15
|
+
*,
|
|
16
|
+
tests_subdir: str = "tests",
|
|
17
|
+
timeout_s: int = 30,
|
|
18
|
+
reward_pass: float = 1.0,
|
|
19
|
+
reward_fail: float = 0.0,
|
|
20
|
+
image: str = "python:3.11-slim",
|
|
21
|
+
memory_mb: int = 512,
|
|
22
|
+
cpus: float = 1.0,
|
|
23
|
+
pids: int = 128,
|
|
24
|
+
use_local_fallback: bool = True,
|
|
25
|
+
) -> VerifyResult:
|
|
26
|
+
"""Run pytest inside a locked-down Docker container.
|
|
27
|
+
|
|
28
|
+
Falls back to local pytest verifier if Docker is unavailable and use_local_fallback is True.
|
|
29
|
+
"""
|
|
30
|
+
if shutil.which("docker") is None:
|
|
31
|
+
if use_local_fallback:
|
|
32
|
+
return local_verify(
|
|
33
|
+
prompt,
|
|
34
|
+
completion,
|
|
35
|
+
workdir,
|
|
36
|
+
tests_subdir=tests_subdir,
|
|
37
|
+
timeout_s=timeout_s,
|
|
38
|
+
reward_pass=reward_pass,
|
|
39
|
+
reward_fail=reward_fail,
|
|
40
|
+
)
|
|
41
|
+
return VerifyResult(
|
|
42
|
+
reward=reward_fail,
|
|
43
|
+
passed=False,
|
|
44
|
+
info={"error": "docker_not_found"},
|
|
45
|
+
artifacts_dir=workdir,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
wd = Path(workdir)
|
|
49
|
+
wd.mkdir(parents=True, exist_ok=True)
|
|
50
|
+
if not any(wd.glob("*.py")):
|
|
51
|
+
(wd / "main.py").write_text(completion, encoding="utf-8")
|
|
52
|
+
|
|
53
|
+
tests_path = wd / tests_subdir
|
|
54
|
+
if not tests_path.exists():
|
|
55
|
+
return VerifyResult(
|
|
56
|
+
reward=reward_fail,
|
|
57
|
+
passed=False,
|
|
58
|
+
info={"error": f"Missing tests folder: {tests_subdir}"},
|
|
59
|
+
artifacts_dir=str(wd),
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
cmd = [
|
|
63
|
+
"docker",
|
|
64
|
+
"run",
|
|
65
|
+
"--rm",
|
|
66
|
+
"--network",
|
|
67
|
+
"none",
|
|
68
|
+
"--read-only",
|
|
69
|
+
"--pids-limit",
|
|
70
|
+
str(int(pids)),
|
|
71
|
+
"--memory",
|
|
72
|
+
f"{int(memory_mb)}m",
|
|
73
|
+
"--cpus",
|
|
74
|
+
str(float(cpus)),
|
|
75
|
+
"--tmpfs",
|
|
76
|
+
"/tmp:rw,noexec,nosuid,nodev",
|
|
77
|
+
"-v",
|
|
78
|
+
f"{wd}:/workspace:rw",
|
|
79
|
+
"-w",
|
|
80
|
+
"/workspace",
|
|
81
|
+
image,
|
|
82
|
+
"pytest",
|
|
83
|
+
"-q",
|
|
84
|
+
]
|
|
85
|
+
|
|
86
|
+
try:
|
|
87
|
+
proc = subprocess.run(
|
|
88
|
+
cmd,
|
|
89
|
+
capture_output=True,
|
|
90
|
+
text=True,
|
|
91
|
+
timeout=timeout_s,
|
|
92
|
+
)
|
|
93
|
+
passed = proc.returncode == 0
|
|
94
|
+
return VerifyResult(
|
|
95
|
+
reward=reward_pass if passed else reward_fail,
|
|
96
|
+
passed=passed,
|
|
97
|
+
info={
|
|
98
|
+
"returncode": proc.returncode,
|
|
99
|
+
"stdout": proc.stdout[-4000:],
|
|
100
|
+
"stderr": proc.stderr[-4000:],
|
|
101
|
+
"docker_image": image,
|
|
102
|
+
},
|
|
103
|
+
artifacts_dir=str(wd),
|
|
104
|
+
)
|
|
105
|
+
except subprocess.TimeoutExpired:
|
|
106
|
+
return VerifyResult(
|
|
107
|
+
reward=reward_fail,
|
|
108
|
+
passed=False,
|
|
109
|
+
info={"error": "timeout", "timeout_s": timeout_s},
|
|
110
|
+
artifacts_dir=str(wd),
|
|
111
|
+
)
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from jsonschema import validate, ValidationError
|
|
8
|
+
|
|
9
|
+
from .types import VerifyResult
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _load_schema(schema: Any) -> Any:
|
|
13
|
+
if isinstance(schema, str):
|
|
14
|
+
p = Path(schema)
|
|
15
|
+
if p.exists():
|
|
16
|
+
return json.loads(p.read_text(encoding="utf-8"))
|
|
17
|
+
return schema
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def verify(
|
|
21
|
+
prompt: str,
|
|
22
|
+
completion: str,
|
|
23
|
+
workdir: str,
|
|
24
|
+
*,
|
|
25
|
+
schema: Any,
|
|
26
|
+
reward_pass: float = 1.0,
|
|
27
|
+
reward_fail: float = 0.0,
|
|
28
|
+
) -> VerifyResult:
|
|
29
|
+
try:
|
|
30
|
+
data = json.loads(completion)
|
|
31
|
+
except json.JSONDecodeError as e:
|
|
32
|
+
return VerifyResult(
|
|
33
|
+
reward=reward_fail,
|
|
34
|
+
passed=False,
|
|
35
|
+
info={"error": "invalid_json", "detail": str(e)},
|
|
36
|
+
artifacts_dir=None,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
schema_obj = _load_schema(schema)
|
|
40
|
+
try:
|
|
41
|
+
validate(instance=data, schema=schema_obj)
|
|
42
|
+
return VerifyResult(
|
|
43
|
+
reward=reward_pass,
|
|
44
|
+
passed=True,
|
|
45
|
+
info={"schema": schema_obj},
|
|
46
|
+
artifacts_dir=None,
|
|
47
|
+
)
|
|
48
|
+
except ValidationError as e:
|
|
49
|
+
return VerifyResult(
|
|
50
|
+
reward=reward_fail,
|
|
51
|
+
passed=False,
|
|
52
|
+
info={"error": "schema_validation", "detail": str(e)},
|
|
53
|
+
artifacts_dir=None,
|
|
54
|
+
)
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import subprocess
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from .types import VerifyResult
|
|
7
|
+
|
|
8
|
+
def _sandbox_env(base_env: dict | None = None, *, workdir: Path) -> dict:
|
|
9
|
+
env = dict(os.environ)
|
|
10
|
+
if base_env:
|
|
11
|
+
env.update(base_env)
|
|
12
|
+
env.setdefault("PYTHONHASHSEED", "0")
|
|
13
|
+
env["HOME"] = str(workdir)
|
|
14
|
+
env["TMPDIR"] = str(workdir)
|
|
15
|
+
env["PYTHONPATH"] = str(workdir)
|
|
16
|
+
env.pop("PYTHONSTARTUP", None)
|
|
17
|
+
env.pop("VIRTUAL_ENV", None)
|
|
18
|
+
return env
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def verify(
|
|
22
|
+
prompt: str,
|
|
23
|
+
completion: str,
|
|
24
|
+
workdir: str,
|
|
25
|
+
*,
|
|
26
|
+
tests_subdir: str = "tests",
|
|
27
|
+
timeout_s: int = 30,
|
|
28
|
+
reward_pass: float = 1.0,
|
|
29
|
+
reward_fail: float = 0.0,
|
|
30
|
+
) -> VerifyResult:
|
|
31
|
+
"""Run pytest in a sandbox directory.
|
|
32
|
+
|
|
33
|
+
Expected usage: environment builder writes files into `workdir` and this verifier runs tests there.
|
|
34
|
+
For v0, we simply run `pytest -q` and treat exit code 0 as pass.
|
|
35
|
+
|
|
36
|
+
IMPORTANT: This runs locally. For stronger isolation, replace with a sandbox runner.
|
|
37
|
+
"""
|
|
38
|
+
wd = Path(workdir)
|
|
39
|
+
wd.mkdir(parents=True, exist_ok=True)
|
|
40
|
+
|
|
41
|
+
# If user is doing code tasks, they can write completion into a file convention, e.g., main.py
|
|
42
|
+
# Here we create a default file if none exist:
|
|
43
|
+
if not any(wd.glob("*.py")):
|
|
44
|
+
(wd / "main.py").write_text(completion, encoding="utf-8")
|
|
45
|
+
|
|
46
|
+
# Ensure tests folder exists; if not, fail deterministically.
|
|
47
|
+
tests_path = wd / tests_subdir
|
|
48
|
+
if not tests_path.exists():
|
|
49
|
+
return VerifyResult(
|
|
50
|
+
reward=reward_fail,
|
|
51
|
+
passed=False,
|
|
52
|
+
info={"error": f"Missing tests folder: {tests_subdir}"},
|
|
53
|
+
artifacts_dir=str(wd),
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
try:
|
|
57
|
+
proc = subprocess.run(
|
|
58
|
+
["pytest", "-q"],
|
|
59
|
+
cwd=str(wd),
|
|
60
|
+
capture_output=True,
|
|
61
|
+
text=True,
|
|
62
|
+
timeout=timeout_s,
|
|
63
|
+
env=_sandbox_env(workdir=wd),
|
|
64
|
+
)
|
|
65
|
+
passed = proc.returncode == 0
|
|
66
|
+
return VerifyResult(
|
|
67
|
+
reward=reward_pass if passed else reward_fail,
|
|
68
|
+
passed=passed,
|
|
69
|
+
info={
|
|
70
|
+
"returncode": proc.returncode,
|
|
71
|
+
"stdout": proc.stdout[-4000:],
|
|
72
|
+
"stderr": proc.stderr[-4000:],
|
|
73
|
+
},
|
|
74
|
+
artifacts_dir=str(wd),
|
|
75
|
+
)
|
|
76
|
+
except subprocess.TimeoutExpired:
|
|
77
|
+
return VerifyResult(
|
|
78
|
+
reward=reward_fail,
|
|
79
|
+
passed=False,
|
|
80
|
+
info={"error": "timeout", "timeout_s": timeout_s},
|
|
81
|
+
artifacts_dir=str(wd),
|
|
82
|
+
)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
|
|
5
|
+
from .types import VerifyResult
|
|
6
|
+
|
|
7
|
+
def verify(prompt: str, completion: str, workdir: str, *, pattern: str, flags: int = 0, reward_pass: float = 1.0, reward_fail: float = 0.0) -> VerifyResult:
|
|
8
|
+
m = re.search(pattern, completion, flags=flags)
|
|
9
|
+
passed = m is not None
|
|
10
|
+
return VerifyResult(
|
|
11
|
+
reward=reward_pass if passed else reward_fail,
|
|
12
|
+
passed=passed,
|
|
13
|
+
info={"pattern": pattern, "match": m.group(0) if m else None},
|
|
14
|
+
artifacts_dir=None,
|
|
15
|
+
)
|