arize 8.0.0a22__py3-none-any.whl → 8.0.0b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (171) hide show
  1. arize/__init__.py +28 -19
  2. arize/_exporter/client.py +56 -37
  3. arize/_exporter/parsers/tracing_data_parser.py +41 -30
  4. arize/_exporter/validation.py +3 -3
  5. arize/_flight/client.py +207 -76
  6. arize/_generated/api_client/__init__.py +30 -6
  7. arize/_generated/api_client/api/__init__.py +1 -0
  8. arize/_generated/api_client/api/datasets_api.py +864 -190
  9. arize/_generated/api_client/api/experiments_api.py +167 -131
  10. arize/_generated/api_client/api/projects_api.py +1197 -0
  11. arize/_generated/api_client/api_client.py +2 -2
  12. arize/_generated/api_client/configuration.py +42 -34
  13. arize/_generated/api_client/exceptions.py +2 -2
  14. arize/_generated/api_client/models/__init__.py +15 -4
  15. arize/_generated/api_client/models/dataset.py +10 -10
  16. arize/_generated/api_client/models/dataset_example.py +111 -0
  17. arize/_generated/api_client/models/dataset_example_update.py +100 -0
  18. arize/_generated/api_client/models/dataset_version.py +13 -13
  19. arize/_generated/api_client/models/datasets_create_request.py +16 -8
  20. arize/_generated/api_client/models/datasets_examples_insert_request.py +100 -0
  21. arize/_generated/api_client/models/datasets_examples_list200_response.py +106 -0
  22. arize/_generated/api_client/models/datasets_examples_update_request.py +102 -0
  23. arize/_generated/api_client/models/datasets_list200_response.py +10 -4
  24. arize/_generated/api_client/models/experiment.py +14 -16
  25. arize/_generated/api_client/models/experiment_run.py +108 -0
  26. arize/_generated/api_client/models/experiment_run_create.py +102 -0
  27. arize/_generated/api_client/models/experiments_create_request.py +16 -10
  28. arize/_generated/api_client/models/experiments_list200_response.py +10 -4
  29. arize/_generated/api_client/models/experiments_runs_list200_response.py +19 -5
  30. arize/_generated/api_client/models/{error.py → pagination_metadata.py} +13 -11
  31. arize/_generated/api_client/models/primitive_value.py +172 -0
  32. arize/_generated/api_client/models/problem.py +100 -0
  33. arize/_generated/api_client/models/project.py +99 -0
  34. arize/_generated/api_client/models/{datasets_list_examples200_response.py → projects_create_request.py} +13 -11
  35. arize/_generated/api_client/models/projects_list200_response.py +106 -0
  36. arize/_generated/api_client/rest.py +2 -2
  37. arize/_generated/api_client/test/test_dataset.py +4 -2
  38. arize/_generated/api_client/test/test_dataset_example.py +56 -0
  39. arize/_generated/api_client/test/test_dataset_example_update.py +52 -0
  40. arize/_generated/api_client/test/test_dataset_version.py +7 -2
  41. arize/_generated/api_client/test/test_datasets_api.py +27 -13
  42. arize/_generated/api_client/test/test_datasets_create_request.py +8 -4
  43. arize/_generated/api_client/test/{test_datasets_list_examples200_response.py → test_datasets_examples_insert_request.py} +19 -15
  44. arize/_generated/api_client/test/test_datasets_examples_list200_response.py +66 -0
  45. arize/_generated/api_client/test/test_datasets_examples_update_request.py +61 -0
  46. arize/_generated/api_client/test/test_datasets_list200_response.py +9 -3
  47. arize/_generated/api_client/test/test_experiment.py +2 -4
  48. arize/_generated/api_client/test/test_experiment_run.py +56 -0
  49. arize/_generated/api_client/test/test_experiment_run_create.py +54 -0
  50. arize/_generated/api_client/test/test_experiments_api.py +6 -6
  51. arize/_generated/api_client/test/test_experiments_create_request.py +9 -6
  52. arize/_generated/api_client/test/test_experiments_list200_response.py +9 -5
  53. arize/_generated/api_client/test/test_experiments_runs_list200_response.py +15 -5
  54. arize/_generated/api_client/test/test_pagination_metadata.py +53 -0
  55. arize/_generated/api_client/test/{test_error.py → test_primitive_value.py} +13 -14
  56. arize/_generated/api_client/test/test_problem.py +57 -0
  57. arize/_generated/api_client/test/test_project.py +58 -0
  58. arize/_generated/api_client/test/test_projects_api.py +59 -0
  59. arize/_generated/api_client/test/test_projects_create_request.py +54 -0
  60. arize/_generated/api_client/test/test_projects_list200_response.py +70 -0
  61. arize/_generated/api_client_README.md +43 -29
  62. arize/_generated/protocol/flight/flight_pb2.py +400 -0
  63. arize/_lazy.py +27 -19
  64. arize/client.py +181 -58
  65. arize/config.py +324 -116
  66. arize/constants/__init__.py +1 -0
  67. arize/constants/config.py +11 -4
  68. arize/constants/ml.py +6 -4
  69. arize/constants/openinference.py +2 -0
  70. arize/constants/pyarrow.py +2 -0
  71. arize/constants/spans.py +3 -1
  72. arize/datasets/__init__.py +1 -0
  73. arize/datasets/client.py +304 -84
  74. arize/datasets/errors.py +32 -2
  75. arize/datasets/validation.py +18 -8
  76. arize/embeddings/__init__.py +2 -0
  77. arize/embeddings/auto_generator.py +23 -19
  78. arize/embeddings/base_generators.py +89 -36
  79. arize/embeddings/constants.py +2 -0
  80. arize/embeddings/cv_generators.py +26 -4
  81. arize/embeddings/errors.py +27 -5
  82. arize/embeddings/nlp_generators.py +43 -18
  83. arize/embeddings/tabular_generators.py +46 -31
  84. arize/embeddings/usecases.py +12 -2
  85. arize/exceptions/__init__.py +1 -0
  86. arize/exceptions/auth.py +11 -1
  87. arize/exceptions/base.py +29 -4
  88. arize/exceptions/models.py +21 -2
  89. arize/exceptions/parameters.py +31 -0
  90. arize/exceptions/spaces.py +12 -1
  91. arize/exceptions/types.py +86 -7
  92. arize/exceptions/values.py +220 -20
  93. arize/experiments/__init__.py +13 -0
  94. arize/experiments/client.py +394 -285
  95. arize/experiments/evaluators/__init__.py +1 -0
  96. arize/experiments/evaluators/base.py +74 -41
  97. arize/experiments/evaluators/exceptions.py +6 -3
  98. arize/experiments/evaluators/executors.py +121 -73
  99. arize/experiments/evaluators/rate_limiters.py +106 -57
  100. arize/experiments/evaluators/types.py +34 -7
  101. arize/experiments/evaluators/utils.py +65 -27
  102. arize/experiments/functions.py +103 -101
  103. arize/experiments/tracing.py +52 -44
  104. arize/experiments/types.py +56 -31
  105. arize/logging.py +54 -22
  106. arize/ml/__init__.py +1 -0
  107. arize/ml/batch_validation/__init__.py +1 -0
  108. arize/{models → ml}/batch_validation/errors.py +545 -67
  109. arize/{models → ml}/batch_validation/validator.py +344 -303
  110. arize/ml/bounded_executor.py +47 -0
  111. arize/{models → ml}/casting.py +118 -108
  112. arize/{models → ml}/client.py +339 -118
  113. arize/{models → ml}/proto.py +97 -42
  114. arize/{models → ml}/stream_validation.py +43 -15
  115. arize/ml/surrogate_explainer/__init__.py +1 -0
  116. arize/{models → ml}/surrogate_explainer/mimic.py +25 -10
  117. arize/{types.py → ml/types.py} +355 -354
  118. arize/pre_releases.py +44 -0
  119. arize/projects/__init__.py +1 -0
  120. arize/projects/client.py +134 -0
  121. arize/regions.py +40 -0
  122. arize/spans/__init__.py +1 -0
  123. arize/spans/client.py +204 -175
  124. arize/spans/columns.py +13 -0
  125. arize/spans/conversion.py +60 -37
  126. arize/spans/validation/__init__.py +1 -0
  127. arize/spans/validation/annotations/__init__.py +1 -0
  128. arize/spans/validation/annotations/annotations_validation.py +6 -4
  129. arize/spans/validation/annotations/dataframe_form_validation.py +13 -11
  130. arize/spans/validation/annotations/value_validation.py +35 -11
  131. arize/spans/validation/common/__init__.py +1 -0
  132. arize/spans/validation/common/argument_validation.py +33 -8
  133. arize/spans/validation/common/dataframe_form_validation.py +35 -9
  134. arize/spans/validation/common/errors.py +211 -11
  135. arize/spans/validation/common/value_validation.py +81 -14
  136. arize/spans/validation/evals/__init__.py +1 -0
  137. arize/spans/validation/evals/dataframe_form_validation.py +28 -8
  138. arize/spans/validation/evals/evals_validation.py +34 -4
  139. arize/spans/validation/evals/value_validation.py +26 -3
  140. arize/spans/validation/metadata/__init__.py +1 -1
  141. arize/spans/validation/metadata/argument_validation.py +14 -5
  142. arize/spans/validation/metadata/dataframe_form_validation.py +26 -10
  143. arize/spans/validation/metadata/value_validation.py +24 -10
  144. arize/spans/validation/spans/__init__.py +1 -0
  145. arize/spans/validation/spans/dataframe_form_validation.py +35 -14
  146. arize/spans/validation/spans/spans_validation.py +35 -4
  147. arize/spans/validation/spans/value_validation.py +78 -8
  148. arize/utils/__init__.py +1 -0
  149. arize/utils/arrow.py +31 -15
  150. arize/utils/cache.py +34 -6
  151. arize/utils/dataframe.py +20 -3
  152. arize/utils/online_tasks/__init__.py +2 -0
  153. arize/utils/online_tasks/dataframe_preprocessor.py +58 -47
  154. arize/utils/openinference_conversion.py +44 -5
  155. arize/utils/proto.py +10 -0
  156. arize/utils/size.py +5 -3
  157. arize/utils/types.py +105 -0
  158. arize/version.py +3 -1
  159. {arize-8.0.0a22.dist-info → arize-8.0.0b0.dist-info}/METADATA +13 -6
  160. arize-8.0.0b0.dist-info/RECORD +175 -0
  161. {arize-8.0.0a22.dist-info → arize-8.0.0b0.dist-info}/WHEEL +1 -1
  162. arize-8.0.0b0.dist-info/licenses/LICENSE +176 -0
  163. arize-8.0.0b0.dist-info/licenses/NOTICE +13 -0
  164. arize/_generated/protocol/flight/export_pb2.py +0 -61
  165. arize/_generated/protocol/flight/ingest_pb2.py +0 -365
  166. arize/models/__init__.py +0 -0
  167. arize/models/batch_validation/__init__.py +0 -0
  168. arize/models/bounded_executor.py +0 -34
  169. arize/models/surrogate_explainer/__init__.py +0 -0
  170. arize-8.0.0a22.dist-info/RECORD +0 -146
  171. arize-8.0.0a22.dist-info/licenses/LICENSE.md +0 -12
