osmosis-ai 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of osmosis-ai might be problematic. Click here for more details.
- osmosis_ai/__init__.py +13 -4
- osmosis_ai/cli.py +50 -0
- osmosis_ai/cli_commands.py +181 -0
- osmosis_ai/cli_services/__init__.py +67 -0
- osmosis_ai/cli_services/config.py +407 -0
- osmosis_ai/cli_services/dataset.py +229 -0
- osmosis_ai/cli_services/engine.py +251 -0
- osmosis_ai/cli_services/errors.py +7 -0
- osmosis_ai/cli_services/reporting.py +307 -0
- osmosis_ai/cli_services/session.py +174 -0
- osmosis_ai/cli_services/shared.py +209 -0
- osmosis_ai/consts.py +1 -1
- osmosis_ai/providers/__init__.py +36 -0
- osmosis_ai/providers/anthropic_provider.py +85 -0
- osmosis_ai/providers/base.py +60 -0
- osmosis_ai/providers/gemini_provider.py +314 -0
- osmosis_ai/providers/openai_family.py +607 -0
- osmosis_ai/providers/shared.py +92 -0
- osmosis_ai/rubric_eval.py +498 -0
- osmosis_ai/rubric_types.py +49 -0
- osmosis_ai/utils.py +392 -5
- osmosis_ai-0.2.3.dist-info/METADATA +303 -0
- osmosis_ai-0.2.3.dist-info/RECORD +27 -0
- osmosis_ai-0.2.3.dist-info/entry_points.txt +4 -0
- osmosis_ai-0.2.1.dist-info/METADATA +0 -143
- osmosis_ai-0.2.1.dist-info/RECORD +0 -8
- {osmosis_ai-0.2.1.dist-info → osmosis_ai-0.2.3.dist-info}/WHEEL +0 -0
- {osmosis_ai-0.2.1.dist-info → osmosis_ai-0.2.3.dist-info}/licenses/LICENSE +0 -0
- {osmosis_ai-0.2.1.dist-info → osmosis_ai-0.2.3.dist-info}/top_level.txt +0 -0
osmosis_ai/utils.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
|
|
2
2
|
import functools
|
|
3
3
|
import inspect
|
|
4
|
-
|
|
4
|
+
import types
|
|
5
|
+
from typing import Any, Callable, Mapping, Union, get_args, get_origin, get_type_hints
|
|
5
6
|
|
|
6
7
|
|
|
7
8
|
def osmosis_reward(func: Callable) -> Callable:
|
|
@@ -27,10 +28,6 @@ def osmosis_reward(func: Callable) -> Callable:
|
|
|
27
28
|
sig = inspect.signature(func)
|
|
28
29
|
params = list(sig.parameters.values())
|
|
29
30
|
|
|
30
|
-
# Check parameter count
|
|
31
|
-
if len(params) < 2 or len(params) > 3:
|
|
32
|
-
raise TypeError(f"Function {func.__name__} must have 2-3 parameters, got {len(params)}")
|
|
33
|
-
|
|
34
31
|
# Check first parameter: solution_str: str
|
|
35
32
|
if params[0].name != 'solution_str':
|
|
36
33
|
raise TypeError(f"First parameter must be named 'solution_str', got '{params[0].name}'")
|
|
@@ -61,3 +58,393 @@ def osmosis_reward(func: Callable) -> Callable:
|
|
|
61
58
|
return result
|
|
62
59
|
|
|
63
60
|
return wrapper
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
ALLOWED_ROLES = {"user", "system", "assistant", "developer", "tool", "function"}
|
|
64
|
+
|
|
65
|
+
_UNION_TYPES = {Union}
|
|
66
|
+
_types_union_type = getattr(types, "UnionType", None)
|
|
67
|
+
if _types_union_type is not None:
|
|
68
|
+
_UNION_TYPES.add(_types_union_type)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _is_str_annotation(annotation: Any) -> bool:
|
|
72
|
+
if annotation is inspect.Parameter.empty:
|
|
73
|
+
return False
|
|
74
|
+
if annotation is str:
|
|
75
|
+
return True
|
|
76
|
+
if isinstance(annotation, str):
|
|
77
|
+
return annotation in {"str", "builtins.str"}
|
|
78
|
+
if isinstance(annotation, type):
|
|
79
|
+
try:
|
|
80
|
+
return issubclass(annotation, str)
|
|
81
|
+
except TypeError:
|
|
82
|
+
return False
|
|
83
|
+
forward_arg = getattr(annotation, "__forward_arg__", None)
|
|
84
|
+
if isinstance(forward_arg, str):
|
|
85
|
+
return forward_arg in {"str", "builtins.str"}
|
|
86
|
+
return False
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _is_optional_str(annotation: Any) -> bool:
|
|
90
|
+
if _is_str_annotation(annotation):
|
|
91
|
+
return True
|
|
92
|
+
if isinstance(annotation, str):
|
|
93
|
+
normalized = annotation.replace(" ", "")
|
|
94
|
+
if normalized in {
|
|
95
|
+
"Optional[str]",
|
|
96
|
+
"typing.Optional[str]",
|
|
97
|
+
"Str|None",
|
|
98
|
+
"str|None",
|
|
99
|
+
"builtins.str|None",
|
|
100
|
+
"None|str",
|
|
101
|
+
"None|builtins.str",
|
|
102
|
+
}:
|
|
103
|
+
return True
|
|
104
|
+
origin = get_origin(annotation)
|
|
105
|
+
if origin in _UNION_TYPES:
|
|
106
|
+
args = tuple(arg for arg in get_args(annotation) if arg is not type(None)) # noqa: E721
|
|
107
|
+
return len(args) == 1 and _is_str_annotation(args[0])
|
|
108
|
+
return False
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _is_list_annotation(annotation: Any) -> bool:
|
|
112
|
+
if annotation is list:
|
|
113
|
+
return True
|
|
114
|
+
if isinstance(annotation, str):
|
|
115
|
+
normalized = annotation.replace(" ", "")
|
|
116
|
+
return (
|
|
117
|
+
normalized in {"list", "builtins.list", "typing.List", "List"}
|
|
118
|
+
or normalized.startswith("list[")
|
|
119
|
+
or normalized.startswith("builtins.list[")
|
|
120
|
+
or normalized.startswith("typing.List[")
|
|
121
|
+
or normalized.startswith("List[")
|
|
122
|
+
)
|
|
123
|
+
origin = get_origin(annotation)
|
|
124
|
+
return origin is list
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def _is_float_annotation(annotation: Any) -> bool:
|
|
128
|
+
if annotation in {inspect.Parameter.empty, float}:
|
|
129
|
+
return True
|
|
130
|
+
if isinstance(annotation, str):
|
|
131
|
+
return annotation in {"float", "builtins.float"}
|
|
132
|
+
origin = get_origin(annotation)
|
|
133
|
+
return origin is float
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def _is_numeric(value: Any) -> bool:
|
|
137
|
+
return isinstance(value, (int, float)) and not isinstance(value, bool)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def _is_dict_annotation(annotation: Any) -> bool:
|
|
141
|
+
if annotation in {dict, Mapping}:
|
|
142
|
+
return True
|
|
143
|
+
origin = get_origin(annotation)
|
|
144
|
+
if origin in {dict, Mapping}:
|
|
145
|
+
return True
|
|
146
|
+
if isinstance(annotation, type):
|
|
147
|
+
try:
|
|
148
|
+
return issubclass(annotation, dict)
|
|
149
|
+
except TypeError:
|
|
150
|
+
return False
|
|
151
|
+
if isinstance(annotation, str):
|
|
152
|
+
normalized = annotation.replace(" ", "")
|
|
153
|
+
return (
|
|
154
|
+
normalized in {"dict", "builtins.dict", "typing.Mapping", "collections.abc.Mapping", "Mapping"}
|
|
155
|
+
or normalized.startswith("dict[")
|
|
156
|
+
or normalized.startswith("builtins.dict[")
|
|
157
|
+
or normalized.startswith("typing.Dict[")
|
|
158
|
+
or normalized.startswith("Dict[")
|
|
159
|
+
or normalized.startswith("typing.Mapping[")
|
|
160
|
+
or normalized.startswith("Mapping[")
|
|
161
|
+
)
|
|
162
|
+
return False
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def osmosis_rubric(func: Callable) -> Callable:
|
|
166
|
+
"""
|
|
167
|
+
Decorator for rubric functions that enforces the signature:
|
|
168
|
+
(model_info: dict, rubric: str, messages: list, ground_truth: Optional[str] = None,
|
|
169
|
+
system_message: Optional[str] = None, extra_info: dict = None,
|
|
170
|
+
score_min: float = 0.0, score_max: float = 1.0) -> float
|
|
171
|
+
|
|
172
|
+
The `model_info` mapping must provide non-empty string entries for both `provider` and `model`.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
func: The rubric function to be wrapped
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
The wrapped function
|
|
179
|
+
|
|
180
|
+
Raises:
|
|
181
|
+
TypeError: If the function doesn't have the required signature or doesn't return a float
|
|
182
|
+
|
|
183
|
+
Example:
|
|
184
|
+
@osmosis_rubric
|
|
185
|
+
def evaluate_response(
|
|
186
|
+
model_info: dict,
|
|
187
|
+
rubric: str,
|
|
188
|
+
messages: list,
|
|
189
|
+
ground_truth: str | None = None,
|
|
190
|
+
system_message: str | None = None,
|
|
191
|
+
extra_info: dict = None,
|
|
192
|
+
score_min: float = 0.0,
|
|
193
|
+
score_max: float = 1.0,
|
|
194
|
+
) -> float:
|
|
195
|
+
return some_evaluation(model_info, messages, ground_truth)
|
|
196
|
+
"""
|
|
197
|
+
# Validate function signature
|
|
198
|
+
sig = inspect.signature(func)
|
|
199
|
+
params = list(sig.parameters.values())
|
|
200
|
+
try:
|
|
201
|
+
resolved_annotations = get_type_hints(
|
|
202
|
+
func,
|
|
203
|
+
globalns=getattr(func, "__globals__", {}),
|
|
204
|
+
include_extras=True,
|
|
205
|
+
)
|
|
206
|
+
except Exception: # pragma: no cover - best effort for forward refs
|
|
207
|
+
resolved_annotations = {}
|
|
208
|
+
|
|
209
|
+
# Check parameter count
|
|
210
|
+
if len(params) < 3 or len(params) > 8:
|
|
211
|
+
raise TypeError(f"Function {func.__name__} must have between 3 and 8 parameters, got {len(params)}")
|
|
212
|
+
|
|
213
|
+
# Check first parameter: model_info: dict
|
|
214
|
+
model_info_param = params[0]
|
|
215
|
+
if model_info_param.name != "model_info":
|
|
216
|
+
raise TypeError(f"First parameter must be named 'model_info', got '{model_info_param.name}'")
|
|
217
|
+
model_info_annotation = resolved_annotations.get(model_info_param.name, model_info_param.annotation)
|
|
218
|
+
if not _is_dict_annotation(model_info_annotation):
|
|
219
|
+
raise TypeError(
|
|
220
|
+
f"First parameter 'model_info' must be annotated as a dict or mapping, got {model_info_annotation}"
|
|
221
|
+
)
|
|
222
|
+
if model_info_param.default is not inspect.Parameter.empty:
|
|
223
|
+
raise TypeError("First parameter 'model_info' cannot have a default value")
|
|
224
|
+
|
|
225
|
+
# Check second parameter: rubric: str
|
|
226
|
+
rubric_param = params[1]
|
|
227
|
+
if rubric_param.name != "rubric":
|
|
228
|
+
raise TypeError(f"Second parameter must be named 'rubric', got '{rubric_param.name}'")
|
|
229
|
+
rubric_annotation = resolved_annotations.get(rubric_param.name, rubric_param.annotation)
|
|
230
|
+
if not _is_str_annotation(rubric_annotation):
|
|
231
|
+
raise TypeError(f"Second parameter 'rubric' must be annotated as str, got {rubric_annotation}")
|
|
232
|
+
if rubric_param.default is not inspect.Parameter.empty:
|
|
233
|
+
raise TypeError("Second parameter 'rubric' cannot have a default value")
|
|
234
|
+
|
|
235
|
+
# Check third parameter: messages: list
|
|
236
|
+
messages_param = params[2]
|
|
237
|
+
if messages_param.name != "messages":
|
|
238
|
+
raise TypeError(f"Third parameter must be named 'messages', got '{messages_param.name}'")
|
|
239
|
+
messages_annotation = resolved_annotations.get(messages_param.name, messages_param.annotation)
|
|
240
|
+
if messages_annotation is inspect.Parameter.empty:
|
|
241
|
+
raise TypeError("Third parameter 'messages' must be annotated as list")
|
|
242
|
+
if not _is_list_annotation(messages_annotation):
|
|
243
|
+
raise TypeError(f"Third parameter 'messages' must be annotated as list, got {messages_annotation}")
|
|
244
|
+
if messages_param.default is not inspect.Parameter.empty:
|
|
245
|
+
raise TypeError("Third parameter 'messages' cannot have a default value")
|
|
246
|
+
|
|
247
|
+
optional_params = params[3:]
|
|
248
|
+
|
|
249
|
+
if optional_params:
|
|
250
|
+
ground_truth_param = optional_params[0]
|
|
251
|
+
# Check fourth parameter: ground_truth: Optional[str]
|
|
252
|
+
if ground_truth_param.name != "ground_truth":
|
|
253
|
+
raise TypeError(f"Fourth parameter must be named 'ground_truth', got '{ground_truth_param.name}'")
|
|
254
|
+
ground_truth_annotation = resolved_annotations.get(
|
|
255
|
+
ground_truth_param.name,
|
|
256
|
+
ground_truth_param.annotation,
|
|
257
|
+
)
|
|
258
|
+
if ground_truth_annotation is inspect.Parameter.empty or not _is_optional_str(ground_truth_annotation):
|
|
259
|
+
raise TypeError(
|
|
260
|
+
"Fourth parameter 'ground_truth' must be annotated as Optional[str] or str"
|
|
261
|
+
)
|
|
262
|
+
if ground_truth_param.default is inspect.Parameter.empty:
|
|
263
|
+
raise TypeError("Fourth parameter 'ground_truth' must have a default value of None")
|
|
264
|
+
if ground_truth_param.default is not None:
|
|
265
|
+
raise TypeError("Fourth parameter 'ground_truth' must default to None")
|
|
266
|
+
optional_params = optional_params[1:]
|
|
267
|
+
|
|
268
|
+
if optional_params:
|
|
269
|
+
system_message_param = optional_params[0]
|
|
270
|
+
# Check fifth parameter: system_message: Optional[str]
|
|
271
|
+
if system_message_param.name != "system_message":
|
|
272
|
+
raise TypeError(f"Fifth parameter must be named 'system_message', got '{system_message_param.name}'")
|
|
273
|
+
system_message_annotation = resolved_annotations.get(
|
|
274
|
+
system_message_param.name,
|
|
275
|
+
system_message_param.annotation,
|
|
276
|
+
)
|
|
277
|
+
if system_message_annotation is inspect.Parameter.empty or not _is_optional_str(system_message_annotation):
|
|
278
|
+
raise TypeError(
|
|
279
|
+
"Fifth parameter 'system_message' must be annotated as Optional[str] or str"
|
|
280
|
+
)
|
|
281
|
+
if system_message_param.default is inspect.Parameter.empty:
|
|
282
|
+
raise TypeError("Fifth parameter 'system_message' must have a default value of None")
|
|
283
|
+
if system_message_param.default is not None:
|
|
284
|
+
raise TypeError("Fifth parameter 'system_message' must default to None")
|
|
285
|
+
optional_params = optional_params[1:]
|
|
286
|
+
|
|
287
|
+
if optional_params:
|
|
288
|
+
extra_info_param = optional_params[0]
|
|
289
|
+
# Check sixth parameter: extra_info: dict = None
|
|
290
|
+
if extra_info_param.name != "extra_info":
|
|
291
|
+
raise TypeError(f"Sixth parameter must be named 'extra_info', got '{extra_info_param.name}'")
|
|
292
|
+
extra_info_annotation = resolved_annotations.get(
|
|
293
|
+
extra_info_param.name,
|
|
294
|
+
extra_info_param.annotation,
|
|
295
|
+
)
|
|
296
|
+
if extra_info_annotation is inspect.Parameter.empty or not _is_dict_annotation(extra_info_annotation):
|
|
297
|
+
raise TypeError(
|
|
298
|
+
f"Sixth parameter 'extra_info' must be annotated as dict, got {extra_info_annotation}"
|
|
299
|
+
)
|
|
300
|
+
if extra_info_param.default is inspect.Parameter.empty:
|
|
301
|
+
raise TypeError("Sixth parameter 'extra_info' must have a default value of None")
|
|
302
|
+
if extra_info_param.default is not None:
|
|
303
|
+
raise TypeError("Sixth parameter 'extra_info' must default to None")
|
|
304
|
+
optional_params = optional_params[1:]
|
|
305
|
+
|
|
306
|
+
if optional_params:
|
|
307
|
+
score_min_param = optional_params[0]
|
|
308
|
+
# Check seventh parameter: score_min: float = 0.0
|
|
309
|
+
if score_min_param.name != "score_min":
|
|
310
|
+
raise TypeError(f"Seventh parameter must be named 'score_min', got '{score_min_param.name}'")
|
|
311
|
+
score_min_annotation = resolved_annotations.get(
|
|
312
|
+
score_min_param.name,
|
|
313
|
+
score_min_param.annotation,
|
|
314
|
+
)
|
|
315
|
+
if not _is_float_annotation(score_min_annotation):
|
|
316
|
+
raise TypeError(
|
|
317
|
+
f"Seventh parameter 'score_min' must be annotated as float, got {score_min_annotation}"
|
|
318
|
+
)
|
|
319
|
+
if score_min_param.default is inspect.Parameter.empty:
|
|
320
|
+
raise TypeError("Seventh parameter 'score_min' must have a default value of 0.0")
|
|
321
|
+
if not _is_numeric(score_min_param.default):
|
|
322
|
+
raise TypeError("Seventh parameter 'score_min' must default to a numeric value")
|
|
323
|
+
if float(score_min_param.default) != 0.0:
|
|
324
|
+
raise TypeError("Seventh parameter 'score_min' must default to 0.0")
|
|
325
|
+
optional_params = optional_params[1:]
|
|
326
|
+
|
|
327
|
+
if optional_params:
|
|
328
|
+
score_max_param = optional_params[0]
|
|
329
|
+
# Check eighth parameter: score_max: float = 1.0
|
|
330
|
+
if score_max_param.name != "score_max":
|
|
331
|
+
raise TypeError(f"Eighth parameter must be named 'score_max', got '{score_max_param.name}'")
|
|
332
|
+
score_max_annotation = resolved_annotations.get(
|
|
333
|
+
score_max_param.name,
|
|
334
|
+
score_max_param.annotation,
|
|
335
|
+
)
|
|
336
|
+
if not _is_float_annotation(score_max_annotation):
|
|
337
|
+
raise TypeError(
|
|
338
|
+
f"Eighth parameter 'score_max' must be annotated as float, got {score_max_annotation}"
|
|
339
|
+
)
|
|
340
|
+
if score_max_param.default is inspect.Parameter.empty:
|
|
341
|
+
raise TypeError("Eighth parameter 'score_max' must have a default value of 1.0")
|
|
342
|
+
if not _is_numeric(score_max_param.default):
|
|
343
|
+
raise TypeError("Eighth parameter 'score_max' must default to a numeric value")
|
|
344
|
+
if float(score_max_param.default) != 1.0:
|
|
345
|
+
raise TypeError("Eighth parameter 'score_max' must default to 1.0")
|
|
346
|
+
optional_params = optional_params[1:]
|
|
347
|
+
|
|
348
|
+
if optional_params:
|
|
349
|
+
unexpected_param = optional_params[0]
|
|
350
|
+
raise TypeError(f"Function {func.__name__} has unexpected parameter '{unexpected_param.name}'")
|
|
351
|
+
|
|
352
|
+
@functools.wraps(func)
|
|
353
|
+
def wrapper(*args, **kwargs):
|
|
354
|
+
# Remove unsupported kwargs
|
|
355
|
+
kwargs.pop("data_source", None)
|
|
356
|
+
bound = sig.bind_partial(*args, **kwargs)
|
|
357
|
+
bound.apply_defaults()
|
|
358
|
+
|
|
359
|
+
# Validate model_info argument
|
|
360
|
+
if "model_info" not in bound.arguments:
|
|
361
|
+
raise TypeError("'model_info' argument is required")
|
|
362
|
+
model_info_value = bound.arguments["model_info"]
|
|
363
|
+
if not isinstance(model_info_value, Mapping):
|
|
364
|
+
raise TypeError(f"'model_info' must be a mapping, got {type(model_info_value).__name__}")
|
|
365
|
+
required_model_fields = {"provider", "model"}
|
|
366
|
+
missing_model_fields = required_model_fields - set(model_info_value.keys())
|
|
367
|
+
if missing_model_fields:
|
|
368
|
+
raise ValueError(f"'model_info' is missing required fields: {sorted(missing_model_fields)}")
|
|
369
|
+
provider_value = model_info_value.get("provider")
|
|
370
|
+
if not isinstance(provider_value, str) or not provider_value.strip():
|
|
371
|
+
raise TypeError("'model_info[\"provider\"]' must be a non-empty string")
|
|
372
|
+
model_value = model_info_value.get("model")
|
|
373
|
+
if not isinstance(model_value, str) or not model_value.strip():
|
|
374
|
+
raise TypeError("'model_info[\"model\"]' must be a non-empty string")
|
|
375
|
+
|
|
376
|
+
# Validate rubric argument
|
|
377
|
+
if "rubric" not in bound.arguments:
|
|
378
|
+
raise TypeError("'rubric' argument is required")
|
|
379
|
+
rubric_value = bound.arguments["rubric"]
|
|
380
|
+
if not isinstance(rubric_value, str):
|
|
381
|
+
raise TypeError(f"'rubric' must be a string, got {type(rubric_value).__name__}")
|
|
382
|
+
|
|
383
|
+
# Validate messages argument
|
|
384
|
+
if "messages" not in bound.arguments:
|
|
385
|
+
raise TypeError("'messages' argument is required")
|
|
386
|
+
messages_value = bound.arguments["messages"]
|
|
387
|
+
if not isinstance(messages_value, list):
|
|
388
|
+
raise TypeError(f"'messages' must be a list, got {type(messages_value).__name__}")
|
|
389
|
+
|
|
390
|
+
# Validate optional ground_truth argument
|
|
391
|
+
ground_truth_value = bound.arguments.get("ground_truth")
|
|
392
|
+
if ground_truth_value is not None and not isinstance(ground_truth_value, str):
|
|
393
|
+
raise TypeError(
|
|
394
|
+
f"'ground_truth' must be a string or None, got {type(ground_truth_value).__name__}"
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
# Validate optional system_message argument
|
|
398
|
+
system_message_value = bound.arguments.get("system_message")
|
|
399
|
+
if system_message_value is not None and not isinstance(system_message_value, str):
|
|
400
|
+
raise TypeError(
|
|
401
|
+
f"'system_message' must be a string or None, got {type(system_message_value).__name__}"
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
# Validate messages structure
|
|
405
|
+
for index, message in enumerate(messages_value):
|
|
406
|
+
if not isinstance(message, dict):
|
|
407
|
+
raise TypeError(f"'messages[{index}]' must be a dict, got {type(message).__name__}")
|
|
408
|
+
missing_fields = {"type", "role", "content"} - message.keys()
|
|
409
|
+
if missing_fields:
|
|
410
|
+
raise ValueError(f"'messages[{index}]' is missing required fields: {missing_fields}")
|
|
411
|
+
if message["role"] not in ALLOWED_ROLES:
|
|
412
|
+
raise ValueError(
|
|
413
|
+
f"'messages[{index}]['role']' must be one of {sorted(ALLOWED_ROLES)}, "
|
|
414
|
+
f"got '{message['role']}'"
|
|
415
|
+
)
|
|
416
|
+
if not isinstance(message["content"], list):
|
|
417
|
+
raise TypeError(f"'messages[{index}]['content']' must be a list")
|
|
418
|
+
|
|
419
|
+
score_min_present = "score_min" in bound.arguments
|
|
420
|
+
score_max_present = "score_max" in bound.arguments
|
|
421
|
+
|
|
422
|
+
if score_min_present:
|
|
423
|
+
score_min_value = bound.arguments["score_min"]
|
|
424
|
+
if not _is_numeric(score_min_value):
|
|
425
|
+
raise TypeError(
|
|
426
|
+
f"'score_min' must be a numeric type, got {type(score_min_value).__name__}"
|
|
427
|
+
)
|
|
428
|
+
else:
|
|
429
|
+
score_min_value = None
|
|
430
|
+
|
|
431
|
+
if score_max_present:
|
|
432
|
+
score_max_value = bound.arguments["score_max"]
|
|
433
|
+
if not _is_numeric(score_max_value):
|
|
434
|
+
raise TypeError(
|
|
435
|
+
f"'score_max' must be a numeric type, got {type(score_max_value).__name__}"
|
|
436
|
+
)
|
|
437
|
+
else:
|
|
438
|
+
score_max_value = None
|
|
439
|
+
|
|
440
|
+
if score_min_present and score_max_present:
|
|
441
|
+
if float(score_max_value) <= float(score_min_value):
|
|
442
|
+
raise ValueError("'score_max' must be greater than 'score_min'")
|
|
443
|
+
|
|
444
|
+
# Validate return type
|
|
445
|
+
result = func(*args, **kwargs)
|
|
446
|
+
if not isinstance(result, float):
|
|
447
|
+
raise TypeError(f"Function {func.__name__} must return a float, got {type(result).__name__}")
|
|
448
|
+
return result
|
|
449
|
+
|
|
450
|
+
return wrapper
|