mlxsmith 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlxsmith/accel/__init__.py +0 -3
- mlxsmith/bench.py +12 -2
- mlxsmith/cli.py +188 -3
- mlxsmith/config_models.py +16 -2
- mlxsmith/integrations/__init__.py +19 -0
- mlxsmith/integrations/mlx_lm_lora.py +117 -0
- mlxsmith/llm/backend.py +8 -1
- mlxsmith/llm/mlx_lm_backend.py +59 -2
- mlxsmith/llm/mock_backend.py +8 -1
- mlxsmith/optim/__init__.py +3 -0
- mlxsmith/optim/muon.py +93 -0
- mlxsmith/orchestrator/daemon.py +44 -377
- mlxsmith/orchestrator/trainer_worker.py +4 -0
- mlxsmith/rlm/loop.py +53 -92
- mlxsmith/sdk/__init__.py +18 -2
- mlxsmith/sdk/losses.py +102 -1
- mlxsmith/sdk/training_client.py +24 -5
- mlxsmith/train/distill.py +6 -1
- mlxsmith/train/online_dpo.py +249 -0
- mlxsmith/train/pref.py +31 -29
- mlxsmith/train/rft.py +123 -38
- mlxsmith/train/self_verify.py +199 -0
- mlxsmith/train/sft.py +13 -2
- mlxsmith/util.py +0 -6
- mlxsmith/verifiers/llm_judge.py +278 -0
- mlxsmith/verifiers/prime.py +127 -0
- {mlxsmith-0.1.1.dist-info → mlxsmith-0.1.3.dist-info}/METADATA +29 -13
- {mlxsmith-0.1.1.dist-info → mlxsmith-0.1.3.dist-info}/RECORD +32 -25
- mlxsmith/accel/zmlx_backend.py +0 -42
- {mlxsmith-0.1.1.dist-info → mlxsmith-0.1.3.dist-info}/WHEEL +0 -0
- {mlxsmith-0.1.1.dist-info → mlxsmith-0.1.3.dist-info}/entry_points.txt +0 -0
- {mlxsmith-0.1.1.dist-info → mlxsmith-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {mlxsmith-0.1.1.dist-info → mlxsmith-0.1.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import random
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Iterable, Optional
|
|
7
|
+
|
|
8
|
+
from rich.console import Console
|
|
9
|
+
|
|
10
|
+
from ..accel import get_backend
|
|
11
|
+
from ..config import ProjectConfig
|
|
12
|
+
from ..models import resolve_model_spec
|
|
13
|
+
from ..runs import RunPaths, new_run, snapshot_config
|
|
14
|
+
from ..util import write_jsonl, now_ts, tree_add, tree_scale, clip_grad_norm
|
|
15
|
+
from ..llm.registry import get_llm_backend
|
|
16
|
+
from ..llm.backend import BackendNotAvailable
|
|
17
|
+
from ..verifiers.llm_judge import verify as judge_verify
|
|
18
|
+
from .lora import LoRAConfig
|
|
19
|
+
|
|
20
|
+
console = Console()
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _iter_prompts(path: Path) -> Iterable[str]:
|
|
24
|
+
for line in path.read_text(encoding="utf-8").splitlines():
|
|
25
|
+
if not line.strip():
|
|
26
|
+
continue
|
|
27
|
+
row = json.loads(line)
|
|
28
|
+
prompt = row.get("prompt") or row.get("instruction") or row.get("input") or row.get("question") or ""
|
|
29
|
+
if not prompt and "messages" in row:
|
|
30
|
+
msgs = row.get("messages") or []
|
|
31
|
+
if msgs:
|
|
32
|
+
prompt = "\n".join([m.get("content", "") for m in msgs])
|
|
33
|
+
if prompt:
|
|
34
|
+
yield str(prompt)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def run_self_verify(
|
|
38
|
+
project_root: Path,
|
|
39
|
+
cfg: ProjectConfig,
|
|
40
|
+
data_path: Path,
|
|
41
|
+
model_id_or_path: str,
|
|
42
|
+
accel: str,
|
|
43
|
+
*,
|
|
44
|
+
verifier_model: Optional[str] = None,
|
|
45
|
+
verifier_backend: str = "mlx-lm",
|
|
46
|
+
rubric: Optional[str] = None,
|
|
47
|
+
max_new_tokens: Optional[int] = None,
|
|
48
|
+
temperature: Optional[float] = None,
|
|
49
|
+
judge_mock_response: Optional[str | list[str]] = None,
|
|
50
|
+
) -> RunPaths:
|
|
51
|
+
run = new_run(project_root, "self_verify")
|
|
52
|
+
snapshot_config(cfg.model_dump(), run.config_snapshot_path)
|
|
53
|
+
|
|
54
|
+
prompts = list(_iter_prompts(data_path))
|
|
55
|
+
if not prompts:
|
|
56
|
+
raise RuntimeError("No prompts found in self-verify dataset")
|
|
57
|
+
|
|
58
|
+
backend = get_backend(accel)
|
|
59
|
+
backend.patch()
|
|
60
|
+
console.print(f"[bold]SELF-VERIFY[/bold] run: {run.run_dir.name} accel={backend.name}")
|
|
61
|
+
|
|
62
|
+
policy = get_llm_backend(cfg.model.backend)
|
|
63
|
+
base_model, adapter_path, _meta = resolve_model_spec(project_root, model_id_or_path, cfg)
|
|
64
|
+
|
|
65
|
+
try:
|
|
66
|
+
policy.load(
|
|
67
|
+
base_model,
|
|
68
|
+
max_seq_len=cfg.model.max_seq_len,
|
|
69
|
+
dtype=cfg.model.dtype,
|
|
70
|
+
trust_remote_code=cfg.model.trust_remote_code,
|
|
71
|
+
)
|
|
72
|
+
if adapter_path:
|
|
73
|
+
policy.apply_adapter(str(adapter_path))
|
|
74
|
+
else:
|
|
75
|
+
lora_cfg = LoRAConfig(
|
|
76
|
+
r=cfg.lora.r,
|
|
77
|
+
alpha=cfg.lora.alpha,
|
|
78
|
+
dropout=cfg.lora.dropout,
|
|
79
|
+
target_modules=list(cfg.lora.target_modules or []),
|
|
80
|
+
num_layers=cfg.lora.num_layers,
|
|
81
|
+
scale=cfg.lora.scale,
|
|
82
|
+
fine_tune_type=cfg.lora.fine_tune_type,
|
|
83
|
+
)
|
|
84
|
+
policy.apply_lora_from_config(lora_cfg)
|
|
85
|
+
except BackendNotAvailable as e:
|
|
86
|
+
console.print(f"[yellow]MLX backend unavailable[/yellow]: {e}")
|
|
87
|
+
(run.adapter_dir / "ADAPTER.txt").write_text(
|
|
88
|
+
f"Backend unavailable in this environment.\nmodel={model_id_or_path}\naccel={backend.name}\n",
|
|
89
|
+
encoding="utf-8",
|
|
90
|
+
)
|
|
91
|
+
return run
|
|
92
|
+
|
|
93
|
+
opt, _params = policy.optimizer_and_params(
|
|
94
|
+
lr=cfg.train.lr,
|
|
95
|
+
weight_decay=cfg.train.weight_decay,
|
|
96
|
+
optimizer=cfg.train.optimizer,
|
|
97
|
+
optimizer_kwargs=cfg.train.optimizer_kwargs,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
total = int(cfg.train.iters)
|
|
101
|
+
grad_accum = max(1, int(cfg.train.grad_accum))
|
|
102
|
+
max_grad_norm = float(getattr(cfg.train, "max_grad_norm", 0.0))
|
|
103
|
+
max_new = int(max_new_tokens or cfg.rft.max_new_tokens)
|
|
104
|
+
temp = float(temperature if temperature is not None else cfg.rft.temperature)
|
|
105
|
+
|
|
106
|
+
rng = random.Random(cfg.train.seed)
|
|
107
|
+
accum_grads = None
|
|
108
|
+
accum_loss = 0.0
|
|
109
|
+
accum_count = 0
|
|
110
|
+
reward_ema = 0.0
|
|
111
|
+
ema_alpha = 0.1
|
|
112
|
+
|
|
113
|
+
def _next_mock(idx: int) -> Optional[str]:
|
|
114
|
+
if judge_mock_response is None:
|
|
115
|
+
return None
|
|
116
|
+
if isinstance(judge_mock_response, list):
|
|
117
|
+
if not judge_mock_response:
|
|
118
|
+
return None
|
|
119
|
+
return judge_mock_response[min(idx, len(judge_mock_response) - 1)]
|
|
120
|
+
return judge_mock_response
|
|
121
|
+
|
|
122
|
+
for step in range(1, total + 1):
|
|
123
|
+
prompt = rng.choice(prompts)
|
|
124
|
+
gen = policy.generate_with_logprobs(
|
|
125
|
+
prompt,
|
|
126
|
+
max_new_tokens=max_new,
|
|
127
|
+
temperature=temp,
|
|
128
|
+
seed=rng.randint(0, 2**31 - 1),
|
|
129
|
+
logprobs=0,
|
|
130
|
+
)
|
|
131
|
+
completion = gen.text[len(prompt) :] if gen.text.startswith(prompt) else gen.text
|
|
132
|
+
|
|
133
|
+
res = judge_verify(
|
|
134
|
+
prompt,
|
|
135
|
+
completion,
|
|
136
|
+
str(run.artifacts_dir),
|
|
137
|
+
model=verifier_model or model_id_or_path,
|
|
138
|
+
backend=verifier_backend,
|
|
139
|
+
rubric=rubric,
|
|
140
|
+
reward_mode="score",
|
|
141
|
+
mock_response=_next_mock(0),
|
|
142
|
+
)
|
|
143
|
+
reward = float(getattr(res, "reward", 0.0))
|
|
144
|
+
reward_ema = (1.0 - ema_alpha) * reward_ema + ema_alpha * reward
|
|
145
|
+
advantage = reward - reward_ema
|
|
146
|
+
|
|
147
|
+
token_ids = list(gen.token_ids)
|
|
148
|
+
prompt_len = int(gen.prompt_len)
|
|
149
|
+
|
|
150
|
+
def loss_fn(_model):
|
|
151
|
+
logp = policy.sequence_logprob(token_ids, prompt_len=prompt_len)
|
|
152
|
+
return -policy.mx.array(float(advantage)) * logp # type: ignore
|
|
153
|
+
|
|
154
|
+
lval, grads = policy.value_and_grad(loss_fn)
|
|
155
|
+
accum_loss += float(lval.item()) if hasattr(lval, "item") else float(lval)
|
|
156
|
+
accum_count += 1
|
|
157
|
+
if grads is not None:
|
|
158
|
+
accum_grads = tree_add(accum_grads, grads)
|
|
159
|
+
|
|
160
|
+
if step % grad_accum == 0:
|
|
161
|
+
if accum_grads is not None:
|
|
162
|
+
scaled = tree_scale(accum_grads, 1.0 / grad_accum)
|
|
163
|
+
if max_grad_norm > 0:
|
|
164
|
+
scaled = clip_grad_norm(scaled, max_grad_norm)
|
|
165
|
+
policy.apply_grads(opt, scaled)
|
|
166
|
+
accum_grads = None
|
|
167
|
+
accum_loss = 0.0
|
|
168
|
+
accum_count = 0
|
|
169
|
+
|
|
170
|
+
if step % cfg.train.log_every == 0 or step == 1 or step == total:
|
|
171
|
+
avg_loss = accum_loss / max(1, accum_count) if accum_count else float(lval)
|
|
172
|
+
write_jsonl(
|
|
173
|
+
run.metrics_path,
|
|
174
|
+
[
|
|
175
|
+
{
|
|
176
|
+
"ts": now_ts(),
|
|
177
|
+
"step": step,
|
|
178
|
+
"kind": "self_verify",
|
|
179
|
+
"loss": avg_loss,
|
|
180
|
+
"reward": reward,
|
|
181
|
+
"advantage": advantage,
|
|
182
|
+
"accel": backend.name,
|
|
183
|
+
}
|
|
184
|
+
],
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
if step % cfg.train.save_every == 0 or step == total:
|
|
188
|
+
policy.save_adapter(
|
|
189
|
+
str(run.adapter_dir),
|
|
190
|
+
metadata={
|
|
191
|
+
"base_model": base_model,
|
|
192
|
+
"source_adapter": str(adapter_path) if adapter_path else None,
|
|
193
|
+
"run": run.run_dir.name,
|
|
194
|
+
"kind": "self_verify",
|
|
195
|
+
},
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
console.print(f"[green]Saved adapter[/green] {run.adapter_dir}")
|
|
199
|
+
return run
|
mlxsmith/train/sft.py
CHANGED
|
@@ -79,7 +79,12 @@ def run_sft(project_root: Path, cfg: ProjectConfig, data_dir: Path, model_id_or_
|
|
|
79
79
|
)
|
|
80
80
|
return run
|
|
81
81
|
|
|
82
|
-
opt, _params = llm.optimizer_and_params(
|
|
82
|
+
opt, _params = llm.optimizer_and_params(
|
|
83
|
+
lr=cfg.train.lr,
|
|
84
|
+
weight_decay=cfg.train.weight_decay,
|
|
85
|
+
optimizer=cfg.train.optimizer,
|
|
86
|
+
optimizer_kwargs=cfg.train.optimizer_kwargs,
|
|
87
|
+
)
|
|
83
88
|
|
|
84
89
|
total = int(cfg.train.iters)
|
|
85
90
|
grad_accum = max(1, int(cfg.train.grad_accum))
|
|
@@ -89,6 +94,7 @@ def run_sft(project_root: Path, cfg: ProjectConfig, data_dir: Path, model_id_or_
|
|
|
89
94
|
rng = random.Random(cfg.train.seed)
|
|
90
95
|
accum_grads = None
|
|
91
96
|
accum_loss = 0.0
|
|
97
|
+
accum_count = 0
|
|
92
98
|
|
|
93
99
|
for step in range(1, total + 1):
|
|
94
100
|
row = rng.choice(rows)
|
|
@@ -110,6 +116,7 @@ def run_sft(project_root: Path, cfg: ProjectConfig, data_dir: Path, model_id_or_
|
|
|
110
116
|
|
|
111
117
|
lval, grads = llm.value_and_grad(loss_fn)
|
|
112
118
|
accum_loss += float(lval.item()) if hasattr(lval, "item") else float(lval)
|
|
119
|
+
accum_count += 1
|
|
113
120
|
if grads is not None:
|
|
114
121
|
accum_grads = tree_add(accum_grads, grads)
|
|
115
122
|
|
|
@@ -121,8 +128,12 @@ def run_sft(project_root: Path, cfg: ProjectConfig, data_dir: Path, model_id_or_
|
|
|
121
128
|
llm.apply_grads(opt, scaled)
|
|
122
129
|
accum_grads = None
|
|
123
130
|
accum_loss = 0.0
|
|
131
|
+
accum_count = 0
|
|
124
132
|
|
|
125
133
|
if step % cfg.train.log_every == 0 or step == 1 or step == total:
|
|
134
|
+
avg_loss = (accum_loss / max(1, accum_count)) if accum_count else (
|
|
135
|
+
float(lval.item()) if hasattr(lval, "item") else float(lval)
|
|
136
|
+
)
|
|
126
137
|
write_jsonl(
|
|
127
138
|
run.metrics_path,
|
|
128
139
|
[
|
|
@@ -130,7 +141,7 @@ def run_sft(project_root: Path, cfg: ProjectConfig, data_dir: Path, model_id_or_
|
|
|
130
141
|
"ts": now_ts(),
|
|
131
142
|
"step": step,
|
|
132
143
|
"kind": "sft",
|
|
133
|
-
"loss":
|
|
144
|
+
"loss": avg_loss,
|
|
134
145
|
"accel": backend.name,
|
|
135
146
|
}
|
|
136
147
|
],
|
mlxsmith/util.py
CHANGED
|
@@ -46,7 +46,6 @@ class SystemInfo:
|
|
|
46
46
|
has_metal: Optional[bool]
|
|
47
47
|
has_mlx: bool
|
|
48
48
|
mlx_version: Optional[str]
|
|
49
|
-
has_zmlx: bool
|
|
50
49
|
|
|
51
50
|
def detect_system() -> SystemInfo:
|
|
52
51
|
has_mlx = False
|
|
@@ -58,10 +57,6 @@ def detect_system() -> SystemInfo:
|
|
|
58
57
|
except Exception:
|
|
59
58
|
pass
|
|
60
59
|
|
|
61
|
-
import importlib.util
|
|
62
|
-
|
|
63
|
-
has_zmlx = importlib.util.find_spec("zmlx") is not None
|
|
64
|
-
|
|
65
60
|
# Metal detection (best-effort): on macOS we assume Metal is present; for CI, this is not reliable.
|
|
66
61
|
has_metal = None
|
|
67
62
|
if sys.platform == "darwin":
|
|
@@ -83,7 +78,6 @@ def detect_system() -> SystemInfo:
|
|
|
83
78
|
has_metal=has_metal,
|
|
84
79
|
has_mlx=has_mlx,
|
|
85
80
|
mlx_version=mlx_version,
|
|
86
|
-
has_zmlx=has_zmlx,
|
|
87
81
|
)
|
|
88
82
|
|
|
89
83
|
def require(cond: bool, msg: str):
|
|
@@ -0,0 +1,278 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
import re
|
|
6
|
+
import time
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Dict, Optional
|
|
9
|
+
|
|
10
|
+
from .types import VerifyResult
|
|
11
|
+
from ..llm.registry import get_llm_backend
|
|
12
|
+
|
|
13
|
+
_STATE: Dict[str, Any] = {
|
|
14
|
+
"backend": None,
|
|
15
|
+
"backend_name": None,
|
|
16
|
+
"model_id": None,
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _read_text(value: Optional[str]) -> Optional[str]:
|
|
21
|
+
if not value:
|
|
22
|
+
return None
|
|
23
|
+
if value.startswith("@"):
|
|
24
|
+
path = Path(value[1:])
|
|
25
|
+
if path.exists():
|
|
26
|
+
return path.read_text(encoding="utf-8")
|
|
27
|
+
path = Path(value)
|
|
28
|
+
if path.exists():
|
|
29
|
+
return path.read_text(encoding="utf-8")
|
|
30
|
+
return value
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _load_backend(model_id: str, backend_name: str, *, max_seq_len: Optional[int], dtype: Optional[str], trust_remote_code: bool) -> Any:
|
|
34
|
+
if (
|
|
35
|
+
_STATE["backend"] is None
|
|
36
|
+
or _STATE["backend_name"] != backend_name
|
|
37
|
+
or _STATE["model_id"] != model_id
|
|
38
|
+
):
|
|
39
|
+
backend = get_llm_backend(backend_name)
|
|
40
|
+
backend.load(
|
|
41
|
+
model_id,
|
|
42
|
+
max_seq_len=max_seq_len,
|
|
43
|
+
dtype=dtype,
|
|
44
|
+
trust_remote_code=trust_remote_code,
|
|
45
|
+
)
|
|
46
|
+
_STATE["backend"] = backend
|
|
47
|
+
_STATE["backend_name"] = backend_name
|
|
48
|
+
_STATE["model_id"] = model_id
|
|
49
|
+
return _STATE["backend"]
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _extract_json(text: str) -> Optional[dict]:
|
|
53
|
+
if not text:
|
|
54
|
+
return None
|
|
55
|
+
start = text.find("{")
|
|
56
|
+
end = text.rfind("}")
|
|
57
|
+
if start == -1 or end == -1 or end <= start:
|
|
58
|
+
return None
|
|
59
|
+
snippet = text[start : end + 1].strip()
|
|
60
|
+
try:
|
|
61
|
+
return json.loads(snippet)
|
|
62
|
+
except json.JSONDecodeError:
|
|
63
|
+
cleaned = re.sub(r",\s*}", "}", snippet)
|
|
64
|
+
cleaned = re.sub(r",\s*]", "]", cleaned)
|
|
65
|
+
cleaned = cleaned.replace("'", "\"")
|
|
66
|
+
try:
|
|
67
|
+
return json.loads(cleaned)
|
|
68
|
+
except json.JSONDecodeError:
|
|
69
|
+
return None
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _coerce_float(val: Any) -> Optional[float]:
|
|
73
|
+
if val is None:
|
|
74
|
+
return None
|
|
75
|
+
try:
|
|
76
|
+
return float(val)
|
|
77
|
+
except (TypeError, ValueError):
|
|
78
|
+
return None
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _aggregate_scores(scores: list[float], mode: str) -> Optional[float]:
|
|
82
|
+
if not scores:
|
|
83
|
+
return None
|
|
84
|
+
mode = (mode or "product").lower()
|
|
85
|
+
if mode == "min":
|
|
86
|
+
return min(scores)
|
|
87
|
+
if mode == "mean":
|
|
88
|
+
return sum(scores) / float(len(scores))
|
|
89
|
+
prod = 1.0
|
|
90
|
+
for s in scores:
|
|
91
|
+
prod *= s
|
|
92
|
+
return prod
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _score_step(
|
|
96
|
+
judge,
|
|
97
|
+
*,
|
|
98
|
+
system_prompt: str,
|
|
99
|
+
step_text: str,
|
|
100
|
+
prompt: str,
|
|
101
|
+
completion: str,
|
|
102
|
+
rubric_text: str,
|
|
103
|
+
temperature: float,
|
|
104
|
+
max_new_tokens: int,
|
|
105
|
+
) -> Optional[float]:
|
|
106
|
+
step_prompt = (
|
|
107
|
+
f"{system_prompt}\n\n"
|
|
108
|
+
"Score this single step from a solution.\n\n"
|
|
109
|
+
f"## Task\n{prompt}\n\n"
|
|
110
|
+
f"## Model Answer\n{completion}\n\n"
|
|
111
|
+
f"## Step\n{step_text}\n\n"
|
|
112
|
+
f"## Rubric\n{rubric_text}\n\n"
|
|
113
|
+
"Return JSON only."
|
|
114
|
+
)
|
|
115
|
+
gen = judge.generate(
|
|
116
|
+
step_prompt,
|
|
117
|
+
max_new_tokens=max_new_tokens,
|
|
118
|
+
temperature=temperature,
|
|
119
|
+
top_p=1.0,
|
|
120
|
+
top_k=None,
|
|
121
|
+
)
|
|
122
|
+
raw = gen.text[len(step_prompt) :] if gen.text.startswith(step_prompt) else gen.text
|
|
123
|
+
parsed = _extract_json(raw) or {}
|
|
124
|
+
return _coerce_float(parsed.get("score"))
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def verify(
|
|
128
|
+
prompt: str,
|
|
129
|
+
completion: str,
|
|
130
|
+
workdir: str,
|
|
131
|
+
*,
|
|
132
|
+
model: Optional[str] = None,
|
|
133
|
+
backend: str = "mlx-lm",
|
|
134
|
+
system_prompt: Optional[str] = None,
|
|
135
|
+
rubric: Optional[str] = None,
|
|
136
|
+
mode: str = "judge",
|
|
137
|
+
temperature: float = 0.0,
|
|
138
|
+
max_new_tokens: int = 256,
|
|
139
|
+
min_score: float = 0.5,
|
|
140
|
+
reward_pass: float = 1.0,
|
|
141
|
+
reward_fail: float = 0.0,
|
|
142
|
+
reward_mode: str = "score",
|
|
143
|
+
max_seq_len: Optional[int] = None,
|
|
144
|
+
dtype: Optional[str] = None,
|
|
145
|
+
trust_remote_code: bool = False,
|
|
146
|
+
mock_response: Optional[str] = None,
|
|
147
|
+
process_agg: str = "product",
|
|
148
|
+
max_steps: int = 8,
|
|
149
|
+
**kwargs,
|
|
150
|
+
) -> VerifyResult:
|
|
151
|
+
"""LLM-based verifier with JSON output.
|
|
152
|
+
|
|
153
|
+
The judge should return JSON: {"passed": bool, "score": 0..1, "reason": "..."}.
|
|
154
|
+
Set mock_response to bypass backend loading (useful for tests).
|
|
155
|
+
"""
|
|
156
|
+
model_id = model or os.environ.get("MLXSMITH_JUDGE_MODEL")
|
|
157
|
+
if not model_id and not mock_response:
|
|
158
|
+
raise RuntimeError("llm_judge requires `model` or MLXSMITH_JUDGE_MODEL")
|
|
159
|
+
|
|
160
|
+
rubric_text = _read_text(rubric) or "Assess correctness and completeness."
|
|
161
|
+
mode = (mode or "judge").strip().lower()
|
|
162
|
+
sys_prompt = system_prompt or (
|
|
163
|
+
"You are a strict verifier. Return ONLY JSON with keys: "
|
|
164
|
+
"passed (bool), score (0-1), reason (string)."
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
if mode == "thinkprm":
|
|
168
|
+
sys_prompt = system_prompt or (
|
|
169
|
+
"You are a process reward model. Evaluate the reasoning quality. "
|
|
170
|
+
"Return ONLY JSON with keys: passed (bool), score (0-1), reason (string), steps (array)."
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
user_prompt = (
|
|
174
|
+
"## Task\n"
|
|
175
|
+
f"{prompt}\n\n"
|
|
176
|
+
"## Model Answer\n"
|
|
177
|
+
f"{completion}\n\n"
|
|
178
|
+
"## Rubric\n"
|
|
179
|
+
f"{rubric_text}\n\n"
|
|
180
|
+
"Return JSON only."
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
t0 = time.time()
|
|
184
|
+
if mock_response is not None:
|
|
185
|
+
raw = str(mock_response)
|
|
186
|
+
else:
|
|
187
|
+
judge = _load_backend(
|
|
188
|
+
model_id,
|
|
189
|
+
backend,
|
|
190
|
+
max_seq_len=max_seq_len,
|
|
191
|
+
dtype=dtype,
|
|
192
|
+
trust_remote_code=trust_remote_code,
|
|
193
|
+
)
|
|
194
|
+
full_prompt = f"{sys_prompt}\n\n{user_prompt}"
|
|
195
|
+
gen = judge.generate(
|
|
196
|
+
full_prompt,
|
|
197
|
+
max_new_tokens=max_new_tokens,
|
|
198
|
+
temperature=temperature,
|
|
199
|
+
top_p=1.0,
|
|
200
|
+
top_k=None,
|
|
201
|
+
)
|
|
202
|
+
raw = gen.text[len(full_prompt) :] if gen.text.startswith(full_prompt) else gen.text
|
|
203
|
+
|
|
204
|
+
parsed = _extract_json(raw) or {}
|
|
205
|
+
score = _coerce_float(parsed.get("score"))
|
|
206
|
+
passed_val = parsed.get("passed")
|
|
207
|
+
steps_raw = parsed.get("steps") if isinstance(parsed, dict) else None
|
|
208
|
+
step_texts: list[str] = []
|
|
209
|
+
step_scores: list[float] = []
|
|
210
|
+
if mode == "thinkprm" and isinstance(steps_raw, list):
|
|
211
|
+
for step in steps_raw[: max(1, int(max_steps))]:
|
|
212
|
+
if isinstance(step, dict):
|
|
213
|
+
text = step.get("text") or step.get("step") or step.get("content") or ""
|
|
214
|
+
if text:
|
|
215
|
+
step_texts.append(str(text))
|
|
216
|
+
s_val = _coerce_float(step.get("score"))
|
|
217
|
+
if s_val is not None:
|
|
218
|
+
step_scores.append(float(s_val))
|
|
219
|
+
elif isinstance(step, str):
|
|
220
|
+
step_texts.append(step)
|
|
221
|
+
|
|
222
|
+
if step_texts and (len(step_scores) < len(step_texts)) and mock_response is None:
|
|
223
|
+
judge = _load_backend(
|
|
224
|
+
model_id,
|
|
225
|
+
backend,
|
|
226
|
+
max_seq_len=max_seq_len,
|
|
227
|
+
dtype=dtype,
|
|
228
|
+
trust_remote_code=trust_remote_code,
|
|
229
|
+
)
|
|
230
|
+
for idx, step_text in enumerate(step_texts):
|
|
231
|
+
if idx < len(step_scores):
|
|
232
|
+
continue
|
|
233
|
+
s_val = _score_step(
|
|
234
|
+
judge,
|
|
235
|
+
system_prompt=sys_prompt,
|
|
236
|
+
step_text=step_text,
|
|
237
|
+
prompt=prompt,
|
|
238
|
+
completion=completion,
|
|
239
|
+
rubric_text=rubric_text,
|
|
240
|
+
temperature=temperature,
|
|
241
|
+
max_new_tokens=max_new_tokens,
|
|
242
|
+
)
|
|
243
|
+
if s_val is not None:
|
|
244
|
+
step_scores.append(float(s_val))
|
|
245
|
+
|
|
246
|
+
process_score = _aggregate_scores(step_scores, process_agg) if step_scores else None
|
|
247
|
+
if mode == "thinkprm" and process_score is not None:
|
|
248
|
+
score = process_score
|
|
249
|
+
if passed_val is None and score is not None:
|
|
250
|
+
passed_val = score >= min_score
|
|
251
|
+
passed = bool(passed_val) if passed_val is not None else False
|
|
252
|
+
reason = parsed.get("reason") or parsed.get("explanation") or ""
|
|
253
|
+
|
|
254
|
+
if reward_mode == "score" and score is not None:
|
|
255
|
+
reward = max(0.0, min(1.0, float(score)))
|
|
256
|
+
else:
|
|
257
|
+
reward = reward_pass if passed else reward_fail
|
|
258
|
+
|
|
259
|
+
latency_ms = (time.time() - t0) * 1000.0
|
|
260
|
+
|
|
261
|
+
return VerifyResult(
|
|
262
|
+
reward=reward,
|
|
263
|
+
passed=passed,
|
|
264
|
+
info={
|
|
265
|
+
"mode": mode,
|
|
266
|
+
"model": model_id,
|
|
267
|
+
"score": score,
|
|
268
|
+
"process_score": process_score,
|
|
269
|
+
"process_agg": process_agg,
|
|
270
|
+
"steps": step_texts,
|
|
271
|
+
"step_scores": step_scores,
|
|
272
|
+
"passed": passed,
|
|
273
|
+
"reason": reason,
|
|
274
|
+
"raw": raw,
|
|
275
|
+
"verifier_latency_ms": latency_ms,
|
|
276
|
+
},
|
|
277
|
+
artifacts_dir=workdir,
|
|
278
|
+
)
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import importlib.util
|
|
4
|
+
import re
|
|
5
|
+
from typing import Any, Dict, List, Optional
|
|
6
|
+
|
|
7
|
+
from .types import VerifyResult
|
|
8
|
+
|
|
9
|
+
_STATE: Dict[str, Dict[str, float]] = {
|
|
10
|
+
"values": {},
|
|
11
|
+
"counts": {},
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _load_verifier(path: str):
|
|
16
|
+
import sys
|
|
17
|
+
from pathlib import Path as _Path
|
|
18
|
+
|
|
19
|
+
verifier_path = _Path(path).resolve()
|
|
20
|
+
|
|
21
|
+
# If the file lives inside a Python package, set __package__ so that
|
|
22
|
+
# relative imports (e.g. ``from .types import ...``) work correctly.
|
|
23
|
+
pkg_name: Optional[str] = None
|
|
24
|
+
if (verifier_path.parent / "__init__.py").exists():
|
|
25
|
+
parts: list[str] = []
|
|
26
|
+
p = verifier_path.parent
|
|
27
|
+
while (p / "__init__.py").exists():
|
|
28
|
+
parts.insert(0, p.name)
|
|
29
|
+
p = p.parent
|
|
30
|
+
pkg_name = ".".join(parts)
|
|
31
|
+
root = str(p)
|
|
32
|
+
if root not in sys.path:
|
|
33
|
+
sys.path.insert(0, root)
|
|
34
|
+
|
|
35
|
+
mod_name = f"{pkg_name}._prime_loaded" if pkg_name else "prime_verifier"
|
|
36
|
+
spec = importlib.util.spec_from_file_location(mod_name, str(verifier_path))
|
|
37
|
+
if spec is None or spec.loader is None:
|
|
38
|
+
raise RuntimeError(f"Could not load verifier: {verifier_path}")
|
|
39
|
+
module = importlib.util.module_from_spec(spec)
|
|
40
|
+
if pkg_name is not None:
|
|
41
|
+
module.__package__ = pkg_name
|
|
42
|
+
spec.loader.exec_module(module) # type: ignore
|
|
43
|
+
verify_fn = getattr(module, "verify", None)
|
|
44
|
+
if not callable(verify_fn):
|
|
45
|
+
raise RuntimeError(f"Verifier must define verify(...): {verifier_path}")
|
|
46
|
+
return verify_fn
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _extract_steps(text: str, *, max_steps: int = 12) -> List[str]:
|
|
50
|
+
lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
|
|
51
|
+
steps = []
|
|
52
|
+
for ln in lines:
|
|
53
|
+
if re.match(r"^(\d+\.|\-|\*|\+)\s+", ln):
|
|
54
|
+
steps.append(re.sub(r"^(\d+\.|\-|\*|\+)\s+", "", ln).strip())
|
|
55
|
+
if not steps:
|
|
56
|
+
steps = lines[:max_steps]
|
|
57
|
+
return steps[:max_steps]
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _aggregate(values: List[float], mode: str) -> float:
|
|
61
|
+
if not values:
|
|
62
|
+
return 0.0
|
|
63
|
+
mode = (mode or "mean").lower()
|
|
64
|
+
if mode == "min":
|
|
65
|
+
return min(values)
|
|
66
|
+
if mode == "product":
|
|
67
|
+
out = 1.0
|
|
68
|
+
for v in values:
|
|
69
|
+
out *= v
|
|
70
|
+
return out
|
|
71
|
+
return sum(values) / float(len(values))
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def verify(
|
|
75
|
+
prompt: str,
|
|
76
|
+
completion: str,
|
|
77
|
+
workdir: str,
|
|
78
|
+
*,
|
|
79
|
+
verifier: str,
|
|
80
|
+
verifier_kwargs: Optional[Dict[str, Any]] = None,
|
|
81
|
+
ema_alpha: float = 0.2,
|
|
82
|
+
max_steps: int = 12,
|
|
83
|
+
agg: str = "mean",
|
|
84
|
+
reward_mode: str = "process",
|
|
85
|
+
min_score: float = 0.0,
|
|
86
|
+
**kwargs,
|
|
87
|
+
) -> VerifyResult:
|
|
88
|
+
"""PRIME-style implicit process rewards.
|
|
89
|
+
|
|
90
|
+
Uses outcome reward from a base verifier to update per-step values.
|
|
91
|
+
"""
|
|
92
|
+
verify_fn = _load_verifier(verifier)
|
|
93
|
+
base = verify_fn(prompt, completion, workdir, **(verifier_kwargs or {}), **kwargs)
|
|
94
|
+
outcome_reward = float(getattr(base, "reward", 0.0))
|
|
95
|
+
steps = _extract_steps(completion, max_steps=max_steps)
|
|
96
|
+
|
|
97
|
+
step_values: List[float] = []
|
|
98
|
+
for step in steps:
|
|
99
|
+
prev = _STATE["values"].get(step, outcome_reward)
|
|
100
|
+
new_val = (1.0 - ema_alpha) * prev + ema_alpha * outcome_reward
|
|
101
|
+
_STATE["values"][step] = new_val
|
|
102
|
+
_STATE["counts"][step] = _STATE["counts"].get(step, 0.0) + 1.0
|
|
103
|
+
step_values.append(new_val)
|
|
104
|
+
|
|
105
|
+
process_reward = _aggregate(step_values, agg)
|
|
106
|
+
if reward_mode == "combined":
|
|
107
|
+
reward = (process_reward + outcome_reward) / 2.0
|
|
108
|
+
else:
|
|
109
|
+
reward = process_reward
|
|
110
|
+
|
|
111
|
+
passed = bool(getattr(base, "passed", False)) and reward >= min_score
|
|
112
|
+
|
|
113
|
+
return VerifyResult(
|
|
114
|
+
reward=reward,
|
|
115
|
+
passed=passed,
|
|
116
|
+
info={
|
|
117
|
+
"mode": "prime",
|
|
118
|
+
"base_reward": outcome_reward,
|
|
119
|
+
"process_reward": process_reward,
|
|
120
|
+
"steps": steps,
|
|
121
|
+
"step_values": step_values,
|
|
122
|
+
"agg": agg,
|
|
123
|
+
"ema_alpha": ema_alpha,
|
|
124
|
+
"base_info": getattr(base, "info", {}),
|
|
125
|
+
},
|
|
126
|
+
artifacts_dir=workdir,
|
|
127
|
+
)
|