arize 8.0.0a14__py3-none-any.whl → 8.0.0a16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. arize/__init__.py +70 -1
  2. arize/_flight/client.py +163 -43
  3. arize/_flight/types.py +1 -0
  4. arize/_generated/api_client/__init__.py +5 -1
  5. arize/_generated/api_client/api/datasets_api.py +6 -6
  6. arize/_generated/api_client/api/experiments_api.py +924 -61
  7. arize/_generated/api_client/api_client.py +1 -1
  8. arize/_generated/api_client/configuration.py +1 -1
  9. arize/_generated/api_client/exceptions.py +1 -1
  10. arize/_generated/api_client/models/__init__.py +3 -1
  11. arize/_generated/api_client/models/dataset.py +2 -2
  12. arize/_generated/api_client/models/dataset_version.py +1 -1
  13. arize/_generated/api_client/models/datasets_create_request.py +3 -3
  14. arize/_generated/api_client/models/datasets_list200_response.py +1 -1
  15. arize/_generated/api_client/models/datasets_list_examples200_response.py +1 -1
  16. arize/_generated/api_client/models/error.py +1 -1
  17. arize/_generated/api_client/models/experiment.py +6 -6
  18. arize/_generated/api_client/models/experiments_create_request.py +98 -0
  19. arize/_generated/api_client/models/experiments_list200_response.py +1 -1
  20. arize/_generated/api_client/models/experiments_runs_list200_response.py +92 -0
  21. arize/_generated/api_client/rest.py +1 -1
  22. arize/_generated/api_client/test/test_dataset.py +2 -1
  23. arize/_generated/api_client/test/test_dataset_version.py +1 -1
  24. arize/_generated/api_client/test/test_datasets_api.py +1 -1
  25. arize/_generated/api_client/test/test_datasets_create_request.py +2 -1
  26. arize/_generated/api_client/test/test_datasets_list200_response.py +1 -1
  27. arize/_generated/api_client/test/test_datasets_list_examples200_response.py +1 -1
  28. arize/_generated/api_client/test/test_error.py +1 -1
  29. arize/_generated/api_client/test/test_experiment.py +6 -1
  30. arize/_generated/api_client/test/test_experiments_api.py +23 -2
  31. arize/_generated/api_client/test/test_experiments_create_request.py +61 -0
  32. arize/_generated/api_client/test/test_experiments_list200_response.py +1 -1
  33. arize/_generated/api_client/test/test_experiments_runs_list200_response.py +56 -0
  34. arize/_generated/api_client_README.md +13 -8
  35. arize/client.py +19 -2
  36. arize/config.py +50 -3
  37. arize/constants/config.py +8 -2
  38. arize/constants/openinference.py +14 -0
  39. arize/constants/pyarrow.py +1 -0
  40. arize/datasets/__init__.py +0 -70
  41. arize/datasets/client.py +106 -19
  42. arize/datasets/errors.py +61 -0
  43. arize/datasets/validation.py +46 -0
  44. arize/experiments/client.py +455 -0
  45. arize/experiments/evaluators/__init__.py +0 -0
  46. arize/experiments/evaluators/base.py +255 -0
  47. arize/experiments/evaluators/exceptions.py +10 -0
  48. arize/experiments/evaluators/executors.py +502 -0
  49. arize/experiments/evaluators/rate_limiters.py +277 -0
  50. arize/experiments/evaluators/types.py +122 -0
  51. arize/experiments/evaluators/utils.py +198 -0
  52. arize/experiments/functions.py +920 -0
  53. arize/experiments/tracing.py +276 -0
  54. arize/experiments/types.py +394 -0
  55. arize/models/client.py +4 -1
  56. arize/spans/client.py +16 -20
  57. arize/utils/arrow.py +4 -3
  58. arize/utils/openinference_conversion.py +56 -0
  59. arize/utils/proto.py +13 -0
  60. arize/utils/size.py +22 -0
  61. arize/version.py +1 -1
  62. {arize-8.0.0a14.dist-info → arize-8.0.0a16.dist-info}/METADATA +3 -1
  63. {arize-8.0.0a14.dist-info → arize-8.0.0a16.dist-info}/RECORD +65 -44
  64. {arize-8.0.0a14.dist-info → arize-8.0.0a16.dist-info}/WHEEL +0 -0
  65. {arize-8.0.0a14.dist-info → arize-8.0.0a16.dist-info}/licenses/LICENSE.md +0 -0
