osmosis-ai 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of osmosis-ai might be problematic. Click here for more details.

osmosis_ai/__init__.py CHANGED
@@ -1,8 +1,8 @@
1
1
  """
2
2
  osmosis-ai: A Python library for reward function validation with strict type enforcement.
3
3
 
4
- This library provides the @osmosis_reward decorator that enforces standardized
5
- function signatures for reward functions used in LLM applications.
4
+ This library provides decorators such as @osmosis_reward and @osmosis_rubric that
5
+ enforce standardized function signatures for LLM-centric workflows.
6
6
 
7
7
  Features:
8
8
  - Type-safe reward function decoration
@@ -10,6 +10,15 @@ Features:
10
10
  - Support for optional configuration parameters
11
11
  """
12
12
 
13
- from .utils import osmosis_reward
13
+ from .rubric_eval import MissingAPIKeyError, evaluate_rubric
14
+ from .rubric_types import ModelNotFoundError, ProviderRequestError
15
+ from .utils import osmosis_reward, osmosis_rubric
14
16
 
15
- __all__ = ["osmosis_reward"]
17
+ __all__ = [
18
+ "osmosis_reward",
19
+ "osmosis_rubric",
20
+ "evaluate_rubric",
21
+ "MissingAPIKeyError",
22
+ "ProviderRequestError",
23
+ "ModelNotFoundError",
24
+ ]
osmosis_ai/consts.py CHANGED
@@ -1,3 +1,3 @@
1
1
  # package metadata
2
2
  package_name = "osmosis-ai"
