arize 8.0.0a14__py3-none-any.whl → 8.0.0a16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arize/__init__.py +70 -1
- arize/_flight/client.py +163 -43
- arize/_flight/types.py +1 -0
- arize/_generated/api_client/__init__.py +5 -1
- arize/_generated/api_client/api/datasets_api.py +6 -6
- arize/_generated/api_client/api/experiments_api.py +924 -61
- arize/_generated/api_client/api_client.py +1 -1
- arize/_generated/api_client/configuration.py +1 -1
- arize/_generated/api_client/exceptions.py +1 -1
- arize/_generated/api_client/models/__init__.py +3 -1
- arize/_generated/api_client/models/dataset.py +2 -2
- arize/_generated/api_client/models/dataset_version.py +1 -1
- arize/_generated/api_client/models/datasets_create_request.py +3 -3
- arize/_generated/api_client/models/datasets_list200_response.py +1 -1
- arize/_generated/api_client/models/datasets_list_examples200_response.py +1 -1
- arize/_generated/api_client/models/error.py +1 -1
- arize/_generated/api_client/models/experiment.py +6 -6
- arize/_generated/api_client/models/experiments_create_request.py +98 -0
- arize/_generated/api_client/models/experiments_list200_response.py +1 -1
- arize/_generated/api_client/models/experiments_runs_list200_response.py +92 -0
- arize/_generated/api_client/rest.py +1 -1
- arize/_generated/api_client/test/test_dataset.py +2 -1
- arize/_generated/api_client/test/test_dataset_version.py +1 -1
- arize/_generated/api_client/test/test_datasets_api.py +1 -1
- arize/_generated/api_client/test/test_datasets_create_request.py +2 -1
- arize/_generated/api_client/test/test_datasets_list200_response.py +1 -1
- arize/_generated/api_client/test/test_datasets_list_examples200_response.py +1 -1
- arize/_generated/api_client/test/test_error.py +1 -1
- arize/_generated/api_client/test/test_experiment.py +6 -1
- arize/_generated/api_client/test/test_experiments_api.py +23 -2
- arize/_generated/api_client/test/test_experiments_create_request.py +61 -0
- arize/_generated/api_client/test/test_experiments_list200_response.py +1 -1
- arize/_generated/api_client/test/test_experiments_runs_list200_response.py +56 -0
- arize/_generated/api_client_README.md +13 -8
- arize/client.py +19 -2
- arize/config.py +50 -3
- arize/constants/config.py +8 -2
- arize/constants/openinference.py +14 -0
- arize/constants/pyarrow.py +1 -0
- arize/datasets/__init__.py +0 -70
- arize/datasets/client.py +106 -19
- arize/datasets/errors.py +61 -0
- arize/datasets/validation.py +46 -0
- arize/experiments/client.py +455 -0
- arize/experiments/evaluators/__init__.py +0 -0
- arize/experiments/evaluators/base.py +255 -0
- arize/experiments/evaluators/exceptions.py +10 -0
- arize/experiments/evaluators/executors.py +502 -0
- arize/experiments/evaluators/rate_limiters.py +277 -0
- arize/experiments/evaluators/types.py +122 -0
- arize/experiments/evaluators/utils.py +198 -0
- arize/experiments/functions.py +920 -0
- arize/experiments/tracing.py +276 -0
- arize/experiments/types.py +394 -0
- arize/models/client.py +4 -1
- arize/spans/client.py +16 -20
- arize/utils/arrow.py +4 -3
- arize/utils/openinference_conversion.py +56 -0
- arize/utils/proto.py +13 -0
- arize/utils/size.py +22 -0
- arize/version.py +1 -1
- {arize-8.0.0a14.dist-info → arize-8.0.0a16.dist-info}/METADATA +3 -1
- {arize-8.0.0a14.dist-info → arize-8.0.0a16.dist-info}/RECORD +65 -44
- {arize-8.0.0a14.dist-info → arize-8.0.0a16.dist-info}/WHEEL +0 -0
- {arize-8.0.0a14.dist-info → arize-8.0.0a16.dist-info}/licenses/LICENSE.md +0 -0
|
@@ -0,0 +1,277 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import time
|
|
3
|
+
from functools import wraps
|
|
4
|
+
from math import exp
|
|
5
|
+
from typing import Any, Callable, Coroutine, Optional, Tuple, Type, TypeVar
|
|
6
|
+
|
|
7
|
+
from typing_extensions import ParamSpec
|
|
8
|
+
|
|
9
|
+
from .exceptions import ArizeException
|
|
10
|
+
from .utils import printif
|
|
11
|
+
|
|
12
|
+
ParameterSpec = ParamSpec("ParameterSpec")
|
|
13
|
+
GenericType = TypeVar("GenericType")
|
|
14
|
+
AsyncCallable = Callable[ParameterSpec, Coroutine[Any, Any, GenericType]]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class UnavailableTokensError(ArizeException):
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class AdaptiveTokenBucket:
|
|
22
|
+
"""
|
|
23
|
+
An adaptive rate-limiter that adjusts the rate based on the number of rate limit errors.
|
|
24
|
+
|
|
25
|
+
This rate limiter does not need to know the exact rate limit. Instead, it starts with a high
|
|
26
|
+
rate and reduces it whenever a rate limit error occurs. The rate is increased slowly over time
|
|
27
|
+
if no further errors occur.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
initial_per_second_request_rate (float): The allowed request rate.
|
|
31
|
+
maximum_per_second_request_rate (float): The maximum allowed request rate.
|
|
32
|
+
enforcement_window_minutes (float): The time window over which the rate limit is enforced.
|
|
33
|
+
rate_reduction_factor (float): Multiplier used to reduce the rate limit after an error.
|
|
34
|
+
rate_increase_factor (float): Exponential factor increasing the rate limit over time.
|
|
35
|
+
cooldown_seconds (float): The minimum time before allowing the rate limit to decrease again.
|
|
36
|
+
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
initial_per_second_request_rate: float,
|
|
42
|
+
maximum_per_second_request_rate: Optional[float] = None,
|
|
43
|
+
minimum_per_second_request_rate: float = 0.1,
|
|
44
|
+
enforcement_window_minutes: float = 1,
|
|
45
|
+
rate_reduction_factor: float = 0.5,
|
|
46
|
+
rate_increase_factor: float = 0.01,
|
|
47
|
+
cooldown_seconds: float = 5,
|
|
48
|
+
):
|
|
49
|
+
self._initial_rate = initial_per_second_request_rate
|
|
50
|
+
self.rate_reduction_factor = rate_reduction_factor
|
|
51
|
+
self.enforcement_window = enforcement_window_minutes * 60
|
|
52
|
+
self.rate_increase_factor = rate_increase_factor
|
|
53
|
+
self.rate = initial_per_second_request_rate
|
|
54
|
+
self.minimum_rate = minimum_per_second_request_rate
|
|
55
|
+
|
|
56
|
+
if maximum_per_second_request_rate is None:
|
|
57
|
+
# if unset, do not allow the maximum rate to exceed 3 consecutive rate reductions
|
|
58
|
+
# assuming the initial rate is the advertised API rate limit
|
|
59
|
+
|
|
60
|
+
maximum_rate_multiple = (1 / rate_reduction_factor) ** 3
|
|
61
|
+
maximum_per_second_request_rate = (
|
|
62
|
+
initial_per_second_request_rate * maximum_rate_multiple
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
maximum_per_second_request_rate = float(maximum_per_second_request_rate)
|
|
66
|
+
assert isinstance(maximum_per_second_request_rate, float)
|
|
67
|
+
self.maximum_rate = maximum_per_second_request_rate
|
|
68
|
+
|
|
69
|
+
self.cooldown = cooldown_seconds
|
|
70
|
+
|
|
71
|
+
now = time.time()
|
|
72
|
+
self.last_rate_update = now
|
|
73
|
+
self.last_checked = now
|
|
74
|
+
self.last_error = now - self.cooldown
|
|
75
|
+
self.tokens = 0.0
|
|
76
|
+
|
|
77
|
+
def increase_rate(self) -> None:
|
|
78
|
+
time_since_last_update = time.time() - self.last_rate_update
|
|
79
|
+
if time_since_last_update > self.enforcement_window:
|
|
80
|
+
self.rate = self._initial_rate
|
|
81
|
+
else:
|
|
82
|
+
self.rate *= exp(self.rate_increase_factor * time_since_last_update)
|
|
83
|
+
self.rate = min(self.rate, self.maximum_rate)
|
|
84
|
+
self.last_rate_update = time.time()
|
|
85
|
+
|
|
86
|
+
def on_rate_limit_error(
|
|
87
|
+
self, request_start_time: float, verbose: bool = False
|
|
88
|
+
) -> None:
|
|
89
|
+
now = time.time()
|
|
90
|
+
if request_start_time < (self.last_error + self.cooldown):
|
|
91
|
+
# do not reduce the rate for concurrent requests
|
|
92
|
+
return
|
|
93
|
+
|
|
94
|
+
original_rate = self.rate
|
|
95
|
+
|
|
96
|
+
self.rate = original_rate * self.rate_reduction_factor
|
|
97
|
+
printif(
|
|
98
|
+
verbose,
|
|
99
|
+
f"Reducing rate from {original_rate} to {self.rate} after rate limit error",
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
self.rate = max(self.rate, self.minimum_rate)
|
|
103
|
+
|
|
104
|
+
# reset request tokens on a rate limit error
|
|
105
|
+
self.tokens = 0
|
|
106
|
+
self.last_checked = now
|
|
107
|
+
self.last_rate_update = now
|
|
108
|
+
self.last_error = now
|
|
109
|
+
time.sleep(self.cooldown) # block for a bit to let the rate limit reset
|
|
110
|
+
|
|
111
|
+
def max_tokens(self) -> float:
|
|
112
|
+
return self.rate * self.enforcement_window
|
|
113
|
+
|
|
114
|
+
def available_requests(self) -> float:
|
|
115
|
+
now = time.time()
|
|
116
|
+
time_since_last_checked = time.time() - self.last_checked
|
|
117
|
+
self.tokens = min(
|
|
118
|
+
self.max_tokens(), self.rate * time_since_last_checked + self.tokens
|
|
119
|
+
)
|
|
120
|
+
self.last_checked = now
|
|
121
|
+
return self.tokens
|
|
122
|
+
|
|
123
|
+
def make_request_if_ready(self) -> None:
|
|
124
|
+
if self.available_requests() <= 1:
|
|
125
|
+
raise UnavailableTokensError
|
|
126
|
+
self.tokens -= 1
|
|
127
|
+
|
|
128
|
+
def wait_until_ready(
|
|
129
|
+
self,
|
|
130
|
+
max_wait_time: float = 300,
|
|
131
|
+
) -> None:
|
|
132
|
+
start = time.time()
|
|
133
|
+
while (time.time() - start) < max_wait_time:
|
|
134
|
+
try:
|
|
135
|
+
self.increase_rate()
|
|
136
|
+
self.make_request_if_ready()
|
|
137
|
+
break
|
|
138
|
+
except UnavailableTokensError:
|
|
139
|
+
time.sleep(0.1 / self.rate)
|
|
140
|
+
continue
|
|
141
|
+
|
|
142
|
+
async def async_wait_until_ready(
|
|
143
|
+
self,
|
|
144
|
+
max_wait_time: float = 10, # defeat the token bucket rate limiter at low rates (<.1 req/s)
|
|
145
|
+
) -> None:
|
|
146
|
+
start = time.time()
|
|
147
|
+
while (time.time() - start) < max_wait_time:
|
|
148
|
+
try:
|
|
149
|
+
self.increase_rate()
|
|
150
|
+
self.make_request_if_ready()
|
|
151
|
+
break
|
|
152
|
+
except UnavailableTokensError:
|
|
153
|
+
await asyncio.sleep(0.1 / self.rate)
|
|
154
|
+
continue
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class RateLimitError(ArizeException): ...
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class RateLimiter:
|
|
161
|
+
def __init__(
|
|
162
|
+
self,
|
|
163
|
+
rate_limit_error: Optional[Type[BaseException]] = None,
|
|
164
|
+
max_rate_limit_retries: int = 3,
|
|
165
|
+
initial_per_second_request_rate: float = 1.0,
|
|
166
|
+
maximum_per_second_request_rate: Optional[float] = None,
|
|
167
|
+
enforcement_window_minutes: float = 1,
|
|
168
|
+
rate_reduction_factor: float = 0.5,
|
|
169
|
+
rate_increase_factor: float = 0.01,
|
|
170
|
+
cooldown_seconds: float = 5,
|
|
171
|
+
verbose: bool = False,
|
|
172
|
+
) -> None:
|
|
173
|
+
self._rate_limit_error: Tuple[Type[BaseException], ...]
|
|
174
|
+
self._rate_limit_error = (
|
|
175
|
+
(rate_limit_error,) if rate_limit_error is not None else tuple()
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
self._max_rate_limit_retries = max_rate_limit_retries
|
|
179
|
+
self._throttler = AdaptiveTokenBucket(
|
|
180
|
+
initial_per_second_request_rate=initial_per_second_request_rate,
|
|
181
|
+
maximum_per_second_request_rate=maximum_per_second_request_rate,
|
|
182
|
+
enforcement_window_minutes=enforcement_window_minutes,
|
|
183
|
+
rate_reduction_factor=rate_reduction_factor,
|
|
184
|
+
rate_increase_factor=rate_increase_factor,
|
|
185
|
+
cooldown_seconds=cooldown_seconds,
|
|
186
|
+
)
|
|
187
|
+
self._rate_limit_handling: Optional[asyncio.Event] = None
|
|
188
|
+
self._rate_limit_handling_lock: Optional[asyncio.Lock] = None
|
|
189
|
+
self._current_loop: Optional[asyncio.AbstractEventLoop] = None
|
|
190
|
+
self._verbose = verbose
|
|
191
|
+
|
|
192
|
+
def limit(
|
|
193
|
+
self, fn: Callable[ParameterSpec, GenericType]
|
|
194
|
+
) -> Callable[ParameterSpec, GenericType]:
|
|
195
|
+
@wraps(fn)
|
|
196
|
+
def wrapper(*args: Any, **kwargs: Any) -> GenericType:
|
|
197
|
+
try:
|
|
198
|
+
self._throttler.wait_until_ready()
|
|
199
|
+
request_start_time = time.time()
|
|
200
|
+
return fn(*args, **kwargs)
|
|
201
|
+
except self._rate_limit_error:
|
|
202
|
+
self._throttler.on_rate_limit_error(
|
|
203
|
+
request_start_time, verbose=self._verbose
|
|
204
|
+
)
|
|
205
|
+
for _attempt in range(self._max_rate_limit_retries):
|
|
206
|
+
try:
|
|
207
|
+
request_start_time = time.time()
|
|
208
|
+
self._throttler.wait_until_ready()
|
|
209
|
+
return fn(*args, **kwargs)
|
|
210
|
+
except self._rate_limit_error:
|
|
211
|
+
self._throttler.on_rate_limit_error(
|
|
212
|
+
request_start_time, verbose=self._verbose
|
|
213
|
+
)
|
|
214
|
+
continue
|
|
215
|
+
raise RateLimitError(
|
|
216
|
+
f"Exceeded max ({self._max_rate_limit_retries}) retries"
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
return wrapper
|
|
220
|
+
|
|
221
|
+
def _initialize_async_primitives(self) -> None:
|
|
222
|
+
"""
|
|
223
|
+
Lazily initialize async primitives to ensure they are created in the correct event loop.
|
|
224
|
+
"""
|
|
225
|
+
loop = asyncio.get_running_loop()
|
|
226
|
+
if loop is not self._current_loop:
|
|
227
|
+
self._current_loop = loop
|
|
228
|
+
self._rate_limit_handling = asyncio.Event()
|
|
229
|
+
self._rate_limit_handling.set()
|
|
230
|
+
self._rate_limit_handling_lock = asyncio.Lock()
|
|
231
|
+
|
|
232
|
+
def alimit(
|
|
233
|
+
self, fn: AsyncCallable[ParameterSpec, GenericType]
|
|
234
|
+
) -> AsyncCallable[ParameterSpec, GenericType]:
|
|
235
|
+
@wraps(fn)
|
|
236
|
+
async def wrapper(*args: Any, **kwargs: Any) -> GenericType:
|
|
237
|
+
self._initialize_async_primitives()
|
|
238
|
+
assert self._rate_limit_handling_lock is not None and isinstance(
|
|
239
|
+
self._rate_limit_handling_lock, asyncio.Lock
|
|
240
|
+
)
|
|
241
|
+
assert self._rate_limit_handling is not None and isinstance(
|
|
242
|
+
self._rate_limit_handling, asyncio.Event
|
|
243
|
+
)
|
|
244
|
+
try:
|
|
245
|
+
try:
|
|
246
|
+
await asyncio.wait_for(
|
|
247
|
+
self._rate_limit_handling.wait(), 120
|
|
248
|
+
)
|
|
249
|
+
except asyncio.TimeoutError:
|
|
250
|
+
self._rate_limit_handling.set() # Set the event as a failsafe
|
|
251
|
+
await self._throttler.async_wait_until_ready()
|
|
252
|
+
request_start_time = time.time()
|
|
253
|
+
return await fn(*args, **kwargs)
|
|
254
|
+
except self._rate_limit_error:
|
|
255
|
+
async with self._rate_limit_handling_lock:
|
|
256
|
+
self._rate_limit_handling.clear() # prevent new requests from starting
|
|
257
|
+
self._throttler.on_rate_limit_error(
|
|
258
|
+
request_start_time, verbose=self._verbose
|
|
259
|
+
)
|
|
260
|
+
try:
|
|
261
|
+
for _attempt in range(self._max_rate_limit_retries):
|
|
262
|
+
try:
|
|
263
|
+
request_start_time = time.time()
|
|
264
|
+
await self._throttler.async_wait_until_ready()
|
|
265
|
+
return await fn(*args, **kwargs)
|
|
266
|
+
except self._rate_limit_error:
|
|
267
|
+
self._throttler.on_rate_limit_error(
|
|
268
|
+
request_start_time, verbose=self._verbose
|
|
269
|
+
)
|
|
270
|
+
continue
|
|
271
|
+
finally:
|
|
272
|
+
self._rate_limit_handling.set() # allow new requests to start
|
|
273
|
+
raise RateLimitError(
|
|
274
|
+
f"Exceeded max ({self._max_rate_limit_retries}) retries"
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
return wrapper
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from typing import Any, Dict, List, Mapping, Tuple
|
|
6
|
+
|
|
7
|
+
JSONSerializable = Dict[str, Any] | List[Any] | str | int | float | bool
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class AnnotatorKind(Enum):
|
|
11
|
+
CODE = "CODE"
|
|
12
|
+
LLM = "LLM"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
EvaluatorKind = str
|
|
16
|
+
EvaluatorName = str
|
|
17
|
+
|
|
18
|
+
Score = bool | int | float | None
|
|
19
|
+
Label = str | None
|
|
20
|
+
Explanation = str | None
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass(frozen=True)
|
|
24
|
+
class EvaluationResult:
|
|
25
|
+
"""
|
|
26
|
+
Represents the result of an evaluation.
|
|
27
|
+
Args:
|
|
28
|
+
score: The score of the evaluation.
|
|
29
|
+
label: The label of the evaluation.
|
|
30
|
+
explanation: The explanation of the evaluation.
|
|
31
|
+
metadata: Additional metadata for the evaluation.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
score: float | None = None
|
|
35
|
+
label: str | None = None
|
|
36
|
+
explanation: str | None = None
|
|
37
|
+
metadata: Mapping[str, JSONSerializable] = field(default_factory=dict)
|
|
38
|
+
|
|
39
|
+
@classmethod
|
|
40
|
+
def from_dict(
|
|
41
|
+
cls, obj: Mapping[str, Any] | None
|
|
42
|
+
) -> EvaluationResult | None:
|
|
43
|
+
if not obj:
|
|
44
|
+
return None
|
|
45
|
+
return cls(
|
|
46
|
+
score=obj.get("score"),
|
|
47
|
+
label=obj.get("label"),
|
|
48
|
+
explanation=obj.get("explanation"),
|
|
49
|
+
metadata=obj.get("metadata") or {},
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
def __post_init__(self) -> None:
|
|
53
|
+
if self.score is None and not self.label:
|
|
54
|
+
raise ValueError("Must specify score or label, or both")
|
|
55
|
+
if self.score is None and not self.label:
|
|
56
|
+
object.__setattr__(self, "score", 0)
|
|
57
|
+
for k in ("label", "explanation"):
|
|
58
|
+
v = getattr(self, k, None)
|
|
59
|
+
if v is not None:
|
|
60
|
+
object.__setattr__(self, k, str(v) or None)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
EvaluatorOutput = (
|
|
64
|
+
EvaluationResult
|
|
65
|
+
| bool
|
|
66
|
+
| int
|
|
67
|
+
| float
|
|
68
|
+
| str
|
|
69
|
+
| Tuple[Score, Label, Explanation]
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@dataclass
|
|
74
|
+
class EvaluationResultFieldNames:
|
|
75
|
+
"""Column names for mapping evaluation results in a DataFrame.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
score: Optional name of column containing evaluation scores
|
|
79
|
+
label: Optional name of column containing evaluation labels
|
|
80
|
+
explanation: Optional name of column containing evaluation explanations
|
|
81
|
+
metadata: Optional mapping of metadata keys to column names. If a column name
|
|
82
|
+
is None or empty string, the metadata key will be used as the column name.
|
|
83
|
+
|
|
84
|
+
Examples:
|
|
85
|
+
>>> # Basic usage with score and label columns
|
|
86
|
+
>>> EvaluationResultColumnNames(
|
|
87
|
+
... score="quality.score", label="quality.label"
|
|
88
|
+
... )
|
|
89
|
+
|
|
90
|
+
>>> # Using metadata with same key and column name
|
|
91
|
+
>>> EvaluationResultColumnNames(
|
|
92
|
+
... score="quality.score",
|
|
93
|
+
... metadata={
|
|
94
|
+
... "version": None
|
|
95
|
+
... }, # Will look for column named "version"
|
|
96
|
+
... )
|
|
97
|
+
|
|
98
|
+
>>> # Using metadata with different key and column name
|
|
99
|
+
>>> EvaluationResultColumnNames(
|
|
100
|
+
... score="quality.score",
|
|
101
|
+
... metadata={
|
|
102
|
+
... # Will look for "column_in_my_df.version" column and ingest as
|
|
103
|
+
... # "eval.{EvaluatorName}.meatadata.model_version"
|
|
104
|
+
... "model_version": "column_in_my_df.version",
|
|
105
|
+
... # Will look for "column_in_my_df.ts" column and ingest as
|
|
106
|
+
... # "eval.{EvaluatorName}.metadata.timestamp"
|
|
107
|
+
... "timestamp": "column_in_my_df.ts",
|
|
108
|
+
... },
|
|
109
|
+
... )
|
|
110
|
+
|
|
111
|
+
Raises:
|
|
112
|
+
ValueError: If neither score nor label column names are specified
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
score: str | None = None
|
|
116
|
+
label: str | None = None
|
|
117
|
+
explanation: str | None = None
|
|
118
|
+
metadata: Dict[str, str | None] | None = None
|
|
119
|
+
|
|
120
|
+
def __post_init__(self) -> None:
|
|
121
|
+
if self.score is None and self.label is None:
|
|
122
|
+
raise ValueError("Must specify score or label column name, or both")
|
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import inspect
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional
|
|
4
|
+
|
|
5
|
+
from tqdm.auto import tqdm
|
|
6
|
+
|
|
7
|
+
from arize.experiments.evaluators.types import (
|
|
8
|
+
EvaluationResult,
|
|
9
|
+
JSONSerializable,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def get_func_name(fn: Callable[..., Any]) -> str:
|
|
14
|
+
"""
|
|
15
|
+
Makes a best-effort attempt to get the name of the function.
|
|
16
|
+
"""
|
|
17
|
+
if isinstance(fn, functools.partial):
|
|
18
|
+
return fn.func.__qualname__
|
|
19
|
+
if hasattr(fn, "__qualname__") and not fn.__qualname__.endswith("<lambda>"):
|
|
20
|
+
return fn.__qualname__.split(".<locals>.")[-1]
|
|
21
|
+
return str(fn)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
if TYPE_CHECKING:
|
|
25
|
+
from ..evaluators.base import Evaluator
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def unwrap_json(obj: JSONSerializable) -> JSONSerializable:
|
|
29
|
+
if isinstance(obj, dict) and len(obj) == 1:
|
|
30
|
+
key = next(iter(obj.keys()))
|
|
31
|
+
output = obj[key]
|
|
32
|
+
assert isinstance(
|
|
33
|
+
output, (dict, list, str, int, float, bool, type(None))
|
|
34
|
+
), "Output must be JSON serializable"
|
|
35
|
+
return output
|
|
36
|
+
return obj
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def validate_evaluator_signature(sig: inspect.Signature) -> None:
|
|
40
|
+
# Check that the wrapped function has a valid signature for use as an evaluator
|
|
41
|
+
# If it does not, raise an error to exit early before running evaluations
|
|
42
|
+
params = sig.parameters
|
|
43
|
+
valid_named_params = {
|
|
44
|
+
"input",
|
|
45
|
+
"output",
|
|
46
|
+
"experiment_output",
|
|
47
|
+
"dataset_output",
|
|
48
|
+
"metadata",
|
|
49
|
+
"dataset_row",
|
|
50
|
+
}
|
|
51
|
+
if len(params) == 0:
|
|
52
|
+
raise ValueError(
|
|
53
|
+
"Evaluation function must have at least one parameter."
|
|
54
|
+
)
|
|
55
|
+
if len(params) > 1:
|
|
56
|
+
for not_found in set(params) - valid_named_params:
|
|
57
|
+
param = params[not_found]
|
|
58
|
+
if (
|
|
59
|
+
param.kind is inspect.Parameter.VAR_KEYWORD
|
|
60
|
+
or param.default is not inspect.Parameter.empty
|
|
61
|
+
):
|
|
62
|
+
continue
|
|
63
|
+
raise ValueError(
|
|
64
|
+
f"Invalid parameter names in evaluation function: {', '.join(not_found)}. "
|
|
65
|
+
"Parameters names for multi-argument functions must be "
|
|
66
|
+
f"any of: {', '.join(valid_named_params)}."
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _bind_evaluator_signature(
|
|
71
|
+
sig: inspect.Signature, **kwargs: Any
|
|
72
|
+
) -> inspect.BoundArguments:
|
|
73
|
+
parameter_mapping = {
|
|
74
|
+
"input": kwargs.get("input"),
|
|
75
|
+
"output": kwargs.get("output"),
|
|
76
|
+
"experiment_output": kwargs.get("experiment_output"),
|
|
77
|
+
"dataset_output": kwargs.get("dataset_output"),
|
|
78
|
+
"metadata": kwargs.get("metadata"),
|
|
79
|
+
"dataset_row": kwargs.get("dataset_row"),
|
|
80
|
+
}
|
|
81
|
+
params = sig.parameters
|
|
82
|
+
if len(params) == 1:
|
|
83
|
+
parameter_name = next(iter(params))
|
|
84
|
+
if parameter_name in parameter_mapping:
|
|
85
|
+
return sig.bind(parameter_mapping[parameter_name])
|
|
86
|
+
else:
|
|
87
|
+
return sig.bind(parameter_mapping["experiment_output"])
|
|
88
|
+
return sig.bind_partial(
|
|
89
|
+
**{
|
|
90
|
+
name: parameter_mapping[name]
|
|
91
|
+
for name in set(parameter_mapping).intersection(params)
|
|
92
|
+
}
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def create_evaluator(
|
|
97
|
+
name: Optional[str] = None,
|
|
98
|
+
scorer: Optional[Callable[[Any], EvaluationResult]] = None,
|
|
99
|
+
) -> Callable[[Callable[..., Any]], "Evaluator"]:
|
|
100
|
+
if scorer is None:
|
|
101
|
+
scorer = _default_eval_scorer
|
|
102
|
+
|
|
103
|
+
def wrapper(func: Callable[..., Any]) -> "Evaluator":
|
|
104
|
+
nonlocal name
|
|
105
|
+
if not name:
|
|
106
|
+
name = get_func_name(func)
|
|
107
|
+
assert name is not None
|
|
108
|
+
|
|
109
|
+
wrapped_signature = inspect.signature(func)
|
|
110
|
+
validate_evaluator_signature(wrapped_signature)
|
|
111
|
+
|
|
112
|
+
if inspect.iscoroutinefunction(func):
|
|
113
|
+
return _wrap_coroutine_evaluation_function(
|
|
114
|
+
name, wrapped_signature, scorer
|
|
115
|
+
)(func)
|
|
116
|
+
|
|
117
|
+
return _wrap_sync_evaluation_function(name, wrapped_signature, scorer)(
|
|
118
|
+
func
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
return wrapper
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _wrap_coroutine_evaluation_function(
|
|
125
|
+
name: str,
|
|
126
|
+
sig: inspect.Signature,
|
|
127
|
+
convert_to_score: Callable[[Any], EvaluationResult],
|
|
128
|
+
) -> Callable[[Callable[..., Any]], "Evaluator"]:
|
|
129
|
+
from ..evaluators.base import Evaluator
|
|
130
|
+
|
|
131
|
+
def wrapper(func: Callable[..., Any]) -> "Evaluator":
|
|
132
|
+
class AsyncEvaluator(Evaluator):
|
|
133
|
+
def __init__(self) -> None:
|
|
134
|
+
self._name = name
|
|
135
|
+
|
|
136
|
+
@functools.wraps(func)
|
|
137
|
+
async def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
|
138
|
+
return await func(*args, **kwargs)
|
|
139
|
+
|
|
140
|
+
async def async_evaluate(self, **kwargs: Any) -> EvaluationResult:
|
|
141
|
+
bound_signature = _bind_evaluator_signature(sig, **kwargs)
|
|
142
|
+
result = await func(
|
|
143
|
+
*bound_signature.args, **bound_signature.kwargs
|
|
144
|
+
)
|
|
145
|
+
return convert_to_score(result)
|
|
146
|
+
|
|
147
|
+
return AsyncEvaluator()
|
|
148
|
+
|
|
149
|
+
return wrapper
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def _wrap_sync_evaluation_function(
|
|
153
|
+
name: str,
|
|
154
|
+
sig: inspect.Signature,
|
|
155
|
+
convert_to_score: Callable[[Any], EvaluationResult],
|
|
156
|
+
) -> Callable[[Callable[..., Any]], "Evaluator"]:
|
|
157
|
+
from ..evaluators.base import Evaluator
|
|
158
|
+
|
|
159
|
+
def wrapper(func: Callable[..., Any]) -> "Evaluator":
|
|
160
|
+
class SyncEvaluator(Evaluator):
|
|
161
|
+
def __init__(self) -> None:
|
|
162
|
+
self._name = name
|
|
163
|
+
|
|
164
|
+
@functools.wraps(func)
|
|
165
|
+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
|
166
|
+
return func(*args, **kwargs)
|
|
167
|
+
|
|
168
|
+
def evaluate(self, **kwargs: Any) -> EvaluationResult:
|
|
169
|
+
bound_signature = _bind_evaluator_signature(sig, **kwargs)
|
|
170
|
+
result = func(*bound_signature.args, **bound_signature.kwargs)
|
|
171
|
+
return convert_to_score(result)
|
|
172
|
+
|
|
173
|
+
return SyncEvaluator()
|
|
174
|
+
|
|
175
|
+
return wrapper
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def _default_eval_scorer(result: Any) -> EvaluationResult:
|
|
179
|
+
if isinstance(result, EvaluationResult):
|
|
180
|
+
return result
|
|
181
|
+
if isinstance(result, bool):
|
|
182
|
+
return EvaluationResult(score=float(result), label=str(result))
|
|
183
|
+
if hasattr(result, "__float__"):
|
|
184
|
+
return EvaluationResult(score=float(result))
|
|
185
|
+
if isinstance(result, str):
|
|
186
|
+
return EvaluationResult(label=result)
|
|
187
|
+
if isinstance(result, (tuple, list)) and len(result) == 2:
|
|
188
|
+
# If the result is a 2-tuple, the first item will be recorded as the score
|
|
189
|
+
# and the second item will recorded as the explanation.
|
|
190
|
+
return EvaluationResult(
|
|
191
|
+
score=float(result[0]), explanation=str(result[1])
|
|
192
|
+
)
|
|
193
|
+
raise ValueError(f"Unsupported evaluation result type: {type(result)}")
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def printif(condition: bool, *args: Any, **kwargs: Any) -> None:
|
|
197
|
+
if condition:
|
|
198
|
+
tqdm.write(*args, **kwargs)
|