arize 8.0.0a14__py3-none-any.whl → 8.0.0a16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. arize/__init__.py +70 -1
  2. arize/_flight/client.py +163 -43
  3. arize/_flight/types.py +1 -0
  4. arize/_generated/api_client/__init__.py +5 -1
  5. arize/_generated/api_client/api/datasets_api.py +6 -6
  6. arize/_generated/api_client/api/experiments_api.py +924 -61
  7. arize/_generated/api_client/api_client.py +1 -1
  8. arize/_generated/api_client/configuration.py +1 -1
  9. arize/_generated/api_client/exceptions.py +1 -1
  10. arize/_generated/api_client/models/__init__.py +3 -1
  11. arize/_generated/api_client/models/dataset.py +2 -2
  12. arize/_generated/api_client/models/dataset_version.py +1 -1
  13. arize/_generated/api_client/models/datasets_create_request.py +3 -3
  14. arize/_generated/api_client/models/datasets_list200_response.py +1 -1
  15. arize/_generated/api_client/models/datasets_list_examples200_response.py +1 -1
  16. arize/_generated/api_client/models/error.py +1 -1
  17. arize/_generated/api_client/models/experiment.py +6 -6
  18. arize/_generated/api_client/models/experiments_create_request.py +98 -0
  19. arize/_generated/api_client/models/experiments_list200_response.py +1 -1
  20. arize/_generated/api_client/models/experiments_runs_list200_response.py +92 -0
  21. arize/_generated/api_client/rest.py +1 -1
  22. arize/_generated/api_client/test/test_dataset.py +2 -1
  23. arize/_generated/api_client/test/test_dataset_version.py +1 -1
  24. arize/_generated/api_client/test/test_datasets_api.py +1 -1
  25. arize/_generated/api_client/test/test_datasets_create_request.py +2 -1
  26. arize/_generated/api_client/test/test_datasets_list200_response.py +1 -1
  27. arize/_generated/api_client/test/test_datasets_list_examples200_response.py +1 -1
  28. arize/_generated/api_client/test/test_error.py +1 -1
  29. arize/_generated/api_client/test/test_experiment.py +6 -1
  30. arize/_generated/api_client/test/test_experiments_api.py +23 -2
  31. arize/_generated/api_client/test/test_experiments_create_request.py +61 -0
  32. arize/_generated/api_client/test/test_experiments_list200_response.py +1 -1
  33. arize/_generated/api_client/test/test_experiments_runs_list200_response.py +56 -0
  34. arize/_generated/api_client_README.md +13 -8
  35. arize/client.py +19 -2
  36. arize/config.py +50 -3
  37. arize/constants/config.py +8 -2
  38. arize/constants/openinference.py +14 -0
  39. arize/constants/pyarrow.py +1 -0
  40. arize/datasets/__init__.py +0 -70
  41. arize/datasets/client.py +106 -19
  42. arize/datasets/errors.py +61 -0
  43. arize/datasets/validation.py +46 -0
  44. arize/experiments/client.py +455 -0
  45. arize/experiments/evaluators/__init__.py +0 -0
  46. arize/experiments/evaluators/base.py +255 -0
  47. arize/experiments/evaluators/exceptions.py +10 -0
  48. arize/experiments/evaluators/executors.py +502 -0
  49. arize/experiments/evaluators/rate_limiters.py +277 -0
  50. arize/experiments/evaluators/types.py +122 -0
  51. arize/experiments/evaluators/utils.py +198 -0
  52. arize/experiments/functions.py +920 -0
  53. arize/experiments/tracing.py +276 -0
  54. arize/experiments/types.py +394 -0
  55. arize/models/client.py +4 -1
  56. arize/spans/client.py +16 -20
  57. arize/utils/arrow.py +4 -3
  58. arize/utils/openinference_conversion.py +56 -0
  59. arize/utils/proto.py +13 -0
  60. arize/utils/size.py +22 -0
  61. arize/version.py +1 -1
  62. {arize-8.0.0a14.dist-info → arize-8.0.0a16.dist-info}/METADATA +3 -1
  63. {arize-8.0.0a14.dist-info → arize-8.0.0a16.dist-info}/RECORD +65 -44
  64. {arize-8.0.0a14.dist-info → arize-8.0.0a16.dist-info}/WHEEL +0 -0
  65. {arize-8.0.0a14.dist-info → arize-8.0.0a16.dist-info}/licenses/LICENSE.md +0 -0
