mlxsmith 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
mlxsmith/rlm/loop.py CHANGED
@@ -49,6 +49,28 @@ from .weights import (
49
49
  console = Console()
50
50
 
51
51
 
52
+ def _run_task_verifier(cfg: ProjectConfig, task_prompt: str, completion: str, workdir: Path) -> tuple[bool, float, float]:
53
+ """Execute verifier for a task and return (passed, reward, latency_ms)."""
54
+ t0 = time.time()
55
+ if cfg.rlm.verifier_backend == "docker":
56
+ res = docker_verify(
57
+ task_prompt,
58
+ completion,
59
+ str(workdir),
60
+ timeout_s=int(cfg.rlm.verifier_timeout_s),
61
+ image=cfg.rlm.docker_image,
62
+ memory_mb=int(cfg.rlm.docker_memory_mb),
63
+ cpus=float(cfg.rlm.docker_cpus),
64
+ pids=int(cfg.rlm.docker_pids),
65
+ )
66
+ else:
67
+ res = pytest_verify(task_prompt, completion, str(workdir), timeout_s=int(cfg.rlm.verifier_timeout_s))
68
+ latency_ms = (time.time() - t0) * 1000.0
69
+ passed = bool(getattr(res, "passed", False))
70
+ reward = float(getattr(res, "reward", 0.0))
71
+ return passed, reward, latency_ms
72
+
73
+
52
74
  def _score_from_eval(result_path: Path) -> float:
53
75
  try:
54
76
  data = json.loads(result_path.read_text(encoding="utf-8"))
@@ -199,7 +221,12 @@ def run_rlm(
199
221
  trust_remote_code=cfg.model.trust_remote_code,
200
222
  )
201
223
 
202
- opt, _params = train_llm.optimizer_and_params(lr=cfg.train.lr, weight_decay=cfg.train.weight_decay)
224
+ opt, _params = train_llm.optimizer_and_params(
225
+ lr=cfg.train.lr,
226
+ weight_decay=cfg.train.weight_decay,
227
+ optimizer=cfg.train.optimizer,
228
+ optimizer_kwargs=cfg.train.optimizer_kwargs,
229
+ )
203
230
 
204
231
  corpus_rows = load_corpus(corpus_path, max_size=int(rlm_cfg.corpus_max))
205
232
  existing_prompts = [row.get("prompt", "") for row in corpus_rows if row.get("prompt")]
@@ -371,9 +398,14 @@ def run_rlm(
371
398
  # Multi-Process Orchestrated RLM
372
399
  # =============================================================================
373
400
 
374
- from ..orchestrator.queue import MessageQueue, MessageType, Message # noqa: E402
375
- from ..orchestrator.inference_worker import InferenceConfig, run_inference_worker # noqa: E402
376
- from ..orchestrator.trainer_worker import TrainerConfig, run_trainer_worker # noqa: E402
401
+ def _lazy_import_orchestrator():
402
+ """Lazy import to break circular dependency with orchestrator module."""
403
+ global MessageQueue, MessageType, Message
404
+ global InferenceConfig, run_inference_worker
405
+ global TrainerConfig, run_trainer_worker
406
+ from ..orchestrator.queue import MessageQueue, MessageType, Message # noqa: E402
407
+ from ..orchestrator.inference_worker import InferenceConfig, run_inference_worker # noqa: E402
408
+ from ..orchestrator.trainer_worker import TrainerConfig, run_trainer_worker # noqa: E402
377
409
 
378
410
 
379
411
  @dataclass
@@ -402,6 +434,7 @@ class RLMOrchestrator:
402
434
  iterations: int = 50,
403
435
  resume: bool = False,
404
436
  ):
437
+ _lazy_import_orchestrator()
405
438
  self.project_root = project_root
406
439
  self.cfg = cfg
407
440
  self.model_spec = model_spec
@@ -492,6 +525,8 @@ class RLMOrchestrator:
492
525
  trust_remote_code=self.cfg.model.trust_remote_code,
493
526
  lr=self.cfg.train.lr,
494
527
  weight_decay=self.cfg.train.weight_decay,
528
+ optimizer=self.cfg.train.optimizer,
529
+ optimizer_kwargs=self.cfg.train.optimizer_kwargs,
495
530
  kl_coeff=self.cfg.rft.kl_coeff,
496
531
  normalize_advantage=self.cfg.rft.normalize_advantage,
497
532
  lora_r=self.cfg.lora.r,
@@ -602,30 +637,12 @@ class RLMOrchestrator:
602
637
  tests_dir = ensure_dir(wdir / "tests")
603
638
  (tests_dir / "test_task.py").write_text(task.tests, encoding="utf-8")
604
639
 
605
- t0 = time.time()
606
- if self.cfg.rlm.verifier_backend == "docker":
607
- res = docker_verify(
608
- task.prompt,
609
- completion,
610
- str(wdir),
611
- timeout_s=int(self.cfg.rlm.verifier_timeout_s),
612
- image=self.cfg.rlm.docker_image,
613
- memory_mb=int(self.cfg.rlm.docker_memory_mb),
614
- cpus=float(self.cfg.rlm.docker_cpus),
615
- pids=int(self.cfg.rlm.docker_pids),
616
- )
617
- else:
618
- from ..verifiers.pytest_verifier import verify as pytest_verify
619
- res = pytest_verify(
620
- task.prompt,
621
- completion,
622
- str(wdir),
623
- timeout_s=int(self.cfg.rlm.verifier_timeout_s),
624
- )
625
- latency_ms = (time.time() - t0) * 1000.0
626
-
627
- passed = bool(getattr(res, "passed", False))
628
- reward = float(getattr(res, "reward", 0.0))
640
+ passed, reward, latency_ms = _run_task_verifier(
641
+ self.cfg,
642
+ task.prompt,
643
+ completion,
644
+ wdir,
645
+ )
629
646
 
630
647
  rollouts.append(Rollout(
631
648
  task_id=task.id,
@@ -706,30 +723,12 @@ class RLMOrchestrator:
706
723
  tests_dir = ensure_dir(wdir / "tests")
707
724
  (tests_dir / "test_task.py").write_text(task.tests, encoding="utf-8")
708
725
 
709
- t0 = time.time()
710
- if self.cfg.rlm.verifier_backend == "docker":
711
- res = docker_verify(
712
- task.prompt,
713
- completion,
714
- str(wdir),
715
- timeout_s=int(self.cfg.rlm.verifier_timeout_s),
716
- image=self.cfg.rlm.docker_image,
717
- memory_mb=int(self.cfg.rlm.docker_memory_mb),
718
- cpus=float(self.cfg.rlm.docker_cpus),
719
- pids=int(self.cfg.rlm.docker_pids),
720
- )
721
- else:
722
- from ..verifiers.pytest_verifier import verify as pytest_verify
723
- res = pytest_verify(
724
- task.prompt,
725
- completion,
726
- str(wdir),
727
- timeout_s=int(self.cfg.rlm.verifier_timeout_s),
728
- )
729
- latency_ms = (time.time() - t0) * 1000.0
730
-
731
- passed = bool(getattr(res, "passed", False))
732
- reward = float(getattr(res, "reward", 0.0))
726
+ passed, reward, latency_ms = _run_task_verifier(
727
+ self.cfg,
728
+ task.prompt,
729
+ completion,
730
+ wdir,
731
+ )
733
732
 
734
733
  rollouts.append(
735
734
  Rollout(
@@ -1171,28 +1170,7 @@ def collect_rollouts_via_api(
1171
1170
  wdir = ensure_dir(artifacts_dir / task.id / f"rollout_{k:02d}")
1172
1171
  (wdir / "main.py").write_text(completion, encoding="utf-8")
1173
1172
  (ensure_dir(wdir / "tests") / "test_task.py").write_text(task.tests, encoding="utf-8")
1174
- t0 = time.time()
1175
- if verifier_backend == "docker":
1176
- res = docker_verify(
1177
- task.prompt,
1178
- completion,
1179
- str(wdir),
1180
- timeout_s=int(cfg.rlm.verifier_timeout_s),
1181
- image=cfg.rlm.docker_image,
1182
- memory_mb=int(cfg.rlm.docker_memory_mb),
1183
- cpus=float(cfg.rlm.docker_cpus),
1184
- pids=int(cfg.rlm.docker_pids),
1185
- )
1186
- else:
1187
- res = pytest_verify(
1188
- task.prompt,
1189
- completion,
1190
- str(wdir),
1191
- timeout_s=int(cfg.rlm.verifier_timeout_s),
1192
- )
1193
- latency_ms = (time.time() - t0) * 1000.0
1194
- passed = bool(getattr(res, "passed", False))
1195
- reward = float(getattr(res, "reward", 0.0))
1173
+ passed, reward, latency_ms = _run_task_verifier(cfg, task.prompt, completion, wdir)
1196
1174
  rollouts.append(
1197
1175
  Rollout(
1198
1176
  task_id=task.id,
@@ -1247,24 +1225,7 @@ def collect_rollouts_via_api(
1247
1225
  (wdir / "main.py").write_text(completion, encoding="utf-8")
1248
1226
  (ensure_dir(wdir / "tests") / "test_task.py").write_text(task.tests, encoding="utf-8")
1249
1227
 
1250
- t0 = time.time()
1251
- if verifier_backend == "docker":
1252
- res = docker_verify(
1253
- task.prompt,
1254
- completion,
1255
- str(wdir),
1256
- timeout_s=int(cfg.rlm.verifier_timeout_s),
1257
- image=cfg.rlm.docker_image,
1258
- memory_mb=int(cfg.rlm.docker_memory_mb),
1259
- cpus=float(cfg.rlm.docker_cpus),
1260
- pids=int(cfg.rlm.docker_pids),
1261
- )
1262
- else:
1263
- res = pytest_verify(task.prompt, completion, str(wdir), timeout_s=int(cfg.rlm.verifier_timeout_s))
1264
- latency_ms = (time.time() - t0) * 1000.0
1265
-
1266
- passed = bool(getattr(res, "passed", False))
1267
- reward = float(getattr(res, "reward", 0.0))
1228
+ passed, reward, latency_ms = _run_task_verifier(cfg, task.prompt, completion, wdir)
1268
1229
 
1269
1230
  rollouts.append(Rollout(
1270
1231
  task_id=task.id,
mlxsmith/sdk/__init__.py CHANGED
@@ -259,6 +259,7 @@ def preference_forward_backward(
259
259
  kl_coeff: float = 0.0,
260
260
  train_on_prompt: bool = False,
261
261
  max_seq_len: Optional[int] = None,
262
+ delta: float = 0.0,
262
263
  ) -> Tuple[Any, Any | None]:
263
264
  """Execute preference-based forward/backward pass.
264
265
 
@@ -296,23 +297,38 @@ def preference_forward_backward(
296
297
  reference_backend=reference_backend,
297
298
  kl_coeff=kl_coeff,
298
299
  train_on_prompt=train_on_prompt,
300
+ delta=delta,
299
301
  )
300
302
 
301
303
  return backend.value_and_grad(loss_fn)
302
304
 
303
305
 
304
- def create_optimizer(backend: Any, *, lr: float, weight_decay: float = 0.0) -> Tuple[Any, Any]:
306
+ def create_optimizer(
307
+ backend: Any,
308
+ *,
309
+ lr: float,
310
+ weight_decay: float = 0.0,
311
+ optimizer: Optional[str] = None,
312
+ optimizer_kwargs: Optional[dict] = None,
313
+ ) -> Tuple[Any, Any]:
305
314
  """Create optimizer for training.
306
315
 
307
316
  Args:
308
317
  backend: LLM backend instance
309
318
  lr: Learning rate
310
319
  weight_decay: Weight decay coefficient
320
+ optimizer: Optimizer name
321
+ optimizer_kwargs: Extra optimizer kwargs
311
322
 
312
323
  Returns:
313
324
  Tuple of (optimizer, parameters)
314
325
  """
315
- return backend.optimizer_and_params(lr=lr, weight_decay=weight_decay)
326
+ return backend.optimizer_and_params(
327
+ lr=lr,
328
+ weight_decay=weight_decay,
329
+ optimizer=optimizer,
330
+ optimizer_kwargs=optimizer_kwargs,
331
+ )
316
332
 
317
333
 
318
334
  def optim_step(backend: Any, optimizer: Any, grads: Any) -> None:
mlxsmith/sdk/losses.py CHANGED
@@ -70,6 +70,76 @@ def preference_diff(
70
70
  return (logp_c - logp_r) - ref_diff
71
71
 
72
72
 
73
+ @register_loss("cpo")
74
+ def cpo_loss(
75
+ backend,
76
+ chosen_ids: Sequence[int],
77
+ rejected_ids: Sequence[int],
78
+ *,
79
+ prompt_len_chosen: int,
80
+ prompt_len_rejected: int,
81
+ beta: float = 0.1,
82
+ ) -> Any:
83
+ mx = _require_mx(backend)
84
+ diff = preference_diff(
85
+ backend,
86
+ chosen_ids,
87
+ rejected_ids,
88
+ prompt_len_chosen=prompt_len_chosen,
89
+ prompt_len_rejected=prompt_len_rejected,
90
+ reference_backend=None,
91
+ )
92
+ scaled = _to_mx_scalar(mx, beta) * diff
93
+ return mx.log1p(mx.exp(-scaled))
94
+
95
+
96
+ @register_loss("ipo")
97
+ def ipo_loss(
98
+ backend,
99
+ chosen_ids: Sequence[int],
100
+ rejected_ids: Sequence[int],
101
+ *,
102
+ prompt_len_chosen: int,
103
+ prompt_len_rejected: int,
104
+ beta: float = 0.1,
105
+ reference_backend: Optional[Any] = None,
106
+ ) -> Any:
107
+ mx = _require_mx(backend)
108
+ diff = preference_diff(
109
+ backend,
110
+ chosen_ids,
111
+ rejected_ids,
112
+ prompt_len_chosen=prompt_len_chosen,
113
+ prompt_len_rejected=prompt_len_rejected,
114
+ reference_backend=reference_backend,
115
+ )
116
+ target = _to_mx_scalar(mx, 1.0 / (2.0 * float(beta))) if beta != 0 else _to_mx_scalar(mx, 0.0)
117
+ return (diff - target) ** 2
118
+
119
+
120
+ @register_loss("hinge")
121
+ def hinge_loss(
122
+ backend,
123
+ chosen_ids: Sequence[int],
124
+ rejected_ids: Sequence[int],
125
+ *,
126
+ prompt_len_chosen: int,
127
+ prompt_len_rejected: int,
128
+ delta: float = 0.0,
129
+ reference_backend: Optional[Any] = None,
130
+ ) -> Any:
131
+ mx = _require_mx(backend)
132
+ diff = preference_diff(
133
+ backend,
134
+ chosen_ids,
135
+ rejected_ids,
136
+ prompt_len_chosen=prompt_len_chosen,
137
+ prompt_len_rejected=prompt_len_rejected,
138
+ reference_backend=reference_backend,
139
+ )
140
+ return mx.maximum(_to_mx_scalar(mx, delta) - diff, _to_mx_scalar(mx, 0.0))
141
+
142
+
73
143
  @register_loss("dpo")
74
144
  def dpo_loss(
75
145
  backend,
@@ -149,8 +219,10 @@ def preference_loss(
149
219
  reference_backend: Optional[Any] = None,
150
220
  kl_coeff: float = 0.0,
151
221
  train_on_prompt: bool = False,
222
+ delta: float = 0.0,
152
223
  ) -> Any:
153
- if algo.lower() == "orpo":
224
+ algo_l = algo.lower()
225
+ if algo_l == "orpo":
154
226
  return orpo_loss(
155
227
  backend,
156
228
  chosen_ids,
@@ -162,6 +234,35 @@ def preference_loss(
162
234
  kl_coeff=kl_coeff,
163
235
  train_on_prompt=train_on_prompt,
164
236
  )
237
+ if algo_l == "cpo":
238
+ return cpo_loss(
239
+ backend,
240
+ chosen_ids,
241
+ rejected_ids,
242
+ prompt_len_chosen=prompt_len_chosen,
243
+ prompt_len_rejected=prompt_len_rejected,
244
+ beta=beta,
245
+ )
246
+ if algo_l == "ipo":
247
+ return ipo_loss(
248
+ backend,
249
+ chosen_ids,
250
+ rejected_ids,
251
+ prompt_len_chosen=prompt_len_chosen,
252
+ prompt_len_rejected=prompt_len_rejected,
253
+ beta=beta,
254
+ reference_backend=reference_backend,
255
+ )
256
+ if algo_l == "hinge":
257
+ return hinge_loss(
258
+ backend,
259
+ chosen_ids,
260
+ rejected_ids,
261
+ prompt_len_chosen=prompt_len_chosen,
262
+ prompt_len_rejected=prompt_len_rejected,
263
+ delta=delta,
264
+ reference_backend=reference_backend,
265
+ )
165
266
  return dpo_loss(
166
267
  backend,
167
268
  chosen_ids,
@@ -210,6 +210,7 @@ class TrainingClient:
210
210
  beta=batch.extra.get("beta", 0.1),
211
211
  reference_backend=batch.extra.get("reference_backend"),
212
212
  kl_coeff=batch.extra.get("kl_coeff", 0.0),
213
+ delta=batch.extra.get("delta", 0.0),
213
214
  train_on_prompt=batch.train_on_prompt,
214
215
  max_seq_len=batch.max_seq_len,
215
216
  )
@@ -514,12 +515,20 @@ class TrainingClient:
514
515
  # Utility Methods
515
516
  # ========================================================================
516
517
 
517
- def create_optimizer(self, lr: float = 1e-4, weight_decay: float = 0.0) -> APIFuture[Any]:
518
+ def create_optimizer(
519
+ self,
520
+ lr: float = 1e-4,
521
+ weight_decay: float = 0.0,
522
+ optimizer: Optional[str] = None,
523
+ optimizer_kwargs: Optional[Dict[str, Any]] = None,
524
+ ) -> APIFuture[Any]:
518
525
  """Create optimizer for training.
519
526
 
520
527
  Args:
521
528
  lr: Learning rate
522
529
  weight_decay: Weight decay coefficient
530
+ optimizer: Optimizer name (e.g., adamw, adam, qhadam, muon)
531
+ optimizer_kwargs: Extra optimizer kwargs
523
532
 
524
533
  Returns:
525
534
  APIFuture resolving to optimizer instance
@@ -535,6 +544,8 @@ class TrainingClient:
535
544
  self.backend,
536
545
  lr=lr,
537
546
  weight_decay=weight_decay,
547
+ optimizer=optimizer,
548
+ optimizer_kwargs=optimizer_kwargs,
538
549
  )
539
550
  return self.optimizer
540
551
 
@@ -588,10 +599,18 @@ class TrainingClient:
588
599
  if len(grads_list) == 1:
589
600
  return grads_list[0]
590
601
 
591
- # Average gradients
592
- # This is backend-specific; for now return first grad
593
- # In practice, MLX would average the arrays
594
- return grads_list[0]
602
+ from ..util import tree_add, tree_scale
603
+
604
+ agg = None
605
+ count = 0
606
+ for grads in grads_list:
607
+ if grads is None:
608
+ continue
609
+ agg = tree_add(agg, grads)
610
+ count += 1
611
+ if agg is None or count == 0:
612
+ return None
613
+ return tree_scale(agg, 1.0 / float(count))
595
614
 
596
615
  def shutdown(self) -> None:
597
616
  """Shutdown the client and its thread pool."""
mlxsmith/train/distill.py CHANGED
@@ -109,7 +109,12 @@ def run_distill(
109
109
  )
110
110
  student.apply_lora_from_config(lora_cfg)
111
111
 
112
- opt, _params = student.optimizer_and_params(lr=cfg.train.lr, weight_decay=cfg.train.weight_decay)
112
+ opt, _params = student.optimizer_and_params(
113
+ lr=cfg.train.lr,
114
+ weight_decay=cfg.train.weight_decay,
115
+ optimizer=cfg.train.optimizer,
116
+ optimizer_kwargs=cfg.train.optimizer_kwargs,
117
+ )
113
118
 
114
119
  rng = random.Random(cfg.train.seed)
115
120
  total = int(cfg.train.iters)