arize 8.0.0a22__py3-none-any.whl → 8.0.0a23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. arize/__init__.py +17 -9
  2. arize/_exporter/client.py +55 -36
  3. arize/_exporter/parsers/tracing_data_parser.py +41 -30
  4. arize/_exporter/validation.py +3 -3
  5. arize/_flight/client.py +207 -76
  6. arize/_generated/api_client/__init__.py +30 -6
  7. arize/_generated/api_client/api/__init__.py +1 -0
  8. arize/_generated/api_client/api/datasets_api.py +864 -190
  9. arize/_generated/api_client/api/experiments_api.py +167 -131
  10. arize/_generated/api_client/api/projects_api.py +1197 -0
  11. arize/_generated/api_client/api_client.py +2 -2
  12. arize/_generated/api_client/configuration.py +42 -34
  13. arize/_generated/api_client/exceptions.py +2 -2
  14. arize/_generated/api_client/models/__init__.py +15 -4
  15. arize/_generated/api_client/models/dataset.py +10 -10
  16. arize/_generated/api_client/models/dataset_example.py +111 -0
  17. arize/_generated/api_client/models/dataset_example_update.py +100 -0
  18. arize/_generated/api_client/models/dataset_version.py +13 -13
  19. arize/_generated/api_client/models/datasets_create_request.py +16 -8
  20. arize/_generated/api_client/models/datasets_examples_insert_request.py +100 -0
  21. arize/_generated/api_client/models/datasets_examples_list200_response.py +106 -0
  22. arize/_generated/api_client/models/datasets_examples_update_request.py +102 -0
  23. arize/_generated/api_client/models/datasets_list200_response.py +10 -4
  24. arize/_generated/api_client/models/experiment.py +14 -16
  25. arize/_generated/api_client/models/experiment_run.py +108 -0
  26. arize/_generated/api_client/models/experiment_run_create.py +102 -0
  27. arize/_generated/api_client/models/experiments_create_request.py +16 -10
  28. arize/_generated/api_client/models/experiments_list200_response.py +10 -4
  29. arize/_generated/api_client/models/experiments_runs_list200_response.py +19 -5
  30. arize/_generated/api_client/models/{error.py → pagination_metadata.py} +13 -11
  31. arize/_generated/api_client/models/primitive_value.py +172 -0
  32. arize/_generated/api_client/models/problem.py +100 -0
  33. arize/_generated/api_client/models/project.py +99 -0
  34. arize/_generated/api_client/models/{datasets_list_examples200_response.py → projects_create_request.py} +13 -11
  35. arize/_generated/api_client/models/projects_list200_response.py +106 -0
  36. arize/_generated/api_client/rest.py +2 -2
  37. arize/_generated/api_client/test/test_dataset.py +4 -2
  38. arize/_generated/api_client/test/test_dataset_example.py +56 -0
  39. arize/_generated/api_client/test/test_dataset_example_update.py +52 -0
  40. arize/_generated/api_client/test/test_dataset_version.py +7 -2
  41. arize/_generated/api_client/test/test_datasets_api.py +27 -13
  42. arize/_generated/api_client/test/test_datasets_create_request.py +8 -4
  43. arize/_generated/api_client/test/{test_datasets_list_examples200_response.py → test_datasets_examples_insert_request.py} +19 -15
  44. arize/_generated/api_client/test/test_datasets_examples_list200_response.py +66 -0
  45. arize/_generated/api_client/test/test_datasets_examples_update_request.py +61 -0
  46. arize/_generated/api_client/test/test_datasets_list200_response.py +9 -3
  47. arize/_generated/api_client/test/test_experiment.py +2 -4
  48. arize/_generated/api_client/test/test_experiment_run.py +56 -0
  49. arize/_generated/api_client/test/test_experiment_run_create.py +54 -0
  50. arize/_generated/api_client/test/test_experiments_api.py +6 -6
  51. arize/_generated/api_client/test/test_experiments_create_request.py +9 -6
  52. arize/_generated/api_client/test/test_experiments_list200_response.py +9 -5
  53. arize/_generated/api_client/test/test_experiments_runs_list200_response.py +15 -5
  54. arize/_generated/api_client/test/test_pagination_metadata.py +53 -0
  55. arize/_generated/api_client/test/{test_error.py → test_primitive_value.py} +13 -14
  56. arize/_generated/api_client/test/test_problem.py +57 -0
  57. arize/_generated/api_client/test/test_project.py +58 -0
  58. arize/_generated/api_client/test/test_projects_api.py +59 -0
  59. arize/_generated/api_client/test/test_projects_create_request.py +54 -0
  60. arize/_generated/api_client/test/test_projects_list200_response.py +70 -0
  61. arize/_generated/api_client_README.md +43 -29
  62. arize/_generated/protocol/flight/flight_pb2.py +400 -0
  63. arize/_lazy.py +27 -19
  64. arize/client.py +268 -55
  65. arize/config.py +365 -116
  66. arize/constants/__init__.py +1 -0
  67. arize/constants/config.py +11 -4
  68. arize/constants/ml.py +6 -4
  69. arize/constants/openinference.py +2 -0
  70. arize/constants/pyarrow.py +2 -0
  71. arize/constants/spans.py +3 -1
  72. arize/datasets/__init__.py +1 -0
  73. arize/datasets/client.py +299 -84
  74. arize/datasets/errors.py +32 -2
  75. arize/datasets/validation.py +18 -8
  76. arize/embeddings/__init__.py +2 -0
  77. arize/embeddings/auto_generator.py +23 -19
  78. arize/embeddings/base_generators.py +89 -36
  79. arize/embeddings/constants.py +2 -0
  80. arize/embeddings/cv_generators.py +26 -4
  81. arize/embeddings/errors.py +27 -5
  82. arize/embeddings/nlp_generators.py +31 -12
  83. arize/embeddings/tabular_generators.py +32 -20
  84. arize/embeddings/usecases.py +12 -2
  85. arize/exceptions/__init__.py +1 -0
  86. arize/exceptions/auth.py +11 -1
  87. arize/exceptions/base.py +29 -4
  88. arize/exceptions/models.py +21 -2
  89. arize/exceptions/parameters.py +31 -0
  90. arize/exceptions/spaces.py +12 -1
  91. arize/exceptions/types.py +86 -7
  92. arize/exceptions/values.py +220 -20
  93. arize/experiments/__init__.py +1 -0
  94. arize/experiments/client.py +389 -285
  95. arize/experiments/evaluators/__init__.py +1 -0
  96. arize/experiments/evaluators/base.py +74 -41
  97. arize/experiments/evaluators/exceptions.py +6 -3
  98. arize/experiments/evaluators/executors.py +121 -73
  99. arize/experiments/evaluators/rate_limiters.py +106 -57
  100. arize/experiments/evaluators/types.py +34 -7
  101. arize/experiments/evaluators/utils.py +65 -27
  102. arize/experiments/functions.py +103 -101
  103. arize/experiments/tracing.py +52 -44
  104. arize/experiments/types.py +56 -31
  105. arize/logging.py +54 -22
  106. arize/models/__init__.py +1 -0
  107. arize/models/batch_validation/__init__.py +1 -0
  108. arize/models/batch_validation/errors.py +543 -65
  109. arize/models/batch_validation/validator.py +339 -300
  110. arize/models/bounded_executor.py +20 -7
  111. arize/models/casting.py +75 -29
  112. arize/models/client.py +326 -107
  113. arize/models/proto.py +95 -40
  114. arize/models/stream_validation.py +42 -14
  115. arize/models/surrogate_explainer/__init__.py +1 -0
  116. arize/models/surrogate_explainer/mimic.py +24 -13
  117. arize/pre_releases.py +43 -0
  118. arize/projects/__init__.py +1 -0
  119. arize/projects/client.py +129 -0
  120. arize/regions.py +40 -0
  121. arize/spans/__init__.py +1 -0
  122. arize/spans/client.py +130 -106
  123. arize/spans/columns.py +13 -0
  124. arize/spans/conversion.py +54 -38
  125. arize/spans/validation/__init__.py +1 -0
  126. arize/spans/validation/annotations/__init__.py +1 -0
  127. arize/spans/validation/annotations/annotations_validation.py +6 -4
  128. arize/spans/validation/annotations/dataframe_form_validation.py +13 -11
  129. arize/spans/validation/annotations/value_validation.py +35 -11
  130. arize/spans/validation/common/__init__.py +1 -0
  131. arize/spans/validation/common/argument_validation.py +33 -8
  132. arize/spans/validation/common/dataframe_form_validation.py +35 -9
  133. arize/spans/validation/common/errors.py +211 -11
  134. arize/spans/validation/common/value_validation.py +80 -13
  135. arize/spans/validation/evals/__init__.py +1 -0
  136. arize/spans/validation/evals/dataframe_form_validation.py +28 -8
  137. arize/spans/validation/evals/evals_validation.py +34 -4
  138. arize/spans/validation/evals/value_validation.py +26 -3
  139. arize/spans/validation/metadata/__init__.py +1 -1
  140. arize/spans/validation/metadata/argument_validation.py +14 -5
  141. arize/spans/validation/metadata/dataframe_form_validation.py +26 -10
  142. arize/spans/validation/metadata/value_validation.py +24 -10
  143. arize/spans/validation/spans/__init__.py +1 -0
  144. arize/spans/validation/spans/dataframe_form_validation.py +34 -13
  145. arize/spans/validation/spans/spans_validation.py +35 -4
  146. arize/spans/validation/spans/value_validation.py +76 -7
  147. arize/types.py +293 -157
  148. arize/utils/__init__.py +1 -0
  149. arize/utils/arrow.py +31 -15
  150. arize/utils/cache.py +34 -6
  151. arize/utils/dataframe.py +19 -2
  152. arize/utils/online_tasks/__init__.py +2 -0
  153. arize/utils/online_tasks/dataframe_preprocessor.py +53 -41
  154. arize/utils/openinference_conversion.py +44 -5
  155. arize/utils/proto.py +10 -0
  156. arize/utils/size.py +5 -3
  157. arize/version.py +3 -1
  158. {arize-8.0.0a22.dist-info → arize-8.0.0a23.dist-info}/METADATA +4 -3
  159. arize-8.0.0a23.dist-info/RECORD +174 -0
  160. {arize-8.0.0a22.dist-info → arize-8.0.0a23.dist-info}/WHEEL +1 -1
  161. arize-8.0.0a23.dist-info/licenses/LICENSE +176 -0
  162. arize-8.0.0a23.dist-info/licenses/NOTICE +13 -0
  163. arize/_generated/protocol/flight/export_pb2.py +0 -61
  164. arize/_generated/protocol/flight/ingest_pb2.py +0 -365
  165. arize-8.0.0a22.dist-info/RECORD +0 -146
  166. arize-8.0.0a22.dist-info/licenses/LICENSE.md +0 -12
