mlxsmith 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. mlxsmith/__init__.py +2 -0
  2. mlxsmith/accel/__init__.py +10 -0
  3. mlxsmith/accel/base.py +17 -0
  4. mlxsmith/accel/none.py +13 -0
  5. mlxsmith/accel/zmlx_backend.py +42 -0
  6. mlxsmith/adapters.py +46 -0
  7. mlxsmith/api/__init__.py +48 -0
  8. mlxsmith/api/handlers.py +1217 -0
  9. mlxsmith/api/schemas.py +436 -0
  10. mlxsmith/auth.py +88 -0
  11. mlxsmith/bench.py +102 -0
  12. mlxsmith/cli.py +950 -0
  13. mlxsmith/config.py +543 -0
  14. mlxsmith/config_models.py +261 -0
  15. mlxsmith/data.py +493 -0
  16. mlxsmith/envs/__init__.py +33 -0
  17. mlxsmith/envs/system.py +388 -0
  18. mlxsmith/envs/token_env.py +191 -0
  19. mlxsmith/eval.py +112 -0
  20. mlxsmith/infer.py +140 -0
  21. mlxsmith/llm/__init__.py +16 -0
  22. mlxsmith/llm/backend.py +126 -0
  23. mlxsmith/llm/interface.py +212 -0
  24. mlxsmith/llm/mlx_lm_backend.py +509 -0
  25. mlxsmith/llm/mock_backend.py +228 -0
  26. mlxsmith/llm/registry.py +12 -0
  27. mlxsmith/models.py +257 -0
  28. mlxsmith/orchestrator/__init__.py +25 -0
  29. mlxsmith/orchestrator/daemon.py +454 -0
  30. mlxsmith/orchestrator/inference_worker.py +496 -0
  31. mlxsmith/orchestrator/queue.py +355 -0
  32. mlxsmith/orchestrator/trainer_worker.py +437 -0
  33. mlxsmith/rlm/__init__.py +8 -0
  34. mlxsmith/rlm/corpus.py +74 -0
  35. mlxsmith/rlm/gating.py +90 -0
  36. mlxsmith/rlm/generate.py +249 -0
  37. mlxsmith/rlm/history.py +12 -0
  38. mlxsmith/rlm/inference.py +150 -0
  39. mlxsmith/rlm/loop.py +1297 -0
  40. mlxsmith/rlm/mutate.py +82 -0
  41. mlxsmith/rlm/trainer.py +73 -0
  42. mlxsmith/rlm/weights.py +263 -0
  43. mlxsmith/runs.py +44 -0
  44. mlxsmith/sdk/__init__.py +392 -0
  45. mlxsmith/sdk/future.py +486 -0
  46. mlxsmith/sdk/losses.py +262 -0
  47. mlxsmith/sdk/sampling_client.py +729 -0
  48. mlxsmith/sdk/training_client.py +676 -0
  49. mlxsmith/server.py +376 -0
  50. mlxsmith/train/__init__.py +0 -0
  51. mlxsmith/train/distill.py +279 -0
  52. mlxsmith/train/lora.py +280 -0
  53. mlxsmith/train/pref.py +180 -0
  54. mlxsmith/train/rft.py +458 -0
  55. mlxsmith/train/sft.py +151 -0
  56. mlxsmith/util.py +174 -0
  57. mlxsmith/verifiers/__init__.py +3 -0
  58. mlxsmith/verifiers/compose.py +109 -0
  59. mlxsmith/verifiers/docker_verifier.py +111 -0
  60. mlxsmith/verifiers/jsonschema.py +54 -0
  61. mlxsmith/verifiers/pytest_verifier.py +82 -0
  62. mlxsmith/verifiers/regex.py +15 -0
  63. mlxsmith/verifiers/types.py +10 -0
  64. mlxsmith-0.1.0.dist-info/METADATA +163 -0
  65. mlxsmith-0.1.0.dist-info/RECORD +69 -0
  66. mlxsmith-0.1.0.dist-info/WHEEL +5 -0
  67. mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
  68. mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
  69. mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,437 @@