@@ -1,8 +1,11 @@
1
+ """Rate limiting utilities for evaluator execution."""
2
+
1
3
  import asyncio
2
4
  import time
5
+ from collections.abc import Callable, Coroutine
3
6
  from functools import wraps
4
7
  from math import exp
5
- from typing import Any, Callable, Coroutine, Optional, Tuple, Type, TypeVar
8
+ from typing import Any, TypeVar
6
9
 
7
10
  from typing_extensions import ParamSpec
8
11
 
@@ -15,12 +18,11 @@ AsyncCallable = Callable[ParameterSpec, Coroutine[Any, Any, GenericType]]
15
18
 
16
19
 
17
20
  class UnavailableTokensError(ArizeException):
18
- pass
21
+ """Raised when insufficient tokens are available for rate limiting."""
19
22
 
20
23
 
21
24
  class AdaptiveTokenBucket:
22
- """
23
- An adaptive rate-limiter that adjusts the rate based on the number of rate limit errors.
25
+ """An adaptive rate-limiter that adjusts the rate based on the number of rate limit errors.
24
26
 
25
27
  This rate limiter does not need to know the exact rate limit. Instead, it starts with a high
26
28
  rate and reduces it whenever a rate limit error occurs. The rate is increased slowly over time
@@ -39,13 +41,24 @@ class AdaptiveTokenBucket:
39
41
  def __init__(
40
42
  self,
41
43
  initial_per_second_request_rate: float,
42
- maximum_per_second_request_rate: Optional[float] = None,
44
+ maximum_per_second_request_rate: float | None = None,
43
45
  minimum_per_second_request_rate: float = 0.1,
44
46
  enforcement_window_minutes: float = 1,
45
47
  rate_reduction_factor: float = 0.5,
46
48
  rate_increase_factor: float = 0.01,
47
49
  cooldown_seconds: float = 5,
48
- ):
50
+ ) -> None:
51
+ """Initialize the adaptive rate limit state.
52
+
53
+ Args:
54
+ initial_per_second_request_rate: Starting request rate per second.
55
+ maximum_per_second_request_rate: Maximum allowed rate limit.
56
+ minimum_per_second_request_rate: Minimum allowed rate limit.
57
+ enforcement_window_minutes: Time window for rate enforcement.
58
+ rate_reduction_factor: Factor to reduce rate on errors.
59
+ rate_increase_factor: Factor to gradually increase rate.
60
+ cooldown_seconds: Cooldown period before rate adjustments.
61
+ """
49
62
  self._initial_rate = initial_per_second_request_rate
