mlxsmith 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlxsmith/__init__.py +2 -0
- mlxsmith/accel/__init__.py +10 -0
- mlxsmith/accel/base.py +17 -0
- mlxsmith/accel/none.py +13 -0
- mlxsmith/accel/zmlx_backend.py +42 -0
- mlxsmith/adapters.py +46 -0
- mlxsmith/api/__init__.py +48 -0
- mlxsmith/api/handlers.py +1217 -0
- mlxsmith/api/schemas.py +436 -0
- mlxsmith/auth.py +88 -0
- mlxsmith/bench.py +102 -0
- mlxsmith/cli.py +950 -0
- mlxsmith/config.py +543 -0
- mlxsmith/config_models.py +261 -0
- mlxsmith/data.py +493 -0
- mlxsmith/envs/__init__.py +33 -0
- mlxsmith/envs/system.py +388 -0
- mlxsmith/envs/token_env.py +191 -0
- mlxsmith/eval.py +112 -0
- mlxsmith/infer.py +140 -0
- mlxsmith/llm/__init__.py +16 -0
- mlxsmith/llm/backend.py +126 -0
- mlxsmith/llm/interface.py +212 -0
- mlxsmith/llm/mlx_lm_backend.py +509 -0
- mlxsmith/llm/mock_backend.py +228 -0
- mlxsmith/llm/registry.py +12 -0
- mlxsmith/models.py +257 -0
- mlxsmith/orchestrator/__init__.py +25 -0
- mlxsmith/orchestrator/daemon.py +454 -0
- mlxsmith/orchestrator/inference_worker.py +496 -0
- mlxsmith/orchestrator/queue.py +355 -0
- mlxsmith/orchestrator/trainer_worker.py +437 -0
- mlxsmith/rlm/__init__.py +8 -0
- mlxsmith/rlm/corpus.py +74 -0
- mlxsmith/rlm/gating.py +90 -0
- mlxsmith/rlm/generate.py +249 -0
- mlxsmith/rlm/history.py +12 -0
- mlxsmith/rlm/inference.py +150 -0
- mlxsmith/rlm/loop.py +1297 -0
- mlxsmith/rlm/mutate.py +82 -0
- mlxsmith/rlm/trainer.py +73 -0
- mlxsmith/rlm/weights.py +263 -0
- mlxsmith/runs.py +44 -0
- mlxsmith/sdk/__init__.py +392 -0
- mlxsmith/sdk/future.py +486 -0
- mlxsmith/sdk/losses.py +262 -0
- mlxsmith/sdk/sampling_client.py +729 -0
- mlxsmith/sdk/training_client.py +676 -0
- mlxsmith/server.py +376 -0
- mlxsmith/train/__init__.py +0 -0
- mlxsmith/train/distill.py +279 -0
- mlxsmith/train/lora.py +280 -0
- mlxsmith/train/pref.py +180 -0
- mlxsmith/train/rft.py +458 -0
- mlxsmith/train/sft.py +151 -0
- mlxsmith/util.py +174 -0
- mlxsmith/verifiers/__init__.py +3 -0
- mlxsmith/verifiers/compose.py +109 -0
- mlxsmith/verifiers/docker_verifier.py +111 -0
- mlxsmith/verifiers/jsonschema.py +54 -0
- mlxsmith/verifiers/pytest_verifier.py +82 -0
- mlxsmith/verifiers/regex.py +15 -0
- mlxsmith/verifiers/types.py +10 -0
- mlxsmith-0.1.0.dist-info/METADATA +163 -0
- mlxsmith-0.1.0.dist-info/RECORD +69 -0
- mlxsmith-0.1.0.dist-info/WHEEL +5 -0
- mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
- mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
- mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
mlxsmith/sdk/__init__.py
ADDED
|
@@ -0,0 +1,392 @@
|
|
|
1
|
+
"""MLXSmith SDK for training and inference.
|
|
2
|
+
|
|
3
|
+
This module provides high-level interfaces for:
|
|
4
|
+
- Model loading and sampling
|
|
5
|
+
- Training operations (SFT, preference, RL)
|
|
6
|
+
- Async futures-based API
|
|
7
|
+
- Checkpoint management
|
|
8
|
+
|
|
9
|
+
Example:
|
|
10
|
+
>>> from mlxsmith.sdk import load_model, TrainingClient, SamplingClient
|
|
11
|
+
>>>
|
|
12
|
+
>>> # Load model
|
|
13
|
+
>>> loaded = load_model("mlx-community/Llama-3.2-1B-Instruct-4bit", cfg)
|
|
14
|
+
>>>
|
|
15
|
+
>>> # Create training client
|
|
16
|
+
>>> trainer = TrainingClient(loaded.backend)
|
|
17
|
+
>>> trainer.create_optimizer(lr=1e-4).result()
|
|
18
|
+
>>>
|
|
19
|
+
>>> # Training loop
|
|
20
|
+
>>> batch = TrainingBatch(prompts=[...], responses=[...])
|
|
21
|
+
>>> result = trainer.forward_backward(batch).result()
|
|
22
|
+
>>> trainer.optim_step(result.grads).result()
|
|
23
|
+
>>>
|
|
24
|
+
>>> # Sampling
|
|
25
|
+
>>> sampler = SamplingClient(backend=loaded.backend)
|
|
26
|
+
>>> result = sampler.sample("Hello", logprobs_k=5)
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
from __future__ import annotations
|
|
30
|
+
|
|
31
|
+
from dataclasses import dataclass
|
|
32
|
+
from pathlib import Path
|
|
33
|
+
from typing import Any, Iterable, List, Optional, Sequence, Tuple
|
|
34
|
+
|
|
35
|
+
from ..config import ProjectConfig
|
|
36
|
+
from ..llm.backend import DecodingConfig, Generation
|
|
37
|
+
from ..llm.registry import get_llm_backend
|
|
38
|
+
from ..models import resolve_model_spec
|
|
39
|
+
from ..train.lora import LoRAConfig
|
|
40
|
+
from .losses import (
|
|
41
|
+
LOSS_REGISTRY,
|
|
42
|
+
get_loss,
|
|
43
|
+
cross_entropy_loss,
|
|
44
|
+
dpo_loss,
|
|
45
|
+
orpo_loss,
|
|
46
|
+
preference_loss,
|
|
47
|
+
importance_sampling_loss,
|
|
48
|
+
ppo_loss,
|
|
49
|
+
cispo_loss,
|
|
50
|
+
dro_loss,
|
|
51
|
+
)
|
|
52
|
+
from .future import (
|
|
53
|
+
APIFuture,
|
|
54
|
+
APIFutureState,
|
|
55
|
+
SdkFuturePool,
|
|
56
|
+
completed_future,
|
|
57
|
+
failed_future,
|
|
58
|
+
cancelled_future,
|
|
59
|
+
)
|
|
60
|
+
from .training_client import (
|
|
61
|
+
TrainingClient,
|
|
62
|
+
TrainingBatch,
|
|
63
|
+
ForwardBackwardResult,
|
|
64
|
+
OptimizerStepResult,
|
|
65
|
+
CheckpointResult,
|
|
66
|
+
WeightsResult,
|
|
67
|
+
DistillationTrainingClient,
|
|
68
|
+
)
|
|
69
|
+
from .sampling_client import (
|
|
70
|
+
SamplingClient,
|
|
71
|
+
SampleResult,
|
|
72
|
+
SampleBatchResult,
|
|
73
|
+
DistillationSampler,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@dataclass
|
|
78
|
+
class LoadedModel:
|
|
79
|
+
backend: Any
|
|
80
|
+
base_model: str
|
|
81
|
+
adapter_path: Optional[str]
|
|
82
|
+
adapter_meta: Optional[dict]
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _truncate_ids(ids: Sequence[int], prompt_len: int, max_seq_len: Optional[int]) -> Tuple[List[int], int]:
|
|
86
|
+
if max_seq_len is None:
|
|
87
|
+
return list(ids), prompt_len
|
|
88
|
+
max_len = int(max_seq_len)
|
|
89
|
+
if max_len <= 0:
|
|
90
|
+
return list(ids), prompt_len
|
|
91
|
+
ids_list = list(ids)
|
|
92
|
+
if len(ids_list) <= max_len:
|
|
93
|
+
return ids_list, prompt_len
|
|
94
|
+
overflow = len(ids_list) - max_len
|
|
95
|
+
ids_list = ids_list[overflow:]
|
|
96
|
+
prompt_len = max(0, prompt_len - overflow)
|
|
97
|
+
return ids_list, prompt_len
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def load_model(
|
|
101
|
+
model_id_or_path: str,
|
|
102
|
+
cfg: ProjectConfig,
|
|
103
|
+
*,
|
|
104
|
+
apply_lora_if_missing: bool = False,
|
|
105
|
+
lora_config: Optional[LoRAConfig] = None,
|
|
106
|
+
) -> LoadedModel:
|
|
107
|
+
"""Load a model with the configured backend.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
model_id_or_path: Model identifier or path
|
|
111
|
+
cfg: Project configuration
|
|
112
|
+
apply_lora_if_missing: Whether to apply LoRA if no adapter found
|
|
113
|
+
lora_config: Optional LoRA configuration
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
LoadedModel with backend and metadata
|
|
117
|
+
"""
|
|
118
|
+
backend = get_llm_backend(cfg.model.backend)
|
|
119
|
+
base_model, adapter_path, adapter_meta = resolve_model_spec(Path.cwd(), model_id_or_path, cfg)
|
|
120
|
+
backend.load(
|
|
121
|
+
base_model,
|
|
122
|
+
max_seq_len=cfg.model.max_seq_len,
|
|
123
|
+
dtype=cfg.model.dtype,
|
|
124
|
+
trust_remote_code=cfg.model.trust_remote_code,
|
|
125
|
+
)
|
|
126
|
+
if adapter_path:
|
|
127
|
+
backend.apply_adapter(str(adapter_path))
|
|
128
|
+
elif apply_lora_if_missing:
|
|
129
|
+
if lora_config is None:
|
|
130
|
+
lora_config = LoRAConfig(
|
|
131
|
+
r=cfg.lora.r,
|
|
132
|
+
alpha=cfg.lora.alpha,
|
|
133
|
+
dropout=cfg.lora.dropout,
|
|
134
|
+
target_modules=list(cfg.lora.target_modules or []),
|
|
135
|
+
num_layers=cfg.lora.num_layers,
|
|
136
|
+
scale=cfg.lora.scale,
|
|
137
|
+
fine_tune_type=cfg.lora.fine_tune_type,
|
|
138
|
+
)
|
|
139
|
+
backend.apply_lora_from_config(lora_config)
|
|
140
|
+
|
|
141
|
+
return LoadedModel(
|
|
142
|
+
backend=backend,
|
|
143
|
+
base_model=base_model,
|
|
144
|
+
adapter_path=str(adapter_path) if adapter_path else None,
|
|
145
|
+
adapter_meta=adapter_meta,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def sample(backend: Any, prompts: Iterable[str], decoding: DecodingConfig) -> List[Generation]:
|
|
150
|
+
"""Sample completions from the backend.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
backend: LLM backend instance
|
|
154
|
+
prompts: Iterable of prompt strings
|
|
155
|
+
decoding: Decoding configuration
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
List of Generation results
|
|
159
|
+
"""
|
|
160
|
+
results: List[Generation] = []
|
|
161
|
+
for prompt in prompts:
|
|
162
|
+
results.append(
|
|
163
|
+
backend.generate(
|
|
164
|
+
prompt,
|
|
165
|
+
max_new_tokens=decoding.max_new_tokens,
|
|
166
|
+
temperature=decoding.temperature,
|
|
167
|
+
top_p=decoding.top_p,
|
|
168
|
+
top_k=decoding.top_k,
|
|
169
|
+
seed=decoding.seed,
|
|
170
|
+
)
|
|
171
|
+
)
|
|
172
|
+
return results
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def logprobs(
|
|
176
|
+
backend: Any,
|
|
177
|
+
prompts: Iterable[str],
|
|
178
|
+
completions: Iterable[str],
|
|
179
|
+
*,
|
|
180
|
+
max_seq_len: Optional[int] = None,
|
|
181
|
+
) -> List[Any]:
|
|
182
|
+
"""Compute logprobs for prompt-completion pairs.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
backend: LLM backend instance
|
|
186
|
+
prompts: Iterable of prompt strings
|
|
187
|
+
completions: Iterable of completion strings
|
|
188
|
+
max_seq_len: Maximum sequence length
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
List of logprob values
|
|
192
|
+
"""
|
|
193
|
+
results: List[Any] = []
|
|
194
|
+
for prompt, completion in zip(prompts, completions):
|
|
195
|
+
prompt_ids = backend.encode(prompt)
|
|
196
|
+
ids = backend.encode(prompt + completion)
|
|
197
|
+
ids, prompt_len = _truncate_ids(ids, len(prompt_ids), max_seq_len)
|
|
198
|
+
results.append(backend.sequence_logprob(ids, prompt_len=prompt_len))
|
|
199
|
+
return results
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def forward_backward(backend: Any, loss_fn) -> Tuple[Any, Any | None]:
|
|
203
|
+
"""Execute forward and backward pass.
|
|
204
|
+
|
|
205
|
+
Args:
|
|
206
|
+
backend: LLM backend instance
|
|
207
|
+
loss_fn: Loss function
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
Tuple of (loss, gradients)
|
|
211
|
+
"""
|
|
212
|
+
return backend.value_and_grad(loss_fn)
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def forward_backward_custom(backend: Any, loss_fn) -> Tuple[Any, Any | None]:
|
|
216
|
+
"""Execute custom forward and backward pass."""
|
|
217
|
+
return backend.value_and_grad(loss_fn)
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def sft_forward_backward(
|
|
221
|
+
backend: Any,
|
|
222
|
+
prompt: str,
|
|
223
|
+
response: str,
|
|
224
|
+
*,
|
|
225
|
+
train_on_prompt: bool = False,
|
|
226
|
+
max_seq_len: Optional[int] = None,
|
|
227
|
+
) -> Tuple[Any, Any | None]:
|
|
228
|
+
"""Execute SFT forward/backward pass.
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
backend: LLM backend instance
|
|
232
|
+
prompt: Prompt string
|
|
233
|
+
response: Response string
|
|
234
|
+
train_on_prompt: Whether to compute loss on prompt tokens
|
|
235
|
+
max_seq_len: Maximum sequence length
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
Tuple of (loss, gradients)
|
|
239
|
+
"""
|
|
240
|
+
prompt_ids = backend.encode(prompt)
|
|
241
|
+
ids = backend.encode(prompt + response)
|
|
242
|
+
ids, prompt_len = _truncate_ids(ids, len(prompt_ids), max_seq_len)
|
|
243
|
+
|
|
244
|
+
def loss_fn(_model):
|
|
245
|
+
return backend.sft_loss(ids, train_on_prompt=train_on_prompt, prompt_len=prompt_len)
|
|
246
|
+
|
|
247
|
+
return backend.value_and_grad(loss_fn)
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def preference_forward_backward(
|
|
251
|
+
backend: Any,
|
|
252
|
+
prompt: str,
|
|
253
|
+
chosen: str,
|
|
254
|
+
rejected: str,
|
|
255
|
+
*,
|
|
256
|
+
algo: str = "dpo",
|
|
257
|
+
beta: float = 0.1,
|
|
258
|
+
reference_backend: Optional[Any] = None,
|
|
259
|
+
kl_coeff: float = 0.0,
|
|
260
|
+
train_on_prompt: bool = False,
|
|
261
|
+
max_seq_len: Optional[int] = None,
|
|
262
|
+
) -> Tuple[Any, Any | None]:
|
|
263
|
+
"""Execute preference-based forward/backward pass.
|
|
264
|
+
|
|
265
|
+
Args:
|
|
266
|
+
backend: LLM backend instance
|
|
267
|
+
prompt: Prompt string
|
|
268
|
+
chosen: Chosen (preferred) response
|
|
269
|
+
rejected: Rejected response
|
|
270
|
+
algo: Algorithm - "dpo", "orpo", etc.
|
|
271
|
+
beta: Temperature parameter for preference loss
|
|
272
|
+
reference_backend: Optional reference model backend
|
|
273
|
+
kl_coeff: KL divergence coefficient
|
|
274
|
+
train_on_prompt: Whether to compute loss on prompt tokens
|
|
275
|
+
max_seq_len: Maximum sequence length
|
|
276
|
+
|
|
277
|
+
Returns:
|
|
278
|
+
Tuple of (loss, gradients)
|
|
279
|
+
"""
|
|
280
|
+
prompt_ids = backend.encode(prompt)
|
|
281
|
+
chosen_ids = backend.encode(prompt + chosen)
|
|
282
|
+
rejected_ids = backend.encode(prompt + rejected)
|
|
283
|
+
|
|
284
|
+
chosen_ids, prompt_len_c = _truncate_ids(chosen_ids, len(prompt_ids), max_seq_len)
|
|
285
|
+
rejected_ids, prompt_len_r = _truncate_ids(rejected_ids, len(prompt_ids), max_seq_len)
|
|
286
|
+
|
|
287
|
+
def loss_fn(_model):
|
|
288
|
+
return preference_loss(
|
|
289
|
+
backend,
|
|
290
|
+
chosen_ids,
|
|
291
|
+
rejected_ids,
|
|
292
|
+
prompt_len_chosen=prompt_len_c,
|
|
293
|
+
prompt_len_rejected=prompt_len_r,
|
|
294
|
+
algo=algo,
|
|
295
|
+
beta=beta,
|
|
296
|
+
reference_backend=reference_backend,
|
|
297
|
+
kl_coeff=kl_coeff,
|
|
298
|
+
train_on_prompt=train_on_prompt,
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
return backend.value_and_grad(loss_fn)
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def create_optimizer(backend: Any, *, lr: float, weight_decay: float = 0.0) -> Tuple[Any, Any]:
|
|
305
|
+
"""Create optimizer for training.
|
|
306
|
+
|
|
307
|
+
Args:
|
|
308
|
+
backend: LLM backend instance
|
|
309
|
+
lr: Learning rate
|
|
310
|
+
weight_decay: Weight decay coefficient
|
|
311
|
+
|
|
312
|
+
Returns:
|
|
313
|
+
Tuple of (optimizer, parameters)
|
|
314
|
+
"""
|
|
315
|
+
return backend.optimizer_and_params(lr=lr, weight_decay=weight_decay)
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def optim_step(backend: Any, optimizer: Any, grads: Any) -> None:
|
|
319
|
+
"""Execute optimizer step.
|
|
320
|
+
|
|
321
|
+
Args:
|
|
322
|
+
backend: LLM backend instance
|
|
323
|
+
optimizer: Optimizer instance
|
|
324
|
+
grads: Gradients
|
|
325
|
+
"""
|
|
326
|
+
if grads is None:
|
|
327
|
+
return
|
|
328
|
+
backend.apply_grads(optimizer, grads)
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def save_adapter(backend: Any, out_dir: str, *, metadata: Optional[dict] = None) -> None:
|
|
332
|
+
"""Save adapter weights.
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
backend: LLM backend instance
|
|
336
|
+
out_dir: Output directory
|
|
337
|
+
metadata: Optional metadata to save
|
|
338
|
+
"""
|
|
339
|
+
backend.save_adapter(out_dir, metadata=metadata)
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
__all__ = [
|
|
343
|
+
# Core classes
|
|
344
|
+
"LoadedModel",
|
|
345
|
+
"load_model",
|
|
346
|
+
|
|
347
|
+
# Operations
|
|
348
|
+
"sample",
|
|
349
|
+
"logprobs",
|
|
350
|
+
"forward_backward",
|
|
351
|
+
"forward_backward_custom",
|
|
352
|
+
"sft_forward_backward",
|
|
353
|
+
"preference_forward_backward",
|
|
354
|
+
"create_optimizer",
|
|
355
|
+
"optim_step",
|
|
356
|
+
"save_adapter",
|
|
357
|
+
|
|
358
|
+
# Futures
|
|
359
|
+
"APIFuture",
|
|
360
|
+
"APIFutureState",
|
|
361
|
+
"SdkFuturePool",
|
|
362
|
+
"completed_future",
|
|
363
|
+
"failed_future",
|
|
364
|
+
"cancelled_future",
|
|
365
|
+
|
|
366
|
+
# Training Client
|
|
367
|
+
"TrainingClient",
|
|
368
|
+
"TrainingBatch",
|
|
369
|
+
"ForwardBackwardResult",
|
|
370
|
+
"OptimizerStepResult",
|
|
371
|
+
"CheckpointResult",
|
|
372
|
+
"WeightsResult",
|
|
373
|
+
"DistillationTrainingClient",
|
|
374
|
+
|
|
375
|
+
# Sampling Client
|
|
376
|
+
"SamplingClient",
|
|
377
|
+
"SampleResult",
|
|
378
|
+
"SampleBatchResult",
|
|
379
|
+
"DistillationSampler",
|
|
380
|
+
|
|
381
|
+
# Losses
|
|
382
|
+
"LOSS_REGISTRY",
|
|
383
|
+
"get_loss",
|
|
384
|
+
"cross_entropy_loss",
|
|
385
|
+
"dpo_loss",
|
|
386
|
+
"orpo_loss",
|
|
387
|
+
"preference_loss",
|
|
388
|
+
"importance_sampling_loss",
|
|
389
|
+
"ppo_loss",
|
|
390
|
+
"cispo_loss",
|
|
391
|
+
"dro_loss",
|
|
392
|
+
]
|