mlxsmith 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. mlxsmith/__init__.py +2 -0
  2. mlxsmith/accel/__init__.py +10 -0
  3. mlxsmith/accel/base.py +17 -0
  4. mlxsmith/accel/none.py +13 -0
  5. mlxsmith/accel/zmlx_backend.py +42 -0
  6. mlxsmith/adapters.py +46 -0
  7. mlxsmith/api/__init__.py +48 -0
  8. mlxsmith/api/handlers.py +1217 -0
  9. mlxsmith/api/schemas.py +436 -0
  10. mlxsmith/auth.py +88 -0
  11. mlxsmith/bench.py +102 -0
  12. mlxsmith/cli.py +950 -0
  13. mlxsmith/config.py +543 -0
  14. mlxsmith/config_models.py +261 -0
  15. mlxsmith/data.py +493 -0
  16. mlxsmith/envs/__init__.py +33 -0
  17. mlxsmith/envs/system.py +388 -0
  18. mlxsmith/envs/token_env.py +191 -0
  19. mlxsmith/eval.py +112 -0
  20. mlxsmith/infer.py +140 -0
  21. mlxsmith/llm/__init__.py +16 -0
  22. mlxsmith/llm/backend.py +126 -0
  23. mlxsmith/llm/interface.py +212 -0
  24. mlxsmith/llm/mlx_lm_backend.py +509 -0
  25. mlxsmith/llm/mock_backend.py +228 -0
  26. mlxsmith/llm/registry.py +12 -0
  27. mlxsmith/models.py +257 -0
  28. mlxsmith/orchestrator/__init__.py +25 -0
  29. mlxsmith/orchestrator/daemon.py +454 -0
  30. mlxsmith/orchestrator/inference_worker.py +496 -0
  31. mlxsmith/orchestrator/queue.py +355 -0
  32. mlxsmith/orchestrator/trainer_worker.py +437 -0
  33. mlxsmith/rlm/__init__.py +8 -0
  34. mlxsmith/rlm/corpus.py +74 -0
  35. mlxsmith/rlm/gating.py +90 -0
  36. mlxsmith/rlm/generate.py +249 -0
  37. mlxsmith/rlm/history.py +12 -0
  38. mlxsmith/rlm/inference.py +150 -0
  39. mlxsmith/rlm/loop.py +1297 -0
  40. mlxsmith/rlm/mutate.py +82 -0
  41. mlxsmith/rlm/trainer.py +73 -0
  42. mlxsmith/rlm/weights.py +263 -0
  43. mlxsmith/runs.py +44 -0
  44. mlxsmith/sdk/__init__.py +392 -0
  45. mlxsmith/sdk/future.py +486 -0
  46. mlxsmith/sdk/losses.py +262 -0
  47. mlxsmith/sdk/sampling_client.py +729 -0
  48. mlxsmith/sdk/training_client.py +676 -0
  49. mlxsmith/server.py +376 -0
  50. mlxsmith/train/__init__.py +0 -0
  51. mlxsmith/train/distill.py +279 -0
  52. mlxsmith/train/lora.py +280 -0
  53. mlxsmith/train/pref.py +180 -0
  54. mlxsmith/train/rft.py +458 -0
  55. mlxsmith/train/sft.py +151 -0
  56. mlxsmith/util.py +174 -0
  57. mlxsmith/verifiers/__init__.py +3 -0
  58. mlxsmith/verifiers/compose.py +109 -0
  59. mlxsmith/verifiers/docker_verifier.py +111 -0
  60. mlxsmith/verifiers/jsonschema.py +54 -0
  61. mlxsmith/verifiers/pytest_verifier.py +82 -0
  62. mlxsmith/verifiers/regex.py +15 -0
  63. mlxsmith/verifiers/types.py +10 -0
  64. mlxsmith-0.1.0.dist-info/METADATA +163 -0
  65. mlxsmith-0.1.0.dist-info/RECORD +69 -0
  66. mlxsmith-0.1.0.dist-info/WHEEL +5 -0
  67. mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
  68. mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
  69. mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,392 @@