@@ -1,3 +1,5 @@
1
+ """Evaluator executors for managing concurrent evaluation tasks."""
2
+
1
3
  from __future__ import annotations
2
4
 
3
5
  import asyncio
@@ -10,18 +12,14 @@ import traceback
10
12
  from contextlib import contextmanager
11
13
  from enum import Enum
12
14
  from typing import (
15
+ TYPE_CHECKING,
13
16
  Any,
14
- Callable,
15
- Coroutine,
16
- Generator,
17
- List,
18
- Optional,
19
17
  Protocol,
20
- Sequence,
21
- Tuple,
22
- Union,
23
18
  )
24
19
 
20
+ if TYPE_CHECKING:
21
+ from collections.abc import Callable, Coroutine, Generator, Sequence
22
+
25
23
  from tqdm.auto import tqdm
26
24
 
27
25
  from arize.experiments.evaluators.exceptions import ArizeException
@@ -30,13 +28,15 @@ logger = logging.getLogger(__name__)
30
28
 
31
29
 
32
30
  class Unset:
33
- pass
31
+ """Sentinel class representing an unset or undefined value."""
34
32
 
35
33
 
36
34
  _unset = Unset()
37
35
 
38
36
 
39
37
  class ExecutionStatus(Enum):
38
+ """Enum representing execution status states for experiment tasks."""
39
+
40
40
  DID_NOT_RUN = "DID NOT RUN"
