mlxsmith 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlxsmith/__init__.py +2 -0
- mlxsmith/accel/__init__.py +10 -0
- mlxsmith/accel/base.py +17 -0
- mlxsmith/accel/none.py +13 -0
- mlxsmith/accel/zmlx_backend.py +42 -0
- mlxsmith/adapters.py +46 -0
- mlxsmith/api/__init__.py +48 -0
- mlxsmith/api/handlers.py +1217 -0
- mlxsmith/api/schemas.py +436 -0
- mlxsmith/auth.py +88 -0
- mlxsmith/bench.py +102 -0
- mlxsmith/cli.py +950 -0
- mlxsmith/config.py +543 -0
- mlxsmith/config_models.py +261 -0
- mlxsmith/data.py +493 -0
- mlxsmith/envs/__init__.py +33 -0
- mlxsmith/envs/system.py +388 -0
- mlxsmith/envs/token_env.py +191 -0
- mlxsmith/eval.py +112 -0
- mlxsmith/infer.py +140 -0
- mlxsmith/llm/__init__.py +16 -0
- mlxsmith/llm/backend.py +126 -0
- mlxsmith/llm/interface.py +212 -0
- mlxsmith/llm/mlx_lm_backend.py +509 -0
- mlxsmith/llm/mock_backend.py +228 -0
- mlxsmith/llm/registry.py +12 -0
- mlxsmith/models.py +257 -0
- mlxsmith/orchestrator/__init__.py +25 -0
- mlxsmith/orchestrator/daemon.py +454 -0
- mlxsmith/orchestrator/inference_worker.py +496 -0
- mlxsmith/orchestrator/queue.py +355 -0
- mlxsmith/orchestrator/trainer_worker.py +437 -0
- mlxsmith/rlm/__init__.py +8 -0
- mlxsmith/rlm/corpus.py +74 -0
- mlxsmith/rlm/gating.py +90 -0
- mlxsmith/rlm/generate.py +249 -0
- mlxsmith/rlm/history.py +12 -0
- mlxsmith/rlm/inference.py +150 -0
- mlxsmith/rlm/loop.py +1297 -0
- mlxsmith/rlm/mutate.py +82 -0
- mlxsmith/rlm/trainer.py +73 -0
- mlxsmith/rlm/weights.py +263 -0
- mlxsmith/runs.py +44 -0
- mlxsmith/sdk/__init__.py +392 -0
- mlxsmith/sdk/future.py +486 -0
- mlxsmith/sdk/losses.py +262 -0
- mlxsmith/sdk/sampling_client.py +729 -0
- mlxsmith/sdk/training_client.py +676 -0
- mlxsmith/server.py +376 -0
- mlxsmith/train/__init__.py +0 -0
- mlxsmith/train/distill.py +279 -0
- mlxsmith/train/lora.py +280 -0
- mlxsmith/train/pref.py +180 -0
- mlxsmith/train/rft.py +458 -0
- mlxsmith/train/sft.py +151 -0
- mlxsmith/util.py +174 -0
- mlxsmith/verifiers/__init__.py +3 -0
- mlxsmith/verifiers/compose.py +109 -0
- mlxsmith/verifiers/docker_verifier.py +111 -0
- mlxsmith/verifiers/jsonschema.py +54 -0
- mlxsmith/verifiers/pytest_verifier.py +82 -0
- mlxsmith/verifiers/regex.py +15 -0
- mlxsmith/verifiers/types.py +10 -0
- mlxsmith-0.1.0.dist-info/METADATA +163 -0
- mlxsmith-0.1.0.dist-info/RECORD +69 -0
- mlxsmith-0.1.0.dist-info/WHEEL +5 -0
- mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
- mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
- mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
mlxsmith/rlm/loop.py
ADDED
|
@@ -0,0 +1,1297 @@
|
|
|
1
|
+
"""RLM Loop - Orchestrator Entry Point for MLXSmith.
|
|
2
|
+
|
|
3
|
+
This module provides both:
|
|
4
|
+
1. Legacy single-process RLM loop (run_rlm)
|
|
5
|
+
2. Multi-process orchestrator mode (run_rlm_orchestrated)
|
|
6
|
+
|
|
7
|
+
The orchestrator mode splits the RLM loop into:
|
|
8
|
+
- Orchestrator Daemon: Queue-based job scheduler
|
|
9
|
+
- Inference Worker Process: OpenAI-compatible API server
|
|
10
|
+
- Trainer Worker Process: Training batch consumer
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import json
|
|
16
|
+
import multiprocessing as mp
|
|
17
|
+
import signal
|
|
18
|
+
import sys
|
|
19
|
+
import time
|
|
20
|
+
import traceback
|
|
21
|
+
from dataclasses import dataclass, asdict
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
from typing import Any, Dict, List, Optional, Callable
|
|
24
|
+
|
|
25
|
+
from rich.console import Console
|
|
26
|
+
|
|
27
|
+
from ..config import ProjectConfig
|
|
28
|
+
from ..eval import run_eval
|
|
29
|
+
from ..llm.registry import get_llm_backend
|
|
30
|
+
from ..models import resolve_model_spec
|
|
31
|
+
from ..runs import new_run, snapshot_config
|
|
32
|
+
from ..train.lora import LoRAConfig
|
|
33
|
+
from ..util import copytree, ensure_dir, now_ts, write_jsonl
|
|
34
|
+
from ..verifiers.docker_verifier import verify as docker_verify
|
|
35
|
+
from ..verifiers.pytest_verifier import verify as pytest_verify
|
|
36
|
+
from .corpus import append_corpus, load_corpus, sample_corpus
|
|
37
|
+
from .gating import load_state, save_state, should_accept, update_state
|
|
38
|
+
from .generate import GeneratedTask, generate_tasks, filter_tasks
|
|
39
|
+
from .history import append_history
|
|
40
|
+
from .inference import Rollout, build_tasks
|
|
41
|
+
from .mutate import mutate_tasks
|
|
42
|
+
from .trainer import train_on_rollouts
|
|
43
|
+
from .weights import (
|
|
44
|
+
WeightPointer,
|
|
45
|
+
WeightPointerIPC,
|
|
46
|
+
WeightPointerStore,
|
|
47
|
+
load_pointer,
|
|
48
|
+
save_pointer,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
console = Console()
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _score_from_eval(result_path: Path) -> float:
|
|
55
|
+
try:
|
|
56
|
+
data = json.loads(result_path.read_text(encoding="utf-8"))
|
|
57
|
+
summary = data.get("summary") or []
|
|
58
|
+
if not summary:
|
|
59
|
+
return 0.0
|
|
60
|
+
return sum(item.get("pass@k", 0.0) for item in summary) / max(1, len(summary))
|
|
61
|
+
except Exception:
|
|
62
|
+
return 0.0
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _load_suite_prompts(path: Path) -> list[str]:
|
|
66
|
+
try:
|
|
67
|
+
import yaml
|
|
68
|
+
|
|
69
|
+
suite = yaml.safe_load(path.read_text(encoding="utf-8")) or {}
|
|
70
|
+
tasks = suite.get("tasks") or []
|
|
71
|
+
return [str(t.get("prompt")) for t in tasks if t.get("prompt")]
|
|
72
|
+
except Exception:
|
|
73
|
+
return []
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
# =============================================================================
|
|
77
|
+
# Legacy Single-Process RLM Loop
|
|
78
|
+
# =============================================================================
|
|
79
|
+
|
|
80
|
+
def run_rlm(
|
|
81
|
+
project_root: Path,
|
|
82
|
+
cfg: ProjectConfig,
|
|
83
|
+
*,
|
|
84
|
+
model_spec: Optional[str] = None,
|
|
85
|
+
iterations: Optional[int] = None,
|
|
86
|
+
resume: bool = False,
|
|
87
|
+
) -> None:
|
|
88
|
+
"""Run single-process RLM loop (legacy mode)."""
|
|
89
|
+
rlm_cfg = cfg.rlm
|
|
90
|
+
state_path = project_root / "runs" / "rlm_state.json"
|
|
91
|
+
history_path = project_root / "runs" / "rlm_history.jsonl"
|
|
92
|
+
corpus_path = project_root / "runs" / "rlm_corpus.jsonl"
|
|
93
|
+
weights_dir = ensure_dir(project_root / "runs" / "rlm_weights")
|
|
94
|
+
|
|
95
|
+
state = load_state(state_path)
|
|
96
|
+
|
|
97
|
+
if model_spec is None:
|
|
98
|
+
model_spec = cfg.model.id
|
|
99
|
+
|
|
100
|
+
base_model, initial_adapter, _meta = resolve_model_spec(project_root, model_spec, cfg)
|
|
101
|
+
infer_ptr_path = weights_dir / "infer.json"
|
|
102
|
+
train_ptr_path = weights_dir / "train.json"
|
|
103
|
+
infer_ptr = load_pointer(infer_ptr_path, base_model=base_model, name="inference")
|
|
104
|
+
train_ptr = load_pointer(train_ptr_path, base_model=base_model, name="trainer")
|
|
105
|
+
|
|
106
|
+
if initial_adapter and not infer_ptr.adapter_path:
|
|
107
|
+
infer_ptr = WeightPointer(
|
|
108
|
+
base_model=base_model,
|
|
109
|
+
adapter_path=str(initial_adapter),
|
|
110
|
+
iteration=state.last_iteration,
|
|
111
|
+
updated_at=now_ts(),
|
|
112
|
+
name="inference",
|
|
113
|
+
)
|
|
114
|
+
save_pointer(infer_ptr_path, infer_ptr)
|
|
115
|
+
|
|
116
|
+
if initial_adapter and not train_ptr.adapter_path:
|
|
117
|
+
train_ptr = WeightPointer(
|
|
118
|
+
base_model=base_model,
|
|
119
|
+
adapter_path=str(initial_adapter),
|
|
120
|
+
iteration=state.last_iteration,
|
|
121
|
+
updated_at=now_ts(),
|
|
122
|
+
name="trainer",
|
|
123
|
+
)
|
|
124
|
+
save_pointer(train_ptr_path, train_ptr)
|
|
125
|
+
|
|
126
|
+
if resume and state.current_adapter:
|
|
127
|
+
train_ptr = WeightPointer(
|
|
128
|
+
base_model=base_model,
|
|
129
|
+
adapter_path=state.current_adapter,
|
|
130
|
+
iteration=state.last_iteration,
|
|
131
|
+
updated_at=now_ts(),
|
|
132
|
+
name="trainer",
|
|
133
|
+
)
|
|
134
|
+
save_pointer(train_ptr_path, train_ptr)
|
|
135
|
+
if not infer_ptr.adapter_path:
|
|
136
|
+
infer_ptr = WeightPointer(
|
|
137
|
+
base_model=base_model,
|
|
138
|
+
adapter_path=state.current_adapter,
|
|
139
|
+
iteration=state.last_iteration,
|
|
140
|
+
updated_at=now_ts(),
|
|
141
|
+
name="inference",
|
|
142
|
+
)
|
|
143
|
+
save_pointer(infer_ptr_path, infer_ptr)
|
|
144
|
+
|
|
145
|
+
start_iter = state.last_iteration + 1 if resume else 1
|
|
146
|
+
total_iters = iterations if iterations is not None else int(rlm_cfg.iterations)
|
|
147
|
+
|
|
148
|
+
def iter_range():
|
|
149
|
+
if total_iters == 0:
|
|
150
|
+
i = start_iter
|
|
151
|
+
while True:
|
|
152
|
+
yield i
|
|
153
|
+
i += 1
|
|
154
|
+
else:
|
|
155
|
+
for i in range(start_iter, start_iter + total_iters):
|
|
156
|
+
yield i
|
|
157
|
+
|
|
158
|
+
for iteration in iter_range():
|
|
159
|
+
run = new_run(project_root, "rlm")
|
|
160
|
+
snapshot_config(cfg.model_dump(), run.config_snapshot_path)
|
|
161
|
+
console.print(f"[bold]RLM[/bold] iteration {iteration} run={run.run_dir.name}")
|
|
162
|
+
|
|
163
|
+
infer_llm = get_llm_backend(cfg.model.backend)
|
|
164
|
+
infer_llm.load(
|
|
165
|
+
infer_ptr.base_model,
|
|
166
|
+
max_seq_len=cfg.model.max_seq_len,
|
|
167
|
+
dtype=cfg.model.dtype,
|
|
168
|
+
trust_remote_code=cfg.model.trust_remote_code,
|
|
169
|
+
)
|
|
170
|
+
if infer_ptr.adapter_path:
|
|
171
|
+
infer_llm.apply_adapter(str(infer_ptr.adapter_path))
|
|
172
|
+
|
|
173
|
+
train_llm = get_llm_backend(cfg.model.backend)
|
|
174
|
+
train_llm.load(
|
|
175
|
+
train_ptr.base_model,
|
|
176
|
+
max_seq_len=cfg.model.max_seq_len,
|
|
177
|
+
dtype=cfg.model.dtype,
|
|
178
|
+
trust_remote_code=cfg.model.trust_remote_code,
|
|
179
|
+
)
|
|
180
|
+
if train_ptr.adapter_path:
|
|
181
|
+
train_llm.apply_adapter(str(train_ptr.adapter_path))
|
|
182
|
+
else:
|
|
183
|
+
lora_cfg = LoRAConfig(
|
|
184
|
+
r=cfg.lora.r,
|
|
185
|
+
alpha=cfg.lora.alpha,
|
|
186
|
+
dropout=cfg.lora.dropout,
|
|
187
|
+
target_modules=list(cfg.lora.target_modules or []),
|
|
188
|
+
num_layers=cfg.lora.num_layers,
|
|
189
|
+
scale=cfg.lora.scale,
|
|
190
|
+
fine_tune_type=cfg.lora.fine_tune_type,
|
|
191
|
+
)
|
|
192
|
+
train_llm.apply_lora_from_config(lora_cfg)
|
|
193
|
+
|
|
194
|
+
ref_llm = None
|
|
195
|
+
if cfg.rft.reference_model:
|
|
196
|
+
ref_llm = get_llm_backend(cfg.model.backend)
|
|
197
|
+
ref_llm.load(
|
|
198
|
+
cfg.rft.reference_model,
|
|
199
|
+
max_seq_len=cfg.model.max_seq_len,
|
|
200
|
+
dtype=cfg.model.dtype,
|
|
201
|
+
trust_remote_code=cfg.model.trust_remote_code,
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
opt, _params = train_llm.optimizer_and_params(lr=cfg.train.lr, weight_decay=cfg.train.weight_decay)
|
|
205
|
+
|
|
206
|
+
corpus_rows = load_corpus(corpus_path, max_size=int(rlm_cfg.corpus_max))
|
|
207
|
+
existing_prompts = [row.get("prompt", "") for row in corpus_rows if row.get("prompt")]
|
|
208
|
+
if rlm_cfg.benchmark_suite:
|
|
209
|
+
suite_path = project_root / rlm_cfg.benchmark_suite
|
|
210
|
+
if suite_path.exists():
|
|
211
|
+
existing_prompts.extend(_load_suite_prompts(suite_path))
|
|
212
|
+
if rlm_cfg.holdout_suite:
|
|
213
|
+
holdout_path = project_root / rlm_cfg.holdout_suite
|
|
214
|
+
if holdout_path.exists():
|
|
215
|
+
existing_prompts.extend(_load_suite_prompts(holdout_path))
|
|
216
|
+
|
|
217
|
+
tasks = build_tasks(
|
|
218
|
+
infer_llm,
|
|
219
|
+
cfg,
|
|
220
|
+
require_recursion=bool(rlm_cfg.require_recursion),
|
|
221
|
+
tasks_per_iter=int(rlm_cfg.tasks_per_iter),
|
|
222
|
+
mutations_per_task=int(rlm_cfg.mutations_per_task),
|
|
223
|
+
max_total=int(rlm_cfg.tasks_per_iter),
|
|
224
|
+
existing_prompts=existing_prompts,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
write_jsonl(run.run_dir / "tasks.jsonl", [task.__dict__ for task in tasks])
|
|
228
|
+
rollouts, passed_samples = collect_rollouts_via_api(
|
|
229
|
+
tasks,
|
|
230
|
+
cfg,
|
|
231
|
+
api_url=f"http://localhost:{cfg.serve.port}",
|
|
232
|
+
artifacts_dir=run.artifacts_dir,
|
|
233
|
+
verifier_backend=str(rlm_cfg.verifier_backend),
|
|
234
|
+
weight_adapter=infer_ptr.adapter_path,
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
metrics_rows = train_on_rollouts(
|
|
238
|
+
train_llm,
|
|
239
|
+
rollouts,
|
|
240
|
+
cfg,
|
|
241
|
+
optimizer=opt,
|
|
242
|
+
train_adapter=train_ptr.adapter_path,
|
|
243
|
+
ref_llm=ref_llm,
|
|
244
|
+
)
|
|
245
|
+
for row in metrics_rows:
|
|
246
|
+
row["iteration"] = iteration
|
|
247
|
+
write_jsonl(run.metrics_path, metrics_rows)
|
|
248
|
+
|
|
249
|
+
if passed_samples:
|
|
250
|
+
append_corpus(corpus_path, passed_samples, max_size=int(rlm_cfg.corpus_max))
|
|
251
|
+
|
|
252
|
+
# Optional corpus rehearsal via SFT
|
|
253
|
+
mix_ratio = float(rlm_cfg.mix_old_ratio)
|
|
254
|
+
if mix_ratio > 0 and corpus_rows:
|
|
255
|
+
n_samples = int(max(1, len(tasks) * mix_ratio))
|
|
256
|
+
for row in sample_corpus(corpus_rows, n=n_samples, hard_ratio=float(rlm_cfg.hard_ratio)):
|
|
257
|
+
prompt = row.get("prompt", "")
|
|
258
|
+
response = row.get("response", "")
|
|
259
|
+
if not prompt or not response:
|
|
260
|
+
continue
|
|
261
|
+
prompt_ids = train_llm.encode(prompt)
|
|
262
|
+
ids = train_llm.encode(prompt + response)
|
|
263
|
+
max_len = int(cfg.model.max_seq_len)
|
|
264
|
+
if max_len and len(ids) > max_len:
|
|
265
|
+
overflow = len(ids) - max_len
|
|
266
|
+
ids = ids[overflow:]
|
|
267
|
+
prompt_ids = prompt_ids[overflow:] if overflow < len(prompt_ids) else []
|
|
268
|
+
|
|
269
|
+
def sft_loss_fn(_model):
|
|
270
|
+
return train_llm.sft_loss(ids, train_on_prompt=cfg.train.train_on_prompt, prompt_len=len(prompt_ids))
|
|
271
|
+
|
|
272
|
+
lval, grads = train_llm.value_and_grad(sft_loss_fn)
|
|
273
|
+
if grads is not None:
|
|
274
|
+
train_llm.apply_grads(opt, grads)
|
|
275
|
+
|
|
276
|
+
train_llm.save_adapter(
|
|
277
|
+
str(run.adapter_dir),
|
|
278
|
+
metadata={
|
|
279
|
+
"base_model": train_ptr.base_model,
|
|
280
|
+
"source_adapter": str(train_ptr.adapter_path) if train_ptr.adapter_path else None,
|
|
281
|
+
"run": run.run_dir.name,
|
|
282
|
+
"kind": "rlm",
|
|
283
|
+
"iteration": iteration,
|
|
284
|
+
},
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
# Evaluate
|
|
288
|
+
adapter_score = 0.0
|
|
289
|
+
if rlm_cfg.benchmark_suite:
|
|
290
|
+
suite_path = project_root / rlm_cfg.benchmark_suite
|
|
291
|
+
if suite_path.exists():
|
|
292
|
+
eval_path = run_eval(project_root, suite_path, run.adapter_dir)
|
|
293
|
+
adapter_score = _score_from_eval(eval_path)
|
|
294
|
+
|
|
295
|
+
holdout_score = None
|
|
296
|
+
if rlm_cfg.holdout_suite:
|
|
297
|
+
holdout_path = project_root / rlm_cfg.holdout_suite
|
|
298
|
+
if holdout_path.exists():
|
|
299
|
+
holdout_eval = run_eval(project_root, holdout_path, run.adapter_dir)
|
|
300
|
+
holdout_score = _score_from_eval(holdout_eval)
|
|
301
|
+
|
|
302
|
+
accepted = should_accept(
|
|
303
|
+
adapter_score,
|
|
304
|
+
state,
|
|
305
|
+
mode=rlm_cfg.gating,
|
|
306
|
+
threshold=float(rlm_cfg.gating_threshold),
|
|
307
|
+
ema_alpha=float(rlm_cfg.gating_ema_alpha),
|
|
308
|
+
)
|
|
309
|
+
state = update_state(
|
|
310
|
+
state,
|
|
311
|
+
iteration=iteration,
|
|
312
|
+
score=adapter_score,
|
|
313
|
+
adapter_path=str(run.adapter_dir),
|
|
314
|
+
accepted=accepted,
|
|
315
|
+
ema_alpha=float(rlm_cfg.gating_ema_alpha),
|
|
316
|
+
)
|
|
317
|
+
save_state(state_path, state)
|
|
318
|
+
|
|
319
|
+
if state.current_adapter:
|
|
320
|
+
train_ptr = WeightPointer(
|
|
321
|
+
base_model=base_model,
|
|
322
|
+
adapter_path=state.current_adapter,
|
|
323
|
+
iteration=iteration,
|
|
324
|
+
updated_at=now_ts(),
|
|
325
|
+
name="trainer",
|
|
326
|
+
)
|
|
327
|
+
save_pointer(train_ptr_path, train_ptr)
|
|
328
|
+
|
|
329
|
+
infer_staleness = int(getattr(rlm_cfg, "infer_staleness", 0))
|
|
330
|
+
if infer_staleness <= 0:
|
|
331
|
+
infer_ptr = train_ptr
|
|
332
|
+
save_pointer(infer_ptr_path, infer_ptr)
|
|
333
|
+
else:
|
|
334
|
+
lag = max(0, int(train_ptr.iteration) - int(infer_ptr.iteration))
|
|
335
|
+
if lag >= infer_staleness:
|
|
336
|
+
infer_ptr = train_ptr
|
|
337
|
+
save_pointer(infer_ptr_path, infer_ptr)
|
|
338
|
+
|
|
339
|
+
append_history(
|
|
340
|
+
history_path,
|
|
341
|
+
{
|
|
342
|
+
"iteration": iteration,
|
|
343
|
+
"timestamp": now_ts(),
|
|
344
|
+
"adapter_score": adapter_score,
|
|
345
|
+
"holdout_score": holdout_score,
|
|
346
|
+
"best_score": state.best_score,
|
|
347
|
+
"accepted": accepted,
|
|
348
|
+
"adapter_dir": str(run.adapter_dir),
|
|
349
|
+
},
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
gating_path = run.run_dir / "gating.json"
|
|
353
|
+
gating_path.write_text(
|
|
354
|
+
json.dumps(
|
|
355
|
+
{
|
|
356
|
+
"iteration": iteration,
|
|
357
|
+
"accepted": accepted,
|
|
358
|
+
"adapter_score": adapter_score,
|
|
359
|
+
"holdout_score": holdout_score,
|
|
360
|
+
"best_score": state.best_score,
|
|
361
|
+
"current_adapter": state.current_adapter,
|
|
362
|
+
},
|
|
363
|
+
indent=2,
|
|
364
|
+
),
|
|
365
|
+
encoding="utf-8",
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
if rlm_cfg.sleep_between > 0:
|
|
369
|
+
time.sleep(float(rlm_cfg.sleep_between))
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
# =============================================================================
|
|
373
|
+
# Multi-Process Orchestrated RLM
|
|
374
|
+
# =============================================================================
|
|
375
|
+
|
|
376
|
+
from ..orchestrator.queue import MessageQueue, MessageType, Message
|
|
377
|
+
from ..orchestrator.inference_worker import InferenceConfig, run_inference_worker
|
|
378
|
+
from ..orchestrator.trainer_worker import TrainerConfig, run_trainer_worker
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
@dataclass
|
|
382
|
+
class OrchestratorState:
|
|
383
|
+
"""State for the orchestrated RLM loop."""
|
|
384
|
+
iteration: int = 0
|
|
385
|
+
run_id: str = ""
|
|
386
|
+
pending_rollouts: int = 0
|
|
387
|
+
pending_training: bool = False
|
|
388
|
+
current_adapter: Optional[str] = None
|
|
389
|
+
best_score: float = 0.0
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
class RLMOrchestrator:
|
|
393
|
+
"""Multi-process RLM orchestrator.
|
|
394
|
+
|
|
395
|
+
Spawns and manages inference and trainer processes,
|
|
396
|
+
coordinates rollout generation and training via queues.
|
|
397
|
+
"""
|
|
398
|
+
|
|
399
|
+
def __init__(
|
|
400
|
+
self,
|
|
401
|
+
project_root: Path,
|
|
402
|
+
cfg: ProjectConfig,
|
|
403
|
+
model_spec: str,
|
|
404
|
+
iterations: int = 50,
|
|
405
|
+
resume: bool = False,
|
|
406
|
+
):
|
|
407
|
+
self.project_root = project_root
|
|
408
|
+
self.cfg = cfg
|
|
409
|
+
self.model_spec = model_spec
|
|
410
|
+
self.iterations = iterations
|
|
411
|
+
self.resume = resume
|
|
412
|
+
|
|
413
|
+
self._base_model, self._initial_adapter, _ = resolve_model_spec(
|
|
414
|
+
self.project_root, self.model_spec, self.cfg
|
|
415
|
+
)
|
|
416
|
+
self._rollout_timeout_s = 120.0
|
|
417
|
+
self._train_timeout_s = 900.0
|
|
418
|
+
|
|
419
|
+
# Paths
|
|
420
|
+
self.state_path = project_root / "runs" / "rlm_state.json"
|
|
421
|
+
self.history_path = project_root / "runs" / "rlm_history.jsonl"
|
|
422
|
+
self.corpus_path = project_root / "runs" / "rlm_corpus.jsonl"
|
|
423
|
+
self.weights_dir = ensure_dir(project_root / "runs" / "rlm_weights")
|
|
424
|
+
|
|
425
|
+
# State
|
|
426
|
+
self.gating_state = load_state(self.state_path)
|
|
427
|
+
self.orchestrator_state = OrchestratorState()
|
|
428
|
+
|
|
429
|
+
# IPC
|
|
430
|
+
self.queue = MessageQueue(maxsize=10000)
|
|
431
|
+
self._pointer_store = WeightPointerStore(self.weights_dir)
|
|
432
|
+
|
|
433
|
+
# Processes
|
|
434
|
+
self._inference_process: Optional[mp.Process] = None
|
|
435
|
+
self._trainer_process: Optional[mp.Process] = None
|
|
436
|
+
self._shutdown = False
|
|
437
|
+
|
|
438
|
+
# Rollout buffer
|
|
439
|
+
self._rollout_buffer: List[Rollout] = []
|
|
440
|
+
self._passed_samples: List[Dict] = []
|
|
441
|
+
|
|
442
|
+
def _setup_signal_handlers(self) -> None:
|
|
443
|
+
"""Setup signal handlers for graceful shutdown."""
|
|
444
|
+
def signal_handler(sig, frame):
|
|
445
|
+
console.print("[yellow]Orchestrator received shutdown signal[/yellow]")
|
|
446
|
+
self._shutdown = True
|
|
447
|
+
|
|
448
|
+
signal.signal(signal.SIGTERM, signal_handler)
|
|
449
|
+
signal.signal(signal.SIGINT, signal_handler)
|
|
450
|
+
|
|
451
|
+
def _start_inference_worker(self) -> None:
|
|
452
|
+
"""Start the inference worker process."""
|
|
453
|
+
inf_config = InferenceConfig(
|
|
454
|
+
model_spec=self.model_spec,
|
|
455
|
+
backend=self.cfg.model.backend,
|
|
456
|
+
host=self.cfg.serve.host,
|
|
457
|
+
port=self.cfg.serve.port,
|
|
458
|
+
max_seq_len=self.cfg.model.max_seq_len,
|
|
459
|
+
dtype=self.cfg.model.dtype,
|
|
460
|
+
trust_remote_code=self.cfg.model.trust_remote_code,
|
|
461
|
+
use_chat_template=self.cfg.model.use_chat_template,
|
|
462
|
+
weights_dir=self.weights_dir,
|
|
463
|
+
hot_reload=True,
|
|
464
|
+
)
|
|
465
|
+
|
|
466
|
+
# Initialize inference pointer
|
|
467
|
+
pointer = WeightPointerIPC(
|
|
468
|
+
base_model=self._base_model,
|
|
469
|
+
adapter_path=str(self._initial_adapter) if self._initial_adapter else None,
|
|
470
|
+
iteration=self.gating_state.last_iteration,
|
|
471
|
+
updated_at=now_ts(),
|
|
472
|
+
version=self.gating_state.last_iteration,
|
|
473
|
+
name="inference",
|
|
474
|
+
)
|
|
475
|
+
self._pointer_store.save(pointer)
|
|
476
|
+
|
|
477
|
+
self._inference_process = mp.Process(
|
|
478
|
+
target=run_inference_worker,
|
|
479
|
+
args=(inf_config, self.queue),
|
|
480
|
+
name="inference_worker",
|
|
481
|
+
daemon=False,
|
|
482
|
+
)
|
|
483
|
+
self._inference_process.start()
|
|
484
|
+
console.print(f"[green]Started inference worker (PID: {self._inference_process.pid})[/green]")
|
|
485
|
+
|
|
486
|
+
def _start_trainer_worker(self) -> None:
|
|
487
|
+
"""Start the trainer worker process."""
|
|
488
|
+
trainer_config = TrainerConfig(
|
|
489
|
+
model_spec=self.model_spec,
|
|
490
|
+
base_model=self._base_model,
|
|
491
|
+
backend=self.cfg.model.backend,
|
|
492
|
+
max_seq_len=self.cfg.model.max_seq_len,
|
|
493
|
+
dtype=self.cfg.model.dtype,
|
|
494
|
+
trust_remote_code=self.cfg.model.trust_remote_code,
|
|
495
|
+
lr=self.cfg.train.lr,
|
|
496
|
+
weight_decay=self.cfg.train.weight_decay,
|
|
497
|
+
kl_coeff=self.cfg.rft.kl_coeff,
|
|
498
|
+
normalize_advantage=self.cfg.rft.normalize_advantage,
|
|
499
|
+
lora_r=self.cfg.lora.r,
|
|
500
|
+
lora_alpha=self.cfg.lora.alpha,
|
|
501
|
+
lora_dropout=self.cfg.lora.dropout,
|
|
502
|
+
lora_target_modules=list(self.cfg.lora.target_modules or []),
|
|
503
|
+
lora_num_layers=self.cfg.lora.num_layers,
|
|
504
|
+
weights_dir=self.weights_dir,
|
|
505
|
+
checkpoint_dir=self.project_root / "runs" / "rlm_checkpoints",
|
|
506
|
+
reference_model=self.cfg.rft.reference_model,
|
|
507
|
+
)
|
|
508
|
+
|
|
509
|
+
# Initialize trainer pointer
|
|
510
|
+
pointer = WeightPointerIPC(
|
|
511
|
+
base_model=self._base_model,
|
|
512
|
+
adapter_path=str(self._initial_adapter) if self._initial_adapter else None,
|
|
513
|
+
iteration=self.gating_state.last_iteration,
|
|
514
|
+
updated_at=now_ts(),
|
|
515
|
+
version=self.gating_state.last_iteration,
|
|
516
|
+
name="trainer",
|
|
517
|
+
)
|
|
518
|
+
self._pointer_store.save(pointer)
|
|
519
|
+
|
|
520
|
+
self._trainer_process = mp.Process(
|
|
521
|
+
target=run_trainer_worker,
|
|
522
|
+
args=(trainer_config, self.queue),
|
|
523
|
+
name="trainer_worker",
|
|
524
|
+
daemon=False,
|
|
525
|
+
)
|
|
526
|
+
self._trainer_process.start()
|
|
527
|
+
console.print(f"[green]Started trainer worker (PID: {self._trainer_process.pid})[/green]")
|
|
528
|
+
|
|
529
|
+
def _stop_workers(self) -> None:
|
|
530
|
+
"""Stop all worker processes."""
|
|
531
|
+
console.print("[yellow]Stopping workers...[/yellow]")
|
|
532
|
+
try:
|
|
533
|
+
self.queue.send("rollout_requests", MessageType.SHUTDOWN, {}, source="orchestrator")
|
|
534
|
+
self.queue.send("control", MessageType.SHUTDOWN, {}, source="orchestrator")
|
|
535
|
+
except Exception:
|
|
536
|
+
pass
|
|
537
|
+
|
|
538
|
+
# Terminate processes
|
|
539
|
+
for name, proc in [("inference", self._inference_process), ("trainer", self._trainer_process)]:
|
|
540
|
+
if proc and proc.is_alive():
|
|
541
|
+
console.print(f" Stopping {name} worker...")
|
|
542
|
+
proc.terminate()
|
|
543
|
+
proc.join(timeout=10.0)
|
|
544
|
+
if proc.is_alive():
|
|
545
|
+
proc.kill()
|
|
546
|
+
proc.join(timeout=5.0)
|
|
547
|
+
|
|
548
|
+
self.queue.stop()
|
|
549
|
+
console.print("[green]Workers stopped[/green]")
|
|
550
|
+
|
|
551
|
+
def _wait_for_inference(self, timeout: float = 60.0) -> bool:
|
|
552
|
+
"""Wait for inference worker to be ready via queue health check."""
|
|
553
|
+
start = time.time()
|
|
554
|
+
while time.time() - start < timeout and not self._shutdown:
|
|
555
|
+
try:
|
|
556
|
+
self.queue.send("rollout_requests", MessageType.HEALTH_CHECK, {}, source="orchestrator")
|
|
557
|
+
msg = self.queue.receive("rollout_responses", timeout=1.0)
|
|
558
|
+
if msg and msg.msg_type == MessageType.HEALTH_RESPONSE:
|
|
559
|
+
return True
|
|
560
|
+
except Exception:
|
|
561
|
+
pass
|
|
562
|
+
time.sleep(0.5)
|
|
563
|
+
return False
|
|
564
|
+
|
|
565
|
+
def _generate_rollout_via_api(
|
|
566
|
+
self,
|
|
567
|
+
task: GeneratedTask,
|
|
568
|
+
rollouts_per_task: int,
|
|
569
|
+
) -> List[Rollout]:
|
|
570
|
+
"""Generate rollouts for a task via inference API."""
|
|
571
|
+
import requests
|
|
572
|
+
rollouts = []
|
|
573
|
+
|
|
574
|
+
for k in range(rollouts_per_task):
|
|
575
|
+
try:
|
|
576
|
+
resp = requests.post(
|
|
577
|
+
f"http://localhost:{self.cfg.serve.port}/internal/rollout",
|
|
578
|
+
json={
|
|
579
|
+
"prompt": task.prompt,
|
|
580
|
+
"max_tokens": int(self.cfg.rft.max_new_tokens),
|
|
581
|
+
"temperature": float(self.cfg.rft.temperature),
|
|
582
|
+
"top_p": float(self.cfg.infer.top_p),
|
|
583
|
+
"top_k": self.cfg.infer.top_k,
|
|
584
|
+
"seed": int(time.time() * 1000) % (2**31 - 1),
|
|
585
|
+
"include_tokens": True,
|
|
586
|
+
"include_logprobs": True,
|
|
587
|
+
},
|
|
588
|
+
timeout=120.0,
|
|
589
|
+
)
|
|
590
|
+
|
|
591
|
+
if resp.status_code != 200:
|
|
592
|
+
console.print(f"[red]Rollout API error: {resp.status_code}[/red]")
|
|
593
|
+
continue
|
|
594
|
+
|
|
595
|
+
data = resp.json()
|
|
596
|
+
completion = data.get("completion", "")
|
|
597
|
+
|
|
598
|
+
# Run verifier
|
|
599
|
+
from ..util import ensure_dir
|
|
600
|
+
wdir = ensure_dir(self.project_root / "runs" / ".temp" / task.id / f"rollout_{k:02d}")
|
|
601
|
+
(wdir / "main.py").write_text(completion, encoding="utf-8")
|
|
602
|
+
|
|
603
|
+
# Write tests
|
|
604
|
+
tests_dir = ensure_dir(wdir / "tests")
|
|
605
|
+
(tests_dir / "test_task.py").write_text(task.tests, encoding="utf-8")
|
|
606
|
+
|
|
607
|
+
t0 = time.time()
|
|
608
|
+
if self.cfg.rlm.verifier_backend == "docker":
|
|
609
|
+
res = docker_verify(
|
|
610
|
+
task.prompt,
|
|
611
|
+
completion,
|
|
612
|
+
str(wdir),
|
|
613
|
+
timeout_s=int(self.cfg.rlm.verifier_timeout_s),
|
|
614
|
+
image=self.cfg.rlm.docker_image,
|
|
615
|
+
memory_mb=int(self.cfg.rlm.docker_memory_mb),
|
|
616
|
+
cpus=float(self.cfg.rlm.docker_cpus),
|
|
617
|
+
pids=int(self.cfg.rlm.docker_pids),
|
|
618
|
+
)
|
|
619
|
+
else:
|
|
620
|
+
from ..verifiers.pytest_verifier import verify as pytest_verify
|
|
621
|
+
res = pytest_verify(
|
|
622
|
+
task.prompt,
|
|
623
|
+
completion,
|
|
624
|
+
str(wdir),
|
|
625
|
+
timeout_s=int(self.cfg.rlm.verifier_timeout_s),
|
|
626
|
+
)
|
|
627
|
+
latency_ms = (time.time() - t0) * 1000.0
|
|
628
|
+
|
|
629
|
+
passed = bool(getattr(res, "passed", False))
|
|
630
|
+
reward = float(getattr(res, "reward", 0.0))
|
|
631
|
+
|
|
632
|
+
rollouts.append(Rollout(
|
|
633
|
+
task_id=task.id,
|
|
634
|
+
prompt=task.prompt,
|
|
635
|
+
completion=completion,
|
|
636
|
+
token_ids=data.get("token_ids", []),
|
|
637
|
+
prompt_len=data.get("prompt_len", 0),
|
|
638
|
+
logprobs=data.get("logprobs"),
|
|
639
|
+
passed=passed,
|
|
640
|
+
reward=reward,
|
|
641
|
+
verifier_latency_ms=latency_ms,
|
|
642
|
+
weight_adapter=self._pointer_store.load("inference", self._base_model).adapter_path,
|
|
643
|
+
))
|
|
644
|
+
|
|
645
|
+
if passed:
|
|
646
|
+
self._passed_samples.append({
|
|
647
|
+
"id": task.id,
|
|
648
|
+
"prompt": task.prompt,
|
|
649
|
+
"response": completion,
|
|
650
|
+
"reward": reward,
|
|
651
|
+
"ts": now_ts(),
|
|
652
|
+
})
|
|
653
|
+
|
|
654
|
+
except Exception as e:
|
|
655
|
+
console.print(f"[red]Rollout error: {e}[/red]")
|
|
656
|
+
continue
|
|
657
|
+
|
|
658
|
+
return rollouts
|
|
659
|
+
|
|
660
|
+
def _generate_rollout_via_queue(
|
|
661
|
+
self,
|
|
662
|
+
task: GeneratedTask,
|
|
663
|
+
rollouts_per_task: int,
|
|
664
|
+
) -> List[Rollout]:
|
|
665
|
+
"""Generate rollouts for a task via the message queue."""
|
|
666
|
+
rollouts: List[Rollout] = []
|
|
667
|
+
|
|
668
|
+
for k in range(rollouts_per_task):
|
|
669
|
+
try:
|
|
670
|
+
req = self.queue.send(
|
|
671
|
+
"rollout_requests",
|
|
672
|
+
MessageType.ROLLOUT_REQUEST,
|
|
673
|
+
{
|
|
674
|
+
"prompt": task.prompt,
|
|
675
|
+
"max_tokens": int(self.cfg.rft.max_new_tokens),
|
|
676
|
+
"temperature": float(self.cfg.rft.temperature),
|
|
677
|
+
"top_p": float(self.cfg.infer.top_p),
|
|
678
|
+
"top_k": self.cfg.infer.top_k,
|
|
679
|
+
"seed": int(time.time() * 1000) % (2**31 - 1),
|
|
680
|
+
},
|
|
681
|
+
source="orchestrator",
|
|
682
|
+
)
|
|
683
|
+
|
|
684
|
+
# Wait for matching response
|
|
685
|
+
response = None
|
|
686
|
+
start = time.time()
|
|
687
|
+
while time.time() - start < self._rollout_timeout_s:
|
|
688
|
+
msg = self.queue.receive("rollout_responses", timeout=0.5)
|
|
689
|
+
if not msg:
|
|
690
|
+
continue
|
|
691
|
+
if msg.msg_type != MessageType.ROLLOUT_RESPONSE:
|
|
692
|
+
continue
|
|
693
|
+
if msg.payload.get("request_id") == req.msg_id:
|
|
694
|
+
response = msg
|
|
695
|
+
break
|
|
696
|
+
|
|
697
|
+
if response is None:
|
|
698
|
+
console.print("[red]Rollout queue timeout[/red]")
|
|
699
|
+
continue
|
|
700
|
+
|
|
701
|
+
data = response.payload
|
|
702
|
+
completion = data.get("completion", "")
|
|
703
|
+
|
|
704
|
+
# Run verifier
|
|
705
|
+
wdir = ensure_dir(self.project_root / "runs" / ".temp" / task.id / f"rollout_{k:02d}")
|
|
706
|
+
(wdir / "main.py").write_text(completion, encoding="utf-8")
|
|
707
|
+
|
|
708
|
+
tests_dir = ensure_dir(wdir / "tests")
|
|
709
|
+
(tests_dir / "test_task.py").write_text(task.tests, encoding="utf-8")
|
|
710
|
+
|
|
711
|
+
t0 = time.time()
|
|
712
|
+
if self.cfg.rlm.verifier_backend == "docker":
|
|
713
|
+
res = docker_verify(
|
|
714
|
+
task.prompt,
|
|
715
|
+
completion,
|
|
716
|
+
str(wdir),
|
|
717
|
+
timeout_s=int(self.cfg.rlm.verifier_timeout_s),
|
|
718
|
+
image=self.cfg.rlm.docker_image,
|
|
719
|
+
memory_mb=int(self.cfg.rlm.docker_memory_mb),
|
|
720
|
+
cpus=float(self.cfg.rlm.docker_cpus),
|
|
721
|
+
pids=int(self.cfg.rlm.docker_pids),
|
|
722
|
+
)
|
|
723
|
+
else:
|
|
724
|
+
from ..verifiers.pytest_verifier import verify as pytest_verify
|
|
725
|
+
res = pytest_verify(
|
|
726
|
+
task.prompt,
|
|
727
|
+
completion,
|
|
728
|
+
str(wdir),
|
|
729
|
+
timeout_s=int(self.cfg.rlm.verifier_timeout_s),
|
|
730
|
+
)
|
|
731
|
+
latency_ms = (time.time() - t0) * 1000.0
|
|
732
|
+
|
|
733
|
+
passed = bool(getattr(res, "passed", False))
|
|
734
|
+
reward = float(getattr(res, "reward", 0.0))
|
|
735
|
+
|
|
736
|
+
rollouts.append(
|
|
737
|
+
Rollout(
|
|
738
|
+
task_id=task.id,
|
|
739
|
+
prompt=task.prompt,
|
|
740
|
+
completion=completion,
|
|
741
|
+
token_ids=data.get("token_ids", []),
|
|
742
|
+
prompt_len=data.get("prompt_len", 0),
|
|
743
|
+
logprobs=data.get("logprobs"),
|
|
744
|
+
passed=passed,
|
|
745
|
+
reward=reward,
|
|
746
|
+
verifier_latency_ms=latency_ms,
|
|
747
|
+
weight_adapter=self._pointer_store.load("inference", self._base_model).adapter_path,
|
|
748
|
+
)
|
|
749
|
+
)
|
|
750
|
+
|
|
751
|
+
if passed:
|
|
752
|
+
self._passed_samples.append(
|
|
753
|
+
{
|
|
754
|
+
"id": task.id,
|
|
755
|
+
"prompt": task.prompt,
|
|
756
|
+
"response": completion,
|
|
757
|
+
"reward": reward,
|
|
758
|
+
"ts": now_ts(),
|
|
759
|
+
}
|
|
760
|
+
)
|
|
761
|
+
|
|
762
|
+
except Exception as e:
|
|
763
|
+
console.print(f"[red]Rollout error: {e}[/red]")
|
|
764
|
+
continue
|
|
765
|
+
|
|
766
|
+
return rollouts
|
|
767
|
+
|
|
768
|
+
def _send_training_batch(self, rollouts: List[Rollout], iteration: int, run_id: str) -> Message:
|
|
769
|
+
"""Send a training batch to the trainer worker via queue."""
|
|
770
|
+
save_checkpoint = True
|
|
771
|
+
|
|
772
|
+
payload = {
|
|
773
|
+
"iteration": iteration,
|
|
774
|
+
"run_id": run_id,
|
|
775
|
+
"save_checkpoint": save_checkpoint,
|
|
776
|
+
"rollouts": [
|
|
777
|
+
{
|
|
778
|
+
"task_id": r.task_id,
|
|
779
|
+
"prompt": r.prompt,
|
|
780
|
+
"completion": r.completion,
|
|
781
|
+
"token_ids": r.token_ids,
|
|
782
|
+
"prompt_len": r.prompt_len,
|
|
783
|
+
"logprobs": r.logprobs,
|
|
784
|
+
"passed": r.passed,
|
|
785
|
+
"reward": r.reward,
|
|
786
|
+
"verifier_latency_ms": r.verifier_latency_ms,
|
|
787
|
+
"weight_adapter": r.weight_adapter,
|
|
788
|
+
}
|
|
789
|
+
for r in rollouts
|
|
790
|
+
],
|
|
791
|
+
}
|
|
792
|
+
return self.queue.send("train_batches", MessageType.TRAIN_BATCH, payload, source="orchestrator")
|
|
793
|
+
|
|
794
|
+
def _drain_queue(self, queue_name: str) -> None:
|
|
795
|
+
"""Drain all pending messages from a queue."""
|
|
796
|
+
while True:
|
|
797
|
+
msg = self.queue.receive(queue_name, timeout=0)
|
|
798
|
+
if msg is None:
|
|
799
|
+
break
|
|
800
|
+
|
|
801
|
+
def run_iteration(self, iteration: int) -> bool:
|
|
802
|
+
"""Run a single orchestrated RLM iteration."""
|
|
803
|
+
console.print(f"\n[bold blue]=== Orchestrated RLM Iteration {iteration} ===[/bold blue]")
|
|
804
|
+
|
|
805
|
+
run = new_run(self.project_root, "rlm")
|
|
806
|
+
snapshot_config(self.cfg.model_dump(), run.config_snapshot_path)
|
|
807
|
+
|
|
808
|
+
# Generate tasks using a temporary LLM instance
|
|
809
|
+
# (In future: task generation could also go through inference worker)
|
|
810
|
+
console.print(" [dim]Generating tasks...[/dim]")
|
|
811
|
+
|
|
812
|
+
llm = get_llm_backend(self.cfg.model.backend)
|
|
813
|
+
pointer = self._pointer_store.load("inference", self._base_model)
|
|
814
|
+
llm.load(
|
|
815
|
+
pointer.base_model,
|
|
816
|
+
max_seq_len=self.cfg.model.max_seq_len,
|
|
817
|
+
dtype=self.cfg.model.dtype,
|
|
818
|
+
trust_remote_code=self.cfg.model.trust_remote_code,
|
|
819
|
+
)
|
|
820
|
+
if pointer.adapter_path:
|
|
821
|
+
llm.apply_adapter(pointer.adapter_path)
|
|
822
|
+
|
|
823
|
+
corpus_rows = load_corpus(self.corpus_path, max_size=int(self.cfg.rlm.corpus_max))
|
|
824
|
+
existing_prompts = [row.get("prompt", "") for row in corpus_rows if row.get("prompt")]
|
|
825
|
+
|
|
826
|
+
tasks = build_tasks(
|
|
827
|
+
llm,
|
|
828
|
+
self.cfg,
|
|
829
|
+
require_recursion=bool(self.cfg.rlm.require_recursion),
|
|
830
|
+
tasks_per_iter=int(self.cfg.rlm.tasks_per_iter),
|
|
831
|
+
mutations_per_task=int(self.cfg.rlm.mutations_per_task),
|
|
832
|
+
max_total=int(self.cfg.rlm.tasks_per_iter),
|
|
833
|
+
existing_prompts=existing_prompts,
|
|
834
|
+
)
|
|
835
|
+
|
|
836
|
+
write_jsonl(run.run_dir / "tasks.jsonl", [task.__dict__ for task in tasks])
|
|
837
|
+
|
|
838
|
+
# Generate rollouts via inference queue
|
|
839
|
+
console.print(f" [dim]Generating {len(tasks) * self.cfg.rlm.rollouts_per_task} rollouts...[/dim]")
|
|
840
|
+
all_rollouts = []
|
|
841
|
+
for i, task in enumerate(tasks):
|
|
842
|
+
rollouts = self._generate_rollout_via_queue(
|
|
843
|
+
task,
|
|
844
|
+
rollouts_per_task=int(self.cfg.rlm.rollouts_per_task),
|
|
845
|
+
)
|
|
846
|
+
all_rollouts.extend(rollouts)
|
|
847
|
+
if (i + 1) % 10 == 0:
|
|
848
|
+
console.print(f" {i + 1}/{len(tasks)} tasks completed")
|
|
849
|
+
|
|
850
|
+
# Save rollouts
|
|
851
|
+
write_jsonl(run.artifacts_dir / "rollouts.jsonl", [
|
|
852
|
+
{
|
|
853
|
+
"task_id": r.task_id,
|
|
854
|
+
"prompt": r.prompt,
|
|
855
|
+
"completion": r.completion,
|
|
856
|
+
"token_ids": r.token_ids,
|
|
857
|
+
"prompt_len": r.prompt_len,
|
|
858
|
+
"logprobs": r.logprobs,
|
|
859
|
+
"passed": r.passed,
|
|
860
|
+
"reward": r.reward,
|
|
861
|
+
}
|
|
862
|
+
for r in all_rollouts
|
|
863
|
+
])
|
|
864
|
+
|
|
865
|
+
# Train via trainer worker (queue)
|
|
866
|
+
console.print(" [dim]Training on rollouts...[/dim]")
|
|
867
|
+
train_msg = self._send_training_batch(all_rollouts, iteration, run.run_dir.name)
|
|
868
|
+
|
|
869
|
+
train_resp = None
|
|
870
|
+
start = time.time()
|
|
871
|
+
while time.time() - start < self._train_timeout_s:
|
|
872
|
+
msg = self.queue.receive("train_complete", timeout=1.0)
|
|
873
|
+
if not msg:
|
|
874
|
+
continue
|
|
875
|
+
if msg.payload.get("request_id") == train_msg.msg_id:
|
|
876
|
+
train_resp = msg
|
|
877
|
+
break
|
|
878
|
+
|
|
879
|
+
if train_resp is None:
|
|
880
|
+
console.print("[red]Trainer timed out[/red]")
|
|
881
|
+
return False
|
|
882
|
+
|
|
883
|
+
train_result = train_resp.payload.get("result") or {}
|
|
884
|
+
checkpoint_path = train_resp.payload.get("checkpoint_path")
|
|
885
|
+
|
|
886
|
+
write_jsonl(
|
|
887
|
+
run.metrics_path,
|
|
888
|
+
[
|
|
889
|
+
{
|
|
890
|
+
"ts": now_ts(),
|
|
891
|
+
"kind": "rlm_train",
|
|
892
|
+
"iteration": iteration,
|
|
893
|
+
"loss": train_result.get("loss"),
|
|
894
|
+
"num_tasks": train_result.get("num_tasks"),
|
|
895
|
+
"num_rollouts": train_result.get("num_rollouts"),
|
|
896
|
+
}
|
|
897
|
+
],
|
|
898
|
+
)
|
|
899
|
+
|
|
900
|
+
if not checkpoint_path:
|
|
901
|
+
console.print("[red]Trainer returned no checkpoint[/red]")
|
|
902
|
+
return False
|
|
903
|
+
|
|
904
|
+
copytree(Path(checkpoint_path), run.adapter_dir)
|
|
905
|
+
|
|
906
|
+
# Drain any weight update notifications from trainer
|
|
907
|
+
self._drain_queue("weight_updates")
|
|
908
|
+
self._drain_queue("checkpoints")
|
|
909
|
+
|
|
910
|
+
# Update corpus
|
|
911
|
+
if self._passed_samples:
|
|
912
|
+
append_corpus(self.corpus_path, self._passed_samples, max_size=int(self.cfg.rlm.corpus_max))
|
|
913
|
+
self._passed_samples = []
|
|
914
|
+
|
|
915
|
+
# Evaluate
|
|
916
|
+
adapter_score = 0.0
|
|
917
|
+
if self.cfg.rlm.benchmark_suite:
|
|
918
|
+
suite_path = self.project_root / self.cfg.rlm.benchmark_suite
|
|
919
|
+
if suite_path.exists():
|
|
920
|
+
eval_path = run_eval(self.project_root, suite_path, run.adapter_dir)
|
|
921
|
+
adapter_score = _score_from_eval(eval_path)
|
|
922
|
+
|
|
923
|
+
holdout_score = None
|
|
924
|
+
if self.cfg.rlm.holdout_suite:
|
|
925
|
+
holdout_path = self.project_root / self.cfg.rlm.holdout_suite
|
|
926
|
+
if holdout_path.exists():
|
|
927
|
+
holdout_eval = run_eval(self.project_root, holdout_path, run.adapter_dir)
|
|
928
|
+
holdout_score = _score_from_eval(holdout_eval)
|
|
929
|
+
|
|
930
|
+
# Gating
|
|
931
|
+
accepted = should_accept(
|
|
932
|
+
adapter_score,
|
|
933
|
+
self.gating_state,
|
|
934
|
+
mode=self.cfg.rlm.gating,
|
|
935
|
+
threshold=float(self.cfg.rlm.gating_threshold),
|
|
936
|
+
ema_alpha=float(self.cfg.rlm.gating_ema_alpha),
|
|
937
|
+
)
|
|
938
|
+
self.gating_state = update_state(
|
|
939
|
+
self.gating_state,
|
|
940
|
+
iteration=iteration,
|
|
941
|
+
score=adapter_score,
|
|
942
|
+
adapter_path=str(run.adapter_dir),
|
|
943
|
+
accepted=accepted,
|
|
944
|
+
ema_alpha=float(self.cfg.rlm.gating_ema_alpha),
|
|
945
|
+
)
|
|
946
|
+
save_state(self.state_path, self.gating_state)
|
|
947
|
+
|
|
948
|
+
# Update weight pointers
|
|
949
|
+
if self.gating_state.current_adapter:
|
|
950
|
+
train_pointer = WeightPointerIPC(
|
|
951
|
+
base_model=self._base_model,
|
|
952
|
+
adapter_path=self.gating_state.current_adapter,
|
|
953
|
+
iteration=iteration,
|
|
954
|
+
updated_at=now_ts(),
|
|
955
|
+
version=iteration,
|
|
956
|
+
name="trainer",
|
|
957
|
+
)
|
|
958
|
+
self._pointer_store.save(train_pointer)
|
|
959
|
+
|
|
960
|
+
# Update inference pointer (hot reload)
|
|
961
|
+
infer_staleness = int(getattr(self.cfg.rlm, "infer_staleness", 0))
|
|
962
|
+
current_infer = self._pointer_store.load("inference", self._base_model)
|
|
963
|
+
update_infer = infer_staleness <= 0
|
|
964
|
+
if not update_infer:
|
|
965
|
+
lag = max(0, int(iteration) - int(current_infer.iteration))
|
|
966
|
+
update_infer = lag >= infer_staleness
|
|
967
|
+
|
|
968
|
+
if update_infer:
|
|
969
|
+
infer_pointer = WeightPointerIPC(
|
|
970
|
+
base_model=self._base_model,
|
|
971
|
+
adapter_path=self.gating_state.current_adapter,
|
|
972
|
+
iteration=iteration,
|
|
973
|
+
updated_at=now_ts(),
|
|
974
|
+
version=iteration,
|
|
975
|
+
name="inference",
|
|
976
|
+
)
|
|
977
|
+
self._pointer_store.save(infer_pointer)
|
|
978
|
+
try:
|
|
979
|
+
self.queue.send(
|
|
980
|
+
"weight_forward",
|
|
981
|
+
MessageType.WEIGHT_UPDATE,
|
|
982
|
+
{
|
|
983
|
+
"adapter_path": self.gating_state.current_adapter,
|
|
984
|
+
"version": iteration,
|
|
985
|
+
"base_model": self._base_model,
|
|
986
|
+
},
|
|
987
|
+
source="orchestrator",
|
|
988
|
+
)
|
|
989
|
+
except Exception as e:
|
|
990
|
+
console.print(f"[yellow]Hot reload trigger failed: {e}[/yellow]")
|
|
991
|
+
|
|
992
|
+
# History
|
|
993
|
+
append_history(
|
|
994
|
+
self.history_path,
|
|
995
|
+
{
|
|
996
|
+
"iteration": iteration,
|
|
997
|
+
"timestamp": now_ts(),
|
|
998
|
+
"adapter_score": adapter_score,
|
|
999
|
+
"holdout_score": holdout_score,
|
|
1000
|
+
"best_score": self.gating_state.best_score,
|
|
1001
|
+
"accepted": accepted,
|
|
1002
|
+
"adapter_dir": str(run.adapter_dir),
|
|
1003
|
+
},
|
|
1004
|
+
)
|
|
1005
|
+
|
|
1006
|
+
gating_path = run.run_dir / "gating.json"
|
|
1007
|
+
gating_path.write_text(
|
|
1008
|
+
json.dumps(
|
|
1009
|
+
{
|
|
1010
|
+
"iteration": iteration,
|
|
1011
|
+
"accepted": accepted,
|
|
1012
|
+
"adapter_score": adapter_score,
|
|
1013
|
+
"holdout_score": holdout_score,
|
|
1014
|
+
"best_score": self.gating_state.best_score,
|
|
1015
|
+
"current_adapter": self.gating_state.current_adapter,
|
|
1016
|
+
},
|
|
1017
|
+
indent=2,
|
|
1018
|
+
),
|
|
1019
|
+
encoding="utf-8",
|
|
1020
|
+
)
|
|
1021
|
+
|
|
1022
|
+
if self.cfg.rlm.sleep_between > 0:
|
|
1023
|
+
time.sleep(float(self.cfg.rlm.sleep_between))
|
|
1024
|
+
|
|
1025
|
+
return True
|
|
1026
|
+
|
|
1027
|
+
def run(self) -> None:
|
|
1028
|
+
"""Run the orchestrated RLM loop."""
|
|
1029
|
+
self._setup_signal_handlers()
|
|
1030
|
+
|
|
1031
|
+
console.print("[bold green]Starting Orchestrated RLM[/bold green]")
|
|
1032
|
+
|
|
1033
|
+
# Start queue manager
|
|
1034
|
+
self.queue.start()
|
|
1035
|
+
|
|
1036
|
+
# Start workers
|
|
1037
|
+
self._start_inference_worker()
|
|
1038
|
+
self._start_trainer_worker()
|
|
1039
|
+
|
|
1040
|
+
console.print("[dim]Waiting for inference server...[/dim]")
|
|
1041
|
+
if not self._wait_for_inference(timeout=120.0):
|
|
1042
|
+
console.print("[red]Inference server failed to start[/red]")
|
|
1043
|
+
self._stop_workers()
|
|
1044
|
+
return
|
|
1045
|
+
console.print("[green]Inference server ready[/green]")
|
|
1046
|
+
|
|
1047
|
+
# Determine iteration range
|
|
1048
|
+
start_iter = self.gating_state.last_iteration + 1 if self.resume else 1
|
|
1049
|
+
total_iters = self.iterations
|
|
1050
|
+
|
|
1051
|
+
def iter_range():
|
|
1052
|
+
if total_iters == 0:
|
|
1053
|
+
i = start_iter
|
|
1054
|
+
while True:
|
|
1055
|
+
yield i
|
|
1056
|
+
i += 1
|
|
1057
|
+
else:
|
|
1058
|
+
for i in range(start_iter, start_iter + total_iters):
|
|
1059
|
+
yield i
|
|
1060
|
+
|
|
1061
|
+
try:
|
|
1062
|
+
for iteration in iter_range():
|
|
1063
|
+
if self._shutdown:
|
|
1064
|
+
break
|
|
1065
|
+
|
|
1066
|
+
success = self.run_iteration(iteration)
|
|
1067
|
+
if not success:
|
|
1068
|
+
console.print("[red]Iteration failed, stopping[/red]")
|
|
1069
|
+
break
|
|
1070
|
+
|
|
1071
|
+
except KeyboardInterrupt:
|
|
1072
|
+
console.print("[yellow]Interrupted by user[/yellow]")
|
|
1073
|
+
except Exception as e:
|
|
1074
|
+
console.print(f"[red]Orchestrator error: {e}[/red]")
|
|
1075
|
+
traceback.print_exc()
|
|
1076
|
+
finally:
|
|
1077
|
+
self._stop_workers()
|
|
1078
|
+
|
|
1079
|
+
|
|
1080
|
+
def run_rlm_orchestrated(
|
|
1081
|
+
project_root: Path,
|
|
1082
|
+
cfg: ProjectConfig,
|
|
1083
|
+
*,
|
|
1084
|
+
model_spec: Optional[str] = None,
|
|
1085
|
+
iterations: Optional[int] = None,
|
|
1086
|
+
resume: bool = False,
|
|
1087
|
+
) -> None:
|
|
1088
|
+
"""Run multi-process orchestrated RLM loop.
|
|
1089
|
+
|
|
1090
|
+
This mode spawns separate inference and trainer processes,
|
|
1091
|
+
coordinating via weight pointers and queue messages.
|
|
1092
|
+
|
|
1093
|
+
Benefits:
|
|
1094
|
+
- Inference server remains responsive during training
|
|
1095
|
+
- Hot-reload of weights without restart
|
|
1096
|
+
- Better resource isolation
|
|
1097
|
+
- Foundation for distributed training
|
|
1098
|
+
"""
|
|
1099
|
+
spec = model_spec or cfg.model.id
|
|
1100
|
+
iters = iterations or cfg.rlm.iterations
|
|
1101
|
+
|
|
1102
|
+
orchestrator = RLMOrchestrator(
|
|
1103
|
+
project_root=project_root,
|
|
1104
|
+
cfg=cfg,
|
|
1105
|
+
model_spec=spec,
|
|
1106
|
+
iterations=iters,
|
|
1107
|
+
resume=resume,
|
|
1108
|
+
)
|
|
1109
|
+
orchestrator.run()
|
|
1110
|
+
|
|
1111
|
+
|
|
1112
|
+
def collect_rollouts_via_api(
|
|
1113
|
+
tasks: List[GeneratedTask],
|
|
1114
|
+
cfg: ProjectConfig,
|
|
1115
|
+
api_url: str,
|
|
1116
|
+
artifacts_dir: Path,
|
|
1117
|
+
verifier_backend: str,
|
|
1118
|
+
weight_adapter: Optional[str],
|
|
1119
|
+
) -> tuple[List[Rollout], list[dict]]:
|
|
1120
|
+
"""Collect rollouts via inference API (for legacy loop with external inference)."""
|
|
1121
|
+
try:
|
|
1122
|
+
import requests
|
|
1123
|
+
except ModuleNotFoundError:
|
|
1124
|
+
requests = None
|
|
1125
|
+
rollouts: List[Rollout] = []
|
|
1126
|
+
passed_samples: list[dict] = []
|
|
1127
|
+
|
|
1128
|
+
# Probe the API endpoint; fall back to local inference if unreachable.
|
|
1129
|
+
_api_available = False
|
|
1130
|
+
if requests is not None:
|
|
1131
|
+
try:
|
|
1132
|
+
requests.get(api_url, timeout=2.0)
|
|
1133
|
+
_api_available = True
|
|
1134
|
+
except Exception:
|
|
1135
|
+
_api_available = False
|
|
1136
|
+
|
|
1137
|
+
if not _api_available:
|
|
1138
|
+
# Resolve model spec to separate base model from adapter.
|
|
1139
|
+
base_model, resolved_adapter, _meta = resolve_model_spec(
|
|
1140
|
+
Path.cwd(), cfg.model.id, cfg
|
|
1141
|
+
)
|
|
1142
|
+
llm = get_llm_backend(cfg.model.backend)
|
|
1143
|
+
llm.load(
|
|
1144
|
+
base_model,
|
|
1145
|
+
max_seq_len=cfg.model.max_seq_len,
|
|
1146
|
+
dtype=cfg.model.dtype,
|
|
1147
|
+
trust_remote_code=cfg.model.trust_remote_code,
|
|
1148
|
+
)
|
|
1149
|
+
adapter_to_apply = weight_adapter or (str(resolved_adapter) if resolved_adapter else None)
|
|
1150
|
+
if adapter_to_apply:
|
|
1151
|
+
llm.apply_adapter(adapter_to_apply)
|
|
1152
|
+
for task in tasks:
|
|
1153
|
+
for k in range(int(cfg.rlm.rollouts_per_task)):
|
|
1154
|
+
try:
|
|
1155
|
+
gen = llm.generate_with_logprobs(
|
|
1156
|
+
task.prompt,
|
|
1157
|
+
max_new_tokens=int(cfg.rft.max_new_tokens),
|
|
1158
|
+
temperature=float(cfg.rft.temperature),
|
|
1159
|
+
top_p=float(cfg.infer.top_p),
|
|
1160
|
+
top_k=cfg.infer.top_k,
|
|
1161
|
+
seed=int(time.time() * 1000) % (2**31 - 1),
|
|
1162
|
+
)
|
|
1163
|
+
except TypeError:
|
|
1164
|
+
gen = llm.generate_with_logprobs(
|
|
1165
|
+
task.prompt,
|
|
1166
|
+
max_new_tokens=int(cfg.rft.max_new_tokens),
|
|
1167
|
+
temperature=float(cfg.rft.temperature),
|
|
1168
|
+
top_p=float(cfg.infer.top_p),
|
|
1169
|
+
top_k_sampling=cfg.infer.top_k,
|
|
1170
|
+
seed=int(time.time() * 1000) % (2**31 - 1),
|
|
1171
|
+
)
|
|
1172
|
+
completion = gen.text[len(task.prompt) :] if gen.text.startswith(task.prompt) else gen.text
|
|
1173
|
+
wdir = ensure_dir(artifacts_dir / task.id / f"rollout_{k:02d}")
|
|
1174
|
+
(wdir / "main.py").write_text(completion, encoding="utf-8")
|
|
1175
|
+
(ensure_dir(wdir / "tests") / "test_task.py").write_text(task.tests, encoding="utf-8")
|
|
1176
|
+
t0 = time.time()
|
|
1177
|
+
if verifier_backend == "docker":
|
|
1178
|
+
res = docker_verify(
|
|
1179
|
+
task.prompt,
|
|
1180
|
+
completion,
|
|
1181
|
+
str(wdir),
|
|
1182
|
+
timeout_s=int(cfg.rlm.verifier_timeout_s),
|
|
1183
|
+
image=cfg.rlm.docker_image,
|
|
1184
|
+
memory_mb=int(cfg.rlm.docker_memory_mb),
|
|
1185
|
+
cpus=float(cfg.rlm.docker_cpus),
|
|
1186
|
+
pids=int(cfg.rlm.docker_pids),
|
|
1187
|
+
)
|
|
1188
|
+
else:
|
|
1189
|
+
res = pytest_verify(
|
|
1190
|
+
task.prompt,
|
|
1191
|
+
completion,
|
|
1192
|
+
str(wdir),
|
|
1193
|
+
timeout_s=int(cfg.rlm.verifier_timeout_s),
|
|
1194
|
+
)
|
|
1195
|
+
latency_ms = (time.time() - t0) * 1000.0
|
|
1196
|
+
passed = bool(getattr(res, "passed", False))
|
|
1197
|
+
reward = float(getattr(res, "reward", 0.0))
|
|
1198
|
+
rollouts.append(
|
|
1199
|
+
Rollout(
|
|
1200
|
+
task_id=task.id,
|
|
1201
|
+
prompt=task.prompt,
|
|
1202
|
+
completion=completion,
|
|
1203
|
+
token_ids=list(gen.token_ids),
|
|
1204
|
+
prompt_len=gen.prompt_len,
|
|
1205
|
+
logprobs=list(gen.logprobs) if gen.logprobs else None,
|
|
1206
|
+
passed=passed,
|
|
1207
|
+
reward=reward,
|
|
1208
|
+
verifier_latency_ms=latency_ms,
|
|
1209
|
+
weight_adapter=weight_adapter,
|
|
1210
|
+
)
|
|
1211
|
+
)
|
|
1212
|
+
if passed:
|
|
1213
|
+
passed_samples.append(
|
|
1214
|
+
{
|
|
1215
|
+
"id": task.id,
|
|
1216
|
+
"prompt": task.prompt,
|
|
1217
|
+
"response": completion,
|
|
1218
|
+
"reward": reward,
|
|
1219
|
+
"ts": now_ts(),
|
|
1220
|
+
}
|
|
1221
|
+
)
|
|
1222
|
+
return rollouts, passed_samples
|
|
1223
|
+
|
|
1224
|
+
for task in tasks:
|
|
1225
|
+
for k in range(int(cfg.rlm.rollouts_per_task)):
|
|
1226
|
+
try:
|
|
1227
|
+
resp = requests.post(
|
|
1228
|
+
f"{api_url}/internal/rollout",
|
|
1229
|
+
json={
|
|
1230
|
+
"prompt": task.prompt,
|
|
1231
|
+
"max_tokens": int(cfg.rft.max_new_tokens),
|
|
1232
|
+
"temperature": float(cfg.rft.temperature),
|
|
1233
|
+
"top_p": float(cfg.infer.top_p),
|
|
1234
|
+
"top_k": cfg.infer.top_k,
|
|
1235
|
+
"seed": int(time.time() * 1000) % (2**31 - 1),
|
|
1236
|
+
"include_tokens": True,
|
|
1237
|
+
"include_logprobs": True,
|
|
1238
|
+
},
|
|
1239
|
+
timeout=120.0,
|
|
1240
|
+
)
|
|
1241
|
+
|
|
1242
|
+
if resp.status_code != 200:
|
|
1243
|
+
continue
|
|
1244
|
+
|
|
1245
|
+
data = resp.json()
|
|
1246
|
+
completion = data.get("completion", "")
|
|
1247
|
+
|
|
1248
|
+
wdir = ensure_dir(artifacts_dir / task.id / f"rollout_{k:02d}")
|
|
1249
|
+
(wdir / "main.py").write_text(completion, encoding="utf-8")
|
|
1250
|
+
(ensure_dir(wdir / "tests") / "test_task.py").write_text(task.tests, encoding="utf-8")
|
|
1251
|
+
|
|
1252
|
+
t0 = time.time()
|
|
1253
|
+
if verifier_backend == "docker":
|
|
1254
|
+
res = docker_verify(
|
|
1255
|
+
task.prompt,
|
|
1256
|
+
completion,
|
|
1257
|
+
str(wdir),
|
|
1258
|
+
timeout_s=int(cfg.rlm.verifier_timeout_s),
|
|
1259
|
+
image=cfg.rlm.docker_image,
|
|
1260
|
+
memory_mb=int(cfg.rlm.docker_memory_mb),
|
|
1261
|
+
cpus=float(cfg.rlm.docker_cpus),
|
|
1262
|
+
pids=int(cfg.rlm.docker_pids),
|
|
1263
|
+
)
|
|
1264
|
+
else:
|
|
1265
|
+
res = pytest_verify(task.prompt, completion, str(wdir), timeout_s=int(cfg.rlm.verifier_timeout_s))
|
|
1266
|
+
latency_ms = (time.time() - t0) * 1000.0
|
|
1267
|
+
|
|
1268
|
+
passed = bool(getattr(res, "passed", False))
|
|
1269
|
+
reward = float(getattr(res, "reward", 0.0))
|
|
1270
|
+
|
|
1271
|
+
rollouts.append(Rollout(
|
|
1272
|
+
task_id=task.id,
|
|
1273
|
+
prompt=task.prompt,
|
|
1274
|
+
completion=completion,
|
|
1275
|
+
token_ids=data.get("token_ids", []),
|
|
1276
|
+
prompt_len=data.get("prompt_len", 0),
|
|
1277
|
+
logprobs=data.get("logprobs"),
|
|
1278
|
+
passed=passed,
|
|
1279
|
+
reward=reward,
|
|
1280
|
+
verifier_latency_ms=latency_ms,
|
|
1281
|
+
weight_adapter=weight_adapter,
|
|
1282
|
+
))
|
|
1283
|
+
|
|
1284
|
+
if passed:
|
|
1285
|
+
passed_samples.append({
|
|
1286
|
+
"id": task.id,
|
|
1287
|
+
"prompt": task.prompt,
|
|
1288
|
+
"response": completion,
|
|
1289
|
+
"reward": reward,
|
|
1290
|
+
"ts": now_ts(),
|
|
1291
|
+
})
|
|
1292
|
+
|
|
1293
|
+
except Exception as e:
|
|
1294
|
+
console.print(f"[red]Rollout error: {e}[/red]")
|
|
1295
|
+
continue
|
|
1296
|
+
|
|
1297
|
+
return rollouts, passed_samples
|