osmosis-ai 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of osmosis-ai might be problematic. Click here for more details.

@@ -0,0 +1,174 @@
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Callable, Optional, Sequence
7
+
8
+ from ..rubric_eval import ensure_api_key_available
9
+ from ..rubric_types import MissingAPIKeyError
10
+ from .config import RubricConfig, RubricSuite, discover_rubric_config_path, load_rubric_suite
11
+ from .dataset import DatasetLoader, DatasetRecord
12
+ from .engine import RubricEvaluationEngine, EvaluationReport
13
+ from .errors import CLIError
14
+ from .reporting import BaselineComparator, BaselineStatistics, JsonReportWriter
15
+
16
+
17
+ _CACHE_ROOT = Path("~/.cache/osmosis/eval_result").expanduser()
18
+
19
+
20
+ def _sanitise_rubric_folder(rubric_id: str) -> str:
21
+ """Produce a filesystem-safe folder name for the rubric id."""
22
+ clean = "".join(ch if ch.isalnum() or ch in {"-", "_"} else "_" for ch in rubric_id.strip())
23
+ clean = clean.strip("_") or "rubric"
24
+ return clean.lower()
25
+
26
+
27
+ @dataclass(frozen=True)
28
+ class EvaluationSessionRequest:
29
+ rubric_id: str
30
+ data_path: Path
31
+ number: int = 1
32
+ config_path: Optional[Path] = None
33
+ output_path: Optional[Path] = None
34
+ output_identifier: Optional[str] = None
35
+ baseline_path: Optional[Path] = None
36
+
37
+
38
+ @dataclass
39
+ class EvaluationSessionResult:
40
+ request: EvaluationSessionRequest
41
+ config_path: Path
42
+ data_path: Path
43
+ rubric_config: RubricConfig
44
+ records: Sequence[DatasetRecord]
45
+ report: EvaluationReport
46
+ baseline: Optional[BaselineStatistics]
47
+ written_path: Optional[Path]
48
+ output_identifier: Optional[str]
49
+
50
+
51
+ class EvaluationSession:
52
+ """Coordinates rubric evaluation end-to-end for reusable orchestration."""
53
+
54
+ def __init__(
55
+ self,
56
+ *,
57
+ config_locator: Callable[[Optional[str], Path], Path] = discover_rubric_config_path,
58
+ suite_loader: Callable[[Path], RubricSuite] = load_rubric_suite,
59
+ dataset_loader: Optional[DatasetLoader] = None,
60
+ engine: Optional[RubricEvaluationEngine] = None,
61
+ baseline_comparator: Optional[BaselineComparator] = None,
62
+ report_writer: Optional[JsonReportWriter] = None,
63
+ identifier_factory: Optional[Callable[[], str]] = None,
64
+ ):
65
+ self._config_locator = config_locator
66
+ self._suite_loader = suite_loader
67
+ self._dataset_loader = dataset_loader or DatasetLoader()
68
+ self._engine = engine or RubricEvaluationEngine()
69
+ self._baseline_comparator = baseline_comparator or BaselineComparator()
70
+ self._report_writer = report_writer or JsonReportWriter()
71
+ self._identifier_factory = identifier_factory or self._default_identifier
72
+
73
+ def execute(self, request: EvaluationSessionRequest) -> EvaluationSessionResult:
74
+ rubric_id = request.rubric_id.strip()
75
+ if not rubric_id:
76
+ raise CLIError("Rubric identifier cannot be empty.")
77
+
78
+ number_value = request.number if request.number is not None else 1
79
+ number = int(number_value)
80
+ if number < 1:
81
+ raise CLIError("Number of runs must be a positive integer.")
82
+
83
+ data_path = request.data_path.expanduser()
84
+ if not data_path.exists():
85
+ raise CLIError(f"Data path '{data_path}' does not exist.")
86
+ if data_path.is_dir():
87
+ raise CLIError(f"Expected a JSONL file but received directory '{data_path}'.")
88
+
89
+ config_override = str(request.config_path.expanduser()) if request.config_path else None
90
+ config_path = self._config_locator(config_override, data_path)
91
+ suite = self._suite_loader(config_path)
92
+ rubric_config = suite.get(rubric_id)
93
+
94
+ try:
95
+ ensure_api_key_available(rubric_config.model_info)
96
+ except (MissingAPIKeyError, TypeError) as exc:
97
+ raise CLIError(str(exc)) from exc
98
+
99
+ all_records = self._dataset_loader.load(data_path)
100
+ matching_records = [
101
+ record for record in all_records if record.rubric_id.lower() == rubric_id.lower()
102
+ ]
103
+ if not matching_records:
104
+ raise CLIError(f"No records in '{data_path}' reference rubric '{rubric_id}'.")
105
+
106
+ baseline_stats = self._load_baseline(request.baseline_path)
107
+
108
+ resolved_output_path, resolved_identifier = self._resolve_output_path(
109
+ request.output_path,
110
+ request.output_identifier,
111
+ rubric_id=rubric_id,
112
+ )
113
+
114
+ report = self._engine.execute(
115
+ rubric_config=rubric_config,
116
+ config_path=config_path,
117
+ data_path=data_path,
118
+ records=matching_records,
119
+ number=number,
120
+ )
121
+
122
+ written_path = None
123
+ if resolved_output_path is not None:
124
+ written_path = self._report_writer.write(
125
+ report,
126
+ output_path=resolved_output_path,
127
+ output_identifier=resolved_identifier,
128
+ baseline=baseline_stats,
129
+ )
130
+
131
+ return EvaluationSessionResult(
132
+ request=request,
133
+ config_path=config_path,
134
+ data_path=data_path,
135
+ rubric_config=rubric_config,
136
+ records=matching_records,
137
+ report=report,
138
+ baseline=baseline_stats,
139
+ written_path=written_path,
140
+ output_identifier=resolved_identifier,
141
+ )
142
+
143
+ def _load_baseline(self, baseline_path: Optional[Path]) -> Optional[BaselineStatistics]:
144
+ if baseline_path is None:
145
+ return None
146
+ resolved = baseline_path.expanduser()
147
+ return self._baseline_comparator.load(resolved)
148
+
149
+ def _resolve_output_path(
150
+ self,
151
+ output_candidate: Optional[Path],
152
+ output_identifier: Optional[str],
153
+ *,
154
+ rubric_id: str,
155
+ ) -> tuple[Optional[Path], Optional[str]]:
156
+ if output_candidate is None:
157
+ identifier = output_identifier or self._identifier_factory()
158
+ target_dir = _CACHE_ROOT / _sanitise_rubric_folder(rubric_id)
159
+ target_dir.mkdir(parents=True, exist_ok=True)
160
+ return target_dir / f"rubric_eval_result_{identifier}.json", identifier
161
+
162
+ candidate = output_candidate.expanduser()
163
+ if candidate.suffix:
164
+ if candidate.exists() and candidate.is_dir():
165
+ raise CLIError(f"Output path '{candidate}' is a directory.")
166
+ return candidate, output_identifier
167
+
168
+ candidate.mkdir(parents=True, exist_ok=True)
169
+ identifier = output_identifier or self._identifier_factory()
170
+ return candidate / f"rubric_eval_result_{identifier}.json", identifier
171
+
172
+ @staticmethod
173
+ def _default_identifier() -> str:
174
+ return str(int(time.time()))
@@ -0,0 +1,209 @@
1
+ from __future__ import annotations
2
+
3
+ from statistics import mean, pvariance, pstdev
4
+ from typing import Any, Collection, Optional, Set
5
+
6
+ from .errors import CLIError
7
+
8
+
9
+ def coerce_optional_float(value: Any, field_name: str, source_label: str) -> Optional[float]:
10
+ if value is None:
11
+ return None
12
+ if isinstance(value, (int, float)) and not isinstance(value, bool):
13
+ return float(value)
14
+ raise CLIError(
15
+ f"Expected '{field_name}' in {source_label} to be numeric, got {type(value).__name__}."
16
+ )
17
+
18
+
19
+ def collapse_preview_text(value: Any, *, max_length: int = 140) -> Optional[str]:
20
+ if not isinstance(value, str):
21
+ return None
22
+ collapsed = " ".join(value.strip().split())
23
+ if not collapsed:
24
+ return None
25
+ if len(collapsed) > max_length:
26
+ collapsed = collapsed[: max_length - 3].rstrip() + "..."
27
+ return collapsed
28
+
29
+
30
+ def calculate_statistics(scores: list[float]) -> dict[str, float]:
31
+ if not scores:
32
+ return {
33
+ "average": 0.0,
34
+ "variance": 0.0,
35
+ "stdev": 0.0,
36
+ "min": 0.0,
37
+ "max": 0.0,
38
+ }
39
+ average = mean(scores)
40
+ variance = pvariance(scores)
41
+ std_dev = pstdev(scores)
42
+ return {
43
+ "average": average,
44
+ "variance": variance,
45
+ "stdev": std_dev,
46
+ "min": min(scores),
47
+ "max": max(scores),
48
+ }
49
+
50
+
51
+ def calculate_stat_deltas(baseline: dict[str, float], current: dict[str, float]) -> dict[str, float]:
52
+ delta: dict[str, float] = {}
53
+ for key, current_value in current.items():
54
+ if key not in baseline:
55
+ continue
56
+ try:
57
+ baseline_value = float(baseline[key])
58
+ current_numeric = float(current_value)
59
+ except (TypeError, ValueError):
60
+ continue
61
+ delta[key] = current_numeric - baseline_value
62
+ return delta
63
+
64
+
65
+ def gather_text_fragments(
66
+ node: Any,
67
+ fragments: list[str],
68
+ *,
69
+ allow_free_strings: bool = False,
70
+ seen: Optional[Set[int]] = None,
71
+ string_key_allowlist: Optional[Collection[str]] = None,
72
+ ) -> None:
73
+ """Collect textual snippets from nested message-like structures.
74
+
75
+ The traversal favours common chat-completions shapes (e.g. ``{"type": "text"}``
76
+ blocks) and avoids indiscriminately pulling in metadata values such as IDs.
77
+ ``allow_free_strings`` controls whether bare strings encountered at the current
78
+ level should be considered textual content (useful for raw message content but
79
+ typically disabled for metadata fields).
80
+ """
81
+
82
+ if seen is None:
83
+ seen = set()
84
+
85
+ if isinstance(node, str):
86
+ if allow_free_strings:
87
+ stripped = node.strip()
88
+ if stripped:
89
+ fragments.append(stripped)
90
+ return
91
+
92
+ if isinstance(node, list):
93
+ for item in node:
94
+ gather_text_fragments(
95
+ item,
96
+ fragments,
97
+ allow_free_strings=allow_free_strings,
98
+ seen=seen,
99
+ string_key_allowlist=string_key_allowlist,
100
+ )
101
+ return
102
+
103
+ if not isinstance(node, dict):
104
+ return
105
+
106
+ node_id = id(node)
107
+ if node_id in seen:
108
+ return
109
+ seen.add(node_id)
110
+
111
+ allowlist = {"text", "value", "message"}
112
+ if string_key_allowlist is not None:
113
+ allowlist = {key.lower() for key in string_key_allowlist}
114
+ else:
115
+ allowlist = {key.lower() for key in allowlist}
116
+
117
+ prioritized_keys = ("text", "value")
118
+ handled_keys: Set[str] = {
119
+ "text",
120
+ "value",
121
+ "content",
122
+ "message",
123
+ "parts",
124
+ "input_text",
125
+ "output_text",
126
+ "type",
127
+ "role",
128
+ "name",
129
+ "id",
130
+ "index",
131
+ "finish_reason",
132
+ "reason",
133
+ "tool_call_id",
134
+ "metadata",
135
+ }
136
+
137
+ for key in prioritized_keys:
138
+ if key not in node:
139
+ continue
140
+ before_count = len(fragments)
141
+ gather_text_fragments(
142
+ node[key],
143
+ fragments,
144
+ allow_free_strings=True,
145
+ seen=seen,
146
+ string_key_allowlist=string_key_allowlist,
147
+ )
148
+ if len(fragments) > before_count:
149
+ break
150
+
151
+ if node.get("type") == "tool_result" and "content" in node:
152
+ gather_text_fragments(
153
+ node["content"],
154
+ fragments,
155
+ allow_free_strings=True,
156
+ seen=seen,
157
+ string_key_allowlist=string_key_allowlist,
158
+ )
159
+ elif "content" in node:
160
+ gather_text_fragments(
161
+ node["content"],
162
+ fragments,
163
+ allow_free_strings=True,
164
+ seen=seen,
165
+ string_key_allowlist=string_key_allowlist,
166
+ )
167
+
168
+ for key in ("message", "parts", "input_text", "output_text"):
169
+ if key in node:
170
+ gather_text_fragments(
171
+ node[key],
172
+ fragments,
173
+ allow_free_strings=True,
174
+ seen=seen,
175
+ string_key_allowlist=string_key_allowlist,
176
+ )
177
+
178
+ for key, value in node.items():
179
+ if key in handled_keys:
180
+ continue
181
+ if isinstance(value, (list, dict)):
182
+ gather_text_fragments(
183
+ value,
184
+ fragments,
185
+ allow_free_strings=False,
186
+ seen=seen,
187
+ string_key_allowlist=string_key_allowlist,
188
+ )
189
+ elif isinstance(value, str) and key.lower() in allowlist:
190
+ stripped = value.strip()
191
+ if stripped:
192
+ fragments.append(stripped)
193
+
194
+
195
+ def collect_text_fragments(
196
+ node: Any,
197
+ *,
198
+ allow_free_strings: bool = False,
199
+ string_key_allowlist: Optional[Collection[str]] = None,
200
+ ) -> list[str]:
201
+ fragments: list[str] = []
202
+ gather_text_fragments(
203
+ node,
204
+ fragments,
205
+ allow_free_strings=allow_free_strings,
206
+ seen=set(),
207
+ string_key_allowlist=string_key_allowlist,
208
+ )
209
+ return fragments
osmosis_ai/consts.py CHANGED
@@ -1,3 +1,3 @@
1
1
  # package metadata