1
+ """MLXSmith SDK for training and inference.
2
+
3
+ This module provides high-level interfaces for:
4
+ - Model loading and sampling
5
+ - Training operations (SFT, preference, RL)
6
+ - Async futures-based API
7
+ - Checkpoint management
8
+
9
+ Example:
10
+ >>> from mlxsmith.sdk import load_model, TrainingClient, SamplingClient
11
+ >>>
12
+ >>> # Load model
13
+ >>> loaded = load_model("mlx-community/Llama-3.2-1B-Instruct-4bit", cfg)
14
+ >>>
15
+ >>> # Create training client
16
+ >>> trainer = TrainingClient(loaded.backend)
17
+ >>> trainer.create_optimizer(lr=1e-4).result()
18
+ >>>
19
+ >>> # Training loop
20
+ >>> batch = TrainingBatch(prompts=[...], responses=[...])
21
+ >>> result = trainer.forward_backward(batch).result()
22
+ >>> trainer.optim_step(result.grads).result()
23
+ >>>
24
+ >>> # Sampling
25
+ >>> sampler = SamplingClient(backend=loaded.backend)
26
+ >>> result = sampler.sample("Hello", logprobs_k=5)
27
+ """
28
+
29
+ from __future__ import annotations
30
+
31
+ from dataclasses import dataclass
32
+ from pathlib import Path
33
+ from typing import Any, Iterable, List, Optional, Sequence, Tuple
34
+
35
+ from ..config import ProjectConfig
36
+ from ..llm.backend import DecodingConfig, Generation
37
+ from ..llm.registry import get_llm_backend
38
+ from ..models import resolve_model_spec
39
+ from ..train.lora import LoRAConfig
40
+ from .losses import (
41
+ LOSS_REGISTRY,
42
+ get_loss,
43
+ cross_entropy_loss,
44
+ dpo_loss,
45
+ orpo_loss,
46
+ preference_loss,
47
+ importance_sampling_loss,
48
+ ppo_loss,
49
+ cispo_loss,
50
+ dro_loss,
51
+ )
52
+ from .future import (
53
+ APIFuture,
54
+ APIFutureState,
55
+ SdkFuturePool,
56
+ completed_future,
57
+ failed_future,
58
+ cancelled_future,
59
+ )
60
+ from .training_client import (
61
+ TrainingClient,
62
+ TrainingBatch,
63
+ ForwardBackwardResult,
64
+ OptimizerStepResult,
65
+ CheckpointResult,
66
+ WeightsResult,
67
+ DistillationTrainingClient,
68
+ )
69
+ from .sampling_client import (
70
+ SamplingClient,
71
+ SampleResult,
72
+ SampleBatchResult,
73
+ DistillationSampler,
74
+ )
75
+
76
+
77
+ @dataclass
78
+ class LoadedModel:
79
+ backend: Any
80
+ base_model: str
81
+ adapter_path: Optional[str]
82
+ adapter_meta: Optional[dict]
83
+
84
+
85
+ def _truncate_ids(ids: Sequence[int], prompt_len: int, max_seq_len: Optional[int]) -> Tuple[List[int], int]:
86
+ if max_seq_len is None:
87
+ return list(ids), prompt_len
88
+ max_len = int(max_seq_len)
89
+ if max_len <= 0:
90
+ return list(ids), prompt_len
91
+ ids_list = list(ids)
92
+ if len(ids_list) <= max_len:
93
+ return ids_list, prompt_len
94
+ overflow = len(ids_list) - max_len
95
+ ids_list = ids_list[overflow:]
96
+ prompt_len = max(0, prompt_len - overflow)
97
+ return ids_list, prompt_len
98
+
99
+
100
+ def load_model(
101
+ model_id_or_path: str,
102
+ cfg: ProjectConfig,
103
+ *,
104
+ apply_lora_if_missing: bool = False,
105
+ lora_config: Optional[LoRAConfig] = None,
106
+ ) -> LoadedModel:
107
+ """Load a model with the configured backend.
108
+
109
+ Args:
110
+ model_id_or_path: Model identifier or path
111
+ cfg: Project configuration
112
+ apply_lora_if_missing: Whether to apply LoRA if no adapter found
113
+ lora_config: Optional LoRA configuration
114
+
115
+ Returns:
116
+ LoadedModel with backend and metadata
117
+ """
118
+ backend = get_llm_backend(cfg.model.backend)
119
+ base_model, adapter_path, adapter_meta = resolve_model_spec(Path.cwd(), model_id_or_path, cfg)
120
+ backend.load(
121
+ base_model,
122
+ max_seq_len=cfg.model.max_seq_len,
123
+ dtype=cfg.model.dtype,
124
+ trust_remote_code=cfg.model.trust_remote_code,
125
+ )
126
+ if adapter_path:
127
+ backend.apply_adapter(str(adapter_path))
128
+ elif apply_lora_if_missing:
129
+ if lora_config is None:
130
+ lora_config = LoRAConfig(
131
+ r=cfg.lora.r,
132
+ alpha=cfg.lora.alpha,
133
+ dropout=cfg.lora.dropout,
134
+ target_modules=list(cfg.lora.target_modules or []),
135
+ num_layers=cfg.lora.num_layers,
136
+ scale=cfg.lora.scale,
137
+ fine_tune_type=cfg.lora.fine_tune_type,
138
+ )
139
+ backend.apply_lora_from_config(lora_config)
140
+
141
+ return LoadedModel(
142
+ backend=backend,
143
+ base_model=base_model,
144
+ adapter_path=str(adapter_path) if adapter_path else None,
145
+ adapter_meta=adapter_meta,
146
+ )
147
+
148
+
149
+ def sample(backend: Any, prompts: Iterable[str], decoding: DecodingConfig) -> List[Generation]:
150
+ """Sample completions from the backend.
151
+
152
+ Args:
153
+ backend: LLM backend instance
154
+ prompts: Iterable of prompt strings
155
+ decoding: Decoding configuration
156
+
157
+ Returns:
158
+ List of Generation results
159
+ """
160
+ results: List[Generation] = []
161
+ for prompt in prompts:
162
+ results.append(
163
+ backend.generate(
164
+ prompt,
165
+ max_new_tokens=decoding.max_new_tokens,
166
+ temperature=decoding.temperature,
167
+ top_p=decoding.top_p,
168
+ top_k=decoding.top_k,
169
+ seed=decoding.seed,
170
+ )
171
+ )
172
+ return results
173
+
174
+
175
+ def logprobs(
176
+ backend: Any,
177
+ prompts: Iterable[str],
178
+ completions: Iterable[str],
179
+ *,
180
+ max_seq_len: Optional[int] = None,
181
+ ) -> List[Any]:
182
+ """Compute logprobs for prompt-completion pairs.
183
+
184
+ Args:
185
+ backend: LLM backend instance
186
+ prompts: Iterable of prompt strings
187
+ completions: Iterable of completion strings
188
+ max_seq_len: Maximum sequence length
189
+
190
+ Returns:
191
+ List of logprob values
192
+ """
193
+ results: List[Any] = []
194
+ for prompt, completion in zip(prompts, completions):
195
+ prompt_ids = backend.encode(prompt)
196
+ ids = backend.encode(prompt + completion)
197
+ ids, prompt_len = _truncate_ids(ids, len(prompt_ids), max_seq_len)
198
+ results.append(backend.sequence_logprob(ids, prompt_len=prompt_len))
199
+ return results
200
+
201
+
202
+ def forward_backward(backend: Any, loss_fn) -> Tuple[Any, Any | None]:
203
+ """Execute forward and backward pass.
204
+
205
+ Args:
206
+ backend: LLM backend instance
207
+ loss_fn: Loss function
208
+
209
+ Returns:
210
+ Tuple of (loss, gradients)
211
+ """
212
+ return backend.value_and_grad(loss_fn)
213
+
214
+
215
+ def forward_backward_custom(backend: Any, loss_fn) -> Tuple[Any, Any | None]:
216
+ """Execute custom forward and backward pass."""
217
+ return backend.value_and_grad(loss_fn)
218
+
219
+
220
+ def sft_forward_backward(
221
+ backend: Any,
222
+ prompt: str,
223
+ response: str,
224
+ *,
225
+ train_on_prompt: bool = False,
226
+ max_seq_len: Optional[int] = None,
227
+ ) -> Tuple[Any, Any | None]:
228
+ """Execute SFT forward/backward pass.
229
+
230
+ Args:
231
+ backend: LLM backend instance
232
+ prompt: Prompt string
233
+ response: Response string
234
+ train_on_prompt: Whether to compute loss on prompt tokens
235
+ max_seq_len: Maximum sequence length
236
+
237
+ Returns:
238
+ Tuple of (loss, gradients)
239
+ """
240
+ prompt_ids = backend.encode(prompt)
241
+ ids = backend.encode(prompt + response)
242
+ ids, prompt_len = _truncate_ids(ids, len(prompt_ids), max_seq_len)
243
+
244
+ def loss_fn(_model):
245
+ return backend.sft_loss(ids, train_on_prompt=train_on_prompt, prompt_len=prompt_len)
246
+
247
+ return backend.value_and_grad(loss_fn)
248
+
249
+
250
+ def preference_forward_backward(
251
+ backend: Any,
252
+ prompt: str,
253
+ chosen: str,
254
+ rejected: str,
255
+ *,
256
+ algo: str = "dpo",
257
+ beta: float = 0.1,
258
+ reference_backend: Optional[Any] = None,
259
+ kl_coeff: float = 0.0,
260
+ train_on_prompt: bool = False,
261
+ max_seq_len: Optional[int] = None,
262
+ ) -> Tuple[Any, Any | None]:
263
+ """Execute preference-based forward/backward pass.
264
+
265
+ Args:
266
+ backend: LLM backend instance
267
+ prompt: Prompt string
268
+ chosen: Chosen (preferred) response
269
+ rejected: Rejected response
270
+ algo: Algorithm - "dpo", "orpo", etc.
271
+ beta: Temperature parameter for preference loss
272
+ reference_backend: Optional reference model backend
273
+ kl_coeff: KL divergence coefficient
274
+ train_on_prompt: Whether to compute loss on prompt tokens
275
+ max_seq_len: Maximum sequence length
276
+
277
+ Returns:
278
+ Tuple of (loss, gradients)
279
+ """
280
+ prompt_ids = backend.encode(prompt)
281
+ chosen_ids = backend.encode(prompt + chosen)
282
+ rejected_ids = backend.encode(prompt + rejected)
283
+
284
+ chosen_ids, prompt_len_c = _truncate_ids(chosen_ids, len(prompt_ids), max_seq_len)
285
+ rejected_ids, prompt_len_r = _truncate_ids(rejected_ids, len(prompt_ids), max_seq_len)
286
+
287
+ def loss_fn(_model):
288
+ return preference_loss(
289
+ backend,
290
+ chosen_ids,
291
+ rejected_ids,
292
+ prompt_len_chosen=prompt_len_c,
293
+ prompt_len_rejected=prompt_len_r,
294
+ algo=algo,
295
+ beta=beta,
296
+ reference_backend=reference_backend,
297
+ kl_coeff=kl_coeff,
298
+ train_on_prompt=train_on_prompt,
299
+ )
300
+
301
+ return backend.value_and_grad(loss_fn)
302
+
303
+
304
+ def create_optimizer(backend: Any, *, lr: float, weight_decay: float = 0.0) -> Tuple[Any, Any]:
305
+ """Create optimizer for training.
306
+
307
+ Args:
308
+ backend: LLM backend instance
309
+ lr: Learning rate
310
+ weight_decay: Weight decay coefficient
311
+
312
+ Returns:
313
+ Tuple of (optimizer, parameters)
314
+ """
315
+ return backend.optimizer_and_params(lr=lr, weight_decay=weight_decay)
316
+
317
+
318
+ def optim_step(backend: Any, optimizer: Any, grads: Any) -> None:
319
+ """Execute optimizer step.
320
+
321
+ Args:
322
+ backend: LLM backend instance
323
+ optimizer: Optimizer instance
324
+ grads: Gradients
325
+ """
326
+ if grads is None:
327
+ return
328
+ backend.apply_grads(optimizer, grads)
329
+
330
+
331
+ def save_adapter(backend: Any, out_dir: str, *, metadata: Optional[dict] = None) -> None:
332
+ """Save adapter weights.
333
+
334
+ Args:
335
+ backend: LLM backend instance
336
+ out_dir: Output directory
337
+ metadata: Optional metadata to save
338
+ """
339
+ backend.save_adapter(out_dir, metadata=metadata)
340
+
341
+
342
+ __all__ = [
343
+ # Core classes
344
+ "LoadedModel",
345
+ "load_model",
346
+
347
+ # Operations
348
+ "sample",
349
+ "logprobs",
350
+ "forward_backward",
351
+ "forward_backward_custom",
352
+ "sft_forward_backward",
353
+ "preference_forward_backward",
354
+ "create_optimizer",
355
+ "optim_step",
356
+ "save_adapter",
357
+
358
+ # Futures
359
+ "APIFuture",
360
+ "APIFutureState",
361
+ "SdkFuturePool",
362
+ "completed_future",
363
+ "failed_future",
364
+ "cancelled_future",
365
+
366
+ # Training Client
367
+ "TrainingClient",
368
+ "TrainingBatch",
369
+ "ForwardBackwardResult",
370
+ "OptimizerStepResult",
371
+ "CheckpointResult",
372
+ "WeightsResult",
373
+ "DistillationTrainingClient",
374
+
375
+ # Sampling Client
376
+ "SamplingClient",
377
+ "SampleResult",
378
+ "SampleBatchResult",
379
+ "DistillationSampler",
380
+
381
+ # Losses
382
+ "LOSS_REGISTRY",
383
+ "get_loss",
384
+ "cross_entropy_loss",
385
+ "dpo_loss",
386
+ "orpo_loss",
387
+ "preference_loss",
388
+ "importance_sampling_loss",
389
+ "ppo_loss",
390
+ "cispo_loss",
391
+ "dro_loss",
392
+ ]