mlxsmith 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlxsmith/__init__.py +2 -0
- mlxsmith/accel/__init__.py +10 -0
- mlxsmith/accel/base.py +17 -0
- mlxsmith/accel/none.py +13 -0
- mlxsmith/accel/zmlx_backend.py +42 -0
- mlxsmith/adapters.py +46 -0
- mlxsmith/api/__init__.py +48 -0
- mlxsmith/api/handlers.py +1217 -0
- mlxsmith/api/schemas.py +436 -0
- mlxsmith/auth.py +88 -0
- mlxsmith/bench.py +102 -0
- mlxsmith/cli.py +950 -0
- mlxsmith/config.py +543 -0
- mlxsmith/config_models.py +261 -0
- mlxsmith/data.py +493 -0
- mlxsmith/envs/__init__.py +33 -0
- mlxsmith/envs/system.py +388 -0
- mlxsmith/envs/token_env.py +191 -0
- mlxsmith/eval.py +112 -0
- mlxsmith/infer.py +140 -0
- mlxsmith/llm/__init__.py +16 -0
- mlxsmith/llm/backend.py +126 -0
- mlxsmith/llm/interface.py +212 -0
- mlxsmith/llm/mlx_lm_backend.py +509 -0
- mlxsmith/llm/mock_backend.py +228 -0
- mlxsmith/llm/registry.py +12 -0
- mlxsmith/models.py +257 -0
- mlxsmith/orchestrator/__init__.py +25 -0
- mlxsmith/orchestrator/daemon.py +454 -0
- mlxsmith/orchestrator/inference_worker.py +496 -0
- mlxsmith/orchestrator/queue.py +355 -0
- mlxsmith/orchestrator/trainer_worker.py +437 -0
- mlxsmith/rlm/__init__.py +8 -0
- mlxsmith/rlm/corpus.py +74 -0
- mlxsmith/rlm/gating.py +90 -0
- mlxsmith/rlm/generate.py +249 -0
- mlxsmith/rlm/history.py +12 -0
- mlxsmith/rlm/inference.py +150 -0
- mlxsmith/rlm/loop.py +1297 -0
- mlxsmith/rlm/mutate.py +82 -0
- mlxsmith/rlm/trainer.py +73 -0
- mlxsmith/rlm/weights.py +263 -0
- mlxsmith/runs.py +44 -0
- mlxsmith/sdk/__init__.py +392 -0
- mlxsmith/sdk/future.py +486 -0
- mlxsmith/sdk/losses.py +262 -0
- mlxsmith/sdk/sampling_client.py +729 -0
- mlxsmith/sdk/training_client.py +676 -0
- mlxsmith/server.py +376 -0
- mlxsmith/train/__init__.py +0 -0
- mlxsmith/train/distill.py +279 -0
- mlxsmith/train/lora.py +280 -0
- mlxsmith/train/pref.py +180 -0
- mlxsmith/train/rft.py +458 -0
- mlxsmith/train/sft.py +151 -0
- mlxsmith/util.py +174 -0
- mlxsmith/verifiers/__init__.py +3 -0
- mlxsmith/verifiers/compose.py +109 -0
- mlxsmith/verifiers/docker_verifier.py +111 -0
- mlxsmith/verifiers/jsonschema.py +54 -0
- mlxsmith/verifiers/pytest_verifier.py +82 -0
- mlxsmith/verifiers/regex.py +15 -0
- mlxsmith/verifiers/types.py +10 -0
- mlxsmith-0.1.0.dist-info/METADATA +163 -0
- mlxsmith-0.1.0.dist-info/RECORD +69 -0
- mlxsmith-0.1.0.dist-info/WHEEL +5 -0
- mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
- mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
- mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,509 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
from typing import Sequence, Any, List, Dict, Optional
|
|
5
|
+
|
|
6
|
+
from .backend import Generation, BackendNotAvailable
|
|
7
|
+
from ..train.lora import apply_adapter, apply_lora, LoRAConfig, save_adapter
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class MlxLMBackend:
|
|
11
|
+
"""Backend built on MLX + mlx-lm."""
|
|
12
|
+
|
|
13
|
+
name = "mlx-lm"
|
|
14
|
+
|
|
15
|
+
def __init__(self, *, lora_config: dict | None = None):
|
|
16
|
+
self.lora_config = lora_config or {}
|
|
17
|
+
self.model = None
|
|
18
|
+
self.tokenizer = None
|
|
19
|
+
self.nn = None
|
|
20
|
+
self.mx = None
|
|
21
|
+
self.optim = None
|
|
22
|
+
self._lora_applied = False
|
|
23
|
+
self._adapter_config: dict | None = None
|
|
24
|
+
|
|
25
|
+
def _require(self):
|
|
26
|
+
try:
|
|
27
|
+
import mlx.core as mx # type: ignore
|
|
28
|
+
import mlx.nn as nn # type: ignore
|
|
29
|
+
import mlx.optimizers as optim # type: ignore
|
|
30
|
+
except Exception as e: # pragma: no cover
|
|
31
|
+
raise BackendNotAvailable(
|
|
32
|
+
"MLX is not installed. Try: pip install -e '.[mlx,llm]'"
|
|
33
|
+
) from e
|
|
34
|
+
self.mx = mx
|
|
35
|
+
self.nn = nn
|
|
36
|
+
self.optim = optim
|
|
37
|
+
|
|
38
|
+
try:
|
|
39
|
+
import mlx_lm # type: ignore
|
|
40
|
+
except Exception as e: # pragma: no cover
|
|
41
|
+
raise BackendNotAvailable(
|
|
42
|
+
"mlx-lm is not installed. Try: pip install -e '.[llm]'"
|
|
43
|
+
) from e
|
|
44
|
+
return mlx_lm
|
|
45
|
+
|
|
46
|
+
def _call_with_supported_kwargs(self, fn, *args, **kwargs):
|
|
47
|
+
try:
|
|
48
|
+
sig = inspect.signature(fn)
|
|
49
|
+
supported = {}
|
|
50
|
+
for k, v in kwargs.items():
|
|
51
|
+
if k in sig.parameters:
|
|
52
|
+
supported[k] = v
|
|
53
|
+
return fn(*args, **supported)
|
|
54
|
+
except Exception:
|
|
55
|
+
return fn(*args, **kwargs)
|
|
56
|
+
|
|
57
|
+
def load(
|
|
58
|
+
self,
|
|
59
|
+
model_id_or_path: str,
|
|
60
|
+
*,
|
|
61
|
+
max_seq_len: int | None = None,
|
|
62
|
+
dtype: str | None = None,
|
|
63
|
+
tokenizer_config: dict | None = None,
|
|
64
|
+
model_config: dict | None = None,
|
|
65
|
+
adapter_path: str | None = None,
|
|
66
|
+
trust_remote_code: bool | None = None,
|
|
67
|
+
) -> None:
|
|
68
|
+
mlx_lm = self._require()
|
|
69
|
+
|
|
70
|
+
if tokenizer_config is None:
|
|
71
|
+
tokenizer_config = {}
|
|
72
|
+
if trust_remote_code is not None:
|
|
73
|
+
tokenizer_config = dict(tokenizer_config)
|
|
74
|
+
tokenizer_config.setdefault("trust_remote_code", trust_remote_code)
|
|
75
|
+
|
|
76
|
+
if model_config is None:
|
|
77
|
+
model_config = {}
|
|
78
|
+
# Pass dtype/max_seq_len as hints when supported in model config.
|
|
79
|
+
if dtype is not None:
|
|
80
|
+
model_config = dict(model_config)
|
|
81
|
+
model_config.setdefault("dtype", dtype)
|
|
82
|
+
if max_seq_len is not None:
|
|
83
|
+
model_config = dict(model_config)
|
|
84
|
+
model_config.setdefault("max_seq_len", max_seq_len)
|
|
85
|
+
|
|
86
|
+
load_fn = getattr(mlx_lm, "load", None)
|
|
87
|
+
if callable(load_fn):
|
|
88
|
+
model, tokenizer = self._call_with_supported_kwargs(
|
|
89
|
+
load_fn,
|
|
90
|
+
model_id_or_path,
|
|
91
|
+
tokenizer_config=tokenizer_config,
|
|
92
|
+
model_config=model_config,
|
|
93
|
+
adapter_path=adapter_path,
|
|
94
|
+
)
|
|
95
|
+
else: # pragma: no cover
|
|
96
|
+
utils = getattr(mlx_lm, "utils", None)
|
|
97
|
+
if utils is None or not callable(getattr(utils, "load", None)):
|
|
98
|
+
raise BackendNotAvailable("Could not find mlx_lm.load(...) API")
|
|
99
|
+
model, tokenizer = self._call_with_supported_kwargs(
|
|
100
|
+
utils.load,
|
|
101
|
+
model_id_or_path,
|
|
102
|
+
tokenizer_config=tokenizer_config,
|
|
103
|
+
model_config=model_config,
|
|
104
|
+
adapter_path=adapter_path,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
self.model = model
|
|
108
|
+
self.tokenizer = tokenizer
|
|
109
|
+
self._lora_applied = False
|
|
110
|
+
self._adapter_config = None
|
|
111
|
+
|
|
112
|
+
def apply_adapter(self, adapter_path: str) -> None:
|
|
113
|
+
if self.model is None:
|
|
114
|
+
raise RuntimeError("Backend not loaded")
|
|
115
|
+
self._adapter_config = apply_adapter(self.model, adapter_path)
|
|
116
|
+
self._lora_applied = True
|
|
117
|
+
|
|
118
|
+
def apply_lora_from_config(self, cfg: LoRAConfig) -> dict:
|
|
119
|
+
if self.model is None:
|
|
120
|
+
raise RuntimeError("Backend not loaded")
|
|
121
|
+
adapter_cfg = apply_lora(self.model, cfg)
|
|
122
|
+
self._lora_applied = True
|
|
123
|
+
self._adapter_config = adapter_cfg
|
|
124
|
+
return adapter_cfg
|
|
125
|
+
|
|
126
|
+
def encode(self, text: str) -> list[int]:
|
|
127
|
+
if self.tokenizer is None:
|
|
128
|
+
raise RuntimeError("Backend not loaded")
|
|
129
|
+
tok = self.tokenizer
|
|
130
|
+
if hasattr(tok, "encode"):
|
|
131
|
+
out = tok.encode(text)
|
|
132
|
+
if isinstance(out, dict) and "input_ids" in out:
|
|
133
|
+
return list(out["input_ids"])
|
|
134
|
+
if isinstance(out, (list, tuple)):
|
|
135
|
+
return list(out)
|
|
136
|
+
if hasattr(tok, "__call__"):
|
|
137
|
+
out = tok(text)
|
|
138
|
+
if isinstance(out, dict) and "input_ids" in out:
|
|
139
|
+
return list(out["input_ids"])
|
|
140
|
+
raise RuntimeError("Tokenizer does not support encode")
|
|
141
|
+
|
|
142
|
+
def decode(self, ids: Sequence[int]) -> str:
|
|
143
|
+
if self.tokenizer is None:
|
|
144
|
+
raise RuntimeError("Backend not loaded")
|
|
145
|
+
tok = self.tokenizer
|
|
146
|
+
if hasattr(tok, "decode"):
|
|
147
|
+
return tok.decode(list(ids))
|
|
148
|
+
raise RuntimeError("Tokenizer does not support decode")
|
|
149
|
+
|
|
150
|
+
def _forward_logits(self, ids: Sequence[int]):
|
|
151
|
+
assert self.mx is not None
|
|
152
|
+
if self.model is None:
|
|
153
|
+
raise RuntimeError("Backend not loaded")
|
|
154
|
+
mx = self.mx
|
|
155
|
+
x = mx.array([list(ids)], dtype=mx.int32)
|
|
156
|
+
return self.model(x)
|
|
157
|
+
|
|
158
|
+
def _response_logprobs(self, ids: Sequence[int], *, prompt_len: int) -> list[float]:
|
|
159
|
+
assert self.mx is not None
|
|
160
|
+
mx = self.mx
|
|
161
|
+
if not ids:
|
|
162
|
+
return []
|
|
163
|
+
logits = self._forward_logits(ids)
|
|
164
|
+
logits = logits[:, :-1, :]
|
|
165
|
+
labels = mx.array([list(ids)[1:]], dtype=mx.int32)
|
|
166
|
+
lse = mx.logsumexp(logits, axis=-1)
|
|
167
|
+
chosen = mx.take_along_axis(logits, labels[..., None], axis=-1).squeeze(-1)
|
|
168
|
+
logp = chosen - lse
|
|
169
|
+
start = max(0, prompt_len - 1)
|
|
170
|
+
if start >= int(getattr(logp, "size", len(ids) - 1)):
|
|
171
|
+
return []
|
|
172
|
+
values = logp[:, start:]
|
|
173
|
+
try:
|
|
174
|
+
flat = values.flatten().tolist()
|
|
175
|
+
except Exception:
|
|
176
|
+
try:
|
|
177
|
+
flat = [float(v) for v in values.reshape(-1)]
|
|
178
|
+
except Exception:
|
|
179
|
+
flat = [float(v) for v in values]
|
|
180
|
+
return [float(v) for v in flat]
|
|
181
|
+
|
|
182
|
+
def _extract_top_k_logprobs(
|
|
183
|
+
self,
|
|
184
|
+
logits: Any,
|
|
185
|
+
k: int,
|
|
186
|
+
sampled_ids: Optional[Sequence[int]] = None,
|
|
187
|
+
) -> List[Dict[str, float]]:
|
|
188
|
+
"""Extract top-k logprobs from logits.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
logits: Logits array [batch, seq_len, vocab_size]
|
|
192
|
+
k: Number of top logprobs to extract
|
|
193
|
+
sampled_ids: Optional token IDs that were actually sampled
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
List of dicts mapping token string to logprob for each position
|
|
197
|
+
"""
|
|
198
|
+
assert self.mx is not None
|
|
199
|
+
mx = self.mx
|
|
200
|
+
|
|
201
|
+
# Get log softmax
|
|
202
|
+
log_probs = mx.log(mx.softmax(logits, axis=-1))
|
|
203
|
+
|
|
204
|
+
# Get top-k indices and values
|
|
205
|
+
# MLX doesn't have topk directly, so we use argsort
|
|
206
|
+
sorted_indices = mx.argsort(-log_probs, axis=-1)
|
|
207
|
+
|
|
208
|
+
results = []
|
|
209
|
+
batch_size, seq_len, vocab_size = log_probs.shape
|
|
210
|
+
|
|
211
|
+
# Limit k to vocab size
|
|
212
|
+
k = min(k, vocab_size)
|
|
213
|
+
|
|
214
|
+
for b in range(batch_size):
|
|
215
|
+
for t in range(seq_len):
|
|
216
|
+
# Get top-k for this position
|
|
217
|
+
top_k_indices = sorted_indices[b, t, :k]
|
|
218
|
+
top_k_logprobs = log_probs[b, t, top_k_indices]
|
|
219
|
+
|
|
220
|
+
# Build dict
|
|
221
|
+
token_logprobs = {}
|
|
222
|
+
for idx, logprob in zip(top_k_indices.tolist(), top_k_logprobs.tolist()):
|
|
223
|
+
# Try to decode token
|
|
224
|
+
try:
|
|
225
|
+
token_str = self.decode([idx])
|
|
226
|
+
# Escape special characters for JSON compatibility
|
|
227
|
+
token_str = token_str.replace('\n', '\\n').replace('\t', '\\t')
|
|
228
|
+
except Exception:
|
|
229
|
+
token_str = f"<token_{idx}>"
|
|
230
|
+
token_logprobs[token_str] = float(logprob)
|
|
231
|
+
|
|
232
|
+
results.append(token_logprobs)
|
|
233
|
+
|
|
234
|
+
return results
|
|
235
|
+
|
|
236
|
+
def generate(
|
|
237
|
+
self,
|
|
238
|
+
prompt: str,
|
|
239
|
+
*,
|
|
240
|
+
max_new_tokens: int = 256,
|
|
241
|
+
temperature: float = 0.8,
|
|
242
|
+
top_p: float = 1.0,
|
|
243
|
+
top_k: int | None = None,
|
|
244
|
+
seed: int | None = None,
|
|
245
|
+
) -> Generation:
|
|
246
|
+
assert self.mx is not None
|
|
247
|
+
mx = self.mx
|
|
248
|
+
if seed is not None:
|
|
249
|
+
mx.random.seed(seed)
|
|
250
|
+
|
|
251
|
+
prompt_ids = self.encode(prompt)
|
|
252
|
+
ids = list(prompt_ids)
|
|
253
|
+
prompt_len = len(prompt_ids)
|
|
254
|
+
|
|
255
|
+
# Prefer mlx_lm sampler if available
|
|
256
|
+
sampler = None
|
|
257
|
+
try:
|
|
258
|
+
from mlx_lm.sample_utils import make_sampler # type: ignore
|
|
259
|
+
|
|
260
|
+
sampler = make_sampler(
|
|
261
|
+
temp=float(temperature),
|
|
262
|
+
top_p=float(top_p),
|
|
263
|
+
top_k=int(top_k or 0),
|
|
264
|
+
)
|
|
265
|
+
except Exception:
|
|
266
|
+
sampler = None
|
|
267
|
+
|
|
268
|
+
for _ in range(max_new_tokens):
|
|
269
|
+
logits = self._forward_logits(ids)
|
|
270
|
+
last = logits[:, -1, :]
|
|
271
|
+
if temperature <= 0:
|
|
272
|
+
next_id = int(mx.argmax(last, axis=-1).item())
|
|
273
|
+
elif sampler is not None:
|
|
274
|
+
next_id = int(sampler(last).item())
|
|
275
|
+
else:
|
|
276
|
+
probs = mx.softmax(last / float(temperature), axis=-1)
|
|
277
|
+
next_id = int(mx.random.categorical(mx.log(probs)).item())
|
|
278
|
+
ids.append(next_id)
|
|
279
|
+
text = self.decode(ids)
|
|
280
|
+
return Generation(text=text, token_ids=ids, prompt_len=prompt_len)
|
|
281
|
+
|
|
282
|
+
def generate_with_logprobs(
|
|
283
|
+
self,
|
|
284
|
+
prompt: str,
|
|
285
|
+
*,
|
|
286
|
+
max_new_tokens: int = 256,
|
|
287
|
+
temperature: float = 0.8,
|
|
288
|
+
top_p: float = 1.0,
|
|
289
|
+
top_k_sampling: int | None = None,
|
|
290
|
+
seed: int | None = None,
|
|
291
|
+
logprobs: int = 0, # Number of top logprobs to return per token
|
|
292
|
+
) -> Generation:
|
|
293
|
+
"""Generate with logprobs support including top-k logprobs.
|
|
294
|
+
|
|
295
|
+
Args:
|
|
296
|
+
prompt: Input prompt
|
|
297
|
+
max_new_tokens: Maximum tokens to generate
|
|
298
|
+
temperature: Sampling temperature
|
|
299
|
+
top_p: Nucleus sampling parameter
|
|
300
|
+
top_k_sampling: Top-k sampling parameter (named to avoid conflict with logprobs)
|
|
301
|
+
seed: Random seed
|
|
302
|
+
logprobs: Number of top logprobs to return per token (0 = only sampled token)
|
|
303
|
+
|
|
304
|
+
Returns:
|
|
305
|
+
Generation with logprobs and optionally top_k_logprobs
|
|
306
|
+
"""
|
|
307
|
+
assert self.mx is not None
|
|
308
|
+
mx = self.mx
|
|
309
|
+
if seed is not None:
|
|
310
|
+
mx.random.seed(seed)
|
|
311
|
+
|
|
312
|
+
prompt_ids = self.encode(prompt)
|
|
313
|
+
ids = list(prompt_ids)
|
|
314
|
+
prompt_len = len(prompt_ids)
|
|
315
|
+
|
|
316
|
+
# Storage for per-token info
|
|
317
|
+
per_token_logprobs: list[float] = []
|
|
318
|
+
per_token_top_k: list[dict[str, float]] = []
|
|
319
|
+
|
|
320
|
+
# Prefer mlx_lm sampler if available
|
|
321
|
+
sampler = None
|
|
322
|
+
try:
|
|
323
|
+
from mlx_lm.sample_utils import make_sampler # type: ignore
|
|
324
|
+
|
|
325
|
+
sampler = make_sampler(
|
|
326
|
+
temp=float(temperature),
|
|
327
|
+
top_p=float(top_p),
|
|
328
|
+
top_k=int(top_k_sampling or 0),
|
|
329
|
+
)
|
|
330
|
+
except Exception:
|
|
331
|
+
sampler = None
|
|
332
|
+
|
|
333
|
+
for _ in range(max_new_tokens):
|
|
334
|
+
logits = self._forward_logits(ids)
|
|
335
|
+
last = logits[:, -1, :] # [batch=1, vocab_size]
|
|
336
|
+
|
|
337
|
+
# Get log probabilities for this position
|
|
338
|
+
log_probs = mx.log(mx.softmax(last, axis=-1))
|
|
339
|
+
|
|
340
|
+
if temperature <= 0:
|
|
341
|
+
next_id = int(mx.argmax(last, axis=-1).item())
|
|
342
|
+
elif sampler is not None:
|
|
343
|
+
next_id = int(sampler(last).item())
|
|
344
|
+
else:
|
|
345
|
+
probs = mx.softmax(last / float(temperature), axis=-1)
|
|
346
|
+
next_id = int(mx.random.categorical(mx.log(probs)).item())
|
|
347
|
+
|
|
348
|
+
# Get logprob of sampled token
|
|
349
|
+
sampled_logprob = float(log_probs[0, next_id].item())
|
|
350
|
+
per_token_logprobs.append(sampled_logprob)
|
|
351
|
+
|
|
352
|
+
# Get top-k logprobs if requested
|
|
353
|
+
if logprobs > 0:
|
|
354
|
+
top_k_logprobs = self._extract_top_k_logprobs(
|
|
355
|
+
last,
|
|
356
|
+
k=logprobs,
|
|
357
|
+
sampled_ids=[next_id]
|
|
358
|
+
)
|
|
359
|
+
if top_k_logprobs:
|
|
360
|
+
per_token_top_k.append(top_k_logprobs[0])
|
|
361
|
+
|
|
362
|
+
ids.append(next_id)
|
|
363
|
+
|
|
364
|
+
text = self.decode(ids)
|
|
365
|
+
|
|
366
|
+
return Generation(
|
|
367
|
+
text=text,
|
|
368
|
+
token_ids=ids,
|
|
369
|
+
prompt_len=prompt_len,
|
|
370
|
+
logprobs=per_token_logprobs,
|
|
371
|
+
top_k_logprobs=per_token_top_k if per_token_top_k else None,
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
def sft_loss(self, token_ids: Sequence[int], *, train_on_prompt: bool, prompt_len: int) -> Any:
|
|
375
|
+
assert self.mx is not None
|
|
376
|
+
mx = self.mx
|
|
377
|
+
ids = list(token_ids)
|
|
378
|
+
logits = self._forward_logits(ids)
|
|
379
|
+
logits = logits[:, :-1, :]
|
|
380
|
+
labels = mx.array([ids[1:]], dtype=mx.int32)
|
|
381
|
+
|
|
382
|
+
if not train_on_prompt:
|
|
383
|
+
mask = [0] * max(0, prompt_len - 1) + [1] * (len(ids) - prompt_len)
|
|
384
|
+
mask = mx.array([mask], dtype=mx.float32)
|
|
385
|
+
else:
|
|
386
|
+
mask = mx.ones(labels.shape, dtype=mx.float32)
|
|
387
|
+
|
|
388
|
+
lse = mx.logsumexp(logits, axis=-1)
|
|
389
|
+
chosen = mx.take_along_axis(logits, labels[..., None], axis=-1).squeeze(-1)
|
|
390
|
+
nll = (lse - chosen) * mask
|
|
391
|
+
denom = mx.maximum(mask.sum(), mx.array(1.0))
|
|
392
|
+
return nll.sum() / denom
|
|
393
|
+
|
|
394
|
+
def rl_loss(self, token_ids: Sequence[int], *, prompt_len: int, advantage: float) -> Any:
|
|
395
|
+
assert self.mx is not None
|
|
396
|
+
mx = self.mx
|
|
397
|
+
ids = list(token_ids)
|
|
398
|
+
logits = self._forward_logits(ids)
|
|
399
|
+
logits = logits[:, :-1, :]
|
|
400
|
+
labels = mx.array([ids[1:]], dtype=mx.int32)
|
|
401
|
+
|
|
402
|
+
lse = mx.logsumexp(logits, axis=-1)
|
|
403
|
+
chosen = mx.take_along_axis(logits, labels[..., None], axis=-1).squeeze(-1)
|
|
404
|
+
logp = chosen - lse
|
|
405
|
+
|
|
406
|
+
start = max(0, prompt_len - 1)
|
|
407
|
+
logp_resp = logp[:, start:]
|
|
408
|
+
return -float(advantage) * logp_resp.sum() / mx.maximum(mx.array(1.0), mx.array(logp_resp.size))
|
|
409
|
+
|
|
410
|
+
def sequence_logprob(self, token_ids: Sequence[int], *, prompt_len: int) -> Any:
|
|
411
|
+
assert self.mx is not None
|
|
412
|
+
mx = self.mx
|
|
413
|
+
ids = list(token_ids)
|
|
414
|
+
logits = self._forward_logits(ids)
|
|
415
|
+
logits = logits[:, :-1, :]
|
|
416
|
+
labels = mx.array([ids[1:]], dtype=mx.int32)
|
|
417
|
+
lse = mx.logsumexp(logits, axis=-1)
|
|
418
|
+
chosen = mx.take_along_axis(logits, labels[..., None], axis=-1).squeeze(-1)
|
|
419
|
+
logp = chosen - lse
|
|
420
|
+
start = max(0, prompt_len - 1)
|
|
421
|
+
return logp[:, start:].sum()
|
|
422
|
+
|
|
423
|
+
def token_logprobs(
|
|
424
|
+
self,
|
|
425
|
+
token_ids: Sequence[int],
|
|
426
|
+
*,
|
|
427
|
+
prompt_len: int,
|
|
428
|
+
top_k: int = 0,
|
|
429
|
+
include_prompt: bool = False,
|
|
430
|
+
) -> tuple[list[float], list[dict[str, float]] | None]:
|
|
431
|
+
assert self.mx is not None
|
|
432
|
+
mx = self.mx
|
|
433
|
+
ids = list(token_ids)
|
|
434
|
+
if len(ids) < 2:
|
|
435
|
+
return [], [] if top_k > 0 else None
|
|
436
|
+
|
|
437
|
+
logits = self._forward_logits(ids)
|
|
438
|
+
logits = logits[:, :-1, :]
|
|
439
|
+
labels = mx.array([ids[1:]], dtype=mx.int32)
|
|
440
|
+
lse = mx.logsumexp(logits, axis=-1)
|
|
441
|
+
chosen = mx.take_along_axis(logits, labels[..., None], axis=-1).squeeze(-1)
|
|
442
|
+
logp = chosen - lse
|
|
443
|
+
|
|
444
|
+
start = 0 if include_prompt else max(0, prompt_len - 1)
|
|
445
|
+
values = logp[:, start:]
|
|
446
|
+
try:
|
|
447
|
+
flat = values.flatten().tolist()
|
|
448
|
+
except Exception:
|
|
449
|
+
try:
|
|
450
|
+
flat = [float(v) for v in values.reshape(-1)]
|
|
451
|
+
except Exception:
|
|
452
|
+
flat = [float(v) for v in values]
|
|
453
|
+
logprobs = [float(v) for v in flat]
|
|
454
|
+
|
|
455
|
+
if top_k <= 0:
|
|
456
|
+
return logprobs, None
|
|
457
|
+
|
|
458
|
+
top_k_all = self._extract_top_k_logprobs(logits, k=int(top_k))
|
|
459
|
+
top_k_list = top_k_all[start:] if top_k_all else []
|
|
460
|
+
return logprobs, top_k_list
|
|
461
|
+
|
|
462
|
+
def value_and_grad(self, loss_fn):
|
|
463
|
+
if self.nn is None or self.model is None:
|
|
464
|
+
return loss_fn(self.model), None
|
|
465
|
+
vag = getattr(self.nn, "value_and_grad", None)
|
|
466
|
+
if callable(vag):
|
|
467
|
+
return vag(self.model, loss_fn)(self.model)
|
|
468
|
+
return loss_fn(self.model), None
|
|
469
|
+
|
|
470
|
+
def optimizer_and_params(self, *, lr: float, weight_decay: float = 0.0) -> tuple[Any, Any]:
|
|
471
|
+
assert self.optim is not None
|
|
472
|
+
if self.model is None:
|
|
473
|
+
raise RuntimeError("Backend not loaded")
|
|
474
|
+
|
|
475
|
+
params = None
|
|
476
|
+
if hasattr(self.model, "trainable_parameters"):
|
|
477
|
+
params = self.model.trainable_parameters()
|
|
478
|
+
if params is None or not params:
|
|
479
|
+
# fallback: train LoRA params if injected
|
|
480
|
+
from ..train.lora import lora_parameters
|
|
481
|
+
|
|
482
|
+
params = lora_parameters(self.model)
|
|
483
|
+
if not params:
|
|
484
|
+
params = getattr(self.model, "parameters", lambda: self.model)()
|
|
485
|
+
|
|
486
|
+
opt = self.optim.AdamW(learning_rate=lr, weight_decay=weight_decay)
|
|
487
|
+
opt.init(params)
|
|
488
|
+
return opt, params
|
|
489
|
+
|
|
490
|
+
def apply_grads(self, optimizer: Any, grads: Any) -> None:
|
|
491
|
+
assert self.mx is not None
|
|
492
|
+
mx = self.mx
|
|
493
|
+
if self.model is None:
|
|
494
|
+
raise RuntimeError("Backend not loaded")
|
|
495
|
+
optimizer.update(self.model, grads)
|
|
496
|
+
try:
|
|
497
|
+
mx.eval(self.model.parameters(), optimizer.state)
|
|
498
|
+
except Exception: # pragma: no cover
|
|
499
|
+
pass
|
|
500
|
+
|
|
501
|
+
def save_adapter(self, out_dir: str, *, metadata: dict | None = None) -> None:
|
|
502
|
+
if self.model is None:
|
|
503
|
+
raise RuntimeError("Backend not loaded")
|
|
504
|
+
adapter_cfg = self._adapter_config or {
|
|
505
|
+
"fine_tune_type": "lora",
|
|
506
|
+
"num_layers": 0,
|
|
507
|
+
"lora_parameters": {},
|
|
508
|
+
}
|
|
509
|
+
save_adapter(self.model, out_dir, adapter_config=adapter_cfg, metadata=metadata)
|