mlxsmith 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlxsmith/accel/__init__.py +0 -3
- mlxsmith/bench.py +12 -2
- mlxsmith/cli.py +188 -3
- mlxsmith/config_models.py +16 -2
- mlxsmith/integrations/__init__.py +19 -0
- mlxsmith/integrations/mlx_lm_lora.py +117 -0
- mlxsmith/llm/backend.py +8 -1
- mlxsmith/llm/mlx_lm_backend.py +59 -2
- mlxsmith/llm/mock_backend.py +8 -1
- mlxsmith/optim/__init__.py +3 -0
- mlxsmith/optim/muon.py +93 -0
- mlxsmith/orchestrator/daemon.py +44 -377
- mlxsmith/orchestrator/trainer_worker.py +4 -0
- mlxsmith/rlm/loop.py +53 -92
- mlxsmith/sdk/__init__.py +18 -2
- mlxsmith/sdk/losses.py +102 -1
- mlxsmith/sdk/training_client.py +24 -5
- mlxsmith/train/distill.py +6 -1
- mlxsmith/train/online_dpo.py +249 -0
- mlxsmith/train/pref.py +31 -29
- mlxsmith/train/rft.py +123 -38
- mlxsmith/train/self_verify.py +199 -0
- mlxsmith/train/sft.py +13 -2
- mlxsmith/util.py +0 -6
- mlxsmith/verifiers/llm_judge.py +278 -0
- mlxsmith/verifiers/prime.py +127 -0
- {mlxsmith-0.1.1.dist-info → mlxsmith-0.1.3.dist-info}/METADATA +29 -13
- {mlxsmith-0.1.1.dist-info → mlxsmith-0.1.3.dist-info}/RECORD +32 -25
- mlxsmith/accel/zmlx_backend.py +0 -42
- {mlxsmith-0.1.1.dist-info → mlxsmith-0.1.3.dist-info}/WHEEL +0 -0
- {mlxsmith-0.1.1.dist-info → mlxsmith-0.1.3.dist-info}/entry_points.txt +0 -0
- {mlxsmith-0.1.1.dist-info → mlxsmith-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {mlxsmith-0.1.1.dist-info → mlxsmith-0.1.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import random
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Iterable, Optional
|
|
7
|
+
|
|
8
|
+
from rich.console import Console
|
|
9
|
+
|
|
10
|
+
from ..accel import get_backend
|
|
11
|
+
from ..config import ProjectConfig
|
|
12
|
+
from ..models import resolve_model_spec
|
|
13
|
+
from ..runs import RunPaths, new_run, snapshot_config
|
|
14
|
+
from ..util import write_jsonl, now_ts, tree_add, tree_scale, clip_grad_norm
|
|
15
|
+
from ..llm.registry import get_llm_backend
|
|
16
|
+
from ..llm.backend import BackendNotAvailable
|
|
17
|
+
from ..sdk.losses import preference_loss
|
|
18
|
+
from ..verifiers.llm_judge import verify as judge_verify
|
|
19
|
+
from .lora import LoRAConfig
|
|
20
|
+
|
|
21
|
+
console = Console()
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _iter_prompts(path: Path) -> Iterable[str]:
|
|
25
|
+
for line in path.read_text(encoding="utf-8").splitlines():
|
|
26
|
+
if not line.strip():
|
|
27
|
+
continue
|
|
28
|
+
row = json.loads(line)
|
|
29
|
+
prompt = row.get("prompt") or row.get("instruction") or row.get("input") or row.get("question") or ""
|
|
30
|
+
if not prompt and "messages" in row:
|
|
31
|
+
msgs = row.get("messages") or []
|
|
32
|
+
if msgs:
|
|
33
|
+
prompt = "\n".join([m.get("content", "") for m in msgs])
|
|
34
|
+
if prompt:
|
|
35
|
+
yield str(prompt)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def run_online_dpo(
|
|
39
|
+
project_root: Path,
|
|
40
|
+
cfg: ProjectConfig,
|
|
41
|
+
data_path: Path,
|
|
42
|
+
model_id_or_path: str,
|
|
43
|
+
accel: str,
|
|
44
|
+
*,
|
|
45
|
+
judge_model: Optional[str] = None,
|
|
46
|
+
judge_backend: str = "mlx-lm",
|
|
47
|
+
rubric: Optional[str] = None,
|
|
48
|
+
group_size: Optional[int] = None,
|
|
49
|
+
max_new_tokens: Optional[int] = None,
|
|
50
|
+
temperature: Optional[float] = None,
|
|
51
|
+
judge_mock_response: Optional[str | list[str]] = None,
|
|
52
|
+
) -> RunPaths:
|
|
53
|
+
run = new_run(project_root, "online_dpo")
|
|
54
|
+
snapshot_config(cfg.model_dump(), run.config_snapshot_path)
|
|
55
|
+
|
|
56
|
+
prompts = list(_iter_prompts(data_path))
|
|
57
|
+
if not prompts:
|
|
58
|
+
raise RuntimeError("No prompts found in online DPO dataset")
|
|
59
|
+
|
|
60
|
+
backend = get_backend(accel)
|
|
61
|
+
backend.patch()
|
|
62
|
+
console.print(f"[bold]ONLINE-DPO[/bold] run: {run.run_dir.name} accel={backend.name}")
|
|
63
|
+
|
|
64
|
+
llm = get_llm_backend(cfg.model.backend)
|
|
65
|
+
base_model, adapter_path, _meta = resolve_model_spec(project_root, model_id_or_path, cfg)
|
|
66
|
+
|
|
67
|
+
try:
|
|
68
|
+
llm.load(
|
|
69
|
+
base_model,
|
|
70
|
+
max_seq_len=cfg.model.max_seq_len,
|
|
71
|
+
dtype=cfg.model.dtype,
|
|
72
|
+
trust_remote_code=cfg.model.trust_remote_code,
|
|
73
|
+
)
|
|
74
|
+
if adapter_path:
|
|
75
|
+
llm.apply_adapter(str(adapter_path))
|
|
76
|
+
else:
|
|
77
|
+
lora_cfg = LoRAConfig(
|
|
78
|
+
r=cfg.lora.r,
|
|
79
|
+
alpha=cfg.lora.alpha,
|
|
80
|
+
dropout=cfg.lora.dropout,
|
|
81
|
+
target_modules=list(cfg.lora.target_modules or []),
|
|
82
|
+
num_layers=cfg.lora.num_layers,
|
|
83
|
+
scale=cfg.lora.scale,
|
|
84
|
+
fine_tune_type=cfg.lora.fine_tune_type,
|
|
85
|
+
)
|
|
86
|
+
llm.apply_lora_from_config(lora_cfg)
|
|
87
|
+
except BackendNotAvailable as e:
|
|
88
|
+
console.print(f"[yellow]MLX backend unavailable[/yellow]: {e}")
|
|
89
|
+
(run.adapter_dir / "ADAPTER.txt").write_text(
|
|
90
|
+
f"Backend unavailable in this environment.\nmodel={model_id_or_path}\naccel={backend.name}\n",
|
|
91
|
+
encoding="utf-8",
|
|
92
|
+
)
|
|
93
|
+
return run
|
|
94
|
+
|
|
95
|
+
ref_llm = None
|
|
96
|
+
if cfg.pref.reference_model:
|
|
97
|
+
ref_llm = get_llm_backend(cfg.model.backend)
|
|
98
|
+
try:
|
|
99
|
+
ref_llm.load(
|
|
100
|
+
cfg.pref.reference_model,
|
|
101
|
+
max_seq_len=cfg.model.max_seq_len,
|
|
102
|
+
dtype=cfg.model.dtype,
|
|
103
|
+
trust_remote_code=cfg.model.trust_remote_code,
|
|
104
|
+
)
|
|
105
|
+
except BackendNotAvailable:
|
|
106
|
+
ref_llm = None
|
|
107
|
+
|
|
108
|
+
opt, _params = llm.optimizer_and_params(
|
|
109
|
+
lr=cfg.train.lr,
|
|
110
|
+
weight_decay=cfg.train.weight_decay,
|
|
111
|
+
optimizer=cfg.train.optimizer,
|
|
112
|
+
optimizer_kwargs=cfg.train.optimizer_kwargs,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
loss_type = str(cfg.pref.loss_type or cfg.pref.algo)
|
|
116
|
+
beta = float(cfg.pref.beta)
|
|
117
|
+
kl_coeff = float(cfg.pref.kl_coeff)
|
|
118
|
+
delta = float(cfg.pref.delta)
|
|
119
|
+
|
|
120
|
+
total = int(cfg.train.iters)
|
|
121
|
+
grad_accum = max(1, int(cfg.train.grad_accum))
|
|
122
|
+
max_grad_norm = float(getattr(cfg.train, "max_grad_norm", 0.0))
|
|
123
|
+
group = int(group_size or cfg.rft.rollouts or 4)
|
|
124
|
+
max_new = int(max_new_tokens or cfg.rft.max_new_tokens)
|
|
125
|
+
temp = float(temperature if temperature is not None else cfg.rft.temperature)
|
|
126
|
+
|
|
127
|
+
rng = random.Random(cfg.train.seed)
|
|
128
|
+
accum_grads = None
|
|
129
|
+
accum_loss = 0.0
|
|
130
|
+
accum_count = 0
|
|
131
|
+
|
|
132
|
+
def _next_mock(idx: int) -> Optional[str]:
|
|
133
|
+
if judge_mock_response is None:
|
|
134
|
+
return None
|
|
135
|
+
if isinstance(judge_mock_response, list):
|
|
136
|
+
if not judge_mock_response:
|
|
137
|
+
return None
|
|
138
|
+
return judge_mock_response[min(idx, len(judge_mock_response) - 1)]
|
|
139
|
+
return judge_mock_response
|
|
140
|
+
|
|
141
|
+
for step in range(1, total + 1):
|
|
142
|
+
prompt = rng.choice(prompts)
|
|
143
|
+
candidates: list[tuple[str, float]] = []
|
|
144
|
+
for k in range(group):
|
|
145
|
+
gen = llm.generate(
|
|
146
|
+
prompt,
|
|
147
|
+
max_new_tokens=max_new,
|
|
148
|
+
temperature=temp,
|
|
149
|
+
seed=rng.randint(0, 2**31 - 1),
|
|
150
|
+
)
|
|
151
|
+
completion = gen.text[len(prompt) :] if gen.text.startswith(prompt) else gen.text
|
|
152
|
+
res = judge_verify(
|
|
153
|
+
prompt,
|
|
154
|
+
completion,
|
|
155
|
+
str(run.artifacts_dir),
|
|
156
|
+
model=judge_model,
|
|
157
|
+
backend=judge_backend,
|
|
158
|
+
rubric=rubric,
|
|
159
|
+
reward_mode="score",
|
|
160
|
+
mock_response=_next_mock(k),
|
|
161
|
+
)
|
|
162
|
+
reward = float(getattr(res, "reward", 0.0))
|
|
163
|
+
candidates.append((completion, reward))
|
|
164
|
+
|
|
165
|
+
if len(candidates) < 2:
|
|
166
|
+
continue
|
|
167
|
+
chosen, chosen_r = max(candidates, key=lambda x: x[1])
|
|
168
|
+
rejected, rejected_r = min(candidates, key=lambda x: x[1])
|
|
169
|
+
if chosen == rejected:
|
|
170
|
+
continue
|
|
171
|
+
|
|
172
|
+
prompt_ids = llm.encode(prompt)
|
|
173
|
+
chosen_ids = llm.encode(prompt + chosen)
|
|
174
|
+
rejected_ids = llm.encode(prompt + rejected)
|
|
175
|
+
p_len_c = len(prompt_ids)
|
|
176
|
+
p_len_r = len(prompt_ids)
|
|
177
|
+
max_len = int(cfg.model.max_seq_len)
|
|
178
|
+
if max_len:
|
|
179
|
+
if len(chosen_ids) > max_len:
|
|
180
|
+
overflow = len(chosen_ids) - max_len
|
|
181
|
+
chosen_ids = chosen_ids[overflow:]
|
|
182
|
+
p_len_c = max(0, p_len_c - overflow)
|
|
183
|
+
if len(rejected_ids) > max_len:
|
|
184
|
+
overflow = len(rejected_ids) - max_len
|
|
185
|
+
rejected_ids = rejected_ids[overflow:]
|
|
186
|
+
p_len_r = max(0, p_len_r - overflow)
|
|
187
|
+
|
|
188
|
+
def loss_fn(_model):
|
|
189
|
+
return preference_loss(
|
|
190
|
+
llm,
|
|
191
|
+
chosen_ids,
|
|
192
|
+
rejected_ids,
|
|
193
|
+
prompt_len_chosen=p_len_c,
|
|
194
|
+
prompt_len_rejected=p_len_r,
|
|
195
|
+
algo=loss_type,
|
|
196
|
+
beta=beta,
|
|
197
|
+
reference_backend=ref_llm,
|
|
198
|
+
kl_coeff=kl_coeff,
|
|
199
|
+
train_on_prompt=bool(cfg.train.train_on_prompt),
|
|
200
|
+
delta=delta,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
lval, grads = llm.value_and_grad(loss_fn)
|
|
204
|
+
accum_loss += float(lval.item()) if hasattr(lval, "item") else float(lval)
|
|
205
|
+
accum_count += 1
|
|
206
|
+
if grads is not None:
|
|
207
|
+
accum_grads = tree_add(accum_grads, grads)
|
|
208
|
+
|
|
209
|
+
if step % grad_accum == 0:
|
|
210
|
+
if accum_grads is not None:
|
|
211
|
+
scaled = tree_scale(accum_grads, 1.0 / grad_accum)
|
|
212
|
+
if max_grad_norm > 0:
|
|
213
|
+
scaled = clip_grad_norm(scaled, max_grad_norm)
|
|
214
|
+
llm.apply_grads(opt, scaled)
|
|
215
|
+
accum_grads = None
|
|
216
|
+
accum_loss = 0.0
|
|
217
|
+
accum_count = 0
|
|
218
|
+
|
|
219
|
+
if step % cfg.train.log_every == 0 or step == 1 or step == total:
|
|
220
|
+
avg_loss = accum_loss / max(1, accum_count) if accum_count else float(lval)
|
|
221
|
+
write_jsonl(
|
|
222
|
+
run.metrics_path,
|
|
223
|
+
[
|
|
224
|
+
{
|
|
225
|
+
"ts": now_ts(),
|
|
226
|
+
"step": step,
|
|
227
|
+
"kind": "online_dpo",
|
|
228
|
+
"algo": loss_type,
|
|
229
|
+
"loss": avg_loss,
|
|
230
|
+
"reward_best": chosen_r,
|
|
231
|
+
"reward_worst": rejected_r,
|
|
232
|
+
"accel": backend.name,
|
|
233
|
+
}
|
|
234
|
+
],
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
if step % cfg.train.save_every == 0 or step == total:
|
|
238
|
+
llm.save_adapter(
|
|
239
|
+
str(run.adapter_dir),
|
|
240
|
+
metadata={
|
|
241
|
+
"base_model": base_model,
|
|
242
|
+
"source_adapter": str(adapter_path) if adapter_path else None,
|
|
243
|
+
"run": run.run_dir.name,
|
|
244
|
+
"kind": "online_dpo",
|
|
245
|
+
},
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
console.print(f"[green]Saved adapter[/green] {run.adapter_dir}")
|
|
249
|
+
return run
|
mlxsmith/train/pref.py
CHANGED
|
@@ -10,7 +10,8 @@ from ..accel import get_backend
|
|
|
10
10
|
from ..config import ProjectConfig
|
|
11
11
|
from ..models import resolve_model_spec
|
|
12
12
|
from ..runs import RunPaths, new_run, snapshot_config
|
|
13
|
-
from ..util import write_jsonl, now_ts, tree_add, tree_scale
|
|
13
|
+
from ..util import write_jsonl, now_ts, tree_add, tree_scale, clip_grad_norm
|
|
14
|
+
from ..sdk.losses import preference_loss
|
|
14
15
|
from ..llm.registry import get_llm_backend
|
|
15
16
|
from ..llm.backend import BackendNotAvailable
|
|
16
17
|
from .lora import LoRAConfig
|
|
@@ -28,7 +29,8 @@ def run_pref(project_root: Path, cfg: ProjectConfig, data_dir: Path, base_model_
|
|
|
28
29
|
|
|
29
30
|
backend = get_backend(accel)
|
|
30
31
|
backend.patch()
|
|
31
|
-
|
|
32
|
+
loss_type = str(cfg.pref.loss_type or cfg.pref.algo)
|
|
33
|
+
console.print(f"[bold]PREF[/bold] run: {run.run_dir.name} algo={loss_type} accel={backend.name}")
|
|
32
34
|
|
|
33
35
|
prefs_path = data_dir / "train.jsonl"
|
|
34
36
|
if not prefs_path.exists():
|
|
@@ -79,10 +81,16 @@ def run_pref(project_root: Path, cfg: ProjectConfig, data_dir: Path, base_model_
|
|
|
79
81
|
except BackendNotAvailable:
|
|
80
82
|
ref_llm = None
|
|
81
83
|
|
|
82
|
-
opt, _params = llm.optimizer_and_params(
|
|
84
|
+
opt, _params = llm.optimizer_and_params(
|
|
85
|
+
lr=cfg.train.lr,
|
|
86
|
+
weight_decay=cfg.train.weight_decay,
|
|
87
|
+
optimizer=cfg.train.optimizer,
|
|
88
|
+
optimizer_kwargs=cfg.train.optimizer_kwargs,
|
|
89
|
+
)
|
|
83
90
|
|
|
84
91
|
beta = float(cfg.pref.beta)
|
|
85
92
|
kl_coeff = float(cfg.pref.kl_coeff)
|
|
93
|
+
delta = float(cfg.pref.delta)
|
|
86
94
|
rng = random.Random(cfg.train.seed)
|
|
87
95
|
total = int(cfg.train.iters)
|
|
88
96
|
grad_accum = max(1, int(cfg.train.grad_accum))
|
|
@@ -114,30 +122,19 @@ def run_pref(project_root: Path, cfg: ProjectConfig, data_dir: Path, base_model_
|
|
|
114
122
|
p_len_r = max(0, p_len_r - overflow)
|
|
115
123
|
|
|
116
124
|
def loss_fn(_model):
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
loss = nll + or_loss
|
|
131
|
-
else:
|
|
132
|
-
# DPO loss
|
|
133
|
-
scaled = llm.mx.array(beta) * diff # type: ignore
|
|
134
|
-
loss = llm.mx.log1p(llm.mx.exp(-scaled)) # type: ignore
|
|
135
|
-
|
|
136
|
-
if ref_llm is not None and kl_coeff > 0:
|
|
137
|
-
# Simple KL penalty on chosen responses
|
|
138
|
-
kl = (logp_c - ref_logp_c) if ref_llm is not None else 0.0
|
|
139
|
-
loss = loss + llm.mx.array(kl_coeff) * kl # type: ignore
|
|
140
|
-
return loss
|
|
125
|
+
return preference_loss(
|
|
126
|
+
llm,
|
|
127
|
+
chosen_ids,
|
|
128
|
+
rejected_ids,
|
|
129
|
+
prompt_len_chosen=p_len_c,
|
|
130
|
+
prompt_len_rejected=p_len_r,
|
|
131
|
+
algo=loss_type,
|
|
132
|
+
beta=beta,
|
|
133
|
+
reference_backend=ref_llm,
|
|
134
|
+
kl_coeff=kl_coeff,
|
|
135
|
+
train_on_prompt=train_on_prompt,
|
|
136
|
+
delta=delta,
|
|
137
|
+
)
|
|
141
138
|
|
|
142
139
|
lval, grads = llm.value_and_grad(loss_fn)
|
|
143
140
|
if grads is not None:
|
|
@@ -145,7 +142,11 @@ def run_pref(project_root: Path, cfg: ProjectConfig, data_dir: Path, base_model_
|
|
|
145
142
|
|
|
146
143
|
if step % grad_accum == 0:
|
|
147
144
|
if accum_grads is not None:
|
|
148
|
-
|
|
145
|
+
scaled = tree_scale(accum_grads, 1.0 / grad_accum)
|
|
146
|
+
max_grad_norm = float(getattr(cfg.train, "max_grad_norm", 0.0))
|
|
147
|
+
if max_grad_norm > 0:
|
|
148
|
+
scaled = clip_grad_norm(scaled, max_grad_norm)
|
|
149
|
+
llm.apply_grads(opt, scaled)
|
|
149
150
|
accum_grads = None
|
|
150
151
|
|
|
151
152
|
if step % cfg.train.log_every == 0 or step == 1 or step == total:
|
|
@@ -156,9 +157,10 @@ def run_pref(project_root: Path, cfg: ProjectConfig, data_dir: Path, base_model_
|
|
|
156
157
|
"ts": now_ts(),
|
|
157
158
|
"step": step,
|
|
158
159
|
"kind": "pref",
|
|
159
|
-
"algo":
|
|
160
|
+
"algo": loss_type,
|
|
160
161
|
"beta": beta,
|
|
161
162
|
"kl_coeff": kl_coeff,
|
|
163
|
+
"delta": delta,
|
|
162
164
|
"loss": float(lval.item()) if hasattr(lval, "item") else float(lval),
|
|
163
165
|
"accel": backend.name,
|
|
164
166
|
}
|
mlxsmith/train/rft.py
CHANGED
|
@@ -46,12 +46,13 @@ def _rollout_token_env(
|
|
|
46
46
|
max_steps: int,
|
|
47
47
|
temperature: float,
|
|
48
48
|
seed: int,
|
|
49
|
-
) -> tuple[list[int], int, str, float, dict, int]:
|
|
49
|
+
) -> tuple[list[int], int, str, float, dict, int, list[float]]:
|
|
50
50
|
obs = env.initial_observation()
|
|
51
51
|
obs_tokens, reward, done, info = _normalize_observation(obs)
|
|
52
52
|
prompt_len = len(obs_tokens)
|
|
53
53
|
full_tokens = list(obs_tokens)
|
|
54
54
|
gen_tokens = 0
|
|
55
|
+
behavior_logprobs: list[float] = []
|
|
55
56
|
|
|
56
57
|
for idx in range(max_steps):
|
|
57
58
|
if done:
|
|
@@ -67,6 +68,8 @@ def _rollout_token_env(
|
|
|
67
68
|
logprobs=0,
|
|
68
69
|
)
|
|
69
70
|
new_token = int(gen.token_ids[-1])
|
|
71
|
+
if gen.logprobs:
|
|
72
|
+
behavior_logprobs.append(float(gen.logprobs[-1]))
|
|
70
73
|
full_tokens.append(new_token)
|
|
71
74
|
gen_tokens += 1
|
|
72
75
|
|
|
@@ -77,7 +80,62 @@ def _rollout_token_env(
|
|
|
77
80
|
obs_tokens = list(step.observation) if step.observation else list(full_tokens)
|
|
78
81
|
|
|
79
82
|
completion = llm.decode(full_tokens[prompt_len:])
|
|
80
|
-
return full_tokens, prompt_len, completion, reward, info, gen_tokens
|
|
83
|
+
return full_tokens, prompt_len, completion, reward, info, gen_tokens, behavior_logprobs
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _pg_loss(
|
|
87
|
+
llm,
|
|
88
|
+
token_ids: list[int],
|
|
89
|
+
*,
|
|
90
|
+
prompt_len: int,
|
|
91
|
+
advantage: float,
|
|
92
|
+
behavior_logprobs: list[float] | None,
|
|
93
|
+
loss_type: str,
|
|
94
|
+
epsilon_low: float,
|
|
95
|
+
epsilon_high: float,
|
|
96
|
+
token_level: bool,
|
|
97
|
+
ref_llm=None,
|
|
98
|
+
kl_coeff: float = 0.0,
|
|
99
|
+
):
|
|
100
|
+
mx = llm.mx # type: ignore
|
|
101
|
+
logp = None
|
|
102
|
+
if token_level:
|
|
103
|
+
token_logps, _ = llm.token_logprobs(
|
|
104
|
+
token_ids, prompt_len=prompt_len, top_k=0, include_prompt=False
|
|
105
|
+
)
|
|
106
|
+
if not token_logps:
|
|
107
|
+
return mx.array(0.0)
|
|
108
|
+
if loss_type == "dapo" and behavior_logprobs:
|
|
109
|
+
n = min(len(token_logps), len(behavior_logprobs))
|
|
110
|
+
total = mx.array(0.0)
|
|
111
|
+
for lp, bp in zip(token_logps[:n], behavior_logprobs[:n]):
|
|
112
|
+
ratio = mx.exp(mx.array(lp) - mx.array(bp))
|
|
113
|
+
clipped = mx.minimum(
|
|
114
|
+
mx.maximum(ratio, mx.array(1.0 - epsilon_low)),
|
|
115
|
+
mx.array(1.0 + epsilon_high),
|
|
116
|
+
)
|
|
117
|
+
total = total + clipped
|
|
118
|
+
loss = -mx.array(float(advantage)) * total / mx.array(float(n))
|
|
119
|
+
else:
|
|
120
|
+
avg_logp = sum(token_logps) / float(len(token_logps))
|
|
121
|
+
loss = -mx.array(float(advantage)) * mx.array(avg_logp)
|
|
122
|
+
else:
|
|
123
|
+
logp = llm.sequence_logprob(token_ids, prompt_len=prompt_len)
|
|
124
|
+
if loss_type == "dapo" and behavior_logprobs:
|
|
125
|
+
behavior = sum(behavior_logprobs)
|
|
126
|
+
ratio = mx.exp(logp - mx.array(float(behavior)))
|
|
127
|
+
clipped = mx.minimum(mx.maximum(ratio, mx.array(1.0 - epsilon_low)), mx.array(1.0 + epsilon_high))
|
|
128
|
+
loss = -mx.array(float(advantage)) * clipped
|
|
129
|
+
else:
|
|
130
|
+
loss = -mx.array(float(advantage)) * logp
|
|
131
|
+
|
|
132
|
+
if ref_llm is not None and kl_coeff > 0:
|
|
133
|
+
if logp is None:
|
|
134
|
+
logp = llm.sequence_logprob(token_ids, prompt_len=prompt_len)
|
|
135
|
+
ref_logp = ref_llm.sequence_logprob(token_ids, prompt_len=prompt_len)
|
|
136
|
+
loss = loss + mx.array(float(kl_coeff)) * (logp - ref_logp)
|
|
137
|
+
|
|
138
|
+
return loss
|
|
81
139
|
|
|
82
140
|
|
|
83
141
|
def run_rft(project_root: Path, cfg: ProjectConfig, env_path: Path, verifier_path: Path, base_model_path: Path, accel: str) -> RunPaths:
|
|
@@ -146,7 +204,12 @@ def run_rft(project_root: Path, cfg: ProjectConfig, env_path: Path, verifier_pat
|
|
|
146
204
|
except BackendNotAvailable:
|
|
147
205
|
ref_llm = None
|
|
148
206
|
|
|
149
|
-
opt, _params = llm.optimizer_and_params(
|
|
207
|
+
opt, _params = llm.optimizer_and_params(
|
|
208
|
+
lr=cfg.train.lr,
|
|
209
|
+
weight_decay=cfg.train.weight_decay,
|
|
210
|
+
optimizer=cfg.train.optimizer,
|
|
211
|
+
optimizer_kwargs=cfg.train.optimizer_kwargs,
|
|
212
|
+
)
|
|
150
213
|
|
|
151
214
|
rng = random.Random(cfg.train.seed)
|
|
152
215
|
total_iters = int(cfg.train.iters)
|
|
@@ -155,6 +218,10 @@ def run_rft(project_root: Path, cfg: ProjectConfig, env_path: Path, verifier_pat
|
|
|
155
218
|
max_new = int(getattr(cfg.rft, "max_new_tokens", 256))
|
|
156
219
|
kl_coeff = float(cfg.rft.kl_coeff)
|
|
157
220
|
normalize_adv = bool(cfg.rft.normalize_advantage)
|
|
221
|
+
loss_type = str(cfg.rft.loss_type or cfg.rft.algo)
|
|
222
|
+
epsilon_low = float(getattr(cfg.rft, "epsilon_low", 0.2))
|
|
223
|
+
epsilon_high = float(getattr(cfg.rft, "epsilon_high", epsilon_low))
|
|
224
|
+
token_level = bool(getattr(cfg.rft, "token_level_loss", False))
|
|
158
225
|
|
|
159
226
|
if token_env_spec is not None:
|
|
160
227
|
base_name = env.get("name") or "token_env"
|
|
@@ -204,7 +271,7 @@ def run_rft(project_root: Path, cfg: ProjectConfig, env_path: Path, verifier_pat
|
|
|
204
271
|
seed=rng.randint(0, 2**31 - 1),
|
|
205
272
|
)
|
|
206
273
|
|
|
207
|
-
token_ids, prompt_len, completion, reward, info, gen_count = _rollout_token_env(
|
|
274
|
+
token_ids, prompt_len, completion, reward, info, gen_count, behavior_logprobs = _rollout_token_env(
|
|
208
275
|
llm,
|
|
209
276
|
env_instance,
|
|
210
277
|
max_steps=max_new,
|
|
@@ -225,27 +292,35 @@ def run_rft(project_root: Path, cfg: ProjectConfig, env_path: Path, verifier_pat
|
|
|
225
292
|
continue
|
|
226
293
|
|
|
227
294
|
passed = bool(info.get("passed", reward > 0.0))
|
|
228
|
-
gens.append((token_ids, prompt_len, completion, passed, reward, info))
|
|
295
|
+
gens.append((token_ids, prompt_len, completion, passed, reward, info, behavior_logprobs))
|
|
229
296
|
|
|
230
297
|
gen_elapsed = max(time.time() - gen_start, 1e-6)
|
|
231
298
|
tps = gen_tokens / gen_elapsed
|
|
232
299
|
|
|
233
|
-
mean_r = sum(r for *_rest, r, _info in gens) / max(1, len(gens))
|
|
300
|
+
mean_r = sum(r for *_rest, r, _info, _bp in gens) / max(1, len(gens))
|
|
234
301
|
std_r = (
|
|
235
|
-
sum((r - mean_r) ** 2 for *_rest, r, _info in gens) / max(1, len(gens))
|
|
302
|
+
sum((r - mean_r) ** 2 for *_rest, r, _info, _bp in gens) / max(1, len(gens))
|
|
236
303
|
) ** 0.5
|
|
237
|
-
advs = [r - mean_r for *_rest, r, _info in gens]
|
|
238
|
-
if normalize_adv and std_r > 1e-6:
|
|
304
|
+
advs = [r - mean_r for *_rest, r, _info, _bp in gens]
|
|
305
|
+
if loss_type != "dr_grpo" and normalize_adv and std_r > 1e-6:
|
|
239
306
|
advs = [a / std_r for a in advs]
|
|
240
307
|
|
|
241
308
|
def loss_fn(_model):
|
|
242
309
|
loss = llm.mx.array(0.0) # type: ignore
|
|
243
|
-
for (token_ids, prompt_len, _comp, _passed, _reward, _info), adv in zip(gens, advs):
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
310
|
+
for (token_ids, prompt_len, _comp, _passed, _reward, _info, bps), adv in zip(gens, advs):
|
|
311
|
+
pg = _pg_loss(
|
|
312
|
+
llm,
|
|
313
|
+
token_ids,
|
|
314
|
+
prompt_len=prompt_len,
|
|
315
|
+
advantage=float(adv),
|
|
316
|
+
behavior_logprobs=bps,
|
|
317
|
+
loss_type=loss_type,
|
|
318
|
+
epsilon_low=epsilon_low,
|
|
319
|
+
epsilon_high=epsilon_high,
|
|
320
|
+
token_level=token_level,
|
|
321
|
+
ref_llm=ref_llm,
|
|
322
|
+
kl_coeff=kl_coeff,
|
|
323
|
+
)
|
|
249
324
|
loss = loss + pg
|
|
250
325
|
return loss / llm.mx.array(float(len(gens))) # type: ignore
|
|
251
326
|
|
|
@@ -256,8 +331,8 @@ def run_rft(project_root: Path, cfg: ProjectConfig, env_path: Path, verifier_pat
|
|
|
256
331
|
best_idx = max(range(len(gens)), key=lambda i: gens[i][4])
|
|
257
332
|
best = gens[best_idx]
|
|
258
333
|
pass_at_1 = 1.0 if gens[0][3] else 0.0
|
|
259
|
-
pass_at_k = 1.0 if any(
|
|
260
|
-
acceptance = sum(1 for
|
|
334
|
+
pass_at_k = 1.0 if any(g[3] for g in gens) else 0.0
|
|
335
|
+
acceptance = sum(1 for g in gens if g[3]) / max(1, len(gens))
|
|
261
336
|
|
|
262
337
|
latency_summary = latency_summary_ms(verifier_latencies_ms)
|
|
263
338
|
per_verifier_summary = {
|
|
@@ -269,7 +344,7 @@ def run_rft(project_root: Path, cfg: ProjectConfig, env_path: Path, verifier_pat
|
|
|
269
344
|
"ts": now_ts(),
|
|
270
345
|
"step": step,
|
|
271
346
|
"kind": "rft",
|
|
272
|
-
"algo":
|
|
347
|
+
"algo": loss_type,
|
|
273
348
|
"task_id": task_id,
|
|
274
349
|
"mean_reward": mean_r,
|
|
275
350
|
"std_reward": std_r,
|
|
@@ -290,7 +365,7 @@ def run_rft(project_root: Path, cfg: ProjectConfig, env_path: Path, verifier_pat
|
|
|
290
365
|
metrics["verifier_latency_ms_by_path"] = per_verifier_summary
|
|
291
366
|
write_jsonl(run.metrics_path, [metrics])
|
|
292
367
|
|
|
293
|
-
for (token_ids, prompt_len, completion, passed, reward, _info) in gens:
|
|
368
|
+
for (token_ids, prompt_len, completion, passed, reward, _info, _bps) in gens:
|
|
294
369
|
if passed:
|
|
295
370
|
prompt_text = llm.decode(token_ids[:prompt_len]) if prompt_len > 0 else ""
|
|
296
371
|
write_jsonl(
|
|
@@ -331,14 +406,16 @@ def run_rft(project_root: Path, cfg: ProjectConfig, env_path: Path, verifier_pat
|
|
|
331
406
|
per_verifier_latencies: dict[str, list[float]] = {}
|
|
332
407
|
|
|
333
408
|
for k in range(rollouts):
|
|
334
|
-
gen = llm.
|
|
409
|
+
gen = llm.generate_with_logprobs(
|
|
335
410
|
prompt,
|
|
336
411
|
max_new_tokens=max_new,
|
|
337
412
|
temperature=temperature,
|
|
338
413
|
seed=rng.randint(0, 2**31 - 1),
|
|
414
|
+
logprobs=0,
|
|
339
415
|
)
|
|
340
416
|
completion = gen.text[len(prompt) :] if gen.text.startswith(prompt) else gen.text
|
|
341
417
|
gen_tokens += max(0, len(gen.token_ids) - gen.prompt_len)
|
|
418
|
+
behavior_logprobs = list(gen.logprobs) if gen.logprobs is not None else []
|
|
342
419
|
|
|
343
420
|
wdir = ensure_dir(run.artifacts_dir / task_id / f"step_{step:06d}" / f"rollout_{k:02d}")
|
|
344
421
|
if "tests" in task:
|
|
@@ -359,27 +436,35 @@ def run_rft(project_root: Path, cfg: ProjectConfig, env_path: Path, verifier_pat
|
|
|
359
436
|
|
|
360
437
|
passed = bool(getattr(res, "passed", False))
|
|
361
438
|
reward = float(getattr(res, "reward", 0.0))
|
|
362
|
-
gens.append((gen, completion, passed, reward))
|
|
439
|
+
gens.append((gen, completion, passed, reward, behavior_logprobs))
|
|
363
440
|
|
|
364
441
|
gen_elapsed = max(time.time() - gen_start, 1e-6)
|
|
365
442
|
tps = gen_tokens / gen_elapsed
|
|
366
443
|
|
|
367
|
-
mean_r = sum(r for *_rest, r in gens) / max(1, len(gens))
|
|
444
|
+
mean_r = sum(r for *_rest, r, _bp in gens) / max(1, len(gens))
|
|
368
445
|
std_r = (
|
|
369
|
-
sum((r - mean_r) ** 2 for *_rest, r in gens) / max(1, len(gens))
|
|
446
|
+
sum((r - mean_r) ** 2 for *_rest, r, _bp in gens) / max(1, len(gens))
|
|
370
447
|
) ** 0.5
|
|
371
|
-
advs = [r - mean_r for *_rest, r in gens]
|
|
372
|
-
if normalize_adv and std_r > 1e-6:
|
|
448
|
+
advs = [r - mean_r for *_rest, r, _bp in gens]
|
|
449
|
+
if loss_type != "dr_grpo" and normalize_adv and std_r > 1e-6:
|
|
373
450
|
advs = [a / std_r for a in advs]
|
|
374
451
|
|
|
375
452
|
def loss_fn(_model):
|
|
376
453
|
loss = llm.mx.array(0.0) # type: ignore
|
|
377
|
-
for (gen, _comp, _passed, _reward), adv in zip(gens, advs):
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
454
|
+
for (gen, _comp, _passed, _reward, bps), adv in zip(gens, advs):
|
|
455
|
+
pg = _pg_loss(
|
|
456
|
+
llm,
|
|
457
|
+
list(gen.token_ids),
|
|
458
|
+
prompt_len=gen.prompt_len,
|
|
459
|
+
advantage=float(adv),
|
|
460
|
+
behavior_logprobs=bps,
|
|
461
|
+
loss_type=loss_type,
|
|
462
|
+
epsilon_low=epsilon_low,
|
|
463
|
+
epsilon_high=epsilon_high,
|
|
464
|
+
token_level=token_level,
|
|
465
|
+
ref_llm=ref_llm,
|
|
466
|
+
kl_coeff=kl_coeff,
|
|
467
|
+
)
|
|
383
468
|
loss = loss + pg
|
|
384
469
|
return loss / llm.mx.array(float(len(gens))) # type: ignore
|
|
385
470
|
|
|
@@ -390,8 +475,8 @@ def run_rft(project_root: Path, cfg: ProjectConfig, env_path: Path, verifier_pat
|
|
|
390
475
|
best_idx = max(range(len(gens)), key=lambda i: gens[i][3])
|
|
391
476
|
best = gens[best_idx]
|
|
392
477
|
pass_at_1 = 1.0 if gens[0][2] else 0.0
|
|
393
|
-
pass_at_k = 1.0 if any(
|
|
394
|
-
acceptance = sum(1 for
|
|
478
|
+
pass_at_k = 1.0 if any(g[2] for g in gens) else 0.0
|
|
479
|
+
acceptance = sum(1 for g in gens if g[2]) / max(1, len(gens))
|
|
395
480
|
|
|
396
481
|
latency_summary = latency_summary_ms([t * 1000.0 for t in verifier_times])
|
|
397
482
|
per_verifier_summary = {
|
|
@@ -406,10 +491,10 @@ def run_rft(project_root: Path, cfg: ProjectConfig, env_path: Path, verifier_pat
|
|
|
406
491
|
"ts": now_ts(),
|
|
407
492
|
"step": step,
|
|
408
493
|
"kind": "rft",
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
494
|
+
"algo": loss_type,
|
|
495
|
+
"task_id": task_id,
|
|
496
|
+
"mean_reward": mean_r,
|
|
497
|
+
"std_reward": std_r,
|
|
413
498
|
"best_reward": best[3],
|
|
414
499
|
"best_passed": best[2],
|
|
415
500
|
"pass@1": pass_at_1,
|
|
@@ -429,7 +514,7 @@ def run_rft(project_root: Path, cfg: ProjectConfig, env_path: Path, verifier_pat
|
|
|
429
514
|
],
|
|
430
515
|
)
|
|
431
516
|
|
|
432
|
-
for (gen, completion, passed, reward) in gens:
|
|
517
|
+
for (gen, completion, passed, reward, _bp) in gens:
|
|
433
518
|
if passed:
|
|
434
519
|
write_jsonl(
|
|
435
520
|
accepted_path,
|