osmosis-ai 0.1.9__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of osmosis-ai might be problematic. Click here for more details.

@@ -0,0 +1,421 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ import sys
5
+ import time
6
+ from dataclasses import dataclass
7
+ from datetime import datetime, timezone
8
+ from pathlib import Path
9
+ import inspect
10
+ from typing import Any, Optional, Sequence
11
+
12
+ from tqdm import tqdm
13
+
14
+ from ..rubric_eval import DEFAULT_API_KEY_ENV, evaluate_rubric
15
+ from ..rubric_types import MissingAPIKeyError, ModelNotFoundError, ProviderRequestError
16
+ from .config import RubricConfig
17
+ from .dataset import DatasetRecord
18
+ from .errors import CLIError
19
+ from .shared import calculate_statistics, coerce_optional_float, collapse_preview_text
20
+
21
+
22
+ def _normalize_config_str(value: Any) -> Optional[str]:
23
+ if value is None:
24
+ return None
25
+ text = str(value).strip()
26
+ return text or None
27
+
28
+
29
+ def _compose_extra_info_context(
30
+ base: Optional[dict[str, Any]],
31
+ *,
32
+ rubric_text: str,
33
+ provider: Optional[str],
34
+ model: Optional[str],
35
+ system_prompt: Optional[str],
36
+ original_input: Optional[str],
37
+ api_key: Optional[str],
38
+ api_key_env: Optional[str],
39
+ score_min: Optional[float],
40
+ score_max: Optional[float],
41
+ model_info: dict[str, Any],
42
+ ) -> tuple[dict[str, Any], Optional[dict[str, Any]]]:
43
+ """
44
+ Build the runtime context passed to rubric functions along with a sanitised
45
+ copy safe for prompt injection.
46
+ """
47
+ base_payload = copy.deepcopy(base) if isinstance(base, dict) else {}
48
+ decorated_payload = copy.deepcopy(base_payload)
49
+
50
+ if provider:
51
+ decorated_payload["provider"] = provider
52
+ if model:
53
+ decorated_payload["model"] = model
54
+ if api_key:
55
+ decorated_payload["api_key"] = api_key
56
+ decorated_payload.pop("api_key_env", None)
57
+ elif api_key_env:
58
+ decorated_payload["api_key_env"] = api_key_env
59
+ decorated_payload.pop("api_key", None)
60
+ if rubric_text:
61
+ decorated_payload["rubric"] = rubric_text
62
+ if score_min is not None:
63
+ decorated_payload["score_min"] = float(score_min)
64
+ if score_max is not None:
65
+ decorated_payload["score_max"] = float(score_max)
66
+ if system_prompt:
67
+ decorated_payload["system_prompt"] = system_prompt
68
+ if original_input and isinstance(original_input, str):
69
+ decorated_payload["original_input"] = original_input
70
+
71
+ model_info_copy = copy.deepcopy(model_info)
72
+ if isinstance(model_info_copy, dict):
73
+ if api_key and "api_key" not in model_info_copy:
74
+ model_info_copy["api_key"] = api_key
75
+ if api_key_env and "api_key_env" not in model_info_copy:
76
+ model_info_copy["api_key_env"] = api_key_env
77
+ decorated_payload["model_info"] = model_info_copy
78
+
79
+ prompt_payload: Optional[dict[str, Any]] = None
80
+ base_metadata = decorated_payload.get("metadata")
81
+ if isinstance(base_metadata, dict):
82
+ prompt_payload = copy.deepcopy(base_metadata)
83
+ elif base_metadata is not None:
84
+ prompt_payload = dict(base_metadata)
85
+
86
+ dataset_metadata = decorated_payload.get("dataset_metadata")
87
+ if isinstance(dataset_metadata, dict):
88
+ if prompt_payload is None:
89
+ prompt_payload = {}
90
+ prompt_payload.setdefault("dataset_metadata", copy.deepcopy(dataset_metadata))
91
+
92
+ if prompt_payload is not None:
93
+ decorated_payload["metadata"] = copy.deepcopy(prompt_payload)
94
+ else:
95
+ decorated_payload.pop("metadata", None)
96
+
97
+ return decorated_payload, prompt_payload
98
+
99
+
100
+ def _merge_system_prompts(
101
+ prepend_prompt: Optional[str],
102
+ base_prompt: Optional[str],
103
+ ) -> Optional[str]:
104
+ prompts: list[str] = []
105
+ if prepend_prompt:
106
+ prompts.append(prepend_prompt)
107
+ if base_prompt:
108
+ prompts.append(base_prompt)
109
+ if not prompts:
110
+ return None
111
+ return "\n\n".join(prompts)
112
+
113
+
114
+ class RubricEvaluator:
115
+ """Thin wrapper over evaluate_rubric to enable injection during tests."""
116
+
117
+ def __init__(self, evaluate_fn: Any = evaluate_rubric):
118
+ self._evaluate_fn = evaluate_fn
119
+
120
+ def run(self, config: RubricConfig, record: DatasetRecord) -> dict[str, Any]:
121
+ solution = record.solution_str
122
+ if not isinstance(solution, str) or not solution.strip():
123
+ label = record.conversation_id or record.rubric_id or "<record>"
124
+ raise CLIError(f"Record '{label}' must include a non-empty 'solution_str' string.")
125
+
126
+ score_min = coerce_optional_float(
127
+ record.score_min if record.score_min is not None else config.score_min,
128
+ "score_min",
129
+ f"record '{record.conversation_id or '<record>'}'",
130
+ )
131
+ score_max = coerce_optional_float(
132
+ record.score_max if record.score_max is not None else config.score_max,
133
+ "score_max",
134
+ f"record '{record.conversation_id or '<record>'}'",
135
+ )
136
+
137
+ ground_truth = record.ground_truth if record.ground_truth is not None else config.ground_truth
138
+ original_input = record.original_input if record.original_input is not None else config.original_input
139
+
140
+ provider_value = _normalize_config_str(config.model_info.get("provider"))
141
+ model_value = _normalize_config_str(config.model_info.get("model"))
142
+ system_prompt_value = _normalize_config_str(config.system_prompt)
143
+ api_key_value = _normalize_config_str(config.model_info.get("api_key"))
144
+ api_key_env_value = _normalize_config_str(config.model_info.get("api_key_env"))
145
+ if api_key_env_value is None and provider_value:
146
+ default_env = DEFAULT_API_KEY_ENV.get(provider_value.lower())
147
+ if default_env:
148
+ api_key_env_value = default_env
149
+
150
+ try:
151
+ model_info_payload = copy.deepcopy(config.model_info)
152
+ base_system_prompt = _normalize_config_str(model_info_payload.get("system_prompt"))
153
+ combined_system_prompt = _merge_system_prompts(system_prompt_value, base_system_prompt)
154
+ if combined_system_prompt is not None:
155
+ model_info_payload["system_prompt"] = combined_system_prompt
156
+ else:
157
+ model_info_payload.pop("system_prompt", None)
158
+
159
+ decorated_extra, prompt_metadata = _compose_extra_info_context(
160
+ record.merged_extra_info(),
161
+ rubric_text=config.rubric_text,
162
+ provider=provider_value,
163
+ model=model_value,
164
+ system_prompt=system_prompt_value,
165
+ original_input=original_input,
166
+ api_key=api_key_value,
167
+ api_key_env=api_key_env_value,
168
+ score_min=score_min,
169
+ score_max=score_max,
170
+ model_info=model_info_payload,
171
+ )
172
+
173
+ signature = inspect.signature(self._evaluate_fn)
174
+ parameters = signature.parameters
175
+ accepts_var_kwargs = any(
176
+ param.kind == inspect.Parameter.VAR_KEYWORD for param in parameters.values()
177
+ )
178
+ is_evaluate_rubric_style = accepts_var_kwargs or "rubric" in parameters or "model_info" in parameters
179
+
180
+ if is_evaluate_rubric_style:
181
+ return self._evaluate_fn(
182
+ rubric=config.rubric_text,
183
+ solution_str=solution,
184
+ model_info=model_info_payload,
185
+ ground_truth=ground_truth,
186
+ original_input=original_input,
187
+ metadata=prompt_metadata,
188
+ score_min=score_min,
189
+ score_max=score_max,
190
+ return_details=True,
191
+ )
192
+
193
+ call_args: list[Any] = []
194
+ call_kwargs: dict[str, Any] = {}
195
+ for param in parameters.values():
196
+ if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
197
+ continue
198
+
199
+ if param.name == "solution_str":
200
+ value = solution
201
+ elif param.name == "ground_truth":
202
+ value = ground_truth
203
+ elif param.name == "extra_info":
204
+ value = decorated_extra
205
+ else:
206
+ continue
207
+
208
+ if param.kind in (
209
+ inspect.Parameter.POSITIONAL_ONLY,
210
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
211
+ ):
212
+ call_args.append(value)
213
+ else:
214
+ call_kwargs[param.name] = value
215
+
216
+ return self._evaluate_fn(*call_args, **call_kwargs)
217
+ except (MissingAPIKeyError, ProviderRequestError, ModelNotFoundError) as exc:
218
+ raise CLIError(str(exc)) from exc
219
+
220
+
221
+ @dataclass
222
+ class EvaluationRun:
223
+ run_index: int
224
+ status: str
225
+ score: Optional[float]
226
+ explanation: Optional[str]
227
+ preview: Optional[str]
228
+ duration_seconds: float
229
+ started_at: datetime
230
+ completed_at: datetime
231
+ error: Optional[str]
232
+ raw: Any
233
+
234
+
235
+ @dataclass
236
+ class EvaluationRecordResult:
237
+ record_index: int
238
+ record: DatasetRecord
239
+ conversation_label: str
240
+ runs: list[EvaluationRun]
241
+ statistics: dict[str, float]
242
+
243
+
244
+ @dataclass
245
+ class EvaluationReport:
246
+ rubric_config: RubricConfig
247
+ config_path: Path
248
+ data_path: Path
249
+ number: int
250
+ record_results: list[EvaluationRecordResult]
251
+ overall_statistics: dict[str, float]
252
+
253
+
254
+ class RubricEvaluationEngine:
255
+ """Executes rubric evaluations across a dataset and aggregates statistics."""
256
+
257
+ def __init__(self, evaluator: Optional[RubricEvaluator] = None):
258
+ self._evaluator = evaluator or RubricEvaluator()
259
+
260
+ def execute(
261
+ self,
262
+ *,
263
+ rubric_config: RubricConfig,
264
+ config_path: Path,
265
+ data_path: Path,
266
+ records: Sequence[DatasetRecord],
267
+ number: int,
268
+ ) -> EvaluationReport:
269
+ record_results: list[EvaluationRecordResult] = []
270
+ aggregate_scores: list[float] = []
271
+ total_runs = 0
272
+ total_successes = 0
273
+
274
+ progress_total = len(records) * number
275
+ show_progress = progress_total > 1 and getattr(sys.stderr, "isatty", lambda: False)()
276
+ progress = (
277
+ tqdm(
278
+ total=progress_total,
279
+ file=sys.stderr,
280
+ dynamic_ncols=True,
281
+ leave=False,
282
+ )
283
+ if show_progress
284
+ else None
285
+ )
286
+
287
+ try:
288
+ for record_index, record in enumerate(records, start=1):
289
+ conversation_label = record.conversation_label(record_index)
290
+ fallback_preview = record.assistant_preview()
291
+
292
+ runs: list[EvaluationRun] = []
293
+ scores: list[float] = []
294
+
295
+ for attempt in range(1, number + 1):
296
+ started_at = datetime.now(timezone.utc)
297
+ timer_start = time.perf_counter()
298
+ status = "success"
299
+ error_message: Optional[str] = None
300
+ score_value: Optional[float] = None
301
+ explanation_value: Optional[str] = None
302
+ preview_value: Optional[str] = None
303
+ raw_payload: Any = None
304
+
305
+ try:
306
+ result = self._evaluator.run(rubric_config, record)
307
+ except CLIError as exc:
308
+ status = "error"
309
+ error_message = str(exc)
310
+ result = None
311
+ except Exception as exc: # pragma: no cover - unexpected path
312
+ status = "error"
313
+ error_message = f"{type(exc).__name__}: {exc}"
314
+ result = None
315
+
316
+ duration_seconds = time.perf_counter() - timer_start
317
+ completed_at = datetime.now(timezone.utc)
318
+
319
+ if status == "success":
320
+ if isinstance(result, dict):
321
+ raw_payload = result.get("raw")
322
+ score_value = _extract_float(result.get("score"))
323
+ explanation_value = _normalize_optional_text(result.get("explanation"))
324
+ preview_value = self._resolve_preview_text(result, fallback_preview)
325
+ if score_value is not None:
326
+ scores.append(score_value)
327
+ aggregate_scores.append(score_value)
328
+ total_successes += 1
329
+ else:
330
+ preview_value = fallback_preview
331
+ numeric_score = _extract_float(result)
332
+ if numeric_score is not None:
333
+ score_value = numeric_score
334
+ raw_payload = result
335
+ scores.append(score_value)
336
+ aggregate_scores.append(score_value)
337
+ total_successes += 1
338
+ else:
339
+ preview_value = fallback_preview
340
+
341
+ total_runs += 1
342
+
343
+ runs.append(
344
+ EvaluationRun(
345
+ run_index=attempt,
346
+ status=status,
347
+ score=score_value,
348
+ explanation=explanation_value,
349
+ preview=preview_value,
350
+ duration_seconds=duration_seconds,
351
+ started_at=started_at,
352
+ completed_at=completed_at,
353
+ error=error_message,
354
+ raw=raw_payload,
355
+ )
356
+ )
357
+
358
+ if progress:
359
+ progress.update()
360
+
361
+ statistics = calculate_statistics(scores)
362
+ statistics["total_runs"] = len(runs)
363
+ statistics["success_count"] = len(scores)
364
+ statistics["failure_count"] = len(runs) - len(scores)
365
+ record_results.append(
366
+ EvaluationRecordResult(
367
+ record_index=record_index,
368
+ record=record,
369
+ conversation_label=conversation_label,
370
+ runs=runs,
371
+ statistics=statistics,
372
+ )
373
+ )
374
+ finally:
375
+ if progress:
376
+ progress.close()
377
+
378
+ overall_statistics = calculate_statistics(aggregate_scores)
379
+ overall_statistics["total_runs"] = total_runs
380
+ overall_statistics["success_count"] = total_successes
381
+ overall_statistics["failure_count"] = total_runs - total_successes
382
+ return EvaluationReport(
383
+ rubric_config=rubric_config,
384
+ config_path=config_path,
385
+ data_path=data_path,
386
+ number=number,
387
+ record_results=record_results,
388
+ overall_statistics=overall_statistics,
389
+ )
390
+
391
+ @staticmethod
392
+ def _resolve_preview_text(result: Optional[dict[str, Any]], fallback: Optional[str]) -> Optional[str]:
393
+ if not isinstance(result, dict):
394
+ return fallback
395
+ preview = collapse_preview_text(result.get("preview"))
396
+ if preview:
397
+ return preview
398
+
399
+ raw_payload = result.get("raw")
400
+ if isinstance(raw_payload, dict):
401
+ for key in ("preview", "summary", "text"):
402
+ preview = collapse_preview_text(raw_payload.get(key))
403
+ if preview:
404
+ return preview
405
+ return fallback
406
+
407
+
408
+ def _extract_float(value: Any) -> Optional[float]:
409
+ try:
410
+ if isinstance(value, (int, float)) and not isinstance(value, bool):
411
+ return float(value)
412
+ return None
413
+ except (TypeError, ValueError):
414
+ return None
415
+
416
+
417
+ def _normalize_optional_text(value: Any) -> Optional[str]:
418
+ if value is None:
419
+ return None
420
+ text = str(value).strip()
421
+ return text or None
@@ -0,0 +1,7 @@
1
+ from __future__ import annotations
2
+
3
+
4
+ class CLIError(Exception):
5
+ """Raised when the CLI encounters a recoverable error."""
6
+
7
+ pass