mlxsmith 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. mlxsmith/__init__.py +2 -0
  2. mlxsmith/accel/__init__.py +10 -0
  3. mlxsmith/accel/base.py +17 -0
  4. mlxsmith/accel/none.py +13 -0
  5. mlxsmith/accel/zmlx_backend.py +42 -0
  6. mlxsmith/adapters.py +46 -0
  7. mlxsmith/api/__init__.py +48 -0
  8. mlxsmith/api/handlers.py +1217 -0
  9. mlxsmith/api/schemas.py +436 -0
  10. mlxsmith/auth.py +88 -0
  11. mlxsmith/bench.py +102 -0
  12. mlxsmith/cli.py +950 -0
  13. mlxsmith/config.py +543 -0
  14. mlxsmith/config_models.py +261 -0
  15. mlxsmith/data.py +493 -0
  16. mlxsmith/envs/__init__.py +33 -0
  17. mlxsmith/envs/system.py +388 -0
  18. mlxsmith/envs/token_env.py +191 -0
  19. mlxsmith/eval.py +112 -0
  20. mlxsmith/infer.py +140 -0
  21. mlxsmith/llm/__init__.py +16 -0
  22. mlxsmith/llm/backend.py +126 -0
  23. mlxsmith/llm/interface.py +212 -0
  24. mlxsmith/llm/mlx_lm_backend.py +509 -0
  25. mlxsmith/llm/mock_backend.py +228 -0
  26. mlxsmith/llm/registry.py +12 -0
  27. mlxsmith/models.py +257 -0
  28. mlxsmith/orchestrator/__init__.py +25 -0
  29. mlxsmith/orchestrator/daemon.py +454 -0
  30. mlxsmith/orchestrator/inference_worker.py +496 -0
  31. mlxsmith/orchestrator/queue.py +355 -0
  32. mlxsmith/orchestrator/trainer_worker.py +437 -0
  33. mlxsmith/rlm/__init__.py +8 -0
  34. mlxsmith/rlm/corpus.py +74 -0
  35. mlxsmith/rlm/gating.py +90 -0
  36. mlxsmith/rlm/generate.py +249 -0
  37. mlxsmith/rlm/history.py +12 -0
  38. mlxsmith/rlm/inference.py +150 -0
  39. mlxsmith/rlm/loop.py +1297 -0
  40. mlxsmith/rlm/mutate.py +82 -0
  41. mlxsmith/rlm/trainer.py +73 -0
  42. mlxsmith/rlm/weights.py +263 -0
  43. mlxsmith/runs.py +44 -0
  44. mlxsmith/sdk/__init__.py +392 -0
  45. mlxsmith/sdk/future.py +486 -0
  46. mlxsmith/sdk/losses.py +262 -0
  47. mlxsmith/sdk/sampling_client.py +729 -0
  48. mlxsmith/sdk/training_client.py +676 -0
  49. mlxsmith/server.py +376 -0
  50. mlxsmith/train/__init__.py +0 -0
  51. mlxsmith/train/distill.py +279 -0
  52. mlxsmith/train/lora.py +280 -0
  53. mlxsmith/train/pref.py +180 -0
  54. mlxsmith/train/rft.py +458 -0
  55. mlxsmith/train/sft.py +151 -0
  56. mlxsmith/util.py +174 -0
  57. mlxsmith/verifiers/__init__.py +3 -0
  58. mlxsmith/verifiers/compose.py +109 -0
  59. mlxsmith/verifiers/docker_verifier.py +111 -0
  60. mlxsmith/verifiers/jsonschema.py +54 -0
  61. mlxsmith/verifiers/pytest_verifier.py +82 -0
  62. mlxsmith/verifiers/regex.py +15 -0
  63. mlxsmith/verifiers/types.py +10 -0
  64. mlxsmith-0.1.0.dist-info/METADATA +163 -0
  65. mlxsmith-0.1.0.dist-info/RECORD +69 -0
  66. mlxsmith-0.1.0.dist-info/WHEEL +5 -0
  67. mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
  68. mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
  69. mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
