osmosis-ai 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of osmosis-ai might be problematic. Click here for more details.
- osmosis_ai/__init__.py +13 -4
- osmosis_ai/consts.py +1 -1
- osmosis_ai/providers/__init__.py +36 -0
- osmosis_ai/providers/anthropic_provider.py +85 -0
- osmosis_ai/providers/base.py +60 -0
- osmosis_ai/providers/gemini_provider.py +269 -0
- osmosis_ai/providers/openai_family.py +607 -0
- osmosis_ai/providers/shared.py +92 -0
- osmosis_ai/rubric_eval.py +537 -0
- osmosis_ai/rubric_types.py +49 -0
- osmosis_ai/utils.py +392 -1
- osmosis_ai-0.2.2.dist-info/METADATA +241 -0
- osmosis_ai-0.2.2.dist-info/RECORD +16 -0
- osmosis_ai-0.2.1.dist-info/METADATA +0 -143
- osmosis_ai-0.2.1.dist-info/RECORD +0 -8
- {osmosis_ai-0.2.1.dist-info → osmosis_ai-0.2.2.dist-info}/WHEEL +0 -0
- {osmosis_ai-0.2.1.dist-info → osmosis_ai-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {osmosis_ai-0.2.1.dist-info → osmosis_ai-0.2.2.dist-info}/top_level.txt +0 -0
osmosis_ai/__init__.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
"""
|
|
2
2
|
osmosis-ai: A Python library for reward function validation with strict type enforcement.
|
|
3
3
|
|
|
4
|
-
This library provides
|
|
5
|
-
function signatures for
|
|
4
|
+
This library provides decorators such as @osmosis_reward and @osmosis_rubric that
|
|
5
|
+
enforce standardized function signatures for LLM-centric workflows.
|
|
6
6
|
|
|
7
7
|
Features:
|
|
8
8
|
- Type-safe reward function decoration
|
|
@@ -10,6 +10,15 @@ Features:
|
|
|
10
10
|
- Support for optional configuration parameters
|
|
11
11
|
"""
|
|
12
12
|
|
|
13
|
-
from .
|
|
13
|
+
from .rubric_eval import MissingAPIKeyError, evaluate_rubric
|
|
14
|
+
from .rubric_types import ModelNotFoundError, ProviderRequestError
|
|
15
|
+
from .utils import osmosis_reward, osmosis_rubric
|
|
14
16
|
|
|
15
|
-
__all__ = [
|
|
17
|
+
__all__ = [
|
|
18
|
+
"osmosis_reward",
|
|
19
|
+
"osmosis_rubric",
|
|
20
|
+
"evaluate_rubric",
|
|
21
|
+
"MissingAPIKeyError",
|
|
22
|
+
"ProviderRequestError",
|
|
23
|
+
"ModelNotFoundError",
|
|
24
|
+
]
|
osmosis_ai/consts.py
CHANGED
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Tuple
|
|
4
|
+
|
|
5
|
+
from .anthropic_provider import AnthropicProvider
|
|
6
|
+
from .base import DEFAULT_REQUEST_TIMEOUT_SECONDS, ProviderRegistry, ProviderRequest, RubricProvider
|
|
7
|
+
from .gemini_provider import GeminiProvider
|
|
8
|
+
from .openai_family import OpenAIProvider, XAIProvider
|
|
9
|
+
|
|
10
|
+
_REGISTRY = ProviderRegistry()
|
|
11
|
+
_REGISTRY.register(OpenAIProvider())
|
|
12
|
+
_REGISTRY.register(XAIProvider())
|
|
13
|
+
_REGISTRY.register(AnthropicProvider())
|
|
14
|
+
_REGISTRY.register(GeminiProvider())
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def get_provider(name: str) -> RubricProvider:
|
|
18
|
+
return _REGISTRY.get(name)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def register_provider(provider: RubricProvider) -> None:
|
|
22
|
+
_REGISTRY.register(provider)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def supported_providers() -> Tuple[str, ...]:
|
|
26
|
+
return _REGISTRY.supported_providers()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
__all__ = [
|
|
30
|
+
"DEFAULT_REQUEST_TIMEOUT_SECONDS",
|
|
31
|
+
"ProviderRequest",
|
|
32
|
+
"RubricProvider",
|
|
33
|
+
"get_provider",
|
|
34
|
+
"register_provider",
|
|
35
|
+
"supported_providers",
|
|
36
|
+
]
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict
|
|
4
|
+
|
|
5
|
+
try: # pragma: no cover - optional dependency
|
|
6
|
+
import anthropic # type: ignore
|
|
7
|
+
from anthropic import APIError # type: ignore
|
|
8
|
+
except ImportError: # pragma: no cover - optional dependency
|
|
9
|
+
anthropic = None # type: ignore[assignment]
|
|
10
|
+
APIError = None # type: ignore[assignment]
|
|
11
|
+
|
|
12
|
+
from ..rubric_types import ModelNotFoundError, ProviderRequestError, RewardRubricRunResult
|
|
13
|
+
from .base import DEFAULT_REQUEST_TIMEOUT_SECONDS, ProviderRequest, RubricProvider
|
|
14
|
+
from .shared import dump_model, extract_structured_score, reward_schema_definition
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class AnthropicProvider(RubricProvider):
|
|
18
|
+
name = "anthropic"
|
|
19
|
+
|
|
20
|
+
def default_timeout(self, model: str) -> float:
|
|
21
|
+
return DEFAULT_REQUEST_TIMEOUT_SECONDS
|
|
22
|
+
|
|
23
|
+
def run(self, request: ProviderRequest) -> RewardRubricRunResult:
|
|
24
|
+
if anthropic is None or APIError is None:
|
|
25
|
+
raise ProviderRequestError(
|
|
26
|
+
self.name,
|
|
27
|
+
request.model,
|
|
28
|
+
"Anthropic SDK is required. Install it via `pip install anthropic`.",
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
client = anthropic.Anthropic(api_key=request.api_key)
|
|
32
|
+
tool_name = "emit_reward_rubric_response"
|
|
33
|
+
schema_definition = reward_schema_definition()
|
|
34
|
+
tool = {
|
|
35
|
+
"name": tool_name,
|
|
36
|
+
"description": "Return the reward rubric score and explanation as structured JSON.",
|
|
37
|
+
"input_schema": schema_definition,
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
try:
|
|
41
|
+
response = client.messages.create(
|
|
42
|
+
model=request.model,
|
|
43
|
+
system=request.system_content,
|
|
44
|
+
messages=[{"role": "user", "content": [{"type": "text", "text": request.user_content}]}],
|
|
45
|
+
tools=[tool],
|
|
46
|
+
tool_choice={"type": "tool", "name": tool_name},
|
|
47
|
+
max_tokens=512,
|
|
48
|
+
temperature=0,
|
|
49
|
+
timeout=request.timeout,
|
|
50
|
+
)
|
|
51
|
+
except APIError as err:
|
|
52
|
+
detail = getattr(err, "message", None)
|
|
53
|
+
if not isinstance(detail, str) or not detail.strip():
|
|
54
|
+
detail = str(err)
|
|
55
|
+
status_code = getattr(err, "status_code", None)
|
|
56
|
+
if status_code == 404:
|
|
57
|
+
not_found_detail = (
|
|
58
|
+
f"Model '{request.model}' was not found. Confirm your Anthropic account has access "
|
|
59
|
+
"to the requested snapshot or update the model identifier."
|
|
60
|
+
)
|
|
61
|
+
raise ModelNotFoundError(self.name, request.model, not_found_detail) from err
|
|
62
|
+
raise ProviderRequestError(self.name, request.model, detail) from err
|
|
63
|
+
except Exception as err:
|
|
64
|
+
detail = str(err).strip() or "Unexpected error during Anthropic request."
|
|
65
|
+
raise ProviderRequestError(self.name, request.model, detail) from err
|
|
66
|
+
|
|
67
|
+
raw = dump_model(response)
|
|
68
|
+
|
|
69
|
+
payload: Dict[str, Any] | None = None
|
|
70
|
+
content_blocks = raw.get("content") if isinstance(raw, dict) else None
|
|
71
|
+
if isinstance(content_blocks, list):
|
|
72
|
+
for block in content_blocks:
|
|
73
|
+
if isinstance(block, dict) and block.get("type") == "tool_use" and block.get("name") == tool_name:
|
|
74
|
+
maybe_input = block.get("input")
|
|
75
|
+
if isinstance(maybe_input, dict):
|
|
76
|
+
payload = maybe_input
|
|
77
|
+
break
|
|
78
|
+
if payload is None:
|
|
79
|
+
raise ProviderRequestError(self.name, request.model, "Model response missing expected tool output.")
|
|
80
|
+
score, explanation = extract_structured_score(payload)
|
|
81
|
+
bounded = max(request.score_min, min(request.score_max, score))
|
|
82
|
+
return {"score": bounded, "explanation": explanation, "raw": raw}
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
__all__ = ["AnthropicProvider"]
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Dict, Tuple
|
|
5
|
+
|
|
6
|
+
from ..rubric_types import RewardRubricRunResult
|
|
7
|
+
|
|
8
|
+
DEFAULT_REQUEST_TIMEOUT_SECONDS = 30.0
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass(frozen=True)
|
|
12
|
+
class ProviderRequest:
|
|
13
|
+
provider: str
|
|
14
|
+
model: str
|
|
15
|
+
api_key: str
|
|
16
|
+
system_content: str
|
|
17
|
+
user_content: str
|
|
18
|
+
score_min: float
|
|
19
|
+
score_max: float
|
|
20
|
+
timeout: float
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class RubricProvider:
|
|
24
|
+
"""Interface for hosted LLM providers that can score rubrics."""
|
|
25
|
+
|
|
26
|
+
name: str
|
|
27
|
+
|
|
28
|
+
def default_timeout(self, model: str) -> float:
|
|
29
|
+
return DEFAULT_REQUEST_TIMEOUT_SECONDS
|
|
30
|
+
|
|
31
|
+
def run(self, request: ProviderRequest) -> RewardRubricRunResult:
|
|
32
|
+
raise NotImplementedError
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class ProviderRegistry:
|
|
36
|
+
def __init__(self) -> None:
|
|
37
|
+
self._providers: Dict[str, RubricProvider] = {}
|
|
38
|
+
|
|
39
|
+
def register(self, provider: RubricProvider) -> None:
|
|
40
|
+
key = provider.name
|
|
41
|
+
if key in self._providers:
|
|
42
|
+
raise ValueError(f"Provider '{key}' is already registered.")
|
|
43
|
+
self._providers[key] = provider
|
|
44
|
+
|
|
45
|
+
def get(self, name: str) -> RubricProvider:
|
|
46
|
+
try:
|
|
47
|
+
return self._providers[name]
|
|
48
|
+
except KeyError as exc:
|
|
49
|
+
raise ValueError(f"Unsupported provider '{name}'.") from exc
|
|
50
|
+
|
|
51
|
+
def supported_providers(self) -> Tuple[str, ...]:
|
|
52
|
+
return tuple(sorted(self._providers))
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
__all__ = [
|
|
56
|
+
"DEFAULT_REQUEST_TIMEOUT_SECONDS",
|
|
57
|
+
"ProviderRequest",
|
|
58
|
+
"RubricProvider",
|
|
59
|
+
"ProviderRegistry",
|
|
60
|
+
]
|
|
@@ -0,0 +1,269 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from contextlib import contextmanager
|
|
4
|
+
import time
|
|
5
|
+
import warnings
|
|
6
|
+
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Tuple
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING: # pragma: no cover - typing helpers only
|
|
9
|
+
from google import genai as genai_module # type: ignore
|
|
10
|
+
from google.genai import types as genai_types_module # type: ignore
|
|
11
|
+
|
|
12
|
+
from ..rubric_types import ProviderRequestError, RewardRubricRunResult
|
|
13
|
+
from .base import DEFAULT_REQUEST_TIMEOUT_SECONDS, ProviderRequest, RubricProvider
|
|
14
|
+
from .shared import dump_model, reward_schema_definition, sanitize_json
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
_GENAI_MODULE: Any | None = None
|
|
18
|
+
_GENAI_TYPES_MODULE: Any | None = None
|
|
19
|
+
_PYDANTIC_ANY_WARNING_MESSAGE = r".*<built-in function any> is not a Python type.*"
|
|
20
|
+
|
|
21
|
+
GEMINI_DEFAULT_TIMEOUT_SECONDS = 60.0
|
|
22
|
+
GEMINI_MIN_TIMEOUT_SECONDS = 5.0
|
|
23
|
+
GEMINI_MAX_TIMEOUT_SECONDS = 180.0
|
|
24
|
+
GEMINI_RETRY_ATTEMPTS = 3
|
|
25
|
+
GEMINI_TIMEOUT_BACKOFF = 1.5
|
|
26
|
+
GEMINI_RETRY_SLEEP_SECONDS = (0.5, 1.0, 2.0)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@contextmanager
|
|
30
|
+
def _suppress_pydantic_any_warning() -> Iterator[None]:
|
|
31
|
+
with warnings.catch_warnings():
|
|
32
|
+
warnings.filterwarnings(
|
|
33
|
+
"ignore",
|
|
34
|
+
message=_PYDANTIC_ANY_WARNING_MESSAGE,
|
|
35
|
+
category=UserWarning,
|
|
36
|
+
module=r"pydantic\._internal\._generate_schema",
|
|
37
|
+
)
|
|
38
|
+
yield
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _load_google_genai() -> Tuple[Any, Any]:
|
|
42
|
+
"""
|
|
43
|
+
Lazily import the Google Generative AI SDK so that environments without the optional
|
|
44
|
+
dependency avoid import-time side effects (like pydantic warnings) unless the Gemini
|
|
45
|
+
provider is actually used.
|
|
46
|
+
"""
|
|
47
|
+
global _GENAI_MODULE, _GENAI_TYPES_MODULE
|
|
48
|
+
if _GENAI_MODULE is not None and _GENAI_TYPES_MODULE is not None:
|
|
49
|
+
return _GENAI_MODULE, _GENAI_TYPES_MODULE
|
|
50
|
+
|
|
51
|
+
try: # pragma: no cover - optional dependency
|
|
52
|
+
with _suppress_pydantic_any_warning():
|
|
53
|
+
from google import genai as genai_mod # type: ignore
|
|
54
|
+
from google.genai import types as genai_types_mod # type: ignore
|
|
55
|
+
except ImportError as exc: # pragma: no cover - optional dependency
|
|
56
|
+
raise RuntimeError(
|
|
57
|
+
"Google Generative AI SDK is required for provider 'gemini'. "
|
|
58
|
+
"Install it via `pip install google-genai`."
|
|
59
|
+
) from exc
|
|
60
|
+
|
|
61
|
+
_GENAI_MODULE = genai_mod
|
|
62
|
+
_GENAI_TYPES_MODULE = genai_types_mod
|
|
63
|
+
return _GENAI_MODULE, _GENAI_TYPES_MODULE
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _normalize_gemini_model(model_id: str) -> str:
|
|
67
|
+
import re
|
|
68
|
+
|
|
69
|
+
return re.sub(r"^models/", "", model_id, flags=re.IGNORECASE)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _json_schema_to_genai(
|
|
73
|
+
schema: Dict[str, Any],
|
|
74
|
+
genai_types: Any,
|
|
75
|
+
) -> "genai_types_module.Schema": # type: ignore[name-defined]
|
|
76
|
+
|
|
77
|
+
type_map = {
|
|
78
|
+
"object": genai_types.Type.OBJECT,
|
|
79
|
+
"string": genai_types.Type.STRING,
|
|
80
|
+
"number": genai_types.Type.NUMBER,
|
|
81
|
+
"integer": genai_types.Type.INTEGER,
|
|
82
|
+
"boolean": genai_types.Type.BOOLEAN,
|
|
83
|
+
"array": genai_types.Type.ARRAY,
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
kwargs: Dict[str, Any] = {}
|
|
87
|
+
type_value = schema.get("type")
|
|
88
|
+
if isinstance(type_value, str):
|
|
89
|
+
mapped = type_map.get(type_value.lower())
|
|
90
|
+
if mapped is not None:
|
|
91
|
+
kwargs["type"] = mapped
|
|
92
|
+
|
|
93
|
+
required = schema.get("required")
|
|
94
|
+
if isinstance(required, list):
|
|
95
|
+
filtered_required = [name for name in required if isinstance(name, str)]
|
|
96
|
+
if filtered_required:
|
|
97
|
+
kwargs["required"] = filtered_required
|
|
98
|
+
|
|
99
|
+
properties = schema.get("properties")
|
|
100
|
+
if isinstance(properties, dict):
|
|
101
|
+
converted_properties = {}
|
|
102
|
+
for key, value in properties.items():
|
|
103
|
+
if isinstance(key, str) and isinstance(value, dict):
|
|
104
|
+
converted_properties[key] = _json_schema_to_genai(value, genai_types)
|
|
105
|
+
if converted_properties:
|
|
106
|
+
kwargs["properties"] = converted_properties
|
|
107
|
+
|
|
108
|
+
items = schema.get("items")
|
|
109
|
+
if isinstance(items, dict):
|
|
110
|
+
kwargs["items"] = _json_schema_to_genai(items, genai_types)
|
|
111
|
+
|
|
112
|
+
enum_values = schema.get("enum")
|
|
113
|
+
if isinstance(enum_values, list):
|
|
114
|
+
filtered_enum = [str(option) for option in enum_values]
|
|
115
|
+
if filtered_enum:
|
|
116
|
+
kwargs["enum"] = filtered_enum
|
|
117
|
+
|
|
118
|
+
description = schema.get("description")
|
|
119
|
+
if isinstance(description, str):
|
|
120
|
+
kwargs["description"] = description
|
|
121
|
+
|
|
122
|
+
minimum = schema.get("minimum")
|
|
123
|
+
if isinstance(minimum, (int, float)):
|
|
124
|
+
kwargs["minimum"] = float(minimum)
|
|
125
|
+
|
|
126
|
+
maximum = schema.get("maximum")
|
|
127
|
+
if isinstance(maximum, (int, float)):
|
|
128
|
+
kwargs["maximum"] = float(maximum)
|
|
129
|
+
|
|
130
|
+
min_items = schema.get("min_items")
|
|
131
|
+
if isinstance(min_items, int):
|
|
132
|
+
kwargs["min_items"] = min_items
|
|
133
|
+
|
|
134
|
+
max_items = schema.get("max_items")
|
|
135
|
+
if isinstance(max_items, int):
|
|
136
|
+
kwargs["max_items"] = max_items
|
|
137
|
+
|
|
138
|
+
min_length = schema.get("min_length")
|
|
139
|
+
if isinstance(min_length, int):
|
|
140
|
+
kwargs["min_length"] = min_length
|
|
141
|
+
|
|
142
|
+
max_length = schema.get("max_length")
|
|
143
|
+
if isinstance(max_length, int):
|
|
144
|
+
kwargs["max_length"] = max_length
|
|
145
|
+
|
|
146
|
+
nullable = schema.get("nullable")
|
|
147
|
+
if isinstance(nullable, bool):
|
|
148
|
+
kwargs["nullable"] = nullable
|
|
149
|
+
|
|
150
|
+
with _suppress_pydantic_any_warning():
|
|
151
|
+
return genai_types.Schema(**kwargs)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def _build_retry_timeouts(requested_timeout: float) -> List[float]:
|
|
155
|
+
# Keep the first attempt generous, then increase for retries while capping growth.
|
|
156
|
+
base = max(requested_timeout, GEMINI_MIN_TIMEOUT_SECONDS, GEMINI_DEFAULT_TIMEOUT_SECONDS)
|
|
157
|
+
timeouts: List[float] = []
|
|
158
|
+
current = base
|
|
159
|
+
for _ in range(GEMINI_RETRY_ATTEMPTS):
|
|
160
|
+
timeouts.append(min(current, GEMINI_MAX_TIMEOUT_SECONDS))
|
|
161
|
+
current = min(current * GEMINI_TIMEOUT_BACKOFF, GEMINI_MAX_TIMEOUT_SECONDS)
|
|
162
|
+
return timeouts
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def _seconds_to_millis(seconds: float) -> int:
|
|
166
|
+
# Gemini client expects timeout in milliseconds. Clamp to at least 1ms.
|
|
167
|
+
return max(int(round(seconds * 1000)), 1)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
class GeminiProvider(RubricProvider):
|
|
171
|
+
name = "gemini"
|
|
172
|
+
|
|
173
|
+
def default_timeout(self, model: str) -> float:
|
|
174
|
+
return max(DEFAULT_REQUEST_TIMEOUT_SECONDS, GEMINI_DEFAULT_TIMEOUT_SECONDS)
|
|
175
|
+
|
|
176
|
+
def run(self, request: ProviderRequest) -> RewardRubricRunResult:
|
|
177
|
+
try:
|
|
178
|
+
genai, genai_types = _load_google_genai()
|
|
179
|
+
except RuntimeError as exc:
|
|
180
|
+
detail = str(exc).strip() or "Google Generative AI SDK is required."
|
|
181
|
+
raise ProviderRequestError(self.name, request.model, detail) from exc
|
|
182
|
+
|
|
183
|
+
try:
|
|
184
|
+
requested_timeout = float(request.timeout)
|
|
185
|
+
except (TypeError, ValueError):
|
|
186
|
+
requested_timeout = float(DEFAULT_REQUEST_TIMEOUT_SECONDS)
|
|
187
|
+
|
|
188
|
+
retry_timeouts = _build_retry_timeouts(requested_timeout)
|
|
189
|
+
max_timeout = max(retry_timeouts)
|
|
190
|
+
|
|
191
|
+
with _suppress_pydantic_any_warning():
|
|
192
|
+
client = genai.Client(
|
|
193
|
+
api_key=request.api_key,
|
|
194
|
+
http_options={"timeout": _seconds_to_millis(max_timeout)},
|
|
195
|
+
)
|
|
196
|
+
schema_definition = reward_schema_definition()
|
|
197
|
+
gemini_schema = _json_schema_to_genai(schema_definition, genai_types)
|
|
198
|
+
config = genai_types.GenerateContentConfig(
|
|
199
|
+
response_mime_type="application/json",
|
|
200
|
+
response_schema=gemini_schema,
|
|
201
|
+
temperature=0,
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
combined_prompt = f"{request.system_content}\n\n{request.user_content}"
|
|
205
|
+
|
|
206
|
+
response: Any | None = None
|
|
207
|
+
last_error: Exception | None = None
|
|
208
|
+
|
|
209
|
+
for attempt_index, attempt_timeout in enumerate(retry_timeouts, start=1):
|
|
210
|
+
try:
|
|
211
|
+
with _suppress_pydantic_any_warning():
|
|
212
|
+
try:
|
|
213
|
+
response = client.models.generate_content(
|
|
214
|
+
model=_normalize_gemini_model(request.model),
|
|
215
|
+
contents=combined_prompt,
|
|
216
|
+
config=config,
|
|
217
|
+
request_options={"timeout": _seconds_to_millis(attempt_timeout)},
|
|
218
|
+
)
|
|
219
|
+
except TypeError as err:
|
|
220
|
+
# Older SDKs may not accept request_options; retry without it.
|
|
221
|
+
if "request_options" not in str(err):
|
|
222
|
+
raise
|
|
223
|
+
response = client.models.generate_content(
|
|
224
|
+
model=_normalize_gemini_model(request.model),
|
|
225
|
+
contents=combined_prompt,
|
|
226
|
+
config=config,
|
|
227
|
+
)
|
|
228
|
+
break
|
|
229
|
+
except Exception as err: # pragma: no cover - network failures depend on runtime
|
|
230
|
+
last_error = err
|
|
231
|
+
if attempt_index >= len(retry_timeouts):
|
|
232
|
+
detail = str(err).strip() or "Gemini request failed."
|
|
233
|
+
raise ProviderRequestError(self.name, request.model, detail) from err
|
|
234
|
+
sleep_idx = min(attempt_index - 1, len(GEMINI_RETRY_SLEEP_SECONDS) - 1)
|
|
235
|
+
time.sleep(GEMINI_RETRY_SLEEP_SECONDS[sleep_idx])
|
|
236
|
+
|
|
237
|
+
if response is None and last_error is not None:
|
|
238
|
+
detail = str(last_error).strip() or "Gemini request failed."
|
|
239
|
+
raise ProviderRequestError(self.name, request.model, detail) from last_error
|
|
240
|
+
|
|
241
|
+
raw = dump_model(response)
|
|
242
|
+
|
|
243
|
+
text = getattr(response, "text", None)
|
|
244
|
+
if not isinstance(text, str) or not text.strip():
|
|
245
|
+
candidates = raw.get("candidates") if isinstance(raw, dict) else None
|
|
246
|
+
if isinstance(candidates, list) and candidates:
|
|
247
|
+
first = candidates[0]
|
|
248
|
+
if isinstance(first, dict):
|
|
249
|
+
content = first.get("content")
|
|
250
|
+
if isinstance(content, dict):
|
|
251
|
+
parts = content.get("parts")
|
|
252
|
+
if isinstance(parts, list):
|
|
253
|
+
for part in parts:
|
|
254
|
+
if isinstance(part, dict):
|
|
255
|
+
candidate_text = part.get("text")
|
|
256
|
+
if isinstance(candidate_text, str) and candidate_text.strip():
|
|
257
|
+
text = candidate_text
|
|
258
|
+
break
|
|
259
|
+
if not isinstance(text, str) or not text.strip():
|
|
260
|
+
raise ProviderRequestError(self.name, request.model, "Model response did not include any text content.")
|
|
261
|
+
try:
|
|
262
|
+
score, explanation = sanitize_json(text)
|
|
263
|
+
except ValueError as err:
|
|
264
|
+
raise ProviderRequestError(self.name, request.model, str(err)) from err
|
|
265
|
+
bounded = max(request.score_min, min(request.score_max, score))
|
|
266
|
+
return {"score": bounded, "explanation": explanation, "raw": raw}
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
__all__ = ["GeminiProvider"]
|