mlxsmith 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mlxsmith/__init__.py +2 -0
- mlxsmith/accel/__init__.py +10 -0
- mlxsmith/accel/base.py +17 -0
- mlxsmith/accel/none.py +13 -0
- mlxsmith/accel/zmlx_backend.py +42 -0
- mlxsmith/adapters.py +46 -0
- mlxsmith/api/__init__.py +48 -0
- mlxsmith/api/handlers.py +1217 -0
- mlxsmith/api/schemas.py +436 -0
- mlxsmith/auth.py +88 -0
- mlxsmith/bench.py +102 -0
- mlxsmith/cli.py +950 -0
- mlxsmith/config.py +543 -0
- mlxsmith/config_models.py +261 -0
- mlxsmith/data.py +493 -0
- mlxsmith/envs/__init__.py +33 -0
- mlxsmith/envs/system.py +388 -0
- mlxsmith/envs/token_env.py +191 -0
- mlxsmith/eval.py +112 -0
- mlxsmith/infer.py +140 -0
- mlxsmith/llm/__init__.py +16 -0
- mlxsmith/llm/backend.py +126 -0
- mlxsmith/llm/interface.py +212 -0
- mlxsmith/llm/mlx_lm_backend.py +509 -0
- mlxsmith/llm/mock_backend.py +228 -0
- mlxsmith/llm/registry.py +12 -0
- mlxsmith/models.py +257 -0
- mlxsmith/orchestrator/__init__.py +25 -0
- mlxsmith/orchestrator/daemon.py +454 -0
- mlxsmith/orchestrator/inference_worker.py +496 -0
- mlxsmith/orchestrator/queue.py +355 -0
- mlxsmith/orchestrator/trainer_worker.py +437 -0
- mlxsmith/rlm/__init__.py +8 -0
- mlxsmith/rlm/corpus.py +74 -0
- mlxsmith/rlm/gating.py +90 -0
- mlxsmith/rlm/generate.py +249 -0
- mlxsmith/rlm/history.py +12 -0
- mlxsmith/rlm/inference.py +150 -0
- mlxsmith/rlm/loop.py +1297 -0
- mlxsmith/rlm/mutate.py +82 -0
- mlxsmith/rlm/trainer.py +73 -0
- mlxsmith/rlm/weights.py +263 -0
- mlxsmith/runs.py +44 -0
- mlxsmith/sdk/__init__.py +392 -0
- mlxsmith/sdk/future.py +486 -0
- mlxsmith/sdk/losses.py +262 -0
- mlxsmith/sdk/sampling_client.py +729 -0
- mlxsmith/sdk/training_client.py +676 -0
- mlxsmith/server.py +376 -0
- mlxsmith/train/__init__.py +0 -0
- mlxsmith/train/distill.py +279 -0
- mlxsmith/train/lora.py +280 -0
- mlxsmith/train/pref.py +180 -0
- mlxsmith/train/rft.py +458 -0
- mlxsmith/train/sft.py +151 -0
- mlxsmith/util.py +174 -0
- mlxsmith/verifiers/__init__.py +3 -0
- mlxsmith/verifiers/compose.py +109 -0
- mlxsmith/verifiers/docker_verifier.py +111 -0
- mlxsmith/verifiers/jsonschema.py +54 -0
- mlxsmith/verifiers/pytest_verifier.py +82 -0
- mlxsmith/verifiers/regex.py +15 -0
- mlxsmith/verifiers/types.py +10 -0
- mlxsmith-0.1.0.dist-info/METADATA +163 -0
- mlxsmith-0.1.0.dist-info/RECORD +69 -0
- mlxsmith-0.1.0.dist-info/WHEEL +5 -0
- mlxsmith-0.1.0.dist-info/entry_points.txt +2 -0
- mlxsmith-0.1.0.dist-info/licenses/LICENSE +21 -0
- mlxsmith-0.1.0.dist-info/top_level.txt +1 -0
mlxsmith/data.py
ADDED
|
@@ -0,0 +1,493 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from datetime import datetime, timezone
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
from .util import ensure_dir
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
PRESET_DATASETS: dict[str, dict] = {
|
|
11
|
+
"alpaca": {
|
|
12
|
+
"dataset": "tatsu-lab/alpaca",
|
|
13
|
+
"kind": "sft",
|
|
14
|
+
"prompt_field": "instruction",
|
|
15
|
+
"response_field": "output",
|
|
16
|
+
"license": "cc-by-nc-4.0",
|
|
17
|
+
},
|
|
18
|
+
"hh-rlhf": {
|
|
19
|
+
"dataset": "Anthropic/hh-rlhf",
|
|
20
|
+
"kind": "prefs",
|
|
21
|
+
"chosen_field": "chosen",
|
|
22
|
+
"rejected_field": "rejected",
|
|
23
|
+
"license": "mit",
|
|
24
|
+
},
|
|
25
|
+
"ultrachat-200k": {
|
|
26
|
+
"dataset": "HuggingFaceH4/ultrachat_200k",
|
|
27
|
+
"kind": "sft",
|
|
28
|
+
"split": "train_sft",
|
|
29
|
+
"config": "default",
|
|
30
|
+
"license": "mit",
|
|
31
|
+
},
|
|
32
|
+
"ultrafeedback-binarized-prefs": {
|
|
33
|
+
"dataset": "HuggingFaceH4/ultrafeedback_binarized",
|
|
34
|
+
"kind": "prefs",
|
|
35
|
+
"split": "train_prefs",
|
|
36
|
+
"config": "default",
|
|
37
|
+
"prompt_field": "prompt",
|
|
38
|
+
"chosen_field": "chosen",
|
|
39
|
+
"rejected_field": "rejected",
|
|
40
|
+
"license": "mit",
|
|
41
|
+
},
|
|
42
|
+
"ultrafeedback-binarized-sft": {
|
|
43
|
+
"dataset": "HuggingFaceH4/ultrafeedback_binarized",
|
|
44
|
+
"kind": "sft",
|
|
45
|
+
"split": "train_sft",
|
|
46
|
+
"config": "default",
|
|
47
|
+
"license": "mit",
|
|
48
|
+
},
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def list_presets() -> dict[str, dict]:
|
|
53
|
+
return {k: dict(v) for k, v in PRESET_DATASETS.items()}
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def resolve_preset(name: str) -> dict:
|
|
57
|
+
key = name.strip()
|
|
58
|
+
if key not in PRESET_DATASETS:
|
|
59
|
+
options = ", ".join(sorted(PRESET_DATASETS.keys()))
|
|
60
|
+
raise ValueError(f"Unknown preset: {name}. Available: {options}")
|
|
61
|
+
return dict(PRESET_DATASETS[key])
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _utc_now_iso() -> str:
|
|
65
|
+
return datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z")
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _select_field_value(row: dict, keys: list[str | None]):
|
|
69
|
+
for key in keys:
|
|
70
|
+
if not key:
|
|
71
|
+
continue
|
|
72
|
+
if key in row and row[key] not in (None, ""):
|
|
73
|
+
return row[key]
|
|
74
|
+
return None
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _messages_to_text(msgs: list) -> str:
|
|
78
|
+
parts: list[str] = []
|
|
79
|
+
for msg in msgs:
|
|
80
|
+
if msg is None:
|
|
81
|
+
continue
|
|
82
|
+
if isinstance(msg, dict):
|
|
83
|
+
text = msg.get("content") or msg.get("value") or msg.get("text") or ""
|
|
84
|
+
else:
|
|
85
|
+
text = str(msg)
|
|
86
|
+
if text:
|
|
87
|
+
parts.append(str(text))
|
|
88
|
+
return "\n".join(parts)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _coerce_text(val) -> str:
|
|
92
|
+
if val is None:
|
|
93
|
+
return ""
|
|
94
|
+
if isinstance(val, str):
|
|
95
|
+
return val
|
|
96
|
+
if isinstance(val, (int, float, bool)):
|
|
97
|
+
return str(val)
|
|
98
|
+
if isinstance(val, list):
|
|
99
|
+
if not val:
|
|
100
|
+
return ""
|
|
101
|
+
if all(isinstance(v, dict) for v in val):
|
|
102
|
+
return _messages_to_text(val)
|
|
103
|
+
return "\n".join(str(v) for v in val if v not in (None, ""))
|
|
104
|
+
if isinstance(val, dict):
|
|
105
|
+
for key in ("text", "content", "value"):
|
|
106
|
+
if key in val and val[key] not in (None, ""):
|
|
107
|
+
return str(val[key])
|
|
108
|
+
if "messages" in val and isinstance(val["messages"], list):
|
|
109
|
+
return _messages_to_text(val["messages"])
|
|
110
|
+
return str(val)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def _messages_from_value(val) -> list | None:
|
|
114
|
+
if isinstance(val, dict) and isinstance(val.get("messages"), list):
|
|
115
|
+
return val.get("messages")
|
|
116
|
+
if isinstance(val, list) and val and all(isinstance(m, dict) for m in val):
|
|
117
|
+
return val
|
|
118
|
+
return None
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def _select_field(row: dict, keys: list[str]) -> str:
|
|
122
|
+
val = _select_field_value(row, keys)
|
|
123
|
+
return _coerce_text(val) if val is not None else ""
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def _extract_prompt_response_from_messages(val) -> tuple[str, str]:
|
|
127
|
+
msgs = _messages_from_value(val)
|
|
128
|
+
if not msgs:
|
|
129
|
+
return "", ""
|
|
130
|
+
if len(msgs) == 1:
|
|
131
|
+
return "", _coerce_text(msgs[0])
|
|
132
|
+
prompt = _messages_to_text(msgs[:-1])
|
|
133
|
+
response = _coerce_text(msgs[-1])
|
|
134
|
+
return prompt, response
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def _row_to_prompt_response(row: dict, prompt_field: str | None = None, response_field: str | None = None) -> tuple[str, str]:
|
|
138
|
+
prompt = _select_field(
|
|
139
|
+
row,
|
|
140
|
+
[
|
|
141
|
+
prompt_field,
|
|
142
|
+
"prompt",
|
|
143
|
+
"instruction",
|
|
144
|
+
"input",
|
|
145
|
+
"question",
|
|
146
|
+
],
|
|
147
|
+
)
|
|
148
|
+
response = _select_field(
|
|
149
|
+
row,
|
|
150
|
+
[
|
|
151
|
+
response_field,
|
|
152
|
+
"response",
|
|
153
|
+
"output",
|
|
154
|
+
"answer",
|
|
155
|
+
"completion",
|
|
156
|
+
],
|
|
157
|
+
)
|
|
158
|
+
if "messages" in row and isinstance(row.get("messages"), list):
|
|
159
|
+
msg_prompt, msg_response = _extract_prompt_response_from_messages(row.get("messages"))
|
|
160
|
+
if msg_prompt and (not prompt or len(msg_prompt) > len(prompt)):
|
|
161
|
+
prompt = msg_prompt
|
|
162
|
+
if not response and msg_response:
|
|
163
|
+
response = msg_response
|
|
164
|
+
return prompt, response
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def _row_to_pref(
|
|
168
|
+
row: dict,
|
|
169
|
+
prompt_field: str | None = None,
|
|
170
|
+
chosen_field: str | None = None,
|
|
171
|
+
rejected_field: str | None = None,
|
|
172
|
+
) -> tuple[str, str, str]:
|
|
173
|
+
prompt = _select_field(
|
|
174
|
+
row,
|
|
175
|
+
[
|
|
176
|
+
prompt_field,
|
|
177
|
+
"prompt",
|
|
178
|
+
"instruction",
|
|
179
|
+
"input",
|
|
180
|
+
"question",
|
|
181
|
+
"query",
|
|
182
|
+
"context",
|
|
183
|
+
"history",
|
|
184
|
+
],
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
chosen_val = _select_field_value(
|
|
188
|
+
row,
|
|
189
|
+
[
|
|
190
|
+
chosen_field,
|
|
191
|
+
"chosen",
|
|
192
|
+
"accepted",
|
|
193
|
+
"preferred",
|
|
194
|
+
"chosen_response",
|
|
195
|
+
"response",
|
|
196
|
+
"output",
|
|
197
|
+
"answer",
|
|
198
|
+
],
|
|
199
|
+
)
|
|
200
|
+
rejected_val = _select_field_value(
|
|
201
|
+
row,
|
|
202
|
+
[
|
|
203
|
+
rejected_field,
|
|
204
|
+
"rejected",
|
|
205
|
+
"rejected_response",
|
|
206
|
+
"rejected_output",
|
|
207
|
+
"rejected_answer",
|
|
208
|
+
"dispreferred",
|
|
209
|
+
],
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
chosen = _coerce_text(chosen_val)
|
|
213
|
+
rejected = _coerce_text(rejected_val)
|
|
214
|
+
|
|
215
|
+
if chosen_val is not None:
|
|
216
|
+
msg_prompt, msg_response = _extract_prompt_response_from_messages(chosen_val)
|
|
217
|
+
if msg_prompt and (not prompt or len(msg_prompt) > len(prompt)):
|
|
218
|
+
prompt = msg_prompt
|
|
219
|
+
if msg_response:
|
|
220
|
+
chosen = msg_response
|
|
221
|
+
|
|
222
|
+
if rejected_val is not None:
|
|
223
|
+
msg_prompt, msg_response = _extract_prompt_response_from_messages(rejected_val)
|
|
224
|
+
if msg_prompt and (not prompt or len(msg_prompt) > len(prompt)):
|
|
225
|
+
prompt = msg_prompt
|
|
226
|
+
if msg_response:
|
|
227
|
+
rejected = msg_response
|
|
228
|
+
|
|
229
|
+
if (not prompt or not chosen) and chosen_val is not None:
|
|
230
|
+
msg_prompt, msg_response = _extract_prompt_response_from_messages(chosen_val)
|
|
231
|
+
if not prompt and msg_prompt:
|
|
232
|
+
prompt = msg_prompt
|
|
233
|
+
if not chosen and msg_response:
|
|
234
|
+
chosen = msg_response
|
|
235
|
+
|
|
236
|
+
if (not prompt or not rejected) and rejected_val is not None:
|
|
237
|
+
msg_prompt, msg_response = _extract_prompt_response_from_messages(rejected_val)
|
|
238
|
+
if not prompt and msg_prompt:
|
|
239
|
+
prompt = msg_prompt
|
|
240
|
+
if not rejected and msg_response:
|
|
241
|
+
rejected = msg_response
|
|
242
|
+
|
|
243
|
+
if not prompt and "messages" in row and isinstance(row.get("messages"), list):
|
|
244
|
+
msg_prompt, _msg_response = _extract_prompt_response_from_messages(row.get("messages"))
|
|
245
|
+
if msg_prompt:
|
|
246
|
+
prompt = msg_prompt
|
|
247
|
+
|
|
248
|
+
return prompt, chosen, rejected
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def _write_provenance(out_dir: Path, metadata: dict) -> Path:
|
|
252
|
+
ensure_dir(out_dir)
|
|
253
|
+
meta_path = out_dir / "metadata.json"
|
|
254
|
+
meta_path.write_text(json.dumps(metadata, indent=2, ensure_ascii=False) + "\n", encoding="utf-8")
|
|
255
|
+
return meta_path
|
|
256
|
+
|
|
257
|
+
def import_sharegpt(in_path: Path, out_path: Path):
|
|
258
|
+
"""Convert ShareGPT-ish JSONL into {prompt, response} JSONL.
|
|
259
|
+
|
|
260
|
+
Expected input lines: {"conversations":[{"from":"human","value":"..."},{"from":"gpt","value":"..."}], ...}
|
|
261
|
+
"""
|
|
262
|
+
ensure_dir(out_path.parent)
|
|
263
|
+
n = 0
|
|
264
|
+
with in_path.open("r", encoding="utf-8") as fin, out_path.open("w", encoding="utf-8") as fout:
|
|
265
|
+
for line in fin:
|
|
266
|
+
line = line.strip()
|
|
267
|
+
if not line:
|
|
268
|
+
continue
|
|
269
|
+
obj = json.loads(line)
|
|
270
|
+
conv = obj.get("conversations") or []
|
|
271
|
+
# naive: take first human then first assistant after it
|
|
272
|
+
prompt = None
|
|
273
|
+
response = None
|
|
274
|
+
for turn in conv:
|
|
275
|
+
frm = (turn.get("from") or "").lower()
|
|
276
|
+
if prompt is None and frm in ("human", "user"):
|
|
277
|
+
prompt = turn.get("value") or ""
|
|
278
|
+
elif prompt is not None and response is None and frm in ("gpt", "assistant"):
|
|
279
|
+
response = turn.get("value") or ""
|
|
280
|
+
break
|
|
281
|
+
if prompt is None or response is None:
|
|
282
|
+
continue
|
|
283
|
+
fout.write(json.dumps({"prompt": prompt, "response": response}, ensure_ascii=False) + "\n")
|
|
284
|
+
n += 1
|
|
285
|
+
return n
|
|
286
|
+
|
|
287
|
+
def split_jsonl(in_path: Path, out_dir: Path, valid_frac: float, test_frac: float, seed: int = 1337):
|
|
288
|
+
import random
|
|
289
|
+
random.seed(seed)
|
|
290
|
+
rows = [json.loads(line) for line in in_path.read_text(encoding="utf-8").splitlines() if line.strip()]
|
|
291
|
+
random.shuffle(rows)
|
|
292
|
+
n = len(rows)
|
|
293
|
+
n_test = int(n * test_frac)
|
|
294
|
+
n_valid = int(n * valid_frac)
|
|
295
|
+
test = rows[:n_test]
|
|
296
|
+
valid = rows[n_test:n_test+n_valid]
|
|
297
|
+
train = rows[n_test+n_valid:]
|
|
298
|
+
ensure_dir(out_dir)
|
|
299
|
+
for name, part in [("train.jsonl", train), ("valid.jsonl", valid), ("test.jsonl", test)]:
|
|
300
|
+
(out_dir / name).write_text("\n".join(json.dumps(r, ensure_ascii=False) for r in part) + ("\n" if part else ""), encoding="utf-8")
|
|
301
|
+
return {"n": n, "train": len(train), "valid": len(valid), "test": len(test)}
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
def pull_hf_dataset(
|
|
305
|
+
dataset: str,
|
|
306
|
+
out_dir: Path,
|
|
307
|
+
split: str = "train",
|
|
308
|
+
limit: int | None = None,
|
|
309
|
+
prompt_field: str | None = None,
|
|
310
|
+
response_field: str | None = None,
|
|
311
|
+
chosen_field: str | None = None,
|
|
312
|
+
rejected_field: str | None = None,
|
|
313
|
+
config: str | None = None,
|
|
314
|
+
revision: str | None = None,
|
|
315
|
+
kind: str = "sft",
|
|
316
|
+
license: str | None = None,
|
|
317
|
+
notes: str | None = None,
|
|
318
|
+
preset: str | None = None,
|
|
319
|
+
write_metadata: bool = True,
|
|
320
|
+
) -> dict:
|
|
321
|
+
"""Download a HF dataset split and write prompt/response JSONL.
|
|
322
|
+
|
|
323
|
+
Args:
|
|
324
|
+
dataset: HF dataset name (e.g. "tatsu-lab/alpaca")
|
|
325
|
+
out_dir: Output directory (data/sft or data/prefs)
|
|
326
|
+
split: Dataset split to pull
|
|
327
|
+
limit: Optional max rows to write
|
|
328
|
+
prompt_field: Optional override for prompt field name
|
|
329
|
+
response_field: Optional override for response field name
|
|
330
|
+
config: Optional dataset config/subset name
|
|
331
|
+
revision: Optional dataset revision
|
|
332
|
+
|
|
333
|
+
Returns:
|
|
334
|
+
Stats dict with counts.
|
|
335
|
+
"""
|
|
336
|
+
try:
|
|
337
|
+
from datasets import load_dataset # type: ignore
|
|
338
|
+
except Exception as e: # pragma: no cover
|
|
339
|
+
raise RuntimeError(f"datasets not available: {e}")
|
|
340
|
+
|
|
341
|
+
ensure_dir(out_dir)
|
|
342
|
+
ds = load_dataset(dataset, name=config, split=split, revision=revision)
|
|
343
|
+
total = len(ds)
|
|
344
|
+
n = 0
|
|
345
|
+
skipped = 0
|
|
346
|
+
kind_norm = kind.strip().lower()
|
|
347
|
+
if kind_norm in ("pref", "prefs", "preference", "preferences"):
|
|
348
|
+
kind_norm = "prefs"
|
|
349
|
+
elif kind_norm in ("sft", "supervised"):
|
|
350
|
+
kind_norm = "sft"
|
|
351
|
+
else:
|
|
352
|
+
raise ValueError(f"Unsupported kind: {kind}")
|
|
353
|
+
out_path = out_dir / "train.jsonl"
|
|
354
|
+
with out_path.open("w", encoding="utf-8") as fout:
|
|
355
|
+
for row in ds:
|
|
356
|
+
if kind_norm == "prefs":
|
|
357
|
+
prompt, chosen, rejected = _row_to_pref(row, prompt_field, chosen_field, rejected_field)
|
|
358
|
+
if not (prompt and chosen and rejected):
|
|
359
|
+
skipped += 1
|
|
360
|
+
continue
|
|
361
|
+
fout.write(
|
|
362
|
+
json.dumps(
|
|
363
|
+
{"prompt": prompt, "chosen": chosen, "rejected": rejected},
|
|
364
|
+
ensure_ascii=False,
|
|
365
|
+
)
|
|
366
|
+
+ "\n"
|
|
367
|
+
)
|
|
368
|
+
else:
|
|
369
|
+
prompt, response = _row_to_prompt_response(row, prompt_field, response_field)
|
|
370
|
+
if not (prompt and response):
|
|
371
|
+
skipped += 1
|
|
372
|
+
continue
|
|
373
|
+
fout.write(json.dumps({"prompt": prompt, "response": response}, ensure_ascii=False) + "\n")
|
|
374
|
+
n += 1
|
|
375
|
+
if limit and n >= limit:
|
|
376
|
+
break
|
|
377
|
+
metadata = {
|
|
378
|
+
"dataset": dataset,
|
|
379
|
+
"preset": preset,
|
|
380
|
+
"config": config,
|
|
381
|
+
"split": split,
|
|
382
|
+
"revision": revision,
|
|
383
|
+
"license": license,
|
|
384
|
+
"notes": notes,
|
|
385
|
+
"kind": kind_norm,
|
|
386
|
+
"total": total,
|
|
387
|
+
"written": n,
|
|
388
|
+
"skipped": skipped,
|
|
389
|
+
"limit": limit,
|
|
390
|
+
"prompt_field": prompt_field,
|
|
391
|
+
"response_field": response_field if kind_norm == "sft" else None,
|
|
392
|
+
"chosen_field": chosen_field if kind_norm == "prefs" else None,
|
|
393
|
+
"rejected_field": rejected_field if kind_norm == "prefs" else None,
|
|
394
|
+
"out": str(out_path),
|
|
395
|
+
"generated_at": _utc_now_iso(),
|
|
396
|
+
}
|
|
397
|
+
if write_metadata:
|
|
398
|
+
meta_path = _write_provenance(out_dir, metadata)
|
|
399
|
+
metadata["metadata"] = str(meta_path)
|
|
400
|
+
return metadata
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
def _normalize_kind(kind: str | None) -> str | None:
|
|
404
|
+
if not kind:
|
|
405
|
+
return None
|
|
406
|
+
kind_norm = kind.strip().lower()
|
|
407
|
+
if kind_norm in ("pref", "prefs", "preference", "preferences"):
|
|
408
|
+
return "prefs"
|
|
409
|
+
if kind_norm in ("sft", "supervised"):
|
|
410
|
+
return "sft"
|
|
411
|
+
return kind_norm
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
def _infer_kind_from_row(row: dict) -> str:
|
|
415
|
+
if "chosen" in row or "rejected" in row:
|
|
416
|
+
return "prefs"
|
|
417
|
+
if "response" in row:
|
|
418
|
+
return "sft"
|
|
419
|
+
if "messages" in row:
|
|
420
|
+
return "sft"
|
|
421
|
+
return "unknown"
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
def analyze_jsonl(path: Path, kind: str | None = None, limit: int | None = None) -> dict:
|
|
425
|
+
stats = {
|
|
426
|
+
"rows": 0,
|
|
427
|
+
"empty_lines": 0,
|
|
428
|
+
"bad_json": 0,
|
|
429
|
+
"missing_prompt": 0,
|
|
430
|
+
"missing_response": 0,
|
|
431
|
+
"missing_chosen": 0,
|
|
432
|
+
"missing_rejected": 0,
|
|
433
|
+
"prompt_chars": 0,
|
|
434
|
+
"response_chars": 0,
|
|
435
|
+
"chosen_chars": 0,
|
|
436
|
+
"rejected_chars": 0,
|
|
437
|
+
"prompt_count": 0,
|
|
438
|
+
"response_count": 0,
|
|
439
|
+
"chosen_count": 0,
|
|
440
|
+
"rejected_count": 0,
|
|
441
|
+
"kind": None,
|
|
442
|
+
}
|
|
443
|
+
kind_norm = _normalize_kind(kind)
|
|
444
|
+
with path.open("r", encoding="utf-8") as fin:
|
|
445
|
+
for line in fin:
|
|
446
|
+
if limit and stats["rows"] >= limit:
|
|
447
|
+
break
|
|
448
|
+
if not line.strip():
|
|
449
|
+
stats["empty_lines"] += 1
|
|
450
|
+
continue
|
|
451
|
+
try:
|
|
452
|
+
row = json.loads(line)
|
|
453
|
+
except json.JSONDecodeError:
|
|
454
|
+
stats["bad_json"] += 1
|
|
455
|
+
continue
|
|
456
|
+
stats["rows"] += 1
|
|
457
|
+
if not kind_norm:
|
|
458
|
+
kind_norm = _infer_kind_from_row(row)
|
|
459
|
+
prompt = _coerce_text(row.get("prompt"))
|
|
460
|
+
if not prompt:
|
|
461
|
+
stats["missing_prompt"] += 1
|
|
462
|
+
else:
|
|
463
|
+
stats["prompt_chars"] += len(prompt)
|
|
464
|
+
stats["prompt_count"] += 1
|
|
465
|
+
if kind_norm == "prefs":
|
|
466
|
+
chosen = _coerce_text(row.get("chosen"))
|
|
467
|
+
rejected = _coerce_text(row.get("rejected"))
|
|
468
|
+
if not chosen:
|
|
469
|
+
stats["missing_chosen"] += 1
|
|
470
|
+
else:
|
|
471
|
+
stats["chosen_chars"] += len(chosen)
|
|
472
|
+
stats["chosen_count"] += 1
|
|
473
|
+
if not rejected:
|
|
474
|
+
stats["missing_rejected"] += 1
|
|
475
|
+
else:
|
|
476
|
+
stats["rejected_chars"] += len(rejected)
|
|
477
|
+
stats["rejected_count"] += 1
|
|
478
|
+
elif kind_norm == "sft":
|
|
479
|
+
response = _coerce_text(row.get("response"))
|
|
480
|
+
if not response:
|
|
481
|
+
stats["missing_response"] += 1
|
|
482
|
+
else:
|
|
483
|
+
stats["response_chars"] += len(response)
|
|
484
|
+
stats["response_count"] += 1
|
|
485
|
+
else:
|
|
486
|
+
response = _coerce_text(row.get("response"))
|
|
487
|
+
if response:
|
|
488
|
+
stats["response_chars"] += len(response)
|
|
489
|
+
stats["response_count"] += 1
|
|
490
|
+
else:
|
|
491
|
+
stats["missing_response"] += 1
|
|
492
|
+
stats["kind"] = kind_norm or "unknown"
|
|
493
|
+
return stats
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
from .system import (
|
|
2
|
+
EnvManifest,
|
|
3
|
+
EnvRef,
|
|
4
|
+
init_env,
|
|
5
|
+
install_env,
|
|
6
|
+
list_registry_packages,
|
|
7
|
+
load_manifest,
|
|
8
|
+
package_env,
|
|
9
|
+
pull_env,
|
|
10
|
+
publish_env,
|
|
11
|
+
registry_info,
|
|
12
|
+
resolve_env_path,
|
|
13
|
+
)
|
|
14
|
+
from .token_env import TokenEnv, TokenEnvStep, load_token_env_spec, create_token_env, StringTaskTokenEnv
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"EnvManifest",
|
|
18
|
+
"EnvRef",
|
|
19
|
+
"init_env",
|
|
20
|
+
"install_env",
|
|
21
|
+
"list_registry_packages",
|
|
22
|
+
"load_manifest",
|
|
23
|
+
"package_env",
|
|
24
|
+
"pull_env",
|
|
25
|
+
"publish_env",
|
|
26
|
+
"registry_info",
|
|
27
|
+
"resolve_env_path",
|
|
28
|
+
"TokenEnv",
|
|
29
|
+
"TokenEnvStep",
|
|
30
|
+
"load_token_env_spec",
|
|
31
|
+
"create_token_env",
|
|
32
|
+
"StringTaskTokenEnv",
|
|
33
|
+
]
|