osmosis-ai 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of osmosis-ai might be problematic. Click here for more details.

osmosis_ai/utils.py CHANGED
@@ -1,7 +1,8 @@
1
1
 
2
2
  import functools
3
3
  import inspect
4
- from typing import Callable
4
+ import types
5
+ from typing import Any, Callable, Mapping, Union, get_args, get_origin, get_type_hints
5
6
 
6
7
 
7
8
  def osmosis_reward(func: Callable) -> Callable:
@@ -27,10 +28,6 @@ def osmosis_reward(func: Callable) -> Callable:
27
28
  sig = inspect.signature(func)
28
29
  params = list(sig.parameters.values())
29
30
 
30
- # Check parameter count
31
- if len(params) < 2 or len(params) > 3:
32
- raise TypeError(f"Function {func.__name__} must have 2-3 parameters, got {len(params)}")
33
-
34
31
  # Check first parameter: solution_str: str
35
32
  if params[0].name != 'solution_str':
36
33
  raise TypeError(f"First parameter must be named 'solution_str', got '{params[0].name}'")
@@ -61,3 +58,393 @@ def osmosis_reward(func: Callable) -> Callable:
61
58
  return result
62
59
 
63
60
  return wrapper
61
+
62
+
63
+ ALLOWED_ROLES = {"user", "system", "assistant", "developer", "tool", "function"}
64
+
65
+ _UNION_TYPES = {Union}
66
+ _types_union_type = getattr(types, "UnionType", None)
67
+ if _types_union_type is not None:
68
+ _UNION_TYPES.add(_types_union_type)
69
+
70
+
71
+ def _is_str_annotation(annotation: Any) -> bool:
72
+ if annotation is inspect.Parameter.empty:
73
+ return False
74
+ if annotation is str:
75
+ return True
76
+ if isinstance(annotation, str):
77
+ return annotation in {"str", "builtins.str"}
78
+ if isinstance(annotation, type):
79
+ try:
80
+ return issubclass(annotation, str)
81
+ except TypeError:
82
+ return False
83
+ forward_arg = getattr(annotation, "__forward_arg__", None)
84
+ if isinstance(forward_arg, str):
85
+ return forward_arg in {"str", "builtins.str"}
86
+ return False
87
+
88
+
89
+ def _is_optional_str(annotation: Any) -> bool:
90
+ if _is_str_annotation(annotation):
91
+ return True
92
+ if isinstance(annotation, str):
93
+ normalized = annotation.replace(" ", "")
94
+ if normalized in {
95
+ "Optional[str]",
96
+ "typing.Optional[str]",
97
+ "Str|None",
98
+ "str|None",
99
+ "builtins.str|None",
100
+ "None|str",
101
+ "None|builtins.str",
102
+ }:
103
+ return True
104
+ origin = get_origin(annotation)
105
+ if origin in _UNION_TYPES:
106
+ args = tuple(arg for arg in get_args(annotation) if arg is not type(None)) # noqa: E721
107
+ return len(args) == 1 and _is_str_annotation(args[0])
108
+ return False
109
+
110
+
111
+ def _is_list_annotation(annotation: Any) -> bool:
112
+ if annotation is list:
113
+ return True
114
+ if isinstance(annotation, str):
115
+ normalized = annotation.replace(" ", "")
116
+ return (
117
+ normalized in {"list", "builtins.list", "typing.List", "List"}
118
+ or normalized.startswith("list[")
119
+ or normalized.startswith("builtins.list[")
120
+ or normalized.startswith("typing.List[")
121
+ or normalized.startswith("List[")
122
+ )
123
+ origin = get_origin(annotation)
124
+ return origin is list
125
+
126
+
127
+ def _is_float_annotation(annotation: Any) -> bool:
128
+ if annotation in {inspect.Parameter.empty, float}:
129
+ return True
130
+ if isinstance(annotation, str):
131
+ return annotation in {"float", "builtins.float"}
132
+ origin = get_origin(annotation)
133
+ return origin is float
134
+
135
+
136
+ def _is_numeric(value: Any) -> bool:
137
+ return isinstance(value, (int, float)) and not isinstance(value, bool)
138
+
139
+
140
+ def _is_dict_annotation(annotation: Any) -> bool:
141
+ if annotation in {dict, Mapping}:
142
+ return True
143
+ origin = get_origin(annotation)
144
+ if origin in {dict, Mapping}:
145
+ return True
146
+ if isinstance(annotation, type):
147
+ try:
148
+ return issubclass(annotation, dict)
149
+ except TypeError:
150
+ return False
151
+ if isinstance(annotation, str):
152
+ normalized = annotation.replace(" ", "")
153
+ return (
154
+ normalized in {"dict", "builtins.dict", "typing.Mapping", "collections.abc.Mapping", "Mapping"}
155
+ or normalized.startswith("dict[")
156
+ or normalized.startswith("builtins.dict[")
157
+ or normalized.startswith("typing.Dict[")
158
+ or normalized.startswith("Dict[")
159
+ or normalized.startswith("typing.Mapping[")
160
+ or normalized.startswith("Mapping[")
161
+ )
162
+ return False
163
+
164
+
165
+ def osmosis_rubric(func: Callable) -> Callable:
166
+ """
167
+ Decorator for rubric functions that enforces the signature:
168
+ (model_info: dict, rubric: str, messages: list, ground_truth: Optional[str] = None,
169
+ system_message: Optional[str] = None, extra_info: dict = None,
170
+ score_min: float = 0.0, score_max: float = 1.0) -> float
171
+
172
+ The `model_info` mapping must provide non-empty string entries for both `provider` and `model`.
173
+
174
+ Args:
175
+ func: The rubric function to be wrapped
176
+
177
+ Returns:
178
+ The wrapped function
179
+
180
+ Raises:
181
+ TypeError: If the function doesn't have the required signature or doesn't return a float
182
+
183
+ Example:
184
+ @osmosis_rubric
185
+ def evaluate_response(
186
+ model_info: dict,
187
+ rubric: str,
188
+ messages: list,
189
+ ground_truth: str | None = None,
190
+ system_message: str | None = None,
191
+ extra_info: dict = None,
192
+ score_min: float = 0.0,
193
+ score_max: float = 1.0,
194
+ ) -> float:
195
+ return some_evaluation(model_info, messages, ground_truth)
196
+ """
197
+ # Validate function signature
198
+ sig = inspect.signature(func)
199
+ params = list(sig.parameters.values())
200
+ try:
201
+ resolved_annotations = get_type_hints(
202
+ func,
203
+ globalns=getattr(func, "__globals__", {}),
204
+ include_extras=True,
205
+ )
206
+ except Exception: # pragma: no cover - best effort for forward refs
207
+ resolved_annotations = {}
208
+
209
+ # Check parameter count
210
+ if len(params) < 3 or len(params) > 8:
211
+ raise TypeError(f"Function {func.__name__} must have between 3 and 8 parameters, got {len(params)}")
212
+
213
+ # Check first parameter: model_info: dict
214
+ model_info_param = params[0]
215
+ if model_info_param.name != "model_info":
216
+ raise TypeError(f"First parameter must be named 'model_info', got '{model_info_param.name}'")
217
+ model_info_annotation = resolved_annotations.get(model_info_param.name, model_info_param.annotation)
218
+ if not _is_dict_annotation(model_info_annotation):
219
+ raise TypeError(
220
+ f"First parameter 'model_info' must be annotated as a dict or mapping, got {model_info_annotation}"
221
+ )
222
+ if model_info_param.default is not inspect.Parameter.empty:
223
+ raise TypeError("First parameter 'model_info' cannot have a default value")
224
+
225
+ # Check second parameter: rubric: str
226
+ rubric_param = params[1]
227
+ if rubric_param.name != "rubric":
228
+ raise TypeError(f"Second parameter must be named 'rubric', got '{rubric_param.name}'")
229
+ rubric_annotation = resolved_annotations.get(rubric_param.name, rubric_param.annotation)
230
+ if not _is_str_annotation(rubric_annotation):
231
+ raise TypeError(f"Second parameter 'rubric' must be annotated as str, got {rubric_annotation}")
232
+ if rubric_param.default is not inspect.Parameter.empty:
233
+ raise TypeError("Second parameter 'rubric' cannot have a default value")
234
+
235
+ # Check third parameter: messages: list
236
+ messages_param = params[2]
237
+ if messages_param.name != "messages":
238
+ raise TypeError(f"Third parameter must be named 'messages', got '{messages_param.name}'")
239
+ messages_annotation = resolved_annotations.get(messages_param.name, messages_param.annotation)
240
+ if messages_annotation is inspect.Parameter.empty:
241
+ raise TypeError("Third parameter 'messages' must be annotated as list")
242
+ if not _is_list_annotation(messages_annotation):
243
+ raise TypeError(f"Third parameter 'messages' must be annotated as list, got {messages_annotation}")
244
+ if messages_param.default is not inspect.Parameter.empty:
245
+ raise TypeError("Third parameter 'messages' cannot have a default value")
246
+
247
+ optional_params = params[3:]
248
+
249
+ if optional_params:
250
+ ground_truth_param = optional_params[0]
251
+ # Check fourth parameter: ground_truth: Optional[str]
252
+ if ground_truth_param.name != "ground_truth":
253
+ raise TypeError(f"Fourth parameter must be named 'ground_truth', got '{ground_truth_param.name}'")
254
+ ground_truth_annotation = resolved_annotations.get(
255
+ ground_truth_param.name,
256
+ ground_truth_param.annotation,
257
+ )
258
+ if ground_truth_annotation is inspect.Parameter.empty or not _is_optional_str(ground_truth_annotation):
259
+ raise TypeError(
260
+ "Fourth parameter 'ground_truth' must be annotated as Optional[str] or str"
261
+ )
262
+ if ground_truth_param.default is inspect.Parameter.empty:
263
+ raise TypeError("Fourth parameter 'ground_truth' must have a default value of None")
264
+ if ground_truth_param.default is not None:
265
+ raise TypeError("Fourth parameter 'ground_truth' must default to None")
266
+ optional_params = optional_params[1:]
267
+
268
+ if optional_params:
269
+ system_message_param = optional_params[0]
270
+ # Check fifth parameter: system_message: Optional[str]
271
+ if system_message_param.name != "system_message":
272
+ raise TypeError(f"Fifth parameter must be named 'system_message', got '{system_message_param.name}'")
273
+ system_message_annotation = resolved_annotations.get(
274
+ system_message_param.name,
275
+ system_message_param.annotation,
276
+ )
277
+ if system_message_annotation is inspect.Parameter.empty or not _is_optional_str(system_message_annotation):
278
+ raise TypeError(
279
+ "Fifth parameter 'system_message' must be annotated as Optional[str] or str"
280
+ )
281
+ if system_message_param.default is inspect.Parameter.empty:
282
+ raise TypeError("Fifth parameter 'system_message' must have a default value of None")
283
+ if system_message_param.default is not None:
284
+ raise TypeError("Fifth parameter 'system_message' must default to None")
285
+ optional_params = optional_params[1:]
286
+
287
+ if optional_params:
288
+ extra_info_param = optional_params[0]
289
+ # Check sixth parameter: extra_info: dict = None
290
+ if extra_info_param.name != "extra_info":
291
+ raise TypeError(f"Sixth parameter must be named 'extra_info', got '{extra_info_param.name}'")
292
+ extra_info_annotation = resolved_annotations.get(
293
+ extra_info_param.name,
294
+ extra_info_param.annotation,
295
+ )
296
+ if extra_info_annotation is inspect.Parameter.empty or not _is_dict_annotation(extra_info_annotation):
297
+ raise TypeError(
298
+ f"Sixth parameter 'extra_info' must be annotated as dict, got {extra_info_annotation}"
299
+ )
300
+ if extra_info_param.default is inspect.Parameter.empty:
301
+ raise TypeError("Sixth parameter 'extra_info' must have a default value of None")
302
+ if extra_info_param.default is not None:
303
+ raise TypeError("Sixth parameter 'extra_info' must default to None")
304
+ optional_params = optional_params[1:]
305
+
306
+ if optional_params:
307
+ score_min_param = optional_params[0]
308
+ # Check seventh parameter: score_min: float = 0.0
309
+ if score_min_param.name != "score_min":
310
+ raise TypeError(f"Seventh parameter must be named 'score_min', got '{score_min_param.name}'")
311
+ score_min_annotation = resolved_annotations.get(
312
+ score_min_param.name,
313
+ score_min_param.annotation,
314
+ )
315
+ if not _is_float_annotation(score_min_annotation):
316
+ raise TypeError(
317
+ f"Seventh parameter 'score_min' must be annotated as float, got {score_min_annotation}"
318
+ )
319
+ if score_min_param.default is inspect.Parameter.empty:
320
+ raise TypeError("Seventh parameter 'score_min' must have a default value of 0.0")
321
+ if not _is_numeric(score_min_param.default):
322
+ raise TypeError("Seventh parameter 'score_min' must default to a numeric value")
323
+ if float(score_min_param.default) != 0.0:
324
+ raise TypeError("Seventh parameter 'score_min' must default to 0.0")
325
+ optional_params = optional_params[1:]
326
+
327
+ if optional_params:
328
+ score_max_param = optional_params[0]
329
+ # Check eighth parameter: score_max: float = 1.0
330
+ if score_max_param.name != "score_max":
331
+ raise TypeError(f"Eighth parameter must be named 'score_max', got '{score_max_param.name}'")
332
+ score_max_annotation = resolved_annotations.get(
333
+ score_max_param.name,
334
+ score_max_param.annotation,
335
+ )
336
+ if not _is_float_annotation(score_max_annotation):
337
+ raise TypeError(
338
+ f"Eighth parameter 'score_max' must be annotated as float, got {score_max_annotation}"
339
+ )
340
+ if score_max_param.default is inspect.Parameter.empty:
341
+ raise TypeError("Eighth parameter 'score_max' must have a default value of 1.0")
342
+ if not _is_numeric(score_max_param.default):
343
+ raise TypeError("Eighth parameter 'score_max' must default to a numeric value")
344
+ if float(score_max_param.default) != 1.0:
345
+ raise TypeError("Eighth parameter 'score_max' must default to 1.0")
346
+ optional_params = optional_params[1:]
347
+
348
+ if optional_params:
349
+ unexpected_param = optional_params[0]
350
+ raise TypeError(f"Function {func.__name__} has unexpected parameter '{unexpected_param.name}'")
351
+
352
+ @functools.wraps(func)
353
+ def wrapper(*args, **kwargs):
354
+ # Remove unsupported kwargs
355
+ kwargs.pop("data_source", None)
356
+ bound = sig.bind_partial(*args, **kwargs)
357
+ bound.apply_defaults()
358
+
359
+ # Validate model_info argument
360
+ if "model_info" not in bound.arguments:
361
+ raise TypeError("'model_info' argument is required")
362
+ model_info_value = bound.arguments["model_info"]
363
+ if not isinstance(model_info_value, Mapping):
364
+ raise TypeError(f"'model_info' must be a mapping, got {type(model_info_value).__name__}")
365
+ required_model_fields = {"provider", "model"}
366
+ missing_model_fields = required_model_fields - set(model_info_value.keys())
367
+ if missing_model_fields:
368
+ raise ValueError(f"'model_info' is missing required fields: {sorted(missing_model_fields)}")
369
+ provider_value = model_info_value.get("provider")
370
+ if not isinstance(provider_value, str) or not provider_value.strip():
371
+ raise TypeError("'model_info[\"provider\"]' must be a non-empty string")
372
+ model_value = model_info_value.get("model")
373
+ if not isinstance(model_value, str) or not model_value.strip():
374
+ raise TypeError("'model_info[\"model\"]' must be a non-empty string")
375
+
376
+ # Validate rubric argument
377
+ if "rubric" not in bound.arguments:
378
+ raise TypeError("'rubric' argument is required")
379
+ rubric_value = bound.arguments["rubric"]
380
+ if not isinstance(rubric_value, str):
381
+ raise TypeError(f"'rubric' must be a string, got {type(rubric_value).__name__}")
382
+
383
+ # Validate messages argument
384
+ if "messages" not in bound.arguments:
385
+ raise TypeError("'messages' argument is required")
386
+ messages_value = bound.arguments["messages"]
387
+ if not isinstance(messages_value, list):
388
+ raise TypeError(f"'messages' must be a list, got {type(messages_value).__name__}")
389
+
390
+ # Validate optional ground_truth argument
391
+ ground_truth_value = bound.arguments.get("ground_truth")
392
+ if ground_truth_value is not None and not isinstance(ground_truth_value, str):
393
+ raise TypeError(
394
+ f"'ground_truth' must be a string or None, got {type(ground_truth_value).__name__}"
395
+ )
396
+
397
+ # Validate optional system_message argument
398
+ system_message_value = bound.arguments.get("system_message")
399
+ if system_message_value is not None and not isinstance(system_message_value, str):
400
+ raise TypeError(
401
+ f"'system_message' must be a string or None, got {type(system_message_value).__name__}"
402
+ )
403
+
404
+ # Validate messages structure
405
+ for index, message in enumerate(messages_value):
406
+ if not isinstance(message, dict):
407
+ raise TypeError(f"'messages[{index}]' must be a dict, got {type(message).__name__}")
408
+ missing_fields = {"type", "role", "content"} - message.keys()
409
+ if missing_fields:
410
+ raise ValueError(f"'messages[{index}]' is missing required fields: {missing_fields}")
411
+ if message["role"] not in ALLOWED_ROLES:
412
+ raise ValueError(
413
+ f"'messages[{index}]['role']' must be one of {sorted(ALLOWED_ROLES)}, "
414
+ f"got '{message['role']}'"
415
+ )
416
+ if not isinstance(message["content"], list):
417
+ raise TypeError(f"'messages[{index}]['content']' must be a list")
418
+
419
+ score_min_present = "score_min" in bound.arguments
420
+ score_max_present = "score_max" in bound.arguments
421
+
422
+ if score_min_present:
423
+ score_min_value = bound.arguments["score_min"]
424
+ if not _is_numeric(score_min_value):
425
+ raise TypeError(
426
+ f"'score_min' must be a numeric type, got {type(score_min_value).__name__}"
427
+ )
428
+ else:
429
+ score_min_value = None
430
+
431
+ if score_max_present:
432
+ score_max_value = bound.arguments["score_max"]
433
+ if not _is_numeric(score_max_value):
434
+ raise TypeError(
435
+ f"'score_max' must be a numeric type, got {type(score_max_value).__name__}"
436
+ )
437
+ else:
438
+ score_max_value = None
439
+
440
+ if score_min_present and score_max_present:
441
+ if float(score_max_value) <= float(score_min_value):
442
+ raise ValueError("'score_max' must be greater than 'score_min'")
443
+
444
+ # Validate return type
445
+ result = func(*args, **kwargs)
446
+ if not isinstance(result, float):
447
+ raise TypeError(f"Function {func.__name__} must return a float, got {type(result).__name__}")
448
+ return result
449
+
450
+ return wrapper