mlxsmith 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlxsmith/__init__.py +2 -0
- mlxsmith/accel/__init__.py +10 -0
- mlxsmith/accel/base.py +17 -0
- mlxsmith/accel/none.py +13 -0
- mlxsmith/accel/zmlx_backend.py +42 -0
- mlxsmith/adapters.py +46 -0
- mlxsmith/api/__init__.py +48 -0
- mlxsmith/api/handlers.py +1217 -0
- mlxsmith/api/schemas.py +436 -0
- mlxsmith/auth.py +88 -0
- mlxsmith/bench.py +102 -0
- mlxsmith/cli.py +950 -0
- mlxsmith/config.py +543 -0
- mlxsmith/config_models.py +261 -0
- mlxsmith/data.py +493 -0
- mlxsmith/envs/__init__.py +33 -0
- mlxsmith/envs/system.py +388 -0
- mlxsmith/envs/token_env.py +191 -0
- mlxsmith/eval.py +112 -0
- mlxsmith/infer.py +140 -0
- mlxsmith/llm/__init__.py +16 -0
- mlxsmith/llm/backend.py +126 -0
- mlxsmith/llm/interface.py +212 -0
- mlxsmith/llm/mlx_lm_backend.py +509 -0
- mlxsmith/llm/mock_backend.py +228 -0
- mlxsmith/llm/registry.py +12 -0
- mlxsmith/models.py +257 -0
- mlxsmith/orchestrator/__init__.py +25 -0
- mlxsmith/orchestrator/daemon.py +454 -0
- mlxsmith/orchestrator/inference_worker.py +496 -0
- mlxsmith/orchestrator/queue.py +355 -0
- mlxsmith/orchestrator/trainer_worker.py +437 -0
- mlxsmith/rlm/__init__.py +8 -0
- mlxsmith/rlm/corpus.py +74 -0
- mlxsmith/rlm/gating.py +90 -0
- mlxsmith/rlm/generate.py +249 -0
- mlxsmith/rlm/history.py +12 -0
- mlxsmith/rlm/inference.py +150 -0
- mlxsmith/rlm/loop.py +1297 -0
- mlxsmith/rlm/mutate.py +82 -0
- mlxsmith/rlm/trainer.py +73 -0
- mlxsmith/rlm/weights.py +263 -0
- mlxsmith/runs.py +44 -0
- mlxsmith/sdk/__init__.py +392 -0
- mlxsmith/sdk/future.py +486 -0
- mlxsmith/sdk/losses.py +262 -0
- mlxsmith/sdk/sampling_client.py +729 -0
- mlxsmith/sdk/training_client.py +676 -0
- mlxsmith/server.py +376 -0
- mlxsmith/train/__init__.py +0 -0
- mlxsmith/train/distill.py +279 -0
- mlxsmith/train/lora.py +280 -0
- mlxsmith/train/pref.py +180 -0
- mlxsmith/train/rft.py +458 -0
- mlxsmith/train/sft.py +151 -0
- mlxsmith/util.py +174 -0
- mlxsmith/verifiers/__init__.py +3 -0
- mlxsmith/verifiers/compose.py +109 -0
- mlxsmith/verifiers/docker_verifier.py +111 -0
- mlxsmith/verifiers/jsonschema.py +54 -0
- mlxsmith/verifiers/pytest_verifier.py +82 -0
- mlxsmith/verifiers/regex.py +15 -0
- mlxsmith/verifiers/types.py +10 -0
- mlxsmith-0.1.0.dist-info/METADATA +163 -0
- mlxsmith-0.1.0.dist-info/RECORD +69 -0
- mlxsmith-0.1.0.dist-info/WHEEL +5 -0
- mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
- mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
- mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,454 @@
|
|
|
1
|
+
"""Orchestrator Daemon for MLXSmith Multi-Process RLM.
|
|
2
|
+
|
|
3
|
+
Central queue-based job scheduler that coordinates between:
|
|
4
|
+
- Inference server (generates rollouts)
|
|
5
|
+
- Trainer worker (consumes batches and updates weights)
|
|
6
|
+
|
|
7
|
+
Manages rollout requests, training batches, and weight updates.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import json
|
|
13
|
+
import multiprocessing as mp
|
|
14
|
+
import signal
|
|
15
|
+
import sys
|
|
16
|
+
import time
|
|
17
|
+
import traceback
|
|
18
|
+
from dataclasses import dataclass, field
|
|
19
|
+
from pathlib import Path
|
|
20
|
+
from typing import Any, Dict, List, Optional, Callable
|
|
21
|
+
|
|
22
|
+
from rich.console import Console
|
|
23
|
+
|
|
24
|
+
from ..config import ProjectConfig
|
|
25
|
+
from ..rlm.corpus import append_corpus, load_corpus, sample_corpus
|
|
26
|
+
from ..rlm.gating import load_state, save_state, should_accept, update_state
|
|
27
|
+
from ..rlm.history import append_history
|
|
28
|
+
from ..rlm.inference import Rollout, build_tasks
|
|
29
|
+
from ..rlm.weights import WeightPointerStore, WeightPointerIPC
|
|
30
|
+
from ..runs import new_run, snapshot_config
|
|
31
|
+
from ..util import ensure_dir, now_ts, write_jsonl
|
|
32
|
+
from .queue import MessageQueue, MessageType, Message
|
|
33
|
+
from .inference_worker import InferenceConfig, run_inference_worker
|
|
34
|
+
from .trainer_worker import TrainerConfig, run_trainer_worker
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
console = Console()
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass
|
|
41
|
+
class DaemonConfig:
|
|
42
|
+
"""Configuration for orchestrator daemon."""
|
|
43
|
+
project_root: Path
|
|
44
|
+
model_spec: str
|
|
45
|
+
|
|
46
|
+
# Process management
|
|
47
|
+
inference_port: int = 8080
|
|
48
|
+
inference_host: str = "0.0.0.0"
|
|
49
|
+
max_restarts: int = 3
|
|
50
|
+
restart_delay: float = 5.0
|
|
51
|
+
health_check_interval: float = 10.0
|
|
52
|
+
|
|
53
|
+
# Training config
|
|
54
|
+
iterations: int = 50
|
|
55
|
+
tasks_per_iter: int = 80
|
|
56
|
+
rollouts_per_task: int = 8
|
|
57
|
+
batch_size: int = 32
|
|
58
|
+
|
|
59
|
+
# Paths
|
|
60
|
+
weights_dir: Optional[Path] = None
|
|
61
|
+
checkpoint_dir: Optional[Path] = None
|
|
62
|
+
|
|
63
|
+
# Gating
|
|
64
|
+
gating_mode: str = "strict"
|
|
65
|
+
gating_threshold: float = 0.0
|
|
66
|
+
gating_ema_alpha: float = 0.2
|
|
67
|
+
|
|
68
|
+
# Verifier
|
|
69
|
+
verifier_backend: str = "pytest"
|
|
70
|
+
verifier_timeout_s: int = 30
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@dataclass
|
|
74
|
+
class ProcessHandle:
|
|
75
|
+
"""Handle for a managed process."""
|
|
76
|
+
name: str
|
|
77
|
+
process: mp.Process
|
|
78
|
+
config: Any
|
|
79
|
+
restart_count: int = 0
|
|
80
|
+
last_restart: float = 0.0
|
|
81
|
+
healthy: bool = True
|
|
82
|
+
start_time: float = field(default_factory=time.time)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class OrchestratorDaemon:
|
|
86
|
+
"""Orchestrator daemon for multi-process RLM.
|
|
87
|
+
|
|
88
|
+
Responsibilities:
|
|
89
|
+
- Spawn and manage inference and trainer processes
|
|
90
|
+
- Coordinate rollout requests and training batches
|
|
91
|
+
- Manage weight pointer updates
|
|
92
|
+
- Handle process lifecycle, monitoring, and restarts
|
|
93
|
+
- Graceful shutdown handling
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
def __init__(self, config: DaemonConfig, project_cfg: ProjectConfig):
|
|
97
|
+
self.config = config
|
|
98
|
+
self.project_cfg = project_cfg
|
|
99
|
+
self.queue = MessageQueue(maxsize=10000)
|
|
100
|
+
self._processes: Dict[str, ProcessHandle] = {}
|
|
101
|
+
self._shutdown = False
|
|
102
|
+
self._pointer_store: Optional[WeightPointerStore] = None
|
|
103
|
+
self._current_iteration = 0
|
|
104
|
+
self._metrics: List[Dict] = []
|
|
105
|
+
|
|
106
|
+
# Setup paths
|
|
107
|
+
self._weights_dir = config.weights_dir or (config.project_root / "runs" / "rlm_weights")
|
|
108
|
+
self._checkpoint_dir = config.checkpoint_dir or (config.project_root / "runs" / "rlm_checkpoints")
|
|
109
|
+
self._state_path = config.project_root / "runs" / "rlm_state.json"
|
|
110
|
+
self._history_path = config.project_root / "runs" / "rlm_history.jsonl"
|
|
111
|
+
self._corpus_path = config.project_root / "runs" / "rlm_corpus.jsonl"
|
|
112
|
+
|
|
113
|
+
ensure_dir(self._weights_dir)
|
|
114
|
+
ensure_dir(self._checkpoint_dir)
|
|
115
|
+
|
|
116
|
+
def _setup_signal_handlers(self) -> None:
|
|
117
|
+
"""Setup signal handlers for graceful shutdown."""
|
|
118
|
+
def signal_handler(sig, frame):
|
|
119
|
+
console.print("[yellow]Orchestrator received shutdown signal[/yellow]")
|
|
120
|
+
self._shutdown = True
|
|
121
|
+
|
|
122
|
+
signal.signal(signal.SIGTERM, signal_handler)
|
|
123
|
+
signal.signal(signal.SIGINT, signal_handler)
|
|
124
|
+
|
|
125
|
+
def _spawn_inference_worker(self) -> ProcessHandle:
|
|
126
|
+
"""Spawn the inference worker process."""
|
|
127
|
+
inf_config = InferenceConfig(
|
|
128
|
+
model_spec=self.config.model_spec,
|
|
129
|
+
backend=self.project_cfg.model.backend,
|
|
130
|
+
host=self.config.inference_host,
|
|
131
|
+
port=self.config.inference_port,
|
|
132
|
+
max_seq_len=self.project_cfg.model.max_seq_len,
|
|
133
|
+
dtype=self.project_cfg.model.dtype,
|
|
134
|
+
trust_remote_code=self.project_cfg.model.trust_remote_code,
|
|
135
|
+
use_chat_template=self.project_cfg.model.use_chat_template,
|
|
136
|
+
weights_dir=self._weights_dir,
|
|
137
|
+
hot_reload=True,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
# Create process
|
|
141
|
+
process = mp.Process(
|
|
142
|
+
target=run_inference_worker,
|
|
143
|
+
args=(inf_config, self.queue),
|
|
144
|
+
name="inference_worker",
|
|
145
|
+
daemon=False,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
handle = ProcessHandle(
|
|
149
|
+
name="inference",
|
|
150
|
+
process=process,
|
|
151
|
+
config=inf_config,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
process.start()
|
|
155
|
+
console.print(f"[green]Spawned inference worker (PID: {process.pid})[/green]")
|
|
156
|
+
|
|
157
|
+
return handle
|
|
158
|
+
|
|
159
|
+
def _spawn_trainer_worker(self) -> ProcessHandle:
|
|
160
|
+
"""Spawn the trainer worker process."""
|
|
161
|
+
# Resolve base model
|
|
162
|
+
from ..models import resolve_model_spec
|
|
163
|
+
base_model, adapter_path, _ = resolve_model_spec(
|
|
164
|
+
self.config.project_root, self.config.model_spec, self.project_cfg
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
trainer_config = TrainerConfig(
|
|
168
|
+
model_spec=self.config.model_spec,
|
|
169
|
+
base_model=base_model,
|
|
170
|
+
backend=self.project_cfg.model.backend,
|
|
171
|
+
max_seq_len=self.project_cfg.model.max_seq_len,
|
|
172
|
+
dtype=self.project_cfg.model.dtype,
|
|
173
|
+
trust_remote_code=self.project_cfg.model.trust_remote_code,
|
|
174
|
+
lr=self.project_cfg.train.lr,
|
|
175
|
+
weight_decay=self.project_cfg.train.weight_decay,
|
|
176
|
+
kl_coeff=self.project_cfg.rft.kl_coeff,
|
|
177
|
+
normalize_advantage=self.project_cfg.rft.normalize_advantage,
|
|
178
|
+
lora_r=self.project_cfg.lora.r,
|
|
179
|
+
lora_alpha=self.project_cfg.lora.alpha,
|
|
180
|
+
lora_dropout=self.project_cfg.lora.dropout,
|
|
181
|
+
lora_target_modules=list(self.project_cfg.lora.target_modules or []),
|
|
182
|
+
lora_num_layers=self.project_cfg.lora.num_layers,
|
|
183
|
+
weights_dir=self._weights_dir,
|
|
184
|
+
checkpoint_dir=self._checkpoint_dir,
|
|
185
|
+
reference_model=self.project_cfg.rft.reference_model,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
# Create process
|
|
189
|
+
process = mp.Process(
|
|
190
|
+
target=run_trainer_worker,
|
|
191
|
+
args=(trainer_config, self.queue),
|
|
192
|
+
name="trainer_worker",
|
|
193
|
+
daemon=False,
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
handle = ProcessHandle(
|
|
197
|
+
name="trainer",
|
|
198
|
+
process=process,
|
|
199
|
+
config=trainer_config,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
process.start()
|
|
203
|
+
console.print(f"[green]Spawned trainer worker (PID: {process.pid})[/green]")
|
|
204
|
+
|
|
205
|
+
return handle
|
|
206
|
+
|
|
207
|
+
def _monitor_processes(self) -> None:
|
|
208
|
+
"""Monitor processes and restart if needed."""
|
|
209
|
+
current_time = time.time()
|
|
210
|
+
|
|
211
|
+
for name, handle in list(self._processes.items()):
|
|
212
|
+
# Check if process is alive
|
|
213
|
+
if not handle.process.is_alive():
|
|
214
|
+
if self._shutdown:
|
|
215
|
+
continue
|
|
216
|
+
|
|
217
|
+
console.print(f"[red]Process {name} (PID: {handle.process.pid}) died[/red]")
|
|
218
|
+
handle.healthy = False
|
|
219
|
+
|
|
220
|
+
# Check restart limit
|
|
221
|
+
if handle.restart_count >= self.config.max_restarts:
|
|
222
|
+
console.print(f"[red]Process {name} exceeded max restarts[/red]")
|
|
223
|
+
continue
|
|
224
|
+
|
|
225
|
+
# Check restart delay
|
|
226
|
+
if current_time - handle.last_restart < self.config.restart_delay:
|
|
227
|
+
time.sleep(self.config.restart_delay)
|
|
228
|
+
|
|
229
|
+
# Restart process
|
|
230
|
+
console.print(f"[yellow]Restarting {name}...[/yellow]")
|
|
231
|
+
|
|
232
|
+
if name == "inference":
|
|
233
|
+
new_handle = self._spawn_inference_worker()
|
|
234
|
+
elif name == "trainer":
|
|
235
|
+
new_handle = self._spawn_trainer_worker()
|
|
236
|
+
else:
|
|
237
|
+
continue
|
|
238
|
+
|
|
239
|
+
new_handle.restart_count = handle.restart_count + 1
|
|
240
|
+
new_handle.last_restart = current_time
|
|
241
|
+
self._processes[name] = new_handle
|
|
242
|
+
|
|
243
|
+
def _health_check(self) -> Dict[str, Any]:
|
|
244
|
+
"""Perform health checks on all processes via queues."""
|
|
245
|
+
results = {}
|
|
246
|
+
|
|
247
|
+
# Check inference via queue
|
|
248
|
+
if "inference" in self._processes:
|
|
249
|
+
self.queue.send(
|
|
250
|
+
"control",
|
|
251
|
+
MessageType.HEALTH_CHECK,
|
|
252
|
+
{},
|
|
253
|
+
source="daemon",
|
|
254
|
+
)
|
|
255
|
+
# Response will be processed in main loop
|
|
256
|
+
|
|
257
|
+
# Check trainer via queue
|
|
258
|
+
if "trainer" in self._processes:
|
|
259
|
+
self.queue.send(
|
|
260
|
+
"train_batches", # Trainer reads from train_batches
|
|
261
|
+
MessageType.HEALTH_CHECK,
|
|
262
|
+
{},
|
|
263
|
+
source="daemon",
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
return results
|
|
267
|
+
|
|
268
|
+
def _forward_weight_updates(self) -> None:
|
|
269
|
+
"""Forward weight updates from trainer to inference."""
|
|
270
|
+
# Check for weight updates from trainer
|
|
271
|
+
msg = self.queue.receive("weight_updates", timeout=0)
|
|
272
|
+
if msg and msg.msg_type == MessageType.WEIGHT_UPDATE:
|
|
273
|
+
# Forward to inference worker
|
|
274
|
+
self.queue.send(
|
|
275
|
+
"weight_forward",
|
|
276
|
+
MessageType.WEIGHT_UPDATE,
|
|
277
|
+
msg.payload,
|
|
278
|
+
source="daemon",
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
# Also update inference pointer
|
|
282
|
+
if self._pointer_store:
|
|
283
|
+
pointer = WeightPointerIPC(
|
|
284
|
+
base_model=msg.payload.get("base_model", ""),
|
|
285
|
+
adapter_path=msg.payload.get("adapter_path"),
|
|
286
|
+
iteration=msg.payload.get("version", 0),
|
|
287
|
+
updated_at=now_ts(),
|
|
288
|
+
version=msg.payload.get("version", 0),
|
|
289
|
+
name="inference",
|
|
290
|
+
)
|
|
291
|
+
self._pointer_store.save(pointer)
|
|
292
|
+
console.print(f"[blue]Forwarded weight update: {pointer.adapter_path}[/blue]")
|
|
293
|
+
|
|
294
|
+
def _shutdown_all(self) -> None:
|
|
295
|
+
"""Shutdown all processes gracefully."""
|
|
296
|
+
console.print("[yellow]Shutting down all processes...[/yellow]")
|
|
297
|
+
|
|
298
|
+
# Send shutdown messages
|
|
299
|
+
for name in self._processes:
|
|
300
|
+
self.queue.send(
|
|
301
|
+
"control",
|
|
302
|
+
MessageType.SHUTDOWN,
|
|
303
|
+
{},
|
|
304
|
+
source="daemon",
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
# Wait for processes to terminate
|
|
308
|
+
for name, handle in self._processes.items():
|
|
309
|
+
console.print(f" Waiting for {name}...")
|
|
310
|
+
handle.process.join(timeout=10.0)
|
|
311
|
+
|
|
312
|
+
if handle.process.is_alive():
|
|
313
|
+
console.print(f" Force terminating {name}")
|
|
314
|
+
handle.process.terminate()
|
|
315
|
+
handle.process.join(timeout=5.0)
|
|
316
|
+
|
|
317
|
+
if handle.process.is_alive():
|
|
318
|
+
handle.process.kill()
|
|
319
|
+
|
|
320
|
+
# Stop queue manager
|
|
321
|
+
self.queue.stop()
|
|
322
|
+
|
|
323
|
+
console.print("[green]All processes shutdown[/green]")
|
|
324
|
+
|
|
325
|
+
def run_iteration(self, iteration: int) -> bool:
|
|
326
|
+
"""Run a single RLM iteration.
|
|
327
|
+
|
|
328
|
+
Returns True if iteration completed successfully.
|
|
329
|
+
"""
|
|
330
|
+
console.print(f"\n[bold blue]=== RLM Iteration {iteration} ===[/bold blue]")
|
|
331
|
+
|
|
332
|
+
run = new_run(self.config.project_root, "rlm")
|
|
333
|
+
snapshot_config(self.project_cfg.model_dump(), run.config_snapshot_path)
|
|
334
|
+
|
|
335
|
+
# Phase 1: Generate tasks (via inference worker API)
|
|
336
|
+
console.print(" [dim]Generating tasks...[/dim]")
|
|
337
|
+
# Tasks are generated by querying inference worker
|
|
338
|
+
|
|
339
|
+
# Phase 2: Collect rollouts (via inference worker)
|
|
340
|
+
console.print(" [dim]Collecting rollouts...[/dim]")
|
|
341
|
+
# Rollouts are generated via /internal/rollout endpoint
|
|
342
|
+
|
|
343
|
+
# Phase 3: Send training batch to trainer
|
|
344
|
+
console.print(" [dim]Sending training batch...[/dim]")
|
|
345
|
+
|
|
346
|
+
# Phase 4: Wait for training completion
|
|
347
|
+
console.print(" [dim]Waiting for training...[/dim]")
|
|
348
|
+
|
|
349
|
+
# This is a placeholder - actual implementation would
|
|
350
|
+
# coordinate via queues and the inference worker API
|
|
351
|
+
|
|
352
|
+
return True
|
|
353
|
+
|
|
354
|
+
def run(self) -> None:
|
|
355
|
+
"""Run the orchestrator daemon."""
|
|
356
|
+
self._setup_signal_handlers()
|
|
357
|
+
|
|
358
|
+
console.print("[bold green]Starting MLXSmith Orchestrator[/bold green]")
|
|
359
|
+
|
|
360
|
+
# Start queue manager
|
|
361
|
+
self.queue.start()
|
|
362
|
+
console.print("[dim]Queue manager started[/dim]")
|
|
363
|
+
|
|
364
|
+
# Initialize weight pointer store
|
|
365
|
+
self._pointer_store = WeightPointerStore(self._weights_dir)
|
|
366
|
+
console.print(f"[dim]Weight store: {self._weights_dir}[/dim]")
|
|
367
|
+
|
|
368
|
+
# Spawn worker processes
|
|
369
|
+
console.print("[dim]Spawning worker processes...[/dim]")
|
|
370
|
+
self._processes["inference"] = self._spawn_inference_worker()
|
|
371
|
+
self._processes["trainer"] = self._spawn_trainer_worker()
|
|
372
|
+
|
|
373
|
+
# Wait for processes to initialize
|
|
374
|
+
console.print("[dim]Waiting for workers to initialize...[/dim]")
|
|
375
|
+
time.sleep(5.0)
|
|
376
|
+
|
|
377
|
+
# Load state
|
|
378
|
+
state = load_state(self._state_path)
|
|
379
|
+
self._current_iteration = state.last_iteration + 1
|
|
380
|
+
|
|
381
|
+
last_health_check = time.time()
|
|
382
|
+
|
|
383
|
+
try:
|
|
384
|
+
# Main orchestrator loop
|
|
385
|
+
while not self._shutdown:
|
|
386
|
+
# Monitor processes
|
|
387
|
+
self._monitor_processes()
|
|
388
|
+
|
|
389
|
+
# Health checks
|
|
390
|
+
current_time = time.time()
|
|
391
|
+
if current_time - last_health_check > self.config.health_check_interval:
|
|
392
|
+
self._health_check()
|
|
393
|
+
last_health_check = current_time
|
|
394
|
+
|
|
395
|
+
# Forward weight updates
|
|
396
|
+
self._forward_weight_updates()
|
|
397
|
+
|
|
398
|
+
# Process queue messages
|
|
399
|
+
self._process_queue_messages()
|
|
400
|
+
|
|
401
|
+
# Small sleep to prevent busy waiting
|
|
402
|
+
time.sleep(0.01)
|
|
403
|
+
|
|
404
|
+
except KeyboardInterrupt:
|
|
405
|
+
console.print("[yellow]Interrupted by user[/yellow]")
|
|
406
|
+
except Exception as e:
|
|
407
|
+
console.print(f"[red]Orchestrator error: {e}[/red]")
|
|
408
|
+
traceback.print_exc()
|
|
409
|
+
finally:
|
|
410
|
+
self._shutdown_all()
|
|
411
|
+
|
|
412
|
+
def _process_queue_messages(self) -> None:
|
|
413
|
+
"""Process pending queue messages."""
|
|
414
|
+
# Process responses from inference
|
|
415
|
+
msg = self.queue.receive("rollout_responses", timeout=0)
|
|
416
|
+
if msg:
|
|
417
|
+
# Handle rollout response
|
|
418
|
+
pass
|
|
419
|
+
|
|
420
|
+
# Process training completion
|
|
421
|
+
msg = self.queue.receive("train_complete", timeout=0)
|
|
422
|
+
if msg:
|
|
423
|
+
# Handle training completion
|
|
424
|
+
pass
|
|
425
|
+
|
|
426
|
+
# Process checkpoints
|
|
427
|
+
msg = self.queue.receive("checkpoints", timeout=0)
|
|
428
|
+
if msg:
|
|
429
|
+
# Handle checkpoint notification
|
|
430
|
+
pass
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
def run_daemon(
|
|
434
|
+
project_root: Path,
|
|
435
|
+
project_cfg: ProjectConfig,
|
|
436
|
+
model_spec: Optional[str] = None,
|
|
437
|
+
iterations: Optional[int] = None,
|
|
438
|
+
) -> None:
|
|
439
|
+
"""Run the orchestrator daemon."""
|
|
440
|
+
config = DaemonConfig(
|
|
441
|
+
project_root=project_root,
|
|
442
|
+
model_spec=model_spec or project_cfg.model.id,
|
|
443
|
+
iterations=iterations or project_cfg.rlm.iterations,
|
|
444
|
+
tasks_per_iter=project_cfg.rlm.tasks_per_iter,
|
|
445
|
+
rollouts_per_task=project_cfg.rlm.rollouts_per_task,
|
|
446
|
+
gating_mode=project_cfg.rlm.gating,
|
|
447
|
+
gating_threshold=project_cfg.rlm.gating_threshold,
|
|
448
|
+
gating_ema_alpha=project_cfg.rlm.gating_ema_alpha,
|
|
449
|
+
verifier_backend=project_cfg.rlm.verifier_backend,
|
|
450
|
+
verifier_timeout_s=project_cfg.rlm.verifier_timeout_s,
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
daemon = OrchestratorDaemon(config, project_cfg)
|
|
454
|
+
daemon.run()
|