2
2
  package_name = "osmosis-ai"
3
- package_version = "0.2.1"
3
+ package_version = "0.2.2"
@@ -0,0 +1,36 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Tuple
4
+
5
+ from .anthropic_provider import AnthropicProvider
6
+ from .base import DEFAULT_REQUEST_TIMEOUT_SECONDS, ProviderRegistry, ProviderRequest, RubricProvider
7
+ from .gemini_provider import GeminiProvider
8
+ from .openai_family import OpenAIProvider, XAIProvider
9
+
10
+ _REGISTRY = ProviderRegistry()
11
+ _REGISTRY.register(OpenAIProvider())
12
+ _REGISTRY.register(XAIProvider())
13
+ _REGISTRY.register(AnthropicProvider())
14
+ _REGISTRY.register(GeminiProvider())
15
+
16
+
17
+ def get_provider(name: str) -> RubricProvider:
18
+ return _REGISTRY.get(name)
19
+
20
+
21
+ def register_provider(provider: RubricProvider) -> None:
22
+ _REGISTRY.register(provider)
23
+
24
+
25
+ def supported_providers() -> Tuple[str, ...]:
26
+ return _REGISTRY.supported_providers()
27
+
28
+
29
+ __all__ = [
30
+ "DEFAULT_REQUEST_TIMEOUT_SECONDS",
31
+ "ProviderRequest",
32
+ "RubricProvider",
33
+ "get_provider",
34
+ "register_provider",
35
+ "supported_providers",
36
+ ]
@@ -0,0 +1,85 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict
4
+
5
+ try: # pragma: no cover - optional dependency
6
+ import anthropic # type: ignore
7
+ from anthropic import APIError # type: ignore
8
+ except ImportError: # pragma: no cover - optional dependency
9
+ anthropic = None # type: ignore[assignment]
10
+ APIError = None # type: ignore[assignment]
11
+
12
+ from ..rubric_types import ModelNotFoundError, ProviderRequestError, RewardRubricRunResult
13
+ from .base import DEFAULT_REQUEST_TIMEOUT_SECONDS, ProviderRequest, RubricProvider
14
+ from .shared import dump_model, extract_structured_score, reward_schema_definition
15
+
16
+
17
+ class AnthropicProvider(RubricProvider):
18
+ name = "anthropic"
19
+
20
+ def default_timeout(self, model: str) -> float:
21
+ return DEFAULT_REQUEST_TIMEOUT_SECONDS
22
+
23
+ def run(self, request: ProviderRequest) -> RewardRubricRunResult:
24
+ if anthropic is None or APIError is None:
25
+ raise ProviderRequestError(
26
+ self.name,
27
+ request.model,
28
+ "Anthropic SDK is required. Install it via `pip install anthropic`.",
29
+ )
30
+
31
+ client = anthropic.Anthropic(api_key=request.api_key)
32
+ tool_name = "emit_reward_rubric_response"
33
+ schema_definition = reward_schema_definition()
34
+ tool = {
35
+ "name": tool_name,
36
+ "description": "Return the reward rubric score and explanation as structured JSON.",
37
+ "input_schema": schema_definition,
38
+ }
39
+
40
+ try:
41
+ response = client.messages.create(
42
+ model=request.model,
43
+ system=request.system_content,
44
+ messages=[{"role": "user", "content": [{"type": "text", "text": request.user_content}]}],
45
+ tools=[tool],
46
+ tool_choice={"type": "tool", "name": tool_name},
47
+ max_tokens=512,
48
+ temperature=0,
49
+ timeout=request.timeout,
50
+ )
51
+ except APIError as err:
52
+ detail = getattr(err, "message", None)
53
+ if not isinstance(detail, str) or not detail.strip():
54
+ detail = str(err)
55
+ status_code = getattr(err, "status_code", None)
56
+ if status_code == 404:
57
+ not_found_detail = (
58
+ f"Model '{request.model}' was not found. Confirm your Anthropic account has access "
59
+ "to the requested snapshot or update the model identifier."
60
+ )
61
+ raise ModelNotFoundError(self.name, request.model, not_found_detail) from err
62
+ raise ProviderRequestError(self.name, request.model, detail) from err
63
+ except Exception as err:
64
+ detail = str(err).strip() or "Unexpected error during Anthropic request."
65
+ raise ProviderRequestError(self.name, request.model, detail) from err
66
+
67
+ raw = dump_model(response)
68
+
69
+ payload: Dict[str, Any] | None = None
70
+ content_blocks = raw.get("content") if isinstance(raw, dict) else None
71
+ if isinstance(content_blocks, list):
72
+ for block in content_blocks:
73
+ if isinstance(block, dict) and block.get("type") == "tool_use" and block.get("name") == tool_name:
74
+ maybe_input = block.get("input")
75
+ if isinstance(maybe_input, dict):
76
+ payload = maybe_input
77
+ break
78
+ if payload is None:
79
+ raise ProviderRequestError(self.name, request.model, "Model response missing expected tool output.")
80
+ score, explanation = extract_structured_score(payload)
81
+ bounded = max(request.score_min, min(request.score_max, score))
82
+ return {"score": bounded, "explanation": explanation, "raw": raw}
83
+
84
+
85
+ __all__ = ["AnthropicProvider"]
@@ -0,0 +1,60 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Dict, Tuple
5
+
6
+ from ..rubric_types import RewardRubricRunResult
7
+
8
+ DEFAULT_REQUEST_TIMEOUT_SECONDS = 30.0
9
+
10
+
11
+ @dataclass(frozen=True)
12
+ class ProviderRequest:
13
+ provider: str
14
+ model: str
15
+ api_key: str
16
+ system_content: str
17
+ user_content: str
18
+ score_min: float
19
+ score_max: float
20
+ timeout: float
21
+
22
+
23
+ class RubricProvider:
24
+ """Interface for hosted LLM providers that can score rubrics."""
25
+
26
+ name: str
27
+
28
+ def default_timeout(self, model: str) -> float:
29
+ return DEFAULT_REQUEST_TIMEOUT_SECONDS
30
+
31
+ def run(self, request: ProviderRequest) -> RewardRubricRunResult:
32
+ raise NotImplementedError
33
+
34
+
35
+ class ProviderRegistry:
36
+ def __init__(self) -> None:
37
+ self._providers: Dict[str, RubricProvider] = {}
38
+
39
+ def register(self, provider: RubricProvider) -> None:
40
+ key = provider.name
41
+ if key in self._providers:
42
+ raise ValueError(f"Provider '{key}' is already registered.")
43
+ self._providers[key] = provider
44
+
45
+ def get(self, name: str) -> RubricProvider:
46
+ try:
47
+ return self._providers[name]
48
+ except KeyError as exc:
49
+ raise ValueError(f"Unsupported provider '{name}'.") from exc
50
+
51
+ def supported_providers(self) -> Tuple[str, ...]:
52
+ return tuple(sorted(self._providers))
53
+
54
+
55
+ __all__ = [
56
+ "DEFAULT_REQUEST_TIMEOUT_SECONDS",
57
+ "ProviderRequest",
58
+ "RubricProvider",
59
+ "ProviderRegistry",
60
+ ]