arize 8.0.0a14__py3-none-any.whl → 8.0.0a16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arize/__init__.py +70 -1
- arize/_flight/client.py +163 -43
- arize/_flight/types.py +1 -0
- arize/_generated/api_client/__init__.py +5 -1
- arize/_generated/api_client/api/datasets_api.py +6 -6
- arize/_generated/api_client/api/experiments_api.py +924 -61
- arize/_generated/api_client/api_client.py +1 -1
- arize/_generated/api_client/configuration.py +1 -1
- arize/_generated/api_client/exceptions.py +1 -1
- arize/_generated/api_client/models/__init__.py +3 -1
- arize/_generated/api_client/models/dataset.py +2 -2
- arize/_generated/api_client/models/dataset_version.py +1 -1
- arize/_generated/api_client/models/datasets_create_request.py +3 -3
- arize/_generated/api_client/models/datasets_list200_response.py +1 -1
- arize/_generated/api_client/models/datasets_list_examples200_response.py +1 -1
- arize/_generated/api_client/models/error.py +1 -1
- arize/_generated/api_client/models/experiment.py +6 -6
- arize/_generated/api_client/models/experiments_create_request.py +98 -0
- arize/_generated/api_client/models/experiments_list200_response.py +1 -1
- arize/_generated/api_client/models/experiments_runs_list200_response.py +92 -0
- arize/_generated/api_client/rest.py +1 -1
- arize/_generated/api_client/test/test_dataset.py +2 -1
- arize/_generated/api_client/test/test_dataset_version.py +1 -1
- arize/_generated/api_client/test/test_datasets_api.py +1 -1
- arize/_generated/api_client/test/test_datasets_create_request.py +2 -1
- arize/_generated/api_client/test/test_datasets_list200_response.py +1 -1
- arize/_generated/api_client/test/test_datasets_list_examples200_response.py +1 -1
- arize/_generated/api_client/test/test_error.py +1 -1
- arize/_generated/api_client/test/test_experiment.py +6 -1
- arize/_generated/api_client/test/test_experiments_api.py +23 -2
- arize/_generated/api_client/test/test_experiments_create_request.py +61 -0
- arize/_generated/api_client/test/test_experiments_list200_response.py +1 -1
- arize/_generated/api_client/test/test_experiments_runs_list200_response.py +56 -0
- arize/_generated/api_client_README.md +13 -8
- arize/client.py +19 -2
- arize/config.py +50 -3
- arize/constants/config.py +8 -2
- arize/constants/openinference.py +14 -0
- arize/constants/pyarrow.py +1 -0
- arize/datasets/__init__.py +0 -70
- arize/datasets/client.py +106 -19
- arize/datasets/errors.py +61 -0
- arize/datasets/validation.py +46 -0
- arize/experiments/client.py +455 -0
- arize/experiments/evaluators/__init__.py +0 -0
- arize/experiments/evaluators/base.py +255 -0
- arize/experiments/evaluators/exceptions.py +10 -0
- arize/experiments/evaluators/executors.py +502 -0
- arize/experiments/evaluators/rate_limiters.py +277 -0
- arize/experiments/evaluators/types.py +122 -0
- arize/experiments/evaluators/utils.py +198 -0
- arize/experiments/functions.py +920 -0
- arize/experiments/tracing.py +276 -0
- arize/experiments/types.py +394 -0
- arize/models/client.py +4 -1
- arize/spans/client.py +16 -20
- arize/utils/arrow.py +4 -3
- arize/utils/openinference_conversion.py +56 -0
- arize/utils/proto.py +13 -0
- arize/utils/size.py +22 -0
- arize/version.py +1 -1
- {arize-8.0.0a14.dist-info → arize-8.0.0a16.dist-info}/METADATA +3 -1
- {arize-8.0.0a14.dist-info → arize-8.0.0a16.dist-info}/RECORD +65 -44
- {arize-8.0.0a14.dist-info → arize-8.0.0a16.dist-info}/WHEEL +0 -0
- {arize-8.0.0a14.dist-info → arize-8.0.0a16.dist-info}/licenses/LICENSE.md +0 -0
|
@@ -0,0 +1,502 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import contextlib
|
|
5
|
+
import logging
|
|
6
|
+
import signal
|
|
7
|
+
import threading
|
|
8
|
+
import time
|
|
9
|
+
import traceback
|
|
10
|
+
from contextlib import contextmanager
|
|
11
|
+
from enum import Enum
|
|
12
|
+
from typing import (
|
|
13
|
+
Any,
|
|
14
|
+
Callable,
|
|
15
|
+
Coroutine,
|
|
16
|
+
Generator,
|
|
17
|
+
List,
|
|
18
|
+
Optional,
|
|
19
|
+
Protocol,
|
|
20
|
+
Sequence,
|
|
21
|
+
Tuple,
|
|
22
|
+
Union,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
from tqdm.auto import tqdm
|
|
26
|
+
|
|
27
|
+
from arize.experiments.evaluators.exceptions import ArizeException
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class Unset:
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
_unset = Unset()
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class ExecutionStatus(Enum):
|
|
40
|
+
DID_NOT_RUN = "DID NOT RUN"
|
|
41
|
+
COMPLETED = "COMPLETED"
|
|
42
|
+
COMPLETED_WITH_RETRIES = "COMPLETED WITH RETRIES"
|
|
43
|
+
FAILED = "FAILED"
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class ExecutionDetails:
|
|
47
|
+
def __init__(self) -> None:
|
|
48
|
+
self.exceptions: List[Exception] = []
|
|
49
|
+
self.status = ExecutionStatus.DID_NOT_RUN
|
|
50
|
+
self.execution_seconds: float = 0
|
|
51
|
+
|
|
52
|
+
def fail(self) -> None:
|
|
53
|
+
self.status = ExecutionStatus.FAILED
|
|
54
|
+
|
|
55
|
+
def complete(self) -> None:
|
|
56
|
+
if self.exceptions:
|
|
57
|
+
self.status = ExecutionStatus.COMPLETED_WITH_RETRIES
|
|
58
|
+
else:
|
|
59
|
+
self.status = ExecutionStatus.COMPLETED
|
|
60
|
+
|
|
61
|
+
def log_exception(self, exc: Exception) -> None:
|
|
62
|
+
self.exceptions.append(exc)
|
|
63
|
+
|
|
64
|
+
def log_runtime(self, start_time: float) -> None:
|
|
65
|
+
self.execution_seconds += time.time() - start_time
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class Executor(Protocol):
|
|
69
|
+
def run(
|
|
70
|
+
self, inputs: Sequence[Any]
|
|
71
|
+
) -> Tuple[List[Any], List[ExecutionDetails]]: ...
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class AsyncExecutor(Executor):
|
|
75
|
+
"""
|
|
76
|
+
A class that provides asynchronous execution of tasks using a producer-consumer pattern.
|
|
77
|
+
|
|
78
|
+
An async interface is provided by the `execute` method, which returns a coroutine, and a sync
|
|
79
|
+
interface is provided by the `run` method.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
generation_fn (Callable[[Any], Coroutine[Any, Any, Any]]): A coroutine function that
|
|
83
|
+
generates tasks to be executed.
|
|
84
|
+
|
|
85
|
+
concurrency (int, optional): The number of concurrent consumers. Defaults to 3.
|
|
86
|
+
|
|
87
|
+
tqdm_bar_format (Optional[str], optional): The format string for the progress bar.
|
|
88
|
+
Defaults to None.
|
|
89
|
+
|
|
90
|
+
max_retries (int, optional): The maximum number of times to retry on exceptions.
|
|
91
|
+
Defaults to 10.
|
|
92
|
+
|
|
93
|
+
exit_on_error (bool, optional): Whether to exit execution on the first encountered error.
|
|
94
|
+
Defaults to True.
|
|
95
|
+
|
|
96
|
+
fallback_return_value (Union[Unset, Any], optional): The fallback return value for tasks
|
|
97
|
+
that encounter errors. Defaults to _unset.
|
|
98
|
+
|
|
99
|
+
termination_signal (signal.Signals, optional): The signal handled to terminate the executor.
|
|
100
|
+
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
def __init__(
|
|
104
|
+
self,
|
|
105
|
+
generation_fn: Callable[[Any], Coroutine[Any, Any, Any]],
|
|
106
|
+
concurrency: int = 3,
|
|
107
|
+
tqdm_bar_format: Optional[str] = None,
|
|
108
|
+
max_retries: int = 10,
|
|
109
|
+
exit_on_error: bool = True,
|
|
110
|
+
fallback_return_value: Union[Unset, Any] = _unset,
|
|
111
|
+
termination_signal: signal.Signals = signal.SIGINT,
|
|
112
|
+
):
|
|
113
|
+
self.generate = generation_fn
|
|
114
|
+
self.fallback_return_value = fallback_return_value
|
|
115
|
+
self.concurrency = concurrency
|
|
116
|
+
self.tqdm_bar_format = tqdm_bar_format
|
|
117
|
+
self.max_retries = max_retries
|
|
118
|
+
self.exit_on_error = exit_on_error
|
|
119
|
+
self.base_priority = 0
|
|
120
|
+
self.termination_signal = termination_signal
|
|
121
|
+
|
|
122
|
+
async def producer(
|
|
123
|
+
self,
|
|
124
|
+
inputs: Sequence[Any],
|
|
125
|
+
queue: asyncio.PriorityQueue[Tuple[int, Any]],
|
|
126
|
+
max_fill: int,
|
|
127
|
+
done_producing: asyncio.Event,
|
|
128
|
+
termination_signal: asyncio.Event,
|
|
129
|
+
) -> None:
|
|
130
|
+
try:
|
|
131
|
+
for index, input in enumerate(inputs):
|
|
132
|
+
if termination_signal.is_set():
|
|
133
|
+
break
|
|
134
|
+
while queue.qsize() >= max_fill:
|
|
135
|
+
# keep room in the queue for requeues
|
|
136
|
+
await asyncio.sleep(1)
|
|
137
|
+
await queue.put((self.base_priority, (index, input)))
|
|
138
|
+
finally:
|
|
139
|
+
done_producing.set()
|
|
140
|
+
|
|
141
|
+
async def consumer(
|
|
142
|
+
self,
|
|
143
|
+
outputs: List[Any],
|
|
144
|
+
execution_details: List[ExecutionDetails],
|
|
145
|
+
queue: asyncio.PriorityQueue[Tuple[int, Any]],
|
|
146
|
+
done_producing: asyncio.Event,
|
|
147
|
+
termination_event: asyncio.Event,
|
|
148
|
+
progress_bar: tqdm[Any],
|
|
149
|
+
) -> None:
|
|
150
|
+
termination_event_watcher = None
|
|
151
|
+
while True:
|
|
152
|
+
marked_done = False
|
|
153
|
+
try:
|
|
154
|
+
priority, item = await asyncio.wait_for(queue.get(), timeout=1)
|
|
155
|
+
except asyncio.TimeoutError:
|
|
156
|
+
if done_producing.is_set() and queue.empty():
|
|
157
|
+
break
|
|
158
|
+
continue
|
|
159
|
+
if termination_event.is_set():
|
|
160
|
+
# discard any remaining items in the queue
|
|
161
|
+
queue.task_done()
|
|
162
|
+
marked_done = True
|
|
163
|
+
continue
|
|
164
|
+
|
|
165
|
+
index, payload = item
|
|
166
|
+
|
|
167
|
+
try:
|
|
168
|
+
task_start_time = time.time()
|
|
169
|
+
generate_task = asyncio.create_task(self.generate(payload))
|
|
170
|
+
termination_event_watcher = asyncio.create_task(
|
|
171
|
+
termination_event.wait()
|
|
172
|
+
)
|
|
173
|
+
done, pending = await asyncio.wait(
|
|
174
|
+
[generate_task, termination_event_watcher],
|
|
175
|
+
timeout=120,
|
|
176
|
+
return_when=asyncio.FIRST_COMPLETED,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
if generate_task in done:
|
|
180
|
+
outputs[index] = generate_task.result()
|
|
181
|
+
execution_details[index].complete()
|
|
182
|
+
execution_details[index].log_runtime(task_start_time)
|
|
183
|
+
progress_bar.update()
|
|
184
|
+
elif termination_event.is_set():
|
|
185
|
+
# discard the pending task and remaining items in the queue
|
|
186
|
+
if not generate_task.done():
|
|
187
|
+
generate_task.cancel()
|
|
188
|
+
# Handle the cancellation exception
|
|
189
|
+
with contextlib.suppress(asyncio.CancelledError):
|
|
190
|
+
# allow any cleanup to finish for the cancelled task
|
|
191
|
+
await generate_task
|
|
192
|
+
queue.task_done()
|
|
193
|
+
marked_done = True
|
|
194
|
+
continue
|
|
195
|
+
else:
|
|
196
|
+
tqdm.write("Worker timeout, requeuing")
|
|
197
|
+
# task timeouts are requeued at the same priority
|
|
198
|
+
await queue.put((priority, item))
|
|
199
|
+
execution_details[index].log_runtime(task_start_time)
|
|
200
|
+
except Exception as exc:
|
|
201
|
+
execution_details[index].log_exception(exc)
|
|
202
|
+
execution_details[index].log_runtime(task_start_time)
|
|
203
|
+
is_arize_exception = isinstance(exc, ArizeException)
|
|
204
|
+
if (
|
|
205
|
+
retry_count := abs(priority)
|
|
206
|
+
) < self.max_retries and not is_arize_exception:
|
|
207
|
+
tqdm.write(
|
|
208
|
+
f"Exception in worker on attempt {retry_count + 1}: raised {repr(exc)}"
|
|
209
|
+
)
|
|
210
|
+
tqdm.write("Requeuing...")
|
|
211
|
+
await queue.put((priority - 1, item))
|
|
212
|
+
else:
|
|
213
|
+
execution_details[index].fail()
|
|
214
|
+
tqdm.write(
|
|
215
|
+
f"Retries exhausted after {retry_count + 1} attempts: {traceback.format_exc()}"
|
|
216
|
+
)
|
|
217
|
+
if self.exit_on_error:
|
|
218
|
+
termination_event.set()
|
|
219
|
+
else:
|
|
220
|
+
progress_bar.update()
|
|
221
|
+
finally:
|
|
222
|
+
if not marked_done:
|
|
223
|
+
queue.task_done()
|
|
224
|
+
if (
|
|
225
|
+
termination_event_watcher
|
|
226
|
+
and not termination_event_watcher.done()
|
|
227
|
+
):
|
|
228
|
+
termination_event_watcher.cancel()
|
|
229
|
+
|
|
230
|
+
async def execute(
|
|
231
|
+
self, inputs: Sequence[Any]
|
|
232
|
+
) -> Tuple[List[Any], List[ExecutionDetails]]:
|
|
233
|
+
termination_event = asyncio.Event()
|
|
234
|
+
|
|
235
|
+
def termination_handler(signum: int, frame: Any) -> None:
|
|
236
|
+
termination_event.set()
|
|
237
|
+
tqdm.write(
|
|
238
|
+
"Process was interrupted. The return value will be incomplete..."
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
original_handler = signal.signal(
|
|
242
|
+
self.termination_signal, termination_handler
|
|
243
|
+
)
|
|
244
|
+
outputs = [self.fallback_return_value] * len(inputs)
|
|
245
|
+
execution_details = [ExecutionDetails() for _ in range(len(inputs))]
|
|
246
|
+
progress_bar = tqdm(total=len(inputs), bar_format=self.tqdm_bar_format)
|
|
247
|
+
|
|
248
|
+
max_queue_size = (
|
|
249
|
+
5 * self.concurrency
|
|
250
|
+
) # limit the queue to bound memory usage
|
|
251
|
+
max_fill = max_queue_size - (
|
|
252
|
+
2 * self.concurrency
|
|
253
|
+
) # ensure there is always room to requeue
|
|
254
|
+
queue: asyncio.PriorityQueue[Tuple[int, Any]] = asyncio.PriorityQueue(
|
|
255
|
+
maxsize=max_queue_size
|
|
256
|
+
)
|
|
257
|
+
done_producing = asyncio.Event()
|
|
258
|
+
|
|
259
|
+
producer = asyncio.create_task(
|
|
260
|
+
self.producer(
|
|
261
|
+
inputs, queue, max_fill, done_producing, termination_event
|
|
262
|
+
)
|
|
263
|
+
)
|
|
264
|
+
consumers = [
|
|
265
|
+
asyncio.create_task(
|
|
266
|
+
self.consumer(
|
|
267
|
+
outputs,
|
|
268
|
+
execution_details,
|
|
269
|
+
queue,
|
|
270
|
+
done_producing,
|
|
271
|
+
termination_event,
|
|
272
|
+
progress_bar,
|
|
273
|
+
)
|
|
274
|
+
)
|
|
275
|
+
for _ in range(self.concurrency)
|
|
276
|
+
]
|
|
277
|
+
|
|
278
|
+
await asyncio.gather(producer, *consumers)
|
|
279
|
+
join_task = asyncio.create_task(queue.join())
|
|
280
|
+
termination_event_watcher = asyncio.create_task(
|
|
281
|
+
termination_event.wait()
|
|
282
|
+
)
|
|
283
|
+
done, pending = await asyncio.wait(
|
|
284
|
+
[join_task, termination_event_watcher],
|
|
285
|
+
return_when=asyncio.FIRST_COMPLETED,
|
|
286
|
+
)
|
|
287
|
+
if termination_event_watcher in done:
|
|
288
|
+
# Cancel all tasks
|
|
289
|
+
if not join_task.done():
|
|
290
|
+
join_task.cancel()
|
|
291
|
+
if not producer.done():
|
|
292
|
+
producer.cancel()
|
|
293
|
+
for task in consumers:
|
|
294
|
+
if not task.done():
|
|
295
|
+
task.cancel()
|
|
296
|
+
|
|
297
|
+
if not termination_event_watcher.done():
|
|
298
|
+
termination_event_watcher.cancel()
|
|
299
|
+
|
|
300
|
+
# reset the SIGTERM handler
|
|
301
|
+
signal.signal(
|
|
302
|
+
self.termination_signal, original_handler
|
|
303
|
+
) # reset the SIGTERM handler
|
|
304
|
+
return outputs, execution_details
|
|
305
|
+
|
|
306
|
+
def run(
|
|
307
|
+
self, inputs: Sequence[Any]
|
|
308
|
+
) -> Tuple[List[Any], List[ExecutionDetails]]:
|
|
309
|
+
return asyncio.run(self.execute(inputs))
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
class SyncExecutor(Executor):
|
|
313
|
+
"""
|
|
314
|
+
Synchronous executor for generating outputs from inputs using a given generation function.
|
|
315
|
+
|
|
316
|
+
Args:
|
|
317
|
+
generation_fn (Callable[[Any], Any]): The generation function that takes an input and
|
|
318
|
+
returns an output.
|
|
319
|
+
|
|
320
|
+
tqdm_bar_format (Optional[str], optional): The format string for the progress bar. Defaults
|
|
321
|
+
to None.
|
|
322
|
+
|
|
323
|
+
max_retries (int, optional): The maximum number of times to retry on exceptions. Defaults to
|
|
324
|
+
10.
|
|
325
|
+
|
|
326
|
+
exit_on_error (bool, optional): Whether to exit execution on the first encountered error.
|
|
327
|
+
Defaults to True.
|
|
328
|
+
|
|
329
|
+
fallback_return_value (Union[Unset, Any], optional): The fallback return value for tasks
|
|
330
|
+
that encounter errors. Defaults to _unset.
|
|
331
|
+
|
|
332
|
+
"""
|
|
333
|
+
|
|
334
|
+
def __init__(
|
|
335
|
+
self,
|
|
336
|
+
generation_fn: Callable[[Any], Any],
|
|
337
|
+
tqdm_bar_format: Optional[str] = None,
|
|
338
|
+
max_retries: int = 10,
|
|
339
|
+
exit_on_error: bool = True,
|
|
340
|
+
fallback_return_value: Union[Unset, Any] = _unset,
|
|
341
|
+
termination_signal: Optional[signal.Signals] = signal.SIGINT,
|
|
342
|
+
):
|
|
343
|
+
self.generate = generation_fn
|
|
344
|
+
self.fallback_return_value = fallback_return_value
|
|
345
|
+
self.tqdm_bar_format = tqdm_bar_format
|
|
346
|
+
self.max_retries = max_retries
|
|
347
|
+
self.exit_on_error = exit_on_error
|
|
348
|
+
self.termination_signal = termination_signal
|
|
349
|
+
|
|
350
|
+
self._TERMINATE = False
|
|
351
|
+
|
|
352
|
+
def _signal_handler(self, signum: int, frame: Any) -> None:
|
|
353
|
+
tqdm.write(
|
|
354
|
+
"Process was interrupted. The return value will be incomplete..."
|
|
355
|
+
)
|
|
356
|
+
self._TERMINATE = True
|
|
357
|
+
|
|
358
|
+
@contextmanager
|
|
359
|
+
def _executor_signal_handling(
|
|
360
|
+
self, signum: Optional[int]
|
|
361
|
+
) -> Generator[None, None, None]:
|
|
362
|
+
original_handler = None
|
|
363
|
+
if signum is not None:
|
|
364
|
+
original_handler = signal.signal(signum, self._signal_handler)
|
|
365
|
+
try:
|
|
366
|
+
yield
|
|
367
|
+
finally:
|
|
368
|
+
signal.signal(signum, original_handler)
|
|
369
|
+
else:
|
|
370
|
+
yield
|
|
371
|
+
|
|
372
|
+
def run(self, inputs: Sequence[Any]) -> Tuple[List[Any], List[Any]]:
|
|
373
|
+
with self._executor_signal_handling(self.termination_signal):
|
|
374
|
+
outputs = [self.fallback_return_value] * len(inputs)
|
|
375
|
+
execution_details: List[ExecutionDetails] = [
|
|
376
|
+
ExecutionDetails() for _ in range(len(inputs))
|
|
377
|
+
]
|
|
378
|
+
progress_bar = tqdm(
|
|
379
|
+
total=len(inputs), bar_format=self.tqdm_bar_format
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
for index, input in enumerate(inputs):
|
|
383
|
+
task_start_time = time.time()
|
|
384
|
+
try:
|
|
385
|
+
for attempt in range(self.max_retries + 1):
|
|
386
|
+
if self._TERMINATE:
|
|
387
|
+
return outputs, execution_details
|
|
388
|
+
try:
|
|
389
|
+
result = self.generate(input)
|
|
390
|
+
outputs[index] = result
|
|
391
|
+
execution_details[index].complete()
|
|
392
|
+
progress_bar.update()
|
|
393
|
+
break
|
|
394
|
+
except Exception as exc:
|
|
395
|
+
execution_details[index].log_exception(exc)
|
|
396
|
+
is_arize_exception = isinstance(exc, ArizeException)
|
|
397
|
+
if (
|
|
398
|
+
attempt >= self.max_retries
|
|
399
|
+
or is_arize_exception
|
|
400
|
+
):
|
|
401
|
+
raise exc
|
|
402
|
+
else:
|
|
403
|
+
tqdm.write(
|
|
404
|
+
f"Exception in worker on attempt {attempt + 1}: {exc}"
|
|
405
|
+
)
|
|
406
|
+
tqdm.write("Retrying...")
|
|
407
|
+
except Exception as exc:
|
|
408
|
+
execution_details[index].fail()
|
|
409
|
+
tqdm.write(
|
|
410
|
+
f"Retries exhausted after {attempt + 1} attempts: {exc}"
|
|
411
|
+
)
|
|
412
|
+
if self.exit_on_error:
|
|
413
|
+
return outputs, execution_details
|
|
414
|
+
else:
|
|
415
|
+
progress_bar.update()
|
|
416
|
+
finally:
|
|
417
|
+
execution_details[index].log_runtime(task_start_time)
|
|
418
|
+
return outputs, execution_details
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
def get_executor_on_sync_context(
|
|
422
|
+
sync_fn: Callable[[Any], Any],
|
|
423
|
+
async_fn: Callable[[Any], Coroutine[Any, Any, Any]],
|
|
424
|
+
run_sync: bool = False,
|
|
425
|
+
concurrency: int = 3,
|
|
426
|
+
tqdm_bar_format: Optional[str] = None,
|
|
427
|
+
max_retries: int = 10,
|
|
428
|
+
exit_on_error: bool = True,
|
|
429
|
+
fallback_return_value: Union[Unset, Any] = _unset,
|
|
430
|
+
) -> Executor:
|
|
431
|
+
if threading.current_thread() is not threading.main_thread():
|
|
432
|
+
# run evals synchronously if not in the main thread
|
|
433
|
+
|
|
434
|
+
if run_sync is False:
|
|
435
|
+
logger.warning(
|
|
436
|
+
"Async evals execution is not supported in non-main threads. Falling back to sync."
|
|
437
|
+
)
|
|
438
|
+
return SyncExecutor(
|
|
439
|
+
sync_fn,
|
|
440
|
+
tqdm_bar_format=tqdm_bar_format,
|
|
441
|
+
exit_on_error=exit_on_error,
|
|
442
|
+
max_retries=max_retries,
|
|
443
|
+
fallback_return_value=fallback_return_value,
|
|
444
|
+
termination_signal=None,
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
if run_sync is True:
|
|
448
|
+
return SyncExecutor(
|
|
449
|
+
sync_fn,
|
|
450
|
+
tqdm_bar_format=tqdm_bar_format,
|
|
451
|
+
max_retries=max_retries,
|
|
452
|
+
exit_on_error=exit_on_error,
|
|
453
|
+
fallback_return_value=fallback_return_value,
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
if _running_event_loop_exists():
|
|
457
|
+
if getattr(asyncio, "_nest_patched", False):
|
|
458
|
+
return AsyncExecutor(
|
|
459
|
+
async_fn,
|
|
460
|
+
concurrency=concurrency,
|
|
461
|
+
tqdm_bar_format=tqdm_bar_format,
|
|
462
|
+
max_retries=max_retries,
|
|
463
|
+
exit_on_error=exit_on_error,
|
|
464
|
+
fallback_return_value=fallback_return_value,
|
|
465
|
+
)
|
|
466
|
+
else:
|
|
467
|
+
logger.warning(
|
|
468
|
+
"🐌!! If running inside a notebook, patching the event loop with "
|
|
469
|
+
"nest_asyncio will allow asynchronous eval submission, and is significantly "
|
|
470
|
+
"faster. To patch the event loop, run `nest_asyncio.apply()`."
|
|
471
|
+
)
|
|
472
|
+
return SyncExecutor(
|
|
473
|
+
sync_fn,
|
|
474
|
+
tqdm_bar_format=tqdm_bar_format,
|
|
475
|
+
max_retries=max_retries,
|
|
476
|
+
exit_on_error=exit_on_error,
|
|
477
|
+
fallback_return_value=fallback_return_value,
|
|
478
|
+
)
|
|
479
|
+
else:
|
|
480
|
+
return AsyncExecutor(
|
|
481
|
+
async_fn,
|
|
482
|
+
concurrency=concurrency,
|
|
483
|
+
tqdm_bar_format=tqdm_bar_format,
|
|
484
|
+
max_retries=max_retries,
|
|
485
|
+
exit_on_error=exit_on_error,
|
|
486
|
+
fallback_return_value=fallback_return_value,
|
|
487
|
+
)
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
def _running_event_loop_exists() -> bool:
|
|
491
|
+
"""
|
|
492
|
+
Checks for a running event loop.
|
|
493
|
+
|
|
494
|
+
Returns:
|
|
495
|
+
bool: True if a running event loop exists, False otherwise.
|
|
496
|
+
|
|
497
|
+
"""
|
|
498
|
+
try:
|
|
499
|
+
asyncio.get_running_loop()
|
|
500
|
+
return True
|
|
501
|
+
except RuntimeError:
|
|
502
|
+
return False
|