41
41
  COMPLETED = "COMPLETED"
42
42
  COMPLETED_WITH_RETRIES = "COMPLETED WITH RETRIES"
@@ -44,42 +44,52 @@ class ExecutionStatus(Enum):
44
44
 
45
45
 
46
46
  class ExecutionDetails:
47
+ """Container for tracking execution status, exceptions, and runtime metrics."""
48
+
47
49
  def __init__(self) -> None:
48
- self.exceptions: List[Exception] = []
50
+ """Initialize execution details with default state."""
51
+ self.exceptions: list[Exception] = []
49
52
  self.status = ExecutionStatus.DID_NOT_RUN
50
53
  self.execution_seconds: float = 0
51
54
 
52
55
  def fail(self) -> None:
56
+ """Mark this execution as failed."""
53
57
  self.status = ExecutionStatus.FAILED
54
58
 
55
59
  def complete(self) -> None:
60
+ """Mark this execution as completed, with or without retries."""
56
61
  if self.exceptions:
57
62
  self.status = ExecutionStatus.COMPLETED_WITH_RETRIES
58
63
  else:
59
64
  self.status = ExecutionStatus.COMPLETED
60
65
 
61
66
  def log_exception(self, exc: Exception) -> None:
67
+ """Log an exception that occurred during execution."""
62
68
  self.exceptions.append(exc)
