mlxsmith 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. mlxsmith/__init__.py +2 -0
  2. mlxsmith/accel/__init__.py +10 -0
  3. mlxsmith/accel/base.py +17 -0
  4. mlxsmith/accel/none.py +13 -0
  5. mlxsmith/accel/zmlx_backend.py +42 -0
  6. mlxsmith/adapters.py +46 -0
  7. mlxsmith/api/__init__.py +48 -0
  8. mlxsmith/api/handlers.py +1217 -0
  9. mlxsmith/api/schemas.py +436 -0
  10. mlxsmith/auth.py +88 -0
  11. mlxsmith/bench.py +102 -0
  12. mlxsmith/cli.py +950 -0
  13. mlxsmith/config.py +543 -0
  14. mlxsmith/config_models.py +261 -0
  15. mlxsmith/data.py +493 -0
  16. mlxsmith/envs/__init__.py +33 -0
  17. mlxsmith/envs/system.py +388 -0
  18. mlxsmith/envs/token_env.py +191 -0
  19. mlxsmith/eval.py +112 -0
  20. mlxsmith/infer.py +140 -0
  21. mlxsmith/llm/__init__.py +16 -0
  22. mlxsmith/llm/backend.py +126 -0
  23. mlxsmith/llm/interface.py +212 -0
  24. mlxsmith/llm/mlx_lm_backend.py +509 -0
  25. mlxsmith/llm/mock_backend.py +228 -0
  26. mlxsmith/llm/registry.py +12 -0
  27. mlxsmith/models.py +257 -0
  28. mlxsmith/orchestrator/__init__.py +25 -0
  29. mlxsmith/orchestrator/daemon.py +454 -0
  30. mlxsmith/orchestrator/inference_worker.py +496 -0
  31. mlxsmith/orchestrator/queue.py +355 -0
  32. mlxsmith/orchestrator/trainer_worker.py +437 -0
  33. mlxsmith/rlm/__init__.py +8 -0
  34. mlxsmith/rlm/corpus.py +74 -0
  35. mlxsmith/rlm/gating.py +90 -0
  36. mlxsmith/rlm/generate.py +249 -0
  37. mlxsmith/rlm/history.py +12 -0
  38. mlxsmith/rlm/inference.py +150 -0
  39. mlxsmith/rlm/loop.py +1297 -0
  40. mlxsmith/rlm/mutate.py +82 -0
  41. mlxsmith/rlm/trainer.py +73 -0
  42. mlxsmith/rlm/weights.py +263 -0
  43. mlxsmith/runs.py +44 -0
  44. mlxsmith/sdk/__init__.py +392 -0
  45. mlxsmith/sdk/future.py +486 -0
  46. mlxsmith/sdk/losses.py +262 -0
  47. mlxsmith/sdk/sampling_client.py +729 -0
  48. mlxsmith/sdk/training_client.py +676 -0
  49. mlxsmith/server.py +376 -0
  50. mlxsmith/train/__init__.py +0 -0
  51. mlxsmith/train/distill.py +279 -0
  52. mlxsmith/train/lora.py +280 -0
  53. mlxsmith/train/pref.py +180 -0
  54. mlxsmith/train/rft.py +458 -0
  55. mlxsmith/train/sft.py +151 -0
  56. mlxsmith/util.py +174 -0
  57. mlxsmith/verifiers/__init__.py +3 -0
  58. mlxsmith/verifiers/compose.py +109 -0
  59. mlxsmith/verifiers/docker_verifier.py +111 -0
  60. mlxsmith/verifiers/jsonschema.py +54 -0
  61. mlxsmith/verifiers/pytest_verifier.py +82 -0
  62. mlxsmith/verifiers/regex.py +15 -0
  63. mlxsmith/verifiers/types.py +10 -0
  64. mlxsmith-0.1.0.dist-info/METADATA +163 -0
  65. mlxsmith-0.1.0.dist-info/RECORD +69 -0
  66. mlxsmith-0.1.0.dist-info/WHEEL +5 -0
  67. mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
  68. mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
  69. mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
