mlxsmith 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. mlxsmith/__init__.py +2 -0
  2. mlxsmith/accel/__init__.py +10 -0
  3. mlxsmith/accel/base.py +17 -0
  4. mlxsmith/accel/none.py +13 -0
  5. mlxsmith/accel/zmlx_backend.py +42 -0
  6. mlxsmith/adapters.py +46 -0
  7. mlxsmith/api/__init__.py +48 -0
  8. mlxsmith/api/handlers.py +1217 -0
  9. mlxsmith/api/schemas.py +436 -0
  10. mlxsmith/auth.py +88 -0
  11. mlxsmith/bench.py +102 -0
  12. mlxsmith/cli.py +950 -0
  13. mlxsmith/config.py +543 -0
  14. mlxsmith/config_models.py +261 -0
  15. mlxsmith/data.py +493 -0
  16. mlxsmith/envs/__init__.py +33 -0
  17. mlxsmith/envs/system.py +388 -0
  18. mlxsmith/envs/token_env.py +191 -0
  19. mlxsmith/eval.py +112 -0
  20. mlxsmith/infer.py +140 -0
  21. mlxsmith/llm/__init__.py +16 -0
  22. mlxsmith/llm/backend.py +126 -0
  23. mlxsmith/llm/interface.py +212 -0
  24. mlxsmith/llm/mlx_lm_backend.py +509 -0
  25. mlxsmith/llm/mock_backend.py +228 -0
  26. mlxsmith/llm/registry.py +12 -0
  27. mlxsmith/models.py +257 -0
  28. mlxsmith/orchestrator/__init__.py +25 -0
  29. mlxsmith/orchestrator/daemon.py +454 -0
  30. mlxsmith/orchestrator/inference_worker.py +496 -0
  31. mlxsmith/orchestrator/queue.py +355 -0
  32. mlxsmith/orchestrator/trainer_worker.py +437 -0
  33. mlxsmith/rlm/__init__.py +8 -0
  34. mlxsmith/rlm/corpus.py +74 -0
  35. mlxsmith/rlm/gating.py +90 -0
  36. mlxsmith/rlm/generate.py +249 -0
  37. mlxsmith/rlm/history.py +12 -0
  38. mlxsmith/rlm/inference.py +150 -0
  39. mlxsmith/rlm/loop.py +1297 -0
  40. mlxsmith/rlm/mutate.py +82 -0
  41. mlxsmith/rlm/trainer.py +73 -0
  42. mlxsmith/rlm/weights.py +263 -0
  43. mlxsmith/runs.py +44 -0
  44. mlxsmith/sdk/__init__.py +392 -0
  45. mlxsmith/sdk/future.py +486 -0
  46. mlxsmith/sdk/losses.py +262 -0
  47. mlxsmith/sdk/sampling_client.py +729 -0
  48. mlxsmith/sdk/training_client.py +676 -0
  49. mlxsmith/server.py +376 -0
  50. mlxsmith/train/__init__.py +0 -0
  51. mlxsmith/train/distill.py +279 -0
  52. mlxsmith/train/lora.py +280 -0
  53. mlxsmith/train/pref.py +180 -0
  54. mlxsmith/train/rft.py +458 -0
  55. mlxsmith/train/sft.py +151 -0
  56. mlxsmith/util.py +174 -0
  57. mlxsmith/verifiers/__init__.py +3 -0
  58. mlxsmith/verifiers/compose.py +109 -0
  59. mlxsmith/verifiers/docker_verifier.py +111 -0
  60. mlxsmith/verifiers/jsonschema.py +54 -0
  61. mlxsmith/verifiers/pytest_verifier.py +82 -0
  62. mlxsmith/verifiers/regex.py +15 -0
  63. mlxsmith/verifiers/types.py +10 -0
  64. mlxsmith-0.1.0.dist-info/METADATA +163 -0
  65. mlxsmith-0.1.0.dist-info/RECORD +69 -0
  66. mlxsmith-0.1.0.dist-info/WHEEL +5 -0
  67. mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
  68. mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
  69. mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,454 @@
