arize 8.0.0a21__py3-none-any.whl → 8.0.0a23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arize/__init__.py +17 -9
- arize/_exporter/client.py +55 -36
- arize/_exporter/parsers/tracing_data_parser.py +41 -30
- arize/_exporter/validation.py +3 -3
- arize/_flight/client.py +208 -77
- arize/_generated/api_client/__init__.py +30 -6
- arize/_generated/api_client/api/__init__.py +1 -0
- arize/_generated/api_client/api/datasets_api.py +864 -190
- arize/_generated/api_client/api/experiments_api.py +167 -131
- arize/_generated/api_client/api/projects_api.py +1197 -0
- arize/_generated/api_client/api_client.py +2 -2
- arize/_generated/api_client/configuration.py +42 -34
- arize/_generated/api_client/exceptions.py +2 -2
- arize/_generated/api_client/models/__init__.py +15 -4
- arize/_generated/api_client/models/dataset.py +10 -10
- arize/_generated/api_client/models/dataset_example.py +111 -0
- arize/_generated/api_client/models/dataset_example_update.py +100 -0
- arize/_generated/api_client/models/dataset_version.py +13 -13
- arize/_generated/api_client/models/datasets_create_request.py +16 -8
- arize/_generated/api_client/models/datasets_examples_insert_request.py +100 -0
- arize/_generated/api_client/models/datasets_examples_list200_response.py +106 -0
- arize/_generated/api_client/models/datasets_examples_update_request.py +102 -0
- arize/_generated/api_client/models/datasets_list200_response.py +10 -4
- arize/_generated/api_client/models/experiment.py +14 -16
- arize/_generated/api_client/models/experiment_run.py +108 -0
- arize/_generated/api_client/models/experiment_run_create.py +102 -0
- arize/_generated/api_client/models/experiments_create_request.py +16 -10
- arize/_generated/api_client/models/experiments_list200_response.py +10 -4
- arize/_generated/api_client/models/experiments_runs_list200_response.py +19 -5
- arize/_generated/api_client/models/{error.py → pagination_metadata.py} +13 -11
- arize/_generated/api_client/models/primitive_value.py +172 -0
- arize/_generated/api_client/models/problem.py +100 -0
- arize/_generated/api_client/models/project.py +99 -0
- arize/_generated/api_client/models/{datasets_list_examples200_response.py → projects_create_request.py} +13 -11
- arize/_generated/api_client/models/projects_list200_response.py +106 -0
- arize/_generated/api_client/rest.py +2 -2
- arize/_generated/api_client/test/test_dataset.py +4 -2
- arize/_generated/api_client/test/test_dataset_example.py +56 -0
- arize/_generated/api_client/test/test_dataset_example_update.py +52 -0
- arize/_generated/api_client/test/test_dataset_version.py +7 -2
- arize/_generated/api_client/test/test_datasets_api.py +27 -13
- arize/_generated/api_client/test/test_datasets_create_request.py +8 -4
- arize/_generated/api_client/test/{test_datasets_list_examples200_response.py → test_datasets_examples_insert_request.py} +19 -15
- arize/_generated/api_client/test/test_datasets_examples_list200_response.py +66 -0
- arize/_generated/api_client/test/test_datasets_examples_update_request.py +61 -0
- arize/_generated/api_client/test/test_datasets_list200_response.py +9 -3
- arize/_generated/api_client/test/test_experiment.py +2 -4
- arize/_generated/api_client/test/test_experiment_run.py +56 -0
- arize/_generated/api_client/test/test_experiment_run_create.py +54 -0
- arize/_generated/api_client/test/test_experiments_api.py +6 -6
- arize/_generated/api_client/test/test_experiments_create_request.py +9 -6
- arize/_generated/api_client/test/test_experiments_list200_response.py +9 -5
- arize/_generated/api_client/test/test_experiments_runs_list200_response.py +15 -5
- arize/_generated/api_client/test/test_pagination_metadata.py +53 -0
- arize/_generated/api_client/test/{test_error.py → test_primitive_value.py} +13 -14
- arize/_generated/api_client/test/test_problem.py +57 -0
- arize/_generated/api_client/test/test_project.py +58 -0
- arize/_generated/api_client/test/test_projects_api.py +59 -0
- arize/_generated/api_client/test/test_projects_create_request.py +54 -0
- arize/_generated/api_client/test/test_projects_list200_response.py +70 -0
- arize/_generated/api_client_README.md +43 -29
- arize/_generated/protocol/flight/flight_pb2.py +400 -0
- arize/_lazy.py +27 -19
- arize/client.py +269 -55
- arize/config.py +365 -116
- arize/constants/__init__.py +1 -0
- arize/constants/config.py +11 -4
- arize/constants/ml.py +6 -4
- arize/constants/openinference.py +2 -0
- arize/constants/pyarrow.py +2 -0
- arize/constants/spans.py +3 -1
- arize/datasets/__init__.py +1 -0
- arize/datasets/client.py +299 -84
- arize/datasets/errors.py +32 -2
- arize/datasets/validation.py +18 -8
- arize/embeddings/__init__.py +2 -0
- arize/embeddings/auto_generator.py +23 -19
- arize/embeddings/base_generators.py +89 -36
- arize/embeddings/constants.py +2 -0
- arize/embeddings/cv_generators.py +26 -4
- arize/embeddings/errors.py +27 -5
- arize/embeddings/nlp_generators.py +31 -12
- arize/embeddings/tabular_generators.py +32 -20
- arize/embeddings/usecases.py +12 -2
- arize/exceptions/__init__.py +1 -0
- arize/exceptions/auth.py +11 -1
- arize/exceptions/base.py +29 -4
- arize/exceptions/models.py +21 -2
- arize/exceptions/parameters.py +31 -0
- arize/exceptions/spaces.py +12 -1
- arize/exceptions/types.py +86 -7
- arize/exceptions/values.py +220 -20
- arize/experiments/__init__.py +1 -0
- arize/experiments/client.py +390 -286
- arize/experiments/evaluators/__init__.py +1 -0
- arize/experiments/evaluators/base.py +74 -41
- arize/experiments/evaluators/exceptions.py +6 -3
- arize/experiments/evaluators/executors.py +121 -73
- arize/experiments/evaluators/rate_limiters.py +106 -57
- arize/experiments/evaluators/types.py +34 -7
- arize/experiments/evaluators/utils.py +65 -27
- arize/experiments/functions.py +103 -101
- arize/experiments/tracing.py +52 -44
- arize/experiments/types.py +56 -31
- arize/logging.py +54 -22
- arize/models/__init__.py +1 -0
- arize/models/batch_validation/__init__.py +1 -0
- arize/models/batch_validation/errors.py +543 -65
- arize/models/batch_validation/validator.py +339 -300
- arize/models/bounded_executor.py +20 -7
- arize/models/casting.py +75 -29
- arize/models/client.py +326 -107
- arize/models/proto.py +95 -40
- arize/models/stream_validation.py +42 -14
- arize/models/surrogate_explainer/__init__.py +1 -0
- arize/models/surrogate_explainer/mimic.py +24 -13
- arize/pre_releases.py +43 -0
- arize/projects/__init__.py +1 -0
- arize/projects/client.py +129 -0
- arize/regions.py +40 -0
- arize/spans/__init__.py +1 -0
- arize/spans/client.py +130 -106
- arize/spans/columns.py +13 -0
- arize/spans/conversion.py +54 -38
- arize/spans/validation/__init__.py +1 -0
- arize/spans/validation/annotations/__init__.py +1 -0
- arize/spans/validation/annotations/annotations_validation.py +6 -4
- arize/spans/validation/annotations/dataframe_form_validation.py +13 -11
- arize/spans/validation/annotations/value_validation.py +35 -11
- arize/spans/validation/common/__init__.py +1 -0
- arize/spans/validation/common/argument_validation.py +33 -8
- arize/spans/validation/common/dataframe_form_validation.py +35 -9
- arize/spans/validation/common/errors.py +211 -11
- arize/spans/validation/common/value_validation.py +80 -13
- arize/spans/validation/evals/__init__.py +1 -0
- arize/spans/validation/evals/dataframe_form_validation.py +28 -8
- arize/spans/validation/evals/evals_validation.py +34 -4
- arize/spans/validation/evals/value_validation.py +26 -3
- arize/spans/validation/metadata/__init__.py +1 -1
- arize/spans/validation/metadata/argument_validation.py +14 -5
- arize/spans/validation/metadata/dataframe_form_validation.py +26 -10
- arize/spans/validation/metadata/value_validation.py +24 -10
- arize/spans/validation/spans/__init__.py +1 -0
- arize/spans/validation/spans/dataframe_form_validation.py +34 -13
- arize/spans/validation/spans/spans_validation.py +35 -4
- arize/spans/validation/spans/value_validation.py +76 -7
- arize/types.py +293 -157
- arize/utils/__init__.py +1 -0
- arize/utils/arrow.py +31 -15
- arize/utils/cache.py +34 -6
- arize/utils/dataframe.py +19 -2
- arize/utils/online_tasks/__init__.py +2 -0
- arize/utils/online_tasks/dataframe_preprocessor.py +53 -41
- arize/utils/openinference_conversion.py +44 -5
- arize/utils/proto.py +10 -0
- arize/utils/size.py +5 -3
- arize/version.py +3 -1
- {arize-8.0.0a21.dist-info → arize-8.0.0a23.dist-info}/METADATA +4 -3
- arize-8.0.0a23.dist-info/RECORD +174 -0
- {arize-8.0.0a21.dist-info → arize-8.0.0a23.dist-info}/WHEEL +1 -1
- arize-8.0.0a23.dist-info/licenses/LICENSE +176 -0
- arize-8.0.0a23.dist-info/licenses/NOTICE +13 -0
- arize/_generated/protocol/flight/export_pb2.py +0 -61
- arize/_generated/protocol/flight/ingest_pb2.py +0 -365
- arize-8.0.0a21.dist-info/RECORD +0 -146
- arize-8.0.0a21.dist-info/licenses/LICENSE.md +0 -12
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
"""Evaluator executors for managing concurrent evaluation tasks."""
|
|
2
|
+
|
|
1
3
|
from __future__ import annotations
|
|
2
4
|
|
|
3
5
|
import asyncio
|
|
@@ -10,18 +12,14 @@ import traceback
|
|
|
10
12
|
from contextlib import contextmanager
|
|
11
13
|
from enum import Enum
|
|
12
14
|
from typing import (
|
|
15
|
+
TYPE_CHECKING,
|
|
13
16
|
Any,
|
|
14
|
-
Callable,
|
|
15
|
-
Coroutine,
|
|
16
|
-
Generator,
|
|
17
|
-
List,
|
|
18
|
-
Optional,
|
|
19
17
|
Protocol,
|
|
20
|
-
Sequence,
|
|
21
|
-
Tuple,
|
|
22
|
-
Union,
|
|
23
18
|
)
|
|
24
19
|
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from collections.abc import Callable, Coroutine, Generator, Sequence
|
|
22
|
+
|
|
25
23
|
from tqdm.auto import tqdm
|
|
26
24
|
|
|
27
25
|
from arize.experiments.evaluators.exceptions import ArizeException
|
|
@@ -30,13 +28,15 @@ logger = logging.getLogger(__name__)
|
|
|
30
28
|
|
|
31
29
|
|
|
32
30
|
class Unset:
|
|
33
|
-
|
|
31
|
+
"""Sentinel class representing an unset or undefined value."""
|
|
34
32
|
|
|
35
33
|
|
|
36
34
|
_unset = Unset()
|
|
37
35
|
|
|
38
36
|
|
|
39
37
|
class ExecutionStatus(Enum):
|
|
38
|
+
"""Enum representing execution status states for experiment tasks."""
|
|
39
|
+
|
|
40
40
|
DID_NOT_RUN = "DID NOT RUN"
|
|
41
41
|
COMPLETED = "COMPLETED"
|
|
42
42
|
COMPLETED_WITH_RETRIES = "COMPLETED WITH RETRIES"
|
|
@@ -44,42 +44,52 @@ class ExecutionStatus(Enum):
|
|
|
44
44
|
|
|
45
45
|
|
|
46
46
|
class ExecutionDetails:
|
|
47
|
+
"""Container for tracking execution status, exceptions, and runtime metrics."""
|
|
48
|
+
|
|
47
49
|
def __init__(self) -> None:
|
|
48
|
-
|
|
50
|
+
"""Initialize execution details with default state."""
|
|
51
|
+
self.exceptions: list[Exception] = []
|
|
49
52
|
self.status = ExecutionStatus.DID_NOT_RUN
|
|
50
53
|
self.execution_seconds: float = 0
|
|
51
54
|
|
|
52
55
|
def fail(self) -> None:
|
|
56
|
+
"""Mark this execution as failed."""
|
|
53
57
|
self.status = ExecutionStatus.FAILED
|
|
54
58
|
|
|
55
59
|
def complete(self) -> None:
|
|
60
|
+
"""Mark this execution as completed, with or without retries."""
|
|
56
61
|
if self.exceptions:
|
|
57
62
|
self.status = ExecutionStatus.COMPLETED_WITH_RETRIES
|
|
58
63
|
else:
|
|
59
64
|
self.status = ExecutionStatus.COMPLETED
|
|
60
65
|
|
|
61
66
|
def log_exception(self, exc: Exception) -> None:
|
|
67
|
+
"""Log an exception that occurred during execution."""
|
|
62
68
|
self.exceptions.append(exc)
|
|
63
69
|
|
|
64
70
|
def log_runtime(self, start_time: float) -> None:
|
|
71
|
+
"""Log the runtime duration for this execution."""
|
|
65
72
|
self.execution_seconds += time.time() - start_time
|
|
66
73
|
|
|
67
74
|
|
|
68
75
|
class Executor(Protocol):
|
|
76
|
+
"""Protocol defining the interface for experiment task executors."""
|
|
77
|
+
|
|
69
78
|
def run(
|
|
70
79
|
self, inputs: Sequence[Any]
|
|
71
|
-
) ->
|
|
80
|
+
) -> tuple[list[object], list[ExecutionDetails]]:
|
|
81
|
+
"""Execute the generation function on all inputs and return outputs with execution details."""
|
|
82
|
+
...
|
|
72
83
|
|
|
73
84
|
|
|
74
85
|
class AsyncExecutor(Executor):
|
|
75
|
-
"""
|
|
76
|
-
A class that provides asynchronous execution of tasks using a producer-consumer pattern.
|
|
86
|
+
"""A class that provides asynchronous execution of tasks using a producer-consumer pattern.
|
|
77
87
|
|
|
78
88
|
An async interface is provided by the `execute` method, which returns a coroutine, and a sync
|
|
79
89
|
interface is provided by the `run` method.
|
|
80
90
|
|
|
81
91
|
Args:
|
|
82
|
-
generation_fn (Callable[[
|
|
92
|
+
generation_fn (Callable[[object], Coroutine[Any, Any, Any]]): A coroutine function that
|
|
83
93
|
generates tasks to be executed.
|
|
84
94
|
|
|
85
95
|
concurrency (int, optional): The number of concurrent consumers. Defaults to 3.
|
|
@@ -102,14 +112,25 @@ class AsyncExecutor(Executor):
|
|
|
102
112
|
|
|
103
113
|
def __init__(
|
|
104
114
|
self,
|
|
105
|
-
generation_fn: Callable[[
|
|
115
|
+
generation_fn: Callable[[object], Coroutine[Any, Any, Any]],
|
|
106
116
|
concurrency: int = 3,
|
|
107
|
-
tqdm_bar_format:
|
|
117
|
+
tqdm_bar_format: str | None = None,
|
|
108
118
|
max_retries: int = 10,
|
|
109
119
|
exit_on_error: bool = True,
|
|
110
|
-
fallback_return_value:
|
|
120
|
+
fallback_return_value: Unset | object = _unset,
|
|
111
121
|
termination_signal: signal.Signals = signal.SIGINT,
|
|
112
|
-
):
|
|
122
|
+
) -> None:
|
|
123
|
+
"""Initialize the async executor with configuration parameters.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
generation_fn: Async function to execute for each item.
|
|
127
|
+
concurrency: Maximum number of concurrent executions.
|
|
128
|
+
tqdm_bar_format: Custom format string for progress bar.
|
|
129
|
+
max_retries: Maximum retry attempts for failed executions.
|
|
130
|
+
exit_on_error: Whether to exit on first error.
|
|
131
|
+
fallback_return_value: Value to return when execution fails.
|
|
132
|
+
termination_signal: Signal to handle for graceful termination.
|
|
133
|
+
"""
|
|
113
134
|
self.generate = generation_fn
|
|
114
135
|
self.fallback_return_value = fallback_return_value
|
|
115
136
|
self.concurrency = concurrency
|
|
@@ -122,11 +143,12 @@ class AsyncExecutor(Executor):
|
|
|
122
143
|
async def producer(
|
|
123
144
|
self,
|
|
124
145
|
inputs: Sequence[Any],
|
|
125
|
-
queue: asyncio.PriorityQueue[
|
|
146
|
+
queue: asyncio.PriorityQueue[tuple[int, Any]],
|
|
126
147
|
max_fill: int,
|
|
127
148
|
done_producing: asyncio.Event,
|
|
128
149
|
termination_signal: asyncio.Event,
|
|
129
150
|
) -> None:
|
|
151
|
+
"""Produce tasks by adding inputs to the queue for consumers to process."""
|
|
130
152
|
try:
|
|
131
153
|
for index, input in enumerate(inputs):
|
|
132
154
|
if termination_signal.is_set():
|
|
@@ -140,13 +162,14 @@ class AsyncExecutor(Executor):
|
|
|
140
162
|
|
|
141
163
|
async def consumer(
|
|
142
164
|
self,
|
|
143
|
-
outputs:
|
|
144
|
-
execution_details:
|
|
145
|
-
queue: asyncio.PriorityQueue[
|
|
165
|
+
outputs: list[object],
|
|
166
|
+
execution_details: list[ExecutionDetails],
|
|
167
|
+
queue: asyncio.PriorityQueue[tuple[int, Any]],
|
|
146
168
|
done_producing: asyncio.Event,
|
|
147
169
|
termination_event: asyncio.Event,
|
|
148
170
|
progress_bar: tqdm[Any],
|
|
149
171
|
) -> None:
|
|
172
|
+
"""Consume tasks from the queue and execute the generation function on each input."""
|
|
150
173
|
termination_event_watcher = None
|
|
151
174
|
while True:
|
|
152
175
|
marked_done = False
|
|
@@ -170,7 +193,7 @@ class AsyncExecutor(Executor):
|
|
|
170
193
|
termination_event_watcher = asyncio.create_task(
|
|
171
194
|
termination_event.wait()
|
|
172
195
|
)
|
|
173
|
-
done,
|
|
196
|
+
done, _pending = await asyncio.wait(
|
|
174
197
|
[generate_task, termination_event_watcher],
|
|
175
198
|
timeout=120,
|
|
176
199
|
return_when=asyncio.FIRST_COMPLETED,
|
|
@@ -205,7 +228,7 @@ class AsyncExecutor(Executor):
|
|
|
205
228
|
retry_count := abs(priority)
|
|
206
229
|
) < self.max_retries and not is_arize_exception:
|
|
207
230
|
tqdm.write(
|
|
208
|
-
f"Exception in worker on attempt {retry_count + 1}: raised {
|
|
231
|
+
f"Exception in worker on attempt {retry_count + 1}: raised {exc!r}"
|
|
209
232
|
)
|
|
210
233
|
tqdm.write("Requeuing...")
|
|
211
234
|
await queue.put((priority - 1, item))
|
|
@@ -229,10 +252,11 @@ class AsyncExecutor(Executor):
|
|
|
229
252
|
|
|
230
253
|
async def execute(
|
|
231
254
|
self, inputs: Sequence[Any]
|
|
232
|
-
) ->
|
|
255
|
+
) -> tuple[list[object], list[ExecutionDetails]]:
|
|
256
|
+
"""Execute all inputs asynchronously using producer-consumer pattern."""
|
|
233
257
|
termination_event = asyncio.Event()
|
|
234
258
|
|
|
235
|
-
def termination_handler(signum: int, frame:
|
|
259
|
+
def termination_handler(signum: int, frame: object) -> None:
|
|
236
260
|
termination_event.set()
|
|
237
261
|
tqdm.write(
|
|
238
262
|
"Process was interrupted. The return value will be incomplete..."
|
|
@@ -251,7 +275,7 @@ class AsyncExecutor(Executor):
|
|
|
251
275
|
max_fill = max_queue_size - (
|
|
252
276
|
2 * self.concurrency
|
|
253
277
|
) # ensure there is always room to requeue
|
|
254
|
-
queue: asyncio.PriorityQueue[
|
|
278
|
+
queue: asyncio.PriorityQueue[tuple[int, Any]] = asyncio.PriorityQueue(
|
|
255
279
|
maxsize=max_queue_size
|
|
256
280
|
)
|
|
257
281
|
done_producing = asyncio.Event()
|
|
@@ -280,7 +304,7 @@ class AsyncExecutor(Executor):
|
|
|
280
304
|
termination_event_watcher = asyncio.create_task(
|
|
281
305
|
termination_event.wait()
|
|
282
306
|
)
|
|
283
|
-
done,
|
|
307
|
+
done, _pending = await asyncio.wait(
|
|
284
308
|
[join_task, termination_event_watcher],
|
|
285
309
|
return_when=asyncio.FIRST_COMPLETED,
|
|
286
310
|
)
|
|
@@ -305,16 +329,16 @@ class AsyncExecutor(Executor):
|
|
|
305
329
|
|
|
306
330
|
def run(
|
|
307
331
|
self, inputs: Sequence[Any]
|
|
308
|
-
) ->
|
|
332
|
+
) -> tuple[list[object], list[ExecutionDetails]]:
|
|
333
|
+
"""Execute all inputs asynchronously and return outputs with execution details."""
|
|
309
334
|
return asyncio.run(self.execute(inputs))
|
|
310
335
|
|
|
311
336
|
|
|
312
337
|
class SyncExecutor(Executor):
|
|
313
|
-
"""
|
|
314
|
-
Synchronous executor for generating outputs from inputs using a given generation function.
|
|
338
|
+
"""Synchronous executor for generating outputs from inputs using a given generation function.
|
|
315
339
|
|
|
316
340
|
Args:
|
|
317
|
-
generation_fn (Callable[[
|
|
341
|
+
generation_fn (Callable[[object], Any]): The generation function that takes an input and
|
|
318
342
|
returns an output.
|
|
319
343
|
|
|
320
344
|
tqdm_bar_format (Optional[str], optional): The format string for the progress bar. Defaults
|
|
@@ -333,13 +357,23 @@ class SyncExecutor(Executor):
|
|
|
333
357
|
|
|
334
358
|
def __init__(
|
|
335
359
|
self,
|
|
336
|
-
generation_fn: Callable[[
|
|
337
|
-
tqdm_bar_format:
|
|
360
|
+
generation_fn: Callable[[object], Any],
|
|
361
|
+
tqdm_bar_format: str | None = None,
|
|
338
362
|
max_retries: int = 10,
|
|
339
363
|
exit_on_error: bool = True,
|
|
340
|
-
fallback_return_value:
|
|
341
|
-
termination_signal:
|
|
342
|
-
):
|
|
364
|
+
fallback_return_value: Unset | object = _unset,
|
|
365
|
+
termination_signal: signal.Signals | None = signal.SIGINT,
|
|
366
|
+
) -> None:
|
|
367
|
+
"""Initialize the sync executor with configuration parameters.
|
|
368
|
+
|
|
369
|
+
Args:
|
|
370
|
+
generation_fn: Synchronous function to execute for each item.
|
|
371
|
+
tqdm_bar_format: Custom format string for progress bar.
|
|
372
|
+
max_retries: Maximum retry attempts for failed executions.
|
|
373
|
+
exit_on_error: Whether to exit on first error.
|
|
374
|
+
fallback_return_value: Value to return when execution fails.
|
|
375
|
+
termination_signal: Signal to handle for graceful termination.
|
|
376
|
+
"""
|
|
343
377
|
self.generate = generation_fn
|
|
344
378
|
self.fallback_return_value = fallback_return_value
|
|
345
379
|
self.tqdm_bar_format = tqdm_bar_format
|
|
@@ -349,7 +383,7 @@ class SyncExecutor(Executor):
|
|
|
349
383
|
|
|
350
384
|
self._TERMINATE = False
|
|
351
385
|
|
|
352
|
-
def _signal_handler(self, signum: int, frame:
|
|
386
|
+
def _signal_handler(self, signum: int, frame: object) -> None:
|
|
353
387
|
tqdm.write(
|
|
354
388
|
"Process was interrupted. The return value will be incomplete..."
|
|
355
389
|
)
|
|
@@ -357,7 +391,7 @@ class SyncExecutor(Executor):
|
|
|
357
391
|
|
|
358
392
|
@contextmanager
|
|
359
393
|
def _executor_signal_handling(
|
|
360
|
-
self, signum:
|
|
394
|
+
self, signum: int | None
|
|
361
395
|
) -> Generator[None, None, None]:
|
|
362
396
|
original_handler = None
|
|
363
397
|
if signum is not None:
|
|
@@ -369,10 +403,11 @@ class SyncExecutor(Executor):
|
|
|
369
403
|
else:
|
|
370
404
|
yield
|
|
371
405
|
|
|
372
|
-
def run(self, inputs: Sequence[Any]) ->
|
|
406
|
+
def run(self, inputs: Sequence[Any]) -> tuple[list[object], list[object]]:
|
|
407
|
+
"""Execute all inputs synchronously and return outputs with execution details."""
|
|
373
408
|
with self._executor_signal_handling(self.termination_signal):
|
|
374
409
|
outputs = [self.fallback_return_value] * len(inputs)
|
|
375
|
-
execution_details:
|
|
410
|
+
execution_details: list[ExecutionDetails] = [
|
|
376
411
|
ExecutionDetails() for _ in range(len(inputs))
|
|
377
412
|
]
|
|
378
413
|
progress_bar = tqdm(
|
|
@@ -398,12 +433,11 @@ class SyncExecutor(Executor):
|
|
|
398
433
|
attempt >= self.max_retries
|
|
399
434
|
or is_arize_exception
|
|
400
435
|
):
|
|
401
|
-
raise
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
tqdm.write("Retrying...")
|
|
436
|
+
raise
|
|
437
|
+
tqdm.write(
|
|
438
|
+
f"Exception in worker on attempt {attempt + 1}: {exc}"
|
|
439
|
+
)
|
|
440
|
+
tqdm.write("Retrying...")
|
|
407
441
|
except Exception as exc:
|
|
408
442
|
execution_details[index].fail()
|
|
409
443
|
tqdm.write(
|
|
@@ -411,23 +445,40 @@ class SyncExecutor(Executor):
|
|
|
411
445
|
)
|
|
412
446
|
if self.exit_on_error:
|
|
413
447
|
return outputs, execution_details
|
|
414
|
-
|
|
415
|
-
progress_bar.update()
|
|
448
|
+
progress_bar.update()
|
|
416
449
|
finally:
|
|
417
450
|
execution_details[index].log_runtime(task_start_time)
|
|
418
451
|
return outputs, execution_details
|
|
419
452
|
|
|
420
453
|
|
|
421
454
|
def get_executor_on_sync_context(
|
|
422
|
-
sync_fn: Callable[[
|
|
423
|
-
async_fn: Callable[[
|
|
455
|
+
sync_fn: Callable[[object], Any],
|
|
456
|
+
async_fn: Callable[[object], Coroutine[Any, Any, Any]],
|
|
424
457
|
run_sync: bool = False,
|
|
425
458
|
concurrency: int = 3,
|
|
426
|
-
tqdm_bar_format:
|
|
459
|
+
tqdm_bar_format: str | None = None,
|
|
427
460
|
max_retries: int = 10,
|
|
428
461
|
exit_on_error: bool = True,
|
|
429
|
-
fallback_return_value:
|
|
462
|
+
fallback_return_value: Unset | object = _unset,
|
|
430
463
|
) -> Executor:
|
|
464
|
+
"""Get an appropriate executor based on the current threading context.
|
|
465
|
+
|
|
466
|
+
Automatically selects between sync and async execution based on thread context.
|
|
467
|
+
Falls back to sync execution when not in the main thread.
|
|
468
|
+
|
|
469
|
+
Args:
|
|
470
|
+
sync_fn: Synchronous function to execute.
|
|
471
|
+
async_fn: Asynchronous function to execute.
|
|
472
|
+
run_sync: Force synchronous execution. Defaults to False.
|
|
473
|
+
concurrency: Number of concurrent executions for async. Defaults to 3.
|
|
474
|
+
tqdm_bar_format: Format string for progress bar. Defaults to None.
|
|
475
|
+
max_retries: Maximum number of retry attempts. Defaults to 10.
|
|
476
|
+
exit_on_error: Whether to exit on first error. Defaults to True.
|
|
477
|
+
fallback_return_value: Value to return on failure. Defaults to unset.
|
|
478
|
+
|
|
479
|
+
Returns:
|
|
480
|
+
An Executor instance configured for the current context.
|
|
481
|
+
"""
|
|
431
482
|
if threading.current_thread() is not threading.main_thread():
|
|
432
483
|
# run evals synchronously if not in the main thread
|
|
433
484
|
|
|
@@ -463,33 +514,30 @@ def get_executor_on_sync_context(
|
|
|
463
514
|
exit_on_error=exit_on_error,
|
|
464
515
|
fallback_return_value=fallback_return_value,
|
|
465
516
|
)
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
sync_fn,
|
|
474
|
-
tqdm_bar_format=tqdm_bar_format,
|
|
475
|
-
max_retries=max_retries,
|
|
476
|
-
exit_on_error=exit_on_error,
|
|
477
|
-
fallback_return_value=fallback_return_value,
|
|
478
|
-
)
|
|
479
|
-
else:
|
|
480
|
-
return AsyncExecutor(
|
|
481
|
-
async_fn,
|
|
482
|
-
concurrency=concurrency,
|
|
517
|
+
logger.warning(
|
|
518
|
+
"🐌!! If running inside a notebook, patching the event loop with "
|
|
519
|
+
"nest_asyncio will allow asynchronous eval submission, and is significantly "
|
|
520
|
+
"faster. To patch the event loop, run `nest_asyncio.apply()`."
|
|
521
|
+
)
|
|
522
|
+
return SyncExecutor(
|
|
523
|
+
sync_fn,
|
|
483
524
|
tqdm_bar_format=tqdm_bar_format,
|
|
484
525
|
max_retries=max_retries,
|
|
485
526
|
exit_on_error=exit_on_error,
|
|
486
527
|
fallback_return_value=fallback_return_value,
|
|
487
528
|
)
|
|
529
|
+
return AsyncExecutor(
|
|
530
|
+
async_fn,
|
|
531
|
+
concurrency=concurrency,
|
|
532
|
+
tqdm_bar_format=tqdm_bar_format,
|
|
533
|
+
max_retries=max_retries,
|
|
534
|
+
exit_on_error=exit_on_error,
|
|
535
|
+
fallback_return_value=fallback_return_value,
|
|
536
|
+
)
|
|
488
537
|
|
|
489
538
|
|
|
490
539
|
def _running_event_loop_exists() -> bool:
|
|
491
|
-
"""
|
|
492
|
-
Checks for a running event loop.
|
|
540
|
+
"""Checks for a running event loop.
|
|
493
541
|
|
|
494
542
|
Returns:
|
|
495
543
|
bool: True if a running event loop exists, False otherwise.
|
|
@@ -497,6 +545,6 @@ def _running_event_loop_exists() -> bool:
|
|
|
497
545
|
"""
|
|
498
546
|
try:
|
|
499
547
|
asyncio.get_running_loop()
|
|
500
|
-
return True
|
|
501
548
|
except RuntimeError:
|
|
502
549
|
return False
|
|
550
|
+
return True
|