osmosis-ai 0.1.9__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of osmosis-ai might be problematic. Click here for more details.

@@ -0,0 +1,410 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Any, Optional, Sequence
7
+
8
+ import yaml
9
+ from yaml.representer import SafeRepresenter
10
+
11
+ from .errors import CLIError
12
+ from .shared import coerce_optional_float
13
+
14
+
15
+ @dataclass(frozen=True)
16
+ class ParsedItem:
17
+ label: Optional[str]
18
+ payload: Any
19
+
20
+
21
+ @dataclass(frozen=True)
22
+ class RubricConfig:
23
+ rubric_id: str
24
+ rubric_text: str
25
+ model_info: dict[str, Any]
26
+ score_min: Optional[float]
27
+ score_max: Optional[float]
28
+ system_prompt: Optional[str]
29
+ original_input: Optional[str]
30
+ ground_truth: Optional[str]
31
+ source_label: str
32
+
33
+
34
+ @dataclass(frozen=True)
35
+ class RubricSuite:
36
+ source_path: Path
37
+ version: Optional[int]
38
+ configs: dict[str, RubricConfig]
39
+
40
+ def get(self, rubric_id: str) -> RubricConfig:
41
+ if rubric_id not in self.configs:
42
+ available = ", ".join(self.available_ids()) or "none"
43
+ raise CLIError(
44
+ f"Rubric '{rubric_id}' not found in '{self.source_path}'. Available IDs: {available}"
45
+ )
46
+ return self.configs[rubric_id]
47
+
48
+ def available_ids(self) -> list[str]:
49
+ return sorted(self.configs)
50
+
51
+
52
+ @dataclass(frozen=True)
53
+ class RubricConfigDocumentResult:
54
+ configs: dict[str, RubricConfig]
55
+ items: list[ParsedItem]
56
+
57
+
58
+ class RubricConfigDocumentSchema:
59
+ """Base interface for schema-specific rubric config parsing."""
60
+
61
+ version: Optional[int] = None
62
+
63
+ def parse_document(
64
+ self,
65
+ document: Any,
66
+ *,
67
+ path: Path,
68
+ doc_index: int,
69
+ strict: bool,
70
+ ) -> RubricConfigDocumentResult:
71
+ raise NotImplementedError
72
+
73
+
74
+ class LegacyRubricConfigSchema(RubricConfigDocumentSchema):
75
+ """Schema handling documents without an explicit version."""
76
+
77
+ version = None
78
+
79
+ def parse_document(
80
+ self,
81
+ document: Any,
82
+ *,
83
+ path: Path,
84
+ doc_index: int,
85
+ strict: bool,
86
+ ) -> RubricConfigDocumentResult:
87
+ defaults = _extract_config_defaults(document, path, doc_index)
88
+ entries = _extract_rubric_items(document, context=None, doc_index=doc_index)
89
+ return _build_document_configs(entries, defaults, path=path, doc_index=doc_index, strict=strict)
90
+
91
+
92
+ class Version1RubricConfigSchema(LegacyRubricConfigSchema):
93
+ """Schema for version 1 documents (currently aligned with legacy layout)."""
94
+
95
+ version = 1
96
+
97
+
98
+ class RubricConfigParser:
99
+ """Parses rubric configuration files and produces typed suites."""
100
+
101
+ def __init__(self, *, schemas: Optional[dict[Optional[int], RubricConfigDocumentSchema]] = None):
102
+ self._schemas = schemas or {
103
+ None: LegacyRubricConfigSchema(),
104
+ 1: Version1RubricConfigSchema(),
105
+ }
106
+ if None not in self._schemas:
107
+ raise ValueError("At least one default schema (key=None) must be provided.")
108
+
109
+ def parse(self, path: Path, *, strict: bool = True) -> tuple[RubricSuite, list[ParsedItem]]:
110
+ documents = _load_yaml_documents(path)
111
+ configs: dict[str, RubricConfig] = {}
112
+ parsed_items: list[ParsedItem] = []
113
+ detected_version: Optional[int] = None
114
+ document_indices = []
115
+
116
+ for doc_index, document in enumerate(documents):
117
+ if document:
118
+ document_indices.append(doc_index)
119
+ if not document:
120
+ continue
121
+
122
+ doc_version = self._coerce_optional_version(document, path, doc_index)
123
+ if doc_version is not None:
124
+ if detected_version is None:
125
+ detected_version = doc_version
126
+ elif detected_version != doc_version:
127
+ raise CLIError(
128
+ f"Rubric config '{path}' mixes different version numbers across documents."
129
+ )
130
+
131
+ schema = self._select_schema(detected_version)
132
+
133
+ for doc_index in document_indices:
134
+ document = documents[doc_index]
135
+ if not document:
136
+ continue
137
+
138
+ result = schema.parse_document(
139
+ document,
140
+ path=path,
141
+ doc_index=doc_index,
142
+ strict=strict,
143
+ )
144
+ parsed_items.extend(result.items)
145
+ for rubric_id, config in result.configs.items():
146
+ if rubric_id in configs:
147
+ raise CLIError(f"Duplicate rubric id '{rubric_id}' detected in '{path}'.")
148
+ configs[rubric_id] = config
149
+
150
+ if strict and not configs:
151
+ raise CLIError(f"No rubric entries found in '{path}'.")
152
+
153
+ suite = RubricSuite(source_path=path, version=detected_version, configs=configs)
154
+ return suite, parsed_items
155
+
156
+ def _select_schema(self, version: Optional[int]) -> RubricConfigDocumentSchema:
157
+ if version in self._schemas:
158
+ return self._schemas[version]
159
+ if version is None:
160
+ return self._schemas[None]
161
+ raise CLIError(f"Unsupported rubric config version '{version}'.")
162
+
163
+ @staticmethod
164
+ def _coerce_optional_version(document: Any, path: Path, doc_index: int) -> Optional[int]:
165
+ if not isinstance(document, dict):
166
+ return None
167
+ version_value = document.get("version")
168
+ if version_value is None:
169
+ return None
170
+ if isinstance(version_value, int):
171
+ if version_value < 0:
172
+ raise CLIError(
173
+ f"Version number in '{path}' document {doc_index} must be non-negative."
174
+ )
175
+ return version_value
176
+ raise CLIError(
177
+ f"Version field in '{path}' document {doc_index} must be an integer."
178
+ )
179
+
180
+
181
+ def _build_document_configs(
182
+ entries: Sequence[ParsedItem],
183
+ defaults: dict[str, Any],
184
+ *,
185
+ path: Path,
186
+ doc_index: int,
187
+ strict: bool,
188
+ ) -> RubricConfigDocumentResult:
189
+ configs: dict[str, RubricConfig] = {}
190
+ parsed_items: list[ParsedItem] = []
191
+
192
+ for item in entries:
193
+ payload = item.payload
194
+ parsed_items.append(ParsedItem(label=item.label, payload=payload))
195
+ if not isinstance(payload, dict):
196
+ continue
197
+ if "extra_info" in payload:
198
+ message = (
199
+ f"Rubric entry in '{path}' (document {doc_index + 1}) must not include 'extra_info'."
200
+ )
201
+ if strict:
202
+ raise CLIError(message)
203
+ continue
204
+
205
+ rubric_key_raw = payload.get("id")
206
+ if not isinstance(rubric_key_raw, str) or not rubric_key_raw.strip():
207
+ if strict:
208
+ raise CLIError(
209
+ f"Rubric entry in '{path}' (document {doc_index}) is missing a non-empty 'id'."
210
+ )
211
+ continue
212
+ rubric_key = rubric_key_raw.strip()
213
+ if rubric_key in configs:
214
+ raise CLIError(f"Duplicate rubric id '{rubric_key}' detected in '{path}'.")
215
+
216
+ rubric_text = payload.get("rubric")
217
+ if not isinstance(rubric_text, str) or not rubric_text.strip():
218
+ if strict:
219
+ raise CLIError(
220
+ f"Rubric '{rubric_key}' in '{path}' must include a non-empty 'rubric' string."
221
+ )
222
+ continue
223
+
224
+ model_info = payload.get("model_info", defaults.get("model_info"))
225
+ if not isinstance(model_info, dict):
226
+ if strict:
227
+ raise CLIError(
228
+ f"Rubric '{rubric_key}' in '{path}' must include a 'model_info' mapping."
229
+ )
230
+ continue
231
+
232
+ try:
233
+ score_min = coerce_optional_float(
234
+ payload.get("score_min", defaults.get("score_min")),
235
+ "score_min",
236
+ f"rubric '{rubric_key}' in {path}",
237
+ )
238
+ score_max = coerce_optional_float(
239
+ payload.get("score_max", defaults.get("score_max")),
240
+ "score_max",
241
+ f"rubric '{rubric_key}' in {path}",
242
+ )
243
+ except CLIError:
244
+ if strict:
245
+ raise
246
+ continue
247
+
248
+ system_prompt = payload.get("system_prompt", defaults.get("system_prompt"))
249
+
250
+ original_input = payload.get("original_input", defaults.get("original_input"))
251
+ if not isinstance(original_input, str):
252
+ original_input = None
253
+
254
+ ground_truth = payload.get("ground_truth", defaults.get("ground_truth"))
255
+
256
+ label = item.label or f"document[{doc_index}]"
257
+ source_label = f"{path}:{label}"
258
+
259
+ configs[rubric_key] = RubricConfig(
260
+ rubric_id=rubric_key,
261
+ rubric_text=rubric_text,
262
+ model_info=copy.deepcopy(model_info),
263
+ score_min=score_min,
264
+ score_max=score_max,
265
+ system_prompt=system_prompt if isinstance(system_prompt, str) else None,
266
+ original_input=original_input,
267
+ ground_truth=ground_truth if isinstance(ground_truth, str) else None,
268
+ source_label=source_label,
269
+ )
270
+
271
+ return RubricConfigDocumentResult(configs=configs, items=parsed_items)
272
+
273
+
274
+ def discover_rubric_config_path(config_arg: Optional[str], data_path: Path) -> Path:
275
+ if config_arg:
276
+ candidate = Path(config_arg).expanduser()
277
+ if not candidate.exists():
278
+ raise CLIError(f"Rubric config path '{candidate}' does not exist.")
279
+ if candidate.is_dir():
280
+ raise CLIError(f"Rubric config path '{candidate}' is a directory.")
281
+ return candidate
282
+
283
+ candidates: list[Path] = []
284
+ candidates.append(data_path.parent / "rubric_configs.yaml")
285
+ candidates.append(Path.cwd() / "rubric_configs.yaml")
286
+ candidates.append(Path.cwd() / "examples" / "rubric_configs.yaml")
287
+
288
+ checked: list[Path] = []
289
+ for candidate in dict.fromkeys(candidates):
290
+ checked.append(candidate)
291
+ if candidate.exists() and candidate.is_file():
292
+ return candidate
293
+
294
+ searched = ", ".join(str(path) for path in checked)
295
+ raise CLIError(
296
+ "Unable to locate a rubric config file. Provide --config explicitly. "
297
+ f"Paths checked: {searched}"
298
+ )
299
+
300
+
301
+ def load_rubric_configs(path: Path) -> list[ParsedItem]:
302
+ parser = RubricConfigParser()
303
+ _, items = parser.parse(path, strict=False)
304
+ return items
305
+
306
+
307
+ def load_rubric_suite(path: Path) -> RubricSuite:
308
+ parser = RubricConfigParser()
309
+ suite, _ = parser.parse(path)
310
+ return suite
311
+
312
+
313
+ def render_yaml_items(items: Sequence[ParsedItem], label: str) -> str:
314
+ blocks: list[str] = []
315
+ total = len(items)
316
+
317
+ for index, item in enumerate(items, start=1):
318
+ header = f"{label} #{index}"
319
+ if item.label:
320
+ header += f" ({item.label})"
321
+ dumped = yaml.dump(
322
+ item.payload,
323
+ Dumper=_LiteralSafeDumper,
324
+ sort_keys=False,
325
+ indent=2,
326
+ allow_unicode=True,
327
+ ).rstrip()
328
+
329
+ snippet = [header, dumped]
330
+ if index != total:
331
+ snippet.append("")
332
+ blocks.append("\n".join(snippet))
333
+
334
+ return "\n".join(blocks)
335
+
336
+
337
+ def _load_yaml_documents(path: Path) -> list[Any]:
338
+ try:
339
+ with path.open("r", encoding="utf-8") as fh:
340
+ return list(yaml.safe_load_all(fh))
341
+ except yaml.YAMLError as exc:
342
+ raise CLIError(f"Failed to parse YAML in '{path}': {exc}") from exc
343
+ except OSError as exc:
344
+ raise CLIError(f"Unable to read rubric config '{path}': {exc}") from exc
345
+
346
+
347
+ def _extract_config_defaults(document: Any, path: Path, doc_index: int) -> dict[str, Any]:
348
+ if not isinstance(document, dict):
349
+ return {
350
+ "model_info": None,
351
+ "score_min": None,
352
+ "score_max": None,
353
+ "system_prompt": None,
354
+ "original_input": None,
355
+ "ground_truth": None,
356
+ }
357
+
358
+ source = f"document[{doc_index}] in {path}"
359
+
360
+ defaults: dict[str, Any] = {}
361
+ if "default_extra_info" in document:
362
+ raise CLIError(
363
+ f"Rubric config document {doc_index + 1} in {path} must not include 'default_extra_info'; extra_info is no longer supported."
364
+ )
365
+ defaults["model_info"] = document.get("default_model_info")
366
+ defaults["score_min"] = coerce_optional_float(
367
+ document.get("default_score_min"), "default_score_min", source
368
+ )
369
+ defaults["score_max"] = coerce_optional_float(
370
+ document.get("default_score_max"), "default_score_max", source
371
+ )
372
+ defaults["system_prompt"] = document.get("default_system_prompt")
373
+ defaults["original_input"] = document.get("default_original_input")
374
+ defaults["ground_truth"] = document.get("default_ground_truth")
375
+ return defaults
376
+
377
+
378
+ def _extract_rubric_items(node: Any, context: Optional[str], doc_index: int) -> list[ParsedItem]:
379
+ items: list[ParsedItem] = []
380
+
381
+ if node is None:
382
+ return items
383
+
384
+ if isinstance(node, dict):
385
+ if "rubric" in node and isinstance(node["rubric"], str):
386
+ label = context or f"document[{doc_index}]"
387
+ items.append(ParsedItem(label=label, payload=node))
388
+ else:
389
+ for key, value in node.items():
390
+ next_context = str(key) if isinstance(key, str) else context
391
+ items.extend(_extract_rubric_items(value, context=next_context, doc_index=doc_index))
392
+ elif isinstance(node, list):
393
+ for index, value in enumerate(node):
394
+ idx_context = f"{context}[{index}]" if context else None
395
+ items.extend(_extract_rubric_items(value, context=idx_context, doc_index=doc_index))
396
+
397
+ return items
398
+
399
+
400
+ class _LiteralSafeDumper(yaml.SafeDumper):
401
+ """YAML dumper that preserves multiline strings with literal blocks."""
402
+
403
+
404
+ def _represent_str(dumper: yaml.Dumper, data: str):
405
+ if "\n" in data:
406
+ return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|")
407
+ return SafeRepresenter.represent_str(dumper, data)
408
+
409
+
410
+ _LiteralSafeDumper.add_representer(str, _represent_str)
@@ -0,0 +1,175 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ import json
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Any, Optional, Sequence
8
+
9
+ from .errors import CLIError
10
+ from .shared import coerce_optional_float
11
+
12
+
13
+ @dataclass(frozen=True)
14
+ class DatasetRecord:
15
+ payload: dict[str, Any]
16
+ rubric_id: str
17
+ conversation_id: Optional[str]
18
+ record_id: Optional[str]
19
+ solution_str: str
20
+ ground_truth: Optional[str]
21
+ original_input: Optional[str]
22
+ metadata: Optional[dict[str, Any]]
23
+ extra_info: Optional[dict[str, Any]]
24
+ score_min: Optional[float]
25
+ score_max: Optional[float]
26
+
27
+ def merged_extra_info(self) -> Optional[dict[str, Any]]:
28
+ merged: dict[str, Any] = {}
29
+ if isinstance(self.extra_info, dict):
30
+ merged.update(copy.deepcopy(self.extra_info))
31
+ if isinstance(self.metadata, dict) and self.metadata:
32
+ merged.setdefault("dataset_metadata", copy.deepcopy(self.metadata))
33
+ return merged or None
34
+
35
+ def assistant_preview(self, *, max_length: int = 140) -> Optional[str]:
36
+ text = self.solution_str.strip()
37
+ if not text:
38
+ return None
39
+ preview = " ".join(text.split())
40
+ if not preview:
41
+ return None
42
+ if len(preview) > max_length:
43
+ preview = preview[: max_length - 3].rstrip() + "..."
44
+ return preview
45
+
46
+ def conversation_label(self, fallback_index: int) -> str:
47
+ if isinstance(self.conversation_id, str) and self.conversation_id.strip():
48
+ return self.conversation_id.strip()
49
+ return f"record[{fallback_index}]"
50
+
51
+ def record_identifier(self, conversation_label: str) -> str:
52
+ if isinstance(self.record_id, str) and self.record_id.strip():
53
+ return self.record_id.strip()
54
+ raw_id = self.payload.get("id")
55
+ if isinstance(raw_id, str) and raw_id.strip():
56
+ return raw_id.strip()
57
+ if raw_id is not None:
58
+ return str(raw_id)
59
+ return conversation_label
60
+
61
+
62
+ class DatasetLoader:
63
+ """Loads dataset records from JSONL files."""
64
+
65
+ def load(self, path: Path) -> list[DatasetRecord]:
66
+ records: list[DatasetRecord] = []
67
+ with path.open("r", encoding="utf-8") as fh:
68
+ for line_number, raw_line in enumerate(fh, start=1):
69
+ stripped = raw_line.strip()
70
+ if not stripped:
71
+ continue
72
+ try:
73
+ payload = json.loads(stripped)
74
+ except json.JSONDecodeError as exc:
75
+ raise CLIError(
76
+ f"Invalid JSON on line {line_number} of '{path}': {exc.msg}"
77
+ ) from exc
78
+ if not isinstance(payload, dict):
79
+ raise CLIError(
80
+ f"Expected JSON object on line {line_number} of '{path}'."
81
+ )
82
+
83
+ records.append(self._create_record(payload))
84
+
85
+ if not records:
86
+ raise CLIError(f"No JSON records found in '{path}'.")
87
+
88
+ return records
89
+
90
+ @staticmethod
91
+ def _create_record(payload: dict[str, Any]) -> DatasetRecord:
92
+ rubric_id = payload.get("rubric_id")
93
+ rubric_id_str = str(rubric_id).strip() if isinstance(rubric_id, str) else ""
94
+
95
+ conversation_id_raw = payload.get("conversation_id")
96
+ conversation_id = None
97
+ if isinstance(conversation_id_raw, str) and conversation_id_raw.strip():
98
+ conversation_id = conversation_id_raw.strip()
99
+
100
+ record_id_raw = payload.get("id")
101
+ record_id = str(record_id_raw).strip() if isinstance(record_id_raw, str) else None
102
+
103
+ score_min = coerce_optional_float(
104
+ payload.get("score_min"), "score_min", f"record '{conversation_id or rubric_id or '<record>'}'"
105
+ )
106
+ score_max = coerce_optional_float(
107
+ payload.get("score_max"), "score_max", f"record '{conversation_id or rubric_id or '<record>'}'"
108
+ )
109
+
110
+ metadata = payload.get("metadata") if isinstance(payload.get("metadata"), dict) else None
111
+ extra_info = payload.get("extra_info") if isinstance(payload.get("extra_info"), dict) else None
112
+ record_label = conversation_id or record_id or rubric_id_str or "<record>"
113
+ solution_raw = payload.get("solution_str")
114
+ if not isinstance(solution_raw, str) or not solution_raw.strip():
115
+ raise CLIError(f"Record '{record_label}' must include a non-empty 'solution_str' string.")
116
+
117
+ original_input_raw = payload.get("original_input")
118
+ if isinstance(original_input_raw, str):
119
+ original_input = original_input_raw
120
+ else:
121
+ original_input = None
122
+
123
+ if original_input is None and isinstance(extra_info, dict):
124
+ extra_original_input = extra_info.get("original_input")
125
+ if isinstance(extra_original_input, str):
126
+ original_input = extra_original_input
127
+
128
+ return DatasetRecord(
129
+ payload=payload,
130
+ rubric_id=rubric_id_str,
131
+ conversation_id=conversation_id,
132
+ record_id=record_id,
133
+ solution_str=solution_raw,
134
+ ground_truth=payload.get("ground_truth") if isinstance(payload.get("ground_truth"), str) else None,
135
+ original_input=original_input,
136
+ metadata=metadata,
137
+ extra_info=extra_info,
138
+ score_min=score_min,
139
+ score_max=score_max,
140
+ )
141
+
142
+
143
+ def load_jsonl_records(path: Path) -> list[dict[str, Any]]:
144
+ records: list[dict[str, Any]] = []
145
+ with path.open("r", encoding="utf-8") as fh:
146
+ for line_number, raw_line in enumerate(fh, start=1):
147
+ stripped = raw_line.strip()
148
+ if not stripped:
149
+ continue
150
+ try:
151
+ record = json.loads(stripped)
152
+ except json.JSONDecodeError as exc:
153
+ raise CLIError(f"Invalid JSON on line {line_number} of '{path}': {exc.msg}") from exc
154
+ if not isinstance(record, dict):
155
+ raise CLIError(f"Expected JSON object on line {line_number} of '{path}'.")
156
+ records.append(record)
157
+
158
+ if not records:
159
+ raise CLIError(f"No JSON records found in '{path}'.")
160
+
161
+ return records
162
+
163
+
164
+ def render_json_records(records: Sequence[dict[str, Any]]) -> str:
165
+ segments: list[str] = []
166
+ total = len(records)
167
+
168
+ for index, record in enumerate(records, start=1):
169
+ body = json.dumps(record, indent=2, ensure_ascii=False)
170
+ snippet = [f"JSONL record #{index}", body]
171
+ if index != total:
172
+ snippet.append("")
173
+ segments.append("\n".join(snippet))
174
+
175
+ return "\n".join(segments)