mlxsmith/sdk/losses.py ADDED
@@ -0,0 +1,262 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Callable, Optional, Sequence
4
+
5
+
6
+ LOSS_REGISTRY: dict[str, Callable[..., Any]] = {}
7
+
8
+
9
+ def register_loss(name: str):
10
+ def decorator(fn: Callable[..., Any]):
11
+ LOSS_REGISTRY[name] = fn
12
+ return fn
13
+
14
+ return decorator
15
+
16
+
17
+ def get_loss(name: str) -> Callable[..., Any]:
18
+ if name not in LOSS_REGISTRY:
19
+ raise KeyError(f"Unknown loss: {name}")
20
+ return LOSS_REGISTRY[name]
21
+
22
+
23
+ def _require_mx(backend) -> Any:
24
+ mx = getattr(backend, "mx", None)
25
+ if mx is None:
26
+ raise RuntimeError("Backend does not expose mx; cannot compute preference losses.")
27
+ return mx
28
+
29
+
30
+ def _to_mx_scalar(mx: Any, value: Any) -> Any:
31
+ if hasattr(value, "item"):
32
+ try:
33
+ value = value.item()
34
+ except Exception:
35
+ pass
36
+ try:
37
+ return mx.array(value)
38
+ except Exception:
39
+ return value
40
+
41
+
42
+ def _coerce_logprob(mx: Any, value: Any) -> Any:
43
+ if isinstance(value, (list, tuple)):
44
+ total = 0.0
45
+ for v in value:
46
+ try:
47
+ total += float(v)
48
+ except Exception:
49
+ total += float(_to_mx_scalar(mx, v))
50
+ return _to_mx_scalar(mx, total)
51
+ return value
52
+
53
+
54
+ def preference_diff(
55
+ backend,
56
+ chosen_ids: Sequence[int],
57
+ rejected_ids: Sequence[int],
58
+ *,
59
+ prompt_len_chosen: int,
60
+ prompt_len_rejected: int,
61
+ reference_backend: Optional[Any] = None,
62
+ ) -> Any:
63
+ logp_c = backend.sequence_logprob(chosen_ids, prompt_len=prompt_len_chosen)
64
+ logp_r = backend.sequence_logprob(rejected_ids, prompt_len=prompt_len_rejected)
65
+ ref_diff = 0.0
66
+ if reference_backend is not None:
67
+ ref_logp_c = reference_backend.sequence_logprob(chosen_ids, prompt_len=prompt_len_chosen)
68
+ ref_logp_r = reference_backend.sequence_logprob(rejected_ids, prompt_len=prompt_len_rejected)
69
+ ref_diff = ref_logp_c - ref_logp_r
70
+ return (logp_c - logp_r) - ref_diff
71
+
72
+
73
+ @register_loss("dpo")
74
+ def dpo_loss(
75
+ backend,
76
+ chosen_ids: Sequence[int],
77
+ rejected_ids: Sequence[int],
78
+ *,
79
+ prompt_len_chosen: int,
80
+ prompt_len_rejected: int,
81
+ beta: float = 0.1,
82
+ reference_backend: Optional[Any] = None,
83
+ kl_coeff: float = 0.0,
84
+ ) -> Any:
85
+ mx = _require_mx(backend)
86
+ diff = preference_diff(
87
+ backend,
88
+ chosen_ids,
89
+ rejected_ids,
90
+ prompt_len_chosen=prompt_len_chosen,
91
+ prompt_len_rejected=prompt_len_rejected,
92
+ reference_backend=reference_backend,
93
+ )
94
+ scaled = _to_mx_scalar(mx, beta) * diff
95
+ loss = mx.log1p(mx.exp(-scaled))
96
+
97
+ if reference_backend is not None and kl_coeff > 0:
98
+ logp_c = backend.sequence_logprob(chosen_ids, prompt_len=prompt_len_chosen)
99
+ ref_logp_c = reference_backend.sequence_logprob(chosen_ids, prompt_len=prompt_len_chosen)
100
+ loss = loss + _to_mx_scalar(mx, kl_coeff) * (logp_c - ref_logp_c)
101
+
102
+ return loss
103
+
104
+
105
+ @register_loss("orpo")
106
+ def orpo_loss(
107
+ backend,
108
+ chosen_ids: Sequence[int],
109
+ rejected_ids: Sequence[int],
110
+ *,
111
+ prompt_len_chosen: int,
112
+ prompt_len_rejected: int,
113
+ beta: float = 0.1,
114
+ reference_backend: Optional[Any] = None,
115
+ kl_coeff: float = 0.0,
116
+ train_on_prompt: bool = False,
117
+ ) -> Any:
118
+ mx = _require_mx(backend)
119
+ diff = preference_diff(
120
+ backend,
121
+ chosen_ids,
122
+ rejected_ids,
123
+ prompt_len_chosen=prompt_len_chosen,
124
+ prompt_len_rejected=prompt_len_rejected,
125
+ reference_backend=reference_backend,
126
+ )
127
+ nll = backend.sft_loss(chosen_ids, train_on_prompt=train_on_prompt, prompt_len=prompt_len_chosen)
128
+ or_term = -_to_mx_scalar(mx, beta) * mx.log(mx.sigmoid(diff))
129
+ loss = _to_mx_scalar(mx, nll) + or_term
130
+
131
+ if reference_backend is not None and kl_coeff > 0:
132
+ logp_c = backend.sequence_logprob(chosen_ids, prompt_len=prompt_len_chosen)
133
+ ref_logp_c = reference_backend.sequence_logprob(chosen_ids, prompt_len=prompt_len_chosen)
134
+ loss = loss + _to_mx_scalar(mx, kl_coeff) * (logp_c - ref_logp_c)
135
+
136
+ return loss
137
+
138
+
139
+ @register_loss("preference")
140
+ def preference_loss(
141
+ backend,
142
+ chosen_ids: Sequence[int],
143
+ rejected_ids: Sequence[int],
144
+ *,
145
+ prompt_len_chosen: int,
146
+ prompt_len_rejected: int,
147
+ algo: str = "dpo",
148
+ beta: float = 0.1,
149
+ reference_backend: Optional[Any] = None,
150
+ kl_coeff: float = 0.0,
151
+ train_on_prompt: bool = False,
152
+ ) -> Any:
153
+ if algo.lower() == "orpo":
154
+ return orpo_loss(
155
+ backend,
156
+ chosen_ids,
157
+ rejected_ids,
158
+ prompt_len_chosen=prompt_len_chosen,
159
+ prompt_len_rejected=prompt_len_rejected,
160
+ beta=beta,
161
+ reference_backend=reference_backend,
162
+ kl_coeff=kl_coeff,
163
+ train_on_prompt=train_on_prompt,
164
+ )
165
+ return dpo_loss(
166
+ backend,
167
+ chosen_ids,
168
+ rejected_ids,
169
+ prompt_len_chosen=prompt_len_chosen,
170
+ prompt_len_rejected=prompt_len_rejected,
171
+ beta=beta,
172
+ reference_backend=reference_backend,
173
+ kl_coeff=kl_coeff,
174
+ )
175
+
176
+
177
+ @register_loss("cross_entropy")
178
+ def cross_entropy_loss(
179
+ backend,
180
+ token_ids: Sequence[int],
181
+ *,
182
+ prompt_len: int,
183
+ train_on_prompt: bool = False,
184
+ ) -> Any:
185
+ return backend.sft_loss(token_ids, train_on_prompt=train_on_prompt, prompt_len=prompt_len)
186
+
187
+
188
+ def _mx_clip(mx: Any, x: Any, lo: float, hi: float) -> Any:
189
+ if hasattr(mx, "minimum") and hasattr(mx, "maximum"):
190
+ return mx.minimum(mx.maximum(x, _to_mx_scalar(mx, lo)), _to_mx_scalar(mx, hi))
191
+ return min(max(x, lo), hi)
192
+
193
+
194
+ @register_loss("importance_sampling")
195
+ def importance_sampling_loss(
196
+ backend,
197
+ token_ids: Sequence[int],
198
+ *,
199
+ prompt_len: int,
200
+ advantage: float,
201
+ behavior_logprob: Optional[Any] = None,
202
+ ) -> Any:
203
+ mx = _require_mx(backend)
204
+ logp = backend.sequence_logprob(token_ids, prompt_len=prompt_len)
205
+ if behavior_logprob is None:
206
+ behavior_logprob = logp
207
+ behavior_logprob = _coerce_logprob(mx, behavior_logprob)
208
+ ratio = mx.exp(logp - behavior_logprob)
209
+ return -ratio * _to_mx_scalar(mx, advantage)
210
+
211
+
212
+ @register_loss("ppo")
213
+ def ppo_loss(
214
+ backend,
215
+ token_ids: Sequence[int],
216
+ *,
217
+ prompt_len: int,
218
+ advantage: float,
219
+ behavior_logprob: Any,
220
+ clip: float = 0.2,
221
+ ) -> Any:
222
+ mx = _require_mx(backend)
223
+ logp = backend.sequence_logprob(token_ids, prompt_len=prompt_len)
224
+ ratio = mx.exp(logp - behavior_logprob)
225
+ adv = _to_mx_scalar(mx, advantage)
226
+ clipped = _mx_clip(mx, ratio, 1.0 - clip, 1.0 + clip)
227
+ return -mx.minimum(ratio * adv, clipped * adv)
228
+
229
+
230
+ @register_loss("cispo")
231
+ def cispo_loss(
232
+ backend,
233
+ token_ids: Sequence[int],
234
+ *,
235
+ prompt_len: int,
236
+ advantage: float,
237
+ behavior_logprob: Any,
238
+ clip: float = 0.2,
239
+ penalty: float = 0.1,
240
+ ) -> Any:
241
+ mx = _require_mx(backend)
242
+ logp = backend.sequence_logprob(token_ids, prompt_len=prompt_len)
243
+ ratio = mx.exp(logp - behavior_logprob)
244
+ adv = _to_mx_scalar(mx, advantage)
245
+ clipped = _mx_clip(mx, ratio, 1.0 - clip, 1.0 + clip)
246
+ penalty_term = _to_mx_scalar(mx, penalty) * (ratio - clipped) ** 2
247
+ return -(clipped * adv) + penalty_term
248
+
249
+
250
+ @register_loss("dro")
251
+ def dro_loss(
252
+ backend,
253
+ token_ids: Sequence[int],
254
+ *,
255
+ prompt_len: int,
256
+ advantage: float,
257
+ temperature: float = 1.0,
258
+ ) -> Any:
259
+ mx = _require_mx(backend)
260
+ logp = backend.sequence_logprob(token_ids, prompt_len=prompt_len)
261
+ weight = mx.exp(_to_mx_scalar(mx, advantage) / _to_mx_scalar(mx, temperature))
262
+ return -weight * logp