osmosis-ai 0.1.9__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of osmosis-ai might be problematic. Click here for more details.
- osmosis_ai/__init__.py +19 -132
- osmosis_ai/cli.py +50 -0
- osmosis_ai/cli_commands.py +181 -0
- osmosis_ai/cli_services/__init__.py +60 -0
- osmosis_ai/cli_services/config.py +410 -0
- osmosis_ai/cli_services/dataset.py +175 -0
- osmosis_ai/cli_services/engine.py +421 -0
- osmosis_ai/cli_services/errors.py +7 -0
- osmosis_ai/cli_services/reporting.py +307 -0
- osmosis_ai/cli_services/session.py +174 -0
- osmosis_ai/cli_services/shared.py +209 -0
- osmosis_ai/consts.py +1 -1
- osmosis_ai/providers/__init__.py +36 -0
- osmosis_ai/providers/anthropic_provider.py +85 -0
- osmosis_ai/providers/base.py +60 -0
- osmosis_ai/providers/gemini_provider.py +314 -0
- osmosis_ai/providers/openai_family.py +607 -0
- osmosis_ai/providers/shared.py +92 -0
- osmosis_ai/rubric_eval.py +356 -0
- osmosis_ai/rubric_types.py +49 -0
- osmosis_ai/utils.py +258 -5
- osmosis_ai-0.2.4.dist-info/METADATA +314 -0
- osmosis_ai-0.2.4.dist-info/RECORD +27 -0
- osmosis_ai-0.2.4.dist-info/entry_points.txt +4 -0
- osmosis_ai-0.1.9.dist-info/METADATA +0 -143
- osmosis_ai-0.1.9.dist-info/RECORD +0 -8
- {osmosis_ai-0.1.9.dist-info → osmosis_ai-0.2.4.dist-info}/WHEEL +0 -0
- {osmosis_ai-0.1.9.dist-info → osmosis_ai-0.2.4.dist-info}/licenses/LICENSE +0 -0
- {osmosis_ai-0.1.9.dist-info → osmosis_ai-0.2.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,307 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from datetime import datetime, timezone
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Callable, Optional
|
|
7
|
+
|
|
8
|
+
from .engine import EvaluationRecordResult, EvaluationReport, EvaluationRun
|
|
9
|
+
from .errors import CLIError
|
|
10
|
+
from .shared import calculate_stat_deltas
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TextReportFormatter:
|
|
14
|
+
"""Builds human-readable text lines for an evaluation report."""
|
|
15
|
+
|
|
16
|
+
def build(
|
|
17
|
+
self,
|
|
18
|
+
report: EvaluationReport,
|
|
19
|
+
baseline: Optional["BaselineStatistics"] = None,
|
|
20
|
+
) -> list[str]:
|
|
21
|
+
lines: list[str] = []
|
|
22
|
+
provider = str(report.rubric_config.model_info.get("provider", "")).strip() or "<unknown>"
|
|
23
|
+
model_name = str(report.rubric_config.model_info.get("model", "")).strip() or "<unspecified>"
|
|
24
|
+
|
|
25
|
+
lines.append(
|
|
26
|
+
f"Rubric '{report.rubric_config.rubric_id}' "
|
|
27
|
+
f"({report.rubric_config.source_label}) -> provider '{provider}' model '{model_name}'"
|
|
28
|
+
)
|
|
29
|
+
lines.append(f"Loaded {len(report.record_results)} matching record(s) from {report.data_path}")
|
|
30
|
+
lines.append(f"Running {report.number} evaluation(s) per record")
|
|
31
|
+
lines.append("")
|
|
32
|
+
|
|
33
|
+
for record_result in report.record_results:
|
|
34
|
+
lines.extend(self._format_record(record_result))
|
|
35
|
+
if baseline is not None:
|
|
36
|
+
lines.extend(self._format_baseline(report, baseline))
|
|
37
|
+
return lines
|
|
38
|
+
|
|
39
|
+
def _format_record(self, record_result: EvaluationRecordResult) -> list[str]:
|
|
40
|
+
lines: list[str] = [f"[{record_result.conversation_label}]"]
|
|
41
|
+
total_runs = len(record_result.runs)
|
|
42
|
+
for index, run in enumerate(record_result.runs):
|
|
43
|
+
lines.extend(self._format_run(run))
|
|
44
|
+
if index < total_runs - 1:
|
|
45
|
+
lines.append("")
|
|
46
|
+
|
|
47
|
+
summary_lines = self._format_summary(record_result.statistics, len(record_result.runs))
|
|
48
|
+
if summary_lines:
|
|
49
|
+
lines.extend(summary_lines)
|
|
50
|
+
lines.append("")
|
|
51
|
+
return lines
|
|
52
|
+
|
|
53
|
+
def _format_run(self, run: EvaluationRun) -> list[str]:
|
|
54
|
+
lines: list[str] = [f" Run {run.run_index:02d} [{run.status.upper()}]"]
|
|
55
|
+
if run.status == "success":
|
|
56
|
+
score_text = "n/a" if run.score is None else f"{run.score:.4f}"
|
|
57
|
+
lines.append(self._format_detail_line("score", score_text))
|
|
58
|
+
if run.preview:
|
|
59
|
+
lines.append(self._format_detail_line("preview", run.preview))
|
|
60
|
+
explanation = run.explanation or "(no explanation provided)"
|
|
61
|
+
lines.append(self._format_detail_line("explanation", explanation))
|
|
62
|
+
else:
|
|
63
|
+
error_text = run.error or "(no error message provided)"
|
|
64
|
+
lines.append(self._format_detail_line("error", error_text))
|
|
65
|
+
if run.preview:
|
|
66
|
+
lines.append(self._format_detail_line("preview", run.preview))
|
|
67
|
+
if run.explanation:
|
|
68
|
+
lines.append(self._format_detail_line("explanation", run.explanation))
|
|
69
|
+
lines.append(self._format_detail_line("duration", f"{run.duration_seconds:.2f}s"))
|
|
70
|
+
return lines
|
|
71
|
+
|
|
72
|
+
def _format_summary(self, statistics: dict[str, float], total_runs: int) -> list[str]:
|
|
73
|
+
if not statistics:
|
|
74
|
+
return []
|
|
75
|
+
success_count = int(round(statistics.get("success_count", total_runs)))
|
|
76
|
+
failure_count = int(round(statistics.get("failure_count", total_runs - success_count)))
|
|
77
|
+
if total_runs <= 1 and failure_count == 0:
|
|
78
|
+
return []
|
|
79
|
+
|
|
80
|
+
lines = [" Summary:"]
|
|
81
|
+
lines.append(f" total: {int(round(statistics.get('total_runs', total_runs)))}")
|
|
82
|
+
lines.append(f" successes: {success_count}")
|
|
83
|
+
lines.append(f" failures: {failure_count}")
|
|
84
|
+
if success_count > 0:
|
|
85
|
+
lines.append(f" average: {statistics.get('average', 0.0):.4f}")
|
|
86
|
+
lines.append(f" variance: {statistics.get('variance', 0.0):.6f}")
|
|
87
|
+
lines.append(f" stdev: {statistics.get('stdev', 0.0):.4f}")
|
|
88
|
+
lines.append(f" min/max: {statistics.get('min', 0.0):.4f} / {statistics.get('max', 0.0):.4f}")
|
|
89
|
+
else:
|
|
90
|
+
lines.append(" average: n/a")
|
|
91
|
+
lines.append(" variance: n/a")
|
|
92
|
+
lines.append(" stdev: n/a")
|
|
93
|
+
lines.append(" min/max: n/a")
|
|
94
|
+
return lines
|
|
95
|
+
|
|
96
|
+
def _format_baseline(
|
|
97
|
+
self,
|
|
98
|
+
report: EvaluationReport,
|
|
99
|
+
baseline: "BaselineStatistics",
|
|
100
|
+
) -> list[str]:
|
|
101
|
+
lines = [f"Baseline comparison (source: {baseline.source_path}):"]
|
|
102
|
+
deltas = baseline.delta(report.overall_statistics)
|
|
103
|
+
keys = ["average", "variance", "stdev", "min", "max", "success_count", "failure_count", "total_runs"]
|
|
104
|
+
|
|
105
|
+
for key in keys:
|
|
106
|
+
if key not in baseline.statistics or key not in report.overall_statistics:
|
|
107
|
+
continue
|
|
108
|
+
baseline_value = float(baseline.statistics[key])
|
|
109
|
+
current_value = float(report.overall_statistics[key])
|
|
110
|
+
delta_value = float(deltas.get(key, current_value - baseline_value))
|
|
111
|
+
|
|
112
|
+
if key in {"success_count", "failure_count", "total_runs"}:
|
|
113
|
+
baseline_str = f"{int(round(baseline_value))}"
|
|
114
|
+
current_str = f"{int(round(current_value))}"
|
|
115
|
+
delta_str = f"{delta_value:+.0f}"
|
|
116
|
+
else:
|
|
117
|
+
precision = 6 if key == "variance" else 4
|
|
118
|
+
baseline_str = format(baseline_value, f".{precision}f")
|
|
119
|
+
current_str = format(current_value, f".{precision}f")
|
|
120
|
+
delta_str = format(delta_value, f"+.{precision}f")
|
|
121
|
+
|
|
122
|
+
lines.append(
|
|
123
|
+
f" {key:12s} baseline={baseline_str} current={current_str} delta={delta_str}"
|
|
124
|
+
)
|
|
125
|
+
return lines
|
|
126
|
+
|
|
127
|
+
@staticmethod
|
|
128
|
+
def _format_detail_line(label: str, value: str, *, indent: int = 4) -> str:
|
|
129
|
+
indent_str = " " * indent
|
|
130
|
+
value_str = str(value)
|
|
131
|
+
continuation_indent = indent_str + " "
|
|
132
|
+
value_str = value_str.replace("\n", f"\n{continuation_indent}")
|
|
133
|
+
return f"{indent_str}{label}: {value_str}"
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class ConsoleReportRenderer:
|
|
137
|
+
"""Pretty prints evaluation reports to stdout (or any printer function)."""
|
|
138
|
+
|
|
139
|
+
def __init__(
|
|
140
|
+
self,
|
|
141
|
+
printer: Callable[[str], None] = print,
|
|
142
|
+
formatter: Optional[TextReportFormatter] = None,
|
|
143
|
+
):
|
|
144
|
+
self._printer = printer
|
|
145
|
+
self._formatter = formatter or TextReportFormatter()
|
|
146
|
+
|
|
147
|
+
def render(
|
|
148
|
+
self,
|
|
149
|
+
report: EvaluationReport,
|
|
150
|
+
baseline: Optional["BaselineStatistics"] = None,
|
|
151
|
+
) -> None:
|
|
152
|
+
for line in self._formatter.build(report, baseline):
|
|
153
|
+
self._printer(line)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class BaselineStatistics:
|
|
157
|
+
def __init__(self, source_path: Path, statistics: dict[str, float]):
|
|
158
|
+
self.source_path = source_path
|
|
159
|
+
self.statistics = statistics
|
|
160
|
+
|
|
161
|
+
def delta(self, current: dict[str, float]) -> dict[str, float]:
|
|
162
|
+
return calculate_stat_deltas(self.statistics, current)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class BaselineComparator:
|
|
166
|
+
"""Loads baseline JSON payloads and extracts statistics."""
|
|
167
|
+
|
|
168
|
+
def load(self, path: Path) -> BaselineStatistics:
|
|
169
|
+
if not path.exists():
|
|
170
|
+
raise CLIError(f"Baseline path '{path}' does not exist.")
|
|
171
|
+
if path.is_dir():
|
|
172
|
+
raise CLIError(f"Baseline path '{path}' is a directory, expected JSON file.")
|
|
173
|
+
|
|
174
|
+
try:
|
|
175
|
+
payload = json.loads(path.read_text(encoding="utf-8"))
|
|
176
|
+
except json.JSONDecodeError as exc:
|
|
177
|
+
raise CLIError(f"Failed to parse baseline JSON: {exc}") from exc
|
|
178
|
+
|
|
179
|
+
if not isinstance(payload, dict):
|
|
180
|
+
raise CLIError("Baseline JSON must contain an object.")
|
|
181
|
+
|
|
182
|
+
source = None
|
|
183
|
+
if isinstance(payload.get("overall_statistics"), dict):
|
|
184
|
+
source = payload["overall_statistics"]
|
|
185
|
+
elif all(key in payload for key in ("average", "variance", "stdev")):
|
|
186
|
+
source = payload
|
|
187
|
+
if source is None:
|
|
188
|
+
raise CLIError(
|
|
189
|
+
"Baseline JSON must include an 'overall_statistics' object or top-level statistics."
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
statistics: dict[str, float] = {}
|
|
193
|
+
for key, value in source.items():
|
|
194
|
+
try:
|
|
195
|
+
statistics[key] = float(value)
|
|
196
|
+
except (TypeError, ValueError):
|
|
197
|
+
continue
|
|
198
|
+
if not statistics:
|
|
199
|
+
raise CLIError("Baseline statistics could not be parsed into numeric values.")
|
|
200
|
+
|
|
201
|
+
return BaselineStatistics(source_path=path, statistics=statistics)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
class JsonReportFormatter:
|
|
205
|
+
"""Builds JSON-serialisable payloads for evaluation reports."""
|
|
206
|
+
|
|
207
|
+
def build(
|
|
208
|
+
self,
|
|
209
|
+
report: EvaluationReport,
|
|
210
|
+
*,
|
|
211
|
+
output_identifier: Optional[str],
|
|
212
|
+
baseline: Optional[BaselineStatistics],
|
|
213
|
+
) -> dict[str, Any]:
|
|
214
|
+
provider = str(report.rubric_config.model_info.get("provider", "")).strip() or "<unknown>"
|
|
215
|
+
model_name = str(report.rubric_config.model_info.get("model", "")).strip() or "<unspecified>"
|
|
216
|
+
|
|
217
|
+
generated: dict[str, Any] = {
|
|
218
|
+
"generated_at": datetime.now(timezone.utc).isoformat(),
|
|
219
|
+
"rubric_id": report.rubric_config.rubric_id,
|
|
220
|
+
"rubric_source": report.rubric_config.source_label,
|
|
221
|
+
"provider": provider,
|
|
222
|
+
"model": model_name,
|
|
223
|
+
"number": report.number,
|
|
224
|
+
"config_path": str(report.config_path),
|
|
225
|
+
"data_path": str(report.data_path),
|
|
226
|
+
"overall_statistics": _normalise_statistics(report.overall_statistics),
|
|
227
|
+
"records": [],
|
|
228
|
+
}
|
|
229
|
+
if output_identifier is not None:
|
|
230
|
+
generated["output_identifier"] = output_identifier
|
|
231
|
+
|
|
232
|
+
for record_result in report.record_results:
|
|
233
|
+
conversation_label = record_result.conversation_label
|
|
234
|
+
record_identifier = record_result.record.record_identifier(conversation_label)
|
|
235
|
+
|
|
236
|
+
record_payload: dict[str, Any] = {
|
|
237
|
+
"id": record_identifier,
|
|
238
|
+
"record_index": record_result.record_index,
|
|
239
|
+
"conversation_id": conversation_label,
|
|
240
|
+
"input_record": record_result.record.payload,
|
|
241
|
+
"statistics": _normalise_statistics(record_result.statistics),
|
|
242
|
+
"runs": [],
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
for run in record_result.runs:
|
|
246
|
+
record_payload["runs"].append(
|
|
247
|
+
{
|
|
248
|
+
"run_index": run.run_index,
|
|
249
|
+
"status": run.status,
|
|
250
|
+
"started_at": run.started_at.isoformat(),
|
|
251
|
+
"completed_at": run.completed_at.isoformat(),
|
|
252
|
+
"duration_seconds": run.duration_seconds,
|
|
253
|
+
"score": run.score,
|
|
254
|
+
"explanation": run.explanation,
|
|
255
|
+
"preview": run.preview,
|
|
256
|
+
"error": run.error,
|
|
257
|
+
"raw": run.raw,
|
|
258
|
+
}
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
generated["records"].append(record_payload)
|
|
262
|
+
|
|
263
|
+
if baseline is not None:
|
|
264
|
+
generated["baseline_comparison"] = {
|
|
265
|
+
"source_path": str(baseline.source_path),
|
|
266
|
+
"baseline_statistics": _normalise_statistics(baseline.statistics),
|
|
267
|
+
"delta_statistics": _normalise_statistics(baseline.delta(report.overall_statistics)),
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
return generated
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
class JsonReportWriter:
|
|
274
|
+
"""Serialises an evaluation report to disk."""
|
|
275
|
+
|
|
276
|
+
def __init__(self, formatter: Optional[JsonReportFormatter] = None):
|
|
277
|
+
self._formatter = formatter or JsonReportFormatter()
|
|
278
|
+
|
|
279
|
+
def write(
|
|
280
|
+
self,
|
|
281
|
+
report: EvaluationReport,
|
|
282
|
+
*,
|
|
283
|
+
output_path: Path,
|
|
284
|
+
output_identifier: Optional[str],
|
|
285
|
+
baseline: Optional[BaselineStatistics],
|
|
286
|
+
) -> Path:
|
|
287
|
+
parent_dir = output_path.parent
|
|
288
|
+
if parent_dir and not parent_dir.exists():
|
|
289
|
+
parent_dir.mkdir(parents=True, exist_ok=True)
|
|
290
|
+
|
|
291
|
+
payload = self._formatter.build(
|
|
292
|
+
report,
|
|
293
|
+
output_identifier=output_identifier,
|
|
294
|
+
baseline=baseline,
|
|
295
|
+
)
|
|
296
|
+
with output_path.open("w", encoding="utf-8") as fh:
|
|
297
|
+
json.dump(payload, fh, indent=2, ensure_ascii=False)
|
|
298
|
+
return output_path
|
|
299
|
+
|
|
300
|
+
def _normalise_statistics(stats: dict[str, float]) -> dict[str, Any]:
|
|
301
|
+
normalised: dict[str, Any] = {}
|
|
302
|
+
for key, value in stats.items():
|
|
303
|
+
if key in {"success_count", "failure_count", "total_runs"}:
|
|
304
|
+
normalised[key] = int(round(value))
|
|
305
|
+
else:
|
|
306
|
+
normalised[key] = float(value)
|
|
307
|
+
return normalised
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Callable, Optional, Sequence
|
|
7
|
+
|
|
8
|
+
from ..rubric_eval import ensure_api_key_available
|
|
9
|
+
from ..rubric_types import MissingAPIKeyError
|
|
10
|
+
from .config import RubricConfig, RubricSuite, discover_rubric_config_path, load_rubric_suite
|
|
11
|
+
from .dataset import DatasetLoader, DatasetRecord
|
|
12
|
+
from .engine import RubricEvaluationEngine, EvaluationReport
|
|
13
|
+
from .errors import CLIError
|
|
14
|
+
from .reporting import BaselineComparator, BaselineStatistics, JsonReportWriter
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
_CACHE_ROOT = Path("~/.cache/osmosis/eval_result").expanduser()
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _sanitise_rubric_folder(rubric_id: str) -> str:
|
|
21
|
+
"""Produce a filesystem-safe folder name for the rubric id."""
|
|
22
|
+
clean = "".join(ch if ch.isalnum() or ch in {"-", "_"} else "_" for ch in rubric_id.strip())
|
|
23
|
+
clean = clean.strip("_") or "rubric"
|
|
24
|
+
return clean.lower()
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass(frozen=True)
|
|
28
|
+
class EvaluationSessionRequest:
|
|
29
|
+
rubric_id: str
|
|
30
|
+
data_path: Path
|
|
31
|
+
number: int = 1
|
|
32
|
+
config_path: Optional[Path] = None
|
|
33
|
+
output_path: Optional[Path] = None
|
|
34
|
+
output_identifier: Optional[str] = None
|
|
35
|
+
baseline_path: Optional[Path] = None
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class EvaluationSessionResult:
|
|
40
|
+
request: EvaluationSessionRequest
|
|
41
|
+
config_path: Path
|
|
42
|
+
data_path: Path
|
|
43
|
+
rubric_config: RubricConfig
|
|
44
|
+
records: Sequence[DatasetRecord]
|
|
45
|
+
report: EvaluationReport
|
|
46
|
+
baseline: Optional[BaselineStatistics]
|
|
47
|
+
written_path: Optional[Path]
|
|
48
|
+
output_identifier: Optional[str]
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class EvaluationSession:
|
|
52
|
+
"""Coordinates rubric evaluation end-to-end for reusable orchestration."""
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
*,
|
|
57
|
+
config_locator: Callable[[Optional[str], Path], Path] = discover_rubric_config_path,
|
|
58
|
+
suite_loader: Callable[[Path], RubricSuite] = load_rubric_suite,
|
|
59
|
+
dataset_loader: Optional[DatasetLoader] = None,
|
|
60
|
+
engine: Optional[RubricEvaluationEngine] = None,
|
|
61
|
+
baseline_comparator: Optional[BaselineComparator] = None,
|
|
62
|
+
report_writer: Optional[JsonReportWriter] = None,
|
|
63
|
+
identifier_factory: Optional[Callable[[], str]] = None,
|
|
64
|
+
):
|
|
65
|
+
self._config_locator = config_locator
|
|
66
|
+
self._suite_loader = suite_loader
|
|
67
|
+
self._dataset_loader = dataset_loader or DatasetLoader()
|
|
68
|
+
self._engine = engine or RubricEvaluationEngine()
|
|
69
|
+
self._baseline_comparator = baseline_comparator or BaselineComparator()
|
|
70
|
+
self._report_writer = report_writer or JsonReportWriter()
|
|
71
|
+
self._identifier_factory = identifier_factory or self._default_identifier
|
|
72
|
+
|
|
73
|
+
def execute(self, request: EvaluationSessionRequest) -> EvaluationSessionResult:
|
|
74
|
+
rubric_id = request.rubric_id.strip()
|
|
75
|
+
if not rubric_id:
|
|
76
|
+
raise CLIError("Rubric identifier cannot be empty.")
|
|
77
|
+
|
|
78
|
+
number_value = request.number if request.number is not None else 1
|
|
79
|
+
number = int(number_value)
|
|
80
|
+
if number < 1:
|
|
81
|
+
raise CLIError("Number of runs must be a positive integer.")
|
|
82
|
+
|
|
83
|
+
data_path = request.data_path.expanduser()
|
|
84
|
+
if not data_path.exists():
|
|
85
|
+
raise CLIError(f"Data path '{data_path}' does not exist.")
|
|
86
|
+
if data_path.is_dir():
|
|
87
|
+
raise CLIError(f"Expected a JSONL file but received directory '{data_path}'.")
|
|
88
|
+
|
|
89
|
+
config_override = str(request.config_path.expanduser()) if request.config_path else None
|
|
90
|
+
config_path = self._config_locator(config_override, data_path)
|
|
91
|
+
suite = self._suite_loader(config_path)
|
|
92
|
+
rubric_config = suite.get(rubric_id)
|
|
93
|
+
|
|
94
|
+
try:
|
|
95
|
+
ensure_api_key_available(rubric_config.model_info)
|
|
96
|
+
except (MissingAPIKeyError, TypeError) as exc:
|
|
97
|
+
raise CLIError(str(exc)) from exc
|
|
98
|
+
|
|
99
|
+
all_records = self._dataset_loader.load(data_path)
|
|
100
|
+
matching_records = [
|
|
101
|
+
record for record in all_records if record.rubric_id.lower() == rubric_id.lower()
|
|
102
|
+
]
|
|
103
|
+
if not matching_records:
|
|
104
|
+
raise CLIError(f"No records in '{data_path}' reference rubric '{rubric_id}'.")
|
|
105
|
+
|
|
106
|
+
baseline_stats = self._load_baseline(request.baseline_path)
|
|
107
|
+
|
|
108
|
+
resolved_output_path, resolved_identifier = self._resolve_output_path(
|
|
109
|
+
request.output_path,
|
|
110
|
+
request.output_identifier,
|
|
111
|
+
rubric_id=rubric_id,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
report = self._engine.execute(
|
|
115
|
+
rubric_config=rubric_config,
|
|
116
|
+
config_path=config_path,
|
|
117
|
+
data_path=data_path,
|
|
118
|
+
records=matching_records,
|
|
119
|
+
number=number,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
written_path = None
|
|
123
|
+
if resolved_output_path is not None:
|
|
124
|
+
written_path = self._report_writer.write(
|
|
125
|
+
report,
|
|
126
|
+
output_path=resolved_output_path,
|
|
127
|
+
output_identifier=resolved_identifier,
|
|
128
|
+
baseline=baseline_stats,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
return EvaluationSessionResult(
|
|
132
|
+
request=request,
|
|
133
|
+
config_path=config_path,
|
|
134
|
+
data_path=data_path,
|
|
135
|
+
rubric_config=rubric_config,
|
|
136
|
+
records=matching_records,
|
|
137
|
+
report=report,
|
|
138
|
+
baseline=baseline_stats,
|
|
139
|
+
written_path=written_path,
|
|
140
|
+
output_identifier=resolved_identifier,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
def _load_baseline(self, baseline_path: Optional[Path]) -> Optional[BaselineStatistics]:
|
|
144
|
+
if baseline_path is None:
|
|
145
|
+
return None
|
|
146
|
+
resolved = baseline_path.expanduser()
|
|
147
|
+
return self._baseline_comparator.load(resolved)
|
|
148
|
+
|
|
149
|
+
def _resolve_output_path(
|
|
150
|
+
self,
|
|
151
|
+
output_candidate: Optional[Path],
|
|
152
|
+
output_identifier: Optional[str],
|
|
153
|
+
*,
|
|
154
|
+
rubric_id: str,
|
|
155
|
+
) -> tuple[Optional[Path], Optional[str]]:
|
|
156
|
+
if output_candidate is None:
|
|
157
|
+
identifier = output_identifier or self._identifier_factory()
|
|
158
|
+
target_dir = _CACHE_ROOT / _sanitise_rubric_folder(rubric_id)
|
|
159
|
+
target_dir.mkdir(parents=True, exist_ok=True)
|
|
160
|
+
return target_dir / f"rubric_eval_result_{identifier}.json", identifier
|
|
161
|
+
|
|
162
|
+
candidate = output_candidate.expanduser()
|
|
163
|
+
if candidate.suffix:
|
|
164
|
+
if candidate.exists() and candidate.is_dir():
|
|
165
|
+
raise CLIError(f"Output path '{candidate}' is a directory.")
|
|
166
|
+
return candidate, output_identifier
|
|
167
|
+
|
|
168
|
+
candidate.mkdir(parents=True, exist_ok=True)
|
|
169
|
+
identifier = output_identifier or self._identifier_factory()
|
|
170
|
+
return candidate / f"rubric_eval_result_{identifier}.json", identifier
|
|
171
|
+
|
|
172
|
+
@staticmethod
|
|
173
|
+
def _default_identifier() -> str:
|
|
174
|
+
return str(int(time.time()))
|
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from statistics import mean, pvariance, pstdev
|
|
4
|
+
from typing import Any, Collection, Optional, Set
|
|
5
|
+
|
|
6
|
+
from .errors import CLIError
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def coerce_optional_float(value: Any, field_name: str, source_label: str) -> Optional[float]:
|
|
10
|
+
if value is None:
|
|
11
|
+
return None
|
|
12
|
+
if isinstance(value, (int, float)) and not isinstance(value, bool):
|
|
13
|
+
return float(value)
|
|
14
|
+
raise CLIError(
|
|
15
|
+
f"Expected '{field_name}' in {source_label} to be numeric, got {type(value).__name__}."
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def collapse_preview_text(value: Any, *, max_length: int = 140) -> Optional[str]:
|
|
20
|
+
if not isinstance(value, str):
|
|
21
|
+
return None
|
|
22
|
+
collapsed = " ".join(value.strip().split())
|
|
23
|
+
if not collapsed:
|
|
24
|
+
return None
|
|
25
|
+
if len(collapsed) > max_length:
|
|
26
|
+
collapsed = collapsed[: max_length - 3].rstrip() + "..."
|
|
27
|
+
return collapsed
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def calculate_statistics(scores: list[float]) -> dict[str, float]:
|
|
31
|
+
if not scores:
|
|
32
|
+
return {
|
|
33
|
+
"average": 0.0,
|
|
34
|
+
"variance": 0.0,
|
|
35
|
+
"stdev": 0.0,
|
|
36
|
+
"min": 0.0,
|
|
37
|
+
"max": 0.0,
|
|
38
|
+
}
|
|
39
|
+
average = mean(scores)
|
|
40
|
+
variance = pvariance(scores)
|
|
41
|
+
std_dev = pstdev(scores)
|
|
42
|
+
return {
|
|
43
|
+
"average": average,
|
|
44
|
+
"variance": variance,
|
|
45
|
+
"stdev": std_dev,
|
|
46
|
+
"min": min(scores),
|
|
47
|
+
"max": max(scores),
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def calculate_stat_deltas(baseline: dict[str, float], current: dict[str, float]) -> dict[str, float]:
|
|
52
|
+
delta: dict[str, float] = {}
|
|
53
|
+
for key, current_value in current.items():
|
|
54
|
+
if key not in baseline:
|
|
55
|
+
continue
|
|
56
|
+
try:
|
|
57
|
+
baseline_value = float(baseline[key])
|
|
58
|
+
current_numeric = float(current_value)
|
|
59
|
+
except (TypeError, ValueError):
|
|
60
|
+
continue
|
|
61
|
+
delta[key] = current_numeric - baseline_value
|
|
62
|
+
return delta
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def gather_text_fragments(
|
|
66
|
+
node: Any,
|
|
67
|
+
fragments: list[str],
|
|
68
|
+
*,
|
|
69
|
+
allow_free_strings: bool = False,
|
|
70
|
+
seen: Optional[Set[int]] = None,
|
|
71
|
+
string_key_allowlist: Optional[Collection[str]] = None,
|
|
72
|
+
) -> None:
|
|
73
|
+
"""Collect textual snippets from nested message-like structures.
|
|
74
|
+
|
|
75
|
+
The traversal favours common chat-completions shapes (e.g. ``{"type": "text"}``
|
|
76
|
+
blocks) and avoids indiscriminately pulling in metadata values such as IDs.
|
|
77
|
+
``allow_free_strings`` controls whether bare strings encountered at the current
|
|
78
|
+
level should be considered textual content (useful for raw message content but
|
|
79
|
+
typically disabled for metadata fields).
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
if seen is None:
|
|
83
|
+
seen = set()
|
|
84
|
+
|
|
85
|
+
if isinstance(node, str):
|
|
86
|
+
if allow_free_strings:
|
|
87
|
+
stripped = node.strip()
|
|
88
|
+
if stripped:
|
|
89
|
+
fragments.append(stripped)
|
|
90
|
+
return
|
|
91
|
+
|
|
92
|
+
if isinstance(node, list):
|
|
93
|
+
for item in node:
|
|
94
|
+
gather_text_fragments(
|
|
95
|
+
item,
|
|
96
|
+
fragments,
|
|
97
|
+
allow_free_strings=allow_free_strings,
|
|
98
|
+
seen=seen,
|
|
99
|
+
string_key_allowlist=string_key_allowlist,
|
|
100
|
+
)
|
|
101
|
+
return
|
|
102
|
+
|
|
103
|
+
if not isinstance(node, dict):
|
|
104
|
+
return
|
|
105
|
+
|
|
106
|
+
node_id = id(node)
|
|
107
|
+
if node_id in seen:
|
|
108
|
+
return
|
|
109
|
+
seen.add(node_id)
|
|
110
|
+
|
|
111
|
+
allowlist = {"text", "value", "message"}
|
|
112
|
+
if string_key_allowlist is not None:
|
|
113
|
+
allowlist = {key.lower() for key in string_key_allowlist}
|
|
114
|
+
else:
|
|
115
|
+
allowlist = {key.lower() for key in allowlist}
|
|
116
|
+
|
|
117
|
+
prioritized_keys = ("text", "value")
|
|
118
|
+
handled_keys: Set[str] = {
|
|
119
|
+
"text",
|
|
120
|
+
"value",
|
|
121
|
+
"content",
|
|
122
|
+
"message",
|
|
123
|
+
"parts",
|
|
124
|
+
"input_text",
|
|
125
|
+
"output_text",
|
|
126
|
+
"type",
|
|
127
|
+
"role",
|
|
128
|
+
"name",
|
|
129
|
+
"id",
|
|
130
|
+
"index",
|
|
131
|
+
"finish_reason",
|
|
132
|
+
"reason",
|
|
133
|
+
"tool_call_id",
|
|
134
|
+
"metadata",
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
for key in prioritized_keys:
|
|
138
|
+
if key not in node:
|
|
139
|
+
continue
|
|
140
|
+
before_count = len(fragments)
|
|
141
|
+
gather_text_fragments(
|
|
142
|
+
node[key],
|
|
143
|
+
fragments,
|
|
144
|
+
allow_free_strings=True,
|
|
145
|
+
seen=seen,
|
|
146
|
+
string_key_allowlist=string_key_allowlist,
|
|
147
|
+
)
|
|
148
|
+
if len(fragments) > before_count:
|
|
149
|
+
break
|
|
150
|
+
|
|
151
|
+
if node.get("type") == "tool_result" and "content" in node:
|
|
152
|
+
gather_text_fragments(
|
|
153
|
+
node["content"],
|
|
154
|
+
fragments,
|
|
155
|
+
allow_free_strings=True,
|
|
156
|
+
seen=seen,
|
|
157
|
+
string_key_allowlist=string_key_allowlist,
|
|
158
|
+
)
|
|
159
|
+
elif "content" in node:
|
|
160
|
+
gather_text_fragments(
|
|
161
|
+
node["content"],
|
|
162
|
+
fragments,
|
|
163
|
+
allow_free_strings=True,
|
|
164
|
+
seen=seen,
|
|
165
|
+
string_key_allowlist=string_key_allowlist,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
for key in ("message", "parts", "input_text", "output_text"):
|
|
169
|
+
if key in node:
|
|
170
|
+
gather_text_fragments(
|
|
171
|
+
node[key],
|
|
172
|
+
fragments,
|
|
173
|
+
allow_free_strings=True,
|
|
174
|
+
seen=seen,
|
|
175
|
+
string_key_allowlist=string_key_allowlist,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
for key, value in node.items():
|
|
179
|
+
if key in handled_keys:
|
|
180
|
+
continue
|
|
181
|
+
if isinstance(value, (list, dict)):
|
|
182
|
+
gather_text_fragments(
|
|
183
|
+
value,
|
|
184
|
+
fragments,
|
|
185
|
+
allow_free_strings=False,
|
|
186
|
+
seen=seen,
|
|
187
|
+
string_key_allowlist=string_key_allowlist,
|
|
188
|
+
)
|
|
189
|
+
elif isinstance(value, str) and key.lower() in allowlist:
|
|
190
|
+
stripped = value.strip()
|
|
191
|
+
if stripped:
|
|
192
|
+
fragments.append(stripped)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def collect_text_fragments(
|
|
196
|
+
node: Any,
|
|
197
|
+
*,
|
|
198
|
+
allow_free_strings: bool = False,
|
|
199
|
+
string_key_allowlist: Optional[Collection[str]] = None,
|
|
200
|
+
) -> list[str]:
|
|
201
|
+
fragments: list[str] = []
|
|
202
|
+
gather_text_fragments(
|
|
203
|
+
node,
|
|
204
|
+
fragments,
|
|
205
|
+
allow_free_strings=allow_free_strings,
|
|
206
|
+
seen=set(),
|
|
207
|
+
string_key_allowlist=string_key_allowlist,
|
|
208
|
+
)
|
|
209
|
+
return fragments
|
osmosis_ai/consts.py
CHANGED