50
63
  self.rate_reduction_factor = rate_reduction_factor
51
64
  self.enforcement_window = enforcement_window_minutes * 60
@@ -63,7 +76,6 @@ class AdaptiveTokenBucket:
63
76
  )
64
77
 
65
78
  maximum_per_second_request_rate = float(maximum_per_second_request_rate)
66
- assert isinstance(maximum_per_second_request_rate, float)
67
79
  self.maximum_rate = maximum_per_second_request_rate
68
80
 
69
81
  self.cooldown = cooldown_seconds
@@ -75,6 +87,7 @@ class AdaptiveTokenBucket:
75
87
  self.tokens = 0.0
76
88
 
77
89
  def increase_rate(self) -> None:
90
+ """Increase the rate limit based on time elapsed since last update."""
78
91
  time_since_last_update = time.time() - self.last_rate_update
79
92
  if time_since_last_update > self.enforcement_window:
80
93
  self.rate = self._initial_rate
@@ -86,6 +99,7 @@ class AdaptiveTokenBucket:
86
99
  def on_rate_limit_error(
87
100
  self, request_start_time: float, verbose: bool = False
88
101
  ) -> None:
102
+ """Handle rate limit error by reducing the rate and adding cooldown period."""
89
103
  now = time.time()
90
104
  if request_start_time < (self.last_error + self.cooldown):
