mlxsmith 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlxsmith/bench.py +12 -2
- mlxsmith/cli.py +187 -1
- mlxsmith/config_models.py +15 -1
- mlxsmith/integrations/__init__.py +19 -0
- mlxsmith/integrations/mlx_lm_lora.py +117 -0
- mlxsmith/llm/backend.py +8 -1
- mlxsmith/llm/mlx_lm_backend.py +59 -2
- mlxsmith/llm/mock_backend.py +8 -1
- mlxsmith/optim/__init__.py +3 -0
- mlxsmith/optim/muon.py +93 -0
- mlxsmith/orchestrator/daemon.py +44 -377
- mlxsmith/orchestrator/trainer_worker.py +4 -0
- mlxsmith/rlm/loop.py +53 -92
- mlxsmith/sdk/__init__.py +18 -2
- mlxsmith/sdk/losses.py +102 -1
- mlxsmith/sdk/training_client.py +24 -5
- mlxsmith/train/distill.py +6 -1
- mlxsmith/train/online_dpo.py +249 -0
- mlxsmith/train/pref.py +31 -29
- mlxsmith/train/rft.py +123 -38
- mlxsmith/train/self_verify.py +199 -0
- mlxsmith/train/sft.py +13 -2
- mlxsmith/verifiers/llm_judge.py +278 -0
- mlxsmith/verifiers/prime.py +127 -0
- {mlxsmith-0.1.2.dist-info → mlxsmith-0.1.3.dist-info}/METADATA +27 -1
- {mlxsmith-0.1.2.dist-info → mlxsmith-0.1.3.dist-info}/RECORD +30 -22
- {mlxsmith-0.1.2.dist-info → mlxsmith-0.1.3.dist-info}/WHEEL +0 -0
- {mlxsmith-0.1.2.dist-info → mlxsmith-0.1.3.dist-info}/entry_points.txt +0 -0
- {mlxsmith-0.1.2.dist-info → mlxsmith-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {mlxsmith-0.1.2.dist-info → mlxsmith-0.1.3.dist-info}/top_level.txt +0 -0
mlxsmith/rlm/loop.py
CHANGED
|
@@ -49,6 +49,28 @@ from .weights import (
|
|
|
49
49
|
console = Console()
|
|
50
50
|
|
|
51
51
|
|
|
52
|
+
def _run_task_verifier(cfg: ProjectConfig, task_prompt: str, completion: str, workdir: Path) -> tuple[bool, float, float]:
|
|
53
|
+
"""Execute verifier for a task and return (passed, reward, latency_ms)."""
|
|
54
|
+
t0 = time.time()
|
|
55
|
+
if cfg.rlm.verifier_backend == "docker":
|
|
56
|
+
res = docker_verify(
|
|
57
|
+
task_prompt,
|
|
58
|
+
completion,
|
|
59
|
+
str(workdir),
|
|
60
|
+
timeout_s=int(cfg.rlm.verifier_timeout_s),
|
|
61
|
+
image=cfg.rlm.docker_image,
|
|
62
|
+
memory_mb=int(cfg.rlm.docker_memory_mb),
|
|
63
|
+
cpus=float(cfg.rlm.docker_cpus),
|
|
64
|
+
pids=int(cfg.rlm.docker_pids),
|
|
65
|
+
)
|
|
66
|
+
else:
|
|
67
|
+
res = pytest_verify(task_prompt, completion, str(workdir), timeout_s=int(cfg.rlm.verifier_timeout_s))
|
|
68
|
+
latency_ms = (time.time() - t0) * 1000.0
|
|
69
|
+
passed = bool(getattr(res, "passed", False))
|
|
70
|
+
reward = float(getattr(res, "reward", 0.0))
|
|
71
|
+
return passed, reward, latency_ms
|
|
72
|
+
|
|
73
|
+
|
|
52
74
|
def _score_from_eval(result_path: Path) -> float:
|
|
53
75
|
try:
|
|
54
76
|
data = json.loads(result_path.read_text(encoding="utf-8"))
|
|
@@ -199,7 +221,12 @@ def run_rlm(
|
|
|
199
221
|
trust_remote_code=cfg.model.trust_remote_code,
|
|
200
222
|
)
|
|
201
223
|
|
|
202
|
-
opt, _params = train_llm.optimizer_and_params(
|
|
224
|
+
opt, _params = train_llm.optimizer_and_params(
|
|
225
|
+
lr=cfg.train.lr,
|
|
226
|
+
weight_decay=cfg.train.weight_decay,
|
|
227
|
+
optimizer=cfg.train.optimizer,
|
|
228
|
+
optimizer_kwargs=cfg.train.optimizer_kwargs,
|
|
229
|
+
)
|
|
203
230
|
|
|
204
231
|
corpus_rows = load_corpus(corpus_path, max_size=int(rlm_cfg.corpus_max))
|
|
205
232
|
existing_prompts = [row.get("prompt", "") for row in corpus_rows if row.get("prompt")]
|
|
@@ -371,9 +398,14 @@ def run_rlm(
|
|
|
371
398
|
# Multi-Process Orchestrated RLM
|
|
372
399
|
# =============================================================================
|
|
373
400
|
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
401
|
+
def _lazy_import_orchestrator():
|
|
402
|
+
"""Lazy import to break circular dependency with orchestrator module."""
|
|
403
|
+
global MessageQueue, MessageType, Message
|
|
404
|
+
global InferenceConfig, run_inference_worker
|
|
405
|
+
global TrainerConfig, run_trainer_worker
|
|
406
|
+
from ..orchestrator.queue import MessageQueue, MessageType, Message # noqa: E402
|
|
407
|
+
from ..orchestrator.inference_worker import InferenceConfig, run_inference_worker # noqa: E402
|
|
408
|
+
from ..orchestrator.trainer_worker import TrainerConfig, run_trainer_worker # noqa: E402
|
|
377
409
|
|
|
378
410
|
|
|
379
411
|
@dataclass
|
|
@@ -402,6 +434,7 @@ class RLMOrchestrator:
|
|
|
402
434
|
iterations: int = 50,
|
|
403
435
|
resume: bool = False,
|
|
404
436
|
):
|
|
437
|
+
_lazy_import_orchestrator()
|
|
405
438
|
self.project_root = project_root
|
|
406
439
|
self.cfg = cfg
|
|
407
440
|
self.model_spec = model_spec
|
|
@@ -492,6 +525,8 @@ class RLMOrchestrator:
|
|
|
492
525
|
trust_remote_code=self.cfg.model.trust_remote_code,
|
|
493
526
|
lr=self.cfg.train.lr,
|
|
494
527
|
weight_decay=self.cfg.train.weight_decay,
|
|
528
|
+
optimizer=self.cfg.train.optimizer,
|
|
529
|
+
optimizer_kwargs=self.cfg.train.optimizer_kwargs,
|
|
495
530
|
kl_coeff=self.cfg.rft.kl_coeff,
|
|
496
531
|
normalize_advantage=self.cfg.rft.normalize_advantage,
|
|
497
532
|
lora_r=self.cfg.lora.r,
|
|
@@ -602,30 +637,12 @@ class RLMOrchestrator:
|
|
|
602
637
|
tests_dir = ensure_dir(wdir / "tests")
|
|
603
638
|
(tests_dir / "test_task.py").write_text(task.tests, encoding="utf-8")
|
|
604
639
|
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
timeout_s=int(self.cfg.rlm.verifier_timeout_s),
|
|
612
|
-
image=self.cfg.rlm.docker_image,
|
|
613
|
-
memory_mb=int(self.cfg.rlm.docker_memory_mb),
|
|
614
|
-
cpus=float(self.cfg.rlm.docker_cpus),
|
|
615
|
-
pids=int(self.cfg.rlm.docker_pids),
|
|
616
|
-
)
|
|
617
|
-
else:
|
|
618
|
-
from ..verifiers.pytest_verifier import verify as pytest_verify
|
|
619
|
-
res = pytest_verify(
|
|
620
|
-
task.prompt,
|
|
621
|
-
completion,
|
|
622
|
-
str(wdir),
|
|
623
|
-
timeout_s=int(self.cfg.rlm.verifier_timeout_s),
|
|
624
|
-
)
|
|
625
|
-
latency_ms = (time.time() - t0) * 1000.0
|
|
626
|
-
|
|
627
|
-
passed = bool(getattr(res, "passed", False))
|
|
628
|
-
reward = float(getattr(res, "reward", 0.0))
|
|
640
|
+
passed, reward, latency_ms = _run_task_verifier(
|
|
641
|
+
self.cfg,
|
|
642
|
+
task.prompt,
|
|
643
|
+
completion,
|
|
644
|
+
wdir,
|
|
645
|
+
)
|
|
629
646
|
|
|
630
647
|
rollouts.append(Rollout(
|
|
631
648
|
task_id=task.id,
|
|
@@ -706,30 +723,12 @@ class RLMOrchestrator:
|
|
|
706
723
|
tests_dir = ensure_dir(wdir / "tests")
|
|
707
724
|
(tests_dir / "test_task.py").write_text(task.tests, encoding="utf-8")
|
|
708
725
|
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
|
|
715
|
-
timeout_s=int(self.cfg.rlm.verifier_timeout_s),
|
|
716
|
-
image=self.cfg.rlm.docker_image,
|
|
717
|
-
memory_mb=int(self.cfg.rlm.docker_memory_mb),
|
|
718
|
-
cpus=float(self.cfg.rlm.docker_cpus),
|
|
719
|
-
pids=int(self.cfg.rlm.docker_pids),
|
|
720
|
-
)
|
|
721
|
-
else:
|
|
722
|
-
from ..verifiers.pytest_verifier import verify as pytest_verify
|
|
723
|
-
res = pytest_verify(
|
|
724
|
-
task.prompt,
|
|
725
|
-
completion,
|
|
726
|
-
str(wdir),
|
|
727
|
-
timeout_s=int(self.cfg.rlm.verifier_timeout_s),
|
|
728
|
-
)
|
|
729
|
-
latency_ms = (time.time() - t0) * 1000.0
|
|
730
|
-
|
|
731
|
-
passed = bool(getattr(res, "passed", False))
|
|
732
|
-
reward = float(getattr(res, "reward", 0.0))
|
|
726
|
+
passed, reward, latency_ms = _run_task_verifier(
|
|
727
|
+
self.cfg,
|
|
728
|
+
task.prompt,
|
|
729
|
+
completion,
|
|
730
|
+
wdir,
|
|
731
|
+
)
|
|
733
732
|
|
|
734
733
|
rollouts.append(
|
|
735
734
|
Rollout(
|
|
@@ -1171,28 +1170,7 @@ def collect_rollouts_via_api(
|
|
|
1171
1170
|
wdir = ensure_dir(artifacts_dir / task.id / f"rollout_{k:02d}")
|
|
1172
1171
|
(wdir / "main.py").write_text(completion, encoding="utf-8")
|
|
1173
1172
|
(ensure_dir(wdir / "tests") / "test_task.py").write_text(task.tests, encoding="utf-8")
|
|
1174
|
-
|
|
1175
|
-
if verifier_backend == "docker":
|
|
1176
|
-
res = docker_verify(
|
|
1177
|
-
task.prompt,
|
|
1178
|
-
completion,
|
|
1179
|
-
str(wdir),
|
|
1180
|
-
timeout_s=int(cfg.rlm.verifier_timeout_s),
|
|
1181
|
-
image=cfg.rlm.docker_image,
|
|
1182
|
-
memory_mb=int(cfg.rlm.docker_memory_mb),
|
|
1183
|
-
cpus=float(cfg.rlm.docker_cpus),
|
|
1184
|
-
pids=int(cfg.rlm.docker_pids),
|
|
1185
|
-
)
|
|
1186
|
-
else:
|
|
1187
|
-
res = pytest_verify(
|
|
1188
|
-
task.prompt,
|
|
1189
|
-
completion,
|
|
1190
|
-
str(wdir),
|
|
1191
|
-
timeout_s=int(cfg.rlm.verifier_timeout_s),
|
|
1192
|
-
)
|
|
1193
|
-
latency_ms = (time.time() - t0) * 1000.0
|
|
1194
|
-
passed = bool(getattr(res, "passed", False))
|
|
1195
|
-
reward = float(getattr(res, "reward", 0.0))
|
|
1173
|
+
passed, reward, latency_ms = _run_task_verifier(cfg, task.prompt, completion, wdir)
|
|
1196
1174
|
rollouts.append(
|
|
1197
1175
|
Rollout(
|
|
1198
1176
|
task_id=task.id,
|
|
@@ -1247,24 +1225,7 @@ def collect_rollouts_via_api(
|
|
|
1247
1225
|
(wdir / "main.py").write_text(completion, encoding="utf-8")
|
|
1248
1226
|
(ensure_dir(wdir / "tests") / "test_task.py").write_text(task.tests, encoding="utf-8")
|
|
1249
1227
|
|
|
1250
|
-
|
|
1251
|
-
if verifier_backend == "docker":
|
|
1252
|
-
res = docker_verify(
|
|
1253
|
-
task.prompt,
|
|
1254
|
-
completion,
|
|
1255
|
-
str(wdir),
|
|
1256
|
-
timeout_s=int(cfg.rlm.verifier_timeout_s),
|
|
1257
|
-
image=cfg.rlm.docker_image,
|
|
1258
|
-
memory_mb=int(cfg.rlm.docker_memory_mb),
|
|
1259
|
-
cpus=float(cfg.rlm.docker_cpus),
|
|
1260
|
-
pids=int(cfg.rlm.docker_pids),
|
|
1261
|
-
)
|
|
1262
|
-
else:
|
|
1263
|
-
res = pytest_verify(task.prompt, completion, str(wdir), timeout_s=int(cfg.rlm.verifier_timeout_s))
|
|
1264
|
-
latency_ms = (time.time() - t0) * 1000.0
|
|
1265
|
-
|
|
1266
|
-
passed = bool(getattr(res, "passed", False))
|
|
1267
|
-
reward = float(getattr(res, "reward", 0.0))
|
|
1228
|
+
passed, reward, latency_ms = _run_task_verifier(cfg, task.prompt, completion, wdir)
|
|
1268
1229
|
|
|
1269
1230
|
rollouts.append(Rollout(
|
|
1270
1231
|
task_id=task.id,
|
mlxsmith/sdk/__init__.py
CHANGED
|
@@ -259,6 +259,7 @@ def preference_forward_backward(
|
|
|
259
259
|
kl_coeff: float = 0.0,
|
|
260
260
|
train_on_prompt: bool = False,
|
|
261
261
|
max_seq_len: Optional[int] = None,
|
|
262
|
+
delta: float = 0.0,
|
|
262
263
|
) -> Tuple[Any, Any | None]:
|
|
263
264
|
"""Execute preference-based forward/backward pass.
|
|
264
265
|
|
|
@@ -296,23 +297,38 @@ def preference_forward_backward(
|
|
|
296
297
|
reference_backend=reference_backend,
|
|
297
298
|
kl_coeff=kl_coeff,
|
|
298
299
|
train_on_prompt=train_on_prompt,
|
|
300
|
+
delta=delta,
|
|
299
301
|
)
|
|
300
302
|
|
|
301
303
|
return backend.value_and_grad(loss_fn)
|
|
302
304
|
|
|
303
305
|
|
|
304
|
-
def create_optimizer(
|
|
306
|
+
def create_optimizer(
|
|
307
|
+
backend: Any,
|
|
308
|
+
*,
|
|
309
|
+
lr: float,
|
|
310
|
+
weight_decay: float = 0.0,
|
|
311
|
+
optimizer: Optional[str] = None,
|
|
312
|
+
optimizer_kwargs: Optional[dict] = None,
|
|
313
|
+
) -> Tuple[Any, Any]:
|
|
305
314
|
"""Create optimizer for training.
|
|
306
315
|
|
|
307
316
|
Args:
|
|
308
317
|
backend: LLM backend instance
|
|
309
318
|
lr: Learning rate
|
|
310
319
|
weight_decay: Weight decay coefficient
|
|
320
|
+
optimizer: Optimizer name
|
|
321
|
+
optimizer_kwargs: Extra optimizer kwargs
|
|
311
322
|
|
|
312
323
|
Returns:
|
|
313
324
|
Tuple of (optimizer, parameters)
|
|
314
325
|
"""
|
|
315
|
-
return backend.optimizer_and_params(
|
|
326
|
+
return backend.optimizer_and_params(
|
|
327
|
+
lr=lr,
|
|
328
|
+
weight_decay=weight_decay,
|
|
329
|
+
optimizer=optimizer,
|
|
330
|
+
optimizer_kwargs=optimizer_kwargs,
|
|
331
|
+
)
|
|
316
332
|
|
|
317
333
|
|
|
318
334
|
def optim_step(backend: Any, optimizer: Any, grads: Any) -> None:
|
mlxsmith/sdk/losses.py
CHANGED
|
@@ -70,6 +70,76 @@ def preference_diff(
|
|
|
70
70
|
return (logp_c - logp_r) - ref_diff
|
|
71
71
|
|
|
72
72
|
|
|
73
|
+
@register_loss("cpo")
|
|
74
|
+
def cpo_loss(
|
|
75
|
+
backend,
|
|
76
|
+
chosen_ids: Sequence[int],
|
|
77
|
+
rejected_ids: Sequence[int],
|
|
78
|
+
*,
|
|
79
|
+
prompt_len_chosen: int,
|
|
80
|
+
prompt_len_rejected: int,
|
|
81
|
+
beta: float = 0.1,
|
|
82
|
+
) -> Any:
|
|
83
|
+
mx = _require_mx(backend)
|
|
84
|
+
diff = preference_diff(
|
|
85
|
+
backend,
|
|
86
|
+
chosen_ids,
|
|
87
|
+
rejected_ids,
|
|
88
|
+
prompt_len_chosen=prompt_len_chosen,
|
|
89
|
+
prompt_len_rejected=prompt_len_rejected,
|
|
90
|
+
reference_backend=None,
|
|
91
|
+
)
|
|
92
|
+
scaled = _to_mx_scalar(mx, beta) * diff
|
|
93
|
+
return mx.log1p(mx.exp(-scaled))
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@register_loss("ipo")
|
|
97
|
+
def ipo_loss(
|
|
98
|
+
backend,
|
|
99
|
+
chosen_ids: Sequence[int],
|
|
100
|
+
rejected_ids: Sequence[int],
|
|
101
|
+
*,
|
|
102
|
+
prompt_len_chosen: int,
|
|
103
|
+
prompt_len_rejected: int,
|
|
104
|
+
beta: float = 0.1,
|
|
105
|
+
reference_backend: Optional[Any] = None,
|
|
106
|
+
) -> Any:
|
|
107
|
+
mx = _require_mx(backend)
|
|
108
|
+
diff = preference_diff(
|
|
109
|
+
backend,
|
|
110
|
+
chosen_ids,
|
|
111
|
+
rejected_ids,
|
|
112
|
+
prompt_len_chosen=prompt_len_chosen,
|
|
113
|
+
prompt_len_rejected=prompt_len_rejected,
|
|
114
|
+
reference_backend=reference_backend,
|
|
115
|
+
)
|
|
116
|
+
target = _to_mx_scalar(mx, 1.0 / (2.0 * float(beta))) if beta != 0 else _to_mx_scalar(mx, 0.0)
|
|
117
|
+
return (diff - target) ** 2
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@register_loss("hinge")
|
|
121
|
+
def hinge_loss(
|
|
122
|
+
backend,
|
|
123
|
+
chosen_ids: Sequence[int],
|
|
124
|
+
rejected_ids: Sequence[int],
|
|
125
|
+
*,
|
|
126
|
+
prompt_len_chosen: int,
|
|
127
|
+
prompt_len_rejected: int,
|
|
128
|
+
delta: float = 0.0,
|
|
129
|
+
reference_backend: Optional[Any] = None,
|
|
130
|
+
) -> Any:
|
|
131
|
+
mx = _require_mx(backend)
|
|
132
|
+
diff = preference_diff(
|
|
133
|
+
backend,
|
|
134
|
+
chosen_ids,
|
|
135
|
+
rejected_ids,
|
|
136
|
+
prompt_len_chosen=prompt_len_chosen,
|
|
137
|
+
prompt_len_rejected=prompt_len_rejected,
|
|
138
|
+
reference_backend=reference_backend,
|
|
139
|
+
)
|
|
140
|
+
return mx.maximum(_to_mx_scalar(mx, delta) - diff, _to_mx_scalar(mx, 0.0))
|
|
141
|
+
|
|
142
|
+
|
|
73
143
|
@register_loss("dpo")
|
|
74
144
|
def dpo_loss(
|
|
75
145
|
backend,
|
|
@@ -149,8 +219,10 @@ def preference_loss(
|
|
|
149
219
|
reference_backend: Optional[Any] = None,
|
|
150
220
|
kl_coeff: float = 0.0,
|
|
151
221
|
train_on_prompt: bool = False,
|
|
222
|
+
delta: float = 0.0,
|
|
152
223
|
) -> Any:
|
|
153
|
-
|
|
224
|
+
algo_l = algo.lower()
|
|
225
|
+
if algo_l == "orpo":
|
|
154
226
|
return orpo_loss(
|
|
155
227
|
backend,
|
|
156
228
|
chosen_ids,
|
|
@@ -162,6 +234,35 @@ def preference_loss(
|
|
|
162
234
|
kl_coeff=kl_coeff,
|
|
163
235
|
train_on_prompt=train_on_prompt,
|
|
164
236
|
)
|
|
237
|
+
if algo_l == "cpo":
|
|
238
|
+
return cpo_loss(
|
|
239
|
+
backend,
|
|
240
|
+
chosen_ids,
|
|
241
|
+
rejected_ids,
|
|
242
|
+
prompt_len_chosen=prompt_len_chosen,
|
|
243
|
+
prompt_len_rejected=prompt_len_rejected,
|
|
244
|
+
beta=beta,
|
|
245
|
+
)
|
|
246
|
+
if algo_l == "ipo":
|
|
247
|
+
return ipo_loss(
|
|
248
|
+
backend,
|
|
249
|
+
chosen_ids,
|
|
250
|
+
rejected_ids,
|
|
251
|
+
prompt_len_chosen=prompt_len_chosen,
|
|
252
|
+
prompt_len_rejected=prompt_len_rejected,
|
|
253
|
+
beta=beta,
|
|
254
|
+
reference_backend=reference_backend,
|
|
255
|
+
)
|
|
256
|
+
if algo_l == "hinge":
|
|
257
|
+
return hinge_loss(
|
|
258
|
+
backend,
|
|
259
|
+
chosen_ids,
|
|
260
|
+
rejected_ids,
|
|
261
|
+
prompt_len_chosen=prompt_len_chosen,
|
|
262
|
+
prompt_len_rejected=prompt_len_rejected,
|
|
263
|
+
delta=delta,
|
|
264
|
+
reference_backend=reference_backend,
|
|
265
|
+
)
|
|
165
266
|
return dpo_loss(
|
|
166
267
|
backend,
|
|
167
268
|
chosen_ids,
|
mlxsmith/sdk/training_client.py
CHANGED
|
@@ -210,6 +210,7 @@ class TrainingClient:
|
|
|
210
210
|
beta=batch.extra.get("beta", 0.1),
|
|
211
211
|
reference_backend=batch.extra.get("reference_backend"),
|
|
212
212
|
kl_coeff=batch.extra.get("kl_coeff", 0.0),
|
|
213
|
+
delta=batch.extra.get("delta", 0.0),
|
|
213
214
|
train_on_prompt=batch.train_on_prompt,
|
|
214
215
|
max_seq_len=batch.max_seq_len,
|
|
215
216
|
)
|
|
@@ -514,12 +515,20 @@ class TrainingClient:
|
|
|
514
515
|
# Utility Methods
|
|
515
516
|
# ========================================================================
|
|
516
517
|
|
|
517
|
-
def create_optimizer(
|
|
518
|
+
def create_optimizer(
|
|
519
|
+
self,
|
|
520
|
+
lr: float = 1e-4,
|
|
521
|
+
weight_decay: float = 0.0,
|
|
522
|
+
optimizer: Optional[str] = None,
|
|
523
|
+
optimizer_kwargs: Optional[Dict[str, Any]] = None,
|
|
524
|
+
) -> APIFuture[Any]:
|
|
518
525
|
"""Create optimizer for training.
|
|
519
526
|
|
|
520
527
|
Args:
|
|
521
528
|
lr: Learning rate
|
|
522
529
|
weight_decay: Weight decay coefficient
|
|
530
|
+
optimizer: Optimizer name (e.g., adamw, adam, qhadam, muon)
|
|
531
|
+
optimizer_kwargs: Extra optimizer kwargs
|
|
523
532
|
|
|
524
533
|
Returns:
|
|
525
534
|
APIFuture resolving to optimizer instance
|
|
@@ -535,6 +544,8 @@ class TrainingClient:
|
|
|
535
544
|
self.backend,
|
|
536
545
|
lr=lr,
|
|
537
546
|
weight_decay=weight_decay,
|
|
547
|
+
optimizer=optimizer,
|
|
548
|
+
optimizer_kwargs=optimizer_kwargs,
|
|
538
549
|
)
|
|
539
550
|
return self.optimizer
|
|
540
551
|
|
|
@@ -588,10 +599,18 @@ class TrainingClient:
|
|
|
588
599
|
if len(grads_list) == 1:
|
|
589
600
|
return grads_list[0]
|
|
590
601
|
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
602
|
+
from ..util import tree_add, tree_scale
|
|
603
|
+
|
|
604
|
+
agg = None
|
|
605
|
+
count = 0
|
|
606
|
+
for grads in grads_list:
|
|
607
|
+
if grads is None:
|
|
608
|
+
continue
|
|
609
|
+
agg = tree_add(agg, grads)
|
|
610
|
+
count += 1
|
|
611
|
+
if agg is None or count == 0:
|
|
612
|
+
return None
|
|
613
|
+
return tree_scale(agg, 1.0 / float(count))
|
|
595
614
|
|
|
596
615
|
def shutdown(self) -> None:
|
|
597
616
|
"""Shutdown the client and its thread pool."""
|
mlxsmith/train/distill.py
CHANGED
|
@@ -109,7 +109,12 @@ def run_distill(
|
|
|
109
109
|
)
|
|
110
110
|
student.apply_lora_from_config(lora_cfg)
|
|
111
111
|
|
|
112
|
-
opt, _params = student.optimizer_and_params(
|
|
112
|
+
opt, _params = student.optimizer_and_params(
|
|
113
|
+
lr=cfg.train.lr,
|
|
114
|
+
weight_decay=cfg.train.weight_decay,
|
|
115
|
+
optimizer=cfg.train.optimizer,
|
|
116
|
+
optimizer_kwargs=cfg.train.optimizer_kwargs,
|
|
117
|
+
)
|
|
113
118
|
|
|
114
119
|
rng = random.Random(cfg.train.seed)
|
|
115
120
|
total = int(cfg.train.iters)
|