mlxsmith 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,199 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import random
5
+ from pathlib import Path
6
+ from typing import Iterable, Optional
7
+
8
+ from rich.console import Console
9
+
10
+ from ..accel import get_backend
11
+ from ..config import ProjectConfig
12
+ from ..models import resolve_model_spec
13
+ from ..runs import RunPaths, new_run, snapshot_config
14
+ from ..util import write_jsonl, now_ts, tree_add, tree_scale, clip_grad_norm
15
+ from ..llm.registry import get_llm_backend
16
+ from ..llm.backend import BackendNotAvailable
17
+ from ..verifiers.llm_judge import verify as judge_verify
18
+ from .lora import LoRAConfig
19
+
20
+ console = Console()
21
+
22
+
23
+ def _iter_prompts(path: Path) -> Iterable[str]:
24
+ for line in path.read_text(encoding="utf-8").splitlines():
25
+ if not line.strip():
26
+ continue
27
+ row = json.loads(line)
28
+ prompt = row.get("prompt") or row.get("instruction") or row.get("input") or row.get("question") or ""
29
+ if not prompt and "messages" in row:
30
+ msgs = row.get("messages") or []
31
+ if msgs:
32
+ prompt = "\n".join([m.get("content", "") for m in msgs])
33
+ if prompt:
34
+ yield str(prompt)
35
+
36
+
37
+ def run_self_verify(
38
+ project_root: Path,
39
+ cfg: ProjectConfig,
40
+ data_path: Path,
41
+ model_id_or_path: str,
42
+ accel: str,
43
+ *,
44
+ verifier_model: Optional[str] = None,
45
+ verifier_backend: str = "mlx-lm",
46
+ rubric: Optional[str] = None,
47
+ max_new_tokens: Optional[int] = None,
48
+ temperature: Optional[float] = None,
49
+ judge_mock_response: Optional[str | list[str]] = None,
50
+ ) -> RunPaths:
51
+ run = new_run(project_root, "self_verify")
52
+ snapshot_config(cfg.model_dump(), run.config_snapshot_path)
53
+
54
+ prompts = list(_iter_prompts(data_path))
55
+ if not prompts:
56
+ raise RuntimeError("No prompts found in self-verify dataset")
57
+
58
+ backend = get_backend(accel)
59
+ backend.patch()
60
+ console.print(f"[bold]SELF-VERIFY[/bold] run: {run.run_dir.name} accel={backend.name}")
61
+
62
+ policy = get_llm_backend(cfg.model.backend)
63
+ base_model, adapter_path, _meta = resolve_model_spec(project_root, model_id_or_path, cfg)
64
+
65
+ try:
66
+ policy.load(
67
+ base_model,
68
+ max_seq_len=cfg.model.max_seq_len,
69
+ dtype=cfg.model.dtype,
70
+ trust_remote_code=cfg.model.trust_remote_code,
71
+ )
72
+ if adapter_path:
73
+ policy.apply_adapter(str(adapter_path))
74
+ else:
75
+ lora_cfg = LoRAConfig(
76
+ r=cfg.lora.r,
77
+ alpha=cfg.lora.alpha,
78
+ dropout=cfg.lora.dropout,
79
+ target_modules=list(cfg.lora.target_modules or []),
80
+ num_layers=cfg.lora.num_layers,
81
+ scale=cfg.lora.scale,
82
+ fine_tune_type=cfg.lora.fine_tune_type,
83
+ )
84
+ policy.apply_lora_from_config(lora_cfg)
85
+ except BackendNotAvailable as e:
86
+ console.print(f"[yellow]MLX backend unavailable[/yellow]: {e}")
87
+ (run.adapter_dir / "ADAPTER.txt").write_text(
88
+ f"Backend unavailable in this environment.\nmodel={model_id_or_path}\naccel={backend.name}\n",
89
+ encoding="utf-8",
90
+ )
91
+ return run
92
+
93
+ opt, _params = policy.optimizer_and_params(
94
+ lr=cfg.train.lr,
95
+ weight_decay=cfg.train.weight_decay,
96
+ optimizer=cfg.train.optimizer,
97
+ optimizer_kwargs=cfg.train.optimizer_kwargs,
98
+ )
99
+
100
+ total = int(cfg.train.iters)
101
+ grad_accum = max(1, int(cfg.train.grad_accum))
102
+ max_grad_norm = float(getattr(cfg.train, "max_grad_norm", 0.0))
103
+ max_new = int(max_new_tokens or cfg.rft.max_new_tokens)
104
+ temp = float(temperature if temperature is not None else cfg.rft.temperature)
105
+
106
+ rng = random.Random(cfg.train.seed)
107
+ accum_grads = None
108
+ accum_loss = 0.0
109
+ accum_count = 0
110
+ reward_ema = 0.0
111
+ ema_alpha = 0.1
112
+
113
+ def _next_mock(idx: int) -> Optional[str]:
114
+ if judge_mock_response is None:
115
+ return None
116
+ if isinstance(judge_mock_response, list):
117
+ if not judge_mock_response:
118
+ return None
119
+ return judge_mock_response[min(idx, len(judge_mock_response) - 1)]
120
+ return judge_mock_response
121
+
122
+ for step in range(1, total + 1):
123
+ prompt = rng.choice(prompts)
124
+ gen = policy.generate_with_logprobs(
125
+ prompt,
126
+ max_new_tokens=max_new,
127
+ temperature=temp,
128
+ seed=rng.randint(0, 2**31 - 1),
129
+ logprobs=0,
130
+ )
131
+ completion = gen.text[len(prompt) :] if gen.text.startswith(prompt) else gen.text
132
+
133
+ res = judge_verify(
134
+ prompt,
135
+ completion,
136
+ str(run.artifacts_dir),
137
+ model=verifier_model or model_id_or_path,
138
+ backend=verifier_backend,
139
+ rubric=rubric,
140
+ reward_mode="score",
141
+ mock_response=_next_mock(0),
142
+ )
143
+ reward = float(getattr(res, "reward", 0.0))
144
+ reward_ema = (1.0 - ema_alpha) * reward_ema + ema_alpha * reward
145
+ advantage = reward - reward_ema
146
+
147
+ token_ids = list(gen.token_ids)
148
+ prompt_len = int(gen.prompt_len)
149
+
150
+ def loss_fn(_model):
151
+ logp = policy.sequence_logprob(token_ids, prompt_len=prompt_len)
152
+ return -policy.mx.array(float(advantage)) * logp # type: ignore
153
+
154
+ lval, grads = policy.value_and_grad(loss_fn)
155
+ accum_loss += float(lval.item()) if hasattr(lval, "item") else float(lval)
156
+ accum_count += 1
157
+ if grads is not None:
158
+ accum_grads = tree_add(accum_grads, grads)
159
+
160
+ if step % grad_accum == 0:
161
+ if accum_grads is not None:
162
+ scaled = tree_scale(accum_grads, 1.0 / grad_accum)
163
+ if max_grad_norm > 0:
164
+ scaled = clip_grad_norm(scaled, max_grad_norm)
165
+ policy.apply_grads(opt, scaled)
166
+ accum_grads = None
167
+ accum_loss = 0.0
168
+ accum_count = 0
169
+
170
+ if step % cfg.train.log_every == 0 or step == 1 or step == total:
171
+ avg_loss = accum_loss / max(1, accum_count) if accum_count else float(lval)
172
+ write_jsonl(
173
+ run.metrics_path,
174
+ [
175
+ {
176
+ "ts": now_ts(),
177
+ "step": step,
178
+ "kind": "self_verify",
179
+ "loss": avg_loss,
180
+ "reward": reward,
181
+ "advantage": advantage,
182
+ "accel": backend.name,
183
+ }
184
+ ],
185
+ )
186
+
187
+ if step % cfg.train.save_every == 0 or step == total:
188
+ policy.save_adapter(
189
+ str(run.adapter_dir),
190
+ metadata={
191
+ "base_model": base_model,
192
+ "source_adapter": str(adapter_path) if adapter_path else None,
193
+ "run": run.run_dir.name,
194
+ "kind": "self_verify",
195
+ },
196
+ )
197
+
198
+ console.print(f"[green]Saved adapter[/green] {run.adapter_dir}")
199
+ return run
mlxsmith/train/sft.py CHANGED
@@ -79,7 +79,12 @@ def run_sft(project_root: Path, cfg: ProjectConfig, data_dir: Path, model_id_or_
79
79
  )
