osmosis-ai 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of osmosis-ai might be problematic. Click here for more details.
- osmosis_ai/__init__.py +13 -4
- osmosis_ai/cli.py +50 -0
- osmosis_ai/cli_commands.py +181 -0
- osmosis_ai/cli_services/__init__.py +67 -0
- osmosis_ai/cli_services/config.py +407 -0
- osmosis_ai/cli_services/dataset.py +229 -0
- osmosis_ai/cli_services/engine.py +251 -0
- osmosis_ai/cli_services/errors.py +7 -0
- osmosis_ai/cli_services/reporting.py +307 -0
- osmosis_ai/cli_services/session.py +174 -0
- osmosis_ai/cli_services/shared.py +209 -0
- osmosis_ai/consts.py +1 -1
- osmosis_ai/providers/__init__.py +36 -0
- osmosis_ai/providers/anthropic_provider.py +85 -0
- osmosis_ai/providers/base.py +60 -0
- osmosis_ai/providers/gemini_provider.py +314 -0
- osmosis_ai/providers/openai_family.py +607 -0
- osmosis_ai/providers/shared.py +92 -0
- osmosis_ai/rubric_eval.py +498 -0
- osmosis_ai/rubric_types.py +49 -0
- osmosis_ai/utils.py +392 -5
- osmosis_ai-0.2.3.dist-info/METADATA +303 -0
- osmosis_ai-0.2.3.dist-info/RECORD +27 -0
- osmosis_ai-0.2.3.dist-info/entry_points.txt +4 -0
- osmosis_ai-0.2.1.dist-info/METADATA +0 -143
- osmosis_ai-0.2.1.dist-info/RECORD +0 -8
- {osmosis_ai-0.2.1.dist-info → osmosis_ai-0.2.3.dist-info}/WHEEL +0 -0
- {osmosis_ai-0.2.1.dist-info → osmosis_ai-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {osmosis_ai-0.2.1.dist-info → osmosis_ai-0.2.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,607 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, Optional
|
|
4
|
+
|
|
5
|
+
try: # pragma: no cover - optional dependency
|
|
6
|
+
from openai import BadRequestError, OpenAI, OpenAIError # type: ignore
|
|
7
|
+
except ImportError: # pragma: no cover - optional dependency
|
|
8
|
+
OpenAI = None # type: ignore[assignment]
|
|
9
|
+
BadRequestError = None # type: ignore[assignment]
|
|
10
|
+
OpenAIError = None # type: ignore[assignment]
|
|
11
|
+
|
|
12
|
+
from ..rubric_types import ProviderRequestError, RewardRubricRunResult
|
|
13
|
+
from .base import DEFAULT_REQUEST_TIMEOUT_SECONDS, ProviderRequest, RubricProvider
|
|
14
|
+
from .shared import dump_model, reward_json_schema, reward_schema_definition, sanitize_json
|
|
15
|
+
|
|
16
|
+
# Families that generally require "max_completion_tokens" on chat.completions in SDK v2.
|
|
17
|
+
FAMILIES_USE_MAX_COMPLETION = ("gpt-5", "gpt-4.1", "gpt-4o", "o3", "o4")
|
|
18
|
+
|
|
19
|
+
# ---- Tunables for safer defaults ----
|
|
20
|
+
# Higher token budget avoids "reasoning-only consumed budget" on gpt-5.
|
|
21
|
+
SAFE_DEFAULT_MAX_OUTPUT = 1536 # you can raise to 2048 if you prefer
|
|
22
|
+
SAFE_MAX_OUTPUT_CAP = 4096 # hard safety cap to avoid unbounded growth
|
|
23
|
+
SAFE_BUMP_FACTOR = 2 # when hitting max_output_tokens, bump once by this factor
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _should_use_openai_responses(model_id: str) -> bool:
|
|
27
|
+
normalized = model_id.strip().lower()
|
|
28
|
+
return any(
|
|
29
|
+
normalized.startswith(prefix) for prefix in ("gpt-4.1", "gpt-4o", "gpt-5", "o3", "o4")
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _is_openai_gpt5_family(model_id: str) -> bool:
|
|
34
|
+
return model_id.strip().lower().startswith("gpt-5")
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _extract_openai_responses_text(payload: Any) -> Optional[str]:
|
|
38
|
+
"""
|
|
39
|
+
Extract visible text or JSON string from Responses API or Chat Completions payloads.
|
|
40
|
+
|
|
41
|
+
Supports:
|
|
42
|
+
- output_text: str | list[str]
|
|
43
|
+
- output -> message -> content[] with parts:
|
|
44
|
+
* {"type": "output_text" | "text" | "input_text", "text": str | {"value"/"content"/"text": str}}
|
|
45
|
+
* {"type": "output_json" | "json", "json": dict} # will be dumped to str
|
|
46
|
+
- response.output[...] (some SDKs wrap as {"response": {...}})
|
|
47
|
+
- message.content[...] (some SDKs expose "message" object)
|
|
48
|
+
- top-level "content" list (rare)
|
|
49
|
+
- chat.completions choices[0].message.content
|
|
50
|
+
"""
|
|
51
|
+
if not isinstance(payload, dict):
|
|
52
|
+
return None
|
|
53
|
+
|
|
54
|
+
def _dump_json(obj: Any) -> Optional[str]:
|
|
55
|
+
try:
|
|
56
|
+
from json import dumps
|
|
57
|
+
s = dumps(obj, ensure_ascii=False)
|
|
58
|
+
return s if isinstance(s, str) and s.strip() else None
|
|
59
|
+
except Exception:
|
|
60
|
+
return None
|
|
61
|
+
|
|
62
|
+
# 0) helper: scan a Responses-style "output" list
|
|
63
|
+
def _scan_output_list(output_list: Any) -> Optional[str]:
|
|
64
|
+
if not isinstance(output_list, list):
|
|
65
|
+
return None
|
|
66
|
+
for entry in output_list:
|
|
67
|
+
if not isinstance(entry, dict):
|
|
68
|
+
continue
|
|
69
|
+
contents = entry.get("content") or [] # guard None
|
|
70
|
+
if not isinstance(contents, list):
|
|
71
|
+
continue
|
|
72
|
+
for part in contents:
|
|
73
|
+
if not isinstance(part, dict):
|
|
74
|
+
continue
|
|
75
|
+
part_type = part.get("type")
|
|
76
|
+
|
|
77
|
+
# Structured JSON parts
|
|
78
|
+
if part_type in ("output_json", "json"):
|
|
79
|
+
pj = part.get("json")
|
|
80
|
+
dumped = _dump_json(pj)
|
|
81
|
+
if dumped:
|
|
82
|
+
return dumped
|
|
83
|
+
|
|
84
|
+
# Textual parts
|
|
85
|
+
if part_type in ("output_text", "text", "input_text"):
|
|
86
|
+
text_field = part.get("text")
|
|
87
|
+
if isinstance(text_field, dict):
|
|
88
|
+
val = (
|
|
89
|
+
text_field.get("value")
|
|
90
|
+
or text_field.get("content")
|
|
91
|
+
or text_field.get("text")
|
|
92
|
+
)
|
|
93
|
+
if isinstance(val, str) and val.strip():
|
|
94
|
+
return val
|
|
95
|
+
if isinstance(text_field, str) and text_field.strip():
|
|
96
|
+
return text_field
|
|
97
|
+
|
|
98
|
+
# Extra leniency
|
|
99
|
+
if isinstance(part.get("text"), str) and part["text"].strip():
|
|
100
|
+
return part["text"].strip()
|
|
101
|
+
|
|
102
|
+
# Nested tool_result -> content[] -> output_json
|
|
103
|
+
if part_type == "tool_result":
|
|
104
|
+
nested = _scan_output_list(part.get("content"))
|
|
105
|
+
if nested:
|
|
106
|
+
return nested
|
|
107
|
+
return None
|
|
108
|
+
|
|
109
|
+
# 1) Fast path: aggregated text
|
|
110
|
+
output_text = payload.get("output_text")
|
|
111
|
+
if isinstance(output_text, str) and output_text.strip():
|
|
112
|
+
return output_text
|
|
113
|
+
if isinstance(output_text, list):
|
|
114
|
+
for item in output_text:
|
|
115
|
+
if isinstance(item, str) and item.strip():
|
|
116
|
+
return item
|
|
117
|
+
|
|
118
|
+
# 2) Some dumps wrap under "response"
|
|
119
|
+
response_obj = payload.get("response")
|
|
120
|
+
if isinstance(response_obj, dict):
|
|
121
|
+
t = response_obj.get("output_text")
|
|
122
|
+
if isinstance(t, str) and t.strip():
|
|
123
|
+
return t
|
|
124
|
+
nested = _scan_output_list(response_obj.get("output"))
|
|
125
|
+
if nested:
|
|
126
|
+
return nested
|
|
127
|
+
|
|
128
|
+
# 3) Top-level Responses API shape
|
|
129
|
+
top = _scan_output_list(payload.get("output"))
|
|
130
|
+
if top:
|
|
131
|
+
return top
|
|
132
|
+
|
|
133
|
+
# 3.1) Some SDKs expose top-level "message" like Responses(Message)
|
|
134
|
+
message_obj = payload.get("message")
|
|
135
|
+
if isinstance(message_obj, dict):
|
|
136
|
+
contents = message_obj.get("content")
|
|
137
|
+
if isinstance(contents, list):
|
|
138
|
+
for part in contents:
|
|
139
|
+
if not isinstance(part, dict):
|
|
140
|
+
continue
|
|
141
|
+
if part.get("type") in ("output_text", "text", "input_text"):
|
|
142
|
+
tf = part.get("text")
|
|
143
|
+
if isinstance(tf, dict):
|
|
144
|
+
val = tf.get("value") or tf.get("content") or tf.get("text")
|
|
145
|
+
if isinstance(val, str) and val.strip():
|
|
146
|
+
return val
|
|
147
|
+
if isinstance(tf, str) and tf.strip():
|
|
148
|
+
return tf
|
|
149
|
+
if part.get("type") in ("output_json", "json"):
|
|
150
|
+
dumped = _dump_json(part.get("json"))
|
|
151
|
+
if dumped:
|
|
152
|
+
return dumped
|
|
153
|
+
|
|
154
|
+
# 3.2) Rare: top-level "content" directly
|
|
155
|
+
top_content = payload.get("content")
|
|
156
|
+
if isinstance(top_content, list):
|
|
157
|
+
for part in top_content:
|
|
158
|
+
if not isinstance(part, dict):
|
|
159
|
+
continue
|
|
160
|
+
if part.get("type") in ("text", "output_text", "input_text"):
|
|
161
|
+
tf = part.get("text")
|
|
162
|
+
if isinstance(tf, dict):
|
|
163
|
+
val = tf.get("value") or tf.get("content") or tf.get("text")
|
|
164
|
+
if isinstance(val, str) and val.strip():
|
|
165
|
+
return val
|
|
166
|
+
if isinstance(tf, str) and tf.strip():
|
|
167
|
+
return tf
|
|
168
|
+
if part.get("type") in ("output_json", "json"):
|
|
169
|
+
dumped = _dump_json(part.get("json"))
|
|
170
|
+
if dumped:
|
|
171
|
+
return dumped
|
|
172
|
+
|
|
173
|
+
# 4) Chat Completions compatibility
|
|
174
|
+
choices = payload.get("choices")
|
|
175
|
+
if isinstance(choices, list) and choices:
|
|
176
|
+
first_choice = choices[0]
|
|
177
|
+
if isinstance(first_choice, dict):
|
|
178
|
+
message = first_choice.get("message")
|
|
179
|
+
if isinstance(message, dict):
|
|
180
|
+
content = message.get("content")
|
|
181
|
+
if isinstance(content, str) and content.strip():
|
|
182
|
+
return content
|
|
183
|
+
if isinstance(content, list):
|
|
184
|
+
for part in content:
|
|
185
|
+
if isinstance(part, dict) and part.get("type") in ("text", "output_text"):
|
|
186
|
+
t = part.get("text")
|
|
187
|
+
if isinstance(t, dict):
|
|
188
|
+
val = t.get("value") or t.get("content") or t.get("text")
|
|
189
|
+
if isinstance(val, str) and val.strip():
|
|
190
|
+
return val
|
|
191
|
+
if isinstance(t, str) and t.strip():
|
|
192
|
+
return t
|
|
193
|
+
# NEW: handle structured JSON in Chat content list
|
|
194
|
+
if isinstance(part, dict) and part.get("type") in ("output_json", "json"):
|
|
195
|
+
dumped = _dump_json(part.get("json"))
|
|
196
|
+
if dumped:
|
|
197
|
+
return dumped
|
|
198
|
+
# Legacy: choice-level content list
|
|
199
|
+
content_list = first_choice.get("content")
|
|
200
|
+
if isinstance(content_list, list):
|
|
201
|
+
for part in content_list:
|
|
202
|
+
if isinstance(part, dict) and part.get("type") in ("text", "output_text"):
|
|
203
|
+
t = part.get("text")
|
|
204
|
+
if isinstance(t, dict):
|
|
205
|
+
val = t.get("value") or t.get("content") or t.get("text")
|
|
206
|
+
if isinstance(val, str) and val.strip():
|
|
207
|
+
return val
|
|
208
|
+
if isinstance(t, str) and t.strip():
|
|
209
|
+
return t
|
|
210
|
+
# NEW: structured JSON in legacy content list
|
|
211
|
+
if isinstance(part, dict) and part.get("type") in ("output_json", "json"):
|
|
212
|
+
dumped = _dump_json(part.get("json"))
|
|
213
|
+
if dumped:
|
|
214
|
+
return dumped
|
|
215
|
+
|
|
216
|
+
return None
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def _openai_error_message(err: Exception) -> str:
|
|
220
|
+
message = getattr(err, "message", None)
|
|
221
|
+
if isinstance(message, str) and message.strip():
|
|
222
|
+
return message.strip()
|
|
223
|
+
body = getattr(err, "body", None)
|
|
224
|
+
if isinstance(body, dict):
|
|
225
|
+
error_field = body.get("error")
|
|
226
|
+
if isinstance(error_field, dict):
|
|
227
|
+
detail = error_field.get("message") or error_field.get("code")
|
|
228
|
+
if isinstance(detail, str) and detail.strip():
|
|
229
|
+
return detail.strip()
|
|
230
|
+
elif isinstance(error_field, str) and error_field.strip():
|
|
231
|
+
return error_field.strip()
|
|
232
|
+
return str(err)
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def _iterative_trim_call(create_fn, label: str, **kwargs):
|
|
236
|
+
"""
|
|
237
|
+
Call an OpenAI SDK method and iteratively strip unsupported kwargs on TypeError/BadRequest.
|
|
238
|
+
This keeps compatibility with older SDKs (e.g., 2.3.0) and server-side feature-gating.
|
|
239
|
+
"""
|
|
240
|
+
UNSUPPORTED_CANDIDATES = [
|
|
241
|
+
"response_format", "modalities", "reasoning",
|
|
242
|
+
"instructions", "temperature", "max_completion_tokens", "max_tokens",
|
|
243
|
+
]
|
|
244
|
+
attempts = 0
|
|
245
|
+
while attempts < 6:
|
|
246
|
+
try:
|
|
247
|
+
return create_fn(**kwargs)
|
|
248
|
+
except TypeError as e:
|
|
249
|
+
msg = str(e)
|
|
250
|
+
removed = False
|
|
251
|
+
for k in list(kwargs.keys()):
|
|
252
|
+
if any(k == bad and bad in msg for bad in UNSUPPORTED_CANDIDATES):
|
|
253
|
+
kwargs.pop(k, None)
|
|
254
|
+
removed = True
|
|
255
|
+
break
|
|
256
|
+
if not removed:
|
|
257
|
+
# conservative extra trim
|
|
258
|
+
for bad in UNSUPPORTED_CANDIDATES:
|
|
259
|
+
if bad in kwargs:
|
|
260
|
+
kwargs.pop(bad, None)
|
|
261
|
+
removed = True
|
|
262
|
+
break
|
|
263
|
+
if not removed:
|
|
264
|
+
raise
|
|
265
|
+
attempts += 1
|
|
266
|
+
except BadRequestError as e:
|
|
267
|
+
# server-side "Unsupported parameter" (e.g., temperature on gpt-5)
|
|
268
|
+
msg = _openai_error_message(e)
|
|
269
|
+
lowered = msg.lower()
|
|
270
|
+
removed = False
|
|
271
|
+
# Strip any obviously rejected parameters
|
|
272
|
+
for bad in UNSUPPORTED_CANDIDATES:
|
|
273
|
+
if bad in kwargs and bad in lowered:
|
|
274
|
+
kwargs.pop(bad, None)
|
|
275
|
+
removed = True
|
|
276
|
+
if "temperature" in lowered and "temperature" in kwargs:
|
|
277
|
+
kwargs.pop("temperature", None)
|
|
278
|
+
removed = True
|
|
279
|
+
if removed:
|
|
280
|
+
attempts += 1
|
|
281
|
+
continue
|
|
282
|
+
raise
|
|
283
|
+
except OpenAIError as e:
|
|
284
|
+
msg = _openai_error_message(e)
|
|
285
|
+
# Typical gpt-5 complaint: temperature unsupported
|
|
286
|
+
if "temperature" in msg and "unsupported" in msg.lower():
|
|
287
|
+
kwargs.pop("temperature", None)
|
|
288
|
+
attempts += 1
|
|
289
|
+
continue
|
|
290
|
+
raise
|
|
291
|
+
raise RuntimeError(f"{label} failed after trims")
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def _call_openai_family(
|
|
295
|
+
provider: str,
|
|
296
|
+
model: str,
|
|
297
|
+
api_key: str,
|
|
298
|
+
system_content: str,
|
|
299
|
+
user_content: str,
|
|
300
|
+
score_min: float,
|
|
301
|
+
score_max: float,
|
|
302
|
+
timeout: float,
|
|
303
|
+
*,
|
|
304
|
+
base_url: Optional[str] = None,
|
|
305
|
+
force_responses_api: bool = False,
|
|
306
|
+
reasoning_effort: Optional[str] = None, # "low" | "medium" | "high" | "none"/"off"/None
|
|
307
|
+
) -> RewardRubricRunResult:
|
|
308
|
+
# --- Guard: SDK available ---
|
|
309
|
+
if OpenAI is None or BadRequestError is None or OpenAIError is None:
|
|
310
|
+
raise ProviderRequestError(
|
|
311
|
+
provider,
|
|
312
|
+
model,
|
|
313
|
+
"OpenAI SDK is required. Install it via `pip install 'openai>=2.0.0'`.",
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
# --- Client / per-request options ---
|
|
317
|
+
client_kwargs = {"api_key": api_key}
|
|
318
|
+
if base_url:
|
|
319
|
+
client_kwargs["base_url"] = base_url
|
|
320
|
+
client = OpenAI(**client_kwargs)
|
|
321
|
+
req = client.with_options(timeout=timeout) # per-request timeout (SDK v2)
|
|
322
|
+
|
|
323
|
+
# --- Schema materials ---
|
|
324
|
+
_ = reward_schema_definition() # kept for parity/debug even if not used directly
|
|
325
|
+
schema_payload = reward_json_schema()
|
|
326
|
+
# Wrap bare JSON Schema into Responses v2 shape if necessary.
|
|
327
|
+
if isinstance(schema_payload, dict) and "schema" not in schema_payload:
|
|
328
|
+
schema_payload = {
|
|
329
|
+
"name": "reward_score_schema",
|
|
330
|
+
"strict": True,
|
|
331
|
+
"schema": schema_payload,
|
|
332
|
+
}
|
|
333
|
+
|
|
334
|
+
# --- Temperature rules ---
|
|
335
|
+
# gpt-5 family does NOT accept temperature; others should use 0 for determinism.
|
|
336
|
+
is_gpt5 = _is_openai_gpt5_family(model)
|
|
337
|
+
temperature_kwargs: Dict[str, Any] = {}
|
|
338
|
+
if not is_gpt5:
|
|
339
|
+
temperature_kwargs["temperature"] = 0
|
|
340
|
+
|
|
341
|
+
# --- Build inputs ---
|
|
342
|
+
# For Responses API: best practice -> system via `instructions`, user via `input` as `input_text`.
|
|
343
|
+
input_user_only_input_text = [
|
|
344
|
+
{"role": "user", "content": [{"type": "input_text", "text": user_content}]},
|
|
345
|
+
]
|
|
346
|
+
|
|
347
|
+
# Chat Completions message shape (fallback)
|
|
348
|
+
chat_messages = [
|
|
349
|
+
{"role": "system", "content": system_content},
|
|
350
|
+
{"role": "user", "content": user_content},
|
|
351
|
+
]
|
|
352
|
+
|
|
353
|
+
# --- Local helper: finalise ---
|
|
354
|
+
def _finalise(raw_response: Any, text: Optional[str], parsed_obj: Any = None) -> RewardRubricRunResult:
|
|
355
|
+
"""
|
|
356
|
+
- If parsed_obj is provided (dict/list), serialize and sanitize.
|
|
357
|
+
- Else, try to sniff JSON objects in raw_response (output_json) and sanitize.
|
|
358
|
+
- Else, fall back to text extraction / sanitize_json(text).
|
|
359
|
+
"""
|
|
360
|
+
try:
|
|
361
|
+
import json
|
|
362
|
+
|
|
363
|
+
if parsed_obj is not None:
|
|
364
|
+
parsed_str = json.dumps(parsed_obj, ensure_ascii=False)
|
|
365
|
+
score, explanation = sanitize_json(parsed_str)
|
|
366
|
+
else:
|
|
367
|
+
# sniff output_json even if SDK doesn't provide output_parsed
|
|
368
|
+
def _sniff_output_json(payload: Any) -> Optional[str]:
|
|
369
|
+
if not isinstance(payload, dict):
|
|
370
|
+
return None
|
|
371
|
+
for node in (payload, payload.get("response")):
|
|
372
|
+
if not isinstance(node, dict):
|
|
373
|
+
continue
|
|
374
|
+
out = node.get("output")
|
|
375
|
+
if isinstance(out, list):
|
|
376
|
+
for entry in out:
|
|
377
|
+
if not isinstance(entry, dict):
|
|
378
|
+
continue
|
|
379
|
+
contents = entry.get("content")
|
|
380
|
+
if not isinstance(contents, list):
|
|
381
|
+
continue
|
|
382
|
+
for part in contents:
|
|
383
|
+
if isinstance(part, dict) and part.get("type") in ("output_json", "json"):
|
|
384
|
+
pj = part.get("json")
|
|
385
|
+
try:
|
|
386
|
+
s = json.dumps(pj, ensure_ascii=False)
|
|
387
|
+
if isinstance(s, str) and s.strip():
|
|
388
|
+
return s
|
|
389
|
+
except Exception:
|
|
390
|
+
pass
|
|
391
|
+
return None
|
|
392
|
+
|
|
393
|
+
json_str = _sniff_output_json(raw_response)
|
|
394
|
+
if json_str:
|
|
395
|
+
score, explanation = sanitize_json(json_str)
|
|
396
|
+
else:
|
|
397
|
+
if not text:
|
|
398
|
+
text = _extract_openai_responses_text(raw_response)
|
|
399
|
+
if not text:
|
|
400
|
+
# Surface incomplete status for better error messages
|
|
401
|
+
status = None
|
|
402
|
+
incomplete = None
|
|
403
|
+
if isinstance(raw_response, dict):
|
|
404
|
+
node = raw_response.get("response") if isinstance(raw_response.get("response"), dict) else raw_response
|
|
405
|
+
status = node.get("status")
|
|
406
|
+
incomplete = node.get("incomplete_details")
|
|
407
|
+
if status and status != "completed":
|
|
408
|
+
reason = (incomplete or {}).get("reason")
|
|
409
|
+
raise ProviderRequestError(provider, model, f"Response incomplete: status={status}, reason={reason or 'unknown'}")
|
|
410
|
+
raise ProviderRequestError(provider, model, "Model response did not include any content.")
|
|
411
|
+
score, explanation = sanitize_json(text)
|
|
412
|
+
|
|
413
|
+
except ValueError as err:
|
|
414
|
+
raise ProviderRequestError(provider, model, str(err)) from err
|
|
415
|
+
|
|
416
|
+
bounded = max(score_min, min(score_max, score))
|
|
417
|
+
return {"score": bounded, "explanation": explanation, "raw": raw_response}
|
|
418
|
+
|
|
419
|
+
# --- Helper: try Responses with a given response_format mode and adaptive bump ---
|
|
420
|
+
def _try_responses(mode: str) -> Optional[RewardRubricRunResult]:
|
|
421
|
+
"""
|
|
422
|
+
mode: "json_schema" | "json_object" | "none"
|
|
423
|
+
Returns a RewardRubricRunResult or None to indicate fallback is needed.
|
|
424
|
+
"""
|
|
425
|
+
kwargs_rf: Dict[str, Any] = {}
|
|
426
|
+
if mode == "json_schema":
|
|
427
|
+
kwargs_rf["response_format"] = {"type": "json_schema", "json_schema": schema_payload}
|
|
428
|
+
elif mode == "json_object":
|
|
429
|
+
kwargs_rf["response_format"] = {"type": "json_object"}
|
|
430
|
+
|
|
431
|
+
responses_base: Dict[str, Any] = {
|
|
432
|
+
"model": model,
|
|
433
|
+
"instructions": system_content, # system here
|
|
434
|
+
"input": input_user_only_input_text, # user only; input_text
|
|
435
|
+
"max_output_tokens": SAFE_DEFAULT_MAX_OUTPUT,
|
|
436
|
+
"modalities": ["text"], # be explicit; will be trimmed if unsupported
|
|
437
|
+
}
|
|
438
|
+
# gpt-5 must not receive temperature
|
|
439
|
+
if not is_gpt5:
|
|
440
|
+
responses_base.update(temperature_kwargs)
|
|
441
|
+
|
|
442
|
+
# Optional reasoning effort knob, if provided
|
|
443
|
+
effort = (reasoning_effort or "").strip().lower() if reasoning_effort else None
|
|
444
|
+
if effort in ("low", "medium", "high"):
|
|
445
|
+
responses_base["reasoning"] = {"effort": effort}
|
|
446
|
+
# "none"/"off"/None => do not send "reasoning" kw
|
|
447
|
+
|
|
448
|
+
try:
|
|
449
|
+
# First attempt
|
|
450
|
+
response = _iterative_trim_call(req.responses.create, "responses", **{**responses_base, **kwargs_rf})
|
|
451
|
+
raw = dump_model(response)
|
|
452
|
+
# If incomplete due to token ceiling, bump once and retry
|
|
453
|
+
try:
|
|
454
|
+
node = raw.get("response", raw) if isinstance(raw, dict) else {}
|
|
455
|
+
if isinstance(node, dict) and node.get("status") == "incomplete":
|
|
456
|
+
inc = node.get("incomplete_details") or {}
|
|
457
|
+
if inc.get("reason") == "max_output_tokens":
|
|
458
|
+
bumped = {**responses_base, **kwargs_rf}
|
|
459
|
+
bumped["max_output_tokens"] = min(max(SAFE_DEFAULT_MAX_OUTPUT * SAFE_BUMP_FACTOR, 2048), SAFE_MAX_OUTPUT_CAP)
|
|
460
|
+
response_b = _iterative_trim_call(req.responses.create, "responses-bumped", **bumped)
|
|
461
|
+
raw_b = dump_model(response_b)
|
|
462
|
+
|
|
463
|
+
parsed_b = getattr(response_b, "output_parsed", None)
|
|
464
|
+
if parsed_b is not None:
|
|
465
|
+
return _finalise(raw_b, None, parsed_obj=parsed_b)
|
|
466
|
+
|
|
467
|
+
text_b = getattr(response_b, "output_text", None) or _extract_openai_responses_text(raw_b)
|
|
468
|
+
if text_b:
|
|
469
|
+
return _finalise(raw_b, text_b)
|
|
470
|
+
# else: fall through to chat fallback
|
|
471
|
+
return None
|
|
472
|
+
except Exception:
|
|
473
|
+
# if any inspection error, just proceed to parse normally
|
|
474
|
+
pass
|
|
475
|
+
|
|
476
|
+
# Normal parse path
|
|
477
|
+
parsed = getattr(response, "output_parsed", None)
|
|
478
|
+
if parsed is not None:
|
|
479
|
+
return _finalise(raw, None, parsed_obj=parsed)
|
|
480
|
+
|
|
481
|
+
text = getattr(response, "output_text", None) or _extract_openai_responses_text(raw)
|
|
482
|
+
if text:
|
|
483
|
+
return _finalise(raw, text)
|
|
484
|
+
|
|
485
|
+
# No content -> let caller try next mode or chat
|
|
486
|
+
return None
|
|
487
|
+
|
|
488
|
+
except BadRequestError as err:
|
|
489
|
+
msg = _openai_error_message(err).lower()
|
|
490
|
+
# If backend complains about json_schema, try json_object / none next
|
|
491
|
+
if "response_format" in msg and "json_schema" in msg and mode == "json_schema":
|
|
492
|
+
return None # caller will try json_object
|
|
493
|
+
if "response_format" in msg and mode in ("json_schema", "json_object"):
|
|
494
|
+
return None # caller will try 'none'
|
|
495
|
+
# Other errors -> raise
|
|
496
|
+
raise ProviderRequestError(provider, model, f"Model request failed. {msg}") from err
|
|
497
|
+
|
|
498
|
+
except OpenAIError as err:
|
|
499
|
+
msg = _openai_error_message(err)
|
|
500
|
+
raise ProviderRequestError(provider, model, f"Model request failed. {msg}") from err
|
|
501
|
+
|
|
502
|
+
# --- Decide primary path ---
|
|
503
|
+
use_responses_api = force_responses_api or _should_use_openai_responses(model)
|
|
504
|
+
|
|
505
|
+
# --------- RESPONSES API PATH ---------
|
|
506
|
+
if use_responses_api:
|
|
507
|
+
# Try json_schema -> json_object -> none
|
|
508
|
+
for mode in ("json_schema", "json_object", "none"):
|
|
509
|
+
result = _try_responses(mode)
|
|
510
|
+
if result is not None:
|
|
511
|
+
return result
|
|
512
|
+
# fall through to Chat if Responses yielded nothing
|
|
513
|
+
|
|
514
|
+
# --------- CHAT COMPLETIONS PATH ---------
|
|
515
|
+
try:
|
|
516
|
+
fam = model.strip().lower().split(":")[0]
|
|
517
|
+
tokens_kw: Dict[str, Any] = {}
|
|
518
|
+
if any(fam.startswith(p) for p in FAMILIES_USE_MAX_COMPLETION):
|
|
519
|
+
tokens_kw["max_completion_tokens"] = SAFE_DEFAULT_MAX_OUTPUT
|
|
520
|
+
else:
|
|
521
|
+
tokens_kw["max_tokens"] = SAFE_DEFAULT_MAX_OUTPUT
|
|
522
|
+
|
|
523
|
+
# build response_format for Chat
|
|
524
|
+
chat_rf: Dict[str, Any] = {}
|
|
525
|
+
# We attempt to send json_schema; if SDK rejects, _iterative_trim_call will remove.
|
|
526
|
+
chat_rf = {"response_format": {"type": "json_schema", "json_schema": schema_payload}}
|
|
527
|
+
|
|
528
|
+
chat_base: Dict[str, Any] = {
|
|
529
|
+
"model": model,
|
|
530
|
+
"messages": chat_messages,
|
|
531
|
+
**tokens_kw,
|
|
532
|
+
**chat_rf,
|
|
533
|
+
}
|
|
534
|
+
if not is_gpt5:
|
|
535
|
+
chat_base.update(temperature_kwargs) # temperature=0 for non-gpt-5
|
|
536
|
+
|
|
537
|
+
completion = _iterative_trim_call(req.chat.completions.create, "chat.completions", **chat_base)
|
|
538
|
+
raw = dump_model(completion)
|
|
539
|
+
text = _extract_openai_responses_text(raw)
|
|
540
|
+
return _finalise(raw, text)
|
|
541
|
+
|
|
542
|
+
except OpenAIError as err:
|
|
543
|
+
msg = _openai_error_message(err)
|
|
544
|
+
raise ProviderRequestError(provider, model, f"Model request failed. {msg}") from err
|
|
545
|
+
|
|
546
|
+
|
|
547
|
+
class OpenAIProvider(RubricProvider):
|
|
548
|
+
name = "openai"
|
|
549
|
+
|
|
550
|
+
def default_timeout(self, model: str) -> float:
|
|
551
|
+
return DEFAULT_REQUEST_TIMEOUT_SECONDS
|
|
552
|
+
|
|
553
|
+
def run(self, request: ProviderRequest) -> RewardRubricRunResult:
|
|
554
|
+
# Try to fetch reasoning effort hint from request if present
|
|
555
|
+
effort_hint: Optional[str] = getattr(request, "reasoning_effort", None)
|
|
556
|
+
opts = getattr(request, "options", None)
|
|
557
|
+
if effort_hint is None and isinstance(opts, dict):
|
|
558
|
+
effort_hint = opts.get("reasoning_effort") or opts.get("effort")
|
|
559
|
+
|
|
560
|
+
return _call_openai_family(
|
|
561
|
+
provider=self.name,
|
|
562
|
+
model=request.model,
|
|
563
|
+
api_key=request.api_key,
|
|
564
|
+
system_content=request.system_content,
|
|
565
|
+
user_content=request.user_content,
|
|
566
|
+
score_min=request.score_min,
|
|
567
|
+
score_max=request.score_max,
|
|
568
|
+
timeout=request.timeout,
|
|
569
|
+
reasoning_effort=effort_hint,
|
|
570
|
+
)
|
|
571
|
+
|
|
572
|
+
|
|
573
|
+
class XAIProvider(OpenAIProvider):
|
|
574
|
+
name = "xai"
|
|
575
|
+
|
|
576
|
+
def default_timeout(self, model: str) -> float:
|
|
577
|
+
normalized = model.strip().lower()
|
|
578
|
+
if normalized.startswith("grok-4"):
|
|
579
|
+
return 60.0
|
|
580
|
+
return 45.0
|
|
581
|
+
|
|
582
|
+
def run(self, request: ProviderRequest) -> RewardRubricRunResult:
|
|
583
|
+
# Try to fetch reasoning effort hint from request if present
|
|
584
|
+
effort_hint: Optional[str] = getattr(request, "reasoning_effort", None)
|
|
585
|
+
opts = getattr(request, "options", None)
|
|
586
|
+
if effort_hint is None and isinstance(opts, dict):
|
|
587
|
+
effort_hint = opts.get("reasoning_effort") or opts.get("effort")
|
|
588
|
+
|
|
589
|
+
return _call_openai_family(
|
|
590
|
+
provider=self.name,
|
|
591
|
+
model=request.model,
|
|
592
|
+
api_key=request.api_key,
|
|
593
|
+
system_content=request.system_content,
|
|
594
|
+
user_content=request.user_content,
|
|
595
|
+
score_min=request.score_min,
|
|
596
|
+
score_max=request.score_max,
|
|
597
|
+
timeout=request.timeout,
|
|
598
|
+
base_url="https://api.x.ai/v1",
|
|
599
|
+
force_responses_api=True,
|
|
600
|
+
reasoning_effort=effort_hint,
|
|
601
|
+
)
|
|
602
|
+
|
|
603
|
+
|
|
604
|
+
__all__ = [
|
|
605
|
+
"OpenAIProvider",
|
|
606
|
+
"XAIProvider",
|
|
607
|
+
]
|