osmosis-ai 0.1.9__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of osmosis-ai might be problematic. Click here for more details.
- osmosis_ai/__init__.py +19 -132
- osmosis_ai/cli.py +50 -0
- osmosis_ai/cli_commands.py +181 -0
- osmosis_ai/cli_services/__init__.py +60 -0
- osmosis_ai/cli_services/config.py +410 -0
- osmosis_ai/cli_services/dataset.py +175 -0
- osmosis_ai/cli_services/engine.py +421 -0
- osmosis_ai/cli_services/errors.py +7 -0
- osmosis_ai/cli_services/reporting.py +307 -0
- osmosis_ai/cli_services/session.py +174 -0
- osmosis_ai/cli_services/shared.py +209 -0
- osmosis_ai/consts.py +1 -1
- osmosis_ai/providers/__init__.py +36 -0
- osmosis_ai/providers/anthropic_provider.py +85 -0
- osmosis_ai/providers/base.py +60 -0
- osmosis_ai/providers/gemini_provider.py +314 -0
- osmosis_ai/providers/openai_family.py +607 -0
- osmosis_ai/providers/shared.py +92 -0
- osmosis_ai/rubric_eval.py +356 -0
- osmosis_ai/rubric_types.py +49 -0
- osmosis_ai/utils.py +258 -5
- osmosis_ai-0.2.4.dist-info/METADATA +314 -0
- osmosis_ai-0.2.4.dist-info/RECORD +27 -0
- osmosis_ai-0.2.4.dist-info/entry_points.txt +4 -0
- osmosis_ai-0.1.9.dist-info/METADATA +0 -143
- osmosis_ai-0.1.9.dist-info/RECORD +0 -8
- {osmosis_ai-0.1.9.dist-info → osmosis_ai-0.2.4.dist-info}/WHEEL +0 -0
- {osmosis_ai-0.1.9.dist-info → osmosis_ai-0.2.4.dist-info}/licenses/LICENSE +0 -0
- {osmosis_ai-0.1.9.dist-info → osmosis_ai-0.2.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,421 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
import sys
|
|
5
|
+
import time
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from datetime import datetime, timezone
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
import inspect
|
|
10
|
+
from typing import Any, Optional, Sequence
|
|
11
|
+
|
|
12
|
+
from tqdm import tqdm
|
|
13
|
+
|
|
14
|
+
from ..rubric_eval import DEFAULT_API_KEY_ENV, evaluate_rubric
|
|
15
|
+
from ..rubric_types import MissingAPIKeyError, ModelNotFoundError, ProviderRequestError
|
|
16
|
+
from .config import RubricConfig
|
|
17
|
+
from .dataset import DatasetRecord
|
|
18
|
+
from .errors import CLIError
|
|
19
|
+
from .shared import calculate_statistics, coerce_optional_float, collapse_preview_text
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _normalize_config_str(value: Any) -> Optional[str]:
|
|
23
|
+
if value is None:
|
|
24
|
+
return None
|
|
25
|
+
text = str(value).strip()
|
|
26
|
+
return text or None
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _compose_extra_info_context(
|
|
30
|
+
base: Optional[dict[str, Any]],
|
|
31
|
+
*,
|
|
32
|
+
rubric_text: str,
|
|
33
|
+
provider: Optional[str],
|
|
34
|
+
model: Optional[str],
|
|
35
|
+
system_prompt: Optional[str],
|
|
36
|
+
original_input: Optional[str],
|
|
37
|
+
api_key: Optional[str],
|
|
38
|
+
api_key_env: Optional[str],
|
|
39
|
+
score_min: Optional[float],
|
|
40
|
+
score_max: Optional[float],
|
|
41
|
+
model_info: dict[str, Any],
|
|
42
|
+
) -> tuple[dict[str, Any], Optional[dict[str, Any]]]:
|
|
43
|
+
"""
|
|
44
|
+
Build the runtime context passed to rubric functions along with a sanitised
|
|
45
|
+
copy safe for prompt injection.
|
|
46
|
+
"""
|
|
47
|
+
base_payload = copy.deepcopy(base) if isinstance(base, dict) else {}
|
|
48
|
+
decorated_payload = copy.deepcopy(base_payload)
|
|
49
|
+
|
|
50
|
+
if provider:
|
|
51
|
+
decorated_payload["provider"] = provider
|
|
52
|
+
if model:
|
|
53
|
+
decorated_payload["model"] = model
|
|
54
|
+
if api_key:
|
|
55
|
+
decorated_payload["api_key"] = api_key
|
|
56
|
+
decorated_payload.pop("api_key_env", None)
|
|
57
|
+
elif api_key_env:
|
|
58
|
+
decorated_payload["api_key_env"] = api_key_env
|
|
59
|
+
decorated_payload.pop("api_key", None)
|
|
60
|
+
if rubric_text:
|
|
61
|
+
decorated_payload["rubric"] = rubric_text
|
|
62
|
+
if score_min is not None:
|
|
63
|
+
decorated_payload["score_min"] = float(score_min)
|
|
64
|
+
if score_max is not None:
|
|
65
|
+
decorated_payload["score_max"] = float(score_max)
|
|
66
|
+
if system_prompt:
|
|
67
|
+
decorated_payload["system_prompt"] = system_prompt
|
|
68
|
+
if original_input and isinstance(original_input, str):
|
|
69
|
+
decorated_payload["original_input"] = original_input
|
|
70
|
+
|
|
71
|
+
model_info_copy = copy.deepcopy(model_info)
|
|
72
|
+
if isinstance(model_info_copy, dict):
|
|
73
|
+
if api_key and "api_key" not in model_info_copy:
|
|
74
|
+
model_info_copy["api_key"] = api_key
|
|
75
|
+
if api_key_env and "api_key_env" not in model_info_copy:
|
|
76
|
+
model_info_copy["api_key_env"] = api_key_env
|
|
77
|
+
decorated_payload["model_info"] = model_info_copy
|
|
78
|
+
|
|
79
|
+
prompt_payload: Optional[dict[str, Any]] = None
|
|
80
|
+
base_metadata = decorated_payload.get("metadata")
|
|
81
|
+
if isinstance(base_metadata, dict):
|
|
82
|
+
prompt_payload = copy.deepcopy(base_metadata)
|
|
83
|
+
elif base_metadata is not None:
|
|
84
|
+
prompt_payload = dict(base_metadata)
|
|
85
|
+
|
|
86
|
+
dataset_metadata = decorated_payload.get("dataset_metadata")
|
|
87
|
+
if isinstance(dataset_metadata, dict):
|
|
88
|
+
if prompt_payload is None:
|
|
89
|
+
prompt_payload = {}
|
|
90
|
+
prompt_payload.setdefault("dataset_metadata", copy.deepcopy(dataset_metadata))
|
|
91
|
+
|
|
92
|
+
if prompt_payload is not None:
|
|
93
|
+
decorated_payload["metadata"] = copy.deepcopy(prompt_payload)
|
|
94
|
+
else:
|
|
95
|
+
decorated_payload.pop("metadata", None)
|
|
96
|
+
|
|
97
|
+
return decorated_payload, prompt_payload
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _merge_system_prompts(
|
|
101
|
+
prepend_prompt: Optional[str],
|
|
102
|
+
base_prompt: Optional[str],
|
|
103
|
+
) -> Optional[str]:
|
|
104
|
+
prompts: list[str] = []
|
|
105
|
+
if prepend_prompt:
|
|
106
|
+
prompts.append(prepend_prompt)
|
|
107
|
+
if base_prompt:
|
|
108
|
+
prompts.append(base_prompt)
|
|
109
|
+
if not prompts:
|
|
110
|
+
return None
|
|
111
|
+
return "\n\n".join(prompts)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class RubricEvaluator:
|
|
115
|
+
"""Thin wrapper over evaluate_rubric to enable injection during tests."""
|
|
116
|
+
|
|
117
|
+
def __init__(self, evaluate_fn: Any = evaluate_rubric):
|
|
118
|
+
self._evaluate_fn = evaluate_fn
|
|
119
|
+
|
|
120
|
+
def run(self, config: RubricConfig, record: DatasetRecord) -> dict[str, Any]:
|
|
121
|
+
solution = record.solution_str
|
|
122
|
+
if not isinstance(solution, str) or not solution.strip():
|
|
123
|
+
label = record.conversation_id or record.rubric_id or "<record>"
|
|
124
|
+
raise CLIError(f"Record '{label}' must include a non-empty 'solution_str' string.")
|
|
125
|
+
|
|
126
|
+
score_min = coerce_optional_float(
|
|
127
|
+
record.score_min if record.score_min is not None else config.score_min,
|
|
128
|
+
"score_min",
|
|
129
|
+
f"record '{record.conversation_id or '<record>'}'",
|
|
130
|
+
)
|
|
131
|
+
score_max = coerce_optional_float(
|
|
132
|
+
record.score_max if record.score_max is not None else config.score_max,
|
|
133
|
+
"score_max",
|
|
134
|
+
f"record '{record.conversation_id or '<record>'}'",
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
ground_truth = record.ground_truth if record.ground_truth is not None else config.ground_truth
|
|
138
|
+
original_input = record.original_input if record.original_input is not None else config.original_input
|
|
139
|
+
|
|
140
|
+
provider_value = _normalize_config_str(config.model_info.get("provider"))
|
|
141
|
+
model_value = _normalize_config_str(config.model_info.get("model"))
|
|
142
|
+
system_prompt_value = _normalize_config_str(config.system_prompt)
|
|
143
|
+
api_key_value = _normalize_config_str(config.model_info.get("api_key"))
|
|
144
|
+
api_key_env_value = _normalize_config_str(config.model_info.get("api_key_env"))
|
|
145
|
+
if api_key_env_value is None and provider_value:
|
|
146
|
+
default_env = DEFAULT_API_KEY_ENV.get(provider_value.lower())
|
|
147
|
+
if default_env:
|
|
148
|
+
api_key_env_value = default_env
|
|
149
|
+
|
|
150
|
+
try:
|
|
151
|
+
model_info_payload = copy.deepcopy(config.model_info)
|
|
152
|
+
base_system_prompt = _normalize_config_str(model_info_payload.get("system_prompt"))
|
|
153
|
+
combined_system_prompt = _merge_system_prompts(system_prompt_value, base_system_prompt)
|
|
154
|
+
if combined_system_prompt is not None:
|
|
155
|
+
model_info_payload["system_prompt"] = combined_system_prompt
|
|
156
|
+
else:
|
|
157
|
+
model_info_payload.pop("system_prompt", None)
|
|
158
|
+
|
|
159
|
+
decorated_extra, prompt_metadata = _compose_extra_info_context(
|
|
160
|
+
record.merged_extra_info(),
|
|
161
|
+
rubric_text=config.rubric_text,
|
|
162
|
+
provider=provider_value,
|
|
163
|
+
model=model_value,
|
|
164
|
+
system_prompt=system_prompt_value,
|
|
165
|
+
original_input=original_input,
|
|
166
|
+
api_key=api_key_value,
|
|
167
|
+
api_key_env=api_key_env_value,
|
|
168
|
+
score_min=score_min,
|
|
169
|
+
score_max=score_max,
|
|
170
|
+
model_info=model_info_payload,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
signature = inspect.signature(self._evaluate_fn)
|
|
174
|
+
parameters = signature.parameters
|
|
175
|
+
accepts_var_kwargs = any(
|
|
176
|
+
param.kind == inspect.Parameter.VAR_KEYWORD for param in parameters.values()
|
|
177
|
+
)
|
|
178
|
+
is_evaluate_rubric_style = accepts_var_kwargs or "rubric" in parameters or "model_info" in parameters
|
|
179
|
+
|
|
180
|
+
if is_evaluate_rubric_style:
|
|
181
|
+
return self._evaluate_fn(
|
|
182
|
+
rubric=config.rubric_text,
|
|
183
|
+
solution_str=solution,
|
|
184
|
+
model_info=model_info_payload,
|
|
185
|
+
ground_truth=ground_truth,
|
|
186
|
+
original_input=original_input,
|
|
187
|
+
metadata=prompt_metadata,
|
|
188
|
+
score_min=score_min,
|
|
189
|
+
score_max=score_max,
|
|
190
|
+
return_details=True,
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
call_args: list[Any] = []
|
|
194
|
+
call_kwargs: dict[str, Any] = {}
|
|
195
|
+
for param in parameters.values():
|
|
196
|
+
if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
|
|
197
|
+
continue
|
|
198
|
+
|
|
199
|
+
if param.name == "solution_str":
|
|
200
|
+
value = solution
|
|
201
|
+
elif param.name == "ground_truth":
|
|
202
|
+
value = ground_truth
|
|
203
|
+
elif param.name == "extra_info":
|
|
204
|
+
value = decorated_extra
|
|
205
|
+
else:
|
|
206
|
+
continue
|
|
207
|
+
|
|
208
|
+
if param.kind in (
|
|
209
|
+
inspect.Parameter.POSITIONAL_ONLY,
|
|
210
|
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
211
|
+
):
|
|
212
|
+
call_args.append(value)
|
|
213
|
+
else:
|
|
214
|
+
call_kwargs[param.name] = value
|
|
215
|
+
|
|
216
|
+
return self._evaluate_fn(*call_args, **call_kwargs)
|
|
217
|
+
except (MissingAPIKeyError, ProviderRequestError, ModelNotFoundError) as exc:
|
|
218
|
+
raise CLIError(str(exc)) from exc
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
@dataclass
|
|
222
|
+
class EvaluationRun:
|
|
223
|
+
run_index: int
|
|
224
|
+
status: str
|
|
225
|
+
score: Optional[float]
|
|
226
|
+
explanation: Optional[str]
|
|
227
|
+
preview: Optional[str]
|
|
228
|
+
duration_seconds: float
|
|
229
|
+
started_at: datetime
|
|
230
|
+
completed_at: datetime
|
|
231
|
+
error: Optional[str]
|
|
232
|
+
raw: Any
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
@dataclass
|
|
236
|
+
class EvaluationRecordResult:
|
|
237
|
+
record_index: int
|
|
238
|
+
record: DatasetRecord
|
|
239
|
+
conversation_label: str
|
|
240
|
+
runs: list[EvaluationRun]
|
|
241
|
+
statistics: dict[str, float]
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
@dataclass
|
|
245
|
+
class EvaluationReport:
|
|
246
|
+
rubric_config: RubricConfig
|
|
247
|
+
config_path: Path
|
|
248
|
+
data_path: Path
|
|
249
|
+
number: int
|
|
250
|
+
record_results: list[EvaluationRecordResult]
|
|
251
|
+
overall_statistics: dict[str, float]
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
class RubricEvaluationEngine:
|
|
255
|
+
"""Executes rubric evaluations across a dataset and aggregates statistics."""
|
|
256
|
+
|
|
257
|
+
def __init__(self, evaluator: Optional[RubricEvaluator] = None):
|
|
258
|
+
self._evaluator = evaluator or RubricEvaluator()
|
|
259
|
+
|
|
260
|
+
def execute(
|
|
261
|
+
self,
|
|
262
|
+
*,
|
|
263
|
+
rubric_config: RubricConfig,
|
|
264
|
+
config_path: Path,
|
|
265
|
+
data_path: Path,
|
|
266
|
+
records: Sequence[DatasetRecord],
|
|
267
|
+
number: int,
|
|
268
|
+
) -> EvaluationReport:
|
|
269
|
+
record_results: list[EvaluationRecordResult] = []
|
|
270
|
+
aggregate_scores: list[float] = []
|
|
271
|
+
total_runs = 0
|
|
272
|
+
total_successes = 0
|
|
273
|
+
|
|
274
|
+
progress_total = len(records) * number
|
|
275
|
+
show_progress = progress_total > 1 and getattr(sys.stderr, "isatty", lambda: False)()
|
|
276
|
+
progress = (
|
|
277
|
+
tqdm(
|
|
278
|
+
total=progress_total,
|
|
279
|
+
file=sys.stderr,
|
|
280
|
+
dynamic_ncols=True,
|
|
281
|
+
leave=False,
|
|
282
|
+
)
|
|
283
|
+
if show_progress
|
|
284
|
+
else None
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
try:
|
|
288
|
+
for record_index, record in enumerate(records, start=1):
|
|
289
|
+
conversation_label = record.conversation_label(record_index)
|
|
290
|
+
fallback_preview = record.assistant_preview()
|
|
291
|
+
|
|
292
|
+
runs: list[EvaluationRun] = []
|
|
293
|
+
scores: list[float] = []
|
|
294
|
+
|
|
295
|
+
for attempt in range(1, number + 1):
|
|
296
|
+
started_at = datetime.now(timezone.utc)
|
|
297
|
+
timer_start = time.perf_counter()
|
|
298
|
+
status = "success"
|
|
299
|
+
error_message: Optional[str] = None
|
|
300
|
+
score_value: Optional[float] = None
|
|
301
|
+
explanation_value: Optional[str] = None
|
|
302
|
+
preview_value: Optional[str] = None
|
|
303
|
+
raw_payload: Any = None
|
|
304
|
+
|
|
305
|
+
try:
|
|
306
|
+
result = self._evaluator.run(rubric_config, record)
|
|
307
|
+
except CLIError as exc:
|
|
308
|
+
status = "error"
|
|
309
|
+
error_message = str(exc)
|
|
310
|
+
result = None
|
|
311
|
+
except Exception as exc: # pragma: no cover - unexpected path
|
|
312
|
+
status = "error"
|
|
313
|
+
error_message = f"{type(exc).__name__}: {exc}"
|
|
314
|
+
result = None
|
|
315
|
+
|
|
316
|
+
duration_seconds = time.perf_counter() - timer_start
|
|
317
|
+
completed_at = datetime.now(timezone.utc)
|
|
318
|
+
|
|
319
|
+
if status == "success":
|
|
320
|
+
if isinstance(result, dict):
|
|
321
|
+
raw_payload = result.get("raw")
|
|
322
|
+
score_value = _extract_float(result.get("score"))
|
|
323
|
+
explanation_value = _normalize_optional_text(result.get("explanation"))
|
|
324
|
+
preview_value = self._resolve_preview_text(result, fallback_preview)
|
|
325
|
+
if score_value is not None:
|
|
326
|
+
scores.append(score_value)
|
|
327
|
+
aggregate_scores.append(score_value)
|
|
328
|
+
total_successes += 1
|
|
329
|
+
else:
|
|
330
|
+
preview_value = fallback_preview
|
|
331
|
+
numeric_score = _extract_float(result)
|
|
332
|
+
if numeric_score is not None:
|
|
333
|
+
score_value = numeric_score
|
|
334
|
+
raw_payload = result
|
|
335
|
+
scores.append(score_value)
|
|
336
|
+
aggregate_scores.append(score_value)
|
|
337
|
+
total_successes += 1
|
|
338
|
+
else:
|
|
339
|
+
preview_value = fallback_preview
|
|
340
|
+
|
|
341
|
+
total_runs += 1
|
|
342
|
+
|
|
343
|
+
runs.append(
|
|
344
|
+
EvaluationRun(
|
|
345
|
+
run_index=attempt,
|
|
346
|
+
status=status,
|
|
347
|
+
score=score_value,
|
|
348
|
+
explanation=explanation_value,
|
|
349
|
+
preview=preview_value,
|
|
350
|
+
duration_seconds=duration_seconds,
|
|
351
|
+
started_at=started_at,
|
|
352
|
+
completed_at=completed_at,
|
|
353
|
+
error=error_message,
|
|
354
|
+
raw=raw_payload,
|
|
355
|
+
)
|
|
356
|
+
)
|
|
357
|
+
|
|
358
|
+
if progress:
|
|
359
|
+
progress.update()
|
|
360
|
+
|
|
361
|
+
statistics = calculate_statistics(scores)
|
|
362
|
+
statistics["total_runs"] = len(runs)
|
|
363
|
+
statistics["success_count"] = len(scores)
|
|
364
|
+
statistics["failure_count"] = len(runs) - len(scores)
|
|
365
|
+
record_results.append(
|
|
366
|
+
EvaluationRecordResult(
|
|
367
|
+
record_index=record_index,
|
|
368
|
+
record=record,
|
|
369
|
+
conversation_label=conversation_label,
|
|
370
|
+
runs=runs,
|
|
371
|
+
statistics=statistics,
|
|
372
|
+
)
|
|
373
|
+
)
|
|
374
|
+
finally:
|
|
375
|
+
if progress:
|
|
376
|
+
progress.close()
|
|
377
|
+
|
|
378
|
+
overall_statistics = calculate_statistics(aggregate_scores)
|
|
379
|
+
overall_statistics["total_runs"] = total_runs
|
|
380
|
+
overall_statistics["success_count"] = total_successes
|
|
381
|
+
overall_statistics["failure_count"] = total_runs - total_successes
|
|
382
|
+
return EvaluationReport(
|
|
383
|
+
rubric_config=rubric_config,
|
|
384
|
+
config_path=config_path,
|
|
385
|
+
data_path=data_path,
|
|
386
|
+
number=number,
|
|
387
|
+
record_results=record_results,
|
|
388
|
+
overall_statistics=overall_statistics,
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
@staticmethod
|
|
392
|
+
def _resolve_preview_text(result: Optional[dict[str, Any]], fallback: Optional[str]) -> Optional[str]:
|
|
393
|
+
if not isinstance(result, dict):
|
|
394
|
+
return fallback
|
|
395
|
+
preview = collapse_preview_text(result.get("preview"))
|
|
396
|
+
if preview:
|
|
397
|
+
return preview
|
|
398
|
+
|
|
399
|
+
raw_payload = result.get("raw")
|
|
400
|
+
if isinstance(raw_payload, dict):
|
|
401
|
+
for key in ("preview", "summary", "text"):
|
|
402
|
+
preview = collapse_preview_text(raw_payload.get(key))
|
|
403
|
+
if preview:
|
|
404
|
+
return preview
|
|
405
|
+
return fallback
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
def _extract_float(value: Any) -> Optional[float]:
|
|
409
|
+
try:
|
|
410
|
+
if isinstance(value, (int, float)) and not isinstance(value, bool):
|
|
411
|
+
return float(value)
|
|
412
|
+
return None
|
|
413
|
+
except (TypeError, ValueError):
|
|
414
|
+
return None
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
def _normalize_optional_text(value: Any) -> Optional[str]:
|
|
418
|
+
if value is None:
|
|
419
|
+
return None
|
|
420
|
+
text = str(value).strip()
|
|
421
|
+
return text or None
|