91
105
  # do not reduce the rate for concurrent requests
@@ -109,9 +123,11 @@ class AdaptiveTokenBucket:
109
123
  time.sleep(self.cooldown) # block for a bit to let the rate limit reset
110
124
 
111
125
  def max_tokens(self) -> float:
126
+ """Return the maximum number of tokens allowed in the enforcement window."""
112
127
  return self.rate * self.enforcement_window
113
128
 
114
129
  def available_requests(self) -> float:
130
+ """Return the current number of available request tokens."""
115
131
  now = time.time()
116
132
  time_since_last_checked = time.time() - self.last_checked
117
133
  self.tokens = min(
@@ -121,6 +137,7 @@ class AdaptiveTokenBucket:
121
137
  return self.tokens
122
138
 
123
139
  def make_request_if_ready(self) -> None:
140
+ """Make a request if tokens are available, otherwise raise error."""
124
141
  if self.available_requests() <= 1:
125
142
  raise UnavailableTokensError
126
143
  self.tokens -= 1
@@ -129,6 +146,7 @@ class AdaptiveTokenBucket:
129
146
  self,
130
147
  max_wait_time: float = 300,
131
148
  ) -> None:
149
+ """Wait until tokens are available for making a request."""
132
150
  start = time.time()
133
151
  while (time.time() - start) < max_wait_time:
134
152
  try:
@@ -143,6 +161,7 @@ class AdaptiveTokenBucket:
143
161
  self,
144
162
  max_wait_time: float = 10, # defeat the token bucket rate limiter at low rates (<.1 req/s)
145
163
  ) -> None:
164
+ """Asynchronously wait until tokens are available for making a request."""
146
165
  start = time.time()
147
166
  while (time.time() - start) < max_wait_time:
148
167
  try:
@@ -154,25 +173,41 @@ class AdaptiveTokenBucket:
154
173
  continue
155
174
 
156
175
 
157
- class RateLimitError(ArizeException): ...
176
+ class RateLimitError(ArizeException):
177
+ """Raised when a rate limit is exceeded."""
158
178
 
159
179
 
160
180
  class RateLimiter:
181
+ """Rate limiter for controlling request frequency with adaptive token bucket algorithm."""
182
+
161
183
  def __init__(
162
184
  self,
163
- rate_limit_error: Optional[Type[BaseException]] = None,
185
+ rate_limit_error: type[BaseException] | None = None,
164
186
  max_rate_limit_retries: int = 3,
165
187
  initial_per_second_request_rate: float = 1.0,
166
- maximum_per_second_request_rate: Optional[float] = None,
188
+ maximum_per_second_request_rate: float | None = None,
167
189
  enforcement_window_minutes: float = 1,
168
190
  rate_reduction_factor: float = 0.5,
169
191
  rate_increase_factor: float = 0.01,
170
192
  cooldown_seconds: float = 5,
171
193
  verbose: bool = False,
172
194
  ) -> None:
173
- self._rate_limit_error: Tuple[Type[BaseException], ...]
195
+ """Initialize the rate limiter with adaptive token bucket algorithm.
196
+
197
+ Args:
198
+ rate_limit_error: Exception type to catch for rate limiting.
199
+ max_rate_limit_retries: Maximum retries for rate limit errors.
200
+ initial_per_second_request_rate: Initial request rate per second.
201
+ maximum_per_second_request_rate: Maximum allowed rate limit.
202
+ enforcement_window_minutes: Time window for rate enforcement.
203
+ rate_reduction_factor: Factor to reduce rate on errors.
204
+ rate_increase_factor: Factor to gradually increase rate.
205
+ cooldown_seconds: Cooldown period before rate adjustments.
206
+ verbose: Whether to print rate limit adjustments.
207
+ """
208
+ self._rate_limit_error: tuple[type[BaseException], ...]
174
209
  self._rate_limit_error = (
175
- (rate_limit_error,) if rate_limit_error is not None else tuple()
210
+ (rate_limit_error,) if rate_limit_error is not None else ()
176
211
  )