80
80
  return run
81
81
 
82
- opt, _params = llm.optimizer_and_params(lr=cfg.train.lr, weight_decay=cfg.train.weight_decay)
82
+ opt, _params = llm.optimizer_and_params(
83
+ lr=cfg.train.lr,
84
+ weight_decay=cfg.train.weight_decay,
85
+ optimizer=cfg.train.optimizer,
86
+ optimizer_kwargs=cfg.train.optimizer_kwargs,
87
+ )
83
88
 
84
89
  total = int(cfg.train.iters)
85
90
  grad_accum = max(1, int(cfg.train.grad_accum))
@@ -89,6 +94,7 @@ def run_sft(project_root: Path, cfg: ProjectConfig, data_dir: Path, model_id_or_
89
94
  rng = random.Random(cfg.train.seed)
90
95
  accum_grads = None
91
96
  accum_loss = 0.0
97
+ accum_count = 0
92
98
 
93
99
  for step in range(1, total + 1):
94
100
  row = rng.choice(rows)
@@ -110,6 +116,7 @@ def run_sft(project_root: Path, cfg: ProjectConfig, data_dir: Path, model_id_or_
110
116
 
111
117
  lval, grads = llm.value_and_grad(loss_fn)
112
118
  accum_loss += float(lval.item()) if hasattr(lval, "item") else float(lval)
119
+ accum_count += 1
113
120
  if grads is not None:
114
121
  accum_grads = tree_add(accum_grads, grads)
115
122
 
@@ -121,8 +128,12 @@ def run_sft(project_root: Path, cfg: ProjectConfig, data_dir: Path, model_id_or_
121
128
  llm.apply_grads(opt, scaled)
122
129
  accum_grads = None
123
130
  accum_loss = 0.0
131
+ accum_count = 0
124
132
 
125
133
  if step % cfg.train.log_every == 0 or step == 1 or step == total:
134
+ avg_loss = (accum_loss / max(1, accum_count)) if accum_count else (
135
+ float(lval.item()) if hasattr(lval, "item") else float(lval)
136
+ )
126
137
  write_jsonl(
127
138
  run.metrics_path,
128
139
  [
@@ -130,7 +141,7 @@ def run_sft(project_root: Path, cfg: ProjectConfig, data_dir: Path, model_id_or_
130
141
  "ts": now_ts(),
131
142
  "step": step,
132
143
  "kind": "sft",
133
- "loss": float(lval.item()) if hasattr(lval, "item") else float(lval),
144
+ "loss": avg_loss,
134
145
  "accel": backend.name,
135
146
  }
136
147
  ],
@@ -0,0 +1,278 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ import re
6
+ import time
7
+ from pathlib import Path
8
+ from typing import Any, Dict, Optional
9
+
10
+ from .types import VerifyResult
11
+ from ..llm.registry import get_llm_backend
12
+
13
+ _STATE: Dict[str, Any] = {
14
+ "backend": None,
15
+ "backend_name": None,
16
+ "model_id": None,
17
+ }
18
+
19
+
20
+ def _read_text(value: Optional[str]) -> Optional[str]:
21
+ if not value:
22
+ return None
23
+ if value.startswith("@"):
24
+ path = Path(value[1:])
25
+ if path.exists():
26
+ return path.read_text(encoding="utf-8")
27
+ path = Path(value)
28
+ if path.exists():
29
+ return path.read_text(encoding="utf-8")
30
+ return value
31
+
32
+
33
+ def _load_backend(model_id: str, backend_name: str, *, max_seq_len: Optional[int], dtype: Optional[str], trust_remote_code: bool) -> Any:
34
+ if (
35
+ _STATE["backend"] is None
36
+ or _STATE["backend_name"] != backend_name
37
+ or _STATE["model_id"] != model_id
38
+ ):
39
+ backend = get_llm_backend(backend_name)
40
+ backend.load(
41
+ model_id,
42
+ max_seq_len=max_seq_len,
43
+ dtype=dtype,
44
+ trust_remote_code=trust_remote_code,
45
+ )
46
+ _STATE["backend"] = backend
47
+ _STATE["backend_name"] = backend_name
48
+ _STATE["model_id"] = model_id
49
+ return _STATE["backend"]
50
+
51
+
52
+ def _extract_json(text: str) -> Optional[dict]:
53
+ if not text:
54
+ return None
55
+ start = text.find("{")
56
+ end = text.rfind("}")
57
+ if start == -1 or end == -1 or end <= start:
58
+ return None
59
+ snippet = text[start : end + 1].strip()
60
+ try:
61
+ return json.loads(snippet)
62
+ except json.JSONDecodeError:
63
+ cleaned = re.sub(r",\s*}", "}", snippet)
64
+ cleaned = re.sub(r",\s*]", "]", cleaned)
65
+ cleaned = cleaned.replace("'", "\"")
66
+ try:
67
+ return json.loads(cleaned)
68
+ except json.JSONDecodeError:
69
+ return None
70
+
71
+
72
+ def _coerce_float(val: Any) -> Optional[float]:
73
+ if val is None:
74
+ return None
75
+ try:
76
+ return float(val)
77
+ except (TypeError, ValueError):
78
+ return None
79
+
80
+
81
+ def _aggregate_scores(scores: list[float], mode: str) -> Optional[float]:
82
+ if not scores:
83
+ return None
84
+ mode = (mode or "product").lower()
85
+ if mode == "min":
86
+ return min(scores)
87
+ if mode == "mean":
88
+ return sum(scores) / float(len(scores))
89
+ prod = 1.0
90
+ for s in scores:
91
+ prod *= s
92
+ return prod
93
+
94
+
95
+ def _score_step(
96
+ judge,
97
+ *,
98
+ system_prompt: str,
99
+ step_text: str,
100
+ prompt: str,
101
+ completion: str,
102
+ rubric_text: str,
103
+ temperature: float,
104
+ max_new_tokens: int,
105
+ ) -> Optional[float]:
106
+ step_prompt = (
107
+ f"{system_prompt}\n\n"
108
+ "Score this single step from a solution.\n\n"
109
+ f"## Task\n{prompt}\n\n"
110
+ f"## Model Answer\n{completion}\n\n"
111
+ f"## Step\n{step_text}\n\n"
112
+ f"## Rubric\n{rubric_text}\n\n"
113
+ "Return JSON only."
114
+ )
115
+ gen = judge.generate(
116
+ step_prompt,
117
+ max_new_tokens=max_new_tokens,
118
+ temperature=temperature,
119
+ top_p=1.0,
120
+ top_k=None,
121
+ )
122
+ raw = gen.text[len(step_prompt) :] if gen.text.startswith(step_prompt) else gen.text
123
+ parsed = _extract_json(raw) or {}
124
+ return _coerce_float(parsed.get("score"))
125
+
126
+
127
+ def verify(
128
+ prompt: str,
129
+ completion: str,
130
+ workdir: str,
131
+ *,
132
+ model: Optional[str] = None,
133
+ backend: str = "mlx-lm",
134
+ system_prompt: Optional[str] = None,
135
+ rubric: Optional[str] = None,
136
+ mode: str = "judge",
137
+ temperature: float = 0.0,
138
+ max_new_tokens: int = 256,
139
+ min_score: float = 0.5,
140
+ reward_pass: float = 1.0,
141
+ reward_fail: float = 0.0,
142
+ reward_mode: str = "score",
143
+ max_seq_len: Optional[int] = None,
144
+ dtype: Optional[str] = None,
145
+ trust_remote_code: bool = False,
146
+ mock_response: Optional[str] = None,
147
+ process_agg: str = "product",
148
+ max_steps: int = 8,
149
+ **kwargs,
150
+ ) -> VerifyResult:
151
+ """LLM-based verifier with JSON output.
152
+
153
+ The judge should return JSON: {"passed": bool, "score": 0..1, "reason": "..."}.
154
+ Set mock_response to bypass backend loading (useful for tests).
155
+ """
156
+ model_id = model or os.environ.get("MLXSMITH_JUDGE_MODEL")
157
+ if not model_id and not mock_response:
158
+ raise RuntimeError("llm_judge requires `model` or MLXSMITH_JUDGE_MODEL")
159
+
160
+ rubric_text = _read_text(rubric) or "Assess correctness and completeness."
161
+ mode = (mode or "judge").strip().lower()
162
+ sys_prompt = system_prompt or (
163
+ "You are a strict verifier. Return ONLY JSON with keys: "
164
+ "passed (bool), score (0-1), reason (string)."
165
+ )
166
+
167
+ if mode == "thinkprm":
168
+ sys_prompt = system_prompt or (
169
+ "You are a process reward model. Evaluate the reasoning quality. "
170
+ "Return ONLY JSON with keys: passed (bool), score (0-1), reason (string), steps (array)."
171
+ )
172
+
173
+ user_prompt = (
174
+ "## Task\n"
175
+ f"{prompt}\n\n"
176
+ "## Model Answer\n"
177
+ f"{completion}\n\n"
178
+ "## Rubric\n"
179
+ f"{rubric_text}\n\n"
180
+ "Return JSON only."
181
+ )
182
+
183
+ t0 = time.time()
184
+ if mock_response is not None:
185
+ raw = str(mock_response)
186
+ else:
187
+ judge = _load_backend(
188
+ model_id,
189
+ backend,
190
+ max_seq_len=max_seq_len,
191
+ dtype=dtype,
192
+ trust_remote_code=trust_remote_code,
193
+ )
194
+ full_prompt = f"{sys_prompt}\n\n{user_prompt}"
195
+ gen = judge.generate(
196
+ full_prompt,
197
+ max_new_tokens=max_new_tokens,
198
+ temperature=temperature,
199
+ top_p=1.0,
200
+ top_k=None,
201
+ )
202
+ raw = gen.text[len(full_prompt) :] if gen.text.startswith(full_prompt) else gen.text
203
+
204
+ parsed = _extract_json(raw) or {}
205
+ score = _coerce_float(parsed.get("score"))
206
+ passed_val = parsed.get("passed")
207
+ steps_raw = parsed.get("steps") if isinstance(parsed, dict) else None
208
+ step_texts: list[str] = []
209
+ step_scores: list[float] = []
210
+ if mode == "thinkprm" and isinstance(steps_raw, list):
211
+ for step in steps_raw[: max(1, int(max_steps))]:
212
+ if isinstance(step, dict):
213
+ text = step.get("text") or step.get("step") or step.get("content") or ""
214
+ if text:
215
+ step_texts.append(str(text))
216
+ s_val = _coerce_float(step.get("score"))
217
+ if s_val is not None:
218
+ step_scores.append(float(s_val))
219
+ elif isinstance(step, str):
220
+ step_texts.append(step)
221
+
222
+ if step_texts and (len(step_scores) < len(step_texts)) and mock_response is None:
223
+ judge = _load_backend(
224
+ model_id,
225
+ backend,
226
+ max_seq_len=max_seq_len,
227
+ dtype=dtype,
228
+ trust_remote_code=trust_remote_code,
229
+ )
230
+ for idx, step_text in enumerate(step_texts):
231
+ if idx < len(step_scores):
232
+ continue
233
+ s_val = _score_step(
234
+ judge,
235
+ system_prompt=sys_prompt,
236
+ step_text=step_text,
237
+ prompt=prompt,
238
+ completion=completion,
239
+ rubric_text=rubric_text,
240
+ temperature=temperature,
241
+ max_new_tokens=max_new_tokens,
242
+ )
243
+ if s_val is not None:
244
+ step_scores.append(float(s_val))
245
+
246
+ process_score = _aggregate_scores(step_scores, process_agg) if step_scores else None
247
+ if mode == "thinkprm" and process_score is not None:
248
+ score = process_score
249
+ if passed_val is None and score is not None:
250
+ passed_val = score >= min_score
251
+ passed = bool(passed_val) if passed_val is not None else False
252
+ reason = parsed.get("reason") or parsed.get("explanation") or ""
253
+
254
+ if reward_mode == "score" and score is not None:
255
+ reward = max(0.0, min(1.0, float(score)))
256
+ else:
257
+ reward = reward_pass if passed else reward_fail
258
+
259
+ latency_ms = (time.time() - t0) * 1000.0
260
+
261
+ return VerifyResult(
262
+ reward=reward,
263
+ passed=passed,
264
+ info={
265
+ "mode": mode,
266
+ "model": model_id,
267
+ "score": score,
268
+ "process_score": process_score,
269
+ "process_agg": process_agg,
270
+ "steps": step_texts,
271
+ "step_scores": step_scores,
272
+ "passed": passed,
273
+ "reason": reason,
274
+ "raw": raw,
275
+ "verifier_latency_ms": latency_ms,
276
+ },
277
+ artifacts_dir=workdir,
278
+ )
@@ -0,0 +1,127 @@
1
+ from __future__ import annotations
2
+
3
+ import importlib.util
4
+ import re
5
+ from typing import Any, Dict, List, Optional
6
+
7
+ from .types import VerifyResult
8
+
9
+ _STATE: Dict[str, Dict[str, float]] = {
10
+ "values": {},
11
+ "counts": {},
12
+ }
13
+
14
+
15
+ def _load_verifier(path: str):
16
+ import sys
17
+ from pathlib import Path as _Path
18
+
19
+ verifier_path = _Path(path).resolve()
20
+
21
+ # If the file lives inside a Python package, set __package__ so that
22
+ # relative imports (e.g. ``from .types import ...``) work correctly.
23
+ pkg_name: Optional[str] = None
24
+ if (verifier_path.parent / "__init__.py").exists():
25
+ parts: list[str] = []
26
+ p = verifier_path.parent
27
+ while (p / "__init__.py").exists():
28
+ parts.insert(0, p.name)
29
+ p = p.parent
30
+ pkg_name = ".".join(parts)
31
+ root = str(p)
32
+ if root not in sys.path:
33
+ sys.path.insert(0, root)
34
+
35
+ mod_name = f"{pkg_name}._prime_loaded" if pkg_name else "prime_verifier"
36
+ spec = importlib.util.spec_from_file_location(mod_name, str(verifier_path))
37
+ if spec is None or spec.loader is None:
38
+ raise RuntimeError(f"Could not load verifier: {verifier_path}")
39
+ module = importlib.util.module_from_spec(spec)
40
+ if pkg_name is not None:
41
+ module.__package__ = pkg_name
42
+ spec.loader.exec_module(module) # type: ignore
43
+ verify_fn = getattr(module, "verify", None)
44
+ if not callable(verify_fn):
45
+ raise RuntimeError(f"Verifier must define verify(...): {verifier_path}")
46
+ return verify_fn
47
+
48
+
49
+ def _extract_steps(text: str, *, max_steps: int = 12) -> List[str]:
50
+ lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
51
+ steps = []
52
+ for ln in lines:
53
+ if re.match(r"^(\d+\.|\-|\*|\+)\s+", ln):
54
+ steps.append(re.sub(r"^(\d+\.|\-|\*|\+)\s+", "", ln).strip())
55
+ if not steps:
56
+ steps = lines[:max_steps]
57
+ return steps[:max_steps]
58
+
59
+
60
+ def _aggregate(values: List[float], mode: str) -> float:
61
+ if not values:
62
+ return 0.0
63
+ mode = (mode or "mean").lower()
64
+ if mode == "min":
65
+ return min(values)
66
+ if mode == "product":
67
+ out = 1.0
68
+ for v in values:
69
+ out *= v
70
+ return out
71
+ return sum(values) / float(len(values))
72
+
73
+
74
+ def verify(
75
+ prompt: str,
76
+ completion: str,
77
+ workdir: str,
78
+ *,
79
+ verifier: str,
80
+ verifier_kwargs: Optional[Dict[str, Any]] = None,
81
+ ema_alpha: float = 0.2,
82
+ max_steps: int = 12,
83
+ agg: str = "mean",
84
+ reward_mode: str = "process",
85
+ min_score: float = 0.0,
86
+ **kwargs,
87
+ ) -> VerifyResult:
88
+ """PRIME-style implicit process rewards.
89
+
90
+ Uses outcome reward from a base verifier to update per-step values.
91
+ """
92
+ verify_fn = _load_verifier(verifier)
93
+ base = verify_fn(prompt, completion, workdir, **(verifier_kwargs or {}), **kwargs)
94
+ outcome_reward = float(getattr(base, "reward", 0.0))
95
+ steps = _extract_steps(completion, max_steps=max_steps)
96
+
97
+ step_values: List[float] = []
98
+ for step in steps:
99
+ prev = _STATE["values"].get(step, outcome_reward)
100
+ new_val = (1.0 - ema_alpha) * prev + ema_alpha * outcome_reward
101
+ _STATE["values"][step] = new_val
102
+ _STATE["counts"][step] = _STATE["counts"].get(step, 0.0) + 1.0
103
+ step_values.append(new_val)
104
+
105
+ process_reward = _aggregate(step_values, agg)
106
+ if reward_mode == "combined":
107
+ reward = (process_reward + outcome_reward) / 2.0
108
+ else:
109
+ reward = process_reward
110
+
111
+ passed = bool(getattr(base, "passed", False)) and reward >= min_score
112
+
113
+ return VerifyResult(
114
+ reward=reward,
115
+ passed=passed,
116
+ info={
117
+ "mode": "prime",
118
+ "base_reward": outcome_reward,
119
+ "process_reward": process_reward,
120
+ "steps": steps,
121
+ "step_values": step_values,
122
+ "agg": agg,
123
+ "ema_alpha": ema_alpha,
124
+ "base_info": getattr(base, "info", {}),
125
+ },
126
+ artifacts_dir=workdir,
127
+ )