mlxsmith 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlxsmith/__init__.py +2 -0
- mlxsmith/accel/__init__.py +10 -0
- mlxsmith/accel/base.py +17 -0
- mlxsmith/accel/none.py +13 -0
- mlxsmith/accel/zmlx_backend.py +42 -0
- mlxsmith/adapters.py +46 -0
- mlxsmith/api/__init__.py +48 -0
- mlxsmith/api/handlers.py +1217 -0
- mlxsmith/api/schemas.py +436 -0
- mlxsmith/auth.py +88 -0
- mlxsmith/bench.py +102 -0
- mlxsmith/cli.py +950 -0
- mlxsmith/config.py +543 -0
- mlxsmith/config_models.py +261 -0
- mlxsmith/data.py +493 -0
- mlxsmith/envs/__init__.py +33 -0
- mlxsmith/envs/system.py +388 -0
- mlxsmith/envs/token_env.py +191 -0
- mlxsmith/eval.py +112 -0
- mlxsmith/infer.py +140 -0
- mlxsmith/llm/__init__.py +16 -0
- mlxsmith/llm/backend.py +126 -0
- mlxsmith/llm/interface.py +212 -0
- mlxsmith/llm/mlx_lm_backend.py +509 -0
- mlxsmith/llm/mock_backend.py +228 -0
- mlxsmith/llm/registry.py +12 -0
- mlxsmith/models.py +257 -0
- mlxsmith/orchestrator/__init__.py +25 -0
- mlxsmith/orchestrator/daemon.py +454 -0
- mlxsmith/orchestrator/inference_worker.py +496 -0
- mlxsmith/orchestrator/queue.py +355 -0
- mlxsmith/orchestrator/trainer_worker.py +437 -0
- mlxsmith/rlm/__init__.py +8 -0
- mlxsmith/rlm/corpus.py +74 -0
- mlxsmith/rlm/gating.py +90 -0
- mlxsmith/rlm/generate.py +249 -0
- mlxsmith/rlm/history.py +12 -0
- mlxsmith/rlm/inference.py +150 -0
- mlxsmith/rlm/loop.py +1297 -0
- mlxsmith/rlm/mutate.py +82 -0
- mlxsmith/rlm/trainer.py +73 -0
- mlxsmith/rlm/weights.py +263 -0
- mlxsmith/runs.py +44 -0
- mlxsmith/sdk/__init__.py +392 -0
- mlxsmith/sdk/future.py +486 -0
- mlxsmith/sdk/losses.py +262 -0
- mlxsmith/sdk/sampling_client.py +729 -0
- mlxsmith/sdk/training_client.py +676 -0
- mlxsmith/server.py +376 -0
- mlxsmith/train/__init__.py +0 -0
- mlxsmith/train/distill.py +279 -0
- mlxsmith/train/lora.py +280 -0
- mlxsmith/train/pref.py +180 -0
- mlxsmith/train/rft.py +458 -0
- mlxsmith/train/sft.py +151 -0
- mlxsmith/util.py +174 -0
- mlxsmith/verifiers/__init__.py +3 -0
- mlxsmith/verifiers/compose.py +109 -0
- mlxsmith/verifiers/docker_verifier.py +111 -0
- mlxsmith/verifiers/jsonschema.py +54 -0
- mlxsmith/verifiers/pytest_verifier.py +82 -0
- mlxsmith/verifiers/regex.py +15 -0
- mlxsmith/verifiers/types.py +10 -0
- mlxsmith-0.1.0.dist-info/METADATA +163 -0
- mlxsmith-0.1.0.dist-info/RECORD +69 -0
- mlxsmith-0.1.0.dist-info/WHEEL +5 -0
- mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
- mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
- mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
mlxsmith/train/lora.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
1
|
+
"""LoRA adapter utilities.
|
|
2
|
+
|
|
3
|
+
This module prefers MLX-LM's LoRA utilities and adapter format when available.
|
|
4
|
+
Fallback implementations are provided for environments without MLX/MLX-LM.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _require_mlx():
|
|
16
|
+
import mlx.core as mx # type: ignore
|
|
17
|
+
import mlx.nn as nn # type: ignore
|
|
18
|
+
from mlx.utils import tree_flatten # type: ignore
|
|
19
|
+
|
|
20
|
+
return mx, nn, tree_flatten
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _try_mlx_lm_utils():
|
|
24
|
+
try:
|
|
25
|
+
from mlx_lm.tuner import utils as tuner_utils # type: ignore
|
|
26
|
+
from mlx_lm.utils import save_config as mlx_save_config # type: ignore
|
|
27
|
+
except Exception:
|
|
28
|
+
return None, None
|
|
29
|
+
return tuner_utils, mlx_save_config
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class LoRAConfig:
|
|
34
|
+
r: int = 16
|
|
35
|
+
alpha: int = 32
|
|
36
|
+
dropout: float = 0.0
|
|
37
|
+
target_modules: list[str] | None = None
|
|
38
|
+
num_layers: int = 0 # 0 => all layers
|
|
39
|
+
scale: float | None = None
|
|
40
|
+
fine_tune_type: str = "lora" # lora | dora | full
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class LoRALinear:
|
|
44
|
+
"""Minimal fallback LoRA wrapper for mlx.nn.Linear."""
|
|
45
|
+
|
|
46
|
+
def __init__(self, base_linear: Any, *, r: int, alpha: int, dropout: float = 0.0):
|
|
47
|
+
mx, nn, _ = _require_mlx()
|
|
48
|
+
self.nn = nn
|
|
49
|
+
self.mx = mx
|
|
50
|
+
|
|
51
|
+
self.base = base_linear
|
|
52
|
+
self.r = int(r)
|
|
53
|
+
self.alpha = float(alpha)
|
|
54
|
+
self.scale = float(alpha) / float(r) if r > 0 else 0.0
|
|
55
|
+
self.dropout = nn.Dropout(p=float(dropout)) if dropout and dropout > 0 else None
|
|
56
|
+
|
|
57
|
+
w = getattr(base_linear, "weight")
|
|
58
|
+
out_dim, in_dim = int(w.shape[0]), int(w.shape[1])
|
|
59
|
+
|
|
60
|
+
self.A = mx.random.normal((self.r, in_dim)) * 0.01
|
|
61
|
+
self.B = mx.zeros((out_dim, self.r))
|
|
62
|
+
|
|
63
|
+
def __call__(self, x):
|
|
64
|
+
mx = self.mx
|
|
65
|
+
base_w = mx.stop_gradient(self.base.weight)
|
|
66
|
+
y = x @ base_w.T
|
|
67
|
+
if getattr(self.base, "bias", None) is not None:
|
|
68
|
+
y = y + mx.stop_gradient(self.base.bias)
|
|
69
|
+
z = x
|
|
70
|
+
if self.dropout is not None:
|
|
71
|
+
z = self.dropout(z)
|
|
72
|
+
z = (z @ self.A.T) @ self.B.T
|
|
73
|
+
return y + z * self.scale
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _iter_named_modules(root: Any, prefix: str = ""):
|
|
77
|
+
for name in dir(root):
|
|
78
|
+
if name.startswith("_"):
|
|
79
|
+
continue
|
|
80
|
+
try:
|
|
81
|
+
obj = getattr(root, name)
|
|
82
|
+
except Exception:
|
|
83
|
+
continue
|
|
84
|
+
if callable(getattr(obj, "__call__", None)) and hasattr(obj, "__dict__"):
|
|
85
|
+
full = f"{prefix}{name}" if not prefix else f"{prefix}.{name}"
|
|
86
|
+
yield full, root, name, obj
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def inject_lora(
|
|
90
|
+
model: Any,
|
|
91
|
+
*,
|
|
92
|
+
r: int = 16,
|
|
93
|
+
alpha: int = 32,
|
|
94
|
+
dropout: float = 0.0,
|
|
95
|
+
target_modules: list[str] | None = None,
|
|
96
|
+
) -> int:
|
|
97
|
+
_mx, _nn, _ = _require_mlx()
|
|
98
|
+
targets = target_modules or ["q_proj", "k_proj", "v_proj", "o_proj"]
|
|
99
|
+
wrapped = 0
|
|
100
|
+
|
|
101
|
+
for full, parent, attr, obj in list(_iter_named_modules(model)):
|
|
102
|
+
if not hasattr(obj, "weight"):
|
|
103
|
+
continue
|
|
104
|
+
if not any(attr.endswith(t) or full.endswith(t) for t in targets):
|
|
105
|
+
continue
|
|
106
|
+
cls_name = obj.__class__.__name__.lower()
|
|
107
|
+
if "linear" not in cls_name:
|
|
108
|
+
continue
|
|
109
|
+
try:
|
|
110
|
+
setattr(parent, attr, LoRALinear(obj, r=r, alpha=alpha, dropout=dropout))
|
|
111
|
+
wrapped += 1
|
|
112
|
+
except Exception:
|
|
113
|
+
continue
|
|
114
|
+
|
|
115
|
+
return wrapped
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def lora_parameters(model: Any) -> dict[str, Any]:
|
|
119
|
+
params: dict[str, Any] = {}
|
|
120
|
+
for full, _parent, _attr, obj in list(_iter_named_modules(model)):
|
|
121
|
+
if obj.__class__.__name__ == "LoRALinear" or (hasattr(obj, "A") and hasattr(obj, "B")):
|
|
122
|
+
try:
|
|
123
|
+
params[f"{full}.A"] = obj.A
|
|
124
|
+
params[f"{full}.B"] = obj.B
|
|
125
|
+
except Exception:
|
|
126
|
+
pass
|
|
127
|
+
return params
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def _scale_for_config(cfg: LoRAConfig) -> float:
|
|
131
|
+
if cfg.scale is not None:
|
|
132
|
+
return float(cfg.scale)
|
|
133
|
+
if cfg.r == 0:
|
|
134
|
+
return float(cfg.alpha)
|
|
135
|
+
return float(cfg.alpha) / float(cfg.r)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _keys_for_target_modules(model: Any, target_modules: list[str]) -> set[str]:
|
|
139
|
+
keys: set[str] = set()
|
|
140
|
+
# Prefer model.named_modules if present (MLX-LM models support this)
|
|
141
|
+
if hasattr(model, "named_modules"):
|
|
142
|
+
for name, _mod in model.named_modules():
|
|
143
|
+
if any(name.endswith(t) for t in target_modules):
|
|
144
|
+
keys.add(name)
|
|
145
|
+
return keys
|
|
146
|
+
|
|
147
|
+
# Fallback: walk attributes
|
|
148
|
+
for full, _parent, _attr, obj in list(_iter_named_modules(model)):
|
|
149
|
+
if any(full.endswith(t) for t in target_modules):
|
|
150
|
+
keys.add(full)
|
|
151
|
+
return keys
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def apply_lora(model: Any, cfg: LoRAConfig) -> dict:
|
|
155
|
+
"""Apply LoRA layers (prefer MLX-LM utilities) and return adapter config."""
|
|
156
|
+
tuner_utils, _save_cfg = _try_mlx_lm_utils()
|
|
157
|
+
scale = _scale_for_config(cfg)
|
|
158
|
+
keys = None
|
|
159
|
+
if cfg.target_modules:
|
|
160
|
+
keys = sorted(_keys_for_target_modules(model, cfg.target_modules))
|
|
161
|
+
|
|
162
|
+
if tuner_utils is not None and hasattr(tuner_utils, "linear_to_lora_layers"):
|
|
163
|
+
# MLX-LM format
|
|
164
|
+
config = {
|
|
165
|
+
"rank": int(cfg.r),
|
|
166
|
+
"scale": float(scale),
|
|
167
|
+
"dropout": float(cfg.dropout),
|
|
168
|
+
}
|
|
169
|
+
if keys:
|
|
170
|
+
config["keys"] = keys
|
|
171
|
+
num_layers = int(cfg.num_layers)
|
|
172
|
+
tuner_utils.linear_to_lora_layers(
|
|
173
|
+
model,
|
|
174
|
+
num_layers,
|
|
175
|
+
config,
|
|
176
|
+
use_dora=(cfg.fine_tune_type == "dora"),
|
|
177
|
+
)
|
|
178
|
+
return {
|
|
179
|
+
"fine_tune_type": cfg.fine_tune_type,
|
|
180
|
+
"num_layers": num_layers,
|
|
181
|
+
"lora_parameters": config,
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
# Fallback to local LoRA injection
|
|
185
|
+
inject_lora(
|
|
186
|
+
model,
|
|
187
|
+
r=cfg.r,
|
|
188
|
+
alpha=cfg.alpha,
|
|
189
|
+
dropout=cfg.dropout,
|
|
190
|
+
target_modules=cfg.target_modules,
|
|
191
|
+
)
|
|
192
|
+
return {
|
|
193
|
+
"fine_tune_type": "lora",
|
|
194
|
+
"num_layers": 0,
|
|
195
|
+
"lora_parameters": {
|
|
196
|
+
"rank": int(cfg.r),
|
|
197
|
+
"scale": float(scale),
|
|
198
|
+
"dropout": float(cfg.dropout),
|
|
199
|
+
},
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def save_adapter(model: Any, out_dir: str | Path, *, adapter_config: dict, metadata: dict | None = None) -> None:
|
|
204
|
+
out = Path(out_dir)
|
|
205
|
+
out.mkdir(parents=True, exist_ok=True)
|
|
206
|
+
|
|
207
|
+
tuner_utils, mlx_save_config = _try_mlx_lm_utils()
|
|
208
|
+
try:
|
|
209
|
+
mx, _nn, tree_flatten = _require_mlx()
|
|
210
|
+
except Exception:
|
|
211
|
+
mx = None
|
|
212
|
+
tree_flatten = None
|
|
213
|
+
|
|
214
|
+
if mx is not None and hasattr(model, "trainable_parameters") and tree_flatten is not None:
|
|
215
|
+
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
|
|
216
|
+
try:
|
|
217
|
+
mx.save_safetensors(str(out / "adapters.safetensors"), adapter_weights)
|
|
218
|
+
except Exception:
|
|
219
|
+
# fallback to numpy
|
|
220
|
+
import numpy as np
|
|
221
|
+
|
|
222
|
+
arrays = {k: mx.array(v).to_numpy() for k, v in adapter_weights.items()}
|
|
223
|
+
np.savez(out / "lora.npz", **arrays)
|
|
224
|
+
else:
|
|
225
|
+
# fallback to local LoRA params
|
|
226
|
+
params = lora_parameters(model)
|
|
227
|
+
if params:
|
|
228
|
+
import numpy as np
|
|
229
|
+
|
|
230
|
+
mx_local, _nn_local, _ = _require_mlx()
|
|
231
|
+
arrays = {k: mx_local.array(v).to_numpy() for k, v in params.items()}
|
|
232
|
+
np.savez(out / "lora.npz", **arrays)
|
|
233
|
+
|
|
234
|
+
config_path = out / "adapter_config.json"
|
|
235
|
+
if mlx_save_config is not None:
|
|
236
|
+
cfg = dict(adapter_config)
|
|
237
|
+
mlx_save_config(cfg, config_path)
|
|
238
|
+
else:
|
|
239
|
+
config_path.write_text(json.dumps(adapter_config, indent=2), encoding="utf-8")
|
|
240
|
+
|
|
241
|
+
if metadata is not None:
|
|
242
|
+
(out / "adapter_metadata.json").write_text(
|
|
243
|
+
json.dumps(metadata, indent=2), encoding="utf-8"
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def load_adapter_config(adapter_dir: str | Path) -> dict | None:
|
|
248
|
+
path = Path(adapter_dir) / "adapter_config.json"
|
|
249
|
+
if not path.exists():
|
|
250
|
+
return None
|
|
251
|
+
return json.loads(path.read_text(encoding="utf-8"))
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def apply_adapter(model: Any, adapter_dir: str | Path) -> dict | None:
|
|
255
|
+
adapter_dir = Path(adapter_dir)
|
|
256
|
+
adapter_cfg = load_adapter_config(adapter_dir)
|
|
257
|
+
tuner_utils, _ = _try_mlx_lm_utils()
|
|
258
|
+
if adapter_cfg is not None and tuner_utils is not None and hasattr(tuner_utils, "load_adapters"):
|
|
259
|
+
tuner_utils.load_adapters(model, str(adapter_dir))
|
|
260
|
+
return adapter_cfg
|
|
261
|
+
|
|
262
|
+
# Fallback: load local lora.npz into LoRALinear wrappers
|
|
263
|
+
lora_file = adapter_dir / "lora.npz"
|
|
264
|
+
if not lora_file.exists():
|
|
265
|
+
return adapter_cfg
|
|
266
|
+
|
|
267
|
+
import numpy as np
|
|
268
|
+
|
|
269
|
+
weights = dict(np.load(lora_file))
|
|
270
|
+
mx, _nn, _ = _require_mlx()
|
|
271
|
+
# apply
|
|
272
|
+
for full, _parent, _attr, obj in list(_iter_named_modules(model)):
|
|
273
|
+
if hasattr(obj, "A") and hasattr(obj, "B"):
|
|
274
|
+
key_a = f"{full}.A"
|
|
275
|
+
key_b = f"{full}.B"
|
|
276
|
+
if key_a in weights:
|
|
277
|
+
obj.A = mx.array(weights[key_a])
|
|
278
|
+
if key_b in weights:
|
|
279
|
+
obj.B = mx.array(weights[key_b])
|
|
280
|
+
return adapter_cfg
|
mlxsmith/train/pref.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import random
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
from rich.console import Console
|
|
8
|
+
|
|
9
|
+
from ..accel import get_backend
|
|
10
|
+
from ..config import ProjectConfig
|
|
11
|
+
from ..models import resolve_model_spec
|
|
12
|
+
from ..runs import RunPaths, new_run, snapshot_config
|
|
13
|
+
from ..util import write_jsonl, now_ts, tree_add, tree_scale
|
|
14
|
+
from ..llm.registry import get_llm_backend
|
|
15
|
+
from ..llm.backend import BackendNotAvailable
|
|
16
|
+
from .lora import LoRAConfig
|
|
17
|
+
|
|
18
|
+
console = Console()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _load_pref_rows(path: Path) -> list[dict]:
|
|
22
|
+
return [json.loads(line) for line in path.read_text(encoding="utf-8").splitlines() if line.strip()]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def run_pref(project_root: Path, cfg: ProjectConfig, data_dir: Path, base_model_path: Path, accel: str) -> RunPaths:
|
|
26
|
+
run = new_run(project_root, "pref")
|
|
27
|
+
snapshot_config(cfg.model_dump(), run.config_snapshot_path)
|
|
28
|
+
|
|
29
|
+
backend = get_backend(accel)
|
|
30
|
+
backend.patch()
|
|
31
|
+
console.print(f"[bold]PREF[/bold] run: {run.run_dir.name} algo={cfg.pref.algo} accel={backend.name}")
|
|
32
|
+
|
|
33
|
+
prefs_path = data_dir / "train.jsonl"
|
|
34
|
+
if not prefs_path.exists():
|
|
35
|
+
raise RuntimeError("Preference data missing. Expect data/prefs/train.jsonl with {prompt, chosen, rejected}")
|
|
36
|
+
rows = _load_pref_rows(prefs_path)
|
|
37
|
+
|
|
38
|
+
llm = get_llm_backend(cfg.model.backend)
|
|
39
|
+
base_model, adapter_path, adapter_meta = resolve_model_spec(project_root, str(base_model_path), cfg)
|
|
40
|
+
|
|
41
|
+
try:
|
|
42
|
+
llm.load(
|
|
43
|
+
base_model,
|
|
44
|
+
max_seq_len=cfg.model.max_seq_len,
|
|
45
|
+
dtype=cfg.model.dtype,
|
|
46
|
+
trust_remote_code=cfg.model.trust_remote_code,
|
|
47
|
+
)
|
|
48
|
+
if adapter_path:
|
|
49
|
+
llm.apply_adapter(str(adapter_path))
|
|
50
|
+
else:
|
|
51
|
+
lora_cfg = LoRAConfig(
|
|
52
|
+
r=cfg.lora.r,
|
|
53
|
+
alpha=cfg.lora.alpha,
|
|
54
|
+
dropout=cfg.lora.dropout,
|
|
55
|
+
target_modules=list(cfg.lora.target_modules or []),
|
|
56
|
+
num_layers=cfg.lora.num_layers,
|
|
57
|
+
scale=cfg.lora.scale,
|
|
58
|
+
fine_tune_type=cfg.lora.fine_tune_type,
|
|
59
|
+
)
|
|
60
|
+
llm.apply_lora_from_config(lora_cfg)
|
|
61
|
+
except BackendNotAvailable as e:
|
|
62
|
+
console.print(f"[yellow]MLX backend unavailable[/yellow]: {e}")
|
|
63
|
+
(run.adapter_dir / "ADAPTER.txt").write_text(
|
|
64
|
+
f"Backend unavailable in this environment.\nbase={base_model}\naccel={backend.name}\n",
|
|
65
|
+
encoding="utf-8",
|
|
66
|
+
)
|
|
67
|
+
return run
|
|
68
|
+
|
|
69
|
+
ref_llm = None
|
|
70
|
+
if cfg.pref.reference_model:
|
|
71
|
+
ref_llm = get_llm_backend(cfg.model.backend)
|
|
72
|
+
try:
|
|
73
|
+
ref_llm.load(
|
|
74
|
+
cfg.pref.reference_model,
|
|
75
|
+
max_seq_len=cfg.model.max_seq_len,
|
|
76
|
+
dtype=cfg.model.dtype,
|
|
77
|
+
trust_remote_code=cfg.model.trust_remote_code,
|
|
78
|
+
)
|
|
79
|
+
except BackendNotAvailable:
|
|
80
|
+
ref_llm = None
|
|
81
|
+
|
|
82
|
+
opt, _params = llm.optimizer_and_params(lr=cfg.train.lr, weight_decay=cfg.train.weight_decay)
|
|
83
|
+
|
|
84
|
+
beta = float(cfg.pref.beta)
|
|
85
|
+
kl_coeff = float(cfg.pref.kl_coeff)
|
|
86
|
+
rng = random.Random(cfg.train.seed)
|
|
87
|
+
total = int(cfg.train.iters)
|
|
88
|
+
grad_accum = max(1, int(cfg.train.grad_accum))
|
|
89
|
+
train_on_prompt = bool(getattr(cfg.train, "train_on_prompt", False))
|
|
90
|
+
|
|
91
|
+
accum_grads = None
|
|
92
|
+
for step in range(1, total + 1):
|
|
93
|
+
row = rng.choice(rows)
|
|
94
|
+
prompt = row.get("prompt") or ""
|
|
95
|
+
chosen = row.get("chosen") or row.get("accepted") or ""
|
|
96
|
+
rejected = row.get("rejected") or row.get("rejected_response") or ""
|
|
97
|
+
if not (prompt and chosen and rejected):
|
|
98
|
+
continue
|
|
99
|
+
|
|
100
|
+
prompt_ids = llm.encode(prompt)
|
|
101
|
+
chosen_ids = llm.encode(prompt + chosen)
|
|
102
|
+
rejected_ids = llm.encode(prompt + rejected)
|
|
103
|
+
p_len_c = len(prompt_ids)
|
|
104
|
+
p_len_r = len(prompt_ids)
|
|
105
|
+
max_len = int(cfg.model.max_seq_len)
|
|
106
|
+
if max_len:
|
|
107
|
+
if len(chosen_ids) > max_len:
|
|
108
|
+
overflow = len(chosen_ids) - max_len
|
|
109
|
+
chosen_ids = chosen_ids[overflow:]
|
|
110
|
+
p_len_c = max(0, p_len_c - overflow)
|
|
111
|
+
if len(rejected_ids) > max_len:
|
|
112
|
+
overflow = len(rejected_ids) - max_len
|
|
113
|
+
rejected_ids = rejected_ids[overflow:]
|
|
114
|
+
p_len_r = max(0, p_len_r - overflow)
|
|
115
|
+
|
|
116
|
+
def loss_fn(_model):
|
|
117
|
+
logp_c = llm.sequence_logprob(chosen_ids, prompt_len=p_len_c)
|
|
118
|
+
logp_r = llm.sequence_logprob(rejected_ids, prompt_len=p_len_r)
|
|
119
|
+
ref_diff = 0.0
|
|
120
|
+
if ref_llm is not None:
|
|
121
|
+
ref_logp_c = ref_llm.sequence_logprob(chosen_ids, prompt_len=p_len_c)
|
|
122
|
+
ref_logp_r = ref_llm.sequence_logprob(rejected_ids, prompt_len=p_len_r)
|
|
123
|
+
ref_diff = ref_logp_c - ref_logp_r
|
|
124
|
+
diff = (logp_c - logp_r) - ref_diff
|
|
125
|
+
|
|
126
|
+
if cfg.pref.algo == "orpo":
|
|
127
|
+
# ORPO loss = NLL(chosen) - beta * log(sigmoid(diff))
|
|
128
|
+
nll = llm.sft_loss(chosen_ids, train_on_prompt=train_on_prompt, prompt_len=p_len_c)
|
|
129
|
+
or_loss = -beta * llm.mx.log(llm.mx.sigmoid(diff)) # type: ignore
|
|
130
|
+
loss = nll + or_loss
|
|
131
|
+
else:
|
|
132
|
+
# DPO loss
|
|
133
|
+
scaled = llm.mx.array(beta) * diff # type: ignore
|
|
134
|
+
loss = llm.mx.log1p(llm.mx.exp(-scaled)) # type: ignore
|
|
135
|
+
|
|
136
|
+
if ref_llm is not None and kl_coeff > 0:
|
|
137
|
+
# Simple KL penalty on chosen responses
|
|
138
|
+
kl = (logp_c - ref_logp_c) if ref_llm is not None else 0.0
|
|
139
|
+
loss = loss + llm.mx.array(kl_coeff) * kl # type: ignore
|
|
140
|
+
return loss
|
|
141
|
+
|
|
142
|
+
lval, grads = llm.value_and_grad(loss_fn)
|
|
143
|
+
if grads is not None:
|
|
144
|
+
accum_grads = tree_add(accum_grads, grads)
|
|
145
|
+
|
|
146
|
+
if step % grad_accum == 0:
|
|
147
|
+
if accum_grads is not None:
|
|
148
|
+
llm.apply_grads(opt, tree_scale(accum_grads, 1.0 / grad_accum))
|
|
149
|
+
accum_grads = None
|
|
150
|
+
|
|
151
|
+
if step % cfg.train.log_every == 0 or step == 1 or step == total:
|
|
152
|
+
write_jsonl(
|
|
153
|
+
run.metrics_path,
|
|
154
|
+
[
|
|
155
|
+
{
|
|
156
|
+
"ts": now_ts(),
|
|
157
|
+
"step": step,
|
|
158
|
+
"kind": "pref",
|
|
159
|
+
"algo": cfg.pref.algo,
|
|
160
|
+
"beta": beta,
|
|
161
|
+
"kl_coeff": kl_coeff,
|
|
162
|
+
"loss": float(lval.item()) if hasattr(lval, "item") else float(lval),
|
|
163
|
+
"accel": backend.name,
|
|
164
|
+
}
|
|
165
|
+
],
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
if step % cfg.train.save_every == 0 or step == total:
|
|
169
|
+
llm.save_adapter(
|
|
170
|
+
str(run.adapter_dir),
|
|
171
|
+
metadata={
|
|
172
|
+
"base_model": base_model,
|
|
173
|
+
"source_adapter": str(adapter_path) if adapter_path else None,
|
|
174
|
+
"run": run.run_dir.name,
|
|
175
|
+
"kind": "pref",
|
|
176
|
+
},
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
console.print(f"[green]Saved adapter[/green] {run.adapter_dir}")
|
|
180
|
+
return run
|