177
212
 
178
213
  self._max_rate_limit_retries = max_rate_limit_retries
@@ -184,44 +219,50 @@ class RateLimiter:
184
219
  rate_increase_factor=rate_increase_factor,
185
220
  cooldown_seconds=cooldown_seconds,
186
221
  )
187
- self._rate_limit_handling: Optional[asyncio.Event] = None
188
- self._rate_limit_handling_lock: Optional[asyncio.Lock] = None
189
- self._current_loop: Optional[asyncio.AbstractEventLoop] = None
222
+ self._rate_limit_handling: asyncio.Event | None = None
223
+ self._rate_limit_handling_lock: asyncio.Lock | None = None
224
+ self._current_loop: asyncio.AbstractEventLoop | None = None
190
225
  self._verbose = verbose
191
226
 
227
+ def _retry_with_rate_limit_sync(
228
+ self,
229
+ fn: Callable[..., GenericType],
230
+ remaining_attempts: int,
231
+ *args: object,
232
+ **kwargs: object,
233
+ ) -> GenericType:
234
+ """Recursively retry a function call with rate limiting."""
235
+ try:
236
+ request_start_time = time.time()
237
+ self._throttler.wait_until_ready()
238
+ return fn(*args, **kwargs)
239
+ except self._rate_limit_error as e:
240
+ self._throttler.on_rate_limit_error(
241
+ request_start_time, verbose=self._verbose
242
+ )
243
+ if remaining_attempts <= 1:
244
+ raise RateLimitError(
245
+ f"Exceeded max ({self._max_rate_limit_retries}) retries"
246
+ ) from e
247
+ return self._retry_with_rate_limit_sync(
248
+ fn, remaining_attempts - 1, *args, **kwargs
249
+ )
250
+
192
251
  def limit(
193
252
  self, fn: Callable[ParameterSpec, GenericType]
194
253
  ) -> Callable[ParameterSpec, GenericType]:
254
+ """Apply rate limiting to a synchronous function."""
255
+
195
256
  @wraps(fn)