@@ -0,0 +1,277 @@
1
+ import asyncio
2
+ import time
3
+ from functools import wraps
4
+ from math import exp
5
+ from typing import Any, Callable, Coroutine, Optional, Tuple, Type, TypeVar
6
+
7
+ from typing_extensions import ParamSpec
8
+
9
+ from .exceptions import ArizeException
10
+ from .utils import printif
11
+
12
+ ParameterSpec = ParamSpec("ParameterSpec")
13
+ GenericType = TypeVar("GenericType")
14
+ AsyncCallable = Callable[ParameterSpec, Coroutine[Any, Any, GenericType]]
15
+
16
+
17
+ class UnavailableTokensError(ArizeException):
18
+ pass
19
+
20
+
21
+ class AdaptiveTokenBucket:
22
+ """
23
+ An adaptive rate-limiter that adjusts the rate based on the number of rate limit errors.
24
+
25
+ This rate limiter does not need to know the exact rate limit. Instead, it starts with a high
26
+ rate and reduces it whenever a rate limit error occurs. The rate is increased slowly over time
27
+ if no further errors occur.
28
+
29
+ Args:
30
+ initial_per_second_request_rate (float): The allowed request rate.
31
+ maximum_per_second_request_rate (float): The maximum allowed request rate.
32
+ enforcement_window_minutes (float): The time window over which the rate limit is enforced.
33
+ rate_reduction_factor (float): Multiplier used to reduce the rate limit after an error.
34
+ rate_increase_factor (float): Exponential factor increasing the rate limit over time.
35
+ cooldown_seconds (float): The minimum time before allowing the rate limit to decrease again.
36
+
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ initial_per_second_request_rate: float,
42
+ maximum_per_second_request_rate: Optional[float] = None,
43
+ minimum_per_second_request_rate: float = 0.1,
44
+ enforcement_window_minutes: float = 1,
45
+ rate_reduction_factor: float = 0.5,
46
+ rate_increase_factor: float = 0.01,
47
+ cooldown_seconds: float = 5,
48
+ ):
49
+ self._initial_rate = initial_per_second_request_rate
50
+ self.rate_reduction_factor = rate_reduction_factor
51
+ self.enforcement_window = enforcement_window_minutes * 60
52
+ self.rate_increase_factor = rate_increase_factor
53
+ self.rate = initial_per_second_request_rate
54
+ self.minimum_rate = minimum_per_second_request_rate
55
+
56
+ if maximum_per_second_request_rate is None:
57
+ # if unset, do not allow the maximum rate to exceed 3 consecutive rate reductions
58
+ # assuming the initial rate is the advertised API rate limit
59
+
60
+ maximum_rate_multiple = (1 / rate_reduction_factor) ** 3
61
+ maximum_per_second_request_rate = (
62
+ initial_per_second_request_rate * maximum_rate_multiple
63
+ )
64
+
65
+ maximum_per_second_request_rate = float(maximum_per_second_request_rate)
66
+ assert isinstance(maximum_per_second_request_rate, float)
67
+ self.maximum_rate = maximum_per_second_request_rate
68
+
69
+ self.cooldown = cooldown_seconds
70
+
71
+ now = time.time()
72
+ self.last_rate_update = now
73
+ self.last_checked = now
74
+ self.last_error = now - self.cooldown
75
+ self.tokens = 0.0
76
+
77
+ def increase_rate(self) -> None:
78
+ time_since_last_update = time.time() - self.last_rate_update
79
+ if time_since_last_update > self.enforcement_window:
80
+ self.rate = self._initial_rate
81
+ else:
82
+ self.rate *= exp(self.rate_increase_factor * time_since_last_update)
83
+ self.rate = min(self.rate, self.maximum_rate)
84
+ self.last_rate_update = time.time()
85
+
86
+ def on_rate_limit_error(
87
+ self, request_start_time: float, verbose: bool = False
88
+ ) -> None:
89
+ now = time.time()
90
+ if request_start_time < (self.last_error + self.cooldown):
91
+ # do not reduce the rate for concurrent requests
92
+ return
93
+
94
+ original_rate = self.rate
95
+
96
+ self.rate = original_rate * self.rate_reduction_factor
97
+ printif(
98
+ verbose,
99
+ f"Reducing rate from {original_rate} to {self.rate} after rate limit error",
100
+ )
101
+
102
+ self.rate = max(self.rate, self.minimum_rate)
103
+
104
+ # reset request tokens on a rate limit error
105
+ self.tokens = 0
106
+ self.last_checked = now
107
+ self.last_rate_update = now
108
+ self.last_error = now
109
+ time.sleep(self.cooldown) # block for a bit to let the rate limit reset
110
+
111
+ def max_tokens(self) -> float:
112
+ return self.rate * self.enforcement_window
113
+
114
+ def available_requests(self) -> float:
115
+ now = time.time()
116
+ time_since_last_checked = time.time() - self.last_checked
117
+ self.tokens = min(
118
+ self.max_tokens(), self.rate * time_since_last_checked + self.tokens
119
+ )
120
+ self.last_checked = now
121
+ return self.tokens
122
+
123
+ def make_request_if_ready(self) -> None:
124
+ if self.available_requests() <= 1:
125
+ raise UnavailableTokensError
126
+ self.tokens -= 1
127
+
128
+ def wait_until_ready(
129
+ self,
130
+ max_wait_time: float = 300,
131
+ ) -> None:
132
+ start = time.time()
133
+ while (time.time() - start) < max_wait_time:
134
+ try:
135
+ self.increase_rate()
136
+ self.make_request_if_ready()
137
+ break
138
+ except UnavailableTokensError:
139
+ time.sleep(0.1 / self.rate)
140
+ continue
141
+
142
+ async def async_wait_until_ready(
143
+ self,
144
+ max_wait_time: float = 10, # defeat the token bucket rate limiter at low rates (<.1 req/s)
145
+ ) -> None:
146
+ start = time.time()
147
+ while (time.time() - start) < max_wait_time:
148
+ try:
149
+ self.increase_rate()
150
+ self.make_request_if_ready()
151
+ break
152
+ except UnavailableTokensError:
153
+ await asyncio.sleep(0.1 / self.rate)
154
+ continue
155
+
156
+
157
+ class RateLimitError(ArizeException): ...
158
+
159
+
160
+ class RateLimiter:
161
+ def __init__(
162
+ self,
163
+ rate_limit_error: Optional[Type[BaseException]] = None,
164
+ max_rate_limit_retries: int = 3,
165
+ initial_per_second_request_rate: float = 1.0,
166
+ maximum_per_second_request_rate: Optional[float] = None,
167
+ enforcement_window_minutes: float = 1,
168
+ rate_reduction_factor: float = 0.5,
169
+ rate_increase_factor: float = 0.01,
170
+ cooldown_seconds: float = 5,
171
+ verbose: bool = False,
172
+ ) -> None:
173
+ self._rate_limit_error: Tuple[Type[BaseException], ...]
174
+ self._rate_limit_error = (
175
+ (rate_limit_error,) if rate_limit_error is not None else tuple()
176
+ )
177
+
178
+ self._max_rate_limit_retries = max_rate_limit_retries
179
+ self._throttler = AdaptiveTokenBucket(
180
+ initial_per_second_request_rate=initial_per_second_request_rate,
181
+ maximum_per_second_request_rate=maximum_per_second_request_rate,
182
+ enforcement_window_minutes=enforcement_window_minutes,
183
+ rate_reduction_factor=rate_reduction_factor,
184
+ rate_increase_factor=rate_increase_factor,
185
+ cooldown_seconds=cooldown_seconds,
186
+ )
187
+ self._rate_limit_handling: Optional[asyncio.Event] = None
188
+ self._rate_limit_handling_lock: Optional[asyncio.Lock] = None
189
+ self._current_loop: Optional[asyncio.AbstractEventLoop] = None
190
+ self._verbose = verbose
191
+
192
+ def limit(
193
+ self, fn: Callable[ParameterSpec, GenericType]
194
+ ) -> Callable[ParameterSpec, GenericType]:
195
+ @wraps(fn)
196
+ def wrapper(*args: Any, **kwargs: Any) -> GenericType:
197
+ try:
198
+ self._throttler.wait_until_ready()
199
+ request_start_time = time.time()
200
+ return fn(*args, **kwargs)
201
+ except self._rate_limit_error:
202
+ self._throttler.on_rate_limit_error(
203
+ request_start_time, verbose=self._verbose
204
+ )
205
+ for _attempt in range(self._max_rate_limit_retries):
206
+ try:
207
+ request_start_time = time.time()
208
+ self._throttler.wait_until_ready()
209
+ return fn(*args, **kwargs)
210
+ except self._rate_limit_error:
211
+ self._throttler.on_rate_limit_error(
212
+ request_start_time, verbose=self._verbose
213
+ )
214
+ continue
215
+ raise RateLimitError(
216
+ f"Exceeded max ({self._max_rate_limit_retries}) retries"
217
+ )
218
+
219
+ return wrapper
220
+
221
+ def _initialize_async_primitives(self) -> None:
222
+ """
223
+ Lazily initialize async primitives to ensure they are created in the correct event loop.
224
+ """
225
+ loop = asyncio.get_running_loop()
226
+ if loop is not self._current_loop:
227
+ self._current_loop = loop
228
+ self._rate_limit_handling = asyncio.Event()
229
+ self._rate_limit_handling.set()
230
+ self._rate_limit_handling_lock = asyncio.Lock()
231
+
232
+ def alimit(
233
+ self, fn: AsyncCallable[ParameterSpec, GenericType]
234
+ ) -> AsyncCallable[ParameterSpec, GenericType]:
235
+ @wraps(fn)
236
+ async def wrapper(*args: Any, **kwargs: Any) -> GenericType:
237
+ self._initialize_async_primitives()
238
+ assert self._rate_limit_handling_lock is not None and isinstance(
239
+ self._rate_limit_handling_lock, asyncio.Lock
240
+ )
241
+ assert self._rate_limit_handling is not None and isinstance(
242
+ self._rate_limit_handling, asyncio.Event
243
+ )
244
+ try:
245
+ try:
246
+ await asyncio.wait_for(
247
+ self._rate_limit_handling.wait(), 120
248
+ )
249
+ except asyncio.TimeoutError:
250
+ self._rate_limit_handling.set() # Set the event as a failsafe
251
+ await self._throttler.async_wait_until_ready()
252
+ request_start_time = time.time()
253
+ return await fn(*args, **kwargs)
254
+ except self._rate_limit_error:
255
+ async with self._rate_limit_handling_lock:
256
+ self._rate_limit_handling.clear() # prevent new requests from starting
257
+ self._throttler.on_rate_limit_error(
258
+ request_start_time, verbose=self._verbose
259
+ )
260
+ try:
261
+ for _attempt in range(self._max_rate_limit_retries):
262
+ try:
263
+ request_start_time = time.time()
264
+ await self._throttler.async_wait_until_ready()
265
+ return await fn(*args, **kwargs)
266
+ except self._rate_limit_error:
267
+ self._throttler.on_rate_limit_error(
268
+ request_start_time, verbose=self._verbose
269
+ )
270
+ continue
271
+ finally:
272
+ self._rate_limit_handling.set() # allow new requests to start
273
+ raise RateLimitError(
274
+ f"Exceeded max ({self._max_rate_limit_retries}) retries"
275
+ )
276
+
277
+ return wrapper
@@ -0,0 +1,122 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from enum import Enum
5
+ from typing import Any, Dict, List, Mapping, Tuple
6
+
7
+ JSONSerializable = Dict[str, Any] | List[Any] | str | int | float | bool
8
+
9
+
10
+ class AnnotatorKind(Enum):
11
+ CODE = "CODE"
12
+ LLM = "LLM"
13
+
14
+
15
+ EvaluatorKind = str
16
+ EvaluatorName = str
17
+
18
+ Score = bool | int | float | None
19
+ Label = str | None
20
+ Explanation = str | None
21
+
22
+
23
+ @dataclass(frozen=True)
24
+ class EvaluationResult:
25
+ """
26
+ Represents the result of an evaluation.
27
+ Args:
28
+ score: The score of the evaluation.
29
+ label: The label of the evaluation.
30
+ explanation: The explanation of the evaluation.
31
+ metadata: Additional metadata for the evaluation.
32
+ """
33
+
34
+ score: float | None = None
35
+ label: str | None = None
36
+ explanation: str | None = None
37
+ metadata: Mapping[str, JSONSerializable] = field(default_factory=dict)
38
+
39
+ @classmethod
40
+ def from_dict(
41
+ cls, obj: Mapping[str, Any] | None
42
+ ) -> EvaluationResult | None:
43
+ if not obj:
44
+ return None
45
+ return cls(
46
+ score=obj.get("score"),
47
+ label=obj.get("label"),
48
+ explanation=obj.get("explanation"),
49
+ metadata=obj.get("metadata") or {},
50
+ )
51
+
52
+ def __post_init__(self) -> None:
53
+ if self.score is None and not self.label:
54
+ raise ValueError("Must specify score or label, or both")
55
+ if self.score is None and not self.label:
56
+ object.__setattr__(self, "score", 0)
57
+ for k in ("label", "explanation"):
58
+ v = getattr(self, k, None)
59
+ if v is not None:
60
+ object.__setattr__(self, k, str(v) or None)
61
+
62
+
63
+ EvaluatorOutput = (
64
+ EvaluationResult
65
+ | bool
66
+ | int
67
+ | float
68
+ | str
69
+ | Tuple[Score, Label, Explanation]
70
+ )
71
+
72
+
73
+ @dataclass
74
+ class EvaluationResultFieldNames:
75
+ """Column names for mapping evaluation results in a DataFrame.
76
+
77
+ Args:
78
+ score: Optional name of column containing evaluation scores
79
+ label: Optional name of column containing evaluation labels
80
+ explanation: Optional name of column containing evaluation explanations
81
+ metadata: Optional mapping of metadata keys to column names. If a column name
82
+ is None or empty string, the metadata key will be used as the column name.
83
+
84
+ Examples:
85
+ >>> # Basic usage with score and label columns
86
+ >>> EvaluationResultColumnNames(
87
+ ... score="quality.score", label="quality.label"
88
+ ... )
89
+
90
+ >>> # Using metadata with same key and column name
91
+ >>> EvaluationResultColumnNames(
92
+ ... score="quality.score",
93
+ ... metadata={
94
+ ... "version": None
95
+ ... }, # Will look for column named "version"
96
+ ... )
97
+
98
+ >>> # Using metadata with different key and column name
99
+ >>> EvaluationResultColumnNames(
100
+ ... score="quality.score",
101
+ ... metadata={
102
+ ... # Will look for "column_in_my_df.version" column and ingest as
103
+ ... # "eval.{EvaluatorName}.meatadata.model_version"
104
+ ... "model_version": "column_in_my_df.version",
105
+ ... # Will look for "column_in_my_df.ts" column and ingest as
106
+ ... # "eval.{EvaluatorName}.metadata.timestamp"
107
+ ... "timestamp": "column_in_my_df.ts",
108
+ ... },
109
+ ... )
110
+
111
+ Raises:
112
+ ValueError: If neither score nor label column names are specified
113
+ """
114
+
115
+ score: str | None = None
116
+ label: str | None = None
117
+ explanation: str | None = None
118
+ metadata: Dict[str, str | None] | None = None
119
+
120
+ def __post_init__(self) -> None:
121
+ if self.score is None and self.label is None:
122
+ raise ValueError("Must specify score or label column name, or both")
@@ -0,0 +1,198 @@
1
+ import functools
2
+ import inspect
3
+ from typing import TYPE_CHECKING, Any, Callable, Optional
4
+
5
+ from tqdm.auto import tqdm
6
+
7
+ from arize.experiments.evaluators.types import (
8
+ EvaluationResult,
9
+ JSONSerializable,
10
+ )
11
+
12
+
13
+ def get_func_name(fn: Callable[..., Any]) -> str:
14
+ """
15
+ Makes a best-effort attempt to get the name of the function.
16
+ """
17
+ if isinstance(fn, functools.partial):
18
+ return fn.func.__qualname__
19
+ if hasattr(fn, "__qualname__") and not fn.__qualname__.endswith("<lambda>"):
20
+ return fn.__qualname__.split(".<locals>.")[-1]
21
+ return str(fn)
22
+
23
+
24
+ if TYPE_CHECKING:
25
+ from ..evaluators.base import Evaluator
26
+
27
+
28
+ def unwrap_json(obj: JSONSerializable) -> JSONSerializable:
29
+ if isinstance(obj, dict) and len(obj) == 1:
30
+ key = next(iter(obj.keys()))
31
+ output = obj[key]
32
+ assert isinstance(
33
+ output, (dict, list, str, int, float, bool, type(None))
34
+ ), "Output must be JSON serializable"
35
+ return output
36
+ return obj
37
+
38
+
39
+ def validate_evaluator_signature(sig: inspect.Signature) -> None:
40
+ # Check that the wrapped function has a valid signature for use as an evaluator
41
+ # If it does not, raise an error to exit early before running evaluations
42
+ params = sig.parameters
43
+ valid_named_params = {
44
+ "input",
45
+ "output",
46
+ "experiment_output",
47
+ "dataset_output",
48
+ "metadata",
49
+ "dataset_row",
50
+ }
51
+ if len(params) == 0:
52
+ raise ValueError(
53
+ "Evaluation function must have at least one parameter."
54
+ )
55
+ if len(params) > 1:
56
+ for not_found in set(params) - valid_named_params:
57
+ param = params[not_found]
58
+ if (
59
+ param.kind is inspect.Parameter.VAR_KEYWORD
60
+ or param.default is not inspect.Parameter.empty
61
+ ):
62
+ continue
63
+ raise ValueError(
64
+ f"Invalid parameter names in evaluation function: {', '.join(not_found)}. "
65
+ "Parameters names for multi-argument functions must be "
66
+ f"any of: {', '.join(valid_named_params)}."
67
+ )
68
+
69
+
70
+ def _bind_evaluator_signature(
71
+ sig: inspect.Signature, **kwargs: Any
72
+ ) -> inspect.BoundArguments:
73
+ parameter_mapping = {
74
+ "input": kwargs.get("input"),
75
+ "output": kwargs.get("output"),
76
+ "experiment_output": kwargs.get("experiment_output"),
77
+ "dataset_output": kwargs.get("dataset_output"),
78
+ "metadata": kwargs.get("metadata"),
79
+ "dataset_row": kwargs.get("dataset_row"),
80
+ }
81
+ params = sig.parameters
82
+ if len(params) == 1:
83
+ parameter_name = next(iter(params))
84
+ if parameter_name in parameter_mapping:
85
+ return sig.bind(parameter_mapping[parameter_name])
86
+ else:
87
+ return sig.bind(parameter_mapping["experiment_output"])
88
+ return sig.bind_partial(
89
+ **{
90
+ name: parameter_mapping[name]
91
+ for name in set(parameter_mapping).intersection(params)
92
+ }
93
+ )
94
+
95
+
96
+ def create_evaluator(
97
+ name: Optional[str] = None,
98
+ scorer: Optional[Callable[[Any], EvaluationResult]] = None,
99
+ ) -> Callable[[Callable[..., Any]], "Evaluator"]:
100
+ if scorer is None:
101
+ scorer = _default_eval_scorer
102
+
103
+ def wrapper(func: Callable[..., Any]) -> "Evaluator":
104
+ nonlocal name
105
+ if not name:
106
+ name = get_func_name(func)
107
+ assert name is not None
108
+
109
+ wrapped_signature = inspect.signature(func)
110
+ validate_evaluator_signature(wrapped_signature)
111
+
112
+ if inspect.iscoroutinefunction(func):
113
+ return _wrap_coroutine_evaluation_function(
114
+ name, wrapped_signature, scorer
115
+ )(func)
116
+
117
+ return _wrap_sync_evaluation_function(name, wrapped_signature, scorer)(
118
+ func
119
+ )
120
+
121
+ return wrapper
122
+
123
+
124
+ def _wrap_coroutine_evaluation_function(
125
+ name: str,
126
+ sig: inspect.Signature,
127
+ convert_to_score: Callable[[Any], EvaluationResult],
128
+ ) -> Callable[[Callable[..., Any]], "Evaluator"]:
129
+ from ..evaluators.base import Evaluator
130
+
131
+ def wrapper(func: Callable[..., Any]) -> "Evaluator":
132
+ class AsyncEvaluator(Evaluator):
133
+ def __init__(self) -> None:
134
+ self._name = name
135
+
136
+ @functools.wraps(func)
137
+ async def __call__(self, *args: Any, **kwargs: Any) -> Any:
138
+ return await func(*args, **kwargs)
139
+
140
+ async def async_evaluate(self, **kwargs: Any) -> EvaluationResult:
141
+ bound_signature = _bind_evaluator_signature(sig, **kwargs)
142
+ result = await func(
143
+ *bound_signature.args, **bound_signature.kwargs
144
+ )
145
+ return convert_to_score(result)
146
+
147
+ return AsyncEvaluator()
148
+
149
+ return wrapper
150
+
151
+
152
+ def _wrap_sync_evaluation_function(
153
+ name: str,
154
+ sig: inspect.Signature,
155
+ convert_to_score: Callable[[Any], EvaluationResult],
156
+ ) -> Callable[[Callable[..., Any]], "Evaluator"]:
157
+ from ..evaluators.base import Evaluator
158
+
159
+ def wrapper(func: Callable[..., Any]) -> "Evaluator":
160
+ class SyncEvaluator(Evaluator):
161
+ def __init__(self) -> None:
162
+ self._name = name
163
+
164
+ @functools.wraps(func)
165
+ def __call__(self, *args: Any, **kwargs: Any) -> Any:
166
+ return func(*args, **kwargs)
167
+
168
+ def evaluate(self, **kwargs: Any) -> EvaluationResult:
169
+ bound_signature = _bind_evaluator_signature(sig, **kwargs)
170
+ result = func(*bound_signature.args, **bound_signature.kwargs)
171
+ return convert_to_score(result)
172
+
173
+ return SyncEvaluator()
174
+
175
+ return wrapper
176
+
177
+
178
+ def _default_eval_scorer(result: Any) -> EvaluationResult:
179
+ if isinstance(result, EvaluationResult):
180
+ return result
181
+ if isinstance(result, bool):
182
+ return EvaluationResult(score=float(result), label=str(result))
183
+ if hasattr(result, "__float__"):
184
+ return EvaluationResult(score=float(result))
185
+ if isinstance(result, str):
186
+ return EvaluationResult(label=result)
187
+ if isinstance(result, (tuple, list)) and len(result) == 2:
188
+ # If the result is a 2-tuple, the first item will be recorded as the score
189
+ # and the second item will recorded as the explanation.
190
+ return EvaluationResult(
191
+ score=float(result[0]), explanation=str(result[1])
192
+ )
193
+ raise ValueError(f"Unsupported evaluation result type: {type(result)}")
194
+
195
+
196
+ def printif(condition: bool, *args: Any, **kwargs: Any) -> None:
197
+ if condition:
198
+ tqdm.write(*args, **kwargs)