mlxsmith 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. mlxsmith/__init__.py +2 -0
  2. mlxsmith/accel/__init__.py +10 -0
  3. mlxsmith/accel/base.py +17 -0
  4. mlxsmith/accel/none.py +13 -0
  5. mlxsmith/accel/zmlx_backend.py +42 -0
  6. mlxsmith/adapters.py +46 -0
  7. mlxsmith/api/__init__.py +48 -0
  8. mlxsmith/api/handlers.py +1217 -0
  9. mlxsmith/api/schemas.py +436 -0
  10. mlxsmith/auth.py +88 -0
  11. mlxsmith/bench.py +102 -0
  12. mlxsmith/cli.py +950 -0
  13. mlxsmith/config.py +543 -0
  14. mlxsmith/config_models.py +261 -0
  15. mlxsmith/data.py +493 -0
  16. mlxsmith/envs/__init__.py +33 -0
  17. mlxsmith/envs/system.py +388 -0
  18. mlxsmith/envs/token_env.py +191 -0
  19. mlxsmith/eval.py +112 -0
  20. mlxsmith/infer.py +140 -0
  21. mlxsmith/llm/__init__.py +16 -0
  22. mlxsmith/llm/backend.py +126 -0
  23. mlxsmith/llm/interface.py +212 -0
  24. mlxsmith/llm/mlx_lm_backend.py +509 -0
  25. mlxsmith/llm/mock_backend.py +228 -0
  26. mlxsmith/llm/registry.py +12 -0
  27. mlxsmith/models.py +257 -0
  28. mlxsmith/orchestrator/__init__.py +25 -0
  29. mlxsmith/orchestrator/daemon.py +454 -0
  30. mlxsmith/orchestrator/inference_worker.py +496 -0
  31. mlxsmith/orchestrator/queue.py +355 -0
  32. mlxsmith/orchestrator/trainer_worker.py +437 -0
  33. mlxsmith/rlm/__init__.py +8 -0
  34. mlxsmith/rlm/corpus.py +74 -0
  35. mlxsmith/rlm/gating.py +90 -0
  36. mlxsmith/rlm/generate.py +249 -0
  37. mlxsmith/rlm/history.py +12 -0
  38. mlxsmith/rlm/inference.py +150 -0
  39. mlxsmith/rlm/loop.py +1297 -0
  40. mlxsmith/rlm/mutate.py +82 -0
  41. mlxsmith/rlm/trainer.py +73 -0
  42. mlxsmith/rlm/weights.py +263 -0
  43. mlxsmith/runs.py +44 -0
  44. mlxsmith/sdk/__init__.py +392 -0
  45. mlxsmith/sdk/future.py +486 -0
  46. mlxsmith/sdk/losses.py +262 -0
  47. mlxsmith/sdk/sampling_client.py +729 -0
  48. mlxsmith/sdk/training_client.py +676 -0
  49. mlxsmith/server.py +376 -0
  50. mlxsmith/train/__init__.py +0 -0
  51. mlxsmith/train/distill.py +279 -0
  52. mlxsmith/train/lora.py +280 -0
  53. mlxsmith/train/pref.py +180 -0
  54. mlxsmith/train/rft.py +458 -0
  55. mlxsmith/train/sft.py +151 -0
  56. mlxsmith/util.py +174 -0
  57. mlxsmith/verifiers/__init__.py +3 -0
  58. mlxsmith/verifiers/compose.py +109 -0
  59. mlxsmith/verifiers/docker_verifier.py +111 -0
  60. mlxsmith/verifiers/jsonschema.py +54 -0
  61. mlxsmith/verifiers/pytest_verifier.py +82 -0
  62. mlxsmith/verifiers/regex.py +15 -0
  63. mlxsmith/verifiers/types.py +10 -0
  64. mlxsmith-0.1.0.dist-info/METADATA +163 -0
  65. mlxsmith-0.1.0.dist-info/RECORD +69 -0
  66. mlxsmith-0.1.0.dist-info/WHEEL +5 -0
  67. mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
  68. mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
  69. mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