1
+ """Orchestrator Daemon for MLXSmith Multi-Process RLM.
2
+
3
+ Central queue-based job scheduler that coordinates between:
4
+ - Inference server (generates rollouts)
5
+ - Trainer worker (consumes batches and updates weights)
6
+
7
+ Manages rollout requests, training batches, and weight updates.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import json
13
+ import multiprocessing as mp
14
+ import signal
15
+ import sys
16
+ import time
17
+ import traceback
18
+ from dataclasses import dataclass, field
19
+ from pathlib import Path
20
+ from typing import Any, Dict, List, Optional, Callable
21
+
22
+ from rich.console import Console
23
+
24
+ from ..config import ProjectConfig
25
+ from ..rlm.corpus import append_corpus, load_corpus, sample_corpus
26
+ from ..rlm.gating import load_state, save_state, should_accept, update_state
27
+ from ..rlm.history import append_history
28
+ from ..rlm.inference import Rollout, build_tasks
29
+ from ..rlm.weights import WeightPointerStore, WeightPointerIPC
30
+ from ..runs import new_run, snapshot_config
31
+ from ..util import ensure_dir, now_ts, write_jsonl
32
+ from .queue import MessageQueue, MessageType, Message
33
+ from .inference_worker import InferenceConfig, run_inference_worker
34
+ from .trainer_worker import TrainerConfig, run_trainer_worker
35
+
36
+
37
+ console = Console()
38
+
39
+
40
+ @dataclass
41
+ class DaemonConfig:
42
+ """Configuration for orchestrator daemon."""
43
+ project_root: Path
44
+ model_spec: str
45
+
46
+ # Process management
47
+ inference_port: int = 8080
48
+ inference_host: str = "0.0.0.0"
49
+ max_restarts: int = 3
50
+ restart_delay: float = 5.0
51
+ health_check_interval: float = 10.0
52
+
53
+ # Training config
54
+ iterations: int = 50
55
+ tasks_per_iter: int = 80
56
+ rollouts_per_task: int = 8
57
+ batch_size: int = 32
58
+
59
+ # Paths
60
+ weights_dir: Optional[Path] = None
61
+ checkpoint_dir: Optional[Path] = None
62
+
63
+ # Gating
64
+ gating_mode: str = "strict"
65
+ gating_threshold: float = 0.0
66
+ gating_ema_alpha: float = 0.2
67
+
68
+ # Verifier
69
+ verifier_backend: str = "pytest"
70
+ verifier_timeout_s: int = 30
71
+
72
+
73
+ @dataclass
74
+ class ProcessHandle:
75
+ """Handle for a managed process."""
76
+ name: str
77
+ process: mp.Process
78
+ config: Any
79
+ restart_count: int = 0
80
+ last_restart: float = 0.0
81
+ healthy: bool = True
82
+ start_time: float = field(default_factory=time.time)
83
+
84
+
85
+ class OrchestratorDaemon:
86
+ """Orchestrator daemon for multi-process RLM.
87
+
88
+ Responsibilities:
89
+ - Spawn and manage inference and trainer processes
90
+ - Coordinate rollout requests and training batches
91
+ - Manage weight pointer updates
92
+ - Handle process lifecycle, monitoring, and restarts
93
+ - Graceful shutdown handling
94
+ """
95
+
96
+ def __init__(self, config: DaemonConfig, project_cfg: ProjectConfig):
97
+ self.config = config
98
+ self.project_cfg = project_cfg
99
+ self.queue = MessageQueue(maxsize=10000)
100
+ self._processes: Dict[str, ProcessHandle] = {}
101
+ self._shutdown = False
102
+ self._pointer_store: Optional[WeightPointerStore] = None
103
+ self._current_iteration = 0
104
+ self._metrics: List[Dict] = []
105
+
106
+ # Setup paths
107
+ self._weights_dir = config.weights_dir or (config.project_root / "runs" / "rlm_weights")
108
+ self._checkpoint_dir = config.checkpoint_dir or (config.project_root / "runs" / "rlm_checkpoints")
109
+ self._state_path = config.project_root / "runs" / "rlm_state.json"
110
+ self._history_path = config.project_root / "runs" / "rlm_history.jsonl"
111
+ self._corpus_path = config.project_root / "runs" / "rlm_corpus.jsonl"
112
+
113
+ ensure_dir(self._weights_dir)
114
+ ensure_dir(self._checkpoint_dir)
115
+
116
+ def _setup_signal_handlers(self) -> None:
117
+ """Setup signal handlers for graceful shutdown."""
118
+ def signal_handler(sig, frame):
119
+ console.print("[yellow]Orchestrator received shutdown signal[/yellow]")
120
+ self._shutdown = True
121
+
122
+ signal.signal(signal.SIGTERM, signal_handler)
123
+ signal.signal(signal.SIGINT, signal_handler)
124
+
125
+ def _spawn_inference_worker(self) -> ProcessHandle:
126
+ """Spawn the inference worker process."""
127
+ inf_config = InferenceConfig(
128
+ model_spec=self.config.model_spec,
129
+ backend=self.project_cfg.model.backend,
130
+ host=self.config.inference_host,
131
+ port=self.config.inference_port,
132
+ max_seq_len=self.project_cfg.model.max_seq_len,
133
+ dtype=self.project_cfg.model.dtype,
134
+ trust_remote_code=self.project_cfg.model.trust_remote_code,
135
+ use_chat_template=self.project_cfg.model.use_chat_template,
136
+ weights_dir=self._weights_dir,
137
+ hot_reload=True,
138
+ )
139
+
140
+ # Create process
141
+ process = mp.Process(
142
+ target=run_inference_worker,
143
+ args=(inf_config, self.queue),
144
+ name="inference_worker",
145
+ daemon=False,
146
+ )
147
+
148
+ handle = ProcessHandle(
149
+ name="inference",
150
+ process=process,
151
+ config=inf_config,
152
+ )
153
+
154
+ process.start()
155
+ console.print(f"[green]Spawned inference worker (PID: {process.pid})[/green]")
156
+
157
+ return handle
158
+
159
+ def _spawn_trainer_worker(self) -> ProcessHandle:
160
+ """Spawn the trainer worker process."""
161
+ # Resolve base model
162
+ from ..models import resolve_model_spec
163
+ base_model, adapter_path, _ = resolve_model_spec(
164
+ self.config.project_root, self.config.model_spec, self.project_cfg
165
+ )
166
+
167
+ trainer_config = TrainerConfig(
168
+ model_spec=self.config.model_spec,
169
+ base_model=base_model,
170
+ backend=self.project_cfg.model.backend,
171
+ max_seq_len=self.project_cfg.model.max_seq_len,
172
+ dtype=self.project_cfg.model.dtype,
173
+ trust_remote_code=self.project_cfg.model.trust_remote_code,
174
+ lr=self.project_cfg.train.lr,
175
+ weight_decay=self.project_cfg.train.weight_decay,
176
+ kl_coeff=self.project_cfg.rft.kl_coeff,
177
+ normalize_advantage=self.project_cfg.rft.normalize_advantage,
178
+ lora_r=self.project_cfg.lora.r,
179
+ lora_alpha=self.project_cfg.lora.alpha,
180
+ lora_dropout=self.project_cfg.lora.dropout,
181
+ lora_target_modules=list(self.project_cfg.lora.target_modules or []),
182
+ lora_num_layers=self.project_cfg.lora.num_layers,
183
+ weights_dir=self._weights_dir,
184
+ checkpoint_dir=self._checkpoint_dir,
185
+ reference_model=self.project_cfg.rft.reference_model,
186
+ )
187
+
188
+ # Create process
189
+ process = mp.Process(
190
+ target=run_trainer_worker,
191
+ args=(trainer_config, self.queue),
192
+ name="trainer_worker",
193
+ daemon=False,
194
+ )
195
+
196
+ handle = ProcessHandle(
197
+ name="trainer",
198
+ process=process,
199
+ config=trainer_config,
200
+ )
201
+
202
+ process.start()
203
+ console.print(f"[green]Spawned trainer worker (PID: {process.pid})[/green]")
204
+
205
+ return handle
206
+
207
+ def _monitor_processes(self) -> None:
208
+ """Monitor processes and restart if needed."""
209
+ current_time = time.time()
210
+
211
+ for name, handle in list(self._processes.items()):
212
+ # Check if process is alive
213
+ if not handle.process.is_alive():
214
+ if self._shutdown:
215
+ continue
216
+
217
+ console.print(f"[red]Process {name} (PID: {handle.process.pid}) died[/red]")
218
+ handle.healthy = False
219
+
220
+ # Check restart limit
221
+ if handle.restart_count >= self.config.max_restarts:
222
+ console.print(f"[red]Process {name} exceeded max restarts[/red]")
223
+ continue
224
+
225
+ # Check restart delay
226
+ if current_time - handle.last_restart < self.config.restart_delay:
227
+ time.sleep(self.config.restart_delay)
228
+
229
+ # Restart process
230
+ console.print(f"[yellow]Restarting {name}...[/yellow]")
231
+
232
+ if name == "inference":
233
+ new_handle = self._spawn_inference_worker()
234
+ elif name == "trainer":
235
+ new_handle = self._spawn_trainer_worker()
236
+ else:
237
+ continue
238
+
239
+ new_handle.restart_count = handle.restart_count + 1
240
+ new_handle.last_restart = current_time
241
+ self._processes[name] = new_handle
242
+
243
+ def _health_check(self) -> Dict[str, Any]:
244
+ """Perform health checks on all processes via queues."""
245
+ results = {}
246
+
247
+ # Check inference via queue
248
+ if "inference" in self._processes:
249
+ self.queue.send(
250
+ "control",
251
+ MessageType.HEALTH_CHECK,
252
+ {},
253
+ source="daemon",
254
+ )
255
+ # Response will be processed in main loop
256
+
257
+ # Check trainer via queue
258
+ if "trainer" in self._processes:
259
+ self.queue.send(
260
+ "train_batches", # Trainer reads from train_batches
261
+ MessageType.HEALTH_CHECK,
262
+ {},
263
+ source="daemon",
264
+ )
265
+
266
+ return results
267
+
268
+ def _forward_weight_updates(self) -> None:
269
+ """Forward weight updates from trainer to inference."""
270
+ # Check for weight updates from trainer
271
+ msg = self.queue.receive("weight_updates", timeout=0)
272
+ if msg and msg.msg_type == MessageType.WEIGHT_UPDATE:
273
+ # Forward to inference worker
274
+ self.queue.send(
275
+ "weight_forward",
276
+ MessageType.WEIGHT_UPDATE,
277
+ msg.payload,
278
+ source="daemon",
279
+ )
280
+
281
+ # Also update inference pointer
282
+ if self._pointer_store:
283
+ pointer = WeightPointerIPC(
284
+ base_model=msg.payload.get("base_model", ""),
285
+ adapter_path=msg.payload.get("adapter_path"),
286
+ iteration=msg.payload.get("version", 0),
287
+ updated_at=now_ts(),
288
+ version=msg.payload.get("version", 0),
289
+ name="inference",
290
+ )
291
+ self._pointer_store.save(pointer)
292
+ console.print(f"[blue]Forwarded weight update: {pointer.adapter_path}[/blue]")
293
+
294
+ def _shutdown_all(self) -> None:
295
+ """Shutdown all processes gracefully."""
296
+ console.print("[yellow]Shutting down all processes...[/yellow]")
297
+
298
+ # Send shutdown messages
299
+ for name in self._processes:
300
+ self.queue.send(
301
+ "control",
302
+ MessageType.SHUTDOWN,
303
+ {},
304
+ source="daemon",
305
+ )
306
+
307
+ # Wait for processes to terminate
308
+ for name, handle in self._processes.items():
309
+ console.print(f" Waiting for {name}...")
310
+ handle.process.join(timeout=10.0)
311
+
312
+ if handle.process.is_alive():
313
+ console.print(f" Force terminating {name}")
314
+ handle.process.terminate()
315
+ handle.process.join(timeout=5.0)
316
+
317
+ if handle.process.is_alive():
318
+ handle.process.kill()
319
+
320
+ # Stop queue manager
321
+ self.queue.stop()
322
+
323
+ console.print("[green]All processes shutdown[/green]")
324
+
325
+ def run_iteration(self, iteration: int) -> bool:
326
+ """Run a single RLM iteration.
327
+
328
+ Returns True if iteration completed successfully.
329
+ """
330
+ console.print(f"\n[bold blue]=== RLM Iteration {iteration} ===[/bold blue]")
331
+
332
+ run = new_run(self.config.project_root, "rlm")
333
+ snapshot_config(self.project_cfg.model_dump(), run.config_snapshot_path)
334
+
335
+ # Phase 1: Generate tasks (via inference worker API)
336
+ console.print(" [dim]Generating tasks...[/dim]")
337
+ # Tasks are generated by querying inference worker
338
+
339
+ # Phase 2: Collect rollouts (via inference worker)
340
+ console.print(" [dim]Collecting rollouts...[/dim]")
341
+ # Rollouts are generated via /internal/rollout endpoint
342
+
343
+ # Phase 3: Send training batch to trainer
344
+ console.print(" [dim]Sending training batch...[/dim]")
345
+
346
+ # Phase 4: Wait for training completion
347
+ console.print(" [dim]Waiting for training...[/dim]")
348
+
349
+ # This is a placeholder - actual implementation would
350
+ # coordinate via queues and the inference worker API
351
+
352
+ return True
353
+
354
+ def run(self) -> None:
355
+ """Run the orchestrator daemon."""
356
+ self._setup_signal_handlers()
357
+
358
+ console.print("[bold green]Starting MLXSmith Orchestrator[/bold green]")
359
+
360
+ # Start queue manager
361
+ self.queue.start()
362
+ console.print("[dim]Queue manager started[/dim]")
363
+
364
+ # Initialize weight pointer store
365
+ self._pointer_store = WeightPointerStore(self._weights_dir)
366
+ console.print(f"[dim]Weight store: {self._weights_dir}[/dim]")
367
+
368
+ # Spawn worker processes
369
+ console.print("[dim]Spawning worker processes...[/dim]")
370
+ self._processes["inference"] = self._spawn_inference_worker()
371
+ self._processes["trainer"] = self._spawn_trainer_worker()
372
+
373
+ # Wait for processes to initialize
374
+ console.print("[dim]Waiting for workers to initialize...[/dim]")
375
+ time.sleep(5.0)
376
+
377
+ # Load state
378
+ state = load_state(self._state_path)
379
+ self._current_iteration = state.last_iteration + 1
380
+
381
+ last_health_check = time.time()
382
+
383
+ try:
384
+ # Main orchestrator loop
385
+ while not self._shutdown:
386
+ # Monitor processes
387
+ self._monitor_processes()
388
+
389
+ # Health checks
390
+ current_time = time.time()
391
+ if current_time - last_health_check > self.config.health_check_interval:
392
+ self._health_check()
393
+ last_health_check = current_time
394
+
395
+ # Forward weight updates
396
+ self._forward_weight_updates()
397
+
398
+ # Process queue messages
399
+ self._process_queue_messages()
400
+
401
+ # Small sleep to prevent busy waiting
402
+ time.sleep(0.01)
403
+
404
+ except KeyboardInterrupt:
405
+ console.print("[yellow]Interrupted by user[/yellow]")
406
+ except Exception as e:
407
+ console.print(f"[red]Orchestrator error: {e}[/red]")
408
+ traceback.print_exc()
409
+ finally:
410
+ self._shutdown_all()
411
+
412
+ def _process_queue_messages(self) -> None:
413
+ """Process pending queue messages."""
414
+ # Process responses from inference
415
+ msg = self.queue.receive("rollout_responses", timeout=0)
416
+ if msg:
417
+ # Handle rollout response
418
+ pass
419
+
420
+ # Process training completion
421
+ msg = self.queue.receive("train_complete", timeout=0)
422
+ if msg:
423
+ # Handle training completion
424
+ pass
425
+
426
+ # Process checkpoints
427
+ msg = self.queue.receive("checkpoints", timeout=0)
428
+ if msg:
429
+ # Handle checkpoint notification
430
+ pass
431
+
432
+
433
+ def run_daemon(
434
+ project_root: Path,
435
+ project_cfg: ProjectConfig,
436
+ model_spec: Optional[str] = None,
437
+ iterations: Optional[int] = None,
438
+ ) -> None:
439
+ """Run the orchestrator daemon."""
440
+ config = DaemonConfig(
441
+ project_root=project_root,
442
+ model_spec=model_spec or project_cfg.model.id,
443
+ iterations=iterations or project_cfg.rlm.iterations,
444
+ tasks_per_iter=project_cfg.rlm.tasks_per_iter,
445
+ rollouts_per_task=project_cfg.rlm.rollouts_per_task,
446
+ gating_mode=project_cfg.rlm.gating,
447
+ gating_threshold=project_cfg.rlm.gating_threshold,
448
+ gating_ema_alpha=project_cfg.rlm.gating_ema_alpha,
449
+ verifier_backend=project_cfg.rlm.verifier_backend,
450
+ verifier_timeout_s=project_cfg.rlm.verifier_timeout_s,
451
+ )
452
+
453
+ daemon = OrchestratorDaemon(config, project_cfg)
454
+ daemon.run()