@@ -0,0 +1,502 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import contextlib
5
+ import logging
6
+ import signal
7
+ import threading
8
+ import time
9
+ import traceback
10
+ from contextlib import contextmanager
11
+ from enum import Enum
12
+ from typing import (
13
+ Any,
14
+ Callable,
15
+ Coroutine,
16
+ Generator,
17
+ List,
18
+ Optional,
19
+ Protocol,
20
+ Sequence,
21
+ Tuple,
22
+ Union,
23
+ )
24
+
25
+ from tqdm.auto import tqdm
26
+
27
+ from arize.experiments.evaluators.exceptions import ArizeException
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ class Unset:
33
+ pass
34
+
35
+
36
+ _unset = Unset()
37
+
38
+
39
+ class ExecutionStatus(Enum):
40
+ DID_NOT_RUN = "DID NOT RUN"
41
+ COMPLETED = "COMPLETED"
42
+ COMPLETED_WITH_RETRIES = "COMPLETED WITH RETRIES"
43
+ FAILED = "FAILED"
44
+
45
+
46
+ class ExecutionDetails:
47
+ def __init__(self) -> None:
48
+ self.exceptions: List[Exception] = []
49
+ self.status = ExecutionStatus.DID_NOT_RUN
50
+ self.execution_seconds: float = 0
51
+
52
+ def fail(self) -> None:
53
+ self.status = ExecutionStatus.FAILED
54
+
55
+ def complete(self) -> None:
56
+ if self.exceptions:
57
+ self.status = ExecutionStatus.COMPLETED_WITH_RETRIES
58
+ else:
59
+ self.status = ExecutionStatus.COMPLETED
60
+
61
+ def log_exception(self, exc: Exception) -> None:
62
+ self.exceptions.append(exc)
63
+
64
+ def log_runtime(self, start_time: float) -> None:
65
+ self.execution_seconds += time.time() - start_time
66
+
67
+
68
+ class Executor(Protocol):
69
+ def run(
70
+ self, inputs: Sequence[Any]
71
+ ) -> Tuple[List[Any], List[ExecutionDetails]]: ...
72
+
73
+
74
+ class AsyncExecutor(Executor):
75
+ """
76
+ A class that provides asynchronous execution of tasks using a producer-consumer pattern.
77
+
78
+ An async interface is provided by the `execute` method, which returns a coroutine, and a sync
79
+ interface is provided by the `run` method.
80
+
81
+ Args:
82
+ generation_fn (Callable[[Any], Coroutine[Any, Any, Any]]): A coroutine function that
83
+ generates tasks to be executed.
84
+
85
+ concurrency (int, optional): The number of concurrent consumers. Defaults to 3.
86
+
87
+ tqdm_bar_format (Optional[str], optional): The format string for the progress bar.
88
+ Defaults to None.
89
+
90
+ max_retries (int, optional): The maximum number of times to retry on exceptions.
91
+ Defaults to 10.
92
+
93
+ exit_on_error (bool, optional): Whether to exit execution on the first encountered error.
94
+ Defaults to True.
95
+
96
+ fallback_return_value (Union[Unset, Any], optional): The fallback return value for tasks
97
+ that encounter errors. Defaults to _unset.
98
+
99
+ termination_signal (signal.Signals, optional): The signal handled to terminate the executor.
100
+
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ generation_fn: Callable[[Any], Coroutine[Any, Any, Any]],
106
+ concurrency: int = 3,
107
+ tqdm_bar_format: Optional[str] = None,
108
+ max_retries: int = 10,
109
+ exit_on_error: bool = True,
110
+ fallback_return_value: Union[Unset, Any] = _unset,
111
+ termination_signal: signal.Signals = signal.SIGINT,
112
+ ):
113
+ self.generate = generation_fn
114
+ self.fallback_return_value = fallback_return_value
115
+ self.concurrency = concurrency
116
+ self.tqdm_bar_format = tqdm_bar_format
117
+ self.max_retries = max_retries
118
+ self.exit_on_error = exit_on_error
119
+ self.base_priority = 0
120
+ self.termination_signal = termination_signal
121
+
122
+ async def producer(
123
+ self,
124
+ inputs: Sequence[Any],
125
+ queue: asyncio.PriorityQueue[Tuple[int, Any]],
126
+ max_fill: int,
127
+ done_producing: asyncio.Event,
128
+ termination_signal: asyncio.Event,
129
+ ) -> None:
130
+ try:
131
+ for index, input in enumerate(inputs):
132
+ if termination_signal.is_set():
133
+ break
134
+ while queue.qsize() >= max_fill:
135
+ # keep room in the queue for requeues
136
+ await asyncio.sleep(1)
137
+ await queue.put((self.base_priority, (index, input)))
138
+ finally:
139
+ done_producing.set()
140
+
141
+ async def consumer(
142
+ self,
143
+ outputs: List[Any],
144
+ execution_details: List[ExecutionDetails],
145
+ queue: asyncio.PriorityQueue[Tuple[int, Any]],
146
+ done_producing: asyncio.Event,
147
+ termination_event: asyncio.Event,
148
+ progress_bar: tqdm[Any],
149
+ ) -> None:
150
+ termination_event_watcher = None
151
+ while True:
152
+ marked_done = False
153
+ try:
154
+ priority, item = await asyncio.wait_for(queue.get(), timeout=1)
155
+ except asyncio.TimeoutError:
156
+ if done_producing.is_set() and queue.empty():
157
+ break
158
+ continue
159
+ if termination_event.is_set():
160
+ # discard any remaining items in the queue
161
+ queue.task_done()
162
+ marked_done = True
163
+ continue
164
+
165
+ index, payload = item
166
+
167
+ try:
168
+ task_start_time = time.time()
169
+ generate_task = asyncio.create_task(self.generate(payload))
170
+ termination_event_watcher = asyncio.create_task(
171
+ termination_event.wait()
172
+ )
173
+ done, pending = await asyncio.wait(
174
+ [generate_task, termination_event_watcher],
175
+ timeout=120,
176
+ return_when=asyncio.FIRST_COMPLETED,
177
+ )
178
+
179
+ if generate_task in done:
180
+ outputs[index] = generate_task.result()
181
+ execution_details[index].complete()
182
+ execution_details[index].log_runtime(task_start_time)
183
+ progress_bar.update()
184
+ elif termination_event.is_set():
185
+ # discard the pending task and remaining items in the queue
186
+ if not generate_task.done():
187
+ generate_task.cancel()
188
+ # Handle the cancellation exception
189
+ with contextlib.suppress(asyncio.CancelledError):
190
+ # allow any cleanup to finish for the cancelled task
191
+ await generate_task
192
+ queue.task_done()
193
+ marked_done = True
194
+ continue
195
+ else:
196
+ tqdm.write("Worker timeout, requeuing")
197
+ # task timeouts are requeued at the same priority
198
+ await queue.put((priority, item))
199
+ execution_details[index].log_runtime(task_start_time)
200
+ except Exception as exc:
201
+ execution_details[index].log_exception(exc)
202
+ execution_details[index].log_runtime(task_start_time)
203
+ is_arize_exception = isinstance(exc, ArizeException)
204
+ if (
205
+ retry_count := abs(priority)
206
+ ) < self.max_retries and not is_arize_exception:
207
+ tqdm.write(
208
+ f"Exception in worker on attempt {retry_count + 1}: raised {repr(exc)}"
209
+ )
210
+ tqdm.write("Requeuing...")
211
+ await queue.put((priority - 1, item))
212
+ else:
213
+ execution_details[index].fail()
214
+ tqdm.write(
215
+ f"Retries exhausted after {retry_count + 1} attempts: {traceback.format_exc()}"
216
+ )
217
+ if self.exit_on_error:
218
+ termination_event.set()
219
+ else:
220
+ progress_bar.update()
221
+ finally:
222
+ if not marked_done:
223
+ queue.task_done()
224
+ if (
225
+ termination_event_watcher
226
+ and not termination_event_watcher.done()
227
+ ):
228
+ termination_event_watcher.cancel()
229
+
230
+ async def execute(
231
+ self, inputs: Sequence[Any]
232
+ ) -> Tuple[List[Any], List[ExecutionDetails]]:
233
+ termination_event = asyncio.Event()
234
+
235
+ def termination_handler(signum: int, frame: Any) -> None:
236
+ termination_event.set()
237
+ tqdm.write(
238
+ "Process was interrupted. The return value will be incomplete..."
239
+ )
240
+
241
+ original_handler = signal.signal(
242
+ self.termination_signal, termination_handler
243
+ )
244
+ outputs = [self.fallback_return_value] * len(inputs)
245
+ execution_details = [ExecutionDetails() for _ in range(len(inputs))]
246
+ progress_bar = tqdm(total=len(inputs), bar_format=self.tqdm_bar_format)
247
+
248
+ max_queue_size = (
249
+ 5 * self.concurrency
250
+ ) # limit the queue to bound memory usage
251
+ max_fill = max_queue_size - (
252
+ 2 * self.concurrency
253
+ ) # ensure there is always room to requeue
254
+ queue: asyncio.PriorityQueue[Tuple[int, Any]] = asyncio.PriorityQueue(
255
+ maxsize=max_queue_size
256
+ )
257
+ done_producing = asyncio.Event()
258
+
259
+ producer = asyncio.create_task(
260
+ self.producer(
261
+ inputs, queue, max_fill, done_producing, termination_event
262
+ )
263
+ )
264
+ consumers = [
265
+ asyncio.create_task(
266
+ self.consumer(
267
+ outputs,
268
+ execution_details,
269
+ queue,
270
+ done_producing,
271
+ termination_event,
272
+ progress_bar,
273
+ )
274
+ )
275
+ for _ in range(self.concurrency)
276
+ ]
277
+
278
+ await asyncio.gather(producer, *consumers)
279
+ join_task = asyncio.create_task(queue.join())
280
+ termination_event_watcher = asyncio.create_task(
281
+ termination_event.wait()
282
+ )
283
+ done, pending = await asyncio.wait(
284
+ [join_task, termination_event_watcher],
285
+ return_when=asyncio.FIRST_COMPLETED,
286
+ )
287
+ if termination_event_watcher in done:
288
+ # Cancel all tasks
289
+ if not join_task.done():
290
+ join_task.cancel()
291
+ if not producer.done():
292
+ producer.cancel()
293
+ for task in consumers:
294
+ if not task.done():
295
+ task.cancel()
296
+
297
+ if not termination_event_watcher.done():
298
+ termination_event_watcher.cancel()
299
+
300
+ # reset the SIGTERM handler
301
+ signal.signal(
302
+ self.termination_signal, original_handler
303
+ ) # reset the SIGTERM handler
304
+ return outputs, execution_details
305
+
306
+ def run(
307
+ self, inputs: Sequence[Any]
308
+ ) -> Tuple[List[Any], List[ExecutionDetails]]:
309
+ return asyncio.run(self.execute(inputs))
310
+
311
+
312
+ class SyncExecutor(Executor):
313
+ """
314
+ Synchronous executor for generating outputs from inputs using a given generation function.
315
+
316
+ Args:
317
+ generation_fn (Callable[[Any], Any]): The generation function that takes an input and
318
+ returns an output.
319
+
320
+ tqdm_bar_format (Optional[str], optional): The format string for the progress bar. Defaults
321
+ to None.
322
+
323
+ max_retries (int, optional): The maximum number of times to retry on exceptions. Defaults to
324
+ 10.
325
+
326
+ exit_on_error (bool, optional): Whether to exit execution on the first encountered error.
327
+ Defaults to True.
328
+
329
+ fallback_return_value (Union[Unset, Any], optional): The fallback return value for tasks
330
+ that encounter errors. Defaults to _unset.
331
+
332
+ """
333
+
334
+ def __init__(
335
+ self,
336
+ generation_fn: Callable[[Any], Any],
337
+ tqdm_bar_format: Optional[str] = None,
338
+ max_retries: int = 10,
339
+ exit_on_error: bool = True,
340
+ fallback_return_value: Union[Unset, Any] = _unset,
341
+ termination_signal: Optional[signal.Signals] = signal.SIGINT,
342
+ ):
343
+ self.generate = generation_fn
344
+ self.fallback_return_value = fallback_return_value
345
+ self.tqdm_bar_format = tqdm_bar_format
346
+ self.max_retries = max_retries
347
+ self.exit_on_error = exit_on_error
348
+ self.termination_signal = termination_signal
349
+
350
+ self._TERMINATE = False
351
+
352
+ def _signal_handler(self, signum: int, frame: Any) -> None:
353
+ tqdm.write(
354
+ "Process was interrupted. The return value will be incomplete..."
355
+ )
356
+ self._TERMINATE = True
357
+
358
+ @contextmanager
359
+ def _executor_signal_handling(
360
+ self, signum: Optional[int]
361
+ ) -> Generator[None, None, None]:
362
+ original_handler = None
363
+ if signum is not None:
364
+ original_handler = signal.signal(signum, self._signal_handler)
365
+ try:
366
+ yield
367
+ finally:
368
+ signal.signal(signum, original_handler)
369
+ else:
370
+ yield
371
+
372
+ def run(self, inputs: Sequence[Any]) -> Tuple[List[Any], List[Any]]:
373
+ with self._executor_signal_handling(self.termination_signal):
374
+ outputs = [self.fallback_return_value] * len(inputs)
375
+ execution_details: List[ExecutionDetails] = [
376
+ ExecutionDetails() for _ in range(len(inputs))
377
+ ]
378
+ progress_bar = tqdm(
379
+ total=len(inputs), bar_format=self.tqdm_bar_format
380
+ )
381
+
382
+ for index, input in enumerate(inputs):
383
+ task_start_time = time.time()
384
+ try:
385
+ for attempt in range(self.max_retries + 1):
386
+ if self._TERMINATE:
387
+ return outputs, execution_details
388
+ try:
389
+ result = self.generate(input)
390
+ outputs[index] = result
391
+ execution_details[index].complete()
392
+ progress_bar.update()
393
+ break
394
+ except Exception as exc:
395
+ execution_details[index].log_exception(exc)
396
+ is_arize_exception = isinstance(exc, ArizeException)
397
+ if (
398
+ attempt >= self.max_retries
399
+ or is_arize_exception
400
+ ):
401
+ raise exc
402
+ else:
403
+ tqdm.write(
404
+ f"Exception in worker on attempt {attempt + 1}: {exc}"
405
+ )
406
+ tqdm.write("Retrying...")
407
+ except Exception as exc:
408
+ execution_details[index].fail()
409
+ tqdm.write(
410
+ f"Retries exhausted after {attempt + 1} attempts: {exc}"
411
+ )
412
+ if self.exit_on_error:
413
+ return outputs, execution_details
414
+ else:
415
+ progress_bar.update()
416
+ finally:
417
+ execution_details[index].log_runtime(task_start_time)
418
+ return outputs, execution_details
419
+
420
+
421
+ def get_executor_on_sync_context(
422
+ sync_fn: Callable[[Any], Any],
423
+ async_fn: Callable[[Any], Coroutine[Any, Any, Any]],
424
+ run_sync: bool = False,
425
+ concurrency: int = 3,
426
+ tqdm_bar_format: Optional[str] = None,
427
+ max_retries: int = 10,
428
+ exit_on_error: bool = True,
429
+ fallback_return_value: Union[Unset, Any] = _unset,
430
+ ) -> Executor:
431
+ if threading.current_thread() is not threading.main_thread():
432
+ # run evals synchronously if not in the main thread
433
+
434
+ if run_sync is False:
435
+ logger.warning(
436
+ "Async evals execution is not supported in non-main threads. Falling back to sync."
437
+ )
438
+ return SyncExecutor(
439
+ sync_fn,
440
+ tqdm_bar_format=tqdm_bar_format,
441
+ exit_on_error=exit_on_error,
442
+ max_retries=max_retries,
443
+ fallback_return_value=fallback_return_value,
444
+ termination_signal=None,
445
+ )
446
+
447
+ if run_sync is True:
448
+ return SyncExecutor(
449
+ sync_fn,
450
+ tqdm_bar_format=tqdm_bar_format,
451
+ max_retries=max_retries,
452
+ exit_on_error=exit_on_error,
453
+ fallback_return_value=fallback_return_value,
454
+ )
455
+
456
+ if _running_event_loop_exists():
457
+ if getattr(asyncio, "_nest_patched", False):
458
+ return AsyncExecutor(
459
+ async_fn,
460
+ concurrency=concurrency,
461
+ tqdm_bar_format=tqdm_bar_format,
462
+ max_retries=max_retries,
463
+ exit_on_error=exit_on_error,
464
+ fallback_return_value=fallback_return_value,
465
+ )
466
+ else:
467
+ logger.warning(
468
+ "🐌!! If running inside a notebook, patching the event loop with "
469
+ "nest_asyncio will allow asynchronous eval submission, and is significantly "
470
+ "faster. To patch the event loop, run `nest_asyncio.apply()`."
471
+ )
472
+ return SyncExecutor(
473
+ sync_fn,
474
+ tqdm_bar_format=tqdm_bar_format,
475
+ max_retries=max_retries,
476
+ exit_on_error=exit_on_error,
477
+ fallback_return_value=fallback_return_value,
478
+ )
479
+ else:
480
+ return AsyncExecutor(
481
+ async_fn,
482
+ concurrency=concurrency,
483
+ tqdm_bar_format=tqdm_bar_format,
484
+ max_retries=max_retries,
485
+ exit_on_error=exit_on_error,
486
+ fallback_return_value=fallback_return_value,
487
+ )
488
+
489
+
490
+ def _running_event_loop_exists() -> bool:
491
+ """
492
+ Checks for a running event loop.
493
+
494
+ Returns:
495
+ bool: True if a running event loop exists, False otherwise.
496
+
497
+ """
498
+ try:
499
+ asyncio.get_running_loop()
500
+ return True
501
+ except RuntimeError:
502
+ return False