osmosis-ai 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of osmosis-ai might be problematic. Click here for more details.
- osmosis_ai/__init__.py +13 -4
- osmosis_ai/consts.py +1 -1
- osmosis_ai/providers/__init__.py +36 -0
- osmosis_ai/providers/anthropic_provider.py +85 -0
- osmosis_ai/providers/base.py +60 -0
- osmosis_ai/providers/gemini_provider.py +269 -0
- osmosis_ai/providers/openai_family.py +607 -0
- osmosis_ai/providers/shared.py +92 -0
- osmosis_ai/rubric_eval.py +537 -0
- osmosis_ai/rubric_types.py +49 -0
- osmosis_ai/utils.py +392 -1
- osmosis_ai-0.2.2.dist-info/METADATA +241 -0
- osmosis_ai-0.2.2.dist-info/RECORD +16 -0
- osmosis_ai-0.2.1.dist-info/METADATA +0 -143
- osmosis_ai-0.2.1.dist-info/RECORD +0 -8
- {osmosis_ai-0.2.1.dist-info → osmosis_ai-0.2.2.dist-info}/WHEEL +0 -0
- {osmosis_ai-0.2.1.dist-info → osmosis_ai-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {osmosis_ai-0.2.1.dist-info → osmosis_ai-0.2.2.dist-info}/top_level.txt +0 -0
osmosis_ai/utils.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
|
|
2
2
|
import functools
|
|
3
3
|
import inspect
|
|
4
|
-
|
|
4
|
+
import types
|
|
5
|
+
from typing import Any, Callable, Mapping, Union, get_args, get_origin, get_type_hints
|
|
5
6
|
|
|
6
7
|
|
|
7
8
|
def osmosis_reward(func: Callable) -> Callable:
|
|
@@ -61,3 +62,393 @@ def osmosis_reward(func: Callable) -> Callable:
|
|
|
61
62
|
return result
|
|
62
63
|
|
|
63
64
|
return wrapper
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
ALLOWED_ROLES = {"user", "system", "assistant", "developer", "tool", "function"}
|
|
68
|
+
|
|
69
|
+
_UNION_TYPES = {Union}
|
|
70
|
+
_types_union_type = getattr(types, "UnionType", None)
|
|
71
|
+
if _types_union_type is not None:
|
|
72
|
+
_UNION_TYPES.add(_types_union_type)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _is_str_annotation(annotation: Any) -> bool:
|
|
76
|
+
if annotation is inspect.Parameter.empty:
|
|
77
|
+
return False
|
|
78
|
+
if annotation is str:
|
|
79
|
+
return True
|
|
80
|
+
if isinstance(annotation, str):
|
|
81
|
+
return annotation in {"str", "builtins.str"}
|
|
82
|
+
if isinstance(annotation, type):
|
|
83
|
+
try:
|
|
84
|
+
return issubclass(annotation, str)
|
|
85
|
+
except TypeError:
|
|
86
|
+
return False
|
|
87
|
+
forward_arg = getattr(annotation, "__forward_arg__", None)
|
|
88
|
+
if isinstance(forward_arg, str):
|
|
89
|
+
return forward_arg in {"str", "builtins.str"}
|
|
90
|
+
return False
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _is_optional_str(annotation: Any) -> bool:
|
|
94
|
+
if _is_str_annotation(annotation):
|
|
95
|
+
return True
|
|
96
|
+
if isinstance(annotation, str):
|
|
97
|
+
normalized = annotation.replace(" ", "")
|
|
98
|
+
if normalized in {
|
|
99
|
+
"Optional[str]",
|
|
100
|
+
"typing.Optional[str]",
|
|
101
|
+
"Str|None",
|
|
102
|
+
"str|None",
|
|
103
|
+
"builtins.str|None",
|
|
104
|
+
"None|str",
|
|
105
|
+
"None|builtins.str",
|
|
106
|
+
}:
|
|
107
|
+
return True
|
|
108
|
+
origin = get_origin(annotation)
|
|
109
|
+
if origin in _UNION_TYPES:
|
|
110
|
+
args = tuple(arg for arg in get_args(annotation) if arg is not type(None)) # noqa: E721
|
|
111
|
+
return len(args) == 1 and _is_str_annotation(args[0])
|
|
112
|
+
return False
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _is_list_annotation(annotation: Any) -> bool:
|
|
116
|
+
if annotation is list:
|
|
117
|
+
return True
|
|
118
|
+
if isinstance(annotation, str):
|
|
119
|
+
normalized = annotation.replace(" ", "")
|
|
120
|
+
return (
|
|
121
|
+
normalized in {"list", "builtins.list", "typing.List", "List"}
|
|
122
|
+
or normalized.startswith("list[")
|
|
123
|
+
or normalized.startswith("builtins.list[")
|
|
124
|
+
or normalized.startswith("typing.List[")
|
|
125
|
+
or normalized.startswith("List[")
|
|
126
|
+
)
|
|
127
|
+
origin = get_origin(annotation)
|
|
128
|
+
return origin is list
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _is_float_annotation(annotation: Any) -> bool:
|
|
132
|
+
if annotation in {inspect.Parameter.empty, float}:
|
|
133
|
+
return True
|
|
134
|
+
if isinstance(annotation, str):
|
|
135
|
+
return annotation in {"float", "builtins.float"}
|
|
136
|
+
origin = get_origin(annotation)
|
|
137
|
+
return origin is float
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def _is_numeric(value: Any) -> bool:
|
|
141
|
+
return isinstance(value, (int, float)) and not isinstance(value, bool)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def _is_dict_annotation(annotation: Any) -> bool:
|
|
145
|
+
if annotation in {dict, Mapping}:
|
|
146
|
+
return True
|
|
147
|
+
origin = get_origin(annotation)
|
|
148
|
+
if origin in {dict, Mapping}:
|
|
149
|
+
return True
|
|
150
|
+
if isinstance(annotation, type):
|
|
151
|
+
try:
|
|
152
|
+
return issubclass(annotation, dict)
|
|
153
|
+
except TypeError:
|
|
154
|
+
return False
|
|
155
|
+
if isinstance(annotation, str):
|
|
156
|
+
normalized = annotation.replace(" ", "")
|
|
157
|
+
return (
|
|
158
|
+
normalized in {"dict", "builtins.dict", "typing.Mapping", "collections.abc.Mapping", "Mapping"}
|
|
159
|
+
or normalized.startswith("dict[")
|
|
160
|
+
or normalized.startswith("builtins.dict[")
|
|
161
|
+
or normalized.startswith("typing.Dict[")
|
|
162
|
+
or normalized.startswith("Dict[")
|
|
163
|
+
or normalized.startswith("typing.Mapping[")
|
|
164
|
+
or normalized.startswith("Mapping[")
|
|
165
|
+
)
|
|
166
|
+
return False
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def osmosis_rubric(func: Callable) -> Callable:
|
|
170
|
+
"""
|
|
171
|
+
Decorator for rubric functions that enforces the signature:
|
|
172
|
+
(model_info: dict, rubric: str, messages: list, ground_truth: Optional[str] = None,
|
|
173
|
+
system_message: Optional[str] = None, extra_info: dict = None,
|
|
174
|
+
score_min: float = 0.0, score_max: float = 1.0) -> float
|
|
175
|
+
|
|
176
|
+
The `model_info` mapping must provide non-empty string entries for both `provider` and `model`.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
func: The rubric function to be wrapped
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
The wrapped function
|
|
183
|
+
|
|
184
|
+
Raises:
|
|
185
|
+
TypeError: If the function doesn't have the required signature or doesn't return a float
|
|
186
|
+
|
|
187
|
+
Example:
|
|
188
|
+
@osmosis_rubric
|
|
189
|
+
def evaluate_response(
|
|
190
|
+
model_info: dict,
|
|
191
|
+
rubric: str,
|
|
192
|
+
messages: list,
|
|
193
|
+
ground_truth: str | None = None,
|
|
194
|
+
system_message: str | None = None,
|
|
195
|
+
extra_info: dict = None,
|
|
196
|
+
score_min: float = 0.0,
|
|
197
|
+
score_max: float = 1.0,
|
|
198
|
+
) -> float:
|
|
199
|
+
return some_evaluation(model_info, messages, ground_truth)
|
|
200
|
+
"""
|
|
201
|
+
# Validate function signature
|
|
202
|
+
sig = inspect.signature(func)
|
|
203
|
+
params = list(sig.parameters.values())
|
|
204
|
+
try:
|
|
205
|
+
resolved_annotations = get_type_hints(
|
|
206
|
+
func,
|
|
207
|
+
globalns=getattr(func, "__globals__", {}),
|
|
208
|
+
include_extras=True,
|
|
209
|
+
)
|
|
210
|
+
except Exception: # pragma: no cover - best effort for forward refs
|
|
211
|
+
resolved_annotations = {}
|
|
212
|
+
|
|
213
|
+
# Check parameter count
|
|
214
|
+
if len(params) < 3 or len(params) > 8:
|
|
215
|
+
raise TypeError(f"Function {func.__name__} must have between 3 and 8 parameters, got {len(params)}")
|
|
216
|
+
|
|
217
|
+
# Check first parameter: model_info: dict
|
|
218
|
+
model_info_param = params[0]
|
|
219
|
+
if model_info_param.name != "model_info":
|
|
220
|
+
raise TypeError(f"First parameter must be named 'model_info', got '{model_info_param.name}'")
|
|
221
|
+
model_info_annotation = resolved_annotations.get(model_info_param.name, model_info_param.annotation)
|
|
222
|
+
if not _is_dict_annotation(model_info_annotation):
|
|
223
|
+
raise TypeError(
|
|
224
|
+
f"First parameter 'model_info' must be annotated as a dict or mapping, got {model_info_annotation}"
|
|
225
|
+
)
|
|
226
|
+
if model_info_param.default is not inspect.Parameter.empty:
|
|
227
|
+
raise TypeError("First parameter 'model_info' cannot have a default value")
|
|
228
|
+
|
|
229
|
+
# Check second parameter: rubric: str
|
|
230
|
+
rubric_param = params[1]
|
|
231
|
+
if rubric_param.name != "rubric":
|
|
232
|
+
raise TypeError(f"Second parameter must be named 'rubric', got '{rubric_param.name}'")
|
|
233
|
+
rubric_annotation = resolved_annotations.get(rubric_param.name, rubric_param.annotation)
|
|
234
|
+
if not _is_str_annotation(rubric_annotation):
|
|
235
|
+
raise TypeError(f"Second parameter 'rubric' must be annotated as str, got {rubric_annotation}")
|
|
236
|
+
if rubric_param.default is not inspect.Parameter.empty:
|
|
237
|
+
raise TypeError("Second parameter 'rubric' cannot have a default value")
|
|
238
|
+
|
|
239
|
+
# Check third parameter: messages: list
|
|
240
|
+
messages_param = params[2]
|
|
241
|
+
if messages_param.name != "messages":
|
|
242
|
+
raise TypeError(f"Third parameter must be named 'messages', got '{messages_param.name}'")
|
|
243
|
+
messages_annotation = resolved_annotations.get(messages_param.name, messages_param.annotation)
|
|
244
|
+
if messages_annotation is inspect.Parameter.empty:
|
|
245
|
+
raise TypeError("Third parameter 'messages' must be annotated as list")
|
|
246
|
+
if not _is_list_annotation(messages_annotation):
|
|
247
|
+
raise TypeError(f"Third parameter 'messages' must be annotated as list, got {messages_annotation}")
|
|
248
|
+
if messages_param.default is not inspect.Parameter.empty:
|
|
249
|
+
raise TypeError("Third parameter 'messages' cannot have a default value")
|
|
250
|
+
|
|
251
|
+
optional_params = params[3:]
|
|
252
|
+
|
|
253
|
+
if optional_params:
|
|
254
|
+
ground_truth_param = optional_params[0]
|
|
255
|
+
# Check fourth parameter: ground_truth: Optional[str]
|
|
256
|
+
if ground_truth_param.name != "ground_truth":
|
|
257
|
+
raise TypeError(f"Fourth parameter must be named 'ground_truth', got '{ground_truth_param.name}'")
|
|
258
|
+
ground_truth_annotation = resolved_annotations.get(
|
|
259
|
+
ground_truth_param.name,
|
|
260
|
+
ground_truth_param.annotation,
|
|
261
|
+
)
|
|
262
|
+
if ground_truth_annotation is inspect.Parameter.empty or not _is_optional_str(ground_truth_annotation):
|
|
263
|
+
raise TypeError(
|
|
264
|
+
"Fourth parameter 'ground_truth' must be annotated as Optional[str] or str"
|
|
265
|
+
)
|
|
266
|
+
if ground_truth_param.default is inspect.Parameter.empty:
|
|
267
|
+
raise TypeError("Fourth parameter 'ground_truth' must have a default value of None")
|
|
268
|
+
if ground_truth_param.default is not None:
|
|
269
|
+
raise TypeError("Fourth parameter 'ground_truth' must default to None")
|
|
270
|
+
optional_params = optional_params[1:]
|
|
271
|
+
|
|
272
|
+
if optional_params:
|
|
273
|
+
system_message_param = optional_params[0]
|
|
274
|
+
# Check fifth parameter: system_message: Optional[str]
|
|
275
|
+
if system_message_param.name != "system_message":
|
|
276
|
+
raise TypeError(f"Fifth parameter must be named 'system_message', got '{system_message_param.name}'")
|
|
277
|
+
system_message_annotation = resolved_annotations.get(
|
|
278
|
+
system_message_param.name,
|
|
279
|
+
system_message_param.annotation,
|
|
280
|
+
)
|
|
281
|
+
if system_message_annotation is inspect.Parameter.empty or not _is_optional_str(system_message_annotation):
|
|
282
|
+
raise TypeError(
|
|
283
|
+
"Fifth parameter 'system_message' must be annotated as Optional[str] or str"
|
|
284
|
+
)
|
|
285
|
+
if system_message_param.default is inspect.Parameter.empty:
|
|
286
|
+
raise TypeError("Fifth parameter 'system_message' must have a default value of None")
|
|
287
|
+
if system_message_param.default is not None:
|
|
288
|
+
raise TypeError("Fifth parameter 'system_message' must default to None")
|
|
289
|
+
optional_params = optional_params[1:]
|
|
290
|
+
|
|
291
|
+
if optional_params:
|
|
292
|
+
extra_info_param = optional_params[0]
|
|
293
|
+
# Check sixth parameter: extra_info: dict = None
|
|
294
|
+
if extra_info_param.name != "extra_info":
|
|
295
|
+
raise TypeError(f"Sixth parameter must be named 'extra_info', got '{extra_info_param.name}'")
|
|
296
|
+
extra_info_annotation = resolved_annotations.get(
|
|
297
|
+
extra_info_param.name,
|
|
298
|
+
extra_info_param.annotation,
|
|
299
|
+
)
|
|
300
|
+
if extra_info_annotation is inspect.Parameter.empty or not _is_dict_annotation(extra_info_annotation):
|
|
301
|
+
raise TypeError(
|
|
302
|
+
f"Sixth parameter 'extra_info' must be annotated as dict, got {extra_info_annotation}"
|
|
303
|
+
)
|
|
304
|
+
if extra_info_param.default is inspect.Parameter.empty:
|
|
305
|
+
raise TypeError("Sixth parameter 'extra_info' must have a default value of None")
|
|
306
|
+
if extra_info_param.default is not None:
|
|
307
|
+
raise TypeError("Sixth parameter 'extra_info' must default to None")
|
|
308
|
+
optional_params = optional_params[1:]
|
|
309
|
+
|
|
310
|
+
if optional_params:
|
|
311
|
+
score_min_param = optional_params[0]
|
|
312
|
+
# Check seventh parameter: score_min: float = 0.0
|
|
313
|
+
if score_min_param.name != "score_min":
|
|
314
|
+
raise TypeError(f"Seventh parameter must be named 'score_min', got '{score_min_param.name}'")
|
|
315
|
+
score_min_annotation = resolved_annotations.get(
|
|
316
|
+
score_min_param.name,
|
|
317
|
+
score_min_param.annotation,
|
|
318
|
+
)
|
|
319
|
+
if not _is_float_annotation(score_min_annotation):
|
|
320
|
+
raise TypeError(
|
|
321
|
+
f"Seventh parameter 'score_min' must be annotated as float, got {score_min_annotation}"
|
|
322
|
+
)
|
|
323
|
+
if score_min_param.default is inspect.Parameter.empty:
|
|
324
|
+
raise TypeError("Seventh parameter 'score_min' must have a default value of 0.0")
|
|
325
|
+
if not _is_numeric(score_min_param.default):
|
|
326
|
+
raise TypeError("Seventh parameter 'score_min' must default to a numeric value")
|
|
327
|
+
if float(score_min_param.default) != 0.0:
|
|
328
|
+
raise TypeError("Seventh parameter 'score_min' must default to 0.0")
|
|
329
|
+
optional_params = optional_params[1:]
|
|
330
|
+
|
|
331
|
+
if optional_params:
|
|
332
|
+
score_max_param = optional_params[0]
|
|
333
|
+
# Check eighth parameter: score_max: float = 1.0
|
|
334
|
+
if score_max_param.name != "score_max":
|
|
335
|
+
raise TypeError(f"Eighth parameter must be named 'score_max', got '{score_max_param.name}'")
|
|
336
|
+
score_max_annotation = resolved_annotations.get(
|
|
337
|
+
score_max_param.name,
|
|
338
|
+
score_max_param.annotation,
|
|
339
|
+
)
|
|
340
|
+
if not _is_float_annotation(score_max_annotation):
|
|
341
|
+
raise TypeError(
|
|
342
|
+
f"Eighth parameter 'score_max' must be annotated as float, got {score_max_annotation}"
|
|
343
|
+
)
|
|
344
|
+
if score_max_param.default is inspect.Parameter.empty:
|
|
345
|
+
raise TypeError("Eighth parameter 'score_max' must have a default value of 1.0")
|
|
346
|
+
if not _is_numeric(score_max_param.default):
|
|
347
|
+
raise TypeError("Eighth parameter 'score_max' must default to a numeric value")
|
|
348
|
+
if float(score_max_param.default) != 1.0:
|
|
349
|
+
raise TypeError("Eighth parameter 'score_max' must default to 1.0")
|
|
350
|
+
optional_params = optional_params[1:]
|
|
351
|
+
|
|
352
|
+
if optional_params:
|
|
353
|
+
unexpected_param = optional_params[0]
|
|
354
|
+
raise TypeError(f"Function {func.__name__} has unexpected parameter '{unexpected_param.name}'")
|
|
355
|
+
|
|
356
|
+
@functools.wraps(func)
|
|
357
|
+
def wrapper(*args, **kwargs):
|
|
358
|
+
# Remove unsupported kwargs
|
|
359
|
+
kwargs.pop("data_source", None)
|
|
360
|
+
bound = sig.bind_partial(*args, **kwargs)
|
|
361
|
+
bound.apply_defaults()
|
|
362
|
+
|
|
363
|
+
# Validate model_info argument
|
|
364
|
+
if "model_info" not in bound.arguments:
|
|
365
|
+
raise TypeError("'model_info' argument is required")
|
|
366
|
+
model_info_value = bound.arguments["model_info"]
|
|
367
|
+
if not isinstance(model_info_value, Mapping):
|
|
368
|
+
raise TypeError(f"'model_info' must be a mapping, got {type(model_info_value).__name__}")
|
|
369
|
+
required_model_fields = {"provider", "model"}
|
|
370
|
+
missing_model_fields = required_model_fields - set(model_info_value.keys())
|
|
371
|
+
if missing_model_fields:
|
|
372
|
+
raise ValueError(f"'model_info' is missing required fields: {sorted(missing_model_fields)}")
|
|
373
|
+
provider_value = model_info_value.get("provider")
|
|
374
|
+
if not isinstance(provider_value, str) or not provider_value.strip():
|
|
375
|
+
raise TypeError("'model_info[\"provider\"]' must be a non-empty string")
|
|
376
|
+
model_value = model_info_value.get("model")
|
|
377
|
+
if not isinstance(model_value, str) or not model_value.strip():
|
|
378
|
+
raise TypeError("'model_info[\"model\"]' must be a non-empty string")
|
|
379
|
+
|
|
380
|
+
# Validate rubric argument
|
|
381
|
+
if "rubric" not in bound.arguments:
|
|
382
|
+
raise TypeError("'rubric' argument is required")
|
|
383
|
+
rubric_value = bound.arguments["rubric"]
|
|
384
|
+
if not isinstance(rubric_value, str):
|
|
385
|
+
raise TypeError(f"'rubric' must be a string, got {type(rubric_value).__name__}")
|
|
386
|
+
|
|
387
|
+
# Validate messages argument
|
|
388
|
+
if "messages" not in bound.arguments:
|
|
389
|
+
raise TypeError("'messages' argument is required")
|
|
390
|
+
messages_value = bound.arguments["messages"]
|
|
391
|
+
if not isinstance(messages_value, list):
|
|
392
|
+
raise TypeError(f"'messages' must be a list, got {type(messages_value).__name__}")
|
|
393
|
+
|
|
394
|
+
# Validate optional ground_truth argument
|
|
395
|
+
ground_truth_value = bound.arguments.get("ground_truth")
|
|
396
|
+
if ground_truth_value is not None and not isinstance(ground_truth_value, str):
|
|
397
|
+
raise TypeError(
|
|
398
|
+
f"'ground_truth' must be a string or None, got {type(ground_truth_value).__name__}"
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
# Validate optional system_message argument
|
|
402
|
+
system_message_value = bound.arguments.get("system_message")
|
|
403
|
+
if system_message_value is not None and not isinstance(system_message_value, str):
|
|
404
|
+
raise TypeError(
|
|
405
|
+
f"'system_message' must be a string or None, got {type(system_message_value).__name__}"
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
# Validate messages structure
|
|
409
|
+
for index, message in enumerate(messages_value):
|
|
410
|
+
if not isinstance(message, dict):
|
|
411
|
+
raise TypeError(f"'messages[{index}]' must be a dict, got {type(message).__name__}")
|
|
412
|
+
missing_fields = {"type", "role", "content"} - message.keys()
|
|
413
|
+
if missing_fields:
|
|
414
|
+
raise ValueError(f"'messages[{index}]' is missing required fields: {missing_fields}")
|
|
415
|
+
if message["role"] not in ALLOWED_ROLES:
|
|
416
|
+
raise ValueError(
|
|
417
|
+
f"'messages[{index}]['role']' must be one of {sorted(ALLOWED_ROLES)}, "
|
|
418
|
+
f"got '{message['role']}'"
|
|
419
|
+
)
|
|
420
|
+
if not isinstance(message["content"], list):
|
|
421
|
+
raise TypeError(f"'messages[{index}]['content']' must be a list")
|
|
422
|
+
|
|
423
|
+
score_min_present = "score_min" in bound.arguments
|
|
424
|
+
score_max_present = "score_max" in bound.arguments
|
|
425
|
+
|
|
426
|
+
if score_min_present:
|
|
427
|
+
score_min_value = bound.arguments["score_min"]
|
|
428
|
+
if not _is_numeric(score_min_value):
|
|
429
|
+
raise TypeError(
|
|
430
|
+
f"'score_min' must be a numeric type, got {type(score_min_value).__name__}"
|
|
431
|
+
)
|
|
432
|
+
else:
|
|
433
|
+
score_min_value = None
|
|
434
|
+
|
|
435
|
+
if score_max_present:
|
|
436
|
+
score_max_value = bound.arguments["score_max"]
|
|
437
|
+
if not _is_numeric(score_max_value):
|
|
438
|
+
raise TypeError(
|
|
439
|
+
f"'score_max' must be a numeric type, got {type(score_max_value).__name__}"
|
|
440
|
+
)
|
|
441
|
+
else:
|
|
442
|
+
score_max_value = None
|
|
443
|
+
|
|
444
|
+
if score_min_present and score_max_present:
|
|
445
|
+
if float(score_max_value) <= float(score_min_value):
|
|
446
|
+
raise ValueError("'score_max' must be greater than 'score_min'")
|
|
447
|
+
|
|
448
|
+
# Validate return type
|
|
449
|
+
result = func(*args, **kwargs)
|
|
450
|
+
if not isinstance(result, float):
|
|
451
|
+
raise TypeError(f"Function {func.__name__} must return a float, got {type(result).__name__}")
|
|
452
|
+
return result
|
|
453
|
+
|
|
454
|
+
return wrapper
|
|
@@ -0,0 +1,241 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: osmosis-ai
|
|
3
|
+
Version: 0.2.2
|
|
4
|
+
Summary: A Python library for reward function validation with strict type enforcement.
|
|
5
|
+
Author-email: Osmosis AI <jake@osmosis.ai>
|
|
6
|
+
License: MIT License
|
|
7
|
+
|
|
8
|
+
Copyright (c) 2025 Gulp AI
|
|
9
|
+
|
|
10
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
11
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
12
|
+
in the Software without restriction, including without limitation the rights
|
|
13
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
14
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
15
|
+
furnished to do so, subject to the following conditions:
|
|
16
|
+
|
|
17
|
+
The above copyright notice and this permission notice shall be included in all
|
|
18
|
+
copies or substantial portions of the Software.
|
|
19
|
+
|
|
20
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
21
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
22
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
23
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
24
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
25
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
26
|
+
SOFTWARE.
|
|
27
|
+
Project-URL: Homepage, https://github.com/Osmosis-AI/osmosis-sdk-python
|
|
28
|
+
Project-URL: Issues, https://github.com/Osmosis-AI/osmosis-sdk-python/issues
|
|
29
|
+
Classifier: Programming Language :: Python :: 3
|
|
30
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
31
|
+
Classifier: Operating System :: OS Independent
|
|
32
|
+
Requires-Python: >=3.6
|
|
33
|
+
Description-Content-Type: text/markdown
|
|
34
|
+
License-File: LICENSE
|
|
35
|
+
Dynamic: license-file
|
|
36
|
+
|
|
37
|
+
# osmosis-ai
|
|
38
|
+
|
|
39
|
+
A Python library that provides reward and rubric validation helpers for LLM applications with strict type enforcement.
|
|
40
|
+
|
|
41
|
+
## Installation
|
|
42
|
+
|
|
43
|
+
```bash
|
|
44
|
+
pip install osmosis-ai
|
|
45
|
+
```
|
|
46
|
+
|
|
47
|
+
For development:
|
|
48
|
+
```bash
|
|
49
|
+
git clone https://github.com/Osmosis-AI/osmosis-sdk-python
|
|
50
|
+
cd osmosis-sdk-python
|
|
51
|
+
pip install -e .
|
|
52
|
+
```
|
|
53
|
+
|
|
54
|
+
## Quick Start
|
|
55
|
+
|
|
56
|
+
```python
|
|
57
|
+
from osmosis_ai import osmosis_reward
|
|
58
|
+
|
|
59
|
+
@osmosis_reward
|
|
60
|
+
def simple_reward(solution_str: str, ground_truth: str, extra_info: dict = None) -> float:
|
|
61
|
+
"""Basic exact match reward function."""
|
|
62
|
+
return 1.0 if solution_str.strip() == ground_truth.strip() else 0.0
|
|
63
|
+
|
|
64
|
+
# Use the reward function
|
|
65
|
+
score = simple_reward("hello world", "hello world") # Returns 1.0
|
|
66
|
+
```
|
|
67
|
+
|
|
68
|
+
```python
|
|
69
|
+
from osmosis_ai import evaluate_rubric
|
|
70
|
+
|
|
71
|
+
messages = [
|
|
72
|
+
{
|
|
73
|
+
"type": "message",
|
|
74
|
+
"role": "user",
|
|
75
|
+
"content": [{"type": "input_text", "text": "What is the capital of France?"}],
|
|
76
|
+
},
|
|
77
|
+
{
|
|
78
|
+
"type": "message",
|
|
79
|
+
"role": "assistant",
|
|
80
|
+
"content": [{"type": "output_text", "text": "The capital of France is Paris."}],
|
|
81
|
+
},
|
|
82
|
+
]
|
|
83
|
+
|
|
84
|
+
# Export OPENAI_API_KEY in your shell before running this snippet.
|
|
85
|
+
rubric_score = evaluate_rubric(
|
|
86
|
+
rubric="Assistant must mention the verified capital city.",
|
|
87
|
+
messages=messages,
|
|
88
|
+
model_info={
|
|
89
|
+
"provider": "openai",
|
|
90
|
+
"model": "gpt-5",
|
|
91
|
+
"api_key_env": "OPENAI_API_KEY",
|
|
92
|
+
},
|
|
93
|
+
ground_truth="Paris",
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
print(rubric_score) # -> 1.0 (full payload available via return_details=True)
|
|
97
|
+
```
|
|
98
|
+
|
|
99
|
+
## Remote Rubric Evaluation
|
|
100
|
+
|
|
101
|
+
`evaluate_rubric` talks to each provider through its official Python SDK while enforcing the same JSON schema everywhere:
|
|
102
|
+
|
|
103
|
+
- **OpenAI / xAI** – Uses `OpenAI(...).responses.create` (or `chat.completions.create`) with `response_format={"type": "json_schema"}` and falls back to `json_object` when needed.
|
|
104
|
+
- **Anthropic** – Forces a tool call with a JSON schema via `Anthropic(...).messages.create`, extracting the returned tool arguments.
|
|
105
|
+
- **Google Gemini** – Invokes `google.genai.Client(...).models.generate_content` with `response_mime_type="application/json"` and `response_schema`.
|
|
106
|
+
|
|
107
|
+
Every provider therefore returns a strict JSON object with `{"score": number, "explanation": string}`. The helper clamps the score into your configured range, validates the structure, and exposes the raw payload when `return_details=True`.
|
|
108
|
+
|
|
109
|
+
Credentials are resolved from environment variables by default:
|
|
110
|
+
|
|
111
|
+
- `OPENAI_API_KEY` for OpenAI
|
|
112
|
+
- `ANTHROPIC_API_KEY` for Anthropic
|
|
113
|
+
- `GOOGLE_API_KEY` for Google Gemini
|
|
114
|
+
- `XAI_API_KEY` for xAI
|
|
115
|
+
|
|
116
|
+
Override the environment variable name with `model_info={"api_key_env": "CUSTOM_ENV_NAME"}` when needed, or supply an inline secret with `model_info={"api_key": "sk-..."}` for ephemeral credentials. Missing API keys raise a `MissingAPIKeyError` that explains how to export the secret before trying again.
|
|
117
|
+
|
|
118
|
+
`model_info` accepts additional rubric-specific knobs:
|
|
119
|
+
|
|
120
|
+
- `score_min` / `score_max` – change the default `[0.0, 1.0]` scoring bounds.
|
|
121
|
+
- `system_prompt` / `original_input` – override the helper’s transcript inference when those entries are absent.
|
|
122
|
+
- `timeout` – customise the provider timeout in seconds.
|
|
123
|
+
|
|
124
|
+
Pass `extra_info={...}` to `evaluate_rubric` when you need structured context quoted in the judge prompt, and set `return_details=True` to receive the full `RewardRubricRunResult` payload (including the provider’s raw response).
|
|
125
|
+
|
|
126
|
+
Remote failures surface as `ProviderRequestError` instances, with `ModelNotFoundError` reserved for missing model identifiers so you can retry with a new snapshot.
|
|
127
|
+
|
|
128
|
+
> Older SDK versions that lack schema parameters automatically fall back to instruction-only JSON; the helper still validates the response payload before returning.
|
|
129
|
+
> Provider model snapshot names change frequently. Check each vendor's dashboard for the latest identifier if you encounter a “model not found” error.
|
|
130
|
+
|
|
131
|
+
### Provider Architecture
|
|
132
|
+
|
|
133
|
+
All remote integrations live in `osmosis_ai/providers/` and implement the `RubricProvider` interface. At import time the default registry registers OpenAI, xAI, Anthropic, and Google Gemini so `evaluate_rubric` can route requests without additional configuration. The request/response plumbing is encapsulated in each provider module, keeping `evaluate_rubric` focused on prompt construction, payload validation, and credential resolution.
|
|
134
|
+
|
|
135
|
+
Add your own provider by subclassing `RubricProvider`, implementing `run()` with the vendor SDK, and calling `register_provider()` during start-up. A step-by-step guide is available in [`osmosis_ai/providers/README.md`](osmosis_ai/providers/README.md).
|
|
136
|
+
|
|
137
|
+
## Required Function Signature
|
|
138
|
+
|
|
139
|
+
All functions decorated with `@osmosis_reward` must have exactly this signature:
|
|
140
|
+
|
|
141
|
+
```python
|
|
142
|
+
@osmosis_reward
|
|
143
|
+
def your_function(solution_str: str, ground_truth: str, extra_info: dict = None) -> float:
|
|
144
|
+
# Your reward logic here
|
|
145
|
+
return float_score
|
|
146
|
+
```
|
|
147
|
+
|
|
148
|
+
### Parameters
|
|
149
|
+
|
|
150
|
+
- **`solution_str: str`** - The solution string to evaluate (required)
|
|
151
|
+
- **`ground_truth: str`** - The correct/expected answer (required)
|
|
152
|
+
- **`extra_info: dict = None`** - Optional dictionary for additional configuration
|
|
153
|
+
|
|
154
|
+
### Return Value
|
|
155
|
+
|
|
156
|
+
- **`-> float`** - Must return a float value representing the reward score
|
|
157
|
+
|
|
158
|
+
The decorator will raise a `TypeError` if the function doesn't match this exact signature or doesn't return a float.
|
|
159
|
+
|
|
160
|
+
## Rubric Function Signature
|
|
161
|
+
|
|
162
|
+
Rubric functions decorated with `@osmosis_rubric` must accept the parameters:
|
|
163
|
+
|
|
164
|
+
- `model_info: dict`
|
|
165
|
+
- `rubric: str`
|
|
166
|
+
- `messages: list`
|
|
167
|
+
- `ground_truth: Optional[str] = None`
|
|
168
|
+
- `system_message: Optional[str] = None`
|
|
169
|
+
- `extra_info: dict = None`
|
|
170
|
+
- `score_min: float = 0.0` *(optional lower bound; must default to 0.0 and stay below `score_max`)*
|
|
171
|
+
- `score_max: float = 1.0` *(optional upper bound; must default to 1.0 and stay above `score_min`)*
|
|
172
|
+
|
|
173
|
+
and must return a `float`. The decorator validates the signature and runtime payload (including message role validation and return type) before delegating to your custom logic.
|
|
174
|
+
|
|
175
|
+
> Required fields: `model_info` must contain non-empty `provider` and `model` string entries.
|
|
176
|
+
|
|
177
|
+
> Annotation quirk: `extra_info` must be annotated as a plain `dict` with a default of `None` to satisfy the validator.
|
|
178
|
+
|
|
179
|
+
> Tip: You can call `evaluate_rubric` from inside a rubric function (or any other orchestrator) to outsource judging to a hosted model while still benefiting from the decorator’s validation.
|
|
180
|
+
|
|
181
|
+
## Examples
|
|
182
|
+
|
|
183
|
+
See the [`examples/`](examples/) directory for complete examples:
|
|
184
|
+
|
|
185
|
+
```python
|
|
186
|
+
@osmosis_reward
|
|
187
|
+
def case_insensitive_match(solution_str: str, ground_truth: str, extra_info: dict = None) -> float:
|
|
188
|
+
"""Case-insensitive string matching with partial credit."""
|
|
189
|
+
match = solution_str.lower().strip() == ground_truth.lower().strip()
|
|
190
|
+
|
|
191
|
+
if extra_info and 'partial_credit' in extra_info:
|
|
192
|
+
if not match and extra_info['partial_credit']:
|
|
193
|
+
len_diff = abs(len(solution_str) - len(ground_truth))
|
|
194
|
+
if len_diff <= 2:
|
|
195
|
+
return 0.5
|
|
196
|
+
|
|
197
|
+
return 1.0 if match else 0.0
|
|
198
|
+
|
|
199
|
+
@osmosis_reward
|
|
200
|
+
def numeric_tolerance(solution_str: str, ground_truth: str, extra_info: dict = None) -> float:
|
|
201
|
+
"""Numeric comparison with configurable tolerance."""
|
|
202
|
+
try:
|
|
203
|
+
solution_num = float(solution_str.strip())
|
|
204
|
+
truth_num = float(ground_truth.strip())
|
|
205
|
+
|
|
206
|
+
tolerance = extra_info.get('tolerance', 0.01) if extra_info else 0.01
|
|
207
|
+
return 1.0 if abs(solution_num - truth_num) <= tolerance else 0.0
|
|
208
|
+
except ValueError:
|
|
209
|
+
return 0.0
|
|
210
|
+
```
|
|
211
|
+
|
|
212
|
+
- `examples/rubric_functions.py` demonstrates `evaluate_rubric` with OpenAI, Anthropic, Gemini, and xAI using the schema-enforced SDK integrations.
|
|
213
|
+
- `examples/reward_functions.py` keeps local reward helpers that showcase the decorator contract without external calls.
|
|
214
|
+
|
|
215
|
+
## Running Examples
|
|
216
|
+
|
|
217
|
+
```bash
|
|
218
|
+
PYTHONPATH=. python examples/reward_functions.py
|
|
219
|
+
PYTHONPATH=. python examples/rubric_functions.py # Uncomment the provider you need before running
|
|
220
|
+
```
|
|
221
|
+
|
|
222
|
+
## Testing
|
|
223
|
+
|
|
224
|
+
Run `python -m pytest tests/test_rubric_eval.py` to exercise the guards that ensure rubric prompts ignore message metadata (for example `tests/test_rubric_eval.py::test_collect_text_skips_metadata_fields`) while still preserving nested tool output. Add additional tests under `tests/` as you extend the library.
|
|
225
|
+
|
|
226
|
+
## License
|
|
227
|
+
|
|
228
|
+
MIT License - see [LICENSE](LICENSE) file for details.
|
|
229
|
+
|
|
230
|
+
## Contributing
|
|
231
|
+
|
|
232
|
+
1. Fork the repository
|
|
233
|
+
2. Create a feature branch
|
|
234
|
+
3. Make your changes
|
|
235
|
+
4. Run tests and examples
|
|
236
|
+
5. Submit a pull request
|
|
237
|
+
|
|
238
|
+
## Links
|
|
239
|
+
|
|
240
|
+
- [Homepage](https://github.com/Osmosis-AI/osmosis-sdk-python)
|
|
241
|
+
- [Issues](https://github.com/Osmosis-AI/osmosis-sdk-python/issues)
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
osmosis_ai/__init__.py,sha256=2_qXxu18Yc7UicqxFZds8PjR4q0mTY1Xt17iR38OFbw,725
|
|
2
|
+
osmosis_ai/consts.py,sha256=-NDo9FaqBTebkCnhiFDxne6BY0W7BL3oM8HnGQDDgSE,73
|
|
3
|
+
osmosis_ai/rubric_eval.py,sha256=bFgxgnbQeD-7K2LkTJfnSk5aG9s4lefLfmvQt4GQSnM,18332
|
|
4
|
+
osmosis_ai/rubric_types.py,sha256=kJvNAjLd3Y-1Q-_Re9HLTprLAUO3qtwR-IWOBeMkFI8,1279
|
|
5
|
+
osmosis_ai/utils.py,sha256=yjC_oQt1wwTJsX7lCx0ZGMa5txHURByuBDuU37WPAO0,19927
|
|
6
|
+
osmosis_ai/providers/__init__.py,sha256=yLSExLbJToZ8AUOVxt4LDplxtIuwv-etSJJyZOcOE2Q,927
|
|
7
|
+
osmosis_ai/providers/anthropic_provider.py,sha256=zrWCVP8co4v8xhcJDFLASwvwEADKN-1p34cY_GH4q5M,3758
|
|
8
|
+
osmosis_ai/providers/base.py,sha256=fN5cnWXYAHN53RR_x6ykbUkM4bictNPDj4U8yd4b2a0,1492
|
|
9
|
+
osmosis_ai/providers/gemini_provider.py,sha256=xqklXRO5K1YZ4SKq5lfU3bDUaF8QN2MIBP4DHGKwLVo,10611
|
|
10
|
+
osmosis_ai/providers/openai_family.py,sha256=DeQWPMcafEvG4xcI97m3AADTKP2pYw9KwcQTcQg-h_4,26078
|
|
11
|
+
osmosis_ai/providers/shared.py,sha256=dmVe8JDgafPmo6HkP-Kl0aWfffhAT6u3ElV_wLlYD34,2957
|
|
12
|
+
osmosis_ai-0.2.2.dist-info/licenses/LICENSE,sha256=FV2ZmyhdCYinoLLvU_ci-7pZ3DeNYY9XqZjVjOd3h94,1064
|
|
13
|
+
osmosis_ai-0.2.2.dist-info/METADATA,sha256=MPovk4NSQ_viTMd-zx1lp7Uo2EGB3EotjftcNddy4HU,10448
|
|
14
|
+
osmosis_ai-0.2.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
15
|
+
osmosis_ai-0.2.2.dist-info/top_level.txt,sha256=UPNRTKIBSrxsJVNxwXnLCqSoBS4bAiL_3jMtjvf5zEY,11
|
|
16
|
+
osmosis_ai-0.2.2.dist-info/RECORD,,
|