osmosis-ai 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of osmosis-ai might be problematic. Click here for more details.
- osmosis_ai/__init__.py +13 -4
- osmosis_ai/cli.py +50 -0
- osmosis_ai/cli_commands.py +181 -0
- osmosis_ai/cli_services/__init__.py +67 -0
- osmosis_ai/cli_services/config.py +407 -0
- osmosis_ai/cli_services/dataset.py +229 -0
- osmosis_ai/cli_services/engine.py +251 -0
- osmosis_ai/cli_services/errors.py +7 -0
- osmosis_ai/cli_services/reporting.py +307 -0
- osmosis_ai/cli_services/session.py +174 -0
- osmosis_ai/cli_services/shared.py +209 -0
- osmosis_ai/consts.py +1 -1
- osmosis_ai/providers/__init__.py +36 -0
- osmosis_ai/providers/anthropic_provider.py +85 -0
- osmosis_ai/providers/base.py +60 -0
- osmosis_ai/providers/gemini_provider.py +314 -0
- osmosis_ai/providers/openai_family.py +607 -0
- osmosis_ai/providers/shared.py +92 -0
- osmosis_ai/rubric_eval.py +498 -0
- osmosis_ai/rubric_types.py +49 -0
- osmosis_ai/utils.py +392 -5
- osmosis_ai-0.2.3.dist-info/METADATA +303 -0
- osmosis_ai-0.2.3.dist-info/RECORD +27 -0
- osmosis_ai-0.2.3.dist-info/entry_points.txt +4 -0
- osmosis_ai-0.2.1.dist-info/METADATA +0 -143
- osmosis_ai-0.2.1.dist-info/RECORD +0 -8
- {osmosis_ai-0.2.1.dist-info → osmosis_ai-0.2.3.dist-info}/WHEEL +0 -0
- {osmosis_ai-0.2.1.dist-info → osmosis_ai-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {osmosis_ai-0.2.1.dist-info → osmosis_ai-0.2.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,314 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from contextlib import contextmanager
|
|
4
|
+
import inspect
|
|
5
|
+
import time
|
|
6
|
+
import warnings
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Tuple
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING: # pragma: no cover - typing helpers only
|
|
10
|
+
from google import genai as genai_module # type: ignore
|
|
11
|
+
from google.genai import types as genai_types_module # type: ignore
|
|
12
|
+
|
|
13
|
+
from ..rubric_types import ProviderRequestError, RewardRubricRunResult
|
|
14
|
+
from .base import DEFAULT_REQUEST_TIMEOUT_SECONDS, ProviderRequest, RubricProvider
|
|
15
|
+
from .shared import dump_model, reward_schema_definition, sanitize_json
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
_GENAI_MODULE: Any | None = None
|
|
19
|
+
_GENAI_TYPES_MODULE: Any | None = None
|
|
20
|
+
_PYDANTIC_ANY_WARNING_MESSAGE = r".*<built-in function any> is not a Python type.*"
|
|
21
|
+
|
|
22
|
+
GEMINI_DEFAULT_TIMEOUT_SECONDS = 60.0
|
|
23
|
+
GEMINI_MIN_TIMEOUT_SECONDS = 5.0
|
|
24
|
+
GEMINI_MAX_TIMEOUT_SECONDS = 180.0
|
|
25
|
+
GEMINI_RETRY_ATTEMPTS = 3
|
|
26
|
+
GEMINI_TIMEOUT_BACKOFF = 1.5
|
|
27
|
+
GEMINI_RETRY_SLEEP_SECONDS = (0.5, 1.0, 2.0)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@contextmanager
|
|
31
|
+
def _suppress_pydantic_any_warning() -> Iterator[None]:
|
|
32
|
+
with warnings.catch_warnings():
|
|
33
|
+
warnings.filterwarnings(
|
|
34
|
+
"ignore",
|
|
35
|
+
message=_PYDANTIC_ANY_WARNING_MESSAGE,
|
|
36
|
+
category=UserWarning,
|
|
37
|
+
module=r"pydantic\._internal\._generate_schema",
|
|
38
|
+
)
|
|
39
|
+
yield
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _load_google_genai() -> Tuple[Any, Any]:
|
|
43
|
+
"""
|
|
44
|
+
Lazily import the Google Generative AI SDK so that environments without the optional
|
|
45
|
+
dependency avoid import-time side effects (like pydantic warnings) unless the Gemini
|
|
46
|
+
provider is actually used.
|
|
47
|
+
"""
|
|
48
|
+
global _GENAI_MODULE, _GENAI_TYPES_MODULE
|
|
49
|
+
if _GENAI_MODULE is not None and _GENAI_TYPES_MODULE is not None:
|
|
50
|
+
return _GENAI_MODULE, _GENAI_TYPES_MODULE
|
|
51
|
+
|
|
52
|
+
try: # pragma: no cover - optional dependency
|
|
53
|
+
with _suppress_pydantic_any_warning():
|
|
54
|
+
from google import genai as genai_mod # type: ignore
|
|
55
|
+
from google.genai import types as genai_types_mod # type: ignore
|
|
56
|
+
except ImportError as exc: # pragma: no cover - optional dependency
|
|
57
|
+
raise RuntimeError(
|
|
58
|
+
"Google Generative AI SDK is required for provider 'gemini'. "
|
|
59
|
+
"Install it via `pip install google-genai`."
|
|
60
|
+
) from exc
|
|
61
|
+
|
|
62
|
+
_GENAI_MODULE = genai_mod
|
|
63
|
+
_GENAI_TYPES_MODULE = genai_types_mod
|
|
64
|
+
return _GENAI_MODULE, _GENAI_TYPES_MODULE
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _normalize_gemini_model(model_id: str) -> str:
|
|
68
|
+
import re
|
|
69
|
+
|
|
70
|
+
return re.sub(r"^models/", "", model_id, flags=re.IGNORECASE)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def _json_schema_to_genai(
|
|
74
|
+
schema: Dict[str, Any],
|
|
75
|
+
genai_types: Any,
|
|
76
|
+
) -> "genai_types_module.Schema": # type: ignore[name-defined]
|
|
77
|
+
|
|
78
|
+
type_map = {
|
|
79
|
+
"object": genai_types.Type.OBJECT,
|
|
80
|
+
"string": genai_types.Type.STRING,
|
|
81
|
+
"number": genai_types.Type.NUMBER,
|
|
82
|
+
"integer": genai_types.Type.INTEGER,
|
|
83
|
+
"boolean": genai_types.Type.BOOLEAN,
|
|
84
|
+
"array": genai_types.Type.ARRAY,
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
kwargs: Dict[str, Any] = {}
|
|
88
|
+
type_value = schema.get("type")
|
|
89
|
+
if isinstance(type_value, str):
|
|
90
|
+
mapped = type_map.get(type_value.lower())
|
|
91
|
+
if mapped is not None:
|
|
92
|
+
kwargs["type"] = mapped
|
|
93
|
+
|
|
94
|
+
required = schema.get("required")
|
|
95
|
+
if isinstance(required, list):
|
|
96
|
+
filtered_required = [name for name in required if isinstance(name, str)]
|
|
97
|
+
if filtered_required:
|
|
98
|
+
kwargs["required"] = filtered_required
|
|
99
|
+
|
|
100
|
+
properties = schema.get("properties")
|
|
101
|
+
if isinstance(properties, dict):
|
|
102
|
+
converted_properties = {}
|
|
103
|
+
for key, value in properties.items():
|
|
104
|
+
if isinstance(key, str) and isinstance(value, dict):
|
|
105
|
+
converted_properties[key] = _json_schema_to_genai(value, genai_types)
|
|
106
|
+
if converted_properties:
|
|
107
|
+
kwargs["properties"] = converted_properties
|
|
108
|
+
|
|
109
|
+
items = schema.get("items")
|
|
110
|
+
if isinstance(items, dict):
|
|
111
|
+
kwargs["items"] = _json_schema_to_genai(items, genai_types)
|
|
112
|
+
|
|
113
|
+
enum_values = schema.get("enum")
|
|
114
|
+
if isinstance(enum_values, list):
|
|
115
|
+
filtered_enum = [str(option) for option in enum_values]
|
|
116
|
+
if filtered_enum:
|
|
117
|
+
kwargs["enum"] = filtered_enum
|
|
118
|
+
|
|
119
|
+
description = schema.get("description")
|
|
120
|
+
if isinstance(description, str):
|
|
121
|
+
kwargs["description"] = description
|
|
122
|
+
|
|
123
|
+
minimum = schema.get("minimum")
|
|
124
|
+
if isinstance(minimum, (int, float)):
|
|
125
|
+
kwargs["minimum"] = float(minimum)
|
|
126
|
+
|
|
127
|
+
maximum = schema.get("maximum")
|
|
128
|
+
if isinstance(maximum, (int, float)):
|
|
129
|
+
kwargs["maximum"] = float(maximum)
|
|
130
|
+
|
|
131
|
+
min_items = schema.get("min_items")
|
|
132
|
+
if isinstance(min_items, int):
|
|
133
|
+
kwargs["min_items"] = min_items
|
|
134
|
+
|
|
135
|
+
max_items = schema.get("max_items")
|
|
136
|
+
if isinstance(max_items, int):
|
|
137
|
+
kwargs["max_items"] = max_items
|
|
138
|
+
|
|
139
|
+
min_length = schema.get("min_length")
|
|
140
|
+
if isinstance(min_length, int):
|
|
141
|
+
kwargs["min_length"] = min_length
|
|
142
|
+
|
|
143
|
+
max_length = schema.get("max_length")
|
|
144
|
+
if isinstance(max_length, int):
|
|
145
|
+
kwargs["max_length"] = max_length
|
|
146
|
+
|
|
147
|
+
nullable = schema.get("nullable")
|
|
148
|
+
if isinstance(nullable, bool):
|
|
149
|
+
kwargs["nullable"] = nullable
|
|
150
|
+
|
|
151
|
+
with _suppress_pydantic_any_warning():
|
|
152
|
+
return genai_types.Schema(**kwargs)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def _build_retry_timeouts(requested_timeout: float) -> List[float]:
|
|
156
|
+
# Keep the first attempt generous, then increase for retries while capping growth.
|
|
157
|
+
base = max(requested_timeout, GEMINI_MIN_TIMEOUT_SECONDS, GEMINI_DEFAULT_TIMEOUT_SECONDS)
|
|
158
|
+
timeouts: List[float] = []
|
|
159
|
+
current = base
|
|
160
|
+
for _ in range(GEMINI_RETRY_ATTEMPTS):
|
|
161
|
+
timeouts.append(min(current, GEMINI_MAX_TIMEOUT_SECONDS))
|
|
162
|
+
current = min(current * GEMINI_TIMEOUT_BACKOFF, GEMINI_MAX_TIMEOUT_SECONDS)
|
|
163
|
+
return timeouts
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _seconds_to_millis(seconds: float) -> int:
|
|
167
|
+
# Gemini client expects timeout in milliseconds. Clamp to at least 1ms.
|
|
168
|
+
return max(int(round(seconds * 1000)), 1)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def _supports_request_options(generate_content: Any) -> bool:
|
|
172
|
+
try:
|
|
173
|
+
signature = inspect.signature(generate_content)
|
|
174
|
+
except (TypeError, ValueError):
|
|
175
|
+
return False
|
|
176
|
+
return "request_options" in signature.parameters
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
class GeminiProvider(RubricProvider):
|
|
180
|
+
name = "gemini"
|
|
181
|
+
|
|
182
|
+
def default_timeout(self, model: str) -> float:
|
|
183
|
+
return max(DEFAULT_REQUEST_TIMEOUT_SECONDS, GEMINI_DEFAULT_TIMEOUT_SECONDS)
|
|
184
|
+
|
|
185
|
+
def run(self, request: ProviderRequest) -> RewardRubricRunResult:
|
|
186
|
+
try:
|
|
187
|
+
genai, genai_types = _load_google_genai()
|
|
188
|
+
except RuntimeError as exc:
|
|
189
|
+
detail = str(exc).strip() or "Google Generative AI SDK is required."
|
|
190
|
+
raise ProviderRequestError(self.name, request.model, detail) from exc
|
|
191
|
+
|
|
192
|
+
try:
|
|
193
|
+
requested_timeout = float(request.timeout)
|
|
194
|
+
except (TypeError, ValueError):
|
|
195
|
+
requested_timeout = float(DEFAULT_REQUEST_TIMEOUT_SECONDS)
|
|
196
|
+
|
|
197
|
+
retry_timeouts = _build_retry_timeouts(requested_timeout)
|
|
198
|
+
max_timeout = max(retry_timeouts)
|
|
199
|
+
|
|
200
|
+
supports_request_options = False
|
|
201
|
+
shared_client: Any | None = None
|
|
202
|
+
|
|
203
|
+
with _suppress_pydantic_any_warning():
|
|
204
|
+
probe_client = genai.Client(
|
|
205
|
+
api_key=request.api_key,
|
|
206
|
+
http_options={"timeout": _seconds_to_millis(max_timeout)},
|
|
207
|
+
)
|
|
208
|
+
try:
|
|
209
|
+
supports_request_options = _supports_request_options(probe_client.models.generate_content)
|
|
210
|
+
except Exception:
|
|
211
|
+
try:
|
|
212
|
+
probe_client.close()
|
|
213
|
+
except Exception:
|
|
214
|
+
pass
|
|
215
|
+
raise
|
|
216
|
+
|
|
217
|
+
if supports_request_options:
|
|
218
|
+
shared_client = probe_client
|
|
219
|
+
else:
|
|
220
|
+
try:
|
|
221
|
+
probe_client.close()
|
|
222
|
+
except Exception:
|
|
223
|
+
pass
|
|
224
|
+
|
|
225
|
+
schema_definition = reward_schema_definition()
|
|
226
|
+
gemini_schema = _json_schema_to_genai(schema_definition, genai_types)
|
|
227
|
+
config = genai_types.GenerateContentConfig(
|
|
228
|
+
response_mime_type="application/json",
|
|
229
|
+
response_schema=gemini_schema,
|
|
230
|
+
temperature=0,
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
combined_prompt = f"{request.system_content}\n\n{request.user_content}"
|
|
234
|
+
|
|
235
|
+
response: Any | None = None
|
|
236
|
+
last_error: Exception | None = None
|
|
237
|
+
|
|
238
|
+
try:
|
|
239
|
+
for attempt_index, attempt_timeout in enumerate(retry_timeouts, start=1):
|
|
240
|
+
per_attempt_client: Any | None = None
|
|
241
|
+
http_timeout_ms = _seconds_to_millis(attempt_timeout)
|
|
242
|
+
try:
|
|
243
|
+
call_kwargs = {
|
|
244
|
+
"model": _normalize_gemini_model(request.model),
|
|
245
|
+
"contents": combined_prompt,
|
|
246
|
+
"config": config,
|
|
247
|
+
}
|
|
248
|
+
if supports_request_options and shared_client is not None:
|
|
249
|
+
call_client = shared_client
|
|
250
|
+
call_kwargs["request_options"] = {"timeout": http_timeout_ms}
|
|
251
|
+
else:
|
|
252
|
+
with _suppress_pydantic_any_warning():
|
|
253
|
+
per_attempt_client = genai.Client(
|
|
254
|
+
api_key=request.api_key,
|
|
255
|
+
http_options={"timeout": http_timeout_ms},
|
|
256
|
+
)
|
|
257
|
+
call_client = per_attempt_client
|
|
258
|
+
|
|
259
|
+
with _suppress_pydantic_any_warning():
|
|
260
|
+
response = call_client.models.generate_content(**call_kwargs)
|
|
261
|
+
break
|
|
262
|
+
except Exception as err: # pragma: no cover - network failures depend on runtime
|
|
263
|
+
last_error = err
|
|
264
|
+
if attempt_index >= len(retry_timeouts):
|
|
265
|
+
detail = str(err).strip() or "Gemini request failed."
|
|
266
|
+
raise ProviderRequestError(self.name, request.model, detail) from err
|
|
267
|
+
sleep_idx = min(attempt_index - 1, len(GEMINI_RETRY_SLEEP_SECONDS) - 1)
|
|
268
|
+
time.sleep(GEMINI_RETRY_SLEEP_SECONDS[sleep_idx])
|
|
269
|
+
finally:
|
|
270
|
+
if per_attempt_client is not None:
|
|
271
|
+
try:
|
|
272
|
+
per_attempt_client.close()
|
|
273
|
+
except Exception:
|
|
274
|
+
pass
|
|
275
|
+
finally:
|
|
276
|
+
if shared_client is not None:
|
|
277
|
+
try:
|
|
278
|
+
shared_client.close()
|
|
279
|
+
except Exception:
|
|
280
|
+
pass
|
|
281
|
+
|
|
282
|
+
if response is None and last_error is not None:
|
|
283
|
+
detail = str(last_error).strip() or "Gemini request failed."
|
|
284
|
+
raise ProviderRequestError(self.name, request.model, detail) from last_error
|
|
285
|
+
|
|
286
|
+
raw = dump_model(response)
|
|
287
|
+
|
|
288
|
+
text = getattr(response, "text", None)
|
|
289
|
+
if not isinstance(text, str) or not text.strip():
|
|
290
|
+
candidates = raw.get("candidates") if isinstance(raw, dict) else None
|
|
291
|
+
if isinstance(candidates, list) and candidates:
|
|
292
|
+
first = candidates[0]
|
|
293
|
+
if isinstance(first, dict):
|
|
294
|
+
content = first.get("content")
|
|
295
|
+
if isinstance(content, dict):
|
|
296
|
+
parts = content.get("parts")
|
|
297
|
+
if isinstance(parts, list):
|
|
298
|
+
for part in parts:
|
|
299
|
+
if isinstance(part, dict):
|
|
300
|
+
candidate_text = part.get("text")
|
|
301
|
+
if isinstance(candidate_text, str) and candidate_text.strip():
|
|
302
|
+
text = candidate_text
|
|
303
|
+
break
|
|
304
|
+
if not isinstance(text, str) or not text.strip():
|
|
305
|
+
raise ProviderRequestError(self.name, request.model, "Model response did not include any text content.")
|
|
306
|
+
try:
|
|
307
|
+
score, explanation = sanitize_json(text)
|
|
308
|
+
except ValueError as err:
|
|
309
|
+
raise ProviderRequestError(self.name, request.model, str(err)) from err
|
|
310
|
+
bounded = max(request.score_min, min(request.score_max, score))
|
|
311
|
+
return {"score": bounded, "explanation": explanation, "raw": raw}
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
__all__ = ["GeminiProvider"]
|