63
69
 
64
70
  def log_runtime(self, start_time: float) -> None:
71
+ """Log the runtime duration for this execution."""
65
72
  self.execution_seconds += time.time() - start_time
66
73
 
67
74
 
68
75
  class Executor(Protocol):
76
+ """Protocol defining the interface for experiment task executors."""
77
+
69
78
  def run(
70
79
  self, inputs: Sequence[Any]
71
- ) -> Tuple[List[Any], List[ExecutionDetails]]: ...
80
+ ) -> tuple[list[object], list[ExecutionDetails]]:
81
+ """Execute the generation function on all inputs and return outputs with execution details."""
82
+ ...
72
83
 
73
84
 
74
85
  class AsyncExecutor(Executor):
75
- """
76
- A class that provides asynchronous execution of tasks using a producer-consumer pattern.
86
+ """A class that provides asynchronous execution of tasks using a producer-consumer pattern.
77
87
 
78
88
  An async interface is provided by the `execute` method, which returns a coroutine, and a sync
79
89
  interface is provided by the `run` method.
80
90
 
81
91
  Args:
82
- generation_fn (Callable[[Any], Coroutine[Any, Any, Any]]): A coroutine function that
92
+ generation_fn (Callable[[object], Coroutine[Any, Any, Any]]): A coroutine function that
83
93
  generates tasks to be executed.
84
94
 
85
95
  concurrency (int, optional): The number of concurrent consumers. Defaults to 3.
@@ -102,14 +112,25 @@ class AsyncExecutor(Executor):
102
112
 
103
113
  def __init__(
104
114
  self,
105
- generation_fn: Callable[[Any], Coroutine[Any, Any, Any]],
115
+ generation_fn: Callable[[object], Coroutine[Any, Any, Any]],
106
116
  concurrency: int = 3,
107
- tqdm_bar_format: Optional[str] = None,
117
+ tqdm_bar_format: str | None = None,
108
118
  max_retries: int = 10,
109
119
  exit_on_error: bool = True,
110
- fallback_return_value: Union[Unset, Any] = _unset,
120
+ fallback_return_value: Unset | object = _unset,
111
121
  termination_signal: signal.Signals = signal.SIGINT,
112
- ):
122
+ ) -> None:
123
+ """Initialize the async executor with configuration parameters.
124
+
125
+ Args:
126
+ generation_fn: Async function to execute for each item.
127
+ concurrency: Maximum number of concurrent executions.
128
+ tqdm_bar_format: Custom format string for progress bar.
129
+ max_retries: Maximum retry attempts for failed executions.
130
+ exit_on_error: Whether to exit on first error.
131
+ fallback_return_value: Value to return when execution fails.
132
+ termination_signal: Signal to handle for graceful termination.
133
+ """
113
134
  self.generate = generation_fn
114
135
  self.fallback_return_value = fallback_return_value
115
136
  self.concurrency = concurrency
@@ -122,11 +143,12 @@ class AsyncExecutor(Executor):
122
143
  async def producer(
123
144
  self,
124
145
  inputs: Sequence[Any],
125
- queue: asyncio.PriorityQueue[Tuple[int, Any]],
146
+ queue: asyncio.PriorityQueue[tuple[int, Any]],
126
147
  max_fill: int,
127
148
  done_producing: asyncio.Event,
128
149
  termination_signal: asyncio.Event,
129
150
  ) -> None:
