mlxsmith 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,249 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import random
5
+ from pathlib import Path
6
+ from typing import Iterable, Optional
7
+
8
+ from rich.console import Console
9
+
10
+ from ..accel import get_backend
11
+ from ..config import ProjectConfig
12
+ from ..models import resolve_model_spec
13
+ from ..runs import RunPaths, new_run, snapshot_config
14
+ from ..util import write_jsonl, now_ts, tree_add, tree_scale, clip_grad_norm
15
+ from ..llm.registry import get_llm_backend
16
+ from ..llm.backend import BackendNotAvailable
17
+ from ..sdk.losses import preference_loss
18
+ from ..verifiers.llm_judge import verify as judge_verify
19
+ from .lora import LoRAConfig
20
+
21
+ console = Console()
22
+
23
+
24
+ def _iter_prompts(path: Path) -> Iterable[str]:
25
+ for line in path.read_text(encoding="utf-8").splitlines():
26
+ if not line.strip():
27
+ continue
28
+ row = json.loads(line)
29
+ prompt = row.get("prompt") or row.get("instruction") or row.get("input") or row.get("question") or ""
30
+ if not prompt and "messages" in row:
31
+ msgs = row.get("messages") or []
32
+ if msgs:
33
+ prompt = "\n".join([m.get("content", "") for m in msgs])
34
+ if prompt:
35
+ yield str(prompt)
36
+
37
+
38
+ def run_online_dpo(
39
+ project_root: Path,
40
+ cfg: ProjectConfig,
41
+ data_path: Path,
42
+ model_id_or_path: str,
43
+ accel: str,
44
+ *,
45
+ judge_model: Optional[str] = None,
46
+ judge_backend: str = "mlx-lm",
47
+ rubric: Optional[str] = None,
48
+ group_size: Optional[int] = None,
49
+ max_new_tokens: Optional[int] = None,
50
+ temperature: Optional[float] = None,
51
+ judge_mock_response: Optional[str | list[str]] = None,
52
+ ) -> RunPaths:
53
+ run = new_run(project_root, "online_dpo")
54
+ snapshot_config(cfg.model_dump(), run.config_snapshot_path)
55
+
56
+ prompts = list(_iter_prompts(data_path))
57
+ if not prompts:
58
+ raise RuntimeError("No prompts found in online DPO dataset")
59
+
60
+ backend = get_backend(accel)
61
+ backend.patch()
62
+ console.print(f"[bold]ONLINE-DPO[/bold] run: {run.run_dir.name} accel={backend.name}")
63
+
64
+ llm = get_llm_backend(cfg.model.backend)
65
+ base_model, adapter_path, _meta = resolve_model_spec(project_root, model_id_or_path, cfg)
66
+
67
+ try:
68
+ llm.load(
69
+ base_model,
70
+ max_seq_len=cfg.model.max_seq_len,
71
+ dtype=cfg.model.dtype,
72
+ trust_remote_code=cfg.model.trust_remote_code,
73
+ )
74
+ if adapter_path:
75
+ llm.apply_adapter(str(adapter_path))
76
+ else:
77
+ lora_cfg = LoRAConfig(
78
+ r=cfg.lora.r,
79
+ alpha=cfg.lora.alpha,
80
+ dropout=cfg.lora.dropout,
81
+ target_modules=list(cfg.lora.target_modules or []),
82
+ num_layers=cfg.lora.num_layers,
83
+ scale=cfg.lora.scale,
84
+ fine_tune_type=cfg.lora.fine_tune_type,
85
+ )
86
+ llm.apply_lora_from_config(lora_cfg)
87
+ except BackendNotAvailable as e:
88
+ console.print(f"[yellow]MLX backend unavailable[/yellow]: {e}")
89
+ (run.adapter_dir / "ADAPTER.txt").write_text(
90
+ f"Backend unavailable in this environment.\nmodel={model_id_or_path}\naccel={backend.name}\n",
91
+ encoding="utf-8",
92
+ )
93
+ return run
94
+
95
+ ref_llm = None
96
+ if cfg.pref.reference_model:
97
+ ref_llm = get_llm_backend(cfg.model.backend)
98
+ try:
99
+ ref_llm.load(
100
+ cfg.pref.reference_model,
101
+ max_seq_len=cfg.model.max_seq_len,
102
+ dtype=cfg.model.dtype,
103
+ trust_remote_code=cfg.model.trust_remote_code,
104
+ )
105
+ except BackendNotAvailable:
106
+ ref_llm = None
107
+
108
+ opt, _params = llm.optimizer_and_params(
109
+ lr=cfg.train.lr,
110
+ weight_decay=cfg.train.weight_decay,
111
+ optimizer=cfg.train.optimizer,
112
+ optimizer_kwargs=cfg.train.optimizer_kwargs,
113
+ )
114
+
115
+ loss_type = str(cfg.pref.loss_type or cfg.pref.algo)
116
+ beta = float(cfg.pref.beta)
117
+ kl_coeff = float(cfg.pref.kl_coeff)
118
+ delta = float(cfg.pref.delta)
119
+
120
+ total = int(cfg.train.iters)
121
+ grad_accum = max(1, int(cfg.train.grad_accum))
122
+ max_grad_norm = float(getattr(cfg.train, "max_grad_norm", 0.0))
123
+ group = int(group_size or cfg.rft.rollouts or 4)
124
+ max_new = int(max_new_tokens or cfg.rft.max_new_tokens)
125
+ temp = float(temperature if temperature is not None else cfg.rft.temperature)
126
+
127
+ rng = random.Random(cfg.train.seed)
128
+ accum_grads = None
129
+ accum_loss = 0.0
130
+ accum_count = 0
131
+
132
+ def _next_mock(idx: int) -> Optional[str]:
133
+ if judge_mock_response is None:
134
+ return None
135
+ if isinstance(judge_mock_response, list):
136
+ if not judge_mock_response:
137
+ return None
138
+ return judge_mock_response[min(idx, len(judge_mock_response) - 1)]
139
+ return judge_mock_response
140
+
141
+ for step in range(1, total + 1):
142
+ prompt = rng.choice(prompts)
143
+ candidates: list[tuple[str, float]] = []
144
+ for k in range(group):
145
+ gen = llm.generate(
146
+ prompt,
147
+ max_new_tokens=max_new,
148
+ temperature=temp,
149
+ seed=rng.randint(0, 2**31 - 1),
150
+ )
151
+ completion = gen.text[len(prompt) :] if gen.text.startswith(prompt) else gen.text
152
+ res = judge_verify(
153
+ prompt,
154
+ completion,
155
+ str(run.artifacts_dir),
156
+ model=judge_model,
157
+ backend=judge_backend,
158
+ rubric=rubric,
159
+ reward_mode="score",
160
+ mock_response=_next_mock(k),
161
+ )
162
+ reward = float(getattr(res, "reward", 0.0))
163
+ candidates.append((completion, reward))
164
+
165
+ if len(candidates) < 2:
166
+ continue
167
+ chosen, chosen_r = max(candidates, key=lambda x: x[1])
168
+ rejected, rejected_r = min(candidates, key=lambda x: x[1])
169
+ if chosen == rejected:
170
+ continue
171
+
172
+ prompt_ids = llm.encode(prompt)
173
+ chosen_ids = llm.encode(prompt + chosen)
174
+ rejected_ids = llm.encode(prompt + rejected)
175
+ p_len_c = len(prompt_ids)
176
+ p_len_r = len(prompt_ids)
177
+ max_len = int(cfg.model.max_seq_len)
178
+ if max_len:
179
+ if len(chosen_ids) > max_len:
180
+ overflow = len(chosen_ids) - max_len
181
+ chosen_ids = chosen_ids[overflow:]
182
+ p_len_c = max(0, p_len_c - overflow)
183
+ if len(rejected_ids) > max_len:
184
+ overflow = len(rejected_ids) - max_len
185
+ rejected_ids = rejected_ids[overflow:]
186
+ p_len_r = max(0, p_len_r - overflow)
187
+
188
+ def loss_fn(_model):
189
+ return preference_loss(
190
+ llm,
191
+ chosen_ids,
192
+ rejected_ids,
193
+ prompt_len_chosen=p_len_c,
194
+ prompt_len_rejected=p_len_r,
195
+ algo=loss_type,
196
+ beta=beta,
197
+ reference_backend=ref_llm,
198
+ kl_coeff=kl_coeff,
199
+ train_on_prompt=bool(cfg.train.train_on_prompt),
200
+ delta=delta,
201
+ )
202
+
203
+ lval, grads = llm.value_and_grad(loss_fn)
204
+ accum_loss += float(lval.item()) if hasattr(lval, "item") else float(lval)
205
+ accum_count += 1
206
+ if grads is not None:
207
+ accum_grads = tree_add(accum_grads, grads)
208
+
209
+ if step % grad_accum == 0:
210
+ if accum_grads is not None:
211
+ scaled = tree_scale(accum_grads, 1.0 / grad_accum)
212
+ if max_grad_norm > 0:
213
+ scaled = clip_grad_norm(scaled, max_grad_norm)
214
+ llm.apply_grads(opt, scaled)
215
+ accum_grads = None
216
+ accum_loss = 0.0
217
+ accum_count = 0
218
+
219
+ if step % cfg.train.log_every == 0 or step == 1 or step == total:
220
+ avg_loss = accum_loss / max(1, accum_count) if accum_count else float(lval)
221
+ write_jsonl(
222
+ run.metrics_path,
223
+ [
224
+ {
225
+ "ts": now_ts(),
226
+ "step": step,
227
+ "kind": "online_dpo",
228
+ "algo": loss_type,
229
+ "loss": avg_loss,
230
+ "reward_best": chosen_r,
231
+ "reward_worst": rejected_r,
232
+ "accel": backend.name,
233
+ }
234
+ ],
235
+ )
236
+
237
+ if step % cfg.train.save_every == 0 or step == total:
238
+ llm.save_adapter(
239
+ str(run.adapter_dir),
240
+ metadata={
241
+ "base_model": base_model,
242
+ "source_adapter": str(adapter_path) if adapter_path else None,
243
+ "run": run.run_dir.name,
244
+ "kind": "online_dpo",
245
+ },
246
+ )
247
+
248
+ console.print(f"[green]Saved adapter[/green] {run.adapter_dir}")
249
+ return run
mlxsmith/train/pref.py CHANGED
@@ -10,7 +10,8 @@ from ..accel import get_backend
10
10
  from ..config import ProjectConfig
