osmosis-ai 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of osmosis-ai might be problematic. Click here for more details.

@@ -0,0 +1,314 @@
1
+ from __future__ import annotations
2
+
3
+ from contextlib import contextmanager
4
+ import inspect
5
+ import time
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Tuple
8
+
9
+ if TYPE_CHECKING: # pragma: no cover - typing helpers only
10
+ from google import genai as genai_module # type: ignore
11
+ from google.genai import types as genai_types_module # type: ignore
12
+
13
+ from ..rubric_types import ProviderRequestError, RewardRubricRunResult
14
+ from .base import DEFAULT_REQUEST_TIMEOUT_SECONDS, ProviderRequest, RubricProvider
15
+ from .shared import dump_model, reward_schema_definition, sanitize_json
16
+
17
+
18
+ _GENAI_MODULE: Any | None = None
19
+ _GENAI_TYPES_MODULE: Any | None = None
20
+ _PYDANTIC_ANY_WARNING_MESSAGE = r".*<built-in function any> is not a Python type.*"
21
+
22
+ GEMINI_DEFAULT_TIMEOUT_SECONDS = 60.0
23
+ GEMINI_MIN_TIMEOUT_SECONDS = 5.0
24
+ GEMINI_MAX_TIMEOUT_SECONDS = 180.0
25
+ GEMINI_RETRY_ATTEMPTS = 3
26
+ GEMINI_TIMEOUT_BACKOFF = 1.5
27
+ GEMINI_RETRY_SLEEP_SECONDS = (0.5, 1.0, 2.0)
28
+
29
+
30
+ @contextmanager
31
+ def _suppress_pydantic_any_warning() -> Iterator[None]:
32
+ with warnings.catch_warnings():
33
+ warnings.filterwarnings(
34
+ "ignore",
35
+ message=_PYDANTIC_ANY_WARNING_MESSAGE,
36
+ category=UserWarning,
37
+ module=r"pydantic\._internal\._generate_schema",
38
+ )
39
+ yield
40
+
41
+
42
+ def _load_google_genai() -> Tuple[Any, Any]:
43
+ """
44
+ Lazily import the Google Generative AI SDK so that environments without the optional
45
+ dependency avoid import-time side effects (like pydantic warnings) unless the Gemini
46
+ provider is actually used.
47
+ """
48
+ global _GENAI_MODULE, _GENAI_TYPES_MODULE
49
+ if _GENAI_MODULE is not None and _GENAI_TYPES_MODULE is not None:
50
+ return _GENAI_MODULE, _GENAI_TYPES_MODULE
51
+
52
+ try: # pragma: no cover - optional dependency
53
+ with _suppress_pydantic_any_warning():
54
+ from google import genai as genai_mod # type: ignore
55
+ from google.genai import types as genai_types_mod # type: ignore
56
+ except ImportError as exc: # pragma: no cover - optional dependency
57
+ raise RuntimeError(
58
+ "Google Generative AI SDK is required for provider 'gemini'. "
59
+ "Install it via `pip install google-genai`."
60
+ ) from exc
61
+
62
+ _GENAI_MODULE = genai_mod
63
+ _GENAI_TYPES_MODULE = genai_types_mod
64
+ return _GENAI_MODULE, _GENAI_TYPES_MODULE
65
+
66
+
67
+ def _normalize_gemini_model(model_id: str) -> str:
68
+ import re
69
+
70
+ return re.sub(r"^models/", "", model_id, flags=re.IGNORECASE)
71
+
72
+
73
+ def _json_schema_to_genai(
74
+ schema: Dict[str, Any],
75
+ genai_types: Any,
76
+ ) -> "genai_types_module.Schema": # type: ignore[name-defined]
77
+
78
+ type_map = {
79
+ "object": genai_types.Type.OBJECT,
80
+ "string": genai_types.Type.STRING,
81
+ "number": genai_types.Type.NUMBER,
82
+ "integer": genai_types.Type.INTEGER,
83
+ "boolean": genai_types.Type.BOOLEAN,
84
+ "array": genai_types.Type.ARRAY,
85
+ }
86
+
87
+ kwargs: Dict[str, Any] = {}
88
+ type_value = schema.get("type")
89
+ if isinstance(type_value, str):
90
+ mapped = type_map.get(type_value.lower())
91
+ if mapped is not None:
92
+ kwargs["type"] = mapped
93
+
94
+ required = schema.get("required")
95
+ if isinstance(required, list):
96
+ filtered_required = [name for name in required if isinstance(name, str)]
97
+ if filtered_required:
98
+ kwargs["required"] = filtered_required
99
+
100
+ properties = schema.get("properties")
101
+ if isinstance(properties, dict):
102
+ converted_properties = {}
103
+ for key, value in properties.items():
104
+ if isinstance(key, str) and isinstance(value, dict):
105
+ converted_properties[key] = _json_schema_to_genai(value, genai_types)
106
+ if converted_properties:
107
+ kwargs["properties"] = converted_properties
108
+
109
+ items = schema.get("items")
110
+ if isinstance(items, dict):
111
+ kwargs["items"] = _json_schema_to_genai(items, genai_types)
112
+
113
+ enum_values = schema.get("enum")
114
+ if isinstance(enum_values, list):
115
+ filtered_enum = [str(option) for option in enum_values]
116
+ if filtered_enum:
117
+ kwargs["enum"] = filtered_enum
118
+
119
+ description = schema.get("description")
120
+ if isinstance(description, str):
121
+ kwargs["description"] = description
122
+
123
+ minimum = schema.get("minimum")
124
+ if isinstance(minimum, (int, float)):
125
+ kwargs["minimum"] = float(minimum)
126
+
127
+ maximum = schema.get("maximum")
128
+ if isinstance(maximum, (int, float)):
129
+ kwargs["maximum"] = float(maximum)
130
+
131
+ min_items = schema.get("min_items")
132
+ if isinstance(min_items, int):
133
+ kwargs["min_items"] = min_items
134
+
135
+ max_items = schema.get("max_items")
136
+ if isinstance(max_items, int):
137
+ kwargs["max_items"] = max_items
138
+
139
+ min_length = schema.get("min_length")
140
+ if isinstance(min_length, int):
141
+ kwargs["min_length"] = min_length
142
+
143
+ max_length = schema.get("max_length")
144
+ if isinstance(max_length, int):
145
+ kwargs["max_length"] = max_length
146
+
147
+ nullable = schema.get("nullable")
148
+ if isinstance(nullable, bool):
149
+ kwargs["nullable"] = nullable
150
+
151
+ with _suppress_pydantic_any_warning():
152
+ return genai_types.Schema(**kwargs)
153
+
154
+
155
+ def _build_retry_timeouts(requested_timeout: float) -> List[float]:
156
+ # Keep the first attempt generous, then increase for retries while capping growth.
157
+ base = max(requested_timeout, GEMINI_MIN_TIMEOUT_SECONDS, GEMINI_DEFAULT_TIMEOUT_SECONDS)
158
+ timeouts: List[float] = []
159
+ current = base
160
+ for _ in range(GEMINI_RETRY_ATTEMPTS):
161
+ timeouts.append(min(current, GEMINI_MAX_TIMEOUT_SECONDS))
162
+ current = min(current * GEMINI_TIMEOUT_BACKOFF, GEMINI_MAX_TIMEOUT_SECONDS)
163
+ return timeouts
164
+
165
+
166
+ def _seconds_to_millis(seconds: float) -> int:
167
+ # Gemini client expects timeout in milliseconds. Clamp to at least 1ms.
168
+ return max(int(round(seconds * 1000)), 1)
169
+
170
+
171
+ def _supports_request_options(generate_content: Any) -> bool:
172
+ try:
173
+ signature = inspect.signature(generate_content)
174
+ except (TypeError, ValueError):
175
+ return False
176
+ return "request_options" in signature.parameters
177
+
178
+
179
+ class GeminiProvider(RubricProvider):
180
+ name = "gemini"
181
+
182
+ def default_timeout(self, model: str) -> float:
183
+ return max(DEFAULT_REQUEST_TIMEOUT_SECONDS, GEMINI_DEFAULT_TIMEOUT_SECONDS)
184
+
185
+ def run(self, request: ProviderRequest) -> RewardRubricRunResult:
186
+ try:
187
+ genai, genai_types = _load_google_genai()
188
+ except RuntimeError as exc:
189
+ detail = str(exc).strip() or "Google Generative AI SDK is required."
190
+ raise ProviderRequestError(self.name, request.model, detail) from exc
191
+
192
+ try:
193
+ requested_timeout = float(request.timeout)
194
+ except (TypeError, ValueError):
195
+ requested_timeout = float(DEFAULT_REQUEST_TIMEOUT_SECONDS)
196
+
197
+ retry_timeouts = _build_retry_timeouts(requested_timeout)
198
+ max_timeout = max(retry_timeouts)
199
+
200
+ supports_request_options = False
201
+ shared_client: Any | None = None
202
+
203
+ with _suppress_pydantic_any_warning():
204
+ probe_client = genai.Client(
205
+ api_key=request.api_key,
206
+ http_options={"timeout": _seconds_to_millis(max_timeout)},
207
+ )
208
+ try:
209
+ supports_request_options = _supports_request_options(probe_client.models.generate_content)
210
+ except Exception:
211
+ try:
212
+ probe_client.close()
213
+ except Exception:
214
+ pass
215
+ raise
216
+
217
+ if supports_request_options:
218
+ shared_client = probe_client
219
+ else:
220
+ try:
221
+ probe_client.close()
222
+ except Exception:
223
+ pass
224
+
225
+ schema_definition = reward_schema_definition()
226
+ gemini_schema = _json_schema_to_genai(schema_definition, genai_types)
227
+ config = genai_types.GenerateContentConfig(
228
+ response_mime_type="application/json",
229
+ response_schema=gemini_schema,
230
+ temperature=0,
231
+ )
232
+
233
+ combined_prompt = f"{request.system_content}\n\n{request.user_content}"
234
+
235
+ response: Any | None = None
236
+ last_error: Exception | None = None
237
+
238
+ try:
239
+ for attempt_index, attempt_timeout in enumerate(retry_timeouts, start=1):
240
+ per_attempt_client: Any | None = None
241
+ http_timeout_ms = _seconds_to_millis(attempt_timeout)
242
+ try:
243
+ call_kwargs = {
244
+ "model": _normalize_gemini_model(request.model),
245
+ "contents": combined_prompt,
246
+ "config": config,
247
+ }
248
+ if supports_request_options and shared_client is not None:
249
+ call_client = shared_client
250
+ call_kwargs["request_options"] = {"timeout": http_timeout_ms}
251
+ else:
252
+ with _suppress_pydantic_any_warning():
253
+ per_attempt_client = genai.Client(
254
+ api_key=request.api_key,
255
+ http_options={"timeout": http_timeout_ms},
256
+ )
257
+ call_client = per_attempt_client
258
+
259
+ with _suppress_pydantic_any_warning():
260
+ response = call_client.models.generate_content(**call_kwargs)
261
+ break
262
+ except Exception as err: # pragma: no cover - network failures depend on runtime
263
+ last_error = err
264
+ if attempt_index >= len(retry_timeouts):
265
+ detail = str(err).strip() or "Gemini request failed."
266
+ raise ProviderRequestError(self.name, request.model, detail) from err
267
+ sleep_idx = min(attempt_index - 1, len(GEMINI_RETRY_SLEEP_SECONDS) - 1)
268
+ time.sleep(GEMINI_RETRY_SLEEP_SECONDS[sleep_idx])
269
+ finally:
270
+ if per_attempt_client is not None:
271
+ try:
272
+ per_attempt_client.close()
273
+ except Exception:
274
+ pass
275
+ finally:
276
+ if shared_client is not None:
277
+ try:
278
+ shared_client.close()
279
+ except Exception:
280
+ pass
281
+
282
+ if response is None and last_error is not None:
283
+ detail = str(last_error).strip() or "Gemini request failed."
284
+ raise ProviderRequestError(self.name, request.model, detail) from last_error
285
+
286
+ raw = dump_model(response)
287
+
288
+ text = getattr(response, "text", None)
289
+ if not isinstance(text, str) or not text.strip():
290
+ candidates = raw.get("candidates") if isinstance(raw, dict) else None
291
+ if isinstance(candidates, list) and candidates:
292
+ first = candidates[0]
293
+ if isinstance(first, dict):
294
+ content = first.get("content")
295
+ if isinstance(content, dict):
296
+ parts = content.get("parts")
297
+ if isinstance(parts, list):
298
+ for part in parts:
299
+ if isinstance(part, dict):
300
+ candidate_text = part.get("text")
301
+ if isinstance(candidate_text, str) and candidate_text.strip():
302
+ text = candidate_text
303
+ break
304
+ if not isinstance(text, str) or not text.strip():
305
+ raise ProviderRequestError(self.name, request.model, "Model response did not include any text content.")
306
+ try:
307
+ score, explanation = sanitize_json(text)
308
+ except ValueError as err:
309
+ raise ProviderRequestError(self.name, request.model, str(err)) from err
310
+ bounded = max(request.score_min, min(request.score_max, score))
311
+ return {"score": bounded, "explanation": explanation, "raw": raw}
312
+
313
+
314
+ __all__ = ["GeminiProvider"]