mlxsmith 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,33 +1,18 @@
1
1
  """Orchestrator Daemon for MLXSmith Multi-Process RLM.
2
2
 
3
- Central queue-based job scheduler that coordinates between:
4
- - Inference server (generates rollouts)
5
- - Trainer worker (consumes batches and updates weights)
6
-
7
- Manages rollout requests, training batches, and weight updates.
3
+ Thin wrapper around the orchestrated RLM loop to keep the legacy daemon API
4
+ usable while delegating implementation to the maintained orchestrator.
8
5
  """
9
6
 
10
7
  from __future__ import annotations
11
8
 
12
- import multiprocessing as mp
13
- import signal
14
- import time
15
- import traceback
16
- from dataclasses import dataclass, field
9
+ from dataclasses import dataclass
17
10
  from pathlib import Path
18
- from typing import Any, Dict, List, Optional
11
+ from typing import Optional
19
12
 
20
13
  from rich.console import Console
21
14
 
22
15
  from ..config import ProjectConfig
23
- from ..rlm.gating import load_state
24
- from ..rlm.weights import WeightPointerStore, WeightPointerIPC
25
- from ..runs import new_run, snapshot_config
26
- from ..util import ensure_dir, now_ts
27
- from .queue import MessageQueue, MessageType
28
- from .inference_worker import InferenceConfig, run_inference_worker
29
- from .trainer_worker import TrainerConfig, run_trainer_worker
30
-
31
16
 
32
17
  console = Console()
33
18
 
@@ -35,394 +20,76 @@ console = Console()
35
20
  @dataclass
36
21
  class DaemonConfig:
37
22
  """Configuration for orchestrator daemon."""
23
+
38
24
  project_root: Path
39
25
  model_spec: str
40
-
41
- # Process management
26
+
27
+ # Process management (reserved for future extensions)
42
28
  inference_port: int = 8080
43
29
  inference_host: str = "0.0.0.0"
44
30
  max_restarts: int = 3
45
31
  restart_delay: float = 5.0
46
32
  health_check_interval: float = 10.0
47
-
33
+
48
34
  # Training config
49
35
  iterations: int = 50
50
36
  tasks_per_iter: int = 80
51
37
  rollouts_per_task: int = 8
52
38
  batch_size: int = 32
53
-
54
- # Paths
39
+
40
+ # Paths (currently derived from project_root in the orchestrator)
55
41
  weights_dir: Optional[Path] = None
56
42
  checkpoint_dir: Optional[Path] = None
57
-
43
+
58
44
  # Gating
59
45
  gating_mode: str = "strict"
60
46
  gating_threshold: float = 0.0
61
47
  gating_ema_alpha: float = 0.2
62
-
48
+
63
49
  # Verifier
64
50
  verifier_backend: str = "pytest"
65
51
  verifier_timeout_s: int = 30
66
52
 
67
53
 
68
- @dataclass
69
- class ProcessHandle:
70
- """Handle for a managed process."""
71
- name: str
72
- process: mp.Process
73
- config: Any
74
- restart_count: int = 0
75
- last_restart: float = 0.0
76
- healthy: bool = True
77
- start_time: float = field(default_factory=time.time)
78
-
79
-
80
54
  class OrchestratorDaemon:
81
- """Orchestrator daemon for multi-process RLM.
82
-
83
- Responsibilities:
84
- - Spawn and manage inference and trainer processes
85
- - Coordinate rollout requests and training batches
86
- - Manage weight pointer updates
87
- - Handle process lifecycle, monitoring, and restarts
88
- - Graceful shutdown handling
89
- """
90
-
55
+ """Orchestrator daemon wrapper for multi-process RLM."""
56
+
91
57
  def __init__(self, config: DaemonConfig, project_cfg: ProjectConfig):
92
58
  self.config = config
