osmosis-ai 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of osmosis-ai might be problematic. Click here for more details.

@@ -0,0 +1,251 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ import sys
5
+ import time
6
+ from dataclasses import dataclass
7
+ from datetime import datetime, timezone
8
+ from pathlib import Path
9
+ from typing import Any, Optional, Sequence
10
+
11
+ from tqdm import tqdm
12
+
13
+ from ..rubric_eval import evaluate_rubric
14
+ from ..rubric_types import MissingAPIKeyError, ModelNotFoundError, ProviderRequestError
15
+ from .config import RubricConfig
16
+ from .dataset import DatasetRecord
17
+ from .errors import CLIError
18
+ from .shared import calculate_statistics, coerce_optional_float, collapse_preview_text
19
+
20
+
21
+ class RubricEvaluator:
22
+ """Thin wrapper over evaluate_rubric to enable injection during tests."""
23
+
24
+ def __init__(self, evaluate_fn: Any = evaluate_rubric):
25
+ self._evaluate_fn = evaluate_fn
26
+
27
+ def run(self, config: RubricConfig, record: DatasetRecord) -> dict[str, Any]:
28
+ messages = record.message_payloads()
29
+ if not messages:
30
+ label = record.conversation_id or record.rubric_id or "<record>"
31
+ raise CLIError(f"Record '{label}' must include a non-empty 'messages' list.")
32
+
33
+ score_min = coerce_optional_float(
34
+ record.score_min if record.score_min is not None else config.score_min,
35
+ "score_min",
36
+ f"record '{record.conversation_id or '<record>'}'",
37
+ )
38
+ score_max = coerce_optional_float(
39
+ record.score_max if record.score_max is not None else config.score_max,
40
+ "score_max",
41
+ f"record '{record.conversation_id or '<record>'}'",
42
+ )
43
+
44
+ try:
45
+ return self._evaluate_fn(
46
+ rubric=config.rubric_text,
47
+ messages=messages,
48
+ model_info=copy.deepcopy(config.model_info),
49
+ ground_truth=record.ground_truth if record.ground_truth is not None else config.ground_truth,
50
+ system_message=record.system_message if record.system_message is not None else config.system_message,
51
+ original_input=record.original_input if record.original_input is not None else config.original_input,
52
+ extra_info=record.merged_extra_info(config.extra_info),
53
+ score_min=score_min,
54
+ score_max=score_max,
55
+ return_details=True,
56
+ )
57
+ except (MissingAPIKeyError, ProviderRequestError, ModelNotFoundError) as exc:
58
+ raise CLIError(str(exc)) from exc
59
+
60
+
61
+ @dataclass
62
+ class EvaluationRun:
63
+ run_index: int
64
+ status: str
65
+ score: Optional[float]
66
+ explanation: Optional[str]
67
+ preview: Optional[str]
68
+ duration_seconds: float
69
+ started_at: datetime
70
+ completed_at: datetime
71
+ error: Optional[str]
72
+ raw: Any
73
+
74
+
75
+ @dataclass
76
+ class EvaluationRecordResult:
77
+ record_index: int
78
+ record: DatasetRecord
79
+ conversation_label: str
80
+ runs: list[EvaluationRun]
81
+ statistics: dict[str, float]
82
+
83
+
84
+ @dataclass
85
+ class EvaluationReport:
86
+ rubric_config: RubricConfig
87
+ config_path: Path
88
+ data_path: Path
89
+ number: int
90
+ record_results: list[EvaluationRecordResult]
91
+ overall_statistics: dict[str, float]
92
+
93
+
94
+ class RubricEvaluationEngine:
95
+ """Executes rubric evaluations across a dataset and aggregates statistics."""
96
+
97
+ def __init__(self, evaluator: Optional[RubricEvaluator] = None):
98
+ self._evaluator = evaluator or RubricEvaluator()
99
+
100
+ def execute(
101
+ self,
102
+ *,
103
+ rubric_config: RubricConfig,
104
+ config_path: Path,
105
+ data_path: Path,
106
+ records: Sequence[DatasetRecord],
107
+ number: int,
108
+ ) -> EvaluationReport:
109
+ record_results: list[EvaluationRecordResult] = []
110
+ aggregate_scores: list[float] = []
111
+ total_runs = 0
112
+ total_successes = 0
113
+
114
+ progress_total = len(records) * number
115
+ show_progress = progress_total > 1 and getattr(sys.stderr, "isatty", lambda: False)()
116
+ progress = (
117
+ tqdm(
118
+ total=progress_total,
119
+ file=sys.stderr,
120
+ dynamic_ncols=True,
121
+ leave=False,
122
+ )
123
+ if show_progress
124
+ else None
125
+ )
126
+
127
+ try:
128
+ for record_index, record in enumerate(records, start=1):
129
+ conversation_label = record.conversation_label(record_index)
130
+ fallback_preview = record.assistant_preview()
131
+
132
+ runs: list[EvaluationRun] = []
133
+ scores: list[float] = []
134
+
135
+ for attempt in range(1, number + 1):
136
+ started_at = datetime.now(timezone.utc)
137
+ timer_start = time.perf_counter()
138
+ status = "success"
139
+ error_message: Optional[str] = None
140
+ score_value: Optional[float] = None
141
+ explanation_value: Optional[str] = None
142
+ preview_value: Optional[str] = None
143
+ raw_payload: Any = None
144
+
145
+ try:
146
+ result = self._evaluator.run(rubric_config, record)
147
+ except CLIError as exc:
148
+ status = "error"
149
+ error_message = str(exc)
150
+ result = None
151
+ except Exception as exc: # pragma: no cover - unexpected path
152
+ status = "error"
153
+ error_message = f"{type(exc).__name__}: {exc}"
154
+ result = None
155
+
156
+ duration_seconds = time.perf_counter() - timer_start
157
+ completed_at = datetime.now(timezone.utc)
158
+
159
+ if status == "success" and isinstance(result, dict):
160
+ raw_payload = result.get("raw")
161
+ score_value = _extract_float(result.get("score"))
162
+ explanation_value = _normalize_optional_text(result.get("explanation"))
163
+ preview_value = self._resolve_preview_text(result, fallback_preview)
164
+ if score_value is not None:
165
+ scores.append(score_value)
166
+ aggregate_scores.append(score_value)
167
+ total_successes += 1
168
+ else:
169
+ preview_value = fallback_preview
170
+
171
+ total_runs += 1
172
+
173
+ runs.append(
174
+ EvaluationRun(
175
+ run_index=attempt,
176
+ status=status,
177
+ score=score_value,
178
+ explanation=explanation_value,
179
+ preview=preview_value,
180
+ duration_seconds=duration_seconds,
181
+ started_at=started_at,
182
+ completed_at=completed_at,
183
+ error=error_message,
184
+ raw=raw_payload,
185
+ )
186
+ )
187
+
188
+ if progress:
189
+ progress.update()
190
+
191
+ statistics = calculate_statistics(scores)
192
+ statistics["total_runs"] = len(runs)
193
+ statistics["success_count"] = len(scores)
194
+ statistics["failure_count"] = len(runs) - len(scores)
195
+ record_results.append(
196
+ EvaluationRecordResult(
197
+ record_index=record_index,
198
+ record=record,
199
+ conversation_label=conversation_label,
200
+ runs=runs,
201
+ statistics=statistics,
202
+ )
203
+ )
204
+ finally:
205
+ if progress:
206
+ progress.close()
207
+
208
+ overall_statistics = calculate_statistics(aggregate_scores)
209
+ overall_statistics["total_runs"] = total_runs
210
+ overall_statistics["success_count"] = total_successes
211
+ overall_statistics["failure_count"] = total_runs - total_successes
212
+ return EvaluationReport(
213
+ rubric_config=rubric_config,
214
+ config_path=config_path,
215
+ data_path=data_path,
216
+ number=number,
217
+ record_results=record_results,
218
+ overall_statistics=overall_statistics,
219
+ )
220
+
221
+ @staticmethod
222
+ def _resolve_preview_text(result: Optional[dict[str, Any]], fallback: Optional[str]) -> Optional[str]:
223
+ if not isinstance(result, dict):
224
+ return fallback
225
+ preview = collapse_preview_text(result.get("preview"))
226
+ if preview:
227
+ return preview
228
+
229
+ raw_payload = result.get("raw")
230
+ if isinstance(raw_payload, dict):
231
+ for key in ("preview", "summary", "text"):
232
+ preview = collapse_preview_text(raw_payload.get(key))
233
+ if preview:
234
+ return preview
235
+ return fallback
236
+
237
+
238
+ def _extract_float(value: Any) -> Optional[float]:
239
+ try:
240
+ if isinstance(value, (int, float)) and not isinstance(value, bool):
241
+ return float(value)
242
+ return None
243
+ except (TypeError, ValueError):
244
+ return None
245
+
246
+
247
+ def _normalize_optional_text(value: Any) -> Optional[str]:
248
+ if value is None:
249
+ return None
250
+ text = str(value).strip()
251
+ return text or None
@@ -0,0 +1,7 @@
1
+ from __future__ import annotations
2
+
3
+
4
+ class CLIError(Exception):
5
+ """Raised when the CLI encounters a recoverable error."""
6
+
7
+ pass
@@ -0,0 +1,307 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from datetime import datetime, timezone
5
+ from pathlib import Path
6
+ from typing import Any, Callable, Optional
7
+
8
+ from .engine import EvaluationRecordResult, EvaluationReport, EvaluationRun
9
+ from .errors import CLIError
10
+ from .shared import calculate_stat_deltas
11
+
12
+
13
+ class TextReportFormatter:
14
+ """Builds human-readable text lines for an evaluation report."""
15
+
16
+ def build(
17
+ self,
18
+ report: EvaluationReport,
19
+ baseline: Optional["BaselineStatistics"] = None,
20
+ ) -> list[str]:
21
+ lines: list[str] = []
22
+ provider = str(report.rubric_config.model_info.get("provider", "")).strip() or "<unknown>"
23
+ model_name = str(report.rubric_config.model_info.get("model", "")).strip() or "<unspecified>"
24
+
25
+ lines.append(
26
+ f"Rubric '{report.rubric_config.rubric_id}' "
27
+ f"({report.rubric_config.source_label}) -> provider '{provider}' model '{model_name}'"
28
+ )
29
+ lines.append(f"Loaded {len(report.record_results)} matching record(s) from {report.data_path}")
30
+ lines.append(f"Running {report.number} evaluation(s) per record")
31
+ lines.append("")
32
+
33
+ for record_result in report.record_results:
34
+ lines.extend(self._format_record(record_result))
35
+ if baseline is not None:
36
+ lines.extend(self._format_baseline(report, baseline))
37
+ return lines
38
+
39
+ def _format_record(self, record_result: EvaluationRecordResult) -> list[str]:
40
+ lines: list[str] = [f"[{record_result.conversation_label}]"]
41
+ total_runs = len(record_result.runs)
42
+ for index, run in enumerate(record_result.runs):
43
+ lines.extend(self._format_run(run))
44
+ if index < total_runs - 1:
45
+ lines.append("")
46
+
47
+ summary_lines = self._format_summary(record_result.statistics, len(record_result.runs))
48
+ if summary_lines:
49
+ lines.extend(summary_lines)
50
+ lines.append("")
51
+ return lines
52
+
53
+ def _format_run(self, run: EvaluationRun) -> list[str]:
54
+ lines: list[str] = [f" Run {run.run_index:02d} [{run.status.upper()}]"]
55
+ if run.status == "success":
56
+ score_text = "n/a" if run.score is None else f"{run.score:.4f}"
57
+ lines.append(self._format_detail_line("score", score_text))
58
+ if run.preview:
59
+ lines.append(self._format_detail_line("preview", run.preview))
60
+ explanation = run.explanation or "(no explanation provided)"
61
+ lines.append(self._format_detail_line("explanation", explanation))
62
+ else:
63
+ error_text = run.error or "(no error message provided)"
64
+ lines.append(self._format_detail_line("error", error_text))
65
+ if run.preview:
66
+ lines.append(self._format_detail_line("preview", run.preview))
67
+ if run.explanation:
68
+ lines.append(self._format_detail_line("explanation", run.explanation))
69
+ lines.append(self._format_detail_line("duration", f"{run.duration_seconds:.2f}s"))
70
+ return lines
71
+
72
+ def _format_summary(self, statistics: dict[str, float], total_runs: int) -> list[str]:
73
+ if not statistics:
74
+ return []
75
+ success_count = int(round(statistics.get("success_count", total_runs)))
76
+ failure_count = int(round(statistics.get("failure_count", total_runs - success_count)))
77
+ if total_runs <= 1 and failure_count == 0:
78
+ return []
79
+
80
+ lines = [" Summary:"]
81
+ lines.append(f" total: {int(round(statistics.get('total_runs', total_runs)))}")
82
+ lines.append(f" successes: {success_count}")
83
+ lines.append(f" failures: {failure_count}")
84
+ if success_count > 0:
85
+ lines.append(f" average: {statistics.get('average', 0.0):.4f}")
86
+ lines.append(f" variance: {statistics.get('variance', 0.0):.6f}")
87
+ lines.append(f" stdev: {statistics.get('stdev', 0.0):.4f}")
88
+ lines.append(f" min/max: {statistics.get('min', 0.0):.4f} / {statistics.get('max', 0.0):.4f}")
89
+ else:
90
+ lines.append(" average: n/a")
91
+ lines.append(" variance: n/a")
92
+ lines.append(" stdev: n/a")
93
+ lines.append(" min/max: n/a")
94
+ return lines
95
+
96
+ def _format_baseline(
97
+ self,
98
+ report: EvaluationReport,
99
+ baseline: "BaselineStatistics",
100
+ ) -> list[str]:
101
+ lines = [f"Baseline comparison (source: {baseline.source_path}):"]
102
+ deltas = baseline.delta(report.overall_statistics)
103
+ keys = ["average", "variance", "stdev", "min", "max", "success_count", "failure_count", "total_runs"]
104
+
105
+ for key in keys:
106
+ if key not in baseline.statistics or key not in report.overall_statistics:
107
+ continue
108
+ baseline_value = float(baseline.statistics[key])
109
+ current_value = float(report.overall_statistics[key])
110
+ delta_value = float(deltas.get(key, current_value - baseline_value))
111
+
112
+ if key in {"success_count", "failure_count", "total_runs"}:
113
+ baseline_str = f"{int(round(baseline_value))}"
114
+ current_str = f"{int(round(current_value))}"
115
+ delta_str = f"{delta_value:+.0f}"
116
+ else:
117
+ precision = 6 if key == "variance" else 4
118
+ baseline_str = format(baseline_value, f".{precision}f")
119
+ current_str = format(current_value, f".{precision}f")
120
+ delta_str = format(delta_value, f"+.{precision}f")
121
+
122
+ lines.append(
123
+ f" {key:12s} baseline={baseline_str} current={current_str} delta={delta_str}"
124
+ )
125
+ return lines
126
+
127
+ @staticmethod
128
+ def _format_detail_line(label: str, value: str, *, indent: int = 4) -> str:
129
+ indent_str = " " * indent
130
+ value_str = str(value)
131
+ continuation_indent = indent_str + " "
132
+ value_str = value_str.replace("\n", f"\n{continuation_indent}")
133
+ return f"{indent_str}{label}: {value_str}"
134
+
135
+
136
+ class ConsoleReportRenderer:
137
+ """Pretty prints evaluation reports to stdout (or any printer function)."""
138
+
139
+ def __init__(
140
+ self,
141
+ printer: Callable[[str], None] = print,
142
+ formatter: Optional[TextReportFormatter] = None,
143
+ ):
144
+ self._printer = printer
145
+ self._formatter = formatter or TextReportFormatter()
146
+
147
+ def render(
148
+ self,
149
+ report: EvaluationReport,
150
+ baseline: Optional["BaselineStatistics"] = None,
151
+ ) -> None:
152
+ for line in self._formatter.build(report, baseline):
153
+ self._printer(line)
154
+
155
+
156
+ class BaselineStatistics:
157
+ def __init__(self, source_path: Path, statistics: dict[str, float]):
158
+ self.source_path = source_path
159
+ self.statistics = statistics
160
+
161
+ def delta(self, current: dict[str, float]) -> dict[str, float]:
162
+ return calculate_stat_deltas(self.statistics, current)
163
+
164
+
165
+ class BaselineComparator:
166
+ """Loads baseline JSON payloads and extracts statistics."""
167
+
168
+ def load(self, path: Path) -> BaselineStatistics:
169
+ if not path.exists():
170
+ raise CLIError(f"Baseline path '{path}' does not exist.")
171
+ if path.is_dir():
172
+ raise CLIError(f"Baseline path '{path}' is a directory, expected JSON file.")
173
+
174
+ try:
175
+ payload = json.loads(path.read_text(encoding="utf-8"))
176
+ except json.JSONDecodeError as exc:
177
+ raise CLIError(f"Failed to parse baseline JSON: {exc}") from exc
178
+
179
+ if not isinstance(payload, dict):
180
+ raise CLIError("Baseline JSON must contain an object.")
181
+
182
+ source = None
183
+ if isinstance(payload.get("overall_statistics"), dict):
184
+ source = payload["overall_statistics"]
185
+ elif all(key in payload for key in ("average", "variance", "stdev")):
186
+ source = payload
187
+ if source is None:
188
+ raise CLIError(
189
+ "Baseline JSON must include an 'overall_statistics' object or top-level statistics."
190
+ )
191
+
192
+ statistics: dict[str, float] = {}
193
+ for key, value in source.items():
194
+ try:
195
+ statistics[key] = float(value)
196
+ except (TypeError, ValueError):
197
+ continue
198
+ if not statistics:
199
+ raise CLIError("Baseline statistics could not be parsed into numeric values.")
200
+
201
+ return BaselineStatistics(source_path=path, statistics=statistics)
202
+
203
+
204
+ class JsonReportFormatter:
205
+ """Builds JSON-serialisable payloads for evaluation reports."""
206
+
207
+ def build(
208
+ self,
209
+ report: EvaluationReport,
210
+ *,
211
+ output_identifier: Optional[str],
212
+ baseline: Optional[BaselineStatistics],
213
+ ) -> dict[str, Any]:
214
+ provider = str(report.rubric_config.model_info.get("provider", "")).strip() or "<unknown>"
215
+ model_name = str(report.rubric_config.model_info.get("model", "")).strip() or "<unspecified>"
216
+
217
+ generated: dict[str, Any] = {
218
+ "generated_at": datetime.now(timezone.utc).isoformat(),
219
+ "rubric_id": report.rubric_config.rubric_id,
220
+ "rubric_source": report.rubric_config.source_label,
221
+ "provider": provider,
222
+ "model": model_name,
223
+ "number": report.number,
224
+ "config_path": str(report.config_path),
225
+ "data_path": str(report.data_path),
226
+ "overall_statistics": _normalise_statistics(report.overall_statistics),
227
+ "records": [],
228
+ }
229
+ if output_identifier is not None:
230
+ generated["output_identifier"] = output_identifier
231
+
232
+ for record_result in report.record_results:
233
+ conversation_label = record_result.conversation_label
234
+ record_identifier = record_result.record.record_identifier(conversation_label)
235
+
236
+ record_payload: dict[str, Any] = {
237
+ "id": record_identifier,
238
+ "record_index": record_result.record_index,
239
+ "conversation_id": conversation_label,
240
+ "input_record": record_result.record.payload,
241
+ "statistics": _normalise_statistics(record_result.statistics),
242
+ "runs": [],
243
+ }
244
+
245
+ for run in record_result.runs:
246
+ record_payload["runs"].append(
247
+ {
248
+ "run_index": run.run_index,
249
+ "status": run.status,
250
+ "started_at": run.started_at.isoformat(),
251
+ "completed_at": run.completed_at.isoformat(),
252
+ "duration_seconds": run.duration_seconds,
253
+ "score": run.score,
254
+ "explanation": run.explanation,
255
+ "preview": run.preview,
256
+ "error": run.error,
257
+ "raw": run.raw,
258
+ }
259
+ )
260
+
261
+ generated["records"].append(record_payload)
262
+
263
+ if baseline is not None:
264
+ generated["baseline_comparison"] = {
265
+ "source_path": str(baseline.source_path),
266
+ "baseline_statistics": _normalise_statistics(baseline.statistics),
267
+ "delta_statistics": _normalise_statistics(baseline.delta(report.overall_statistics)),
268
+ }
269
+
270
+ return generated
271
+
272
+
273
+ class JsonReportWriter:
274
+ """Serialises an evaluation report to disk."""
275
+
276
+ def __init__(self, formatter: Optional[JsonReportFormatter] = None):
277
+ self._formatter = formatter or JsonReportFormatter()
278
+
279
+ def write(
280
+ self,
281
+ report: EvaluationReport,
282
+ *,
283
+ output_path: Path,
284
+ output_identifier: Optional[str],
285
+ baseline: Optional[BaselineStatistics],
286
+ ) -> Path:
287
+ parent_dir = output_path.parent
288
+ if parent_dir and not parent_dir.exists():
289
+ parent_dir.mkdir(parents=True, exist_ok=True)
290
+
291
+ payload = self._formatter.build(
292
+ report,
293
+ output_identifier=output_identifier,
294
+ baseline=baseline,
295
+ )
296
+ with output_path.open("w", encoding="utf-8") as fh:
297
+ json.dump(payload, fh, indent=2, ensure_ascii=False)
298
+ return output_path
299
+
300
+ def _normalise_statistics(stats: dict[str, float]) -> dict[str, Any]:
301
+ normalised: dict[str, Any] = {}
302
+ for key, value in stats.items():
303
+ if key in {"success_count", "failure_count", "total_runs"}:
304
+ normalised[key] = int(round(value))
305
+ else:
306
+ normalised[key] = float(value)
307
+ return normalised