mlxsmith 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlxsmith/__init__.py +2 -0
- mlxsmith/accel/__init__.py +10 -0
- mlxsmith/accel/base.py +17 -0
- mlxsmith/accel/none.py +13 -0
- mlxsmith/accel/zmlx_backend.py +42 -0
- mlxsmith/adapters.py +46 -0
- mlxsmith/api/__init__.py +48 -0
- mlxsmith/api/handlers.py +1217 -0
- mlxsmith/api/schemas.py +436 -0
- mlxsmith/auth.py +88 -0
- mlxsmith/bench.py +102 -0
- mlxsmith/cli.py +950 -0
- mlxsmith/config.py +543 -0
- mlxsmith/config_models.py +261 -0
- mlxsmith/data.py +493 -0
- mlxsmith/envs/__init__.py +33 -0
- mlxsmith/envs/system.py +388 -0
- mlxsmith/envs/token_env.py +191 -0
- mlxsmith/eval.py +112 -0
- mlxsmith/infer.py +140 -0
- mlxsmith/llm/__init__.py +16 -0
- mlxsmith/llm/backend.py +126 -0
- mlxsmith/llm/interface.py +212 -0
- mlxsmith/llm/mlx_lm_backend.py +509 -0
- mlxsmith/llm/mock_backend.py +228 -0
- mlxsmith/llm/registry.py +12 -0
- mlxsmith/models.py +257 -0
- mlxsmith/orchestrator/__init__.py +25 -0
- mlxsmith/orchestrator/daemon.py +454 -0
- mlxsmith/orchestrator/inference_worker.py +496 -0
- mlxsmith/orchestrator/queue.py +355 -0
- mlxsmith/orchestrator/trainer_worker.py +437 -0
- mlxsmith/rlm/__init__.py +8 -0
- mlxsmith/rlm/corpus.py +74 -0
- mlxsmith/rlm/gating.py +90 -0
- mlxsmith/rlm/generate.py +249 -0
- mlxsmith/rlm/history.py +12 -0
- mlxsmith/rlm/inference.py +150 -0
- mlxsmith/rlm/loop.py +1297 -0
- mlxsmith/rlm/mutate.py +82 -0
- mlxsmith/rlm/trainer.py +73 -0
- mlxsmith/rlm/weights.py +263 -0
- mlxsmith/runs.py +44 -0
- mlxsmith/sdk/__init__.py +392 -0
- mlxsmith/sdk/future.py +486 -0
- mlxsmith/sdk/losses.py +262 -0
- mlxsmith/sdk/sampling_client.py +729 -0
- mlxsmith/sdk/training_client.py +676 -0
- mlxsmith/server.py +376 -0
- mlxsmith/train/__init__.py +0 -0
- mlxsmith/train/distill.py +279 -0
- mlxsmith/train/lora.py +280 -0
- mlxsmith/train/pref.py +180 -0
- mlxsmith/train/rft.py +458 -0
- mlxsmith/train/sft.py +151 -0
- mlxsmith/util.py +174 -0
- mlxsmith/verifiers/__init__.py +3 -0
- mlxsmith/verifiers/compose.py +109 -0
- mlxsmith/verifiers/docker_verifier.py +111 -0
- mlxsmith/verifiers/jsonschema.py +54 -0
- mlxsmith/verifiers/pytest_verifier.py +82 -0
- mlxsmith/verifiers/regex.py +15 -0
- mlxsmith/verifiers/types.py +10 -0
- mlxsmith-0.1.0.dist-info/METADATA +163 -0
- mlxsmith-0.1.0.dist-info/RECORD +69 -0
- mlxsmith-0.1.0.dist-info/WHEEL +5 -0
- mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
- mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
- mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
mlxsmith/rlm/mutate.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from typing import Iterable, List, Optional
|
|
5
|
+
|
|
6
|
+
from ..util import sha1_text
|
|
7
|
+
from .generate import GeneratedTask, extract_json_objects, task_to_prompt, task_to_tests
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def mutate_tasks(
|
|
11
|
+
llm,
|
|
12
|
+
tasks: Iterable[GeneratedTask],
|
|
13
|
+
*,
|
|
14
|
+
mutations_per_task: int,
|
|
15
|
+
max_total: Optional[int] = None,
|
|
16
|
+
temperature: float = 0.7,
|
|
17
|
+
max_new_tokens: int = 512,
|
|
18
|
+
top_p: float = 1.0,
|
|
19
|
+
top_k: Optional[int] = None,
|
|
20
|
+
require_recursion: bool = False,
|
|
21
|
+
) -> List[GeneratedTask]:
|
|
22
|
+
tasks_list = list(tasks)
|
|
23
|
+
if mutations_per_task <= 0 or not tasks_list:
|
|
24
|
+
return tasks_list
|
|
25
|
+
|
|
26
|
+
mutated: List[GeneratedTask] = list(tasks_list)
|
|
27
|
+
|
|
28
|
+
for task in tasks_list:
|
|
29
|
+
for idx in range(mutations_per_task):
|
|
30
|
+
prompt = (
|
|
31
|
+
"Mutate the following coding task to increase diversity. "
|
|
32
|
+
"Return ONE JSON object with fields: id, description, signature, tests.\n\n"
|
|
33
|
+
f"ORIGINAL_ID: {task.id}\n"
|
|
34
|
+
f"ORIGINAL_PROMPT:\n{task.prompt}\n\n"
|
|
35
|
+
f"ORIGINAL_TESTS:\n{task.tests}\n"
|
|
36
|
+
)
|
|
37
|
+
gen = llm.generate(
|
|
38
|
+
prompt,
|
|
39
|
+
max_new_tokens=max_new_tokens,
|
|
40
|
+
temperature=temperature,
|
|
41
|
+
top_p=top_p,
|
|
42
|
+
top_k=top_k,
|
|
43
|
+
)
|
|
44
|
+
items = extract_json_objects(gen.text)
|
|
45
|
+
if not items:
|
|
46
|
+
continue
|
|
47
|
+
|
|
48
|
+
item = items[0]
|
|
49
|
+
tid = item.get("id") or f"{task.id}_m{idx}" or sha1_text(json.dumps(item, sort_keys=True))[:12]
|
|
50
|
+
task_prompt = task_to_prompt(item, require_recursion=require_recursion)
|
|
51
|
+
tests = task_to_tests(item)
|
|
52
|
+
|
|
53
|
+
if len(task_prompt) < 10 or not tests:
|
|
54
|
+
continue
|
|
55
|
+
|
|
56
|
+
mutated.append(
|
|
57
|
+
GeneratedTask(
|
|
58
|
+
id=str(tid),
|
|
59
|
+
prompt=task_prompt,
|
|
60
|
+
tests=tests,
|
|
61
|
+
description=item.get("description"),
|
|
62
|
+
)
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
if max_total is not None and len(mutated) >= max_total:
|
|
66
|
+
break
|
|
67
|
+
if max_total is not None and len(mutated) >= max_total:
|
|
68
|
+
break
|
|
69
|
+
|
|
70
|
+
# Deduplicate by id/prompt hash
|
|
71
|
+
seen = set()
|
|
72
|
+
deduped: List[GeneratedTask] = []
|
|
73
|
+
for t in mutated:
|
|
74
|
+
key = t.id or sha1_text(t.prompt)
|
|
75
|
+
if key in seen:
|
|
76
|
+
continue
|
|
77
|
+
seen.add(key)
|
|
78
|
+
deduped.append(t)
|
|
79
|
+
if max_total is not None and len(deduped) >= max_total:
|
|
80
|
+
break
|
|
81
|
+
|
|
82
|
+
return deduped
|
mlxsmith/rlm/trainer.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from typing import Iterable, Optional
|
|
5
|
+
|
|
6
|
+
from ..config import ProjectConfig
|
|
7
|
+
from ..util import now_ts, latency_summary_ms
|
|
8
|
+
from .inference import Rollout
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def train_on_rollouts(
|
|
12
|
+
llm,
|
|
13
|
+
rollouts: Iterable[Rollout],
|
|
14
|
+
cfg: ProjectConfig,
|
|
15
|
+
*,
|
|
16
|
+
optimizer: object,
|
|
17
|
+
train_adapter: Optional[str] = None,
|
|
18
|
+
ref_llm: Optional[object] = None,
|
|
19
|
+
) -> list[dict]:
|
|
20
|
+
grouped = defaultdict(list)
|
|
21
|
+
for r in rollouts:
|
|
22
|
+
grouped[r.task_id].append(r)
|
|
23
|
+
|
|
24
|
+
metrics_rows: list[dict] = []
|
|
25
|
+
|
|
26
|
+
for task_id, rows in grouped.items():
|
|
27
|
+
if not rows:
|
|
28
|
+
continue
|
|
29
|
+
|
|
30
|
+
mean_r = sum(r.reward for r in rows) / max(1, len(rows))
|
|
31
|
+
std_r = (sum((r.reward - mean_r) ** 2 for r in rows) / max(1, len(rows))) ** 0.5
|
|
32
|
+
advs = [r.reward - mean_r for r in rows]
|
|
33
|
+
if bool(cfg.rft.normalize_advantage) and std_r > 1e-6:
|
|
34
|
+
advs = [a / std_r for a in advs]
|
|
35
|
+
|
|
36
|
+
def loss_fn(_model):
|
|
37
|
+
loss = llm.mx.array(0.0) # type: ignore
|
|
38
|
+
for rollout, adv in zip(rows, advs):
|
|
39
|
+
logp = llm.sequence_logprob(rollout.token_ids, prompt_len=rollout.prompt_len)
|
|
40
|
+
if rollout.logprobs and rollout.weight_adapter and rollout.weight_adapter != train_adapter:
|
|
41
|
+
behavior_logp = llm.mx.array(sum(rollout.logprobs)) # type: ignore
|
|
42
|
+
ratio = llm.mx.exp(logp - behavior_logp) # type: ignore
|
|
43
|
+
pg = -ratio * llm.mx.array(float(adv)) # type: ignore
|
|
44
|
+
else:
|
|
45
|
+
pg = -llm.mx.array(float(adv)) * logp # type: ignore
|
|
46
|
+
if ref_llm is not None and cfg.rft.kl_coeff > 0:
|
|
47
|
+
ref_logp = ref_llm.sequence_logprob(rollout.token_ids, prompt_len=rollout.prompt_len)
|
|
48
|
+
pg = pg + llm.mx.array(cfg.rft.kl_coeff) * (logp - ref_logp) # type: ignore
|
|
49
|
+
loss = loss + pg
|
|
50
|
+
return loss / llm.mx.array(float(len(rows))) # type: ignore
|
|
51
|
+
|
|
52
|
+
lval, grads = llm.value_and_grad(loss_fn)
|
|
53
|
+
if grads is not None:
|
|
54
|
+
llm.apply_grads(optimizer, grads)
|
|
55
|
+
|
|
56
|
+
latency_summary = latency_summary_ms([float(r.verifier_latency_ms) for r in rows])
|
|
57
|
+
metrics = {
|
|
58
|
+
"ts": now_ts(),
|
|
59
|
+
"task_id": task_id,
|
|
60
|
+
"mean_reward": mean_r,
|
|
61
|
+
"std_reward": std_r,
|
|
62
|
+
"loss": float(lval.item()) if hasattr(lval, "item") else float(lval),
|
|
63
|
+
"verifier_latency_ms": latency_summary.get("mean", 0.0),
|
|
64
|
+
"verifier_latency_ms_mean": latency_summary.get("mean", 0.0),
|
|
65
|
+
"verifier_latency_ms_p50": latency_summary.get("p50", 0.0),
|
|
66
|
+
"verifier_latency_ms_p90": latency_summary.get("p90", 0.0),
|
|
67
|
+
"verifier_latency_ms_p99": latency_summary.get("p99", 0.0),
|
|
68
|
+
"verifier_latency_ms_max": latency_summary.get("max", 0.0),
|
|
69
|
+
"weight_adapter": rows[0].weight_adapter,
|
|
70
|
+
}
|
|
71
|
+
metrics_rows.append(metrics)
|
|
72
|
+
|
|
73
|
+
return metrics_rows
|
mlxsmith/rlm/weights.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
1
|
+
"""Weight pointer system for tracking adapter weights across RLM iterations.
|
|
2
|
+
|
|
3
|
+
Extends to support IPC for multi-process orchestration with atomic updates
|
|
4
|
+
and hot-reload capabilities.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
import multiprocessing as mp
|
|
11
|
+
import time
|
|
12
|
+
from dataclasses import dataclass, asdict
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import Optional, Callable
|
|
15
|
+
|
|
16
|
+
from ..util import ensure_dir, now_ts
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class WeightPointer:
|
|
21
|
+
base_model: str
|
|
22
|
+
adapter_path: Optional[str]
|
|
23
|
+
iteration: int
|
|
24
|
+
updated_at: str
|
|
25
|
+
name: Optional[str] = None
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass
|
|
29
|
+
class WeightPointerIPC:
|
|
30
|
+
"""Extended WeightPointer with IPC support for multi-process orchestration.
|
|
31
|
+
|
|
32
|
+
Includes versioning and atomic update mechanisms for hot-reloading.
|
|
33
|
+
"""
|
|
34
|
+
base_model: str
|
|
35
|
+
adapter_path: Optional[str]
|
|
36
|
+
iteration: int
|
|
37
|
+
updated_at: str
|
|
38
|
+
version: int = 0 # Monotonic version for ordering updates
|
|
39
|
+
checksum: Optional[str] = None # Optional checksum for integrity
|
|
40
|
+
name: Optional[str] = None
|
|
41
|
+
|
|
42
|
+
def to_dict(self) -> dict:
|
|
43
|
+
return {
|
|
44
|
+
"base_model": self.base_model,
|
|
45
|
+
"adapter_path": self.adapter_path,
|
|
46
|
+
"iteration": self.iteration,
|
|
47
|
+
"updated_at": self.updated_at,
|
|
48
|
+
"version": self.version,
|
|
49
|
+
"checksum": self.checksum,
|
|
50
|
+
"name": self.name,
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
@classmethod
|
|
54
|
+
def from_dict(cls, data: dict) -> "WeightPointerIPC":
|
|
55
|
+
return cls(
|
|
56
|
+
base_model=data["base_model"],
|
|
57
|
+
adapter_path=data.get("adapter_path"),
|
|
58
|
+
iteration=data.get("iteration", 0),
|
|
59
|
+
updated_at=data.get("updated_at", now_ts()),
|
|
60
|
+
version=data.get("version", 0),
|
|
61
|
+
checksum=data.get("checksum"),
|
|
62
|
+
name=data.get("name"),
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class WeightPointerStore:
|
|
67
|
+
"""Atomic weight pointer store for IPC between processes.
|
|
68
|
+
|
|
69
|
+
Uses file-based atomic updates with versioning to ensure
|
|
70
|
+
inference workers always see consistent state.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
def __init__(self, weights_dir: Path):
|
|
74
|
+
self._weights_dir = Path(weights_dir)
|
|
75
|
+
self._lock = mp.Lock()
|
|
76
|
+
|
|
77
|
+
def get_path(self, name: str) -> Path:
|
|
78
|
+
"""Get the storage path for a named pointer."""
|
|
79
|
+
return self._weights_dir / f"{name}.json"
|
|
80
|
+
|
|
81
|
+
def get_atomic_path(self, name: str) -> Path:
|
|
82
|
+
"""Get the temporary path for atomic writes."""
|
|
83
|
+
return self._weights_dir / f".{name}.tmp"
|
|
84
|
+
|
|
85
|
+
def load(self, name: str, base_model: str) -> WeightPointerIPC:
|
|
86
|
+
"""Load a weight pointer from storage."""
|
|
87
|
+
path = self.get_path(name)
|
|
88
|
+
|
|
89
|
+
with self._lock:
|
|
90
|
+
if not path.exists():
|
|
91
|
+
return WeightPointerIPC(
|
|
92
|
+
base_model=base_model,
|
|
93
|
+
adapter_path=None,
|
|
94
|
+
iteration=0,
|
|
95
|
+
updated_at=now_ts(),
|
|
96
|
+
version=0,
|
|
97
|
+
name=name,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
try:
|
|
101
|
+
data = json.loads(path.read_text(encoding="utf-8"))
|
|
102
|
+
return WeightPointerIPC.from_dict(data)
|
|
103
|
+
except Exception:
|
|
104
|
+
return WeightPointerIPC(
|
|
105
|
+
base_model=base_model,
|
|
106
|
+
adapter_path=None,
|
|
107
|
+
iteration=0,
|
|
108
|
+
updated_at=now_ts(),
|
|
109
|
+
version=0,
|
|
110
|
+
name=name,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
def save(self, pointer: WeightPointerIPC) -> None:
|
|
114
|
+
"""Atomically save a weight pointer."""
|
|
115
|
+
path = self.get_path(pointer.name or "default")
|
|
116
|
+
tmp_path = self.get_atomic_path(pointer.name or "default")
|
|
117
|
+
|
|
118
|
+
ensure_dir(self._weights_dir)
|
|
119
|
+
|
|
120
|
+
with self._lock:
|
|
121
|
+
# Write to temp file
|
|
122
|
+
tmp_path.write_text(
|
|
123
|
+
json.dumps(pointer.to_dict(), indent=2),
|
|
124
|
+
encoding="utf-8",
|
|
125
|
+
)
|
|
126
|
+
# Atomic rename
|
|
127
|
+
tmp_path.rename(path)
|
|
128
|
+
|
|
129
|
+
def update(
|
|
130
|
+
self,
|
|
131
|
+
name: str,
|
|
132
|
+
adapter_path: Optional[str] = None,
|
|
133
|
+
iteration: Optional[int] = None,
|
|
134
|
+
checksum: Optional[str] = None,
|
|
135
|
+
) -> WeightPointerIPC:
|
|
136
|
+
"""Update a weight pointer atomically."""
|
|
137
|
+
current = self.load(name, "") # base_model will be preserved
|
|
138
|
+
|
|
139
|
+
new_pointer = WeightPointerIPC(
|
|
140
|
+
base_model=current.base_model,
|
|
141
|
+
adapter_path=adapter_path if adapter_path is not None else current.adapter_path,
|
|
142
|
+
iteration=iteration if iteration is not None else current.iteration,
|
|
143
|
+
updated_at=now_ts(),
|
|
144
|
+
version=current.version + 1,
|
|
145
|
+
checksum=checksum,
|
|
146
|
+
name=name,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
self.save(new_pointer)
|
|
150
|
+
return new_pointer
|
|
151
|
+
|
|
152
|
+
def watch(
|
|
153
|
+
self,
|
|
154
|
+
name: str,
|
|
155
|
+
base_model: str,
|
|
156
|
+
callback: Callable[[WeightPointerIPC], None],
|
|
157
|
+
poll_interval: float = 1.0,
|
|
158
|
+
) -> "WeightWatcher":
|
|
159
|
+
"""Create a watcher that monitors for pointer changes."""
|
|
160
|
+
return WeightWatcher(self, name, base_model, callback, poll_interval)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class WeightWatcher:
|
|
164
|
+
"""Watches a weight pointer for changes and triggers callbacks.
|
|
165
|
+
|
|
166
|
+
Used by inference workers to hot-reload weights when updates
|
|
167
|
+
are published by the trainer.
|
|
168
|
+
"""
|
|
169
|
+
|
|
170
|
+
def __init__(
|
|
171
|
+
self,
|
|
172
|
+
store: WeightPointerStore,
|
|
173
|
+
name: str,
|
|
174
|
+
base_model: str,
|
|
175
|
+
callback: Callable[[WeightPointerIPC], None],
|
|
176
|
+
poll_interval: float = 1.0,
|
|
177
|
+
):
|
|
178
|
+
self._store = store
|
|
179
|
+
self._name = name
|
|
180
|
+
self._base_model = base_model
|
|
181
|
+
self._callback = callback
|
|
182
|
+
self._poll_interval = poll_interval
|
|
183
|
+
self._last_version = -1
|
|
184
|
+
self._running = False
|
|
185
|
+
self._process: Optional[mp.Process] = None
|
|
186
|
+
|
|
187
|
+
def start(self) -> None:
|
|
188
|
+
"""Start watching in a background process."""
|
|
189
|
+
self._running = True
|
|
190
|
+
self._process = mp.Process(target=self._watch_loop)
|
|
191
|
+
self._process.start()
|
|
192
|
+
|
|
193
|
+
def stop(self) -> None:
|
|
194
|
+
"""Stop the watcher."""
|
|
195
|
+
self._running = False
|
|
196
|
+
if self._process:
|
|
197
|
+
self._process.join(timeout=5.0)
|
|
198
|
+
if self._process.is_alive():
|
|
199
|
+
self._process.terminate()
|
|
200
|
+
self._process = None
|
|
201
|
+
|
|
202
|
+
def _watch_loop(self) -> None:
|
|
203
|
+
"""Internal watch loop running in separate process."""
|
|
204
|
+
while self._running:
|
|
205
|
+
try:
|
|
206
|
+
pointer = self._store.load(self._name, self._base_model)
|
|
207
|
+
if pointer.version > self._last_version:
|
|
208
|
+
self._last_version = pointer.version
|
|
209
|
+
self._callback(pointer)
|
|
210
|
+
except Exception:
|
|
211
|
+
pass # Continue watching despite errors
|
|
212
|
+
|
|
213
|
+
time.sleep(self._poll_interval)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def load_pointer(path: Path, *, base_model: str, name: Optional[str] = None) -> WeightPointer:
|
|
217
|
+
"""Load a weight pointer from disk (backward compatible)."""
|
|
218
|
+
if not path.exists():
|
|
219
|
+
return WeightPointer(
|
|
220
|
+
base_model=base_model,
|
|
221
|
+
adapter_path=None,
|
|
222
|
+
iteration=0,
|
|
223
|
+
updated_at=now_ts(),
|
|
224
|
+
name=name,
|
|
225
|
+
)
|
|
226
|
+
data = json.loads(path.read_text(encoding="utf-8"))
|
|
227
|
+
return WeightPointer(
|
|
228
|
+
base_model=data.get("base_model") or base_model,
|
|
229
|
+
adapter_path=data.get("adapter_path"),
|
|
230
|
+
iteration=int(data.get("iteration", 0)),
|
|
231
|
+
updated_at=data.get("updated_at") or now_ts(),
|
|
232
|
+
name=data.get("name") or name,
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def save_pointer(path: Path, pointer: WeightPointer) -> None:
|
|
237
|
+
"""Save a weight pointer to disk (backward compatible)."""
|
|
238
|
+
ensure_dir(path.parent)
|
|
239
|
+
path.write_text(
|
|
240
|
+
json.dumps(
|
|
241
|
+
{
|
|
242
|
+
"base_model": pointer.base_model,
|
|
243
|
+
"adapter_path": pointer.adapter_path,
|
|
244
|
+
"iteration": pointer.iteration,
|
|
245
|
+
"updated_at": pointer.updated_at,
|
|
246
|
+
"name": pointer.name,
|
|
247
|
+
},
|
|
248
|
+
indent=2,
|
|
249
|
+
),
|
|
250
|
+
encoding="utf-8",
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def load_pointer_ipc(path: Path, base_model: str, name: str) -> WeightPointerIPC:
|
|
255
|
+
"""Load an IPC-enabled weight pointer."""
|
|
256
|
+
store = WeightPointerStore(path.parent)
|
|
257
|
+
return store.load(name, base_model)
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def save_pointer_ipc(path: Path, pointer: WeightPointerIPC) -> None:
|
|
261
|
+
"""Save an IPC-enabled weight pointer."""
|
|
262
|
+
store = WeightPointerStore(path.parent)
|
|
263
|
+
store.save(pointer)
|
mlxsmith/runs.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import yaml
|
|
7
|
+
|
|
8
|
+
from .util import ensure_dir
|
|
9
|
+
|
|
10
|
+
@dataclass
|
|
11
|
+
class RunPaths:
|
|
12
|
+
run_dir: Path
|
|
13
|
+
logs_dir: Path
|
|
14
|
+
checkpoints_dir: Path
|
|
15
|
+
adapter_dir: Path
|
|
16
|
+
artifacts_dir: Path
|
|
17
|
+
metrics_path: Path
|
|
18
|
+
config_snapshot_path: Path
|
|
19
|
+
|
|
20
|
+
def new_run(root: Path, kind: str) -> RunPaths:
|
|
21
|
+
runs_root = ensure_dir(root / "runs")
|
|
22
|
+
# monotonically increasing id by counting existing runs of same kind
|
|
23
|
+
existing = sorted([p for p in runs_root.glob(f"{kind}_*") if p.is_dir()])
|
|
24
|
+
next_idx = len(existing) + 1
|
|
25
|
+
run_name = f"{kind}_{next_idx:04d}"
|
|
26
|
+
run_dir = ensure_dir(runs_root / run_name)
|
|
27
|
+
logs_dir = ensure_dir(run_dir / "logs")
|
|
28
|
+
checkpoints_dir = ensure_dir(run_dir / "checkpoints")
|
|
29
|
+
adapter_dir = ensure_dir(run_dir / "adapter")
|
|
30
|
+
artifacts_dir = ensure_dir(run_dir / "artifacts")
|
|
31
|
+
metrics_path = run_dir / "metrics.jsonl"
|
|
32
|
+
config_snapshot_path = run_dir / "config.snapshot.yaml"
|
|
33
|
+
return RunPaths(
|
|
34
|
+
run_dir=run_dir,
|
|
35
|
+
logs_dir=logs_dir,
|
|
36
|
+
checkpoints_dir=checkpoints_dir,
|
|
37
|
+
adapter_dir=adapter_dir,
|
|
38
|
+
artifacts_dir=artifacts_dir,
|
|
39
|
+
metrics_path=metrics_path,
|
|
40
|
+
config_snapshot_path=config_snapshot_path,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
def snapshot_config(cfg_dict: dict, path: Path):
|
|
44
|
+
path.write_text(yaml.safe_dump(cfg_dict, sort_keys=False), encoding="utf-8")
|