3
- package_version = "0.2.0"
3
+ package_version = "0.2.2"
@@ -0,0 +1,36 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Tuple
4
+
5
+ from .anthropic_provider import AnthropicProvider
6
+ from .base import DEFAULT_REQUEST_TIMEOUT_SECONDS, ProviderRegistry, ProviderRequest, RubricProvider
7
+ from .gemini_provider import GeminiProvider
8
+ from .openai_family import OpenAIProvider, XAIProvider
9
+
10
+ _REGISTRY = ProviderRegistry()
11
+ _REGISTRY.register(OpenAIProvider())
12
+ _REGISTRY.register(XAIProvider())
13
+ _REGISTRY.register(AnthropicProvider())
14
+ _REGISTRY.register(GeminiProvider())
15
+
16
+
17
+ def get_provider(name: str) -> RubricProvider:
18
+ return _REGISTRY.get(name)
19
+
20
+
21
+ def register_provider(provider: RubricProvider) -> None:
22
+ _REGISTRY.register(provider)
23
+
24
+
25
+ def supported_providers() -> Tuple[str, ...]:
26
+ return _REGISTRY.supported_providers()
27
+
28
+
29
+ __all__ = [
30
+ "DEFAULT_REQUEST_TIMEOUT_SECONDS",
31
+ "ProviderRequest",
32
+ "RubricProvider",
33
+ "get_provider",
34
+ "register_provider",
35
+ "supported_providers",
36
+ ]
@@ -0,0 +1,85 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any, Dict
4
+
5
+ try: # pragma: no cover - optional dependency
6
+ import anthropic # type: ignore
7
+ from anthropic import APIError # type: ignore
8
+ except ImportError: # pragma: no cover - optional dependency
9
+ anthropic = None # type: ignore[assignment]
10
+ APIError = None # type: ignore[assignment]
11
+
12
+ from ..rubric_types import ModelNotFoundError, ProviderRequestError, RewardRubricRunResult
13
+ from .base import DEFAULT_REQUEST_TIMEOUT_SECONDS, ProviderRequest, RubricProvider
14
+ from .shared import dump_model, extract_structured_score, reward_schema_definition
15
+
16
+
17
+ class AnthropicProvider(RubricProvider):
18
+ name = "anthropic"
19
+
20
+ def default_timeout(self, model: str) -> float:
21
+ return DEFAULT_REQUEST_TIMEOUT_SECONDS
22
+
23
+ def run(self, request: ProviderRequest) -> RewardRubricRunResult:
24
+ if anthropic is None or APIError is None:
25
+ raise ProviderRequestError(
26
+ self.name,
27
+ request.model,
28
+ "Anthropic SDK is required. Install it via `pip install anthropic`.",
29
+ )
30
+
31
+ client = anthropic.Anthropic(api_key=request.api_key)
32
+ tool_name = "emit_reward_rubric_response"
33
+ schema_definition = reward_schema_definition()
34
+ tool = {
35
+ "name": tool_name,
36
+ "description": "Return the reward rubric score and explanation as structured JSON.",
37
+ "input_schema": schema_definition,
38
+ }
39
+
40
+ try:
41
+ response = client.messages.create(
42
+ model=request.model,
43
+ system=request.system_content,
44
+ messages=[{"role": "user", "content": [{"type": "text", "text": request.user_content}]}],
45
+ tools=[tool],
46
+ tool_choice={"type": "tool", "name": tool_name},
47
+ max_tokens=512,
48
+ temperature=0,
49
+ timeout=request.timeout,
50
+ )
51
+ except APIError as err:
52
+ detail = getattr(err, "message", None)
53
+ if not isinstance(detail, str) or not detail.strip():
54
+ detail = str(err)
55
+ status_code = getattr(err, "status_code", None)
56
+ if status_code == 404:
57
+ not_found_detail = (
58
+ f"Model '{request.model}' was not found. Confirm your Anthropic account has access "
59
+ "to the requested snapshot or update the model identifier."
60
+ )
61
+ raise ModelNotFoundError(self.name, request.model, not_found_detail) from err
62
+ raise ProviderRequestError(self.name, request.model, detail) from err
63
+ except Exception as err:
64
+ detail = str(err).strip() or "Unexpected error during Anthropic request."
65
+ raise ProviderRequestError(self.name, request.model, detail) from err
66
+
67
+ raw = dump_model(response)
68
+
69
+ payload: Dict[str, Any] | None = None
70
+ content_blocks = raw.get("content") if isinstance(raw, dict) else None
71
+ if isinstance(content_blocks, list):
72
+ for block in content_blocks:
73
+ if isinstance(block, dict) and block.get("type") == "tool_use" and block.get("name") == tool_name:
74
+ maybe_input = block.get("input")
75
+ if isinstance(maybe_input, dict):
76
+ payload = maybe_input
77
+ break
78
+ if payload is None:
79
+ raise ProviderRequestError(self.name, request.model, "Model response missing expected tool output.")
80
+ score, explanation = extract_structured_score(payload)
81
+ bounded = max(request.score_min, min(request.score_max, score))
82
+ return {"score": bounded, "explanation": explanation, "raw": raw}
83
+
84
+
85
+ __all__ = ["AnthropicProvider"]
@@ -0,0 +1,60 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Dict, Tuple
5
+
6
+ from ..rubric_types import RewardRubricRunResult
7
+
8
+ DEFAULT_REQUEST_TIMEOUT_SECONDS = 30.0
9
+
10
+
11
+ @dataclass(frozen=True)
12
+ class ProviderRequest:
13
+ provider: str
14
+ model: str
15
+ api_key: str
16
+ system_content: str
17
+ user_content: str
18
+ score_min: float
19
+ score_max: float
20
+ timeout: float
21
+
22
+
23
+ class RubricProvider:
24
+ """Interface for hosted LLM providers that can score rubrics."""
25
+
26
+ name: str
27
+
28
+ def default_timeout(self, model: str) -> float:
29
+ return DEFAULT_REQUEST_TIMEOUT_SECONDS
30
+
31
+ def run(self, request: ProviderRequest) -> RewardRubricRunResult:
32
+ raise NotImplementedError
33
+
34
+
35
+ class ProviderRegistry:
36
+ def __init__(self) -> None:
37
+ self._providers: Dict[str, RubricProvider] = {}
38
+
39
+ def register(self, provider: RubricProvider) -> None:
40
+ key = provider.name
41
+ if key in self._providers:
42
+ raise ValueError(f"Provider '{key}' is already registered.")
43
+ self._providers[key] = provider
44
+
45
+ def get(self, name: str) -> RubricProvider:
46
+ try:
47
+ return self._providers[name]
48
+ except KeyError as exc:
49
+ raise ValueError(f"Unsupported provider '{name}'.") from exc
50
+
51
+ def supported_providers(self) -> Tuple[str, ...]:
52
+ return tuple(sorted(self._providers))
53
+
54
+
55
+ __all__ = [
56
+ "DEFAULT_REQUEST_TIMEOUT_SECONDS",
57
+ "ProviderRequest",
58
+ "RubricProvider",
59
+ "ProviderRegistry",
60
+ ]
@@ -0,0 +1,269 @@
1
+ from __future__ import annotations
2
+
3
+ from contextlib import contextmanager
4
+ import time
5
+ import warnings
6
+ from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Tuple
7
+
8
+ if TYPE_CHECKING: # pragma: no cover - typing helpers only
9
+ from google import genai as genai_module # type: ignore
10
+ from google.genai import types as genai_types_module # type: ignore
11
+
12
+ from ..rubric_types import ProviderRequestError, RewardRubricRunResult
13
+ from .base import DEFAULT_REQUEST_TIMEOUT_SECONDS, ProviderRequest, RubricProvider
14
+ from .shared import dump_model, reward_schema_definition, sanitize_json
15
+
16
+
17
+ _GENAI_MODULE: Any | None = None
18
+ _GENAI_TYPES_MODULE: Any | None = None
19
+ _PYDANTIC_ANY_WARNING_MESSAGE = r".*<built-in function any> is not a Python type.*"
20
+
21
+ GEMINI_DEFAULT_TIMEOUT_SECONDS = 60.0
22
+ GEMINI_MIN_TIMEOUT_SECONDS = 5.0
23
+ GEMINI_MAX_TIMEOUT_SECONDS = 180.0
24
+ GEMINI_RETRY_ATTEMPTS = 3
25
+ GEMINI_TIMEOUT_BACKOFF = 1.5
26
+ GEMINI_RETRY_SLEEP_SECONDS = (0.5, 1.0, 2.0)
27
+
28
+
29
+ @contextmanager
30
+ def _suppress_pydantic_any_warning() -> Iterator[None]:
31
+ with warnings.catch_warnings():
32
+ warnings.filterwarnings(
33
+ "ignore",
34
+ message=_PYDANTIC_ANY_WARNING_MESSAGE,
35
+ category=UserWarning,
36
+ module=r"pydantic\._internal\._generate_schema",
37
+ )
38
+ yield
39
+
40
+
41
+ def _load_google_genai() -> Tuple[Any, Any]:
42
+ """
43
+ Lazily import the Google Generative AI SDK so that environments without the optional
44
+ dependency avoid import-time side effects (like pydantic warnings) unless the Gemini
45
+ provider is actually used.
46
+ """
47
+ global _GENAI_MODULE, _GENAI_TYPES_MODULE
48
+ if _GENAI_MODULE is not None and _GENAI_TYPES_MODULE is not None:
49
+ return _GENAI_MODULE, _GENAI_TYPES_MODULE
50
+
51
+ try: # pragma: no cover - optional dependency
52
+ with _suppress_pydantic_any_warning():
53
+ from google import genai as genai_mod # type: ignore
54
+ from google.genai import types as genai_types_mod # type: ignore
55
+ except ImportError as exc: # pragma: no cover - optional dependency
56
+ raise RuntimeError(
57
+ "Google Generative AI SDK is required for provider 'gemini'. "
58
+ "Install it via `pip install google-genai`."
59
+ ) from exc
60
+
61
+ _GENAI_MODULE = genai_mod
62
+ _GENAI_TYPES_MODULE = genai_types_mod
63
+ return _GENAI_MODULE, _GENAI_TYPES_MODULE
64
+
65
+
66
+ def _normalize_gemini_model(model_id: str) -> str:
67
+ import re
68
+
69
+ return re.sub(r"^models/", "", model_id, flags=re.IGNORECASE)
70
+
71
+
72
+ def _json_schema_to_genai(
73
+ schema: Dict[str, Any],
74
+ genai_types: Any,
75
+ ) -> "genai_types_module.Schema": # type: ignore[name-defined]
76
+
77
+ type_map = {
78
+ "object": genai_types.Type.OBJECT,
79
+ "string": genai_types.Type.STRING,
80
+ "number": genai_types.Type.NUMBER,
81
+ "integer": genai_types.Type.INTEGER,
82
+ "boolean": genai_types.Type.BOOLEAN,
83
+ "array": genai_types.Type.ARRAY,
84
+ }
85
+
86
+ kwargs: Dict[str, Any] = {}
87
+ type_value = schema.get("type")
88
+ if isinstance(type_value, str):
89
+ mapped = type_map.get(type_value.lower())
90
+ if mapped is not None:
91
+ kwargs["type"] = mapped
92
+
93
+ required = schema.get("required")
94
+ if isinstance(required, list):
95
+ filtered_required = [name for name in required if isinstance(name, str)]
96
+ if filtered_required:
97
+ kwargs["required"] = filtered_required
98
+
99
+ properties = schema.get("properties")
100
+ if isinstance(properties, dict):
101
+ converted_properties = {}
102
+ for key, value in properties.items():
103
+ if isinstance(key, str) and isinstance(value, dict):
104
+ converted_properties[key] = _json_schema_to_genai(value, genai_types)
105
+ if converted_properties:
106
+ kwargs["properties"] = converted_properties
107
+
108
+ items = schema.get("items")
109
+ if isinstance(items, dict):
110
+ kwargs["items"] = _json_schema_to_genai(items, genai_types)
111
+
112
+ enum_values = schema.get("enum")
113
+ if isinstance(enum_values, list):
114
+ filtered_enum = [str(option) for option in enum_values]
115
+ if filtered_enum:
116
+ kwargs["enum"] = filtered_enum
117
+
118
+ description = schema.get("description")
119
+ if isinstance(description, str):
120
+ kwargs["description"] = description
121
+
122
+ minimum = schema.get("minimum")
123
+ if isinstance(minimum, (int, float)):
124
+ kwargs["minimum"] = float(minimum)
125
+
126
+ maximum = schema.get("maximum")
127
+ if isinstance(maximum, (int, float)):
128
+ kwargs["maximum"] = float(maximum)
129
+
130
+ min_items = schema.get("min_items")
131
+ if isinstance(min_items, int):
132
+ kwargs["min_items"] = min_items
133
+
134
+ max_items = schema.get("max_items")
135
+ if isinstance(max_items, int):
136
+ kwargs["max_items"] = max_items
137
+
138
+ min_length = schema.get("min_length")
139
+ if isinstance(min_length, int):
140
+ kwargs["min_length"] = min_length
141
+
142
+ max_length = schema.get("max_length")
143
+ if isinstance(max_length, int):
144
+ kwargs["max_length"] = max_length
145
+
146
+ nullable = schema.get("nullable")
147
+ if isinstance(nullable, bool):
148
+ kwargs["nullable"] = nullable
149
+
150
+ with _suppress_pydantic_any_warning():
151
+ return genai_types.Schema(**kwargs)
152
+
153
+
154
+ def _build_retry_timeouts(requested_timeout: float) -> List[float]:
155
+ # Keep the first attempt generous, then increase for retries while capping growth.
156
+ base = max(requested_timeout, GEMINI_MIN_TIMEOUT_SECONDS, GEMINI_DEFAULT_TIMEOUT_SECONDS)
157
+ timeouts: List[float] = []
158
+ current = base
159
+ for _ in range(GEMINI_RETRY_ATTEMPTS):
160
+ timeouts.append(min(current, GEMINI_MAX_TIMEOUT_SECONDS))
161
+ current = min(current * GEMINI_TIMEOUT_BACKOFF, GEMINI_MAX_TIMEOUT_SECONDS)
162
+ return timeouts
163
+
164
+
165
+ def _seconds_to_millis(seconds: float) -> int:
166
+ # Gemini client expects timeout in milliseconds. Clamp to at least 1ms.
167
+ return max(int(round(seconds * 1000)), 1)
168
+
169
+
170
+ class GeminiProvider(RubricProvider):
171
+ name = "gemini"
172
+
173
+ def default_timeout(self, model: str) -> float:
174
+ return max(DEFAULT_REQUEST_TIMEOUT_SECONDS, GEMINI_DEFAULT_TIMEOUT_SECONDS)
175
+
176
+ def run(self, request: ProviderRequest) -> RewardRubricRunResult:
177
+ try:
178
+ genai, genai_types = _load_google_genai()
179
+ except RuntimeError as exc:
180
+ detail = str(exc).strip() or "Google Generative AI SDK is required."
181
+ raise ProviderRequestError(self.name, request.model, detail) from exc
182
+
183
+ try:
184
+ requested_timeout = float(request.timeout)
185
+ except (TypeError, ValueError):
186
+ requested_timeout = float(DEFAULT_REQUEST_TIMEOUT_SECONDS)
187
+
188
+ retry_timeouts = _build_retry_timeouts(requested_timeout)
189
+ max_timeout = max(retry_timeouts)
190
+
191
+ with _suppress_pydantic_any_warning():
192
+ client = genai.Client(
193
+ api_key=request.api_key,
194
+ http_options={"timeout": _seconds_to_millis(max_timeout)},
195
+ )
196
+ schema_definition = reward_schema_definition()
197
+ gemini_schema = _json_schema_to_genai(schema_definition, genai_types)
198
+ config = genai_types.GenerateContentConfig(
199
+ response_mime_type="application/json",
200
+ response_schema=gemini_schema,
201
+ temperature=0,
202
+ )
203
+
204
+ combined_prompt = f"{request.system_content}\n\n{request.user_content}"
205
+
206
+ response: Any | None = None
207
+ last_error: Exception | None = None
208
+
209
+ for attempt_index, attempt_timeout in enumerate(retry_timeouts, start=1):
210
+ try:
211
+ with _suppress_pydantic_any_warning():
212
+ try:
213
+ response = client.models.generate_content(
214
+ model=_normalize_gemini_model(request.model),
215
+ contents=combined_prompt,
216
+ config=config,
217
+ request_options={"timeout": _seconds_to_millis(attempt_timeout)},
218
+ )
219
+ except TypeError as err:
220
+ # Older SDKs may not accept request_options; retry without it.
221
+ if "request_options" not in str(err):
222
+ raise
223
+ response = client.models.generate_content(
224
+ model=_normalize_gemini_model(request.model),
225
+ contents=combined_prompt,
226
+ config=config,
227
+ )
228
+ break
229
+ except Exception as err: # pragma: no cover - network failures depend on runtime
230
+ last_error = err
231
+ if attempt_index >= len(retry_timeouts):
232
+ detail = str(err).strip() or "Gemini request failed."
233
+ raise ProviderRequestError(self.name, request.model, detail) from err
234
+ sleep_idx = min(attempt_index - 1, len(GEMINI_RETRY_SLEEP_SECONDS) - 1)
235
+ time.sleep(GEMINI_RETRY_SLEEP_SECONDS[sleep_idx])
236
+
237
+ if response is None and last_error is not None:
238
+ detail = str(last_error).strip() or "Gemini request failed."
239
+ raise ProviderRequestError(self.name, request.model, detail) from last_error
240
+
241
+ raw = dump_model(response)
242
+
243
+ text = getattr(response, "text", None)
244
+ if not isinstance(text, str) or not text.strip():
245
+ candidates = raw.get("candidates") if isinstance(raw, dict) else None
246
+ if isinstance(candidates, list) and candidates:
247
+ first = candidates[0]
248
+ if isinstance(first, dict):
249
+ content = first.get("content")
250
+ if isinstance(content, dict):
251
+ parts = content.get("parts")
252
+ if isinstance(parts, list):
253
+ for part in parts:
254
+ if isinstance(part, dict):
255
+ candidate_text = part.get("text")
256
+ if isinstance(candidate_text, str) and candidate_text.strip():
257
+ text = candidate_text
258
+ break
259
+ if not isinstance(text, str) or not text.strip():
260
+ raise ProviderRequestError(self.name, request.model, "Model response did not include any text content.")
261
+ try:
262
+ score, explanation = sanitize_json(text)
263
+ except ValueError as err:
264
+ raise ProviderRequestError(self.name, request.model, str(err)) from err
265
+ bounded = max(request.score_min, min(request.score_max, score))
266
+ return {"score": bounded, "explanation": explanation, "raw": raw}
267
+
268
+
269
+ __all__ = ["GeminiProvider"]