151
+ """Produce tasks by adding inputs to the queue for consumers to process."""
130
152
  try:
131
153
  for index, input in enumerate(inputs):
132
154
  if termination_signal.is_set():
@@ -140,13 +162,14 @@ class AsyncExecutor(Executor):
140
162
 
141
163
  async def consumer(
142
164
  self,
143
- outputs: List[Any],
144
- execution_details: List[ExecutionDetails],
145
- queue: asyncio.PriorityQueue[Tuple[int, Any]],
165
+ outputs: list[object],
166
+ execution_details: list[ExecutionDetails],
167
+ queue: asyncio.PriorityQueue[tuple[int, Any]],
146
168
  done_producing: asyncio.Event,
147
169
  termination_event: asyncio.Event,
148
170
  progress_bar: tqdm[Any],
149
171
  ) -> None:
172
+ """Consume tasks from the queue and execute the generation function on each input."""
150
173
  termination_event_watcher = None
151
174
  while True:
152
175
  marked_done = False
@@ -170,7 +193,7 @@ class AsyncExecutor(Executor):
170
193
  termination_event_watcher = asyncio.create_task(
171
194
  termination_event.wait()
172
195
  )
173
- done, pending = await asyncio.wait(
196
+ done, _pending = await asyncio.wait(
174
197
  [generate_task, termination_event_watcher],
175
198
  timeout=120,
176
199
  return_when=asyncio.FIRST_COMPLETED,
@@ -205,7 +228,7 @@ class AsyncExecutor(Executor):
205
228
  retry_count := abs(priority)
206
229
  ) < self.max_retries and not is_arize_exception:
207
230
  tqdm.write(
208
- f"Exception in worker on attempt {retry_count + 1}: raised {repr(exc)}"
231
+ f"Exception in worker on attempt {retry_count + 1}: raised {exc!r}"
209
232
  )
210
233
  tqdm.write("Requeuing...")
211
234
  await queue.put((priority - 1, item))
@@ -229,10 +252,11 @@ class AsyncExecutor(Executor):
229
252
 
230
253
  async def execute(
231
254
  self, inputs: Sequence[Any]
232
- ) -> Tuple[List[Any], List[ExecutionDetails]]:
255
+ ) -> tuple[list[object], list[ExecutionDetails]]:
256
+ """Execute all inputs asynchronously using producer-consumer pattern."""
233
257
  termination_event = asyncio.Event()
234
258
 
235
- def termination_handler(signum: int, frame: Any) -> None:
259
+ def termination_handler(signum: int, frame: object) -> None:
236
260
  termination_event.set()