1
+ """Trainer Worker Process for MLXSmith Orchestrator.
2
+
3
+ Runs as a separate process for training.
4
+ Consumes training batches from queue.
5
+ Publishes adapter checkpoints and signals weight updates.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import json
11
+ import signal
12
+ import sys
13
+ import time
14
+ import traceback
15
+ from collections import defaultdict
16
+ from dataclasses import dataclass, field
17
+ from pathlib import Path
18
+ from typing import Any, Dict, List, Optional
19
+
20
+ from ..config import ProjectConfig
21
+ from ..llm.registry import get_llm_backend
22
+ from ..rlm.inference import Rollout
23
+ from ..rlm.weights import WeightPointerStore, WeightPointerIPC
24
+ from ..train.lora import LoRAConfig
25
+ from ..util import ensure_dir, now_ts
26
+ from .queue import MessageQueue, MessageType, Message
27
+
28
+
29
+ @dataclass
30
+ class TrainerConfig:
31
+ """Configuration for trainer worker."""
32
+ model_spec: str
33
+ base_model: str
34
+ backend: str = "mlx-lm"
35
+ max_seq_len: int = 8192
36
+ dtype: str = "bf16"
37
+ trust_remote_code: bool = False
38
+
39
+ # Training config
40
+ lr: float = 2e-4
41
+ weight_decay: float = 0.0
42
+ kl_coeff: float = 0.02
43
+ normalize_advantage: bool = True
44
+
45
+ # LoRA config
46
+ lora_r: int = 16
47
+ lora_alpha: int = 32
48
+ lora_dropout: float = 0.05
49
+ lora_target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj", "o_proj"])
50
+ lora_num_layers: int = 0
51
+
52
+ # Paths
53
+ weights_dir: Optional[Path] = None
54
+ checkpoint_dir: Optional[Path] = None
55
+
56
+ # Reference model for KL penalty
57
+ reference_model: Optional[str] = None
58
+
59
+
60
+ @dataclass
61
+ class TrainingBatch:
62
+ """A batch of rollouts for training."""
63
+ iteration: int
64
+ run_id: str
65
+ rollouts: List[Rollout]
66
+ task_metadata: List[Dict[str, Any]] = field(default_factory=list)
67
+
68
+
69
+ class TrainerWorker:
70
+ """Trainer worker process for RLM training.
71
+
72
+ Runs in a separate process, handles:
73
+ - Consuming training batches from queue
74
+ - Running forward/backward passes
75
+ - Publishing adapter checkpoints
76
+ - Signaling weight updates to orchestrator
77
+ """
78
+
79
+ def __init__(
80
+ self,
81
+ config: TrainerConfig,
82
+ queue: Optional[MessageQueue] = None,
83
+ ):
84
+ self.config = config
85
+ self.queue = queue
86
+ self._llm = None
87
+ self._ref_llm = None
88
+ self._optimizer = None
89
+ self._pointer_store: Optional[WeightPointerStore] = None
90
+ self._current_iteration = 0
91
+ self._current_adapter: Optional[str] = None
92
+ self._shutdown = False
93
+ self._metrics: List[Dict] = []
94
+
95
+ def _load_model(self) -> None:
96
+ """Load the model and optimizer."""
97
+ print("[TrainerWorker] Loading model...")
98
+ self._llm = get_llm_backend(self.config.backend)
99
+
100
+ # Load base model
101
+ self._llm.load(
102
+ self.config.base_model,
103
+ max_seq_len=self.config.max_seq_len,
104
+ dtype=self.config.dtype,
105
+ trust_remote_code=self.config.trust_remote_code,
106
+ )
107
+
108
+ # Setup weight pointer store
109
+ if self.config.weights_dir:
110
+ self._pointer_store = WeightPointerStore(self.config.weights_dir)
111
+ pointer = self._pointer_store.load("trainer", self.config.base_model)
112
+
113
+ if pointer.adapter_path:
114
+ print(f"[TrainerWorker] Loading adapter: {pointer.adapter_path}")
115
+ self._llm.apply_adapter(pointer.adapter_path)
116
+ self._current_adapter = pointer.adapter_path
117
+ self._current_iteration = pointer.iteration
118
+ else:
119
+ # Initialize new LoRA adapter
120
+ print("[TrainerWorker] Initializing LoRA adapter...")
121
+ lora_cfg = LoRAConfig(
122
+ r=self.config.lora_r,
123
+ alpha=self.config.lora_alpha,
124
+ dropout=self.config.lora_dropout,
125
+ target_modules=list(self.config.lora_target_modules),
126
+ num_layers=self.config.lora_num_layers,
127
+ )
128
+ self._llm.apply_lora_from_config(lora_cfg)
129
+
130
+ # Setup optimizer
131
+ self._optimizer, _ = self._llm.optimizer_and_params(
132
+ lr=self.config.lr,
133
+ weight_decay=self.config.weight_decay,
134
+ )
135
+
136
+ # Load reference model if needed for KL
137
+ if self.config.reference_model and self.config.kl_coeff > 0:
138
+ print("[TrainerWorker] Loading reference model...")
139
+ self._ref_llm = get_llm_backend(self.config.backend)
140
+ self._ref_llm.load(
141
+ self.config.reference_model,
142
+ max_seq_len=self.config.max_seq_len,
143
+ dtype=self.config.dtype,
144
+ trust_remote_code=self.config.trust_remote_code,
145
+ )
146
+
147
+ print("[TrainerWorker] Model loaded successfully")
148
+
149
+ def _train_on_batch(self, batch: TrainingBatch) -> Dict[str, Any]:
150
+ """Train on a batch of rollouts."""
151
+ rollouts = batch.rollouts
152
+ if not rollouts:
153
+ return {"status": "empty", "loss": 0.0}
154
+
155
+ # Group rollouts by task
156
+ grouped = defaultdict(list)
157
+ for r in rollouts:
158
+ grouped[r.task_id].append(r)
159
+
160
+ total_loss = 0.0
161
+ num_tasks = 0
162
+
163
+ for task_id, task_rollouts in grouped.items():
164
+ if not task_rollouts:
165
+ continue
166
+
167
+ # Compute advantages
168
+ mean_r = sum(r.reward for r in task_rollouts) / len(task_rollouts)
169
+ std_r = (sum((r.reward - mean_r) ** 2 for r in task_rollouts) / len(task_rollouts)) ** 0.5
170
+ advs = [r.reward - mean_r for r in task_rollouts]
171
+
172
+ if self.config.normalize_advantage and std_r > 1e-6:
173
+ advs = [a / std_r for a in advs]
174
+
175
+ # Define loss function
176
+ def loss_fn(_model):
177
+ loss = self._llm.mx.array(0.0)
178
+ for rollout, adv in zip(task_rollouts, advs):
179
+ logp = self._llm.sequence_logprob(
180
+ rollout.token_ids,
181
+ prompt_len=rollout.prompt_len,
182
+ )
183
+
184
+ # Importance sampling if rollout was from different policy
185
+ if rollout.logprobs and rollout.weight_adapter and rollout.weight_adapter != self._current_adapter:
186
+ behavior_logp = self._llm.mx.array(sum(rollout.logprobs))
187
+ ratio = self._llm.mx.exp(logp - behavior_logp)
188
+ pg = -ratio * self._llm.mx.array(float(adv))
189
+ else:
190
+ pg = -self._llm.mx.array(float(adv)) * logp
191
+
192
+ # KL penalty from reference model
193
+ if self._ref_llm is not None and self.config.kl_coeff > 0:
194
+ ref_logp = self._ref_llm.sequence_logprob(
195
+ rollout.token_ids,
196
+ prompt_len=rollout.prompt_len,
197
+ )
198
+ pg = pg + self._llm.mx.array(self.config.kl_coeff) * (logp - ref_logp)
199
+
200
+ loss = loss + pg
201
+
202
+ return loss / self._llm.mx.array(float(len(task_rollouts)))
203
+
204
+ # Compute gradients and update
205
+ lval, grads = self._llm.value_and_grad(loss_fn)
206
+ if grads is not None:
207
+ self._llm.apply_grads(self._optimizer, grads)
208
+
209
+ loss_val = float(lval.item()) if hasattr(lval, "item") else float(lval)
210
+ total_loss += loss_val
211
+ num_tasks += 1
212
+
213
+ # Record metrics
214
+ self._metrics.append({
215
+ "ts": now_ts(),
216
+ "iteration": batch.iteration,
217
+ "task_id": task_id,
218
+ "mean_reward": mean_r,
219
+ "std_reward": std_r,
220
+ "loss": loss_val,
221
+ "num_rollouts": len(task_rollouts),
222
+ })
223
+
224
+ avg_loss = total_loss / max(1, num_tasks)
225
+ return {
226
+ "status": "success",
227
+ "loss": avg_loss,
228
+ "num_tasks": num_tasks,
229
+ "num_rollouts": len(rollouts),
230
+ }
231
+
232
+ def _save_checkpoint(self, iteration: int) -> Optional[str]:
233
+ """Save adapter checkpoint and return path."""
234
+ if not self.config.checkpoint_dir:
235
+ return None
236
+
237
+ checkpoint_path = ensure_dir(self.config.checkpoint_dir / f"iter_{iteration:04d}")
238
+
239
+ self._llm.save_adapter(
240
+ str(checkpoint_path),
241
+ metadata={
242
+ "base_model": self.config.base_model,
243
+ "source_adapter": self._current_adapter,
244
+ "iteration": iteration,
245
+ "kind": "rlm",
246
+ },
247
+ )
248
+
249
+ # Save metrics
250
+ metrics_path = checkpoint_path / "training_metrics.jsonl"
251
+ with open(metrics_path, "w") as f:
252
+ for m in self._metrics:
253
+ f.write(json.dumps(m) + "\n")
254
+
255
+ self._current_adapter = str(checkpoint_path)
256
+ self._current_iteration = iteration
257
+
258
+ return str(checkpoint_path)
259
+
260
+ def _update_weight_pointer(self, adapter_path: str, iteration: int) -> None:
261
+ """Update the weight pointer for inference to pick up."""
262
+ if not self._pointer_store:
263
+ return
264
+
265
+ pointer = WeightPointerIPC(
266
+ base_model=self.config.base_model,
267
+ adapter_path=adapter_path,
268
+ iteration=iteration,
269
+ updated_at=now_ts(),
270
+ version=iteration, # Use iteration as version
271
+ name="trainer",
272
+ )
273
+
274
+ self._pointer_store.save(pointer)
275
+ print(f"[TrainerWorker] Updated weight pointer: {adapter_path} (iter {iteration})")
276
+
277
+ def _handle_train_batch(self, msg: Message) -> Message:
278
+ """Handle a training batch message."""
279
+ payload = msg.payload
280
+
281
+ # Deserialize rollouts
282
+ rollout_data = payload.get("rollouts", [])
283
+ rollouts = []
284
+ for r in rollout_data:
285
+ rollouts.append(Rollout(
286
+ task_id=r["task_id"],
287
+ prompt=r["prompt"],
288
+ completion=r["completion"],
289
+ token_ids=r["token_ids"],
290
+ prompt_len=r["prompt_len"],
291
+ logprobs=r.get("logprobs"),
292
+ passed=r["passed"],
293
+ reward=r["reward"],
294
+ verifier_latency_ms=r["verifier_latency_ms"],
295
+ weight_adapter=r.get("weight_adapter"),
296
+ ))
297
+
298
+ batch = TrainingBatch(
299
+ iteration=payload.get("iteration", 0),
300
+ run_id=payload.get("run_id", ""),
301
+ rollouts=rollouts,
302
+ task_metadata=payload.get("task_metadata", []),
303
+ )
304
+
305
+ # Train
306
+ result = self._train_on_batch(batch)
307
+
308
+ # Save checkpoint
309
+ checkpoint_path = None
310
+ if payload.get("save_checkpoint", False):
311
+ checkpoint_path = self._save_checkpoint(batch.iteration)
312
+
313
+ # Update weight pointer for hot-reload
314
+ if checkpoint_path:
315
+ self._update_weight_pointer(checkpoint_path, batch.iteration)
316
+
317
+ return Message(
318
+ msg_type=MessageType.TRAIN_COMPLETE,
319
+ payload={
320
+ "request_id": msg.msg_id,
321
+ "iteration": batch.iteration,
322
+ "run_id": batch.run_id,
323
+ "result": result,
324
+ "checkpoint_path": checkpoint_path,
325
+ },
326
+ source="trainer",
327
+ )
328
+
329
+ def _handle_health_check(self, msg: Message) -> Message:
330
+ """Handle a health check request."""
331
+ return Message(
332
+ msg_type=MessageType.HEALTH_RESPONSE,
333
+ payload={
334
+ "status": "healthy",
335
+ "base_model": self.config.base_model,
336
+ "current_adapter": self._current_adapter,
337
+ "current_iteration": self._current_iteration,
338
+ "metrics_count": len(self._metrics),
339
+ },
340
+ source="trainer",
341
+ )
342
+
343
+ def _process_message(self, msg: Message) -> Optional[Message]:
344
+ """Process a single message."""
345
+ try:
346
+ if msg.msg_type == MessageType.TRAIN_BATCH:
347
+ return self._handle_train_batch(msg)
348
+ elif msg.msg_type == MessageType.HEALTH_CHECK:
349
+ return self._handle_health_check(msg)
350
+ elif msg.msg_type == MessageType.SHUTDOWN:
351
+ self._shutdown = True
352
+ return None
353
+ else:
354
+ print(f"[TrainerWorker] Unknown message type: {msg.msg_type}")
355
+ return None
356
+ except Exception as e:
357
+ print(f"[TrainerWorker] Error processing message: {e}")
358
+ traceback.print_exc()
359
+ return Message(
360
+ msg_type=MessageType.TRAIN_COMPLETE,
361
+ payload={
362
+ "request_id": msg.msg_id,
363
+ "status": "error",
364
+ "error": str(e),
365
+ },
366
+ source="trainer",
367
+ )
368
+
369
+ def run(self) -> None:
370
+ """Run the trainer worker loop."""
371
+ # Setup signal handlers
372
+ def signal_handler(sig, frame):
373
+ print("[TrainerWorker] Shutting down...")
374
+ self._shutdown = True
375
+
376
+ signal.signal(signal.SIGTERM, signal_handler)
377
+ signal.signal(signal.SIGINT, signal_handler)
378
+
379
+ # Load model
380
+ self._load_model()
381
+
382
+ if not self.queue:
383
+ print("[TrainerWorker] No queue provided, running in standalone mode")
384
+ # Standalone mode - just wait for shutdown
385
+ while not self._shutdown:
386
+ time.sleep(0.1)
387
+ return
388
+
389
+ print("[TrainerWorker] Starting training loop...")
390
+
391
+ # Main training loop
392
+ while not self._shutdown:
393
+ try:
394
+ # Check for training batches
395
+ msg = self.queue.receive("train_batches", timeout=1.0)
396
+
397
+ if msg:
398
+ print(f"[TrainerWorker] Received training batch: {msg.msg_id}")
399
+ response = self._process_message(msg)
400
+
401
+ if response:
402
+ self.queue.get_queue("train_complete").put(response.to_dict())
403
+
404
+ # Signal weight update if checkpoint was saved
405
+ if response.payload.get("checkpoint_path"):
406
+ weight_update = Message(
407
+ msg_type=MessageType.WEIGHT_UPDATE,
408
+ payload={
409
+ "adapter_path": response.payload["checkpoint_path"],
410
+ "version": response.payload.get("iteration", 0),
411
+ "base_model": self.config.base_model,
412
+ },
413
+ source="trainer",
414
+ )
415
+ self.queue.get_queue("weight_updates").put(weight_update.to_dict())
416
+ self.queue.get_queue("checkpoints").put(weight_update.to_dict())
417
+
418
+ # Check control queue
419
+ control_msg = self.queue.receive("control", timeout=0)
420
+ if control_msg:
421
+ self._process_message(control_msg)
422
+
423
+ except Exception as e:
424
+ print(f"[TrainerWorker] Error in main loop: {e}")
425
+ traceback.print_exc()
426
+ time.sleep(1.0)
427
+
428
+ print("[TrainerWorker] Shutdown complete")
429
+
430
+
431
+ def run_trainer_worker(
432
+ config: TrainerConfig,
433
+ queue: Optional[MessageQueue] = None,
434
+ ) -> None:
435
+ """Entry point for trainer worker process."""
436
+ worker = TrainerWorker(config, queue)
437
+ worker.run()
@@ -0,0 +1,8 @@
1
+ """Recursive Language Model (RLM) module for MLXSmith.
2
+
3
+ Provides both single-process and multi-process orchestrated RLM training loops.
4
+ """
5
+
6
+ from .loop import run_rlm, run_rlm_orchestrated
7
+
8
+ __all__ = ["run_rlm", "run_rlm_orchestrated"]
mlxsmith/rlm/corpus.py ADDED
@@ -0,0 +1,74 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from pathlib import Path
5
+ from typing import Iterable, List
6
+
7
+ from ..util import ensure_dir
8
+
9
+
10
+ def _read_jsonl(path: Path) -> List[dict]:
11
+ if not path.exists():
12
+ return []
13
+ rows = []
14
+ for line in path.read_text(encoding="utf-8").splitlines():
15
+ line = line.strip()
16
+ if not line:
17
+ continue
18
+ try:
19
+ rows.append(json.loads(line))
20
+ except Exception:
21
+ continue
22
+ return rows
23
+
24
+
25
+ def load_corpus(path: Path, *, max_size: int | None = None) -> List[dict]:
26
+ rows = _read_jsonl(path)
27
+ if max_size is not None and len(rows) > max_size:
28
+ return rows[-max_size:]
29
+ return rows
30
+
31
+
32
+ def append_corpus(path: Path, rows: Iterable[dict], *, max_size: int) -> None:
33
+ ensure_dir(path.parent)
34
+ with path.open("a", encoding="utf-8") as f:
35
+ for row in rows:
36
+ f.write(json.dumps(row, ensure_ascii=False) + "\n")
37
+
38
+ if max_size <= 0:
39
+ return
40
+
41
+ all_rows = _read_jsonl(path)
42
+ if len(all_rows) > max_size:
43
+ trimmed = all_rows[-max_size:]
44
+ path.write_text("\n".join(json.dumps(r, ensure_ascii=False) for r in trimmed) + "\n", encoding="utf-8")
45
+
46
+
47
+ def sample_corpus(rows: List[dict], *, n: int, hard_ratio: float = 0.0) -> List[dict]:
48
+ if n <= 0 or not rows:
49
+ return []
50
+ if n >= len(rows):
51
+ return list(rows)
52
+
53
+ hard_n = int(round(n * max(0.0, min(1.0, hard_ratio))))
54
+ easy_n = n - hard_n
55
+
56
+ # Hard samples = longest prompts (proxy for difficulty).
57
+ sorted_rows = sorted(rows, key=lambda r: len((r.get("prompt") or "")), reverse=True)
58
+ hard = sorted_rows[:hard_n] if hard_n > 0 else []
59
+
60
+ # Easy samples = shortest prompts.
61
+ easy = sorted_rows[-easy_n:] if easy_n > 0 else []
62
+
63
+ # Deduplicate while preserving order.
64
+ seen = set()
65
+ out = []
66
+ for row in hard + easy:
67
+ key = row.get("id") or row.get("hash") or (row.get("prompt"), row.get("response"))
68
+ if key in seen:
69
+ continue
70
+ seen.add(key)
71
+ out.append(row)
72
+ if len(out) >= n:
73
+ break
74
+ return out
mlxsmith/rlm/gating.py ADDED
@@ -0,0 +1,90 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Optional
7
+
8
+ from ..util import ensure_dir
9
+
10
+
11
+ @dataclass
12
+ class GatingState:
13
+ best_score: Optional[float] = None
14
+ best_adapter: Optional[str] = None
15
+ ema_score: Optional[float] = None
16
+ last_iteration: int = 0
17
+ current_adapter: Optional[str] = None
18
+
19
+
20
+ def load_state(path: Path) -> GatingState:
21
+ if not path.exists():
22
+ return GatingState()
23
+ data = json.loads(path.read_text(encoding="utf-8"))
24
+ return GatingState(
25
+ best_score=data.get("best_score"),
26
+ best_adapter=data.get("best_adapter"),
27
+ ema_score=data.get("ema_score"),
28
+ last_iteration=int(data.get("last_iteration", 0)),
29
+ current_adapter=data.get("current_adapter"),
30
+ )
31
+
32
+
33
+ def save_state(path: Path, state: GatingState) -> None:
34
+ ensure_dir(path.parent)
35
+ path.write_text(
36
+ json.dumps(
37
+ {
38
+ "best_score": state.best_score,
39
+ "best_adapter": state.best_adapter,
40
+ "ema_score": state.ema_score,
41
+ "last_iteration": state.last_iteration,
42
+ "current_adapter": state.current_adapter,
43
+ },
44
+ indent=2,
45
+ ),
46
+ encoding="utf-8",
47
+ )
48
+
49
+
50
+ def should_accept(
51
+ score: float,
52
+ state: GatingState,
53
+ *,
54
+ mode: str,
55
+ threshold: float = 0.0,
56
+ ema_alpha: float = 0.2,
57
+ ) -> bool:
58
+ if state.best_score is None:
59
+ return True
60
+ mode = (mode or "strict").lower()
61
+ if mode == "threshold":
62
+ return score >= float(state.best_score) + float(threshold)
63
+ if mode == "ema":
64
+ ema = state.ema_score if state.ema_score is not None else state.best_score
65
+ return score >= float(ema)
66
+ # strict
67
+ return score > float(state.best_score)
68
+
69
+
70
+ def update_state(
71
+ state: GatingState,
72
+ *,
73
+ iteration: int,
74
+ score: float,
75
+ adapter_path: str,
76
+ accepted: bool,
77
+ ema_alpha: float = 0.2,
78
+ ) -> GatingState:
79
+ state.last_iteration = iteration
80
+ if state.ema_score is None:
81
+ state.ema_score = score
82
+ else:
83
+ state.ema_score = float(ema_alpha) * score + (1.0 - float(ema_alpha)) * float(state.ema_score)
84
+
85
+ if accepted:
86
+ state.current_adapter = adapter_path
87
+ if state.best_score is None or score > float(state.best_score):
88
+ state.best_score = score
89
+ state.best_adapter = adapter_path
90
+ return state