osmosis-ai 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of osmosis-ai might be problematic. Click here for more details.

@@ -0,0 +1,229 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ import json
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Any, Optional, Sequence
8
+
9
+ from .errors import CLIError
10
+ from .shared import coerce_optional_float, gather_text_fragments
11
+
12
+
13
+ @dataclass(frozen=True)
14
+ class ConversationMessage:
15
+ """Normalized conversation message with preserved raw payload fields."""
16
+
17
+ role: str
18
+ content: Any
19
+ metadata: dict[str, Any]
20
+
21
+ def to_payload(self) -> dict[str, Any]:
22
+ payload: dict[str, Any] = copy.deepcopy(self.metadata)
23
+ payload["role"] = self.role
24
+ if self.content is None:
25
+ payload.pop("content", None)
26
+ else:
27
+ payload["content"] = copy.deepcopy(self.content)
28
+ return payload
29
+
30
+ def text_fragments(self) -> list[str]:
31
+ fragments: list[str] = []
32
+ seen: set[int] = set()
33
+ gather_text_fragments(self.content, fragments, allow_free_strings=True, seen=seen)
34
+ for value in self.metadata.values():
35
+ gather_text_fragments(value, fragments, seen=seen)
36
+ return fragments
37
+
38
+ @classmethod
39
+ def from_raw(cls, raw: dict[str, Any], *, source_label: str, index: int) -> "ConversationMessage":
40
+ role_value = raw.get("role")
41
+ if not isinstance(role_value, str) or not role_value.strip():
42
+ raise CLIError(
43
+ f"Message {index} in {source_label} must include a non-empty string 'role'."
44
+ )
45
+ content_value = copy.deepcopy(raw.get("content"))
46
+ metadata: dict[str, Any] = {}
47
+ for key, value in raw.items():
48
+ if key in {"role", "content"}:
49
+ continue
50
+ metadata[str(key)] = copy.deepcopy(value)
51
+ return cls(role=role_value.strip().lower(), content=content_value, metadata=metadata)
52
+
53
+
54
+ @dataclass(frozen=True)
55
+ class DatasetRecord:
56
+ payload: dict[str, Any]
57
+ rubric_id: str
58
+ conversation_id: Optional[str]
59
+ record_id: Optional[str]
60
+ messages: tuple[ConversationMessage, ...]
61
+ ground_truth: Optional[str]
62
+ system_message: Optional[str]
63
+ original_input: Optional[str]
64
+ metadata: Optional[dict[str, Any]]
65
+ extra_info: Optional[dict[str, Any]]
66
+ score_min: Optional[float]
67
+ score_max: Optional[float]
68
+
69
+ def message_payloads(self) -> list[dict[str, Any]]:
70
+ """Return messages as provider-ready payloads."""
71
+ return [message.to_payload() for message in self.messages]
72
+
73
+ def merged_extra_info(self, config_extra: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]:
74
+ merged: dict[str, Any] = {}
75
+ if isinstance(config_extra, dict):
76
+ merged.update(copy.deepcopy(config_extra))
77
+ if isinstance(self.extra_info, dict):
78
+ merged.update(copy.deepcopy(self.extra_info))
79
+ if isinstance(self.metadata, dict) and self.metadata:
80
+ merged.setdefault("dataset_metadata", copy.deepcopy(self.metadata))
81
+ return merged or None
82
+
83
+ def assistant_preview(self, *, max_length: int = 140) -> Optional[str]:
84
+ for message in reversed(self.messages):
85
+ if message.role != "assistant":
86
+ continue
87
+ fragments = message.text_fragments()
88
+ if not fragments:
89
+ continue
90
+ preview = " ".join(" ".join(fragments).split())
91
+ if not preview:
92
+ continue
93
+ if len(preview) > max_length:
94
+ preview = preview[: max_length - 3].rstrip() + "..."
95
+ return preview
96
+ return None
97
+
98
+ def conversation_label(self, fallback_index: int) -> str:
99
+ if isinstance(self.conversation_id, str) and self.conversation_id.strip():
100
+ return self.conversation_id.strip()
101
+ return f"record[{fallback_index}]"
102
+
103
+ def record_identifier(self, conversation_label: str) -> str:
104
+ if isinstance(self.record_id, str) and self.record_id.strip():
105
+ return self.record_id.strip()
106
+ raw_id = self.payload.get("id")
107
+ if isinstance(raw_id, str) and raw_id.strip():
108
+ return raw_id.strip()
109
+ if raw_id is not None:
110
+ return str(raw_id)
111
+ return conversation_label
112
+
113
+
114
+ class DatasetLoader:
115
+ """Loads dataset records from JSONL files."""
116
+
117
+ def load(self, path: Path) -> list[DatasetRecord]:
118
+ records: list[DatasetRecord] = []
119
+ with path.open("r", encoding="utf-8") as fh:
120
+ for line_number, raw_line in enumerate(fh, start=1):
121
+ stripped = raw_line.strip()
122
+ if not stripped:
123
+ continue
124
+ try:
125
+ payload = json.loads(stripped)
126
+ except json.JSONDecodeError as exc:
127
+ raise CLIError(
128
+ f"Invalid JSON on line {line_number} of '{path}': {exc.msg}"
129
+ ) from exc
130
+ if not isinstance(payload, dict):
131
+ raise CLIError(
132
+ f"Expected JSON object on line {line_number} of '{path}'."
133
+ )
134
+
135
+ records.append(self._create_record(payload))
136
+
137
+ if not records:
138
+ raise CLIError(f"No JSON records found in '{path}'.")
139
+
140
+ return records
141
+
142
+ @staticmethod
143
+ def _create_record(payload: dict[str, Any]) -> DatasetRecord:
144
+ rubric_id = payload.get("rubric_id")
145
+ rubric_id_str = str(rubric_id).strip() if isinstance(rubric_id, str) else ""
146
+
147
+ conversation_id_raw = payload.get("conversation_id")
148
+ conversation_id = None
149
+ if isinstance(conversation_id_raw, str) and conversation_id_raw.strip():
150
+ conversation_id = conversation_id_raw.strip()
151
+
152
+ record_id_raw = payload.get("id")
153
+ record_id = str(record_id_raw).strip() if isinstance(record_id_raw, str) else None
154
+
155
+ score_min = coerce_optional_float(
156
+ payload.get("score_min"), "score_min", f"record '{conversation_id or rubric_id or '<record>'}'"
157
+ )
158
+ score_max = coerce_optional_float(
159
+ payload.get("score_max"), "score_max", f"record '{conversation_id or rubric_id or '<record>'}'"
160
+ )
161
+
162
+ metadata = payload.get("metadata") if isinstance(payload.get("metadata"), dict) else None
163
+ extra_info = payload.get("extra_info") if isinstance(payload.get("extra_info"), dict) else None
164
+ record_label = conversation_id or record_id or rubric_id_str or "<record>"
165
+ messages = _parse_messages(payload.get("messages"), source_label=record_label)
166
+
167
+ return DatasetRecord(
168
+ payload=payload,
169
+ rubric_id=rubric_id_str,
170
+ conversation_id=conversation_id,
171
+ record_id=record_id,
172
+ messages=messages,
173
+ ground_truth=payload.get("ground_truth") if isinstance(payload.get("ground_truth"), str) else None,
174
+ system_message=payload.get("system_message") if isinstance(payload.get("system_message"), str) else None,
175
+ original_input=payload.get("original_input") if isinstance(payload.get("original_input"), str) else None,
176
+ metadata=metadata,
177
+ extra_info=extra_info,
178
+ score_min=score_min,
179
+ score_max=score_max,
180
+ )
181
+
182
+
183
+ def load_jsonl_records(path: Path) -> list[dict[str, Any]]:
184
+ records: list[dict[str, Any]] = []
185
+ with path.open("r", encoding="utf-8") as fh:
186
+ for line_number, raw_line in enumerate(fh, start=1):
187
+ stripped = raw_line.strip()
188
+ if not stripped:
189
+ continue
190
+ try:
191
+ record = json.loads(stripped)
192
+ except json.JSONDecodeError as exc:
193
+ raise CLIError(f"Invalid JSON on line {line_number} of '{path}': {exc.msg}") from exc
194
+ if not isinstance(record, dict):
195
+ raise CLIError(f"Expected JSON object on line {line_number} of '{path}'.")
196
+ records.append(record)
197
+
198
+ if not records:
199
+ raise CLIError(f"No JSON records found in '{path}'.")
200
+
201
+ return records
202
+
203
+
204
+ def render_json_records(records: Sequence[dict[str, Any]]) -> str:
205
+ segments: list[str] = []
206
+ total = len(records)
207
+
208
+ for index, record in enumerate(records, start=1):
209
+ body = json.dumps(record, indent=2, ensure_ascii=False)
210
+ snippet = [f"JSONL record #{index}", body]
211
+ if index != total:
212
+ snippet.append("")
213
+ segments.append("\n".join(snippet))
214
+
215
+ return "\n".join(segments)
216
+
217
+
218
+ def _parse_messages(messages: Any, *, source_label: str) -> tuple[ConversationMessage, ...]:
219
+ if not isinstance(messages, list) or not messages:
220
+ raise CLIError(f"Record '{source_label}' must include a non-empty 'messages' list.")
221
+
222
+ normalized: list[ConversationMessage] = []
223
+ for index, entry in enumerate(messages):
224
+ if not isinstance(entry, dict):
225
+ raise CLIError(
226
+ f"Message {index} in {source_label} must be an object, got {type(entry).__name__}."
227
+ )
228
+ normalized.append(ConversationMessage.from_raw(entry, source_label=source_label, index=index))
229
+ return tuple(normalized)
@@ -0,0 +1,251 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ import sys
5
+ import time
6
+ from dataclasses import dataclass
7
+ from datetime import datetime, timezone
8
+ from pathlib import Path
9
+ from typing import Any, Optional, Sequence
10
+
11
+ from tqdm import tqdm
12
+
13
+ from ..rubric_eval import evaluate_rubric
14
+ from ..rubric_types import MissingAPIKeyError, ModelNotFoundError, ProviderRequestError
15
+ from .config import RubricConfig
16
+ from .dataset import DatasetRecord
17
+ from .errors import CLIError
18
+ from .shared import calculate_statistics, coerce_optional_float, collapse_preview_text
19
+
20
+
21
+ class RubricEvaluator:
22
+ """Thin wrapper over evaluate_rubric to enable injection during tests."""
23
+
24
+ def __init__(self, evaluate_fn: Any = evaluate_rubric):
25
+ self._evaluate_fn = evaluate_fn
26
+
27
+ def run(self, config: RubricConfig, record: DatasetRecord) -> dict[str, Any]:
28
+ messages = record.message_payloads()
29
+ if not messages:
30
+ label = record.conversation_id or record.rubric_id or "<record>"
31
+ raise CLIError(f"Record '{label}' must include a non-empty 'messages' list.")
32
+
33
+ score_min = coerce_optional_float(
34
+ record.score_min if record.score_min is not None else config.score_min,
35
+ "score_min",
36
+ f"record '{record.conversation_id or '<record>'}'",
37
+ )
38
+ score_max = coerce_optional_float(
39
+ record.score_max if record.score_max is not None else config.score_max,
40
+ "score_max",
41
+ f"record '{record.conversation_id or '<record>'}'",
42
+ )
43
+
44
+ try:
45
+ return self._evaluate_fn(
46
+ rubric=config.rubric_text,
47
+ messages=messages,
48
+ model_info=copy.deepcopy(config.model_info),
49
+ ground_truth=record.ground_truth if record.ground_truth is not None else config.ground_truth,
50
+ system_message=record.system_message if record.system_message is not None else config.system_message,
51
+ original_input=record.original_input if record.original_input is not None else config.original_input,
52
+ extra_info=record.merged_extra_info(config.extra_info),
53
+ score_min=score_min,
54
+ score_max=score_max,
55
+ return_details=True,
56
+ )
57
+ except (MissingAPIKeyError, ProviderRequestError, ModelNotFoundError) as exc:
58
+ raise CLIError(str(exc)) from exc
59
+
60
+
61
+ @dataclass
62
+ class EvaluationRun:
63
+ run_index: int
64
+ status: str
65
+ score: Optional[float]
66
+ explanation: Optional[str]
67
+ preview: Optional[str]
68
+ duration_seconds: float
69
+ started_at: datetime
70
+ completed_at: datetime
71
+ error: Optional[str]
72
+ raw: Any
73
+
74
+
75
+ @dataclass
76
+ class EvaluationRecordResult:
77
+ record_index: int
78
+ record: DatasetRecord
79
+ conversation_label: str
80
+ runs: list[EvaluationRun]
81
+ statistics: dict[str, float]
82
+
83
+
84
+ @dataclass
85
+ class EvaluationReport:
86
+ rubric_config: RubricConfig
87
+ config_path: Path
88
+ data_path: Path
89
+ number: int
90
+ record_results: list[EvaluationRecordResult]
91
+ overall_statistics: dict[str, float]
92
+
93
+
94
+ class RubricEvaluationEngine:
95
+ """Executes rubric evaluations across a dataset and aggregates statistics."""
96
+
97
+ def __init__(self, evaluator: Optional[RubricEvaluator] = None):
98
+ self._evaluator = evaluator or RubricEvaluator()
99
+
100
+ def execute(
101
+ self,
102
+ *,
103
+ rubric_config: RubricConfig,
104
+ config_path: Path,
105
+ data_path: Path,
106
+ records: Sequence[DatasetRecord],
107
+ number: int,
108
+ ) -> EvaluationReport:
109
+ record_results: list[EvaluationRecordResult] = []
110
+ aggregate_scores: list[float] = []
111
+ total_runs = 0
112
+ total_successes = 0
113
+
114
+ progress_total = len(records) * number
115
+ show_progress = progress_total > 1 and getattr(sys.stderr, "isatty", lambda: False)()
116
+ progress = (
117
+ tqdm(
118
+ total=progress_total,
119
+ file=sys.stderr,
120
+ dynamic_ncols=True,
121
+ leave=False,
122
+ )
123
+ if show_progress
124
+ else None
125
+ )
126
+
127
+ try:
128
+ for record_index, record in enumerate(records, start=1):
129
+ conversation_label = record.conversation_label(record_index)
130
+ fallback_preview = record.assistant_preview()
131
+
132
+ runs: list[EvaluationRun] = []
133
+ scores: list[float] = []
134
+
135
+ for attempt in range(1, number + 1):
136
+ started_at = datetime.now(timezone.utc)
137
+ timer_start = time.perf_counter()
138
+ status = "success"
139
+ error_message: Optional[str] = None
140
+ score_value: Optional[float] = None
141
+ explanation_value: Optional[str] = None
142
+ preview_value: Optional[str] = None
143
+ raw_payload: Any = None
144
+
145
+ try:
146
+ result = self._evaluator.run(rubric_config, record)
147
+ except CLIError as exc:
148
+ status = "error"
149
+ error_message = str(exc)
150
+ result = None
151
+ except Exception as exc: # pragma: no cover - unexpected path
152
+ status = "error"
153
+ error_message = f"{type(exc).__name__}: {exc}"
154
+ result = None
155
+
156
+ duration_seconds = time.perf_counter() - timer_start
157
+ completed_at = datetime.now(timezone.utc)
158
+
159
+ if status == "success" and isinstance(result, dict):
160
+ raw_payload = result.get("raw")
161
+ score_value = _extract_float(result.get("score"))
162
+ explanation_value = _normalize_optional_text(result.get("explanation"))
163
+ preview_value = self._resolve_preview_text(result, fallback_preview)
164
+ if score_value is not None:
165
+ scores.append(score_value)
166
+ aggregate_scores.append(score_value)
167
+ total_successes += 1
168
+ else:
169
+ preview_value = fallback_preview
170
+
171
+ total_runs += 1
172
+
173
+ runs.append(
174
+ EvaluationRun(
175
+ run_index=attempt,
176
+ status=status,
177
+ score=score_value,
178
+ explanation=explanation_value,
179
+ preview=preview_value,
180
+ duration_seconds=duration_seconds,
181
+ started_at=started_at,
182
+ completed_at=completed_at,
183
+ error=error_message,
184
+ raw=raw_payload,
185
+ )
186
+ )
187
+
188
+ if progress:
189
+ progress.update()
190
+
191
+ statistics = calculate_statistics(scores)
192
+ statistics["total_runs"] = len(runs)
193
+ statistics["success_count"] = len(scores)
194
+ statistics["failure_count"] = len(runs) - len(scores)
195
+ record_results.append(
196
+ EvaluationRecordResult(
197
+ record_index=record_index,
198
+ record=record,
199
+ conversation_label=conversation_label,
200
+ runs=runs,
201
+ statistics=statistics,
202
+ )
203
+ )
204
+ finally:
205
+ if progress:
206
+ progress.close()
207
+
208
+ overall_statistics = calculate_statistics(aggregate_scores)
209
+ overall_statistics["total_runs"] = total_runs
210
+ overall_statistics["success_count"] = total_successes
211
+ overall_statistics["failure_count"] = total_runs - total_successes
212
+ return EvaluationReport(
213
+ rubric_config=rubric_config,
214
+ config_path=config_path,
215
+ data_path=data_path,
216
+ number=number,
217
+ record_results=record_results,
218
+ overall_statistics=overall_statistics,
219
+ )
220
+
221
+ @staticmethod
222
+ def _resolve_preview_text(result: Optional[dict[str, Any]], fallback: Optional[str]) -> Optional[str]:
223
+ if not isinstance(result, dict):
224
+ return fallback
225
+ preview = collapse_preview_text(result.get("preview"))
226
+ if preview:
227
+ return preview
228
+
229
+ raw_payload = result.get("raw")
230
+ if isinstance(raw_payload, dict):
231
+ for key in ("preview", "summary", "text"):
232
+ preview = collapse_preview_text(raw_payload.get(key))
233
+ if preview:
234
+ return preview
235
+ return fallback
236
+
237
+
238
+ def _extract_float(value: Any) -> Optional[float]:
239
+ try:
240
+ if isinstance(value, (int, float)) and not isinstance(value, bool):
241
+ return float(value)
242
+ return None
243
+ except (TypeError, ValueError):
244
+ return None
245
+
246
+
247
+ def _normalize_optional_text(value: Any) -> Optional[str]:
248
+ if value is None:
249
+ return None
250
+ text = str(value).strip()
251
+ return text or None
@@ -0,0 +1,7 @@
1
+ from __future__ import annotations
2
+
3
+
4
+ class CLIError(Exception):
5
+ """Raised when the CLI encounters a recoverable error."""
6
+
7
+ pass