237
261
  tqdm.write(
238
262
  "Process was interrupted. The return value will be incomplete..."
@@ -251,7 +275,7 @@ class AsyncExecutor(Executor):
251
275
  max_fill = max_queue_size - (
252
276
  2 * self.concurrency
253
277
  ) # ensure there is always room to requeue
254
- queue: asyncio.PriorityQueue[Tuple[int, Any]] = asyncio.PriorityQueue(
278
+ queue: asyncio.PriorityQueue[tuple[int, Any]] = asyncio.PriorityQueue(
255
279
  maxsize=max_queue_size
256
280
  )
257
281
  done_producing = asyncio.Event()
@@ -280,7 +304,7 @@ class AsyncExecutor(Executor):
280
304
  termination_event_watcher = asyncio.create_task(
281
305
  termination_event.wait()
282
306
  )
283
- done, pending = await asyncio.wait(
307
+ done, _pending = await asyncio.wait(
284
308
  [join_task, termination_event_watcher],
285
309
  return_when=asyncio.FIRST_COMPLETED,
286
310
  )
@@ -305,16 +329,16 @@ class AsyncExecutor(Executor):
305
329
 
306
330
  def run(
307
331
  self, inputs: Sequence[Any]
308
- ) -> Tuple[List[Any], List[ExecutionDetails]]:
332
+ ) -> tuple[list[object], list[ExecutionDetails]]:
333
+ """Execute all inputs asynchronously and return outputs with execution details."""
309
334
  return asyncio.run(self.execute(inputs))
310
335
 
311
336
 
312
337
  class SyncExecutor(Executor):
313
- """
314
- Synchronous executor for generating outputs from inputs using a given generation function.
338
+ """Synchronous executor for generating outputs from inputs using a given generation function.
315
339
 
316
340
  Args:
317
- generation_fn (Callable[[Any], Any]): The generation function that takes an input and
341
+ generation_fn (Callable[[object], Any]): The generation function that takes an input and
318
342
  returns an output.
319
343
 
320
344
  tqdm_bar_format (Optional[str], optional): The format string for the progress bar. Defaults
@@ -333,13 +357,23 @@ class SyncExecutor(Executor):
333
357
 
334
358
  def __init__(
335
359
  self,
336
- generation_fn: Callable[[Any], Any],
337
- tqdm_bar_format: Optional[str] = None,
360
+ generation_fn: Callable[[object], Any],
361
+ tqdm_bar_format: str | None = None,
338
362
  max_retries: int = 10,
339
363
  exit_on_error: bool = True,
340
- fallback_return_value: Union[Unset, Any] = _unset,
341
- termination_signal: Optional[signal.Signals] = signal.SIGINT,
342
- ):
364
+ fallback_return_value: Unset | object = _unset,
365
+ termination_signal: signal.Signals | None = signal.SIGINT,
366
+ ) -> None:
367
+ """Initialize the sync executor with configuration parameters.
368
+
369
+ Args:
370
+ generation_fn: Synchronous function to execute for each item.
371
+ tqdm_bar_format: Custom format string for progress bar.
372
+ max_retries: Maximum retry attempts for failed executions.
373
+ exit_on_error: Whether to exit on first error.
374
+ fallback_return_value: Value to return when execution fails.
375
+ termination_signal: Signal to handle for graceful termination.
376
+ """
343
377
  self.generate = generation_fn
344
378
  self.fallback_return_value = fallback_return_value
345
379
  self.tqdm_bar_format = tqdm_bar_format
@@ -349,7 +383,7 @@ class SyncExecutor(Executor):
349
383
 
350
384
  self._TERMINATE = False
351
385
 
352
- def _signal_handler(self, signum: int, frame: Any) -> None:
386
+ def _signal_handler(self, signum: int, frame: object) -> None:
353
387
  tqdm.write(
354
388
  "Process was interrupted. The return value will be incomplete..."
355
389
  )
@@ -357,7 +391,7 @@ class SyncExecutor(Executor):
357
391
 
358
392
  @contextmanager
359
393
  def _executor_signal_handling(
360
- self, signum: Optional[int]
394
+ self, signum: int | None
361
395
  ) -> Generator[None, None, None]:
362
396
  original_handler = None
363
397
  if signum is not None:
@@ -369,10 +403,11 @@ class SyncExecutor(Executor):
369
403
  else:
370
404
  yield
371
405
 
372
- def run(self, inputs: Sequence[Any]) -> Tuple[List[Any], List[Any]]:
406
+ def run(self, inputs: Sequence[Any]) -> tuple[list[object], list[object]]:
407
+ """Execute all inputs synchronously and return outputs with execution details."""
373
408
  with self._executor_signal_handling(self.termination_signal):
374
409
  outputs = [self.fallback_return_value] * len(inputs)
375
- execution_details: List[ExecutionDetails] = [
410
+ execution_details: list[ExecutionDetails] = [
376
411
  ExecutionDetails() for _ in range(len(inputs))
377
412
  ]
378
413
  progress_bar = tqdm(
@@ -398,12 +433,11 @@ class SyncExecutor(Executor):
398
433
  attempt >= self.max_retries
399
434
  or is_arize_exception
400
435
  ):
401
- raise exc
402
- else:
403
- tqdm.write(
404
- f"Exception in worker on attempt {attempt + 1}: {exc}"
405
- )
406
- tqdm.write("Retrying...")
436
+ raise
437
+ tqdm.write(
438
+ f"Exception in worker on attempt {attempt + 1}: {exc}"
439
+ )
440
+ tqdm.write("Retrying...")
407
441
  except Exception as exc:
408
442
  execution_details[index].fail()
409
443
  tqdm.write(
@@ -411,23 +445,40 @@ class SyncExecutor(Executor):
411
445
  )
412
446
  if self.exit_on_error:
413
447
  return outputs, execution_details
414
- else:
415
- progress_bar.update()
448
+ progress_bar.update()
416
449
  finally:
417
450
  execution_details[index].log_runtime(task_start_time)
418
451
  return outputs, execution_details
419
452
 
420
453
 
421
454
  def get_executor_on_sync_context(
422
- sync_fn: Callable[[Any], Any],
423
- async_fn: Callable[[Any], Coroutine[Any, Any, Any]],
455
+ sync_fn: Callable[[object], Any],
456
+ async_fn: Callable[[object], Coroutine[Any, Any, Any]],
424
457
  run_sync: bool = False,
425
458
  concurrency: int = 3,
426
- tqdm_bar_format: Optional[str] = None,
459
+ tqdm_bar_format: str | None = None,
427
460
  max_retries: int = 10,
428
461
  exit_on_error: bool = True,
429
- fallback_return_value: Union[Unset, Any] = _unset,
462
+ fallback_return_value: Unset | object = _unset,
430
463
  ) -> Executor:
464
+ """Get an appropriate executor based on the current threading context.
465
+
466
+ Automatically selects between sync and async execution based on thread context.
467
+ Falls back to sync execution when not in the main thread.
468
+
469
+ Args:
470
+ sync_fn: Synchronous function to execute.
471
+ async_fn: Asynchronous function to execute.
472
+ run_sync: Force synchronous execution. Defaults to False.
473
+ concurrency: Number of concurrent executions for async. Defaults to 3.
474
+ tqdm_bar_format: Format string for progress bar. Defaults to None.
475
+ max_retries: Maximum number of retry attempts. Defaults to 10.
476
+ exit_on_error: Whether to exit on first error. Defaults to True.
477
+ fallback_return_value: Value to return on failure. Defaults to unset.
478
+
479
+ Returns:
480
+ An Executor instance configured for the current context.
481
+ """
431
482
  if threading.current_thread() is not threading.main_thread():
432
483
  # run evals synchronously if not in the main thread
433
484
 
@@ -463,33 +514,30 @@ def get_executor_on_sync_context(
463
514
  exit_on_error=exit_on_error,
464
515
  fallback_return_value=fallback_return_value,
465
516
  )
466
- else:
467
- logger.warning(
468
- "🐌!! If running inside a notebook, patching the event loop with "
469
- "nest_asyncio will allow asynchronous eval submission, and is significantly "
470
- "faster. To patch the event loop, run `nest_asyncio.apply()`."
471
- )
472
- return SyncExecutor(
473
- sync_fn,
474
- tqdm_bar_format=tqdm_bar_format,
475
- max_retries=max_retries,
476
- exit_on_error=exit_on_error,
477
- fallback_return_value=fallback_return_value,
478
- )
479
- else:
480
- return AsyncExecutor(
481
- async_fn,
482
- concurrency=concurrency,
517
+ logger.warning(
518
+ "🐌!! If running inside a notebook, patching the event loop with "
519
+ "nest_asyncio will allow asynchronous eval submission, and is significantly "
520
+ "faster. To patch the event loop, run `nest_asyncio.apply()`."
521
+ )
522
+ return SyncExecutor(
523
+ sync_fn,
483
524
  tqdm_bar_format=tqdm_bar_format,
484
525
  max_retries=max_retries,
485
526
  exit_on_error=exit_on_error,
486
527
  fallback_return_value=fallback_return_value,
487
528
  )
529
+ return AsyncExecutor(
530
+ async_fn,
531
+ concurrency=concurrency,
532
+ tqdm_bar_format=tqdm_bar_format,
533
+ max_retries=max_retries,
534
+ exit_on_error=exit_on_error,
535
+ fallback_return_value=fallback_return_value,
536
+ )
488
537
 
489
538
 
490
539
  def _running_event_loop_exists() -> bool:
491
- """
492
- Checks for a running event loop.
540
+ """Checks for a running event loop.
493
541
 
494
542
  Returns:
495
543
  bool: True if a running event loop exists, False otherwise.
@@ -497,6 +545,6 @@ def _running_event_loop_exists() -> bool:
497
545
  """
498
546
  try:
499
547
  asyncio.get_running_loop()
500
- return True
501
548
  except RuntimeError:
502
549
  return False
550
+ return True