osmosis-ai 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of osmosis-ai might be problematic. Click here for more details.
- osmosis_ai/__init__.py +13 -4
- osmosis_ai/cli.py +50 -0
- osmosis_ai/cli_commands.py +181 -0
- osmosis_ai/cli_services/__init__.py +67 -0
- osmosis_ai/cli_services/config.py +407 -0
- osmosis_ai/cli_services/dataset.py +229 -0
- osmosis_ai/cli_services/engine.py +251 -0
- osmosis_ai/cli_services/errors.py +7 -0
- osmosis_ai/cli_services/reporting.py +307 -0
- osmosis_ai/cli_services/session.py +174 -0
- osmosis_ai/cli_services/shared.py +209 -0
- osmosis_ai/consts.py +1 -1
- osmosis_ai/providers/__init__.py +36 -0
- osmosis_ai/providers/anthropic_provider.py +85 -0
- osmosis_ai/providers/base.py +60 -0
- osmosis_ai/providers/gemini_provider.py +314 -0
- osmosis_ai/providers/openai_family.py +607 -0
- osmosis_ai/providers/shared.py +92 -0
- osmosis_ai/rubric_eval.py +498 -0
- osmosis_ai/rubric_types.py +49 -0
- osmosis_ai/utils.py +392 -5
- osmosis_ai-0.2.3.dist-info/METADATA +303 -0
- osmosis_ai-0.2.3.dist-info/RECORD +27 -0
- osmosis_ai-0.2.3.dist-info/entry_points.txt +4 -0
- osmosis_ai-0.2.1.dist-info/METADATA +0 -143
- osmosis_ai-0.2.1.dist-info/RECORD +0 -8
- {osmosis_ai-0.2.1.dist-info → osmosis_ai-0.2.3.dist-info}/WHEEL +0 -0
- {osmosis_ai-0.2.1.dist-info → osmosis_ai-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {osmosis_ai-0.2.1.dist-info → osmosis_ai-0.2.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
import sys
|
|
5
|
+
import time
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from datetime import datetime, timezone
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any, Optional, Sequence
|
|
10
|
+
|
|
11
|
+
from tqdm import tqdm
|
|
12
|
+
|
|
13
|
+
from ..rubric_eval import evaluate_rubric
|
|
14
|
+
from ..rubric_types import MissingAPIKeyError, ModelNotFoundError, ProviderRequestError
|
|
15
|
+
from .config import RubricConfig
|
|
16
|
+
from .dataset import DatasetRecord
|
|
17
|
+
from .errors import CLIError
|
|
18
|
+
from .shared import calculate_statistics, coerce_optional_float, collapse_preview_text
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class RubricEvaluator:
|
|
22
|
+
"""Thin wrapper over evaluate_rubric to enable injection during tests."""
|
|
23
|
+
|
|
24
|
+
def __init__(self, evaluate_fn: Any = evaluate_rubric):
|
|
25
|
+
self._evaluate_fn = evaluate_fn
|
|
26
|
+
|
|
27
|
+
def run(self, config: RubricConfig, record: DatasetRecord) -> dict[str, Any]:
|
|
28
|
+
messages = record.message_payloads()
|
|
29
|
+
if not messages:
|
|
30
|
+
label = record.conversation_id or record.rubric_id or "<record>"
|
|
31
|
+
raise CLIError(f"Record '{label}' must include a non-empty 'messages' list.")
|
|
32
|
+
|
|
33
|
+
score_min = coerce_optional_float(
|
|
34
|
+
record.score_min if record.score_min is not None else config.score_min,
|
|
35
|
+
"score_min",
|
|
36
|
+
f"record '{record.conversation_id or '<record>'}'",
|
|
37
|
+
)
|
|
38
|
+
score_max = coerce_optional_float(
|
|
39
|
+
record.score_max if record.score_max is not None else config.score_max,
|
|
40
|
+
"score_max",
|
|
41
|
+
f"record '{record.conversation_id or '<record>'}'",
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
try:
|
|
45
|
+
return self._evaluate_fn(
|
|
46
|
+
rubric=config.rubric_text,
|
|
47
|
+
messages=messages,
|
|
48
|
+
model_info=copy.deepcopy(config.model_info),
|
|
49
|
+
ground_truth=record.ground_truth if record.ground_truth is not None else config.ground_truth,
|
|
50
|
+
system_message=record.system_message if record.system_message is not None else config.system_message,
|
|
51
|
+
original_input=record.original_input if record.original_input is not None else config.original_input,
|
|
52
|
+
extra_info=record.merged_extra_info(config.extra_info),
|
|
53
|
+
score_min=score_min,
|
|
54
|
+
score_max=score_max,
|
|
55
|
+
return_details=True,
|
|
56
|
+
)
|
|
57
|
+
except (MissingAPIKeyError, ProviderRequestError, ModelNotFoundError) as exc:
|
|
58
|
+
raise CLIError(str(exc)) from exc
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
@dataclass
|
|
62
|
+
class EvaluationRun:
|
|
63
|
+
run_index: int
|
|
64
|
+
status: str
|
|
65
|
+
score: Optional[float]
|
|
66
|
+
explanation: Optional[str]
|
|
67
|
+
preview: Optional[str]
|
|
68
|
+
duration_seconds: float
|
|
69
|
+
started_at: datetime
|
|
70
|
+
completed_at: datetime
|
|
71
|
+
error: Optional[str]
|
|
72
|
+
raw: Any
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@dataclass
|
|
76
|
+
class EvaluationRecordResult:
|
|
77
|
+
record_index: int
|
|
78
|
+
record: DatasetRecord
|
|
79
|
+
conversation_label: str
|
|
80
|
+
runs: list[EvaluationRun]
|
|
81
|
+
statistics: dict[str, float]
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@dataclass
|
|
85
|
+
class EvaluationReport:
|
|
86
|
+
rubric_config: RubricConfig
|
|
87
|
+
config_path: Path
|
|
88
|
+
data_path: Path
|
|
89
|
+
number: int
|
|
90
|
+
record_results: list[EvaluationRecordResult]
|
|
91
|
+
overall_statistics: dict[str, float]
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class RubricEvaluationEngine:
|
|
95
|
+
"""Executes rubric evaluations across a dataset and aggregates statistics."""
|
|
96
|
+
|
|
97
|
+
def __init__(self, evaluator: Optional[RubricEvaluator] = None):
|
|
98
|
+
self._evaluator = evaluator or RubricEvaluator()
|
|
99
|
+
|
|
100
|
+
def execute(
|
|
101
|
+
self,
|
|
102
|
+
*,
|
|
103
|
+
rubric_config: RubricConfig,
|
|
104
|
+
config_path: Path,
|
|
105
|
+
data_path: Path,
|
|
106
|
+
records: Sequence[DatasetRecord],
|
|
107
|
+
number: int,
|
|
108
|
+
) -> EvaluationReport:
|
|
109
|
+
record_results: list[EvaluationRecordResult] = []
|
|
110
|
+
aggregate_scores: list[float] = []
|
|
111
|
+
total_runs = 0
|
|
112
|
+
total_successes = 0
|
|
113
|
+
|
|
114
|
+
progress_total = len(records) * number
|
|
115
|
+
show_progress = progress_total > 1 and getattr(sys.stderr, "isatty", lambda: False)()
|
|
116
|
+
progress = (
|
|
117
|
+
tqdm(
|
|
118
|
+
total=progress_total,
|
|
119
|
+
file=sys.stderr,
|
|
120
|
+
dynamic_ncols=True,
|
|
121
|
+
leave=False,
|
|
122
|
+
)
|
|
123
|
+
if show_progress
|
|
124
|
+
else None
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
try:
|
|
128
|
+
for record_index, record in enumerate(records, start=1):
|
|
129
|
+
conversation_label = record.conversation_label(record_index)
|
|
130
|
+
fallback_preview = record.assistant_preview()
|
|
131
|
+
|
|
132
|
+
runs: list[EvaluationRun] = []
|
|
133
|
+
scores: list[float] = []
|
|
134
|
+
|
|
135
|
+
for attempt in range(1, number + 1):
|
|
136
|
+
started_at = datetime.now(timezone.utc)
|
|
137
|
+
timer_start = time.perf_counter()
|
|
138
|
+
status = "success"
|
|
139
|
+
error_message: Optional[str] = None
|
|
140
|
+
score_value: Optional[float] = None
|
|
141
|
+
explanation_value: Optional[str] = None
|
|
142
|
+
preview_value: Optional[str] = None
|
|
143
|
+
raw_payload: Any = None
|
|
144
|
+
|
|
145
|
+
try:
|
|
146
|
+
result = self._evaluator.run(rubric_config, record)
|
|
147
|
+
except CLIError as exc:
|
|
148
|
+
status = "error"
|
|
149
|
+
error_message = str(exc)
|
|
150
|
+
result = None
|
|
151
|
+
except Exception as exc: # pragma: no cover - unexpected path
|
|
152
|
+
status = "error"
|
|
153
|
+
error_message = f"{type(exc).__name__}: {exc}"
|
|
154
|
+
result = None
|
|
155
|
+
|
|
156
|
+
duration_seconds = time.perf_counter() - timer_start
|
|
157
|
+
completed_at = datetime.now(timezone.utc)
|
|
158
|
+
|
|
159
|
+
if status == "success" and isinstance(result, dict):
|
|
160
|
+
raw_payload = result.get("raw")
|
|
161
|
+
score_value = _extract_float(result.get("score"))
|
|
162
|
+
explanation_value = _normalize_optional_text(result.get("explanation"))
|
|
163
|
+
preview_value = self._resolve_preview_text(result, fallback_preview)
|
|
164
|
+
if score_value is not None:
|
|
165
|
+
scores.append(score_value)
|
|
166
|
+
aggregate_scores.append(score_value)
|
|
167
|
+
total_successes += 1
|
|
168
|
+
else:
|
|
169
|
+
preview_value = fallback_preview
|
|
170
|
+
|
|
171
|
+
total_runs += 1
|
|
172
|
+
|
|
173
|
+
runs.append(
|
|
174
|
+
EvaluationRun(
|
|
175
|
+
run_index=attempt,
|
|
176
|
+
status=status,
|
|
177
|
+
score=score_value,
|
|
178
|
+
explanation=explanation_value,
|
|
179
|
+
preview=preview_value,
|
|
180
|
+
duration_seconds=duration_seconds,
|
|
181
|
+
started_at=started_at,
|
|
182
|
+
completed_at=completed_at,
|
|
183
|
+
error=error_message,
|
|
184
|
+
raw=raw_payload,
|
|
185
|
+
)
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
if progress:
|
|
189
|
+
progress.update()
|
|
190
|
+
|
|
191
|
+
statistics = calculate_statistics(scores)
|
|
192
|
+
statistics["total_runs"] = len(runs)
|
|
193
|
+
statistics["success_count"] = len(scores)
|
|
194
|
+
statistics["failure_count"] = len(runs) - len(scores)
|
|
195
|
+
record_results.append(
|
|
196
|
+
EvaluationRecordResult(
|
|
197
|
+
record_index=record_index,
|
|
198
|
+
record=record,
|
|
199
|
+
conversation_label=conversation_label,
|
|
200
|
+
runs=runs,
|
|
201
|
+
statistics=statistics,
|
|
202
|
+
)
|
|
203
|
+
)
|
|
204
|
+
finally:
|
|
205
|
+
if progress:
|
|
206
|
+
progress.close()
|
|
207
|
+
|
|
208
|
+
overall_statistics = calculate_statistics(aggregate_scores)
|
|
209
|
+
overall_statistics["total_runs"] = total_runs
|
|
210
|
+
overall_statistics["success_count"] = total_successes
|
|
211
|
+
overall_statistics["failure_count"] = total_runs - total_successes
|
|
212
|
+
return EvaluationReport(
|
|
213
|
+
rubric_config=rubric_config,
|
|
214
|
+
config_path=config_path,
|
|
215
|
+
data_path=data_path,
|
|
216
|
+
number=number,
|
|
217
|
+
record_results=record_results,
|
|
218
|
+
overall_statistics=overall_statistics,
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
@staticmethod
|
|
222
|
+
def _resolve_preview_text(result: Optional[dict[str, Any]], fallback: Optional[str]) -> Optional[str]:
|
|
223
|
+
if not isinstance(result, dict):
|
|
224
|
+
return fallback
|
|
225
|
+
preview = collapse_preview_text(result.get("preview"))
|
|
226
|
+
if preview:
|
|
227
|
+
return preview
|
|
228
|
+
|
|
229
|
+
raw_payload = result.get("raw")
|
|
230
|
+
if isinstance(raw_payload, dict):
|
|
231
|
+
for key in ("preview", "summary", "text"):
|
|
232
|
+
preview = collapse_preview_text(raw_payload.get(key))
|
|
233
|
+
if preview:
|
|
234
|
+
return preview
|
|
235
|
+
return fallback
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def _extract_float(value: Any) -> Optional[float]:
|
|
239
|
+
try:
|
|
240
|
+
if isinstance(value, (int, float)) and not isinstance(value, bool):
|
|
241
|
+
return float(value)
|
|
242
|
+
return None
|
|
243
|
+
except (TypeError, ValueError):
|
|
244
|
+
return None
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def _normalize_optional_text(value: Any) -> Optional[str]:
|
|
248
|
+
if value is None:
|
|
249
|
+
return None
|
|
250
|
+
text = str(value).strip()
|
|
251
|
+
return text or None
|
|
@@ -0,0 +1,307 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from datetime import datetime, timezone
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Callable, Optional
|
|
7
|
+
|
|
8
|
+
from .engine import EvaluationRecordResult, EvaluationReport, EvaluationRun
|
|
9
|
+
from .errors import CLIError
|
|
10
|
+
from .shared import calculate_stat_deltas
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TextReportFormatter:
|
|
14
|
+
"""Builds human-readable text lines for an evaluation report."""
|
|
15
|
+
|
|
16
|
+
def build(
|
|
17
|
+
self,
|
|
18
|
+
report: EvaluationReport,
|
|
19
|
+
baseline: Optional["BaselineStatistics"] = None,
|
|
20
|
+
) -> list[str]:
|
|
21
|
+
lines: list[str] = []
|
|
22
|
+
provider = str(report.rubric_config.model_info.get("provider", "")).strip() or "<unknown>"
|
|
23
|
+
model_name = str(report.rubric_config.model_info.get("model", "")).strip() or "<unspecified>"
|
|
24
|
+
|
|
25
|
+
lines.append(
|
|
26
|
+
f"Rubric '{report.rubric_config.rubric_id}' "
|
|
27
|
+
f"({report.rubric_config.source_label}) -> provider '{provider}' model '{model_name}'"
|
|
28
|
+
)
|
|
29
|
+
lines.append(f"Loaded {len(report.record_results)} matching record(s) from {report.data_path}")
|
|
30
|
+
lines.append(f"Running {report.number} evaluation(s) per record")
|
|
31
|
+
lines.append("")
|
|
32
|
+
|
|
33
|
+
for record_result in report.record_results:
|
|
34
|
+
lines.extend(self._format_record(record_result))
|
|
35
|
+
if baseline is not None:
|
|
36
|
+
lines.extend(self._format_baseline(report, baseline))
|
|
37
|
+
return lines
|
|
38
|
+
|
|
39
|
+
def _format_record(self, record_result: EvaluationRecordResult) -> list[str]:
|
|
40
|
+
lines: list[str] = [f"[{record_result.conversation_label}]"]
|
|
41
|
+
total_runs = len(record_result.runs)
|
|
42
|
+
for index, run in enumerate(record_result.runs):
|
|
43
|
+
lines.extend(self._format_run(run))
|
|
44
|
+
if index < total_runs - 1:
|
|
45
|
+
lines.append("")
|
|
46
|
+
|
|
47
|
+
summary_lines = self._format_summary(record_result.statistics, len(record_result.runs))
|
|
48
|
+
if summary_lines:
|
|
49
|
+
lines.extend(summary_lines)
|
|
50
|
+
lines.append("")
|
|
51
|
+
return lines
|
|
52
|
+
|
|
53
|
+
def _format_run(self, run: EvaluationRun) -> list[str]:
|
|
54
|
+
lines: list[str] = [f" Run {run.run_index:02d} [{run.status.upper()}]"]
|
|
55
|
+
if run.status == "success":
|
|
56
|
+
score_text = "n/a" if run.score is None else f"{run.score:.4f}"
|
|
57
|
+
lines.append(self._format_detail_line("score", score_text))
|
|
58
|
+
if run.preview:
|
|
59
|
+
lines.append(self._format_detail_line("preview", run.preview))
|
|
60
|
+
explanation = run.explanation or "(no explanation provided)"
|
|
61
|
+
lines.append(self._format_detail_line("explanation", explanation))
|
|
62
|
+
else:
|
|
63
|
+
error_text = run.error or "(no error message provided)"
|
|
64
|
+
lines.append(self._format_detail_line("error", error_text))
|
|
65
|
+
if run.preview:
|
|
66
|
+
lines.append(self._format_detail_line("preview", run.preview))
|
|
67
|
+
if run.explanation:
|
|
68
|
+
lines.append(self._format_detail_line("explanation", run.explanation))
|
|
69
|
+
lines.append(self._format_detail_line("duration", f"{run.duration_seconds:.2f}s"))
|
|
70
|
+
return lines
|
|
71
|
+
|
|
72
|
+
def _format_summary(self, statistics: dict[str, float], total_runs: int) -> list[str]:
|
|
73
|
+
if not statistics:
|
|
74
|
+
return []
|
|
75
|
+
success_count = int(round(statistics.get("success_count", total_runs)))
|
|
76
|
+
failure_count = int(round(statistics.get("failure_count", total_runs - success_count)))
|
|
77
|
+
if total_runs <= 1 and failure_count == 0:
|
|
78
|
+
return []
|
|
79
|
+
|
|
80
|
+
lines = [" Summary:"]
|
|
81
|
+
lines.append(f" total: {int(round(statistics.get('total_runs', total_runs)))}")
|
|
82
|
+
lines.append(f" successes: {success_count}")
|
|
83
|
+
lines.append(f" failures: {failure_count}")
|
|
84
|
+
if success_count > 0:
|
|
85
|
+
lines.append(f" average: {statistics.get('average', 0.0):.4f}")
|
|
86
|
+
lines.append(f" variance: {statistics.get('variance', 0.0):.6f}")
|
|
87
|
+
lines.append(f" stdev: {statistics.get('stdev', 0.0):.4f}")
|
|
88
|
+
lines.append(f" min/max: {statistics.get('min', 0.0):.4f} / {statistics.get('max', 0.0):.4f}")
|
|
89
|
+
else:
|
|
90
|
+
lines.append(" average: n/a")
|
|
91
|
+
lines.append(" variance: n/a")
|
|
92
|
+
lines.append(" stdev: n/a")
|
|
93
|
+
lines.append(" min/max: n/a")
|
|
94
|
+
return lines
|
|
95
|
+
|
|
96
|
+
def _format_baseline(
|
|
97
|
+
self,
|
|
98
|
+
report: EvaluationReport,
|
|
99
|
+
baseline: "BaselineStatistics",
|
|
100
|
+
) -> list[str]:
|
|
101
|
+
lines = [f"Baseline comparison (source: {baseline.source_path}):"]
|
|
102
|
+
deltas = baseline.delta(report.overall_statistics)
|
|
103
|
+
keys = ["average", "variance", "stdev", "min", "max", "success_count", "failure_count", "total_runs"]
|
|
104
|
+
|
|
105
|
+
for key in keys:
|
|
106
|
+
if key not in baseline.statistics or key not in report.overall_statistics:
|
|
107
|
+
continue
|
|
108
|
+
baseline_value = float(baseline.statistics[key])
|
|
109
|
+
current_value = float(report.overall_statistics[key])
|
|
110
|
+
delta_value = float(deltas.get(key, current_value - baseline_value))
|
|
111
|
+
|
|
112
|
+
if key in {"success_count", "failure_count", "total_runs"}:
|
|
113
|
+
baseline_str = f"{int(round(baseline_value))}"
|
|
114
|
+
current_str = f"{int(round(current_value))}"
|
|
115
|
+
delta_str = f"{delta_value:+.0f}"
|
|
116
|
+
else:
|
|
117
|
+
precision = 6 if key == "variance" else 4
|
|
118
|
+
baseline_str = format(baseline_value, f".{precision}f")
|
|
119
|
+
current_str = format(current_value, f".{precision}f")
|
|
120
|
+
delta_str = format(delta_value, f"+.{precision}f")
|
|
121
|
+
|
|
122
|
+
lines.append(
|
|
123
|
+
f" {key:12s} baseline={baseline_str} current={current_str} delta={delta_str}"
|
|
124
|
+
)
|
|
125
|
+
return lines
|
|
126
|
+
|
|
127
|
+
@staticmethod
|
|
128
|
+
def _format_detail_line(label: str, value: str, *, indent: int = 4) -> str:
|
|
129
|
+
indent_str = " " * indent
|
|
130
|
+
value_str = str(value)
|
|
131
|
+
continuation_indent = indent_str + " "
|
|
132
|
+
value_str = value_str.replace("\n", f"\n{continuation_indent}")
|
|
133
|
+
return f"{indent_str}{label}: {value_str}"
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class ConsoleReportRenderer:
|
|
137
|
+
"""Pretty prints evaluation reports to stdout (or any printer function)."""
|
|
138
|
+
|
|
139
|
+
def __init__(
|
|
140
|
+
self,
|
|
141
|
+
printer: Callable[[str], None] = print,
|
|
142
|
+
formatter: Optional[TextReportFormatter] = None,
|
|
143
|
+
):
|
|
144
|
+
self._printer = printer
|
|
145
|
+
self._formatter = formatter or TextReportFormatter()
|
|
146
|
+
|
|
147
|
+
def render(
|
|
148
|
+
self,
|
|
149
|
+
report: EvaluationReport,
|
|
150
|
+
baseline: Optional["BaselineStatistics"] = None,
|
|
151
|
+
) -> None:
|
|
152
|
+
for line in self._formatter.build(report, baseline):
|
|
153
|
+
self._printer(line)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class BaselineStatistics:
|
|
157
|
+
def __init__(self, source_path: Path, statistics: dict[str, float]):
|
|
158
|
+
self.source_path = source_path
|
|
159
|
+
self.statistics = statistics
|
|
160
|
+
|
|
161
|
+
def delta(self, current: dict[str, float]) -> dict[str, float]:
|
|
162
|
+
return calculate_stat_deltas(self.statistics, current)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class BaselineComparator:
|
|
166
|
+
"""Loads baseline JSON payloads and extracts statistics."""
|
|
167
|
+
|
|
168
|
+
def load(self, path: Path) -> BaselineStatistics:
|
|
169
|
+
if not path.exists():
|
|
170
|
+
raise CLIError(f"Baseline path '{path}' does not exist.")
|
|
171
|
+
if path.is_dir():
|
|
172
|
+
raise CLIError(f"Baseline path '{path}' is a directory, expected JSON file.")
|
|
173
|
+
|
|
174
|
+
try:
|
|
175
|
+
payload = json.loads(path.read_text(encoding="utf-8"))
|
|
176
|
+
except json.JSONDecodeError as exc:
|
|
177
|
+
raise CLIError(f"Failed to parse baseline JSON: {exc}") from exc
|
|
178
|
+
|
|
179
|
+
if not isinstance(payload, dict):
|
|
180
|
+
raise CLIError("Baseline JSON must contain an object.")
|
|
181
|
+
|
|
182
|
+
source = None
|
|
183
|
+
if isinstance(payload.get("overall_statistics"), dict):
|
|
184
|
+
source = payload["overall_statistics"]
|
|
185
|
+
elif all(key in payload for key in ("average", "variance", "stdev")):
|
|
186
|
+
source = payload
|
|
187
|
+
if source is None:
|
|
188
|
+
raise CLIError(
|
|
189
|
+
"Baseline JSON must include an 'overall_statistics' object or top-level statistics."
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
statistics: dict[str, float] = {}
|
|
193
|
+
for key, value in source.items():
|
|
194
|
+
try:
|
|
195
|
+
statistics[key] = float(value)
|
|
196
|
+
except (TypeError, ValueError):
|
|
197
|
+
continue
|
|
198
|
+
if not statistics:
|
|
199
|
+
raise CLIError("Baseline statistics could not be parsed into numeric values.")
|
|
200
|
+
|
|
201
|
+
return BaselineStatistics(source_path=path, statistics=statistics)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
class JsonReportFormatter:
|
|
205
|
+
"""Builds JSON-serialisable payloads for evaluation reports."""
|
|
206
|
+
|
|
207
|
+
def build(
|
|
208
|
+
self,
|
|
209
|
+
report: EvaluationReport,
|
|
210
|
+
*,
|
|
211
|
+
output_identifier: Optional[str],
|
|
212
|
+
baseline: Optional[BaselineStatistics],
|
|
213
|
+
) -> dict[str, Any]:
|
|
214
|
+
provider = str(report.rubric_config.model_info.get("provider", "")).strip() or "<unknown>"
|
|
215
|
+
model_name = str(report.rubric_config.model_info.get("model", "")).strip() or "<unspecified>"
|
|
216
|
+
|
|
217
|
+
generated: dict[str, Any] = {
|
|
218
|
+
"generated_at": datetime.now(timezone.utc).isoformat(),
|
|
219
|
+
"rubric_id": report.rubric_config.rubric_id,
|
|
220
|
+
"rubric_source": report.rubric_config.source_label,
|
|
221
|
+
"provider": provider,
|
|
222
|
+
"model": model_name,
|
|
223
|
+
"number": report.number,
|
|
224
|
+
"config_path": str(report.config_path),
|
|
225
|
+
"data_path": str(report.data_path),
|
|
226
|
+
"overall_statistics": _normalise_statistics(report.overall_statistics),
|
|
227
|
+
"records": [],
|
|
228
|
+
}
|
|
229
|
+
if output_identifier is not None:
|
|
230
|
+
generated["output_identifier"] = output_identifier
|
|
231
|
+
|
|
232
|
+
for record_result in report.record_results:
|
|
233
|
+
conversation_label = record_result.conversation_label
|
|
234
|
+
record_identifier = record_result.record.record_identifier(conversation_label)
|
|
235
|
+
|
|
236
|
+
record_payload: dict[str, Any] = {
|
|
237
|
+
"id": record_identifier,
|
|
238
|
+
"record_index": record_result.record_index,
|
|
239
|
+
"conversation_id": conversation_label,
|
|
240
|
+
"input_record": record_result.record.payload,
|
|
241
|
+
"statistics": _normalise_statistics(record_result.statistics),
|
|
242
|
+
"runs": [],
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
for run in record_result.runs:
|
|
246
|
+
record_payload["runs"].append(
|
|
247
|
+
{
|
|
248
|
+
"run_index": run.run_index,
|
|
249
|
+
"status": run.status,
|
|
250
|
+
"started_at": run.started_at.isoformat(),
|
|
251
|
+
"completed_at": run.completed_at.isoformat(),
|
|
252
|
+
"duration_seconds": run.duration_seconds,
|
|
253
|
+
"score": run.score,
|
|
254
|
+
"explanation": run.explanation,
|
|
255
|
+
"preview": run.preview,
|
|
256
|
+
"error": run.error,
|
|
257
|
+
"raw": run.raw,
|
|
258
|
+
}
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
generated["records"].append(record_payload)
|
|
262
|
+
|
|
263
|
+
if baseline is not None:
|
|
264
|
+
generated["baseline_comparison"] = {
|
|
265
|
+
"source_path": str(baseline.source_path),
|
|
266
|
+
"baseline_statistics": _normalise_statistics(baseline.statistics),
|
|
267
|
+
"delta_statistics": _normalise_statistics(baseline.delta(report.overall_statistics)),
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
return generated
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
class JsonReportWriter:
|
|
274
|
+
"""Serialises an evaluation report to disk."""
|
|
275
|
+
|
|
276
|
+
def __init__(self, formatter: Optional[JsonReportFormatter] = None):
|
|
277
|
+
self._formatter = formatter or JsonReportFormatter()
|
|
278
|
+
|
|
279
|
+
def write(
|
|
280
|
+
self,
|
|
281
|
+
report: EvaluationReport,
|
|
282
|
+
*,
|
|
283
|
+
output_path: Path,
|
|
284
|
+
output_identifier: Optional[str],
|
|
285
|
+
baseline: Optional[BaselineStatistics],
|
|
286
|
+
) -> Path:
|
|
287
|
+
parent_dir = output_path.parent
|
|
288
|
+
if parent_dir and not parent_dir.exists():
|
|
289
|
+
parent_dir.mkdir(parents=True, exist_ok=True)
|
|
290
|
+
|
|
291
|
+
payload = self._formatter.build(
|
|
292
|
+
report,
|
|
293
|
+
output_identifier=output_identifier,
|
|
294
|
+
baseline=baseline,
|
|
295
|
+
)
|
|
296
|
+
with output_path.open("w", encoding="utf-8") as fh:
|
|
297
|
+
json.dump(payload, fh, indent=2, ensure_ascii=False)
|
|
298
|
+
return output_path
|
|
299
|
+
|
|
300
|
+
def _normalise_statistics(stats: dict[str, float]) -> dict[str, Any]:
|
|
301
|
+
normalised: dict[str, Any] = {}
|
|
302
|
+
for key, value in stats.items():
|
|
303
|
+
if key in {"success_count", "failure_count", "total_runs"}:
|
|
304
|
+
normalised[key] = int(round(value))
|
|
305
|
+
else:
|
|
306
|
+
normalised[key] = float(value)
|
|
307
|
+
return normalised
|