mlxsmith/data.py ADDED
@@ -0,0 +1,493 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from datetime import datetime, timezone
5
+ from pathlib import Path
6
+
7
+ from .util import ensure_dir
8
+
9
+
10
+ PRESET_DATASETS: dict[str, dict] = {
11
+ "alpaca": {
12
+ "dataset": "tatsu-lab/alpaca",
13
+ "kind": "sft",
14
+ "prompt_field": "instruction",
15
+ "response_field": "output",
16
+ "license": "cc-by-nc-4.0",
17
+ },
18
+ "hh-rlhf": {
19
+ "dataset": "Anthropic/hh-rlhf",
20
+ "kind": "prefs",
21
+ "chosen_field": "chosen",
22
+ "rejected_field": "rejected",
23
+ "license": "mit",
24
+ },
25
+ "ultrachat-200k": {
26
+ "dataset": "HuggingFaceH4/ultrachat_200k",
27
+ "kind": "sft",
28
+ "split": "train_sft",
29
+ "config": "default",
30
+ "license": "mit",
31
+ },
32
+ "ultrafeedback-binarized-prefs": {
33
+ "dataset": "HuggingFaceH4/ultrafeedback_binarized",
34
+ "kind": "prefs",
35
+ "split": "train_prefs",
36
+ "config": "default",
37
+ "prompt_field": "prompt",
38
+ "chosen_field": "chosen",
39
+ "rejected_field": "rejected",
40
+ "license": "mit",
41
+ },
42
+ "ultrafeedback-binarized-sft": {
43
+ "dataset": "HuggingFaceH4/ultrafeedback_binarized",
44
+ "kind": "sft",
45
+ "split": "train_sft",
46
+ "config": "default",
47
+ "license": "mit",
48
+ },
49
+ }
50
+
51
+
52
+ def list_presets() -> dict[str, dict]:
53
+ return {k: dict(v) for k, v in PRESET_DATASETS.items()}
54
+
55
+
56
+ def resolve_preset(name: str) -> dict:
57
+ key = name.strip()
58
+ if key not in PRESET_DATASETS:
59
+ options = ", ".join(sorted(PRESET_DATASETS.keys()))
60
+ raise ValueError(f"Unknown preset: {name}. Available: {options}")
61
+ return dict(PRESET_DATASETS[key])
62
+
63
+
64
+ def _utc_now_iso() -> str:
65
+ return datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z")
66
+
67
+
68
+ def _select_field_value(row: dict, keys: list[str | None]):
69
+ for key in keys:
70
+ if not key:
71
+ continue
72
+ if key in row and row[key] not in (None, ""):
73
+ return row[key]
74
+ return None
75
+
76
+
77
+ def _messages_to_text(msgs: list) -> str:
78
+ parts: list[str] = []
79
+ for msg in msgs:
80
+ if msg is None:
81
+ continue
82
+ if isinstance(msg, dict):
83
+ text = msg.get("content") or msg.get("value") or msg.get("text") or ""
84
+ else:
85
+ text = str(msg)
86
+ if text:
87
+ parts.append(str(text))
88
+ return "\n".join(parts)
89
+
90
+
91
+ def _coerce_text(val) -> str:
92
+ if val is None:
93
+ return ""
94
+ if isinstance(val, str):
95
+ return val
96
+ if isinstance(val, (int, float, bool)):
97
+ return str(val)
98
+ if isinstance(val, list):
99
+ if not val:
100
+ return ""
101
+ if all(isinstance(v, dict) for v in val):
102
+ return _messages_to_text(val)
103
+ return "\n".join(str(v) for v in val if v not in (None, ""))
104
+ if isinstance(val, dict):
105
+ for key in ("text", "content", "value"):
106
+ if key in val and val[key] not in (None, ""):
107
+ return str(val[key])
108
+ if "messages" in val and isinstance(val["messages"], list):
109
+ return _messages_to_text(val["messages"])
110
+ return str(val)
111
+
112
+
113
+ def _messages_from_value(val) -> list | None:
114
+ if isinstance(val, dict) and isinstance(val.get("messages"), list):
115
+ return val.get("messages")
116
+ if isinstance(val, list) and val and all(isinstance(m, dict) for m in val):
117
+ return val
118
+ return None
119
+
120
+
121
+ def _select_field(row: dict, keys: list[str]) -> str:
122
+ val = _select_field_value(row, keys)
123
+ return _coerce_text(val) if val is not None else ""
124
+
125
+
126
+ def _extract_prompt_response_from_messages(val) -> tuple[str, str]:
127
+ msgs = _messages_from_value(val)
128
+ if not msgs:
129
+ return "", ""
130
+ if len(msgs) == 1:
131
+ return "", _coerce_text(msgs[0])
132
+ prompt = _messages_to_text(msgs[:-1])
133
+ response = _coerce_text(msgs[-1])
134
+ return prompt, response
135
+
136
+
137
+ def _row_to_prompt_response(row: dict, prompt_field: str | None = None, response_field: str | None = None) -> tuple[str, str]:
138
+ prompt = _select_field(
139
+ row,
140
+ [
141
+ prompt_field,
142
+ "prompt",
143
+ "instruction",
144
+ "input",
145
+ "question",
146
+ ],
147
+ )
148
+ response = _select_field(
149
+ row,
150
+ [
151
+ response_field,
152
+ "response",
153
+ "output",
154
+ "answer",
155
+ "completion",
156
+ ],
157
+ )
158
+ if "messages" in row and isinstance(row.get("messages"), list):
159
+ msg_prompt, msg_response = _extract_prompt_response_from_messages(row.get("messages"))
160
+ if msg_prompt and (not prompt or len(msg_prompt) > len(prompt)):
161
+ prompt = msg_prompt
162
+ if not response and msg_response:
163
+ response = msg_response
164
+ return prompt, response
165
+
166
+
167
+ def _row_to_pref(
168
+ row: dict,
169
+ prompt_field: str | None = None,
170
+ chosen_field: str | None = None,
171
+ rejected_field: str | None = None,
172
+ ) -> tuple[str, str, str]:
173
+ prompt = _select_field(
174
+ row,
175
+ [
176
+ prompt_field,
177
+ "prompt",
178
+ "instruction",
179
+ "input",
180
+ "question",
181
+ "query",
182
+ "context",
183
+ "history",
184
+ ],
185
+ )
186
+
187
+ chosen_val = _select_field_value(
188
+ row,
189
+ [
190
+ chosen_field,
191
+ "chosen",
192
+ "accepted",
193
+ "preferred",
194
+ "chosen_response",
195
+ "response",
196
+ "output",
197
+ "answer",
198
+ ],
199
+ )
200
+ rejected_val = _select_field_value(
201
+ row,
202
+ [
203
+ rejected_field,
204
+ "rejected",
205
+ "rejected_response",
206
+ "rejected_output",
207
+ "rejected_answer",
208
+ "dispreferred",
209
+ ],
210
+ )
211
+
212
+ chosen = _coerce_text(chosen_val)
213
+ rejected = _coerce_text(rejected_val)
214
+
215
+ if chosen_val is not None:
216
+ msg_prompt, msg_response = _extract_prompt_response_from_messages(chosen_val)
217
+ if msg_prompt and (not prompt or len(msg_prompt) > len(prompt)):
218
+ prompt = msg_prompt
219
+ if msg_response:
220
+ chosen = msg_response
221
+
222
+ if rejected_val is not None:
223
+ msg_prompt, msg_response = _extract_prompt_response_from_messages(rejected_val)
224
+ if msg_prompt and (not prompt or len(msg_prompt) > len(prompt)):
225
+ prompt = msg_prompt
226
+ if msg_response:
227
+ rejected = msg_response
228
+
229
+ if (not prompt or not chosen) and chosen_val is not None:
230
+ msg_prompt, msg_response = _extract_prompt_response_from_messages(chosen_val)
231
+ if not prompt and msg_prompt:
232
+ prompt = msg_prompt
233
+ if not chosen and msg_response:
234
+ chosen = msg_response
235
+
236
+ if (not prompt or not rejected) and rejected_val is not None:
237
+ msg_prompt, msg_response = _extract_prompt_response_from_messages(rejected_val)
238
+ if not prompt and msg_prompt:
239
+ prompt = msg_prompt
240
+ if not rejected and msg_response:
241
+ rejected = msg_response
242
+
243
+ if not prompt and "messages" in row and isinstance(row.get("messages"), list):
244
+ msg_prompt, _msg_response = _extract_prompt_response_from_messages(row.get("messages"))
245
+ if msg_prompt:
246
+ prompt = msg_prompt
247
+
248
+ return prompt, chosen, rejected
249
+
250
+
251
+ def _write_provenance(out_dir: Path, metadata: dict) -> Path:
252
+ ensure_dir(out_dir)
253
+ meta_path = out_dir / "metadata.json"
254
+ meta_path.write_text(json.dumps(metadata, indent=2, ensure_ascii=False) + "\n", encoding="utf-8")
255
+ return meta_path
256
+
257
+ def import_sharegpt(in_path: Path, out_path: Path):
258
+ """Convert ShareGPT-ish JSONL into {prompt, response} JSONL.
259
+
260
+ Expected input lines: {"conversations":[{"from":"human","value":"..."},{"from":"gpt","value":"..."}], ...}
261
+ """
262
+ ensure_dir(out_path.parent)
263
+ n = 0
264
+ with in_path.open("r", encoding="utf-8") as fin, out_path.open("w", encoding="utf-8") as fout:
265
+ for line in fin:
266
+ line = line.strip()
267
+ if not line:
268
+ continue
269
+ obj = json.loads(line)
270
+ conv = obj.get("conversations") or []
271
+ # naive: take first human then first assistant after it
272
+ prompt = None
273
+ response = None
274
+ for turn in conv:
275
+ frm = (turn.get("from") or "").lower()
276
+ if prompt is None and frm in ("human", "user"):
277
+ prompt = turn.get("value") or ""
278
+ elif prompt is not None and response is None and frm in ("gpt", "assistant"):
279
+ response = turn.get("value") or ""
280
+ break
281
+ if prompt is None or response is None:
282
+ continue
283
+ fout.write(json.dumps({"prompt": prompt, "response": response}, ensure_ascii=False) + "\n")
284
+ n += 1
285
+ return n
286
+
287
+ def split_jsonl(in_path: Path, out_dir: Path, valid_frac: float, test_frac: float, seed: int = 1337):
288
+ import random
289
+ random.seed(seed)
290
+ rows = [json.loads(line) for line in in_path.read_text(encoding="utf-8").splitlines() if line.strip()]
291
+ random.shuffle(rows)
292
+ n = len(rows)
293
+ n_test = int(n * test_frac)
294
+ n_valid = int(n * valid_frac)
295
+ test = rows[:n_test]
296
+ valid = rows[n_test:n_test+n_valid]
297
+ train = rows[n_test+n_valid:]
298
+ ensure_dir(out_dir)
299
+ for name, part in [("train.jsonl", train), ("valid.jsonl", valid), ("test.jsonl", test)]:
300
+ (out_dir / name).write_text("\n".join(json.dumps(r, ensure_ascii=False) for r in part) + ("\n" if part else ""), encoding="utf-8")
301
+ return {"n": n, "train": len(train), "valid": len(valid), "test": len(test)}
302
+
303
+
304
+ def pull_hf_dataset(
305
+ dataset: str,
306
+ out_dir: Path,
307
+ split: str = "train",
308
+ limit: int | None = None,
309
+ prompt_field: str | None = None,
310
+ response_field: str | None = None,
311
+ chosen_field: str | None = None,
312
+ rejected_field: str | None = None,
313
+ config: str | None = None,
314
+ revision: str | None = None,
315
+ kind: str = "sft",
316
+ license: str | None = None,
317
+ notes: str | None = None,
318
+ preset: str | None = None,
319
+ write_metadata: bool = True,
320
+ ) -> dict:
321
+ """Download a HF dataset split and write prompt/response JSONL.
322
+
323
+ Args:
324
+ dataset: HF dataset name (e.g. "tatsu-lab/alpaca")
325
+ out_dir: Output directory (data/sft or data/prefs)
326
+ split: Dataset split to pull
327
+ limit: Optional max rows to write
328
+ prompt_field: Optional override for prompt field name
329
+ response_field: Optional override for response field name
330
+ config: Optional dataset config/subset name
331
+ revision: Optional dataset revision
332
+
333
+ Returns:
334
+ Stats dict with counts.
335
+ """
336
+ try:
337
+ from datasets import load_dataset # type: ignore
338
+ except Exception as e: # pragma: no cover
339
+ raise RuntimeError(f"datasets not available: {e}")
340
+
341
+ ensure_dir(out_dir)
342
+ ds = load_dataset(dataset, name=config, split=split, revision=revision)
343
+ total = len(ds)
344
+ n = 0
345
+ skipped = 0
346
+ kind_norm = kind.strip().lower()
347
+ if kind_norm in ("pref", "prefs", "preference", "preferences"):
348
+ kind_norm = "prefs"
349
+ elif kind_norm in ("sft", "supervised"):
350
+ kind_norm = "sft"
351
+ else:
352
+ raise ValueError(f"Unsupported kind: {kind}")
353
+ out_path = out_dir / "train.jsonl"
354
+ with out_path.open("w", encoding="utf-8") as fout:
355
+ for row in ds:
356
+ if kind_norm == "prefs":
357
+ prompt, chosen, rejected = _row_to_pref(row, prompt_field, chosen_field, rejected_field)
358
+ if not (prompt and chosen and rejected):
359
+ skipped += 1
360
+ continue
361
+ fout.write(
362
+ json.dumps(
363
+ {"prompt": prompt, "chosen": chosen, "rejected": rejected},
364
+ ensure_ascii=False,
365
+ )
366
+ + "\n"
367
+ )
368
+ else:
369
+ prompt, response = _row_to_prompt_response(row, prompt_field, response_field)
370
+ if not (prompt and response):
371
+ skipped += 1
372
+ continue
373
+ fout.write(json.dumps({"prompt": prompt, "response": response}, ensure_ascii=False) + "\n")
374
+ n += 1
375
+ if limit and n >= limit:
376
+ break
377
+ metadata = {
378
+ "dataset": dataset,
379
+ "preset": preset,
380
+ "config": config,
381
+ "split": split,
382
+ "revision": revision,
383
+ "license": license,
384
+ "notes": notes,
385
+ "kind": kind_norm,
386
+ "total": total,
387
+ "written": n,
388
+ "skipped": skipped,
389
+ "limit": limit,
390
+ "prompt_field": prompt_field,
391
+ "response_field": response_field if kind_norm == "sft" else None,
392
+ "chosen_field": chosen_field if kind_norm == "prefs" else None,
393
+ "rejected_field": rejected_field if kind_norm == "prefs" else None,
394
+ "out": str(out_path),
395
+ "generated_at": _utc_now_iso(),
396
+ }
397
+ if write_metadata:
398
+ meta_path = _write_provenance(out_dir, metadata)
399
+ metadata["metadata"] = str(meta_path)
400
+ return metadata
401
+
402
+
403
+ def _normalize_kind(kind: str | None) -> str | None:
404
+ if not kind:
405
+ return None
406
+ kind_norm = kind.strip().lower()
407
+ if kind_norm in ("pref", "prefs", "preference", "preferences"):
408
+ return "prefs"
409
+ if kind_norm in ("sft", "supervised"):
410
+ return "sft"
411
+ return kind_norm
412
+
413
+
414
+ def _infer_kind_from_row(row: dict) -> str:
415
+ if "chosen" in row or "rejected" in row:
416
+ return "prefs"
417
+ if "response" in row:
418
+ return "sft"
419
+ if "messages" in row:
420
+ return "sft"
421
+ return "unknown"
422
+
423
+
424
+ def analyze_jsonl(path: Path, kind: str | None = None, limit: int | None = None) -> dict:
425
+ stats = {
426
+ "rows": 0,
427
+ "empty_lines": 0,
428
+ "bad_json": 0,
429
+ "missing_prompt": 0,
430
+ "missing_response": 0,
431
+ "missing_chosen": 0,
432
+ "missing_rejected": 0,
433
+ "prompt_chars": 0,
434
+ "response_chars": 0,
435
+ "chosen_chars": 0,
436
+ "rejected_chars": 0,
437
+ "prompt_count": 0,
438
+ "response_count": 0,
439
+ "chosen_count": 0,
440
+ "rejected_count": 0,
441
+ "kind": None,
442
+ }
443
+ kind_norm = _normalize_kind(kind)
444
+ with path.open("r", encoding="utf-8") as fin:
445
+ for line in fin:
446
+ if limit and stats["rows"] >= limit:
447
+ break
448
+ if not line.strip():
449
+ stats["empty_lines"] += 1
450
+ continue
451
+ try:
452
+ row = json.loads(line)
453
+ except json.JSONDecodeError:
454
+ stats["bad_json"] += 1
455
+ continue
456
+ stats["rows"] += 1
457
+ if not kind_norm:
458
+ kind_norm = _infer_kind_from_row(row)
459
+ prompt = _coerce_text(row.get("prompt"))
460
+ if not prompt:
461
+ stats["missing_prompt"] += 1
462
+ else:
463
+ stats["prompt_chars"] += len(prompt)
464
+ stats["prompt_count"] += 1
465
+ if kind_norm == "prefs":
466
+ chosen = _coerce_text(row.get("chosen"))
467
+ rejected = _coerce_text(row.get("rejected"))
468
+ if not chosen:
469
+ stats["missing_chosen"] += 1
470
+ else:
471
+ stats["chosen_chars"] += len(chosen)
472
+ stats["chosen_count"] += 1
473
+ if not rejected:
474
+ stats["missing_rejected"] += 1
475
+ else:
476
+ stats["rejected_chars"] += len(rejected)
477
+ stats["rejected_count"] += 1
478
+ elif kind_norm == "sft":
479
+ response = _coerce_text(row.get("response"))
480
+ if not response:
481
+ stats["missing_response"] += 1
482
+ else:
483
+ stats["response_chars"] += len(response)
484
+ stats["response_count"] += 1
485
+ else:
486
+ response = _coerce_text(row.get("response"))
487
+ if response:
488
+ stats["response_chars"] += len(response)
489
+ stats["response_count"] += 1
490
+ else:
491
+ stats["missing_response"] += 1
492
+ stats["kind"] = kind_norm or "unknown"
493
+ return stats
@@ -0,0 +1,33 @@
1
+ from .system import (
2
+ EnvManifest,
3
+ EnvRef,
4
+ init_env,
5
+ install_env,
6
+ list_registry_packages,
7
+ load_manifest,
8
+ package_env,
9
+ pull_env,
10
+ publish_env,
11
+ registry_info,
12
+ resolve_env_path,
13
+ )
14
+ from .token_env import TokenEnv, TokenEnvStep, load_token_env_spec, create_token_env, StringTaskTokenEnv
15
+
16
+ __all__ = [
17
+ "EnvManifest",
18
+ "EnvRef",
19
+ "init_env",
20
+ "install_env",
21
+ "list_registry_packages",
22
+ "load_manifest",
23
+ "package_env",
24
+ "pull_env",
25
+ "publish_env",
26
+ "registry_info",
27
+ "resolve_env_path",
28
+ "TokenEnv",
29
+ "TokenEnvStep",
30
+ "load_token_env_spec",
31
+ "create_token_env",
32
+ "StringTaskTokenEnv",
33
+ ]