196
- def wrapper(*args: Any, **kwargs: Any) -> GenericType:
197
- try:
198
- self._throttler.wait_until_ready()
199
- request_start_time = time.time()
200
- return fn(*args, **kwargs)
201
- except self._rate_limit_error:
202
- self._throttler.on_rate_limit_error(
203
- request_start_time, verbose=self._verbose
204
- )
205
- for _attempt in range(self._max_rate_limit_retries):
206
- try:
207
- request_start_time = time.time()
208
- self._throttler.wait_until_ready()
209
- return fn(*args, **kwargs)
210
- except self._rate_limit_error:
211
- self._throttler.on_rate_limit_error(
212
- request_start_time, verbose=self._verbose
213
- )
214
- continue
215
- raise RateLimitError(
216
- f"Exceeded max ({self._max_rate_limit_retries}) retries"
257
+ def wrapper(*args: object, **kwargs: object) -> GenericType:
258
+ return self._retry_with_rate_limit_sync(
259
+ fn, self._max_rate_limit_retries, *args, **kwargs
217
260
  )
218
261
 
219
262
  return wrapper
220
263
 
221
264
  def _initialize_async_primitives(self) -> None:
222
- """
223
- Lazily initialize async primitives to ensure they are created in the correct event loop.
224
- """
265
+ """Lazily initialize async primitives to ensure they are created in the correct event loop."""
225
266
  loop = asyncio.get_running_loop()
226
267
  if loop is not self._current_loop:
227
268
  self._current_loop = loop
@@ -232,15 +273,19 @@ class RateLimiter:
232
273
  def alimit(
233
274
  self, fn: AsyncCallable[ParameterSpec, GenericType]
234
275
  ) -> AsyncCallable[ParameterSpec, GenericType]:
276
+ """Apply rate limiting to an asynchronous function."""
277
+
235
278
  @wraps(fn)
236
- async def wrapper(*args: Any, **kwargs: Any) -> GenericType:
279
+ async def wrapper(*args: object, **kwargs: object) -> GenericType:
237
280
  self._initialize_async_primitives()
238
- assert self._rate_limit_handling_lock is not None and isinstance(
281
+ if self._rate_limit_handling_lock is None or not isinstance(
239
282
  self._rate_limit_handling_lock, asyncio.Lock
240
- )
241
- assert self._rate_limit_handling is not None and isinstance(
283
+ ):
284
+ raise RuntimeError("Rate limit lock not properly initialized")
285
+ if self._rate_limit_handling is None or not isinstance(
242
286
  self._rate_limit_handling, asyncio.Event
243
- )
287
+ ):
288
+ raise RuntimeError("Rate limit event not properly initialized")
244
289
  try:
245
290
  try:
246
291
  await asyncio.wait_for(
@@ -257,21 +302,25 @@ class RateLimiter:
257
302
  self._throttler.on_rate_limit_error(
258
303
  request_start_time, verbose=self._verbose
259
304
  )
305
+
306
+ async def _retry_async(remaining: int) -> GenericType:
307
+ try:
308
+ request_start_time = time.time()
309
+ await self._throttler.async_wait_until_ready()
310
+ return await fn(*args, **kwargs)
311
+ except self._rate_limit_error as e:
312
+ self._throttler.on_rate_limit_error(
313
+ request_start_time, verbose=self._verbose
314
+ )
315
+ if remaining <= 1:
316
+ raise RateLimitError(
317
+ f"Exceeded max ({self._max_rate_limit_retries}) retries"
318
+ ) from e
319
+ return await _retry_async(remaining - 1)
320
+
260
321
  try:
261
- for _attempt in range(self._max_rate_limit_retries):
262
- try:
263
- request_start_time = time.time()
264
- await self._throttler.async_wait_until_ready()
265
- return await fn(*args, **kwargs)
266
- except self._rate_limit_error:
267
- self._throttler.on_rate_limit_error(
268
- request_start_time, verbose=self._verbose
269
- )
270
- continue
322
+ return await _retry_async(self._max_rate_limit_retries)
271
323
  finally:
272
324
  self._rate_limit_handling.set() # allow new requests to start
273
- raise RateLimitError(
274
- f"Exceeded max ({self._max_rate_limit_retries}) retries"
275
- )
276
325
 
277
326
  return wrapper
@@ -1,13 +1,29 @@
1
+ """Type definitions for evaluators and evaluation results."""
2
+
1
3
  from __future__ import annotations
2
4
 
3
5
  from dataclasses import dataclass, field
4
6
  from enum import Enum
5
- from typing import Any, Dict, List, Mapping, Tuple
7
+ from typing import TYPE_CHECKING
8
+
9
+ if TYPE_CHECKING:
10
+ from collections.abc import Mapping
6
11
 
7
- JSONSerializable = Dict[str, Any] | List[Any] | str | int | float | bool
12
+ # Recursive type alias for JSON-serializable values
13
+ JSONSerializable = (
14
+ dict[str, "JSONSerializable"]
15
+ | list["JSONSerializable"]
16
+ | str
17
+ | int
18
+ | float
19
+ | bool
20
+ | None
21
+ )
8
22
 
9
23
 
10
24
  class AnnotatorKind(Enum):
25
+ """Enum representing the type of annotator used for evaluation."""
26
+
11
27
  CODE = "CODE"
12
28
  LLM = "LLM"
13
29
 
@@ -22,8 +38,8 @@ Explanation = str | None
22
38
 
23
39
  @dataclass(frozen=True)
24
40
  class EvaluationResult:
25
- """
26
- Represents the result of an evaluation.
41
+ """Represents the result of an evaluation.
42
+
27
43
  Args:
28
44
  score: The score of the evaluation.
29
45
  label: The label of the evaluation.
@@ -38,8 +54,9 @@ class EvaluationResult:
38
54
 
39
55
  @classmethod
40
56
  def from_dict(
41
- cls, obj: Mapping[str, Any] | None
57
+ cls, obj: Mapping[str, object] | None
42
58
  ) -> EvaluationResult | None:
59
+ """Create an EvaluationResult instance from a dictionary."""
43
60
  if not obj:
44
61
  return None
45
62
  return cls(
@@ -50,6 +67,11 @@ class EvaluationResult:
50
67
  )
51
68
 
52
69
  def __post_init__(self) -> None:
70
+ """Validate and normalize evaluation result fields.
71
+
72
+ Raises:
73
+ ValueError: If neither score nor label is specified.
74
+ """
53
75
  if self.score is None and not self.label:
54
76
  raise ValueError("Must specify score or label, or both")
55
77
  if self.score is None and not self.label:
@@ -66,7 +88,7 @@ EvaluatorOutput = (
66
88
  | int
67
89
  | float
68
90
  | str
69
- | Tuple[Score, Label, Explanation]
91
+ | tuple[Score, Label, Explanation]
70
92
  )
71
93
 
72
94
 
@@ -115,8 +137,13 @@ class EvaluationResultFieldNames:
115
137
  score: str | None = None
116
138
  label: str | None = None
117
139
  explanation: str | None = None
118
- metadata: Dict[str, str | None] | None = None
140
+ metadata: dict[str, str | None] | None = None
119
141
 
120
142
  def __post_init__(self) -> None:
143
+ """Validate that at least one output column is specified.
144
+
145
+ Raises:
146
+ ValueError: If neither score nor label column name is specified.
147
+ """
121
148
  if self.score is None and self.label is None:
122
149
  raise ValueError("Must specify score or label column name, or both")
@@ -1,6 +1,9 @@
1
+ """Utility functions for evaluator operations."""
2
+
1
3
  import functools
2
4
  import inspect
3
- from typing import TYPE_CHECKING, Any, Callable, Optional
5
+ from collections.abc import Callable
6
+ from typing import TYPE_CHECKING
4
7
 
5
8
  from tqdm.auto import tqdm
6
9
 
@@ -10,10 +13,8 @@ from arize.experiments.evaluators.types import (
10
13
  )
11
14
 
12
15
 
13
- def get_func_name(fn: Callable[..., Any]) -> str:
14
- """
15
- Makes a best-effort attempt to get the name of the function.
16
- """
16
+ def get_func_name(fn: Callable[..., object]) -> str:
17
+ """Makes a best-effort attempt to get the name of the function."""
17
18
  if isinstance(fn, functools.partial):
18
19
  return fn.func.__qualname__
19
20
  if hasattr(fn, "__qualname__") and not fn.__qualname__.endswith("<lambda>"):
@@ -26,17 +27,36 @@ if TYPE_CHECKING:
26
27
 
27
28
 
28
29
  def unwrap_json(obj: JSONSerializable) -> JSONSerializable:
30
+ """Unwrap a single-key JSON object to extract its value.
31
+
32
+ Args:
33
+ obj: A JSON-serializable object to unwrap.
34
+
35
+ Returns:
36
+ The unwrapped value if obj is a single-key dict, otherwise the original obj.
37
+ """
29
38
  if isinstance(obj, dict) and len(obj) == 1:
30
39
  key = next(iter(obj.keys()))
31
40
  output = obj[key]
32
- assert isinstance(
41
+ if not isinstance(
33
42
  output, (dict, list, str, int, float, bool, type(None))
34
- ), "Output must be JSON serializable"
43
+ ):
44
+ raise TypeError(
45
+ f"Evaluator output must be JSON serializable, got {type(output).__name__}"
46
+ )
35
47
  return output
36
48
  return obj
37
49
 
38
50
 
39
51
  def validate_evaluator_signature(sig: inspect.Signature) -> None:
52
+ """Validate that a function signature is compatible for use as an evaluator.
53
+
54
+ Args:
55
+ sig: The function signature to validate.
56
+
57
+ Raises:
58
+ ValueError: If the signature is invalid for use as an evaluator.
59
+ """
40
60
  # Check that the wrapped function has a valid signature for use as an evaluator
41
61
  # If it does not, raise an error to exit early before running evaluations
42
62
  params = sig.parameters
@@ -68,7 +88,7 @@ def validate_evaluator_signature(sig: inspect.Signature) -> None:
68
88
 
69
89
 
70
90
  def _bind_evaluator_signature(
71
- sig: inspect.Signature, **kwargs: Any
91
+ sig: inspect.Signature, **kwargs: object
72
92
  ) -> inspect.BoundArguments:
73
93
  parameter_mapping = {
74
94
  "input": kwargs.get("input"),
@@ -83,8 +103,7 @@ def _bind_evaluator_signature(
83
103
  parameter_name = next(iter(params))
84
104
  if parameter_name in parameter_mapping:
85
105
  return sig.bind(parameter_mapping[parameter_name])
86
- else:
87
- return sig.bind(parameter_mapping["experiment_output"])
106
+ return sig.bind(parameter_mapping["experiment_output"])
88
107
  return sig.bind_partial(
89
108
  **{
90
109
  name: parameter_mapping[name]
@@ -94,17 +113,27 @@ def _bind_evaluator_signature(
94
113
 
95
114
 
96
115
  def create_evaluator(
97
- name: Optional[str] = None,
98
- scorer: Optional[Callable[[Any], EvaluationResult]] = None,
99
- ) -> Callable[[Callable[..., Any]], "Evaluator"]:
116
+ name: str | None = None,
117
+ scorer: Callable[[object], EvaluationResult] | None = None,
118
+ ) -> Callable[[Callable[..., object]], "Evaluator"]:
119
+ """Create an evaluator decorator for wrapping evaluation functions.
120
+
121
+ Args:
122
+ name: Optional name for the evaluator. Defaults to None (uses function name).
123
+ scorer: Optional custom scoring function. Defaults to None (uses default scorer).
124
+
125
+ Returns:
126
+ A decorator that wraps a function as an Evaluator instance.
127
+ """
100
128
  if scorer is None:
101
129
  scorer = _default_eval_scorer
102
130
 
103
- def wrapper(func: Callable[..., Any]) -> "Evaluator":
131
+ def wrapper(func: Callable[..., object]) -> "Evaluator":
104
132
  nonlocal name
105
133
  if not name:
106
134
  name = get_func_name(func)
107
- assert name is not None
135
+ if name is None:
136
+ raise ValueError("Evaluator name cannot be None")
108
137
 
109
138
  wrapped_signature = inspect.signature(func)
110
139
  validate_evaluator_signature(wrapped_signature)
@@ -124,20 +153,22 @@ def create_evaluator(
124
153
  def _wrap_coroutine_evaluation_function(
125
154
  name: str,
126
155
  sig: inspect.Signature,
127
- convert_to_score: Callable[[Any], EvaluationResult],
128
- ) -> Callable[[Callable[..., Any]], "Evaluator"]:
156
+ convert_to_score: Callable[[object], EvaluationResult],
157
+ ) -> Callable[[Callable[..., object]], "Evaluator"]:
129
158
  from ..evaluators.base import Evaluator
130
159
 
131
- def wrapper(func: Callable[..., Any]) -> "Evaluator":
160
+ def wrapper(func: Callable[..., object]) -> "Evaluator":
132
161
  class AsyncEvaluator(Evaluator):
133
162
  def __init__(self) -> None:
134
163
  self._name = name
135
164
 
136
165
  @functools.wraps(func)
137
- async def __call__(self, *args: Any, **kwargs: Any) -> Any:
166
+ async def __call__(self, *args: object, **kwargs: object) -> object:
138
167
  return await func(*args, **kwargs)
139
168
 
140
- async def async_evaluate(self, **kwargs: Any) -> EvaluationResult:
169
+ async def async_evaluate(
170
+ self, **kwargs: object
171
+ ) -> EvaluationResult:
141
172
  bound_signature = _bind_evaluator_signature(sig, **kwargs)
142
173
  result = await func(
143
174
  *bound_signature.args, **bound_signature.kwargs
@@ -152,20 +183,20 @@ def _wrap_coroutine_evaluation_function(
152
183
  def _wrap_sync_evaluation_function(
153
184
  name: str,
154
185
  sig: inspect.Signature,
155
- convert_to_score: Callable[[Any], EvaluationResult],
156
- ) -> Callable[[Callable[..., Any]], "Evaluator"]:
186
+ convert_to_score: Callable[[object], EvaluationResult],
187
+ ) -> Callable[[Callable[..., object]], "Evaluator"]:
157
188
  from ..evaluators.base import Evaluator
158
189
 
159
- def wrapper(func: Callable[..., Any]) -> "Evaluator":
190
+ def wrapper(func: Callable[..., object]) -> "Evaluator":
160
191
  class SyncEvaluator(Evaluator):
161
192
  def __init__(self) -> None:
162
193
  self._name = name
163
194
 
164
195
  @functools.wraps(func)
165
- def __call__(self, *args: Any, **kwargs: Any) -> Any:
196
+ def __call__(self, *args: object, **kwargs: object) -> object:
166
197
  return func(*args, **kwargs)
167
198
 
168
- def evaluate(self, **kwargs: Any) -> EvaluationResult:
199
+ def evaluate(self, **kwargs: object) -> EvaluationResult:
169
200
  bound_signature = _bind_evaluator_signature(sig, **kwargs)
170
201
  result = func(*bound_signature.args, **bound_signature.kwargs)
171
202
  return convert_to_score(result)
@@ -175,7 +206,7 @@ def _wrap_sync_evaluation_function(
175
206
  return wrapper
176
207
 
177
208
 
178
- def _default_eval_scorer(result: Any) -> EvaluationResult:
209
+ def _default_eval_scorer(result: object) -> EvaluationResult:
179
210
  if isinstance(result, EvaluationResult):
180
211
  return result
181
212
  if isinstance(result, bool):
@@ -193,6 +224,13 @@ def _default_eval_scorer(result: Any) -> EvaluationResult:
193
224
  raise ValueError(f"Unsupported evaluation result type: {type(result)}")
194
225
 
195
226
 
196
- def printif(condition: bool, *args: Any, **kwargs: Any) -> None:
227
+ def printif(condition: bool, *args: object, **kwargs: object) -> None:
228
+ """Print to tqdm output if the condition is true.
229
+
230
+ Args:
231
+ condition: Whether to print the message.
232
+ *args: Positional arguments to pass to tqdm.write.
233
+ **kwargs: Keyword arguments to pass to tqdm.write.
234
+ """
197
235
  if condition:
198
236
  tqdm.write(*args, **kwargs)