mlxsmith/rlm/loop.py ADDED
@@ -0,0 +1,1297 @@
1
+ """RLM Loop - Orchestrator Entry Point for MLXSmith.
2
+
3
+ This module provides both:
4
+ 1. Legacy single-process RLM loop (run_rlm)
5
+ 2. Multi-process orchestrator mode (run_rlm_orchestrated)
6
+
7
+ The orchestrator mode splits the RLM loop into:
8
+ - Orchestrator Daemon: Queue-based job scheduler
9
+ - Inference Worker Process: OpenAI-compatible API server
10
+ - Trainer Worker Process: Training batch consumer
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import json
16
+ import multiprocessing as mp
17
+ import signal
18
+ import sys
19
+ import time
20
+ import traceback
21
+ from dataclasses import dataclass, asdict
22
+ from pathlib import Path
23
+ from typing import Any, Dict, List, Optional, Callable
24
+
25
+ from rich.console import Console
26
+
27
+ from ..config import ProjectConfig
28
+ from ..eval import run_eval
29
+ from ..llm.registry import get_llm_backend
30
+ from ..models import resolve_model_spec
31
+ from ..runs import new_run, snapshot_config
32
+ from ..train.lora import LoRAConfig
33
+ from ..util import copytree, ensure_dir, now_ts, write_jsonl
34
+ from ..verifiers.docker_verifier import verify as docker_verify
35
+ from ..verifiers.pytest_verifier import verify as pytest_verify
36
+ from .corpus import append_corpus, load_corpus, sample_corpus
37
+ from .gating import load_state, save_state, should_accept, update_state
38
+ from .generate import GeneratedTask, generate_tasks, filter_tasks
39
+ from .history import append_history
40
+ from .inference import Rollout, build_tasks
41
+ from .mutate import mutate_tasks
42
+ from .trainer import train_on_rollouts
43
+ from .weights import (
44
+ WeightPointer,
45
+ WeightPointerIPC,
46
+ WeightPointerStore,
47
+ load_pointer,
48
+ save_pointer,
49
+ )
50
+
51
+ console = Console()
52
+
53
+
54
+ def _score_from_eval(result_path: Path) -> float:
55
+ try:
56
+ data = json.loads(result_path.read_text(encoding="utf-8"))
57
+ summary = data.get("summary") or []
58
+ if not summary:
59
+ return 0.0
60
+ return sum(item.get("pass@k", 0.0) for item in summary) / max(1, len(summary))
61
+ except Exception:
62
+ return 0.0
63
+
64
+
65
+ def _load_suite_prompts(path: Path) -> list[str]:
66
+ try:
67
+ import yaml
68
+
69
+ suite = yaml.safe_load(path.read_text(encoding="utf-8")) or {}
70
+ tasks = suite.get("tasks") or []
71
+ return [str(t.get("prompt")) for t in tasks if t.get("prompt")]
72
+ except Exception:
73
+ return []
74
+
75
+
76
+ # =============================================================================
77
+ # Legacy Single-Process RLM Loop
78
+ # =============================================================================
79
+
80
+ def run_rlm(
81
+ project_root: Path,
82
+ cfg: ProjectConfig,
83
+ *,
84
+ model_spec: Optional[str] = None,
85
+ iterations: Optional[int] = None,
86
+ resume: bool = False,
87
+ ) -> None:
88
+ """Run single-process RLM loop (legacy mode)."""
89
+ rlm_cfg = cfg.rlm
90
+ state_path = project_root / "runs" / "rlm_state.json"
91
+ history_path = project_root / "runs" / "rlm_history.jsonl"
92
+ corpus_path = project_root / "runs" / "rlm_corpus.jsonl"
93
+ weights_dir = ensure_dir(project_root / "runs" / "rlm_weights")
94
+
95
+ state = load_state(state_path)
96
+
97
+ if model_spec is None:
98
+ model_spec = cfg.model.id
99
+
100
+ base_model, initial_adapter, _meta = resolve_model_spec(project_root, model_spec, cfg)
101
+ infer_ptr_path = weights_dir / "infer.json"
102
+ train_ptr_path = weights_dir / "train.json"
103
+ infer_ptr = load_pointer(infer_ptr_path, base_model=base_model, name="inference")
104
+ train_ptr = load_pointer(train_ptr_path, base_model=base_model, name="trainer")
105
+
106
+ if initial_adapter and not infer_ptr.adapter_path:
107
+ infer_ptr = WeightPointer(
108
+ base_model=base_model,
109
+ adapter_path=str(initial_adapter),
110
+ iteration=state.last_iteration,
111
+ updated_at=now_ts(),
112
+ name="inference",
113
+ )
114
+ save_pointer(infer_ptr_path, infer_ptr)
115
+
116
+ if initial_adapter and not train_ptr.adapter_path:
117
+ train_ptr = WeightPointer(
118
+ base_model=base_model,
119
+ adapter_path=str(initial_adapter),
120
+ iteration=state.last_iteration,
121
+ updated_at=now_ts(),
122
+ name="trainer",
123
+ )
124
+ save_pointer(train_ptr_path, train_ptr)
125
+
126
+ if resume and state.current_adapter:
127
+ train_ptr = WeightPointer(
128
+ base_model=base_model,
129
+ adapter_path=state.current_adapter,
130
+ iteration=state.last_iteration,
131
+ updated_at=now_ts(),
132
+ name="trainer",
133
+ )
134
+ save_pointer(train_ptr_path, train_ptr)
135
+ if not infer_ptr.adapter_path:
136
+ infer_ptr = WeightPointer(
137
+ base_model=base_model,
138
+ adapter_path=state.current_adapter,
139
+ iteration=state.last_iteration,
140
+ updated_at=now_ts(),
141
+ name="inference",
142
+ )
143
+ save_pointer(infer_ptr_path, infer_ptr)
144
+
145
+ start_iter = state.last_iteration + 1 if resume else 1
146
+ total_iters = iterations if iterations is not None else int(rlm_cfg.iterations)
147
+
148
+ def iter_range():
149
+ if total_iters == 0:
150
+ i = start_iter
151
+ while True:
152
+ yield i
153
+ i += 1
154
+ else:
155
+ for i in range(start_iter, start_iter + total_iters):
156
+ yield i
157
+
158
+ for iteration in iter_range():
159
+ run = new_run(project_root, "rlm")
160
+ snapshot_config(cfg.model_dump(), run.config_snapshot_path)
161
+ console.print(f"[bold]RLM[/bold] iteration {iteration} run={run.run_dir.name}")
162
+
163
+ infer_llm = get_llm_backend(cfg.model.backend)
164
+ infer_llm.load(
165
+ infer_ptr.base_model,
166
+ max_seq_len=cfg.model.max_seq_len,
167
+ dtype=cfg.model.dtype,
168
+ trust_remote_code=cfg.model.trust_remote_code,
169
+ )
170
+ if infer_ptr.adapter_path:
171
+ infer_llm.apply_adapter(str(infer_ptr.adapter_path))
172
+
173
+ train_llm = get_llm_backend(cfg.model.backend)
174
+ train_llm.load(
175
+ train_ptr.base_model,
176
+ max_seq_len=cfg.model.max_seq_len,
177
+ dtype=cfg.model.dtype,
178
+ trust_remote_code=cfg.model.trust_remote_code,
179
+ )
180
+ if train_ptr.adapter_path:
181
+ train_llm.apply_adapter(str(train_ptr.adapter_path))
182
+ else:
183
+ lora_cfg = LoRAConfig(
184
+ r=cfg.lora.r,
185
+ alpha=cfg.lora.alpha,
186
+ dropout=cfg.lora.dropout,
187
+ target_modules=list(cfg.lora.target_modules or []),
188
+ num_layers=cfg.lora.num_layers,
189
+ scale=cfg.lora.scale,
190
+ fine_tune_type=cfg.lora.fine_tune_type,
191
+ )
192
+ train_llm.apply_lora_from_config(lora_cfg)
193
+
194
+ ref_llm = None
195
+ if cfg.rft.reference_model:
196
+ ref_llm = get_llm_backend(cfg.model.backend)
197
+ ref_llm.load(
198
+ cfg.rft.reference_model,
199
+ max_seq_len=cfg.model.max_seq_len,
200
+ dtype=cfg.model.dtype,
201
+ trust_remote_code=cfg.model.trust_remote_code,
202
+ )
203
+
204
+ opt, _params = train_llm.optimizer_and_params(lr=cfg.train.lr, weight_decay=cfg.train.weight_decay)
205
+
206
+ corpus_rows = load_corpus(corpus_path, max_size=int(rlm_cfg.corpus_max))
207
+ existing_prompts = [row.get("prompt", "") for row in corpus_rows if row.get("prompt")]
208
+ if rlm_cfg.benchmark_suite:
209
+ suite_path = project_root / rlm_cfg.benchmark_suite
210
+ if suite_path.exists():
211
+ existing_prompts.extend(_load_suite_prompts(suite_path))
212
+ if rlm_cfg.holdout_suite:
213
+ holdout_path = project_root / rlm_cfg.holdout_suite
214
+ if holdout_path.exists():
215
+ existing_prompts.extend(_load_suite_prompts(holdout_path))
216
+
217
+ tasks = build_tasks(
218
+ infer_llm,
219
+ cfg,
220
+ require_recursion=bool(rlm_cfg.require_recursion),
221
+ tasks_per_iter=int(rlm_cfg.tasks_per_iter),
222
+ mutations_per_task=int(rlm_cfg.mutations_per_task),
223
+ max_total=int(rlm_cfg.tasks_per_iter),
224
+ existing_prompts=existing_prompts,
225
+ )
226
+
227
+ write_jsonl(run.run_dir / "tasks.jsonl", [task.__dict__ for task in tasks])
228
+ rollouts, passed_samples = collect_rollouts_via_api(
229
+ tasks,
230
+ cfg,
231
+ api_url=f"http://localhost:{cfg.serve.port}",
232
+ artifacts_dir=run.artifacts_dir,
233
+ verifier_backend=str(rlm_cfg.verifier_backend),
234
+ weight_adapter=infer_ptr.adapter_path,
235
+ )
236
+
237
+ metrics_rows = train_on_rollouts(
238
+ train_llm,
239
+ rollouts,
240
+ cfg,
241
+ optimizer=opt,
242
+ train_adapter=train_ptr.adapter_path,
243
+ ref_llm=ref_llm,
244
+ )
245
+ for row in metrics_rows:
246
+ row["iteration"] = iteration
247
+ write_jsonl(run.metrics_path, metrics_rows)
248
+
249
+ if passed_samples:
250
+ append_corpus(corpus_path, passed_samples, max_size=int(rlm_cfg.corpus_max))
251
+
252
+ # Optional corpus rehearsal via SFT
253
+ mix_ratio = float(rlm_cfg.mix_old_ratio)
254
+ if mix_ratio > 0 and corpus_rows:
255
+ n_samples = int(max(1, len(tasks) * mix_ratio))
256
+ for row in sample_corpus(corpus_rows, n=n_samples, hard_ratio=float(rlm_cfg.hard_ratio)):
257
+ prompt = row.get("prompt", "")
258
+ response = row.get("response", "")
259
+ if not prompt or not response:
260
+ continue
261
+ prompt_ids = train_llm.encode(prompt)
262
+ ids = train_llm.encode(prompt + response)
263
+ max_len = int(cfg.model.max_seq_len)
264
+ if max_len and len(ids) > max_len:
265
+ overflow = len(ids) - max_len
266
+ ids = ids[overflow:]
267
+ prompt_ids = prompt_ids[overflow:] if overflow < len(prompt_ids) else []
268
+
269
+ def sft_loss_fn(_model):
270
+ return train_llm.sft_loss(ids, train_on_prompt=cfg.train.train_on_prompt, prompt_len=len(prompt_ids))
271
+
272
+ lval, grads = train_llm.value_and_grad(sft_loss_fn)
273
+ if grads is not None:
274
+ train_llm.apply_grads(opt, grads)
275
+
276
+ train_llm.save_adapter(
277
+ str(run.adapter_dir),
278
+ metadata={
279
+ "base_model": train_ptr.base_model,
280
+ "source_adapter": str(train_ptr.adapter_path) if train_ptr.adapter_path else None,
281
+ "run": run.run_dir.name,
282
+ "kind": "rlm",
283
+ "iteration": iteration,
284
+ },
285
+ )
286
+
287
+ # Evaluate
288
+ adapter_score = 0.0
289
+ if rlm_cfg.benchmark_suite:
290
+ suite_path = project_root / rlm_cfg.benchmark_suite
291
+ if suite_path.exists():
292
+ eval_path = run_eval(project_root, suite_path, run.adapter_dir)
293
+ adapter_score = _score_from_eval(eval_path)
294
+
295
+ holdout_score = None
296
+ if rlm_cfg.holdout_suite:
297
+ holdout_path = project_root / rlm_cfg.holdout_suite
298
+ if holdout_path.exists():
299
+ holdout_eval = run_eval(project_root, holdout_path, run.adapter_dir)
300
+ holdout_score = _score_from_eval(holdout_eval)
301
+
302
+ accepted = should_accept(
303
+ adapter_score,
304
+ state,
305
+ mode=rlm_cfg.gating,
306
+ threshold=float(rlm_cfg.gating_threshold),
307
+ ema_alpha=float(rlm_cfg.gating_ema_alpha),
308
+ )
309
+ state = update_state(
310
+ state,
311
+ iteration=iteration,
312
+ score=adapter_score,
313
+ adapter_path=str(run.adapter_dir),
314
+ accepted=accepted,
315
+ ema_alpha=float(rlm_cfg.gating_ema_alpha),
316
+ )
317
+ save_state(state_path, state)
318
+
319
+ if state.current_adapter:
320
+ train_ptr = WeightPointer(
321
+ base_model=base_model,
322
+ adapter_path=state.current_adapter,
323
+ iteration=iteration,
324
+ updated_at=now_ts(),
325
+ name="trainer",
326
+ )
327
+ save_pointer(train_ptr_path, train_ptr)
328
+
329
+ infer_staleness = int(getattr(rlm_cfg, "infer_staleness", 0))
330
+ if infer_staleness <= 0:
331
+ infer_ptr = train_ptr
332
+ save_pointer(infer_ptr_path, infer_ptr)
333
+ else:
334
+ lag = max(0, int(train_ptr.iteration) - int(infer_ptr.iteration))
335
+ if lag >= infer_staleness:
336
+ infer_ptr = train_ptr
337
+ save_pointer(infer_ptr_path, infer_ptr)
338
+
339
+ append_history(
340
+ history_path,
341
+ {
342
+ "iteration": iteration,
343
+ "timestamp": now_ts(),
344
+ "adapter_score": adapter_score,
345
+ "holdout_score": holdout_score,
346
+ "best_score": state.best_score,
347
+ "accepted": accepted,
348
+ "adapter_dir": str(run.adapter_dir),
349
+ },
350
+ )
351
+
352
+ gating_path = run.run_dir / "gating.json"
353
+ gating_path.write_text(
354
+ json.dumps(
355
+ {
356
+ "iteration": iteration,
357
+ "accepted": accepted,
358
+ "adapter_score": adapter_score,
359
+ "holdout_score": holdout_score,
360
+ "best_score": state.best_score,
361
+ "current_adapter": state.current_adapter,
362
+ },
363
+ indent=2,
364
+ ),
365
+ encoding="utf-8",
366
+ )
367
+
368
+ if rlm_cfg.sleep_between > 0:
369
+ time.sleep(float(rlm_cfg.sleep_between))
370
+
371
+
372
+ # =============================================================================
373
+ # Multi-Process Orchestrated RLM
374
+ # =============================================================================
375
+
376
+ from ..orchestrator.queue import MessageQueue, MessageType, Message
377
+ from ..orchestrator.inference_worker import InferenceConfig, run_inference_worker
378
+ from ..orchestrator.trainer_worker import TrainerConfig, run_trainer_worker
379
+
380
+
381
+ @dataclass
382
+ class OrchestratorState:
383
+ """State for the orchestrated RLM loop."""
384
+ iteration: int = 0
385
+ run_id: str = ""
386
+ pending_rollouts: int = 0
387
+ pending_training: bool = False
388
+ current_adapter: Optional[str] = None
389
+ best_score: float = 0.0
390
+
391
+
392
+ class RLMOrchestrator:
393
+ """Multi-process RLM orchestrator.
394
+
395
+ Spawns and manages inference and trainer processes,
396
+ coordinates rollout generation and training via queues.
397
+ """
398
+
399
+ def __init__(
400
+ self,
401
+ project_root: Path,
402
+ cfg: ProjectConfig,
403
+ model_spec: str,
404
+ iterations: int = 50,
405
+ resume: bool = False,
406
+ ):
407
+ self.project_root = project_root
408
+ self.cfg = cfg
409
+ self.model_spec = model_spec
410
+ self.iterations = iterations
411
+ self.resume = resume
412
+
413
+ self._base_model, self._initial_adapter, _ = resolve_model_spec(
414
+ self.project_root, self.model_spec, self.cfg
415
+ )
416
+ self._rollout_timeout_s = 120.0
417
+ self._train_timeout_s = 900.0
418
+
419
+ # Paths
420
+ self.state_path = project_root / "runs" / "rlm_state.json"
421
+ self.history_path = project_root / "runs" / "rlm_history.jsonl"
422
+ self.corpus_path = project_root / "runs" / "rlm_corpus.jsonl"
423
+ self.weights_dir = ensure_dir(project_root / "runs" / "rlm_weights")
424
+
425
+ # State
426
+ self.gating_state = load_state(self.state_path)
427
+ self.orchestrator_state = OrchestratorState()
428
+
429
+ # IPC
430
+ self.queue = MessageQueue(maxsize=10000)
431
+ self._pointer_store = WeightPointerStore(self.weights_dir)
432
+
433
+ # Processes
434
+ self._inference_process: Optional[mp.Process] = None
435
+ self._trainer_process: Optional[mp.Process] = None
436
+ self._shutdown = False
437
+
438
+ # Rollout buffer
439
+ self._rollout_buffer: List[Rollout] = []
440
+ self._passed_samples: List[Dict] = []
441
+
442
+ def _setup_signal_handlers(self) -> None:
443
+ """Setup signal handlers for graceful shutdown."""
444
+ def signal_handler(sig, frame):
445
+ console.print("[yellow]Orchestrator received shutdown signal[/yellow]")
446
+ self._shutdown = True
447
+
448
+ signal.signal(signal.SIGTERM, signal_handler)
449
+ signal.signal(signal.SIGINT, signal_handler)
450
+
451
+ def _start_inference_worker(self) -> None:
452
+ """Start the inference worker process."""
453
+ inf_config = InferenceConfig(
454
+ model_spec=self.model_spec,
455
+ backend=self.cfg.model.backend,
456
+ host=self.cfg.serve.host,
457
+ port=self.cfg.serve.port,
458
+ max_seq_len=self.cfg.model.max_seq_len,
459
+ dtype=self.cfg.model.dtype,
460
+ trust_remote_code=self.cfg.model.trust_remote_code,
461
+ use_chat_template=self.cfg.model.use_chat_template,
462
+ weights_dir=self.weights_dir,
463
+ hot_reload=True,
464
+ )
465
+
466
+ # Initialize inference pointer
467
+ pointer = WeightPointerIPC(
468
+ base_model=self._base_model,
469
+ adapter_path=str(self._initial_adapter) if self._initial_adapter else None,
470
+ iteration=self.gating_state.last_iteration,
471
+ updated_at=now_ts(),
472
+ version=self.gating_state.last_iteration,
473
+ name="inference",
474
+ )
475
+ self._pointer_store.save(pointer)
476
+
477
+ self._inference_process = mp.Process(
478
+ target=run_inference_worker,
479
+ args=(inf_config, self.queue),
480
+ name="inference_worker",
481
+ daemon=False,
482
+ )
483
+ self._inference_process.start()
484
+ console.print(f"[green]Started inference worker (PID: {self._inference_process.pid})[/green]")
485
+
486
+ def _start_trainer_worker(self) -> None:
487
+ """Start the trainer worker process."""
488
+ trainer_config = TrainerConfig(
489
+ model_spec=self.model_spec,
490
+ base_model=self._base_model,
491
+ backend=self.cfg.model.backend,
492
+ max_seq_len=self.cfg.model.max_seq_len,
493
+ dtype=self.cfg.model.dtype,
494
+ trust_remote_code=self.cfg.model.trust_remote_code,
495
+ lr=self.cfg.train.lr,
496
+ weight_decay=self.cfg.train.weight_decay,
497
+ kl_coeff=self.cfg.rft.kl_coeff,
498
+ normalize_advantage=self.cfg.rft.normalize_advantage,
499
+ lora_r=self.cfg.lora.r,
500
+ lora_alpha=self.cfg.lora.alpha,
501
+ lora_dropout=self.cfg.lora.dropout,
502
+ lora_target_modules=list(self.cfg.lora.target_modules or []),
503
+ lora_num_layers=self.cfg.lora.num_layers,
504
+ weights_dir=self.weights_dir,
505
+ checkpoint_dir=self.project_root / "runs" / "rlm_checkpoints",
506
+ reference_model=self.cfg.rft.reference_model,
507
+ )
508
+
509
+ # Initialize trainer pointer
510
+ pointer = WeightPointerIPC(
511
+ base_model=self._base_model,
512
+ adapter_path=str(self._initial_adapter) if self._initial_adapter else None,
513
+ iteration=self.gating_state.last_iteration,
514
+ updated_at=now_ts(),
515
+ version=self.gating_state.last_iteration,
516
+ name="trainer",
517
+ )
518
+ self._pointer_store.save(pointer)
519
+
520
+ self._trainer_process = mp.Process(
521
+ target=run_trainer_worker,
522
+ args=(trainer_config, self.queue),
523
+ name="trainer_worker",
524
+ daemon=False,
525
+ )
526
+ self._trainer_process.start()
527
+ console.print(f"[green]Started trainer worker (PID: {self._trainer_process.pid})[/green]")
528
+
529
+ def _stop_workers(self) -> None:
530
+ """Stop all worker processes."""
531
+ console.print("[yellow]Stopping workers...[/yellow]")
532
+ try:
533
+ self.queue.send("rollout_requests", MessageType.SHUTDOWN, {}, source="orchestrator")
534
+ self.queue.send("control", MessageType.SHUTDOWN, {}, source="orchestrator")
535
+ except Exception:
536
+ pass
537
+
538
+ # Terminate processes
539
+ for name, proc in [("inference", self._inference_process), ("trainer", self._trainer_process)]:
540
+ if proc and proc.is_alive():
541
+ console.print(f" Stopping {name} worker...")
542
+ proc.terminate()
543
+ proc.join(timeout=10.0)
544
+ if proc.is_alive():
545
+ proc.kill()
546
+ proc.join(timeout=5.0)
547
+
548
+ self.queue.stop()
549
+ console.print("[green]Workers stopped[/green]")
550
+
551
+ def _wait_for_inference(self, timeout: float = 60.0) -> bool:
552
+ """Wait for inference worker to be ready via queue health check."""
553
+ start = time.time()
554
+ while time.time() - start < timeout and not self._shutdown:
555
+ try:
556
+ self.queue.send("rollout_requests", MessageType.HEALTH_CHECK, {}, source="orchestrator")
557
+ msg = self.queue.receive("rollout_responses", timeout=1.0)
558
+ if msg and msg.msg_type == MessageType.HEALTH_RESPONSE:
559
+ return True
560
+ except Exception:
561
+ pass
562
+ time.sleep(0.5)
563
+ return False
564
+
565
+ def _generate_rollout_via_api(
566
+ self,
567
+ task: GeneratedTask,
568
+ rollouts_per_task: int,
569
+ ) -> List[Rollout]:
570
+ """Generate rollouts for a task via inference API."""
571
+ import requests
572
+ rollouts = []
573
+
574
+ for k in range(rollouts_per_task):
575
+ try:
576
+ resp = requests.post(
577
+ f"http://localhost:{self.cfg.serve.port}/internal/rollout",
578
+ json={
579
+ "prompt": task.prompt,
580
+ "max_tokens": int(self.cfg.rft.max_new_tokens),
581
+ "temperature": float(self.cfg.rft.temperature),
582
+ "top_p": float(self.cfg.infer.top_p),
583
+ "top_k": self.cfg.infer.top_k,
584
+ "seed": int(time.time() * 1000) % (2**31 - 1),
585
+ "include_tokens": True,
586
+ "include_logprobs": True,
587
+ },
588
+ timeout=120.0,
589
+ )
590
+
591
+ if resp.status_code != 200:
592
+ console.print(f"[red]Rollout API error: {resp.status_code}[/red]")
593
+ continue
594
+
595
+ data = resp.json()
596
+ completion = data.get("completion", "")
597
+
598
+ # Run verifier
599
+ from ..util import ensure_dir
600
+ wdir = ensure_dir(self.project_root / "runs" / ".temp" / task.id / f"rollout_{k:02d}")
601
+ (wdir / "main.py").write_text(completion, encoding="utf-8")
602
+
603
+ # Write tests
604
+ tests_dir = ensure_dir(wdir / "tests")
605
+ (tests_dir / "test_task.py").write_text(task.tests, encoding="utf-8")
606
+
607
+ t0 = time.time()
608
+ if self.cfg.rlm.verifier_backend == "docker":
609
+ res = docker_verify(
610
+ task.prompt,
611
+ completion,
612
+ str(wdir),
613
+ timeout_s=int(self.cfg.rlm.verifier_timeout_s),
614
+ image=self.cfg.rlm.docker_image,
615
+ memory_mb=int(self.cfg.rlm.docker_memory_mb),
616
+ cpus=float(self.cfg.rlm.docker_cpus),
617
+ pids=int(self.cfg.rlm.docker_pids),
618
+ )
619
+ else:
620
+ from ..verifiers.pytest_verifier import verify as pytest_verify
621
+ res = pytest_verify(
622
+ task.prompt,
623
+ completion,
624
+ str(wdir),
625
+ timeout_s=int(self.cfg.rlm.verifier_timeout_s),
626
+ )
627
+ latency_ms = (time.time() - t0) * 1000.0
628
+
629
+ passed = bool(getattr(res, "passed", False))
630
+ reward = float(getattr(res, "reward", 0.0))
631
+
632
+ rollouts.append(Rollout(
633
+ task_id=task.id,
634
+ prompt=task.prompt,
635
+ completion=completion,
636
+ token_ids=data.get("token_ids", []),
637
+ prompt_len=data.get("prompt_len", 0),
638
+ logprobs=data.get("logprobs"),
639
+ passed=passed,
640
+ reward=reward,
641
+ verifier_latency_ms=latency_ms,
642
+ weight_adapter=self._pointer_store.load("inference", self._base_model).adapter_path,
643
+ ))
644
+
645
+ if passed:
646
+ self._passed_samples.append({
647
+ "id": task.id,
648
+ "prompt": task.prompt,
649
+ "response": completion,
650
+ "reward": reward,
651
+ "ts": now_ts(),
652
+ })
653
+
654
+ except Exception as e:
655
+ console.print(f"[red]Rollout error: {e}[/red]")
656
+ continue
657
+
658
+ return rollouts
659
+
660
+ def _generate_rollout_via_queue(
661
+ self,
662
+ task: GeneratedTask,
663
+ rollouts_per_task: int,
664
+ ) -> List[Rollout]:
665
+ """Generate rollouts for a task via the message queue."""
666
+ rollouts: List[Rollout] = []
667
+
668
+ for k in range(rollouts_per_task):
669
+ try:
670
+ req = self.queue.send(
671
+ "rollout_requests",
672
+ MessageType.ROLLOUT_REQUEST,
673
+ {
674
+ "prompt": task.prompt,
675
+ "max_tokens": int(self.cfg.rft.max_new_tokens),
676
+ "temperature": float(self.cfg.rft.temperature),
677
+ "top_p": float(self.cfg.infer.top_p),
678
+ "top_k": self.cfg.infer.top_k,
679
+ "seed": int(time.time() * 1000) % (2**31 - 1),
680
+ },
681
+ source="orchestrator",
682
+ )
683
+
684
+ # Wait for matching response
685
+ response = None
686
+ start = time.time()
687
+ while time.time() - start < self._rollout_timeout_s:
688
+ msg = self.queue.receive("rollout_responses", timeout=0.5)
689
+ if not msg:
690
+ continue
691
+ if msg.msg_type != MessageType.ROLLOUT_RESPONSE:
692
+ continue
693
+ if msg.payload.get("request_id") == req.msg_id:
694
+ response = msg
695
+ break
696
+
697
+ if response is None:
698
+ console.print("[red]Rollout queue timeout[/red]")
699
+ continue
700
+
701
+ data = response.payload
702
+ completion = data.get("completion", "")
703
+
704
+ # Run verifier
705
+ wdir = ensure_dir(self.project_root / "runs" / ".temp" / task.id / f"rollout_{k:02d}")
706
+ (wdir / "main.py").write_text(completion, encoding="utf-8")
707
+
708
+ tests_dir = ensure_dir(wdir / "tests")
709
+ (tests_dir / "test_task.py").write_text(task.tests, encoding="utf-8")
710
+
711
+ t0 = time.time()
712
+ if self.cfg.rlm.verifier_backend == "docker":
713
+ res = docker_verify(
714
+ task.prompt,
715
+ completion,
716
+ str(wdir),
717
+ timeout_s=int(self.cfg.rlm.verifier_timeout_s),
718
+ image=self.cfg.rlm.docker_image,
719
+ memory_mb=int(self.cfg.rlm.docker_memory_mb),
720
+ cpus=float(self.cfg.rlm.docker_cpus),
721
+ pids=int(self.cfg.rlm.docker_pids),
722
+ )
723
+ else:
724
+ from ..verifiers.pytest_verifier import verify as pytest_verify
725
+ res = pytest_verify(
726
+ task.prompt,
727
+ completion,
728
+ str(wdir),
729
+ timeout_s=int(self.cfg.rlm.verifier_timeout_s),
730
+ )
731
+ latency_ms = (time.time() - t0) * 1000.0
732
+
733
+ passed = bool(getattr(res, "passed", False))
734
+ reward = float(getattr(res, "reward", 0.0))
735
+
736
+ rollouts.append(
737
+ Rollout(
738
+ task_id=task.id,
739
+ prompt=task.prompt,
740
+ completion=completion,
741
+ token_ids=data.get("token_ids", []),
742
+ prompt_len=data.get("prompt_len", 0),
743
+ logprobs=data.get("logprobs"),
744
+ passed=passed,
745
+ reward=reward,
746
+ verifier_latency_ms=latency_ms,
747
+ weight_adapter=self._pointer_store.load("inference", self._base_model).adapter_path,
748
+ )
749
+ )
750
+
751
+ if passed:
752
+ self._passed_samples.append(
753
+ {
754
+ "id": task.id,
755
+ "prompt": task.prompt,
756
+ "response": completion,
757
+ "reward": reward,
758
+ "ts": now_ts(),
759
+ }
760
+ )
761
+
762
+ except Exception as e:
763
+ console.print(f"[red]Rollout error: {e}[/red]")
764
+ continue
765
+
766
+ return rollouts
767
+
768
+ def _send_training_batch(self, rollouts: List[Rollout], iteration: int, run_id: str) -> Message:
769
+ """Send a training batch to the trainer worker via queue."""
770
+ save_checkpoint = True
771
+
772
+ payload = {
773
+ "iteration": iteration,
774
+ "run_id": run_id,
775
+ "save_checkpoint": save_checkpoint,
776
+ "rollouts": [
777
+ {
778
+ "task_id": r.task_id,
779
+ "prompt": r.prompt,
780
+ "completion": r.completion,
781
+ "token_ids": r.token_ids,
782
+ "prompt_len": r.prompt_len,
783
+ "logprobs": r.logprobs,
784
+ "passed": r.passed,
785
+ "reward": r.reward,
786
+ "verifier_latency_ms": r.verifier_latency_ms,
787
+ "weight_adapter": r.weight_adapter,
788
+ }
789
+ for r in rollouts
790
+ ],
791
+ }
792
+ return self.queue.send("train_batches", MessageType.TRAIN_BATCH, payload, source="orchestrator")
793
+
794
+ def _drain_queue(self, queue_name: str) -> None:
795
+ """Drain all pending messages from a queue."""
796
+ while True:
797
+ msg = self.queue.receive(queue_name, timeout=0)
798
+ if msg is None:
799
+ break
800
+
801
+ def run_iteration(self, iteration: int) -> bool:
802
+ """Run a single orchestrated RLM iteration."""
803
+ console.print(f"\n[bold blue]=== Orchestrated RLM Iteration {iteration} ===[/bold blue]")
804
+
805
+ run = new_run(self.project_root, "rlm")
806
+ snapshot_config(self.cfg.model_dump(), run.config_snapshot_path)
807
+
808
+ # Generate tasks using a temporary LLM instance
809
+ # (In future: task generation could also go through inference worker)
810
+ console.print(" [dim]Generating tasks...[/dim]")
811
+
812
+ llm = get_llm_backend(self.cfg.model.backend)
813
+ pointer = self._pointer_store.load("inference", self._base_model)
814
+ llm.load(
815
+ pointer.base_model,
816
+ max_seq_len=self.cfg.model.max_seq_len,
817
+ dtype=self.cfg.model.dtype,
818
+ trust_remote_code=self.cfg.model.trust_remote_code,
819
+ )
820
+ if pointer.adapter_path:
821
+ llm.apply_adapter(pointer.adapter_path)
822
+
823
+ corpus_rows = load_corpus(self.corpus_path, max_size=int(self.cfg.rlm.corpus_max))
824
+ existing_prompts = [row.get("prompt", "") for row in corpus_rows if row.get("prompt")]
825
+
826
+ tasks = build_tasks(
827
+ llm,
828
+ self.cfg,
829
+ require_recursion=bool(self.cfg.rlm.require_recursion),
830
+ tasks_per_iter=int(self.cfg.rlm.tasks_per_iter),
831
+ mutations_per_task=int(self.cfg.rlm.mutations_per_task),
832
+ max_total=int(self.cfg.rlm.tasks_per_iter),
833
+ existing_prompts=existing_prompts,
834
+ )
835
+
836
+ write_jsonl(run.run_dir / "tasks.jsonl", [task.__dict__ for task in tasks])
837
+
838
+ # Generate rollouts via inference queue
839
+ console.print(f" [dim]Generating {len(tasks) * self.cfg.rlm.rollouts_per_task} rollouts...[/dim]")
840
+ all_rollouts = []
841
+ for i, task in enumerate(tasks):
842
+ rollouts = self._generate_rollout_via_queue(
843
+ task,
844
+ rollouts_per_task=int(self.cfg.rlm.rollouts_per_task),
845
+ )
846
+ all_rollouts.extend(rollouts)
847
+ if (i + 1) % 10 == 0:
848
+ console.print(f" {i + 1}/{len(tasks)} tasks completed")
849
+
850
+ # Save rollouts
851
+ write_jsonl(run.artifacts_dir / "rollouts.jsonl", [
852
+ {
853
+ "task_id": r.task_id,
854
+ "prompt": r.prompt,
855
+ "completion": r.completion,
856
+ "token_ids": r.token_ids,
857
+ "prompt_len": r.prompt_len,
858
+ "logprobs": r.logprobs,
859
+ "passed": r.passed,
860
+ "reward": r.reward,
861
+ }
862
+ for r in all_rollouts
863
+ ])
864
+
865
+ # Train via trainer worker (queue)
866
+ console.print(" [dim]Training on rollouts...[/dim]")
867
+ train_msg = self._send_training_batch(all_rollouts, iteration, run.run_dir.name)
868
+
869
+ train_resp = None
870
+ start = time.time()
871
+ while time.time() - start < self._train_timeout_s:
872
+ msg = self.queue.receive("train_complete", timeout=1.0)
873
+ if not msg:
874
+ continue
875
+ if msg.payload.get("request_id") == train_msg.msg_id:
876
+ train_resp = msg
877
+ break
878
+
879
+ if train_resp is None:
880
+ console.print("[red]Trainer timed out[/red]")
881
+ return False
882
+
883
+ train_result = train_resp.payload.get("result") or {}
884
+ checkpoint_path = train_resp.payload.get("checkpoint_path")
885
+
886
+ write_jsonl(
887
+ run.metrics_path,
888
+ [
889
+ {
890
+ "ts": now_ts(),
891
+ "kind": "rlm_train",
892
+ "iteration": iteration,
893
+ "loss": train_result.get("loss"),
894
+ "num_tasks": train_result.get("num_tasks"),
895
+ "num_rollouts": train_result.get("num_rollouts"),
896
+ }
897
+ ],
898
+ )
899
+
900
+ if not checkpoint_path:
901
+ console.print("[red]Trainer returned no checkpoint[/red]")
902
+ return False
903
+
904
+ copytree(Path(checkpoint_path), run.adapter_dir)
905
+
906
+ # Drain any weight update notifications from trainer
907
+ self._drain_queue("weight_updates")
908
+ self._drain_queue("checkpoints")
909
+
910
+ # Update corpus
911
+ if self._passed_samples:
912
+ append_corpus(self.corpus_path, self._passed_samples, max_size=int(self.cfg.rlm.corpus_max))
913
+ self._passed_samples = []
914
+
915
+ # Evaluate
916
+ adapter_score = 0.0
917
+ if self.cfg.rlm.benchmark_suite:
918
+ suite_path = self.project_root / self.cfg.rlm.benchmark_suite
919
+ if suite_path.exists():
920
+ eval_path = run_eval(self.project_root, suite_path, run.adapter_dir)
921
+ adapter_score = _score_from_eval(eval_path)
922
+
923
+ holdout_score = None
924
+ if self.cfg.rlm.holdout_suite:
925
+ holdout_path = self.project_root / self.cfg.rlm.holdout_suite
926
+ if holdout_path.exists():
927
+ holdout_eval = run_eval(self.project_root, holdout_path, run.adapter_dir)
928
+ holdout_score = _score_from_eval(holdout_eval)
929
+
930
+ # Gating
931
+ accepted = should_accept(
932
+ adapter_score,
933
+ self.gating_state,
934
+ mode=self.cfg.rlm.gating,
935
+ threshold=float(self.cfg.rlm.gating_threshold),
936
+ ema_alpha=float(self.cfg.rlm.gating_ema_alpha),
937
+ )
938
+ self.gating_state = update_state(
939
+ self.gating_state,
940
+ iteration=iteration,
941
+ score=adapter_score,
942
+ adapter_path=str(run.adapter_dir),
943
+ accepted=accepted,
944
+ ema_alpha=float(self.cfg.rlm.gating_ema_alpha),
945
+ )
946
+ save_state(self.state_path, self.gating_state)
947
+
948
+ # Update weight pointers
949
+ if self.gating_state.current_adapter:
950
+ train_pointer = WeightPointerIPC(
951
+ base_model=self._base_model,
952
+ adapter_path=self.gating_state.current_adapter,
953
+ iteration=iteration,
954
+ updated_at=now_ts(),
955
+ version=iteration,
956
+ name="trainer",
957
+ )
958
+ self._pointer_store.save(train_pointer)
959
+
960
+ # Update inference pointer (hot reload)
961
+ infer_staleness = int(getattr(self.cfg.rlm, "infer_staleness", 0))
962
+ current_infer = self._pointer_store.load("inference", self._base_model)
963
+ update_infer = infer_staleness <= 0
964
+ if not update_infer:
965
+ lag = max(0, int(iteration) - int(current_infer.iteration))
966
+ update_infer = lag >= infer_staleness
967
+
968
+ if update_infer:
969
+ infer_pointer = WeightPointerIPC(
970
+ base_model=self._base_model,
971
+ adapter_path=self.gating_state.current_adapter,
972
+ iteration=iteration,
973
+ updated_at=now_ts(),
974
+ version=iteration,
975
+ name="inference",
976
+ )
977
+ self._pointer_store.save(infer_pointer)
978
+ try:
979
+ self.queue.send(
980
+ "weight_forward",
981
+ MessageType.WEIGHT_UPDATE,
982
+ {
983
+ "adapter_path": self.gating_state.current_adapter,
984
+ "version": iteration,
985
+ "base_model": self._base_model,
986
+ },
987
+ source="orchestrator",
988
+ )
989
+ except Exception as e:
990
+ console.print(f"[yellow]Hot reload trigger failed: {e}[/yellow]")
991
+
992
+ # History
993
+ append_history(
994
+ self.history_path,
995
+ {
996
+ "iteration": iteration,
997
+ "timestamp": now_ts(),
998
+ "adapter_score": adapter_score,
999
+ "holdout_score": holdout_score,
1000
+ "best_score": self.gating_state.best_score,
1001
+ "accepted": accepted,
1002
+ "adapter_dir": str(run.adapter_dir),
1003
+ },
1004
+ )
1005
+
1006
+ gating_path = run.run_dir / "gating.json"
1007
+ gating_path.write_text(
1008
+ json.dumps(
1009
+ {
1010
+ "iteration": iteration,
1011
+ "accepted": accepted,
1012
+ "adapter_score": adapter_score,
1013
+ "holdout_score": holdout_score,
1014
+ "best_score": self.gating_state.best_score,
1015
+ "current_adapter": self.gating_state.current_adapter,
1016
+ },
1017
+ indent=2,
1018
+ ),
1019
+ encoding="utf-8",
1020
+ )
1021
+
1022
+ if self.cfg.rlm.sleep_between > 0:
1023
+ time.sleep(float(self.cfg.rlm.sleep_between))
1024
+
1025
+ return True
1026
+
1027
+ def run(self) -> None:
1028
+ """Run the orchestrated RLM loop."""
1029
+ self._setup_signal_handlers()
1030
+
1031
+ console.print("[bold green]Starting Orchestrated RLM[/bold green]")
1032
+
1033
+ # Start queue manager
1034
+ self.queue.start()
1035
+
1036
+ # Start workers
1037
+ self._start_inference_worker()
1038
+ self._start_trainer_worker()
1039
+
1040
+ console.print("[dim]Waiting for inference server...[/dim]")
1041
+ if not self._wait_for_inference(timeout=120.0):
1042
+ console.print("[red]Inference server failed to start[/red]")
1043
+ self._stop_workers()
1044
+ return
1045
+ console.print("[green]Inference server ready[/green]")
1046
+
1047
+ # Determine iteration range
1048
+ start_iter = self.gating_state.last_iteration + 1 if self.resume else 1
1049
+ total_iters = self.iterations
1050
+
1051
+ def iter_range():
1052
+ if total_iters == 0:
1053
+ i = start_iter
1054
+ while True:
1055
+ yield i
1056
+ i += 1
1057
+ else:
1058
+ for i in range(start_iter, start_iter + total_iters):
1059
+ yield i
1060
+
1061
+ try:
1062
+ for iteration in iter_range():
1063
+ if self._shutdown:
1064
+ break
1065
+
1066
+ success = self.run_iteration(iteration)
1067
+ if not success:
1068
+ console.print("[red]Iteration failed, stopping[/red]")
1069
+ break
1070
+
1071
+ except KeyboardInterrupt:
1072
+ console.print("[yellow]Interrupted by user[/yellow]")
1073
+ except Exception as e:
1074
+ console.print(f"[red]Orchestrator error: {e}[/red]")
1075
+ traceback.print_exc()
1076
+ finally:
1077
+ self._stop_workers()
1078
+
1079
+
1080
+ def run_rlm_orchestrated(
1081
+ project_root: Path,
1082
+ cfg: ProjectConfig,
1083
+ *,
1084
+ model_spec: Optional[str] = None,
1085
+ iterations: Optional[int] = None,
1086
+ resume: bool = False,
1087
+ ) -> None:
1088
+ """Run multi-process orchestrated RLM loop.
1089
+
1090
+ This mode spawns separate inference and trainer processes,
1091
+ coordinating via weight pointers and queue messages.
1092
+
1093
+ Benefits:
1094
+ - Inference server remains responsive during training
1095
+ - Hot-reload of weights without restart
1096
+ - Better resource isolation
1097
+ - Foundation for distributed training
1098
+ """
1099
+ spec = model_spec or cfg.model.id
1100
+ iters = iterations or cfg.rlm.iterations
1101
+
1102
+ orchestrator = RLMOrchestrator(
1103
+ project_root=project_root,
1104
+ cfg=cfg,
1105
+ model_spec=spec,
1106
+ iterations=iters,
1107
+ resume=resume,
1108
+ )
1109
+ orchestrator.run()
1110
+
1111
+
1112
+ def collect_rollouts_via_api(
1113
+ tasks: List[GeneratedTask],
1114
+ cfg: ProjectConfig,
1115
+ api_url: str,
1116
+ artifacts_dir: Path,
1117
+ verifier_backend: str,
1118
+ weight_adapter: Optional[str],
1119
+ ) -> tuple[List[Rollout], list[dict]]:
1120
+ """Collect rollouts via inference API (for legacy loop with external inference)."""
1121
+ try:
1122
+ import requests
1123
+ except ModuleNotFoundError:
1124
+ requests = None
1125
+ rollouts: List[Rollout] = []
1126
+ passed_samples: list[dict] = []
1127
+
1128
+ # Probe the API endpoint; fall back to local inference if unreachable.
1129
+ _api_available = False
1130
+ if requests is not None:
1131
+ try:
1132
+ requests.get(api_url, timeout=2.0)
1133
+ _api_available = True
1134
+ except Exception:
1135
+ _api_available = False
1136
+
1137
+ if not _api_available:
1138
+ # Resolve model spec to separate base model from adapter.
1139
+ base_model, resolved_adapter, _meta = resolve_model_spec(
1140
+ Path.cwd(), cfg.model.id, cfg
1141
+ )
1142
+ llm = get_llm_backend(cfg.model.backend)
1143
+ llm.load(
1144
+ base_model,
1145
+ max_seq_len=cfg.model.max_seq_len,
1146
+ dtype=cfg.model.dtype,
1147
+ trust_remote_code=cfg.model.trust_remote_code,
1148
+ )
1149
+ adapter_to_apply = weight_adapter or (str(resolved_adapter) if resolved_adapter else None)
1150
+ if adapter_to_apply:
1151
+ llm.apply_adapter(adapter_to_apply)
1152
+ for task in tasks:
1153
+ for k in range(int(cfg.rlm.rollouts_per_task)):
1154
+ try:
1155
+ gen = llm.generate_with_logprobs(
1156
+ task.prompt,
1157
+ max_new_tokens=int(cfg.rft.max_new_tokens),
1158
+ temperature=float(cfg.rft.temperature),
1159
+ top_p=float(cfg.infer.top_p),
1160
+ top_k=cfg.infer.top_k,
1161
+ seed=int(time.time() * 1000) % (2**31 - 1),
1162
+ )
1163
+ except TypeError:
1164
+ gen = llm.generate_with_logprobs(
1165
+ task.prompt,
1166
+ max_new_tokens=int(cfg.rft.max_new_tokens),
1167
+ temperature=float(cfg.rft.temperature),
1168
+ top_p=float(cfg.infer.top_p),
1169
+ top_k_sampling=cfg.infer.top_k,
1170
+ seed=int(time.time() * 1000) % (2**31 - 1),
1171
+ )
1172
+ completion = gen.text[len(task.prompt) :] if gen.text.startswith(task.prompt) else gen.text
1173
+ wdir = ensure_dir(artifacts_dir / task.id / f"rollout_{k:02d}")
1174
+ (wdir / "main.py").write_text(completion, encoding="utf-8")
1175
+ (ensure_dir(wdir / "tests") / "test_task.py").write_text(task.tests, encoding="utf-8")
1176
+ t0 = time.time()
1177
+ if verifier_backend == "docker":
1178
+ res = docker_verify(
1179
+ task.prompt,
1180
+ completion,
1181
+ str(wdir),
1182
+ timeout_s=int(cfg.rlm.verifier_timeout_s),
1183
+ image=cfg.rlm.docker_image,
1184
+ memory_mb=int(cfg.rlm.docker_memory_mb),
1185
+ cpus=float(cfg.rlm.docker_cpus),
1186
+ pids=int(cfg.rlm.docker_pids),
1187
+ )
1188
+ else:
1189
+ res = pytest_verify(
1190
+ task.prompt,
1191
+ completion,
1192
+ str(wdir),
1193
+ timeout_s=int(cfg.rlm.verifier_timeout_s),
1194
+ )
1195
+ latency_ms = (time.time() - t0) * 1000.0
1196
+ passed = bool(getattr(res, "passed", False))
1197
+ reward = float(getattr(res, "reward", 0.0))
1198
+ rollouts.append(
1199
+ Rollout(
1200
+ task_id=task.id,
1201
+ prompt=task.prompt,
1202
+ completion=completion,
1203
+ token_ids=list(gen.token_ids),
1204
+ prompt_len=gen.prompt_len,
1205
+ logprobs=list(gen.logprobs) if gen.logprobs else None,
1206
+ passed=passed,
1207
+ reward=reward,
1208
+ verifier_latency_ms=latency_ms,
1209
+ weight_adapter=weight_adapter,
1210
+ )
1211
+ )
1212
+ if passed:
1213
+ passed_samples.append(
1214
+ {
1215
+ "id": task.id,
1216
+ "prompt": task.prompt,
1217
+ "response": completion,
1218
+ "reward": reward,
1219
+ "ts": now_ts(),
1220
+ }
1221
+ )
1222
+ return rollouts, passed_samples
1223
+
1224
+ for task in tasks:
1225
+ for k in range(int(cfg.rlm.rollouts_per_task)):
1226
+ try:
1227
+ resp = requests.post(
1228
+ f"{api_url}/internal/rollout",
1229
+ json={
1230
+ "prompt": task.prompt,
1231
+ "max_tokens": int(cfg.rft.max_new_tokens),
1232
+ "temperature": float(cfg.rft.temperature),
1233
+ "top_p": float(cfg.infer.top_p),
1234
+ "top_k": cfg.infer.top_k,
1235
+ "seed": int(time.time() * 1000) % (2**31 - 1),
1236
+ "include_tokens": True,
1237
+ "include_logprobs": True,
1238
+ },
1239
+ timeout=120.0,
1240
+ )
1241
+
1242
+ if resp.status_code != 200:
1243
+ continue
1244
+
1245
+ data = resp.json()
1246
+ completion = data.get("completion", "")
1247
+
1248
+ wdir = ensure_dir(artifacts_dir / task.id / f"rollout_{k:02d}")
1249
+ (wdir / "main.py").write_text(completion, encoding="utf-8")
1250
+ (ensure_dir(wdir / "tests") / "test_task.py").write_text(task.tests, encoding="utf-8")
1251
+
1252
+ t0 = time.time()
1253
+ if verifier_backend == "docker":
1254
+ res = docker_verify(
1255
+ task.prompt,
1256
+ completion,
1257
+ str(wdir),
1258
+ timeout_s=int(cfg.rlm.verifier_timeout_s),
1259
+ image=cfg.rlm.docker_image,
1260
+ memory_mb=int(cfg.rlm.docker_memory_mb),
1261
+ cpus=float(cfg.rlm.docker_cpus),
1262
+ pids=int(cfg.rlm.docker_pids),
1263
+ )
1264
+ else:
1265
+ res = pytest_verify(task.prompt, completion, str(wdir), timeout_s=int(cfg.rlm.verifier_timeout_s))
1266
+ latency_ms = (time.time() - t0) * 1000.0
1267
+
1268
+ passed = bool(getattr(res, "passed", False))
1269
+ reward = float(getattr(res, "reward", 0.0))
1270
+
1271
+ rollouts.append(Rollout(
1272
+ task_id=task.id,
1273
+ prompt=task.prompt,
1274
+ completion=completion,
1275
+ token_ids=data.get("token_ids", []),
1276
+ prompt_len=data.get("prompt_len", 0),
1277
+ logprobs=data.get("logprobs"),
1278
+ passed=passed,
1279
+ reward=reward,
1280
+ verifier_latency_ms=latency_ms,
1281
+ weight_adapter=weight_adapter,
1282
+ ))
1283
+
1284
+ if passed:
1285
+ passed_samples.append({
1286
+ "id": task.id,
1287
+ "prompt": task.prompt,
1288
+ "response": completion,
1289
+ "reward": reward,
1290
+ "ts": now_ts(),
1291
+ })
1292
+
1293
+ except Exception as e:
1294
+ console.print(f"[red]Rollout error: {e}[/red]")
1295
+ continue
1296
+
1297
+ return rollouts, passed_samples