93
- self.project_cfg = project_cfg
94
- self.queue = MessageQueue(maxsize=10000)
95
- self._processes: Dict[str, ProcessHandle] = {}
96
- self._shutdown = False
97
- self._pointer_store: Optional[WeightPointerStore] = None
98
- self._current_iteration = 0
99
- self._metrics: List[Dict] = []
100
-
101
- # Setup paths
102
- self._weights_dir = config.weights_dir or (config.project_root / "runs" / "rlm_weights")
103
- self._checkpoint_dir = config.checkpoint_dir or (config.project_root / "runs" / "rlm_checkpoints")
104
- self._state_path = config.project_root / "runs" / "rlm_state.json"
105
- self._history_path = config.project_root / "runs" / "rlm_history.jsonl"
106
- self._corpus_path = config.project_root / "runs" / "rlm_corpus.jsonl"
107
-
108
- ensure_dir(self._weights_dir)
109
- ensure_dir(self._checkpoint_dir)
110
-
111
- def _setup_signal_handlers(self) -> None:
112
- """Setup signal handlers for graceful shutdown."""
113
- def signal_handler(sig, frame):
114
- console.print("[yellow]Orchestrator received shutdown signal[/yellow]")
115
- self._shutdown = True
116
-
117
- signal.signal(signal.SIGTERM, signal_handler)
118
- signal.signal(signal.SIGINT, signal_handler)
119
-
120
- def _spawn_inference_worker(self) -> ProcessHandle:
121
- """Spawn the inference worker process."""
122
- inf_config = InferenceConfig(
59
+ # Work on a copy so daemon overrides don't mutate caller state.
60
+ self.project_cfg = project_cfg.model_copy(deep=True)
61
+ self._apply_overrides()
62
+ from ..rlm.loop import RLMOrchestrator
63
+ self._orchestrator = RLMOrchestrator(
64
+ project_root=self.config.project_root,
65
+ cfg=self.project_cfg,
123
66
  model_spec=self.config.model_spec,
124
- backend=self.project_cfg.model.backend,
125
- host=self.config.inference_host,
126
- port=self.config.inference_port,
127
- max_seq_len=self.project_cfg.model.max_seq_len,
128
- dtype=self.project_cfg.model.dtype,
129
- trust_remote_code=self.project_cfg.model.trust_remote_code,
130
- use_chat_template=self.project_cfg.model.use_chat_template,
131
- weights_dir=self._weights_dir,
132
- hot_reload=True,
67
+ iterations=self.config.iterations,
68
+ resume=False,
133
69
  )
134
-
135
- # Create process
136
- process = mp.Process(
137
- target=run_inference_worker,
138
- args=(inf_config, self.queue),
139
- name="inference_worker",
140
- daemon=False,
141
- )
142
-
143
- handle = ProcessHandle(
144
- name="inference",
145
- process=process,
146
- config=inf_config,
147
- )
148
-
149
- process.start()
150
- console.print(f"[green]Spawned inference worker (PID: {process.pid})[/green]")
151
-
152
- return handle
153
-
154
- def _spawn_trainer_worker(self) -> ProcessHandle:
155
- """Spawn the trainer worker process."""
156
- # Resolve base model
157
- from ..models import resolve_model_spec
158
- base_model, adapter_path, _ = resolve_model_spec(
159
- self.config.project_root, self.config.model_spec, self.project_cfg
160
- )
161
-
162
- trainer_config = TrainerConfig(
163
- model_spec=self.config.model_spec,
164
- base_model=base_model,
165
- backend=self.project_cfg.model.backend,
166
- max_seq_len=self.project_cfg.model.max_seq_len,
167
- dtype=self.project_cfg.model.dtype,
168
- trust_remote_code=self.project_cfg.model.trust_remote_code,
169
- lr=self.project_cfg.train.lr,
170
- weight_decay=self.project_cfg.train.weight_decay,
171
- kl_coeff=self.project_cfg.rft.kl_coeff,
172
- normalize_advantage=self.project_cfg.rft.normalize_advantage,
173
- lora_r=self.project_cfg.lora.r,
174
- lora_alpha=self.project_cfg.lora.alpha,
175
- lora_dropout=self.project_cfg.lora.dropout,
176
- lora_target_modules=list(self.project_cfg.lora.target_modules or []),
177
- lora_num_layers=self.project_cfg.lora.num_layers,
178
- weights_dir=self._weights_dir,
179
- checkpoint_dir=self._checkpoint_dir,
180
- reference_model=self.project_cfg.rft.reference_model,
181
- )
182
-
183
- # Create process
184
- process = mp.Process(
185
- target=run_trainer_worker,
186
- args=(trainer_config, self.queue),
187
- name="trainer_worker",
188
- daemon=False,
189
- )
190
-
191
- handle = ProcessHandle(
192
- name="trainer",
193
- process=process,
194
- config=trainer_config,
195
- )
196
-
197
- process.start()
198
- console.print(f"[green]Spawned trainer worker (PID: {process.pid})[/green]")
199
-
200
- return handle
201
-
202
- def _monitor_processes(self) -> None:
203
- """Monitor processes and restart if needed."""
204
- current_time = time.time()
205
-
206
- for name, handle in list(self._processes.items()):
207
- # Check if process is alive
208
- if not handle.process.is_alive():
209
- if self._shutdown:
210
- continue
211
-
212
- console.print(f"[red]Process {name} (PID: {handle.process.pid}) died[/red]")
213
- handle.healthy = False
214
-
215
- # Check restart limit
216
- if handle.restart_count >= self.config.max_restarts:
217
- console.print(f"[red]Process {name} exceeded max restarts[/red]")
218
- continue
219
-
220
- # Check restart delay
221
- if current_time - handle.last_restart < self.config.restart_delay:
222
- time.sleep(self.config.restart_delay)
223
-
224
- # Restart process
225
- console.print(f"[yellow]Restarting {name}...[/yellow]")
226
-
227
- if name == "inference":
228
- new_handle = self._spawn_inference_worker()
229
- elif name == "trainer":
230
- new_handle = self._spawn_trainer_worker()
231
- else:
232
- continue
233
-
234
- new_handle.restart_count = handle.restart_count + 1
235
- new_handle.last_restart = current_time
236
- self._processes[name] = new_handle
237
-
238
- def _health_check(self) -> Dict[str, Any]:
239
- """Perform health checks on all processes via queues."""
240
- results = {}
241
-
242
- # Check inference via queue
243
- if "inference" in self._processes:
244
- self.queue.send(
245
- "control",
246
- MessageType.HEALTH_CHECK,
247
- {},
248
- source="daemon",
249
- )
250
- # Response will be processed in main loop
251
-
252
- # Check trainer via queue
253
- if "trainer" in self._processes:
254
- self.queue.send(
255
- "train_batches", # Trainer reads from train_batches
256
- MessageType.HEALTH_CHECK,
257
- {},
258
- source="daemon",
259
- )
260
-
261
- return results
262
-
263
- def _forward_weight_updates(self) -> None:
264
- """Forward weight updates from trainer to inference."""
265
- # Check for weight updates from trainer
266
- msg = self.queue.receive("weight_updates", timeout=0)
267
- if msg and msg.msg_type == MessageType.WEIGHT_UPDATE:
268
- # Forward to inference worker
269
- self.queue.send(
270
- "weight_forward",
271
- MessageType.WEIGHT_UPDATE,
272
- msg.payload,
273
- source="daemon",
274
- )
275
-
276
- # Also update inference pointer
277
- if self._pointer_store:
278
- pointer = WeightPointerIPC(
279
- base_model=msg.payload.get("base_model", ""),
280
- adapter_path=msg.payload.get("adapter_path"),
281
- iteration=msg.payload.get("version", 0),
282
- updated_at=now_ts(),
283
- version=msg.payload.get("version", 0),
284
- name="inference",
285
- )
286
- self._pointer_store.save(pointer)
287
- console.print(f"[blue]Forwarded weight update: {pointer.adapter_path}[/blue]")
288
-
289
- def _shutdown_all(self) -> None:
290
- """Shutdown all processes gracefully."""
291
- console.print("[yellow]Shutting down all processes...[/yellow]")
292
-
293
- # Send shutdown messages
294
- for name in self._processes:
295
- self.queue.send(
296
- "control",
297
- MessageType.SHUTDOWN,
298
- {},
299
- source="daemon",
300
- )
301
-
302
- # Wait for processes to terminate
303
- for name, handle in self._processes.items():
304
- console.print(f" Waiting for {name}...")
305
- handle.process.join(timeout=10.0)
306
-
307
- if handle.process.is_alive():
308
- console.print(f" Force terminating {name}")
309
- handle.process.terminate()
310
- handle.process.join(timeout=5.0)
311
-
312
- if handle.process.is_alive():
313
- handle.process.kill()
314
-
315
- # Stop queue manager
316
- self.queue.stop()
317
-
318
- console.print("[green]All processes shutdown[/green]")
319
-
70
+
71
+ def _apply_overrides(self) -> None:
72
+ """Apply daemon config overrides onto the project config."""
73
+ self.project_cfg.serve.host = self.config.inference_host
74
+ self.project_cfg.serve.port = self.config.inference_port
75
+
76
+ self.project_cfg.rlm.iterations = self.config.iterations
77
+ self.project_cfg.rlm.tasks_per_iter = self.config.tasks_per_iter
78
+ self.project_cfg.rlm.rollouts_per_task = self.config.rollouts_per_task
79
+ self.project_cfg.rlm.gating = self.config.gating_mode
80
+ self.project_cfg.rlm.gating_threshold = self.config.gating_threshold
81
+ self.project_cfg.rlm.gating_ema_alpha = self.config.gating_ema_alpha
82
+ self.project_cfg.rlm.verifier_backend = self.config.verifier_backend
83
+ self.project_cfg.rlm.verifier_timeout_s = self.config.verifier_timeout_s
84
+
320
85
  def run_iteration(self, iteration: int) -> bool:
321
- """Run a single RLM iteration.
322
-
323
- Returns True if iteration completed successfully.
324
- """
325
- console.print(f"\n[bold blue]=== RLM Iteration {iteration} ===[/bold blue]")
326
-
327
- run = new_run(self.config.project_root, "rlm")
328
- snapshot_config(self.project_cfg.model_dump(), run.config_snapshot_path)
329
-
330
- # Phase 1: Generate tasks (via inference worker API)
331
- console.print(" [dim]Generating tasks...[/dim]")
332
- # Tasks are generated by querying inference worker
333
-
334
- # Phase 2: Collect rollouts (via inference worker)
335
- console.print(" [dim]Collecting rollouts...[/dim]")
336
- # Rollouts are generated via /internal/rollout endpoint
337
-
338
- # Phase 3: Send training batch to trainer
339
- console.print(" [dim]Sending training batch...[/dim]")
340
-
341
- # Phase 4: Wait for training completion
342
- console.print(" [dim]Waiting for training...[/dim]")
343
-
344
- # This is a placeholder - actual implementation would
345
- # coordinate via queues and the inference worker API
346
-
347
- return True
348
-
86
+ """Run a single orchestrated iteration."""
87
+ return self._orchestrator.run_iteration(iteration)
88
+
349
89
  def run(self) -> None:
350
- """Run the orchestrator daemon."""
351
- self._setup_signal_handlers()
352
-
90
+ """Run the orchestrated RLM loop."""
353
91
  console.print("[bold green]Starting MLXSmith Orchestrator[/bold green]")
354
-
355
- # Start queue manager
356
- self.queue.start()
357
- console.print("[dim]Queue manager started[/dim]")
358
-
359
- # Initialize weight pointer store
360
- self._pointer_store = WeightPointerStore(self._weights_dir)
361
- console.print(f"[dim]Weight store: {self._weights_dir}[/dim]")
362
-
363
- # Spawn worker processes
364
- console.print("[dim]Spawning worker processes...[/dim]")
365
- self._processes["inference"] = self._spawn_inference_worker()
366
- self._processes["trainer"] = self._spawn_trainer_worker()
367
-
368
- # Wait for processes to initialize
369
- console.print("[dim]Waiting for workers to initialize...[/dim]")
370
- time.sleep(5.0)
371
-
372
- # Load state
373
- state = load_state(self._state_path)
374
- self._current_iteration = state.last_iteration + 1
375
-
376
- last_health_check = time.time()
377
-
378
- try:
379
- # Main orchestrator loop
380
- while not self._shutdown:
381
- # Monitor processes
382
- self._monitor_processes()
383
-
384
- # Health checks
385
- current_time = time.time()
386
- if current_time - last_health_check > self.config.health_check_interval:
387
- self._health_check()
388
- last_health_check = current_time
389
-
390
- # Forward weight updates
391
- self._forward_weight_updates()
392
-
393
- # Process queue messages
394
- self._process_queue_messages()
395
-
396
- # Small sleep to prevent busy waiting
397
- time.sleep(0.01)
398
-
399
- except KeyboardInterrupt:
400
- console.print("[yellow]Interrupted by user[/yellow]")
401
- except Exception as e:
402
- console.print(f"[red]Orchestrator error: {e}[/red]")
403
- traceback.print_exc()
404
- finally:
405
- self._shutdown_all()
406
-
407
- def _process_queue_messages(self) -> None:
408
- """Process pending queue messages."""
409
- # Process responses from inference
410
- msg = self.queue.receive("rollout_responses", timeout=0)
411
- if msg:
412
- # Handle rollout response
413
- pass
414
-
415
- # Process training completion
416
- msg = self.queue.receive("train_complete", timeout=0)
417
- if msg:
418
- # Handle training completion
419
- pass
420
-
421
- # Process checkpoints
422
- msg = self.queue.receive("checkpoints", timeout=0)
423
- if msg:
424
- # Handle checkpoint notification
425
- pass
92
+ self._orchestrator.run()
426
93
 
427
94
 
428
95
  def run_daemon(
@@ -444,6 +111,6 @@ def run_daemon(
444
111
  verifier_backend=project_cfg.rlm.verifier_backend,
445
112
  verifier_timeout_s=project_cfg.rlm.verifier_timeout_s,
446
113
  )
447
-
114
+
448
115
  daemon = OrchestratorDaemon(config, project_cfg)
449
116
  daemon.run()
@@ -37,6 +37,8 @@ class TrainerConfig:
37
37
  # Training config
38
38
  lr: float = 2e-4
39
39
  weight_decay: float = 0.0
40
+ optimizer: str = "adamw"
41
+ optimizer_kwargs: Dict[str, Any] = field(default_factory=dict)
40
42
  kl_coeff: float = 0.02
41
43
  normalize_advantage: bool = True
42
44
 
@@ -129,6 +131,8 @@ class TrainerWorker:
129
131
  self._optimizer, _ = self._llm.optimizer_and_params(
130
132
  lr=self.config.lr,
131
133
  weight_decay=self.config.weight_decay,
134
+ optimizer=self.config.optimizer,
135
+ optimizer_kwargs=self.config.optimizer_kwargs,
132
136
  )
133
137
 
134
138
  # Load reference model if needed for KL