mlxsmith 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlxsmith/__init__.py +2 -0
- mlxsmith/accel/__init__.py +10 -0
- mlxsmith/accel/base.py +17 -0
- mlxsmith/accel/none.py +13 -0
- mlxsmith/accel/zmlx_backend.py +42 -0
- mlxsmith/adapters.py +46 -0
- mlxsmith/api/__init__.py +48 -0
- mlxsmith/api/handlers.py +1217 -0
- mlxsmith/api/schemas.py +436 -0
- mlxsmith/auth.py +88 -0
- mlxsmith/bench.py +102 -0
- mlxsmith/cli.py +950 -0
- mlxsmith/config.py +543 -0
- mlxsmith/config_models.py +261 -0
- mlxsmith/data.py +493 -0
- mlxsmith/envs/__init__.py +33 -0
- mlxsmith/envs/system.py +388 -0
- mlxsmith/envs/token_env.py +191 -0
- mlxsmith/eval.py +112 -0
- mlxsmith/infer.py +140 -0
- mlxsmith/llm/__init__.py +16 -0
- mlxsmith/llm/backend.py +126 -0
- mlxsmith/llm/interface.py +212 -0
- mlxsmith/llm/mlx_lm_backend.py +509 -0
- mlxsmith/llm/mock_backend.py +228 -0
- mlxsmith/llm/registry.py +12 -0
- mlxsmith/models.py +257 -0
- mlxsmith/orchestrator/__init__.py +25 -0
- mlxsmith/orchestrator/daemon.py +454 -0
- mlxsmith/orchestrator/inference_worker.py +496 -0
- mlxsmith/orchestrator/queue.py +355 -0
- mlxsmith/orchestrator/trainer_worker.py +437 -0
- mlxsmith/rlm/__init__.py +8 -0
- mlxsmith/rlm/corpus.py +74 -0
- mlxsmith/rlm/gating.py +90 -0
- mlxsmith/rlm/generate.py +249 -0
- mlxsmith/rlm/history.py +12 -0
- mlxsmith/rlm/inference.py +150 -0
- mlxsmith/rlm/loop.py +1297 -0
- mlxsmith/rlm/mutate.py +82 -0
- mlxsmith/rlm/trainer.py +73 -0
- mlxsmith/rlm/weights.py +263 -0
- mlxsmith/runs.py +44 -0
- mlxsmith/sdk/__init__.py +392 -0
- mlxsmith/sdk/future.py +486 -0
- mlxsmith/sdk/losses.py +262 -0
- mlxsmith/sdk/sampling_client.py +729 -0
- mlxsmith/sdk/training_client.py +676 -0
- mlxsmith/server.py +376 -0
- mlxsmith/train/__init__.py +0 -0
- mlxsmith/train/distill.py +279 -0
- mlxsmith/train/lora.py +280 -0
- mlxsmith/train/pref.py +180 -0
- mlxsmith/train/rft.py +458 -0
- mlxsmith/train/sft.py +151 -0
- mlxsmith/util.py +174 -0
- mlxsmith/verifiers/__init__.py +3 -0
- mlxsmith/verifiers/compose.py +109 -0
- mlxsmith/verifiers/docker_verifier.py +111 -0
- mlxsmith/verifiers/jsonschema.py +54 -0
- mlxsmith/verifiers/pytest_verifier.py +82 -0
- mlxsmith/verifiers/regex.py +15 -0
- mlxsmith/verifiers/types.py +10 -0
- mlxsmith-0.1.0.dist-info/METADATA +163 -0
- mlxsmith-0.1.0.dist-info/RECORD +69 -0
- mlxsmith-0.1.0.dist-info/WHEEL +5 -0
- mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
- mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
- mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,437 @@
|
|
|
1
|
+
"""Trainer Worker Process for MLXSmith Orchestrator.
|
|
2
|
+
|
|
3
|
+
Runs as a separate process for training.
|
|
4
|
+
Consumes training batches from queue.
|
|
5
|
+
Publishes adapter checkpoints and signals weight updates.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
import signal
|
|
12
|
+
import sys
|
|
13
|
+
import time
|
|
14
|
+
import traceback
|
|
15
|
+
from collections import defaultdict
|
|
16
|
+
from dataclasses import dataclass, field
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from typing import Any, Dict, List, Optional
|
|
19
|
+
|
|
20
|
+
from ..config import ProjectConfig
|
|
21
|
+
from ..llm.registry import get_llm_backend
|
|
22
|
+
from ..rlm.inference import Rollout
|
|
23
|
+
from ..rlm.weights import WeightPointerStore, WeightPointerIPC
|
|
24
|
+
from ..train.lora import LoRAConfig
|
|
25
|
+
from ..util import ensure_dir, now_ts
|
|
26
|
+
from .queue import MessageQueue, MessageType, Message
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@dataclass
|
|
30
|
+
class TrainerConfig:
|
|
31
|
+
"""Configuration for trainer worker."""
|
|
32
|
+
model_spec: str
|
|
33
|
+
base_model: str
|
|
34
|
+
backend: str = "mlx-lm"
|
|
35
|
+
max_seq_len: int = 8192
|
|
36
|
+
dtype: str = "bf16"
|
|
37
|
+
trust_remote_code: bool = False
|
|
38
|
+
|
|
39
|
+
# Training config
|
|
40
|
+
lr: float = 2e-4
|
|
41
|
+
weight_decay: float = 0.0
|
|
42
|
+
kl_coeff: float = 0.02
|
|
43
|
+
normalize_advantage: bool = True
|
|
44
|
+
|
|
45
|
+
# LoRA config
|
|
46
|
+
lora_r: int = 16
|
|
47
|
+
lora_alpha: int = 32
|
|
48
|
+
lora_dropout: float = 0.05
|
|
49
|
+
lora_target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj", "o_proj"])
|
|
50
|
+
lora_num_layers: int = 0
|
|
51
|
+
|
|
52
|
+
# Paths
|
|
53
|
+
weights_dir: Optional[Path] = None
|
|
54
|
+
checkpoint_dir: Optional[Path] = None
|
|
55
|
+
|
|
56
|
+
# Reference model for KL penalty
|
|
57
|
+
reference_model: Optional[str] = None
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@dataclass
|
|
61
|
+
class TrainingBatch:
|
|
62
|
+
"""A batch of rollouts for training."""
|
|
63
|
+
iteration: int
|
|
64
|
+
run_id: str
|
|
65
|
+
rollouts: List[Rollout]
|
|
66
|
+
task_metadata: List[Dict[str, Any]] = field(default_factory=list)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class TrainerWorker:
|
|
70
|
+
"""Trainer worker process for RLM training.
|
|
71
|
+
|
|
72
|
+
Runs in a separate process, handles:
|
|
73
|
+
- Consuming training batches from queue
|
|
74
|
+
- Running forward/backward passes
|
|
75
|
+
- Publishing adapter checkpoints
|
|
76
|
+
- Signaling weight updates to orchestrator
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
config: TrainerConfig,
|
|
82
|
+
queue: Optional[MessageQueue] = None,
|
|
83
|
+
):
|
|
84
|
+
self.config = config
|
|
85
|
+
self.queue = queue
|
|
86
|
+
self._llm = None
|
|
87
|
+
self._ref_llm = None
|
|
88
|
+
self._optimizer = None
|
|
89
|
+
self._pointer_store: Optional[WeightPointerStore] = None
|
|
90
|
+
self._current_iteration = 0
|
|
91
|
+
self._current_adapter: Optional[str] = None
|
|
92
|
+
self._shutdown = False
|
|
93
|
+
self._metrics: List[Dict] = []
|
|
94
|
+
|
|
95
|
+
def _load_model(self) -> None:
|
|
96
|
+
"""Load the model and optimizer."""
|
|
97
|
+
print("[TrainerWorker] Loading model...")
|
|
98
|
+
self._llm = get_llm_backend(self.config.backend)
|
|
99
|
+
|
|
100
|
+
# Load base model
|
|
101
|
+
self._llm.load(
|
|
102
|
+
self.config.base_model,
|
|
103
|
+
max_seq_len=self.config.max_seq_len,
|
|
104
|
+
dtype=self.config.dtype,
|
|
105
|
+
trust_remote_code=self.config.trust_remote_code,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# Setup weight pointer store
|
|
109
|
+
if self.config.weights_dir:
|
|
110
|
+
self._pointer_store = WeightPointerStore(self.config.weights_dir)
|
|
111
|
+
pointer = self._pointer_store.load("trainer", self.config.base_model)
|
|
112
|
+
|
|
113
|
+
if pointer.adapter_path:
|
|
114
|
+
print(f"[TrainerWorker] Loading adapter: {pointer.adapter_path}")
|
|
115
|
+
self._llm.apply_adapter(pointer.adapter_path)
|
|
116
|
+
self._current_adapter = pointer.adapter_path
|
|
117
|
+
self._current_iteration = pointer.iteration
|
|
118
|
+
else:
|
|
119
|
+
# Initialize new LoRA adapter
|
|
120
|
+
print("[TrainerWorker] Initializing LoRA adapter...")
|
|
121
|
+
lora_cfg = LoRAConfig(
|
|
122
|
+
r=self.config.lora_r,
|
|
123
|
+
alpha=self.config.lora_alpha,
|
|
124
|
+
dropout=self.config.lora_dropout,
|
|
125
|
+
target_modules=list(self.config.lora_target_modules),
|
|
126
|
+
num_layers=self.config.lora_num_layers,
|
|
127
|
+
)
|
|
128
|
+
self._llm.apply_lora_from_config(lora_cfg)
|
|
129
|
+
|
|
130
|
+
# Setup optimizer
|
|
131
|
+
self._optimizer, _ = self._llm.optimizer_and_params(
|
|
132
|
+
lr=self.config.lr,
|
|
133
|
+
weight_decay=self.config.weight_decay,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
# Load reference model if needed for KL
|
|
137
|
+
if self.config.reference_model and self.config.kl_coeff > 0:
|
|
138
|
+
print("[TrainerWorker] Loading reference model...")
|
|
139
|
+
self._ref_llm = get_llm_backend(self.config.backend)
|
|
140
|
+
self._ref_llm.load(
|
|
141
|
+
self.config.reference_model,
|
|
142
|
+
max_seq_len=self.config.max_seq_len,
|
|
143
|
+
dtype=self.config.dtype,
|
|
144
|
+
trust_remote_code=self.config.trust_remote_code,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
print("[TrainerWorker] Model loaded successfully")
|
|
148
|
+
|
|
149
|
+
def _train_on_batch(self, batch: TrainingBatch) -> Dict[str, Any]:
|
|
150
|
+
"""Train on a batch of rollouts."""
|
|
151
|
+
rollouts = batch.rollouts
|
|
152
|
+
if not rollouts:
|
|
153
|
+
return {"status": "empty", "loss": 0.0}
|
|
154
|
+
|
|
155
|
+
# Group rollouts by task
|
|
156
|
+
grouped = defaultdict(list)
|
|
157
|
+
for r in rollouts:
|
|
158
|
+
grouped[r.task_id].append(r)
|
|
159
|
+
|
|
160
|
+
total_loss = 0.0
|
|
161
|
+
num_tasks = 0
|
|
162
|
+
|
|
163
|
+
for task_id, task_rollouts in grouped.items():
|
|
164
|
+
if not task_rollouts:
|
|
165
|
+
continue
|
|
166
|
+
|
|
167
|
+
# Compute advantages
|
|
168
|
+
mean_r = sum(r.reward for r in task_rollouts) / len(task_rollouts)
|
|
169
|
+
std_r = (sum((r.reward - mean_r) ** 2 for r in task_rollouts) / len(task_rollouts)) ** 0.5
|
|
170
|
+
advs = [r.reward - mean_r for r in task_rollouts]
|
|
171
|
+
|
|
172
|
+
if self.config.normalize_advantage and std_r > 1e-6:
|
|
173
|
+
advs = [a / std_r for a in advs]
|
|
174
|
+
|
|
175
|
+
# Define loss function
|
|
176
|
+
def loss_fn(_model):
|
|
177
|
+
loss = self._llm.mx.array(0.0)
|
|
178
|
+
for rollout, adv in zip(task_rollouts, advs):
|
|
179
|
+
logp = self._llm.sequence_logprob(
|
|
180
|
+
rollout.token_ids,
|
|
181
|
+
prompt_len=rollout.prompt_len,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
# Importance sampling if rollout was from different policy
|
|
185
|
+
if rollout.logprobs and rollout.weight_adapter and rollout.weight_adapter != self._current_adapter:
|
|
186
|
+
behavior_logp = self._llm.mx.array(sum(rollout.logprobs))
|
|
187
|
+
ratio = self._llm.mx.exp(logp - behavior_logp)
|
|
188
|
+
pg = -ratio * self._llm.mx.array(float(adv))
|
|
189
|
+
else:
|
|
190
|
+
pg = -self._llm.mx.array(float(adv)) * logp
|
|
191
|
+
|
|
192
|
+
# KL penalty from reference model
|
|
193
|
+
if self._ref_llm is not None and self.config.kl_coeff > 0:
|
|
194
|
+
ref_logp = self._ref_llm.sequence_logprob(
|
|
195
|
+
rollout.token_ids,
|
|
196
|
+
prompt_len=rollout.prompt_len,
|
|
197
|
+
)
|
|
198
|
+
pg = pg + self._llm.mx.array(self.config.kl_coeff) * (logp - ref_logp)
|
|
199
|
+
|
|
200
|
+
loss = loss + pg
|
|
201
|
+
|
|
202
|
+
return loss / self._llm.mx.array(float(len(task_rollouts)))
|
|
203
|
+
|
|
204
|
+
# Compute gradients and update
|
|
205
|
+
lval, grads = self._llm.value_and_grad(loss_fn)
|
|
206
|
+
if grads is not None:
|
|
207
|
+
self._llm.apply_grads(self._optimizer, grads)
|
|
208
|
+
|
|
209
|
+
loss_val = float(lval.item()) if hasattr(lval, "item") else float(lval)
|
|
210
|
+
total_loss += loss_val
|
|
211
|
+
num_tasks += 1
|
|
212
|
+
|
|
213
|
+
# Record metrics
|
|
214
|
+
self._metrics.append({
|
|
215
|
+
"ts": now_ts(),
|
|
216
|
+
"iteration": batch.iteration,
|
|
217
|
+
"task_id": task_id,
|
|
218
|
+
"mean_reward": mean_r,
|
|
219
|
+
"std_reward": std_r,
|
|
220
|
+
"loss": loss_val,
|
|
221
|
+
"num_rollouts": len(task_rollouts),
|
|
222
|
+
})
|
|
223
|
+
|
|
224
|
+
avg_loss = total_loss / max(1, num_tasks)
|
|
225
|
+
return {
|
|
226
|
+
"status": "success",
|
|
227
|
+
"loss": avg_loss,
|
|
228
|
+
"num_tasks": num_tasks,
|
|
229
|
+
"num_rollouts": len(rollouts),
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
def _save_checkpoint(self, iteration: int) -> Optional[str]:
|
|
233
|
+
"""Save adapter checkpoint and return path."""
|
|
234
|
+
if not self.config.checkpoint_dir:
|
|
235
|
+
return None
|
|
236
|
+
|
|
237
|
+
checkpoint_path = ensure_dir(self.config.checkpoint_dir / f"iter_{iteration:04d}")
|
|
238
|
+
|
|
239
|
+
self._llm.save_adapter(
|
|
240
|
+
str(checkpoint_path),
|
|
241
|
+
metadata={
|
|
242
|
+
"base_model": self.config.base_model,
|
|
243
|
+
"source_adapter": self._current_adapter,
|
|
244
|
+
"iteration": iteration,
|
|
245
|
+
"kind": "rlm",
|
|
246
|
+
},
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
# Save metrics
|
|
250
|
+
metrics_path = checkpoint_path / "training_metrics.jsonl"
|
|
251
|
+
with open(metrics_path, "w") as f:
|
|
252
|
+
for m in self._metrics:
|
|
253
|
+
f.write(json.dumps(m) + "\n")
|
|
254
|
+
|
|
255
|
+
self._current_adapter = str(checkpoint_path)
|
|
256
|
+
self._current_iteration = iteration
|
|
257
|
+
|
|
258
|
+
return str(checkpoint_path)
|
|
259
|
+
|
|
260
|
+
def _update_weight_pointer(self, adapter_path: str, iteration: int) -> None:
|
|
261
|
+
"""Update the weight pointer for inference to pick up."""
|
|
262
|
+
if not self._pointer_store:
|
|
263
|
+
return
|
|
264
|
+
|
|
265
|
+
pointer = WeightPointerIPC(
|
|
266
|
+
base_model=self.config.base_model,
|
|
267
|
+
adapter_path=adapter_path,
|
|
268
|
+
iteration=iteration,
|
|
269
|
+
updated_at=now_ts(),
|
|
270
|
+
version=iteration, # Use iteration as version
|
|
271
|
+
name="trainer",
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
self._pointer_store.save(pointer)
|
|
275
|
+
print(f"[TrainerWorker] Updated weight pointer: {adapter_path} (iter {iteration})")
|
|
276
|
+
|
|
277
|
+
def _handle_train_batch(self, msg: Message) -> Message:
|
|
278
|
+
"""Handle a training batch message."""
|
|
279
|
+
payload = msg.payload
|
|
280
|
+
|
|
281
|
+
# Deserialize rollouts
|
|
282
|
+
rollout_data = payload.get("rollouts", [])
|
|
283
|
+
rollouts = []
|
|
284
|
+
for r in rollout_data:
|
|
285
|
+
rollouts.append(Rollout(
|
|
286
|
+
task_id=r["task_id"],
|
|
287
|
+
prompt=r["prompt"],
|
|
288
|
+
completion=r["completion"],
|
|
289
|
+
token_ids=r["token_ids"],
|
|
290
|
+
prompt_len=r["prompt_len"],
|
|
291
|
+
logprobs=r.get("logprobs"),
|
|
292
|
+
passed=r["passed"],
|
|
293
|
+
reward=r["reward"],
|
|
294
|
+
verifier_latency_ms=r["verifier_latency_ms"],
|
|
295
|
+
weight_adapter=r.get("weight_adapter"),
|
|
296
|
+
))
|
|
297
|
+
|
|
298
|
+
batch = TrainingBatch(
|
|
299
|
+
iteration=payload.get("iteration", 0),
|
|
300
|
+
run_id=payload.get("run_id", ""),
|
|
301
|
+
rollouts=rollouts,
|
|
302
|
+
task_metadata=payload.get("task_metadata", []),
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
# Train
|
|
306
|
+
result = self._train_on_batch(batch)
|
|
307
|
+
|
|
308
|
+
# Save checkpoint
|
|
309
|
+
checkpoint_path = None
|
|
310
|
+
if payload.get("save_checkpoint", False):
|
|
311
|
+
checkpoint_path = self._save_checkpoint(batch.iteration)
|
|
312
|
+
|
|
313
|
+
# Update weight pointer for hot-reload
|
|
314
|
+
if checkpoint_path:
|
|
315
|
+
self._update_weight_pointer(checkpoint_path, batch.iteration)
|
|
316
|
+
|
|
317
|
+
return Message(
|
|
318
|
+
msg_type=MessageType.TRAIN_COMPLETE,
|
|
319
|
+
payload={
|
|
320
|
+
"request_id": msg.msg_id,
|
|
321
|
+
"iteration": batch.iteration,
|
|
322
|
+
"run_id": batch.run_id,
|
|
323
|
+
"result": result,
|
|
324
|
+
"checkpoint_path": checkpoint_path,
|
|
325
|
+
},
|
|
326
|
+
source="trainer",
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
def _handle_health_check(self, msg: Message) -> Message:
|
|
330
|
+
"""Handle a health check request."""
|
|
331
|
+
return Message(
|
|
332
|
+
msg_type=MessageType.HEALTH_RESPONSE,
|
|
333
|
+
payload={
|
|
334
|
+
"status": "healthy",
|
|
335
|
+
"base_model": self.config.base_model,
|
|
336
|
+
"current_adapter": self._current_adapter,
|
|
337
|
+
"current_iteration": self._current_iteration,
|
|
338
|
+
"metrics_count": len(self._metrics),
|
|
339
|
+
},
|
|
340
|
+
source="trainer",
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
def _process_message(self, msg: Message) -> Optional[Message]:
|
|
344
|
+
"""Process a single message."""
|
|
345
|
+
try:
|
|
346
|
+
if msg.msg_type == MessageType.TRAIN_BATCH:
|
|
347
|
+
return self._handle_train_batch(msg)
|
|
348
|
+
elif msg.msg_type == MessageType.HEALTH_CHECK:
|
|
349
|
+
return self._handle_health_check(msg)
|
|
350
|
+
elif msg.msg_type == MessageType.SHUTDOWN:
|
|
351
|
+
self._shutdown = True
|
|
352
|
+
return None
|
|
353
|
+
else:
|
|
354
|
+
print(f"[TrainerWorker] Unknown message type: {msg.msg_type}")
|
|
355
|
+
return None
|
|
356
|
+
except Exception as e:
|
|
357
|
+
print(f"[TrainerWorker] Error processing message: {e}")
|
|
358
|
+
traceback.print_exc()
|
|
359
|
+
return Message(
|
|
360
|
+
msg_type=MessageType.TRAIN_COMPLETE,
|
|
361
|
+
payload={
|
|
362
|
+
"request_id": msg.msg_id,
|
|
363
|
+
"status": "error",
|
|
364
|
+
"error": str(e),
|
|
365
|
+
},
|
|
366
|
+
source="trainer",
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
def run(self) -> None:
|
|
370
|
+
"""Run the trainer worker loop."""
|
|
371
|
+
# Setup signal handlers
|
|
372
|
+
def signal_handler(sig, frame):
|
|
373
|
+
print("[TrainerWorker] Shutting down...")
|
|
374
|
+
self._shutdown = True
|
|
375
|
+
|
|
376
|
+
signal.signal(signal.SIGTERM, signal_handler)
|
|
377
|
+
signal.signal(signal.SIGINT, signal_handler)
|
|
378
|
+
|
|
379
|
+
# Load model
|
|
380
|
+
self._load_model()
|
|
381
|
+
|
|
382
|
+
if not self.queue:
|
|
383
|
+
print("[TrainerWorker] No queue provided, running in standalone mode")
|
|
384
|
+
# Standalone mode - just wait for shutdown
|
|
385
|
+
while not self._shutdown:
|
|
386
|
+
time.sleep(0.1)
|
|
387
|
+
return
|
|
388
|
+
|
|
389
|
+
print("[TrainerWorker] Starting training loop...")
|
|
390
|
+
|
|
391
|
+
# Main training loop
|
|
392
|
+
while not self._shutdown:
|
|
393
|
+
try:
|
|
394
|
+
# Check for training batches
|
|
395
|
+
msg = self.queue.receive("train_batches", timeout=1.0)
|
|
396
|
+
|
|
397
|
+
if msg:
|
|
398
|
+
print(f"[TrainerWorker] Received training batch: {msg.msg_id}")
|
|
399
|
+
response = self._process_message(msg)
|
|
400
|
+
|
|
401
|
+
if response:
|
|
402
|
+
self.queue.get_queue("train_complete").put(response.to_dict())
|
|
403
|
+
|
|
404
|
+
# Signal weight update if checkpoint was saved
|
|
405
|
+
if response.payload.get("checkpoint_path"):
|
|
406
|
+
weight_update = Message(
|
|
407
|
+
msg_type=MessageType.WEIGHT_UPDATE,
|
|
408
|
+
payload={
|
|
409
|
+
"adapter_path": response.payload["checkpoint_path"],
|
|
410
|
+
"version": response.payload.get("iteration", 0),
|
|
411
|
+
"base_model": self.config.base_model,
|
|
412
|
+
},
|
|
413
|
+
source="trainer",
|
|
414
|
+
)
|
|
415
|
+
self.queue.get_queue("weight_updates").put(weight_update.to_dict())
|
|
416
|
+
self.queue.get_queue("checkpoints").put(weight_update.to_dict())
|
|
417
|
+
|
|
418
|
+
# Check control queue
|
|
419
|
+
control_msg = self.queue.receive("control", timeout=0)
|
|
420
|
+
if control_msg:
|
|
421
|
+
self._process_message(control_msg)
|
|
422
|
+
|
|
423
|
+
except Exception as e:
|
|
424
|
+
print(f"[TrainerWorker] Error in main loop: {e}")
|
|
425
|
+
traceback.print_exc()
|
|
426
|
+
time.sleep(1.0)
|
|
427
|
+
|
|
428
|
+
print("[TrainerWorker] Shutdown complete")
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
def run_trainer_worker(
|
|
432
|
+
config: TrainerConfig,
|
|
433
|
+
queue: Optional[MessageQueue] = None,
|
|
434
|
+
) -> None:
|
|
435
|
+
"""Entry point for trainer worker process."""
|
|
436
|
+
worker = TrainerWorker(config, queue)
|
|
437
|
+
worker.run()
|
mlxsmith/rlm/__init__.py
ADDED
mlxsmith/rlm/corpus.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Iterable, List
|
|
6
|
+
|
|
7
|
+
from ..util import ensure_dir
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _read_jsonl(path: Path) -> List[dict]:
|
|
11
|
+
if not path.exists():
|
|
12
|
+
return []
|
|
13
|
+
rows = []
|
|
14
|
+
for line in path.read_text(encoding="utf-8").splitlines():
|
|
15
|
+
line = line.strip()
|
|
16
|
+
if not line:
|
|
17
|
+
continue
|
|
18
|
+
try:
|
|
19
|
+
rows.append(json.loads(line))
|
|
20
|
+
except Exception:
|
|
21
|
+
continue
|
|
22
|
+
return rows
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def load_corpus(path: Path, *, max_size: int | None = None) -> List[dict]:
|
|
26
|
+
rows = _read_jsonl(path)
|
|
27
|
+
if max_size is not None and len(rows) > max_size:
|
|
28
|
+
return rows[-max_size:]
|
|
29
|
+
return rows
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def append_corpus(path: Path, rows: Iterable[dict], *, max_size: int) -> None:
|
|
33
|
+
ensure_dir(path.parent)
|
|
34
|
+
with path.open("a", encoding="utf-8") as f:
|
|
35
|
+
for row in rows:
|
|
36
|
+
f.write(json.dumps(row, ensure_ascii=False) + "\n")
|
|
37
|
+
|
|
38
|
+
if max_size <= 0:
|
|
39
|
+
return
|
|
40
|
+
|
|
41
|
+
all_rows = _read_jsonl(path)
|
|
42
|
+
if len(all_rows) > max_size:
|
|
43
|
+
trimmed = all_rows[-max_size:]
|
|
44
|
+
path.write_text("\n".join(json.dumps(r, ensure_ascii=False) for r in trimmed) + "\n", encoding="utf-8")
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def sample_corpus(rows: List[dict], *, n: int, hard_ratio: float = 0.0) -> List[dict]:
|
|
48
|
+
if n <= 0 or not rows:
|
|
49
|
+
return []
|
|
50
|
+
if n >= len(rows):
|
|
51
|
+
return list(rows)
|
|
52
|
+
|
|
53
|
+
hard_n = int(round(n * max(0.0, min(1.0, hard_ratio))))
|
|
54
|
+
easy_n = n - hard_n
|
|
55
|
+
|
|
56
|
+
# Hard samples = longest prompts (proxy for difficulty).
|
|
57
|
+
sorted_rows = sorted(rows, key=lambda r: len((r.get("prompt") or "")), reverse=True)
|
|
58
|
+
hard = sorted_rows[:hard_n] if hard_n > 0 else []
|
|
59
|
+
|
|
60
|
+
# Easy samples = shortest prompts.
|
|
61
|
+
easy = sorted_rows[-easy_n:] if easy_n > 0 else []
|
|
62
|
+
|
|
63
|
+
# Deduplicate while preserving order.
|
|
64
|
+
seen = set()
|
|
65
|
+
out = []
|
|
66
|
+
for row in hard + easy:
|
|
67
|
+
key = row.get("id") or row.get("hash") or (row.get("prompt"), row.get("response"))
|
|
68
|
+
if key in seen:
|
|
69
|
+
continue
|
|
70
|
+
seen.add(key)
|
|
71
|
+
out.append(row)
|
|
72
|
+
if len(out) >= n:
|
|
73
|
+
break
|
|
74
|
+
return out
|
mlxsmith/rlm/gating.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
from ..util import ensure_dir
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class GatingState:
|
|
13
|
+
best_score: Optional[float] = None
|
|
14
|
+
best_adapter: Optional[str] = None
|
|
15
|
+
ema_score: Optional[float] = None
|
|
16
|
+
last_iteration: int = 0
|
|
17
|
+
current_adapter: Optional[str] = None
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def load_state(path: Path) -> GatingState:
|
|
21
|
+
if not path.exists():
|
|
22
|
+
return GatingState()
|
|
23
|
+
data = json.loads(path.read_text(encoding="utf-8"))
|
|
24
|
+
return GatingState(
|
|
25
|
+
best_score=data.get("best_score"),
|
|
26
|
+
best_adapter=data.get("best_adapter"),
|
|
27
|
+
ema_score=data.get("ema_score"),
|
|
28
|
+
last_iteration=int(data.get("last_iteration", 0)),
|
|
29
|
+
current_adapter=data.get("current_adapter"),
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def save_state(path: Path, state: GatingState) -> None:
|
|
34
|
+
ensure_dir(path.parent)
|
|
35
|
+
path.write_text(
|
|
36
|
+
json.dumps(
|
|
37
|
+
{
|
|
38
|
+
"best_score": state.best_score,
|
|
39
|
+
"best_adapter": state.best_adapter,
|
|
40
|
+
"ema_score": state.ema_score,
|
|
41
|
+
"last_iteration": state.last_iteration,
|
|
42
|
+
"current_adapter": state.current_adapter,
|
|
43
|
+
},
|
|
44
|
+
indent=2,
|
|
45
|
+
),
|
|
46
|
+
encoding="utf-8",
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def should_accept(
|
|
51
|
+
score: float,
|
|
52
|
+
state: GatingState,
|
|
53
|
+
*,
|
|
54
|
+
mode: str,
|
|
55
|
+
threshold: float = 0.0,
|
|
56
|
+
ema_alpha: float = 0.2,
|
|
57
|
+
) -> bool:
|
|
58
|
+
if state.best_score is None:
|
|
59
|
+
return True
|
|
60
|
+
mode = (mode or "strict").lower()
|
|
61
|
+
if mode == "threshold":
|
|
62
|
+
return score >= float(state.best_score) + float(threshold)
|
|
63
|
+
if mode == "ema":
|
|
64
|
+
ema = state.ema_score if state.ema_score is not None else state.best_score
|
|
65
|
+
return score >= float(ema)
|
|
66
|
+
# strict
|
|
67
|
+
return score > float(state.best_score)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def update_state(
|
|
71
|
+
state: GatingState,
|
|
72
|
+
*,
|
|
73
|
+
iteration: int,
|
|
74
|
+
score: float,
|
|
75
|
+
adapter_path: str,
|
|
76
|
+
accepted: bool,
|
|
77
|
+
ema_alpha: float = 0.2,
|
|
78
|
+
) -> GatingState:
|
|
79
|
+
state.last_iteration = iteration
|
|
80
|
+
if state.ema_score is None:
|
|
81
|
+
state.ema_score = score
|
|
82
|
+
else:
|
|
83
|
+
state.ema_score = float(ema_alpha) * score + (1.0 - float(ema_alpha)) * float(state.ema_score)
|
|
84
|
+
|
|
85
|
+
if accepted:
|
|
86
|
+
state.current_adapter = adapter_path
|
|
87
|
+
if state.best_score is None or score > float(state.best_score):
|
|
88
|
+
state.best_score = score
|
|
89
|
+
state.best_adapter = adapter_path
|
|
90
|
+
return state
|