mlxsmith 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlxsmith/bench.py +12 -2
- mlxsmith/cli.py +187 -1
- mlxsmith/config_models.py +15 -1
- mlxsmith/integrations/__init__.py +19 -0
- mlxsmith/integrations/mlx_lm_lora.py +117 -0
- mlxsmith/llm/backend.py +8 -1
- mlxsmith/llm/mlx_lm_backend.py +59 -2
- mlxsmith/llm/mock_backend.py +8 -1
- mlxsmith/optim/__init__.py +3 -0
- mlxsmith/optim/muon.py +93 -0
- mlxsmith/orchestrator/daemon.py +44 -377
- mlxsmith/orchestrator/trainer_worker.py +4 -0
- mlxsmith/rlm/loop.py +53 -92
- mlxsmith/sdk/__init__.py +18 -2
- mlxsmith/sdk/losses.py +102 -1
- mlxsmith/sdk/training_client.py +24 -5
- mlxsmith/train/distill.py +6 -1
- mlxsmith/train/online_dpo.py +249 -0
- mlxsmith/train/pref.py +31 -29
- mlxsmith/train/rft.py +123 -38
- mlxsmith/train/self_verify.py +199 -0
- mlxsmith/train/sft.py +13 -2
- mlxsmith/verifiers/llm_judge.py +278 -0
- mlxsmith/verifiers/prime.py +127 -0
- {mlxsmith-0.1.2.dist-info → mlxsmith-0.1.3.dist-info}/METADATA +27 -1
- {mlxsmith-0.1.2.dist-info → mlxsmith-0.1.3.dist-info}/RECORD +30 -22
- {mlxsmith-0.1.2.dist-info → mlxsmith-0.1.3.dist-info}/WHEEL +0 -0
- {mlxsmith-0.1.2.dist-info → mlxsmith-0.1.3.dist-info}/entry_points.txt +0 -0
- {mlxsmith-0.1.2.dist-info → mlxsmith-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {mlxsmith-0.1.2.dist-info → mlxsmith-0.1.3.dist-info}/top_level.txt +0 -0
mlxsmith/orchestrator/daemon.py
CHANGED
|
@@ -1,33 +1,18 @@
|
|
|
1
1
|
"""Orchestrator Daemon for MLXSmith Multi-Process RLM.
|
|
2
2
|
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
- Trainer worker (consumes batches and updates weights)
|
|
6
|
-
|
|
7
|
-
Manages rollout requests, training batches, and weight updates.
|
|
3
|
+
Thin wrapper around the orchestrated RLM loop to keep the legacy daemon API
|
|
4
|
+
usable while delegating implementation to the maintained orchestrator.
|
|
8
5
|
"""
|
|
9
6
|
|
|
10
7
|
from __future__ import annotations
|
|
11
8
|
|
|
12
|
-
|
|
13
|
-
import signal
|
|
14
|
-
import time
|
|
15
|
-
import traceback
|
|
16
|
-
from dataclasses import dataclass, field
|
|
9
|
+
from dataclasses import dataclass
|
|
17
10
|
from pathlib import Path
|
|
18
|
-
from typing import
|
|
11
|
+
from typing import Optional
|
|
19
12
|
|
|
20
13
|
from rich.console import Console
|
|
21
14
|
|
|
22
15
|
from ..config import ProjectConfig
|
|
23
|
-
from ..rlm.gating import load_state
|
|
24
|
-
from ..rlm.weights import WeightPointerStore, WeightPointerIPC
|
|
25
|
-
from ..runs import new_run, snapshot_config
|
|
26
|
-
from ..util import ensure_dir, now_ts
|
|
27
|
-
from .queue import MessageQueue, MessageType
|
|
28
|
-
from .inference_worker import InferenceConfig, run_inference_worker
|
|
29
|
-
from .trainer_worker import TrainerConfig, run_trainer_worker
|
|
30
|
-
|
|
31
16
|
|
|
32
17
|
console = Console()
|
|
33
18
|
|
|
@@ -35,394 +20,76 @@ console = Console()
|
|
|
35
20
|
@dataclass
|
|
36
21
|
class DaemonConfig:
|
|
37
22
|
"""Configuration for orchestrator daemon."""
|
|
23
|
+
|
|
38
24
|
project_root: Path
|
|
39
25
|
model_spec: str
|
|
40
|
-
|
|
41
|
-
# Process management
|
|
26
|
+
|
|
27
|
+
# Process management (reserved for future extensions)
|
|
42
28
|
inference_port: int = 8080
|
|
43
29
|
inference_host: str = "0.0.0.0"
|
|
44
30
|
max_restarts: int = 3
|
|
45
31
|
restart_delay: float = 5.0
|
|
46
32
|
health_check_interval: float = 10.0
|
|
47
|
-
|
|
33
|
+
|
|
48
34
|
# Training config
|
|
49
35
|
iterations: int = 50
|
|
50
36
|
tasks_per_iter: int = 80
|
|
51
37
|
rollouts_per_task: int = 8
|
|
52
38
|
batch_size: int = 32
|
|
53
|
-
|
|
54
|
-
# Paths
|
|
39
|
+
|
|
40
|
+
# Paths (currently derived from project_root in the orchestrator)
|
|
55
41
|
weights_dir: Optional[Path] = None
|
|
56
42
|
checkpoint_dir: Optional[Path] = None
|
|
57
|
-
|
|
43
|
+
|
|
58
44
|
# Gating
|
|
59
45
|
gating_mode: str = "strict"
|
|
60
46
|
gating_threshold: float = 0.0
|
|
61
47
|
gating_ema_alpha: float = 0.2
|
|
62
|
-
|
|
48
|
+
|
|
63
49
|
# Verifier
|
|
64
50
|
verifier_backend: str = "pytest"
|
|
65
51
|
verifier_timeout_s: int = 30
|
|
66
52
|
|
|
67
53
|
|
|
68
|
-
@dataclass
|
|
69
|
-
class ProcessHandle:
|
|
70
|
-
"""Handle for a managed process."""
|
|
71
|
-
name: str
|
|
72
|
-
process: mp.Process
|
|
73
|
-
config: Any
|
|
74
|
-
restart_count: int = 0
|
|
75
|
-
last_restart: float = 0.0
|
|
76
|
-
healthy: bool = True
|
|
77
|
-
start_time: float = field(default_factory=time.time)
|
|
78
|
-
|
|
79
|
-
|
|
80
54
|
class OrchestratorDaemon:
|
|
81
|
-
"""Orchestrator daemon for multi-process RLM.
|
|
82
|
-
|
|
83
|
-
Responsibilities:
|
|
84
|
-
- Spawn and manage inference and trainer processes
|
|
85
|
-
- Coordinate rollout requests and training batches
|
|
86
|
-
- Manage weight pointer updates
|
|
87
|
-
- Handle process lifecycle, monitoring, and restarts
|
|
88
|
-
- Graceful shutdown handling
|
|
89
|
-
"""
|
|
90
|
-
|
|
55
|
+
"""Orchestrator daemon wrapper for multi-process RLM."""
|
|
56
|
+
|
|
91
57
|
def __init__(self, config: DaemonConfig, project_cfg: ProjectConfig):
|
|
92
58
|
self.config = config
|
|
93
|
-
|
|
94
|
-
self.
|
|
95
|
-
self.
|
|
96
|
-
|
|
97
|
-
self.
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
# Setup paths
|
|
102
|
-
self._weights_dir = config.weights_dir or (config.project_root / "runs" / "rlm_weights")
|
|
103
|
-
self._checkpoint_dir = config.checkpoint_dir or (config.project_root / "runs" / "rlm_checkpoints")
|
|
104
|
-
self._state_path = config.project_root / "runs" / "rlm_state.json"
|
|
105
|
-
self._history_path = config.project_root / "runs" / "rlm_history.jsonl"
|
|
106
|
-
self._corpus_path = config.project_root / "runs" / "rlm_corpus.jsonl"
|
|
107
|
-
|
|
108
|
-
ensure_dir(self._weights_dir)
|
|
109
|
-
ensure_dir(self._checkpoint_dir)
|
|
110
|
-
|
|
111
|
-
def _setup_signal_handlers(self) -> None:
|
|
112
|
-
"""Setup signal handlers for graceful shutdown."""
|
|
113
|
-
def signal_handler(sig, frame):
|
|
114
|
-
console.print("[yellow]Orchestrator received shutdown signal[/yellow]")
|
|
115
|
-
self._shutdown = True
|
|
116
|
-
|
|
117
|
-
signal.signal(signal.SIGTERM, signal_handler)
|
|
118
|
-
signal.signal(signal.SIGINT, signal_handler)
|
|
119
|
-
|
|
120
|
-
def _spawn_inference_worker(self) -> ProcessHandle:
|
|
121
|
-
"""Spawn the inference worker process."""
|
|
122
|
-
inf_config = InferenceConfig(
|
|
59
|
+
# Work on a copy so daemon overrides don't mutate caller state.
|
|
60
|
+
self.project_cfg = project_cfg.model_copy(deep=True)
|
|
61
|
+
self._apply_overrides()
|
|
62
|
+
from ..rlm.loop import RLMOrchestrator
|
|
63
|
+
self._orchestrator = RLMOrchestrator(
|
|
64
|
+
project_root=self.config.project_root,
|
|
65
|
+
cfg=self.project_cfg,
|
|
123
66
|
model_spec=self.config.model_spec,
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
port=self.config.inference_port,
|
|
127
|
-
max_seq_len=self.project_cfg.model.max_seq_len,
|
|
128
|
-
dtype=self.project_cfg.model.dtype,
|
|
129
|
-
trust_remote_code=self.project_cfg.model.trust_remote_code,
|
|
130
|
-
use_chat_template=self.project_cfg.model.use_chat_template,
|
|
131
|
-
weights_dir=self._weights_dir,
|
|
132
|
-
hot_reload=True,
|
|
67
|
+
iterations=self.config.iterations,
|
|
68
|
+
resume=False,
|
|
133
69
|
)
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
process.start()
|
|
150
|
-
console.print(f"[green]Spawned inference worker (PID: {process.pid})[/green]")
|
|
151
|
-
|
|
152
|
-
return handle
|
|
153
|
-
|
|
154
|
-
def _spawn_trainer_worker(self) -> ProcessHandle:
|
|
155
|
-
"""Spawn the trainer worker process."""
|
|
156
|
-
# Resolve base model
|
|
157
|
-
from ..models import resolve_model_spec
|
|
158
|
-
base_model, adapter_path, _ = resolve_model_spec(
|
|
159
|
-
self.config.project_root, self.config.model_spec, self.project_cfg
|
|
160
|
-
)
|
|
161
|
-
|
|
162
|
-
trainer_config = TrainerConfig(
|
|
163
|
-
model_spec=self.config.model_spec,
|
|
164
|
-
base_model=base_model,
|
|
165
|
-
backend=self.project_cfg.model.backend,
|
|
166
|
-
max_seq_len=self.project_cfg.model.max_seq_len,
|
|
167
|
-
dtype=self.project_cfg.model.dtype,
|
|
168
|
-
trust_remote_code=self.project_cfg.model.trust_remote_code,
|
|
169
|
-
lr=self.project_cfg.train.lr,
|
|
170
|
-
weight_decay=self.project_cfg.train.weight_decay,
|
|
171
|
-
kl_coeff=self.project_cfg.rft.kl_coeff,
|
|
172
|
-
normalize_advantage=self.project_cfg.rft.normalize_advantage,
|
|
173
|
-
lora_r=self.project_cfg.lora.r,
|
|
174
|
-
lora_alpha=self.project_cfg.lora.alpha,
|
|
175
|
-
lora_dropout=self.project_cfg.lora.dropout,
|
|
176
|
-
lora_target_modules=list(self.project_cfg.lora.target_modules or []),
|
|
177
|
-
lora_num_layers=self.project_cfg.lora.num_layers,
|
|
178
|
-
weights_dir=self._weights_dir,
|
|
179
|
-
checkpoint_dir=self._checkpoint_dir,
|
|
180
|
-
reference_model=self.project_cfg.rft.reference_model,
|
|
181
|
-
)
|
|
182
|
-
|
|
183
|
-
# Create process
|
|
184
|
-
process = mp.Process(
|
|
185
|
-
target=run_trainer_worker,
|
|
186
|
-
args=(trainer_config, self.queue),
|
|
187
|
-
name="trainer_worker",
|
|
188
|
-
daemon=False,
|
|
189
|
-
)
|
|
190
|
-
|
|
191
|
-
handle = ProcessHandle(
|
|
192
|
-
name="trainer",
|
|
193
|
-
process=process,
|
|
194
|
-
config=trainer_config,
|
|
195
|
-
)
|
|
196
|
-
|
|
197
|
-
process.start()
|
|
198
|
-
console.print(f"[green]Spawned trainer worker (PID: {process.pid})[/green]")
|
|
199
|
-
|
|
200
|
-
return handle
|
|
201
|
-
|
|
202
|
-
def _monitor_processes(self) -> None:
|
|
203
|
-
"""Monitor processes and restart if needed."""
|
|
204
|
-
current_time = time.time()
|
|
205
|
-
|
|
206
|
-
for name, handle in list(self._processes.items()):
|
|
207
|
-
# Check if process is alive
|
|
208
|
-
if not handle.process.is_alive():
|
|
209
|
-
if self._shutdown:
|
|
210
|
-
continue
|
|
211
|
-
|
|
212
|
-
console.print(f"[red]Process {name} (PID: {handle.process.pid}) died[/red]")
|
|
213
|
-
handle.healthy = False
|
|
214
|
-
|
|
215
|
-
# Check restart limit
|
|
216
|
-
if handle.restart_count >= self.config.max_restarts:
|
|
217
|
-
console.print(f"[red]Process {name} exceeded max restarts[/red]")
|
|
218
|
-
continue
|
|
219
|
-
|
|
220
|
-
# Check restart delay
|
|
221
|
-
if current_time - handle.last_restart < self.config.restart_delay:
|
|
222
|
-
time.sleep(self.config.restart_delay)
|
|
223
|
-
|
|
224
|
-
# Restart process
|
|
225
|
-
console.print(f"[yellow]Restarting {name}...[/yellow]")
|
|
226
|
-
|
|
227
|
-
if name == "inference":
|
|
228
|
-
new_handle = self._spawn_inference_worker()
|
|
229
|
-
elif name == "trainer":
|
|
230
|
-
new_handle = self._spawn_trainer_worker()
|
|
231
|
-
else:
|
|
232
|
-
continue
|
|
233
|
-
|
|
234
|
-
new_handle.restart_count = handle.restart_count + 1
|
|
235
|
-
new_handle.last_restart = current_time
|
|
236
|
-
self._processes[name] = new_handle
|
|
237
|
-
|
|
238
|
-
def _health_check(self) -> Dict[str, Any]:
|
|
239
|
-
"""Perform health checks on all processes via queues."""
|
|
240
|
-
results = {}
|
|
241
|
-
|
|
242
|
-
# Check inference via queue
|
|
243
|
-
if "inference" in self._processes:
|
|
244
|
-
self.queue.send(
|
|
245
|
-
"control",
|
|
246
|
-
MessageType.HEALTH_CHECK,
|
|
247
|
-
{},
|
|
248
|
-
source="daemon",
|
|
249
|
-
)
|
|
250
|
-
# Response will be processed in main loop
|
|
251
|
-
|
|
252
|
-
# Check trainer via queue
|
|
253
|
-
if "trainer" in self._processes:
|
|
254
|
-
self.queue.send(
|
|
255
|
-
"train_batches", # Trainer reads from train_batches
|
|
256
|
-
MessageType.HEALTH_CHECK,
|
|
257
|
-
{},
|
|
258
|
-
source="daemon",
|
|
259
|
-
)
|
|
260
|
-
|
|
261
|
-
return results
|
|
262
|
-
|
|
263
|
-
def _forward_weight_updates(self) -> None:
|
|
264
|
-
"""Forward weight updates from trainer to inference."""
|
|
265
|
-
# Check for weight updates from trainer
|
|
266
|
-
msg = self.queue.receive("weight_updates", timeout=0)
|
|
267
|
-
if msg and msg.msg_type == MessageType.WEIGHT_UPDATE:
|
|
268
|
-
# Forward to inference worker
|
|
269
|
-
self.queue.send(
|
|
270
|
-
"weight_forward",
|
|
271
|
-
MessageType.WEIGHT_UPDATE,
|
|
272
|
-
msg.payload,
|
|
273
|
-
source="daemon",
|
|
274
|
-
)
|
|
275
|
-
|
|
276
|
-
# Also update inference pointer
|
|
277
|
-
if self._pointer_store:
|
|
278
|
-
pointer = WeightPointerIPC(
|
|
279
|
-
base_model=msg.payload.get("base_model", ""),
|
|
280
|
-
adapter_path=msg.payload.get("adapter_path"),
|
|
281
|
-
iteration=msg.payload.get("version", 0),
|
|
282
|
-
updated_at=now_ts(),
|
|
283
|
-
version=msg.payload.get("version", 0),
|
|
284
|
-
name="inference",
|
|
285
|
-
)
|
|
286
|
-
self._pointer_store.save(pointer)
|
|
287
|
-
console.print(f"[blue]Forwarded weight update: {pointer.adapter_path}[/blue]")
|
|
288
|
-
|
|
289
|
-
def _shutdown_all(self) -> None:
|
|
290
|
-
"""Shutdown all processes gracefully."""
|
|
291
|
-
console.print("[yellow]Shutting down all processes...[/yellow]")
|
|
292
|
-
|
|
293
|
-
# Send shutdown messages
|
|
294
|
-
for name in self._processes:
|
|
295
|
-
self.queue.send(
|
|
296
|
-
"control",
|
|
297
|
-
MessageType.SHUTDOWN,
|
|
298
|
-
{},
|
|
299
|
-
source="daemon",
|
|
300
|
-
)
|
|
301
|
-
|
|
302
|
-
# Wait for processes to terminate
|
|
303
|
-
for name, handle in self._processes.items():
|
|
304
|
-
console.print(f" Waiting for {name}...")
|
|
305
|
-
handle.process.join(timeout=10.0)
|
|
306
|
-
|
|
307
|
-
if handle.process.is_alive():
|
|
308
|
-
console.print(f" Force terminating {name}")
|
|
309
|
-
handle.process.terminate()
|
|
310
|
-
handle.process.join(timeout=5.0)
|
|
311
|
-
|
|
312
|
-
if handle.process.is_alive():
|
|
313
|
-
handle.process.kill()
|
|
314
|
-
|
|
315
|
-
# Stop queue manager
|
|
316
|
-
self.queue.stop()
|
|
317
|
-
|
|
318
|
-
console.print("[green]All processes shutdown[/green]")
|
|
319
|
-
|
|
70
|
+
|
|
71
|
+
def _apply_overrides(self) -> None:
|
|
72
|
+
"""Apply daemon config overrides onto the project config."""
|
|
73
|
+
self.project_cfg.serve.host = self.config.inference_host
|
|
74
|
+
self.project_cfg.serve.port = self.config.inference_port
|
|
75
|
+
|
|
76
|
+
self.project_cfg.rlm.iterations = self.config.iterations
|
|
77
|
+
self.project_cfg.rlm.tasks_per_iter = self.config.tasks_per_iter
|
|
78
|
+
self.project_cfg.rlm.rollouts_per_task = self.config.rollouts_per_task
|
|
79
|
+
self.project_cfg.rlm.gating = self.config.gating_mode
|
|
80
|
+
self.project_cfg.rlm.gating_threshold = self.config.gating_threshold
|
|
81
|
+
self.project_cfg.rlm.gating_ema_alpha = self.config.gating_ema_alpha
|
|
82
|
+
self.project_cfg.rlm.verifier_backend = self.config.verifier_backend
|
|
83
|
+
self.project_cfg.rlm.verifier_timeout_s = self.config.verifier_timeout_s
|
|
84
|
+
|
|
320
85
|
def run_iteration(self, iteration: int) -> bool:
|
|
321
|
-
"""Run a single
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
"""
|
|
325
|
-
console.print(f"\n[bold blue]=== RLM Iteration {iteration} ===[/bold blue]")
|
|
326
|
-
|
|
327
|
-
run = new_run(self.config.project_root, "rlm")
|
|
328
|
-
snapshot_config(self.project_cfg.model_dump(), run.config_snapshot_path)
|
|
329
|
-
|
|
330
|
-
# Phase 1: Generate tasks (via inference worker API)
|
|
331
|
-
console.print(" [dim]Generating tasks...[/dim]")
|
|
332
|
-
# Tasks are generated by querying inference worker
|
|
333
|
-
|
|
334
|
-
# Phase 2: Collect rollouts (via inference worker)
|
|
335
|
-
console.print(" [dim]Collecting rollouts...[/dim]")
|
|
336
|
-
# Rollouts are generated via /internal/rollout endpoint
|
|
337
|
-
|
|
338
|
-
# Phase 3: Send training batch to trainer
|
|
339
|
-
console.print(" [dim]Sending training batch...[/dim]")
|
|
340
|
-
|
|
341
|
-
# Phase 4: Wait for training completion
|
|
342
|
-
console.print(" [dim]Waiting for training...[/dim]")
|
|
343
|
-
|
|
344
|
-
# This is a placeholder - actual implementation would
|
|
345
|
-
# coordinate via queues and the inference worker API
|
|
346
|
-
|
|
347
|
-
return True
|
|
348
|
-
|
|
86
|
+
"""Run a single orchestrated iteration."""
|
|
87
|
+
return self._orchestrator.run_iteration(iteration)
|
|
88
|
+
|
|
349
89
|
def run(self) -> None:
|
|
350
|
-
"""Run the
|
|
351
|
-
self._setup_signal_handlers()
|
|
352
|
-
|
|
90
|
+
"""Run the orchestrated RLM loop."""
|
|
353
91
|
console.print("[bold green]Starting MLXSmith Orchestrator[/bold green]")
|
|
354
|
-
|
|
355
|
-
# Start queue manager
|
|
356
|
-
self.queue.start()
|
|
357
|
-
console.print("[dim]Queue manager started[/dim]")
|
|
358
|
-
|
|
359
|
-
# Initialize weight pointer store
|
|
360
|
-
self._pointer_store = WeightPointerStore(self._weights_dir)
|
|
361
|
-
console.print(f"[dim]Weight store: {self._weights_dir}[/dim]")
|
|
362
|
-
|
|
363
|
-
# Spawn worker processes
|
|
364
|
-
console.print("[dim]Spawning worker processes...[/dim]")
|
|
365
|
-
self._processes["inference"] = self._spawn_inference_worker()
|
|
366
|
-
self._processes["trainer"] = self._spawn_trainer_worker()
|
|
367
|
-
|
|
368
|
-
# Wait for processes to initialize
|
|
369
|
-
console.print("[dim]Waiting for workers to initialize...[/dim]")
|
|
370
|
-
time.sleep(5.0)
|
|
371
|
-
|
|
372
|
-
# Load state
|
|
373
|
-
state = load_state(self._state_path)
|
|
374
|
-
self._current_iteration = state.last_iteration + 1
|
|
375
|
-
|
|
376
|
-
last_health_check = time.time()
|
|
377
|
-
|
|
378
|
-
try:
|
|
379
|
-
# Main orchestrator loop
|
|
380
|
-
while not self._shutdown:
|
|
381
|
-
# Monitor processes
|
|
382
|
-
self._monitor_processes()
|
|
383
|
-
|
|
384
|
-
# Health checks
|
|
385
|
-
current_time = time.time()
|
|
386
|
-
if current_time - last_health_check > self.config.health_check_interval:
|
|
387
|
-
self._health_check()
|
|
388
|
-
last_health_check = current_time
|
|
389
|
-
|
|
390
|
-
# Forward weight updates
|
|
391
|
-
self._forward_weight_updates()
|
|
392
|
-
|
|
393
|
-
# Process queue messages
|
|
394
|
-
self._process_queue_messages()
|
|
395
|
-
|
|
396
|
-
# Small sleep to prevent busy waiting
|
|
397
|
-
time.sleep(0.01)
|
|
398
|
-
|
|
399
|
-
except KeyboardInterrupt:
|
|
400
|
-
console.print("[yellow]Interrupted by user[/yellow]")
|
|
401
|
-
except Exception as e:
|
|
402
|
-
console.print(f"[red]Orchestrator error: {e}[/red]")
|
|
403
|
-
traceback.print_exc()
|
|
404
|
-
finally:
|
|
405
|
-
self._shutdown_all()
|
|
406
|
-
|
|
407
|
-
def _process_queue_messages(self) -> None:
|
|
408
|
-
"""Process pending queue messages."""
|
|
409
|
-
# Process responses from inference
|
|
410
|
-
msg = self.queue.receive("rollout_responses", timeout=0)
|
|
411
|
-
if msg:
|
|
412
|
-
# Handle rollout response
|
|
413
|
-
pass
|
|
414
|
-
|
|
415
|
-
# Process training completion
|
|
416
|
-
msg = self.queue.receive("train_complete", timeout=0)
|
|
417
|
-
if msg:
|
|
418
|
-
# Handle training completion
|
|
419
|
-
pass
|
|
420
|
-
|
|
421
|
-
# Process checkpoints
|
|
422
|
-
msg = self.queue.receive("checkpoints", timeout=0)
|
|
423
|
-
if msg:
|
|
424
|
-
# Handle checkpoint notification
|
|
425
|
-
pass
|
|
92
|
+
self._orchestrator.run()
|
|
426
93
|
|
|
427
94
|
|
|
428
95
|
def run_daemon(
|
|
@@ -444,6 +111,6 @@ def run_daemon(
|
|
|
444
111
|
verifier_backend=project_cfg.rlm.verifier_backend,
|
|
445
112
|
verifier_timeout_s=project_cfg.rlm.verifier_timeout_s,
|
|
446
113
|
)
|
|
447
|
-
|
|
114
|
+
|
|
448
115
|
daemon = OrchestratorDaemon(config, project_cfg)
|
|
449
116
|
daemon.run()
|
|
@@ -37,6 +37,8 @@ class TrainerConfig:
|
|
|
37
37
|
# Training config
|
|
38
38
|
lr: float = 2e-4
|
|
39
39
|
weight_decay: float = 0.0
|
|
40
|
+
optimizer: str = "adamw"
|
|
41
|
+
optimizer_kwargs: Dict[str, Any] = field(default_factory=dict)
|
|
40
42
|
kl_coeff: float = 0.02
|
|
41
43
|
normalize_advantage: bool = True
|
|
42
44
|
|
|
@@ -129,6 +131,8 @@ class TrainerWorker:
|
|
|
129
131
|
self._optimizer, _ = self._llm.optimizer_and_params(
|
|
130
132
|
lr=self.config.lr,
|
|
131
133
|
weight_decay=self.config.weight_decay,
|
|
134
|
+
optimizer=self.config.optimizer,
|
|
135
|
+
optimizer_kwargs=self.config.optimizer_kwargs,
|
|
132
136
|
)
|
|
133
137
|
|
|
134
138
|
# Load reference model if needed for KL
|