11
11
  from ..models import resolve_model_spec
12
12
  from ..runs import RunPaths, new_run, snapshot_config
13
- from ..util import write_jsonl, now_ts, tree_add, tree_scale
13
+ from ..util import write_jsonl, now_ts, tree_add, tree_scale, clip_grad_norm
14
+ from ..sdk.losses import preference_loss
14
15
  from ..llm.registry import get_llm_backend
15
16
  from ..llm.backend import BackendNotAvailable
16
17
  from .lora import LoRAConfig
@@ -28,7 +29,8 @@ def run_pref(project_root: Path, cfg: ProjectConfig, data_dir: Path, base_model_
28
29
 
29
30
  backend = get_backend(accel)
30
31
  backend.patch()
31
- console.print(f"[bold]PREF[/bold] run: {run.run_dir.name} algo={cfg.pref.algo} accel={backend.name}")
32
+ loss_type = str(cfg.pref.loss_type or cfg.pref.algo)
33
+ console.print(f"[bold]PREF[/bold] run: {run.run_dir.name} algo={loss_type} accel={backend.name}")
32
34
 
33
35
  prefs_path = data_dir / "train.jsonl"
34
36
  if not prefs_path.exists():
@@ -79,10 +81,16 @@ def run_pref(project_root: Path, cfg: ProjectConfig, data_dir: Path, base_model_
79
81
  except BackendNotAvailable:
80
82
  ref_llm = None
81
83
 
82
- opt, _params = llm.optimizer_and_params(lr=cfg.train.lr, weight_decay=cfg.train.weight_decay)
84
+ opt, _params = llm.optimizer_and_params(
85
+ lr=cfg.train.lr,
86
+ weight_decay=cfg.train.weight_decay,
87
+ optimizer=cfg.train.optimizer,
88
+ optimizer_kwargs=cfg.train.optimizer_kwargs,
89
+ )
83
90
 
84
91
  beta = float(cfg.pref.beta)
85
92
  kl_coeff = float(cfg.pref.kl_coeff)
93
+ delta = float(cfg.pref.delta)
86
94
  rng = random.Random(cfg.train.seed)
87
95
  total = int(cfg.train.iters)
88
96
  grad_accum = max(1, int(cfg.train.grad_accum))
@@ -114,30 +122,19 @@ def run_pref(project_root: Path, cfg: ProjectConfig, data_dir: Path, base_model_
114
122
  p_len_r = max(0, p_len_r - overflow)
115
123
 
116
124
  def loss_fn(_model):
117
- logp_c = llm.sequence_logprob(chosen_ids, prompt_len=p_len_c)
118
- logp_r = llm.sequence_logprob(rejected_ids, prompt_len=p_len_r)
119
- ref_diff = 0.0
120
- if ref_llm is not None:
121
- ref_logp_c = ref_llm.sequence_logprob(chosen_ids, prompt_len=p_len_c)
122
- ref_logp_r = ref_llm.sequence_logprob(rejected_ids, prompt_len=p_len_r)
123
- ref_diff = ref_logp_c - ref_logp_r
124
- diff = (logp_c - logp_r) - ref_diff
125
-
126
- if cfg.pref.algo == "orpo":
127
- # ORPO loss = NLL(chosen) - beta * log(sigmoid(diff))
128
- nll = llm.sft_loss(chosen_ids, train_on_prompt=train_on_prompt, prompt_len=p_len_c)
129
- or_loss = -beta * llm.mx.log(llm.mx.sigmoid(diff)) # type: ignore
130
- loss = nll + or_loss
131
- else:
132
- # DPO loss
133
- scaled = llm.mx.array(beta) * diff # type: ignore
134
- loss = llm.mx.log1p(llm.mx.exp(-scaled)) # type: ignore
135
-
136
- if ref_llm is not None and kl_coeff > 0:
137
- # Simple KL penalty on chosen responses
138
- kl = (logp_c - ref_logp_c) if ref_llm is not None else 0.0
139
- loss = loss + llm.mx.array(kl_coeff) * kl # type: ignore
140
- return loss
125
+ return preference_loss(
126
+ llm,
127
+ chosen_ids,
128
+ rejected_ids,
129
+ prompt_len_chosen=p_len_c,
130
+ prompt_len_rejected=p_len_r,
131
+ algo=loss_type,
132
+ beta=beta,
133
+ reference_backend=ref_llm,
134
+ kl_coeff=kl_coeff,
135
+ train_on_prompt=train_on_prompt,
136
+ delta=delta,
137
+ )
141
138
 
142
139
  lval, grads = llm.value_and_grad(loss_fn)
143
140
  if grads is not None:
@@ -145,7 +142,11 @@ def run_pref(project_root: Path, cfg: ProjectConfig, data_dir: Path, base_model_
145
142
 
146
143
  if step % grad_accum == 0:
147
144
  if accum_grads is not None:
148
- llm.apply_grads(opt, tree_scale(accum_grads, 1.0 / grad_accum))
145
+ scaled = tree_scale(accum_grads, 1.0 / grad_accum)
146
+ max_grad_norm = float(getattr(cfg.train, "max_grad_norm", 0.0))
147
+ if max_grad_norm > 0:
148
+ scaled = clip_grad_norm(scaled, max_grad_norm)
149
+ llm.apply_grads(opt, scaled)
149
150
  accum_grads = None
150
151
 
151
152
  if step % cfg.train.log_every == 0 or step == 1 or step == total:
@@ -156,9 +157,10 @@ def run_pref(project_root: Path, cfg: ProjectConfig, data_dir: Path, base_model_
156
157
  "ts": now_ts(),
157
158
  "step": step,
158
159
  "kind": "pref",
159
- "algo": cfg.pref.algo,
160
+ "algo": loss_type,
160
161
  "beta": beta,
161
162
  "kl_coeff": kl_coeff,
163
+ "delta": delta,
162
164
  "loss": float(lval.item()) if hasattr(lval, "item") else float(lval),
163
165
  "accel": backend.name,
164
166
  }
mlxsmith/train/rft.py CHANGED
@@ -46,12 +46,13 @@ def _rollout_token_env(
46
46
  max_steps: int,
47
47
  temperature: float,
48
48
  seed: int,
49
- ) -> tuple[list[int], int, str, float, dict, int]:
49
+ ) -> tuple[list[int], int, str, float, dict, int, list[float]]:
50
50
  obs = env.initial_observation()
51
51
  obs_tokens, reward, done, info = _normalize_observation(obs)
52
52
  prompt_len = len(obs_tokens)
53
53
  full_tokens = list(obs_tokens)
54
54
  gen_tokens = 0
55
+ behavior_logprobs: list[float] = []
55
56
 
56
57
  for idx in range(max_steps):
57
58
  if done:
@@ -67,6 +68,8 @@ def _rollout_token_env(
67
68
  logprobs=0,
68
69
  )
69
70
  new_token = int(gen.token_ids[-1])
71
+ if gen.logprobs:
72
+ behavior_logprobs.append(float(gen.logprobs[-1]))
70
73
  full_tokens.append(new_token)
71
74
  gen_tokens += 1
72
75
 
@@ -77,7 +80,62 @@ def _rollout_token_env(
77
80
  obs_tokens = list(step.observation) if step.observation else list(full_tokens)
78
81
 
79
82
  completion = llm.decode(full_tokens[prompt_len:])
80
- return full_tokens, prompt_len, completion, reward, info, gen_tokens
83
+ return full_tokens, prompt_len, completion, reward, info, gen_tokens, behavior_logprobs
84
+
85
+
86
+ def _pg_loss(
87
+ llm,
88
+ token_ids: list[int],
89
+ *,
90
+ prompt_len: int,
91
+ advantage: float,
92
+ behavior_logprobs: list[float] | None,
93
+ loss_type: str,
94
+ epsilon_low: float,
95
+ epsilon_high: float,
96
+ token_level: bool,
97
+ ref_llm=None,
98
+ kl_coeff: float = 0.0,
99
+ ):
100
+ mx = llm.mx # type: ignore
101
+ logp = None
102
+ if token_level:
103
+ token_logps, _ = llm.token_logprobs(
104
+ token_ids, prompt_len=prompt_len, top_k=0, include_prompt=False
105
+ )
106
+ if not token_logps:
107
+ return mx.array(0.0)
108
+ if loss_type == "dapo" and behavior_logprobs:
109
+ n = min(len(token_logps), len(behavior_logprobs))
110
+ total = mx.array(0.0)
111
+ for lp, bp in zip(token_logps[:n], behavior_logprobs[:n]):
112
+ ratio = mx.exp(mx.array(lp) - mx.array(bp))
113
+ clipped = mx.minimum(
114
+ mx.maximum(ratio, mx.array(1.0 - epsilon_low)),
115
+ mx.array(1.0 + epsilon_high),
116
+ )
117
+ total = total + clipped
118
+ loss = -mx.array(float(advantage)) * total / mx.array(float(n))
119
+ else:
120
+ avg_logp = sum(token_logps) / float(len(token_logps))
121
+ loss = -mx.array(float(advantage)) * mx.array(avg_logp)
122
+ else:
123
+ logp = llm.sequence_logprob(token_ids, prompt_len=prompt_len)
124
+ if loss_type == "dapo" and behavior_logprobs:
125
+ behavior = sum(behavior_logprobs)
126
+ ratio = mx.exp(logp - mx.array(float(behavior)))
127
+ clipped = mx.minimum(mx.maximum(ratio, mx.array(1.0 - epsilon_low)), mx.array(1.0 + epsilon_high))
128
+ loss = -mx.array(float(advantage)) * clipped
129
+ else:
130
+ loss = -mx.array(float(advantage)) * logp
131
+
132
+ if ref_llm is not None and kl_coeff > 0:
133
+ if logp is None:
134
+ logp = llm.sequence_logprob(token_ids, prompt_len=prompt_len)
135
+ ref_logp = ref_llm.sequence_logprob(token_ids, prompt_len=prompt_len)
136
+ loss = loss + mx.array(float(kl_coeff)) * (logp - ref_logp)
137
+
138
+ return loss
81
139
 
82
140
 
83
141
  def run_rft(project_root: Path, cfg: ProjectConfig, env_path: Path, verifier_path: Path, base_model_path: Path, accel: str) -> RunPaths:
@@ -146,7 +204,12 @@ def run_rft(project_root: Path, cfg: ProjectConfig, env_path: Path, verifier_pat
146
204
  except BackendNotAvailable:
147
205
  ref_llm = None
148
206
 
149
- opt, _params = llm.optimizer_and_params(lr=cfg.train.lr, weight_decay=cfg.train.weight_decay)
207
+ opt, _params = llm.optimizer_and_params(
208
+ lr=cfg.train.lr,
209
+ weight_decay=cfg.train.weight_decay,
210
+ optimizer=cfg.train.optimizer,
211
+ optimizer_kwargs=cfg.train.optimizer_kwargs,
212
+ )
150
213
 
151
214
  rng = random.Random(cfg.train.seed)
152
215
  total_iters = int(cfg.train.iters)
@@ -155,6 +218,10 @@ def run_rft(project_root: Path, cfg: ProjectConfig, env_path: Path, verifier_pat
155
218
  max_new = int(getattr(cfg.rft, "max_new_tokens", 256))
156
219
  kl_coeff = float(cfg.rft.kl_coeff)
157
220
  normalize_adv = bool(cfg.rft.normalize_advantage)
221
+ loss_type = str(cfg.rft.loss_type or cfg.rft.algo)
222
+ epsilon_low = float(getattr(cfg.rft, "epsilon_low", 0.2))
223
+ epsilon_high = float(getattr(cfg.rft, "epsilon_high", epsilon_low))
224
+ token_level = bool(getattr(cfg.rft, "token_level_loss", False))
158
225
 
159
226
  if token_env_spec is not None:
160
227
  base_name = env.get("name") or "token_env"
@@ -204,7 +271,7 @@ def run_rft(project_root: Path, cfg: ProjectConfig, env_path: Path, verifier_pat
204
271
  seed=rng.randint(0, 2**31 - 1),
205
272
  )
206
273
 
207
- token_ids, prompt_len, completion, reward, info, gen_count = _rollout_token_env(
274
+ token_ids, prompt_len, completion, reward, info, gen_count, behavior_logprobs = _rollout_token_env(
208
275
  llm,
209
276
  env_instance,
210
277
  max_steps=max_new,
@@ -225,27 +292,35 @@ def run_rft(project_root: Path, cfg: ProjectConfig, env_path: Path, verifier_pat
225
292
  continue
226
293
 
227
294
  passed = bool(info.get("passed", reward > 0.0))
228
- gens.append((token_ids, prompt_len, completion, passed, reward, info))
295
+ gens.append((token_ids, prompt_len, completion, passed, reward, info, behavior_logprobs))
229
296
 
230
297
  gen_elapsed = max(time.time() - gen_start, 1e-6)
231
298
  tps = gen_tokens / gen_elapsed
232
299
 
233
- mean_r = sum(r for *_rest, r, _info in gens) / max(1, len(gens))
300
+ mean_r = sum(r for *_rest, r, _info, _bp in gens) / max(1, len(gens))
234
301
  std_r = (
235
- sum((r - mean_r) ** 2 for *_rest, r, _info in gens) / max(1, len(gens))
302
+ sum((r - mean_r) ** 2 for *_rest, r, _info, _bp in gens) / max(1, len(gens))
236
303
  ) ** 0.5
237
- advs = [r - mean_r for *_rest, r, _info in gens]
238
- if normalize_adv and std_r > 1e-6:
304
+ advs = [r - mean_r for *_rest, r, _info, _bp in gens]
305
+ if loss_type != "dr_grpo" and normalize_adv and std_r > 1e-6:
239
306
  advs = [a / std_r for a in advs]
240
307
 
241
308
  def loss_fn(_model):
242
309
  loss = llm.mx.array(0.0) # type: ignore
243
- for (token_ids, prompt_len, _comp, _passed, _reward, _info), adv in zip(gens, advs):
244
- logp = llm.sequence_logprob(token_ids, prompt_len=prompt_len)
245
- pg = -llm.mx.array(float(adv)) * logp # type: ignore
246
- if ref_llm is not None and kl_coeff > 0:
247
- ref_logp = ref_llm.sequence_logprob(token_ids, prompt_len=prompt_len)
248
- pg = pg + llm.mx.array(kl_coeff) * (logp - ref_logp) # type: ignore
310
+ for (token_ids, prompt_len, _comp, _passed, _reward, _info, bps), adv in zip(gens, advs):
311
+ pg = _pg_loss(
312
+ llm,
313
+ token_ids,
314
+ prompt_len=prompt_len,
315
+ advantage=float(adv),
316
+ behavior_logprobs=bps,
317
+ loss_type=loss_type,
318
+ epsilon_low=epsilon_low,
319
+ epsilon_high=epsilon_high,
320
+ token_level=token_level,
321
+ ref_llm=ref_llm,
322
+ kl_coeff=kl_coeff,
323
+ )
249
324
  loss = loss + pg
250
325
  return loss / llm.mx.array(float(len(gens))) # type: ignore
251
326
 
@@ -256,8 +331,8 @@ def run_rft(project_root: Path, cfg: ProjectConfig, env_path: Path, verifier_pat
256
331
  best_idx = max(range(len(gens)), key=lambda i: gens[i][4])
257
332
  best = gens[best_idx]
258
333
  pass_at_1 = 1.0 if gens[0][3] else 0.0
259
- pass_at_k = 1.0 if any(passed for *_g, passed, _r, _i in gens) else 0.0
260
- acceptance = sum(1 for *_g, passed, _r, _i in gens if passed) / max(1, len(gens))
334
+ pass_at_k = 1.0 if any(g[3] for g in gens) else 0.0
335
+ acceptance = sum(1 for g in gens if g[3]) / max(1, len(gens))
261
336
 
262
337
  latency_summary = latency_summary_ms(verifier_latencies_ms)
263
338
  per_verifier_summary = {
@@ -269,7 +344,7 @@ def run_rft(project_root: Path, cfg: ProjectConfig, env_path: Path, verifier_pat
269
344
  "ts": now_ts(),
270
345
  "step": step,
271
346
  "kind": "rft",
272
- "algo": cfg.rft.algo,
347
+ "algo": loss_type,
273
348
  "task_id": task_id,
274
349
  "mean_reward": mean_r,
275
350
  "std_reward": std_r,
@@ -290,7 +365,7 @@ def run_rft(project_root: Path, cfg: ProjectConfig, env_path: Path, verifier_pat
290
365
  metrics["verifier_latency_ms_by_path"] = per_verifier_summary
291
366
  write_jsonl(run.metrics_path, [metrics])
292
367
 
293
- for (token_ids, prompt_len, completion, passed, reward, _info) in gens:
368
+ for (token_ids, prompt_len, completion, passed, reward, _info, _bps) in gens:
294
369
  if passed:
295
370
  prompt_text = llm.decode(token_ids[:prompt_len]) if prompt_len > 0 else ""
296
371
  write_jsonl(
@@ -331,14 +406,16 @@ def run_rft(project_root: Path, cfg: ProjectConfig, env_path: Path, verifier_pat
331
406
  per_verifier_latencies: dict[str, list[float]] = {}
332
407
 
333
408
  for k in range(rollouts):
334
- gen = llm.generate(
409
+ gen = llm.generate_with_logprobs(
335
410
  prompt,
336
411
  max_new_tokens=max_new,
337
412
  temperature=temperature,
338
413
  seed=rng.randint(0, 2**31 - 1),
414
+ logprobs=0,
339
415
  )
340
416
  completion = gen.text[len(prompt) :] if gen.text.startswith(prompt) else gen.text
341
417
  gen_tokens += max(0, len(gen.token_ids) - gen.prompt_len)
418
+ behavior_logprobs = list(gen.logprobs) if gen.logprobs is not None else []
342
419
 
343
420
  wdir = ensure_dir(run.artifacts_dir / task_id / f"step_{step:06d}" / f"rollout_{k:02d}")
344
421
  if "tests" in task:
@@ -359,27 +436,35 @@ def run_rft(project_root: Path, cfg: ProjectConfig, env_path: Path, verifier_pat
359
436
 
360
437
  passed = bool(getattr(res, "passed", False))
361
438
  reward = float(getattr(res, "reward", 0.0))
362
- gens.append((gen, completion, passed, reward))
439
+ gens.append((gen, completion, passed, reward, behavior_logprobs))
363
440
 
364
441
  gen_elapsed = max(time.time() - gen_start, 1e-6)
365
442
  tps = gen_tokens / gen_elapsed
366
443
 
367
- mean_r = sum(r for *_rest, r in gens) / max(1, len(gens))
444
+ mean_r = sum(r for *_rest, r, _bp in gens) / max(1, len(gens))
368
445
  std_r = (
369
- sum((r - mean_r) ** 2 for *_rest, r in gens) / max(1, len(gens))
446
+ sum((r - mean_r) ** 2 for *_rest, r, _bp in gens) / max(1, len(gens))
370
447
  ) ** 0.5
371
- advs = [r - mean_r for *_rest, r in gens]
372
- if normalize_adv and std_r > 1e-6:
448
+ advs = [r - mean_r for *_rest, r, _bp in gens]
449
+ if loss_type != "dr_grpo" and normalize_adv and std_r > 1e-6:
373
450
  advs = [a / std_r for a in advs]
374
451
 
375
452
  def loss_fn(_model):
376
453
  loss = llm.mx.array(0.0) # type: ignore
377
- for (gen, _comp, _passed, _reward), adv in zip(gens, advs):
378
- logp = llm.sequence_logprob(gen.token_ids, prompt_len=gen.prompt_len)
379
- pg = -llm.mx.array(float(adv)) * logp # type: ignore
380
- if ref_llm is not None and kl_coeff > 0:
381
- ref_logp = ref_llm.sequence_logprob(gen.token_ids, prompt_len=gen.prompt_len)
382
- pg = pg + llm.mx.array(kl_coeff) * (logp - ref_logp) # type: ignore
454
+ for (gen, _comp, _passed, _reward, bps), adv in zip(gens, advs):
455
+ pg = _pg_loss(
456
+ llm,
457
+ list(gen.token_ids),
458
+ prompt_len=gen.prompt_len,
459
+ advantage=float(adv),
460
+ behavior_logprobs=bps,
461
+ loss_type=loss_type,
462
+ epsilon_low=epsilon_low,
463
+ epsilon_high=epsilon_high,
464
+ token_level=token_level,
465
+ ref_llm=ref_llm,
466
+ kl_coeff=kl_coeff,
467
+ )
383
468
  loss = loss + pg
384
469
  return loss / llm.mx.array(float(len(gens))) # type: ignore
385
470
 
@@ -390,8 +475,8 @@ def run_rft(project_root: Path, cfg: ProjectConfig, env_path: Path, verifier_pat
390
475
  best_idx = max(range(len(gens)), key=lambda i: gens[i][3])
391
476
  best = gens[best_idx]
392
477
  pass_at_1 = 1.0 if gens[0][2] else 0.0
393
- pass_at_k = 1.0 if any(passed for _g, _c, passed, _r in gens) else 0.0
394
- acceptance = sum(1 for *_rest, passed, _reward in gens if passed) / max(1, len(gens))
478
+ pass_at_k = 1.0 if any(g[2] for g in gens) else 0.0
479
+ acceptance = sum(1 for g in gens if g[2]) / max(1, len(gens))
395
480
 
396
481
  latency_summary = latency_summary_ms([t * 1000.0 for t in verifier_times])
397
482
  per_verifier_summary = {
@@ -406,10 +491,10 @@ def run_rft(project_root: Path, cfg: ProjectConfig, env_path: Path, verifier_pat
406
491
  "ts": now_ts(),
407
492
  "step": step,
408
493
  "kind": "rft",
409
- "algo": cfg.rft.algo,
410
- "task_id": task_id,
411
- "mean_reward": mean_r,
412
- "std_reward": std_r,
494
+ "algo": loss_type,
495
+ "task_id": task_id,
496
+ "mean_reward": mean_r,
497
+ "std_reward": std_r,
413
498
  "best_reward": best[3],
414
499
  "best_passed": best[2],
415
500
  "pass@1": pass_at_1,
@@ -429,7 +514,7 @@ def run_rft(project_root: Path, cfg: ProjectConfig, env_path: Path, verifier_pat
429
514
  ],
430
515
  )
431
516
 
432
- for (gen, completion, passed, reward) in gens:
517
+ for (gen, completion, passed, reward, _bp) in gens:
433
518
  if passed:
434
519
  write_jsonl(
435
520
  accepted_path,