modal 0.62.115__py3-none-any.whl → 0.72.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- modal/__init__.py +13 -9
- modal/__main__.py +41 -3
- modal/_clustered_functions.py +80 -0
- modal/_clustered_functions.pyi +22 -0
- modal/_container_entrypoint.py +402 -398
- modal/_ipython.py +3 -13
- modal/_location.py +17 -10
- modal/_output.py +243 -99
- modal/_pty.py +2 -2
- modal/_resolver.py +55 -60
- modal/_resources.py +26 -7
- modal/_runtime/__init__.py +1 -0
- modal/_runtime/asgi.py +519 -0
- modal/_runtime/container_io_manager.py +1025 -0
- modal/{execution_context.py → _runtime/execution_context.py} +11 -2
- modal/_runtime/telemetry.py +169 -0
- modal/_runtime/user_code_imports.py +356 -0
- modal/_serialization.py +123 -6
- modal/_traceback.py +47 -187
- modal/_tunnel.py +50 -14
- modal/_tunnel.pyi +19 -36
- modal/_utils/app_utils.py +3 -17
- modal/_utils/async_utils.py +386 -104
- modal/_utils/blob_utils.py +157 -186
- modal/_utils/bytes_io_segment_payload.py +97 -0
- modal/_utils/deprecation.py +89 -0
- modal/_utils/docker_utils.py +98 -0
- modal/_utils/function_utils.py +299 -98
- modal/_utils/grpc_testing.py +47 -34
- modal/_utils/grpc_utils.py +54 -21
- modal/_utils/hash_utils.py +51 -10
- modal/_utils/http_utils.py +39 -9
- modal/_utils/logger.py +2 -1
- modal/_utils/mount_utils.py +34 -16
- modal/_utils/name_utils.py +58 -0
- modal/_utils/package_utils.py +14 -1
- modal/_utils/pattern_utils.py +205 -0
- modal/_utils/rand_pb_testing.py +3 -3
- modal/_utils/shell_utils.py +15 -49
- modal/_vendor/a2wsgi_wsgi.py +62 -72
- modal/_vendor/cloudpickle.py +1 -1
- modal/_watcher.py +12 -10
- modal/app.py +561 -323
- modal/app.pyi +474 -262
- modal/call_graph.py +7 -6
- modal/cli/_download.py +22 -6
- modal/cli/_traceback.py +200 -0
- modal/cli/app.py +203 -42
- modal/cli/config.py +12 -5
- modal/cli/container.py +61 -13
- modal/cli/dict.py +128 -0
- modal/cli/entry_point.py +26 -13
- modal/cli/environment.py +40 -9
- modal/cli/import_refs.py +21 -48
- modal/cli/launch.py +28 -14
- modal/cli/network_file_system.py +57 -21
- modal/cli/profile.py +1 -1
- modal/cli/programs/run_jupyter.py +34 -9
- modal/cli/programs/vscode.py +58 -8
- modal/cli/queues.py +131 -0
- modal/cli/run.py +199 -96
- modal/cli/secret.py +5 -4
- modal/cli/token.py +7 -2
- modal/cli/utils.py +74 -8
- modal/cli/volume.py +97 -56
- modal/client.py +248 -144
- modal/client.pyi +156 -124
- modal/cloud_bucket_mount.py +43 -30
- modal/cloud_bucket_mount.pyi +32 -25
- modal/cls.py +528 -141
- modal/cls.pyi +189 -145
- modal/config.py +32 -15
- modal/container_process.py +177 -0
- modal/container_process.pyi +82 -0
- modal/dict.py +50 -54
- modal/dict.pyi +120 -164
- modal/environments.py +106 -5
- modal/environments.pyi +77 -25
- modal/exception.py +30 -43
- modal/experimental.py +62 -2
- modal/file_io.py +537 -0
- modal/file_io.pyi +235 -0
- modal/file_pattern_matcher.py +196 -0
- modal/functions.py +846 -428
- modal/functions.pyi +446 -387
- modal/gpu.py +57 -44
- modal/image.py +943 -417
- modal/image.pyi +584 -245
- modal/io_streams.py +434 -0
- modal/io_streams.pyi +122 -0
- modal/mount.py +223 -90
- modal/mount.pyi +241 -243
- modal/network_file_system.py +85 -86
- modal/network_file_system.pyi +151 -110
- modal/object.py +66 -36
- modal/object.pyi +166 -143
- modal/output.py +63 -0
- modal/parallel_map.py +73 -47
- modal/parallel_map.pyi +51 -63
- modal/partial_function.py +272 -107
- modal/partial_function.pyi +219 -120
- modal/proxy.py +15 -12
- modal/proxy.pyi +3 -8
- modal/queue.py +96 -72
- modal/queue.pyi +210 -135
- modal/requirements/2024.04.txt +2 -1
- modal/requirements/2024.10.txt +16 -0
- modal/requirements/README.md +21 -0
- modal/requirements/base-images.json +22 -0
- modal/retries.py +45 -4
- modal/runner.py +325 -203
- modal/runner.pyi +124 -110
- modal/running_app.py +27 -4
- modal/sandbox.py +509 -231
- modal/sandbox.pyi +396 -169
- modal/schedule.py +2 -2
- modal/scheduler_placement.py +20 -3
- modal/secret.py +41 -25
- modal/secret.pyi +62 -42
- modal/serving.py +39 -49
- modal/serving.pyi +37 -43
- modal/stream_type.py +15 -0
- modal/token_flow.py +5 -3
- modal/token_flow.pyi +37 -32
- modal/volume.py +123 -137
- modal/volume.pyi +228 -221
- {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/METADATA +5 -5
- modal-0.72.13.dist-info/RECORD +174 -0
- {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/top_level.txt +0 -1
- modal_docs/gen_reference_docs.py +3 -1
- modal_docs/mdmd/mdmd.py +0 -1
- modal_docs/mdmd/signatures.py +1 -2
- modal_global_objects/images/base_images.py +28 -0
- modal_global_objects/mounts/python_standalone.py +2 -2
- modal_proto/__init__.py +1 -1
- modal_proto/api.proto +1231 -531
- modal_proto/api_grpc.py +750 -430
- modal_proto/api_pb2.py +2102 -1176
- modal_proto/api_pb2.pyi +8859 -0
- modal_proto/api_pb2_grpc.py +1329 -675
- modal_proto/api_pb2_grpc.pyi +1416 -0
- modal_proto/modal_api_grpc.py +149 -0
- modal_proto/modal_options_grpc.py +3 -0
- modal_proto/options_pb2.pyi +20 -0
- modal_proto/options_pb2_grpc.pyi +7 -0
- modal_proto/py.typed +0 -0
- modal_version/__init__.py +1 -1
- modal_version/_version_generated.py +2 -2
- modal/_asgi.py +0 -370
- modal/_container_exec.py +0 -128
- modal/_container_io_manager.py +0 -646
- modal/_container_io_manager.pyi +0 -412
- modal/_sandbox_shell.py +0 -49
- modal/app_utils.py +0 -20
- modal/app_utils.pyi +0 -17
- modal/execution_context.pyi +0 -37
- modal/shared_volume.py +0 -23
- modal/shared_volume.pyi +0 -24
- modal-0.62.115.dist-info/RECORD +0 -207
- modal_global_objects/images/conda.py +0 -15
- modal_global_objects/images/debian_slim.py +0 -15
- modal_global_objects/images/micromamba.py +0 -15
- test/__init__.py +0 -1
- test/aio_test.py +0 -12
- test/async_utils_test.py +0 -279
- test/blob_test.py +0 -67
- test/cli_imports_test.py +0 -149
- test/cli_test.py +0 -674
- test/client_test.py +0 -203
- test/cloud_bucket_mount_test.py +0 -22
- test/cls_test.py +0 -636
- test/config_test.py +0 -149
- test/conftest.py +0 -1485
- test/container_app_test.py +0 -50
- test/container_test.py +0 -1405
- test/cpu_test.py +0 -23
- test/decorator_test.py +0 -85
- test/deprecation_test.py +0 -34
- test/dict_test.py +0 -51
- test/e2e_test.py +0 -68
- test/error_test.py +0 -7
- test/function_serialization_test.py +0 -32
- test/function_test.py +0 -791
- test/function_utils_test.py +0 -101
- test/gpu_test.py +0 -159
- test/grpc_utils_test.py +0 -82
- test/helpers.py +0 -47
- test/image_test.py +0 -814
- test/live_reload_test.py +0 -80
- test/lookup_test.py +0 -70
- test/mdmd_test.py +0 -329
- test/mount_test.py +0 -162
- test/mounted_files_test.py +0 -327
- test/network_file_system_test.py +0 -188
- test/notebook_test.py +0 -66
- test/object_test.py +0 -41
- test/package_utils_test.py +0 -25
- test/queue_test.py +0 -115
- test/resolver_test.py +0 -59
- test/retries_test.py +0 -67
- test/runner_test.py +0 -85
- test/sandbox_test.py +0 -191
- test/schedule_test.py +0 -15
- test/scheduler_placement_test.py +0 -57
- test/secret_test.py +0 -89
- test/serialization_test.py +0 -50
- test/stub_composition_test.py +0 -10
- test/stub_test.py +0 -361
- test/test_asgi_wrapper.py +0 -234
- test/token_flow_test.py +0 -18
- test/traceback_test.py +0 -135
- test/tunnel_test.py +0 -29
- test/utils_test.py +0 -88
- test/version_test.py +0 -14
- test/volume_test.py +0 -397
- test/watcher_test.py +0 -58
- test/webhook_test.py +0 -145
- {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/LICENSE +0 -0
- {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/WHEEL +0 -0
- {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/entry_points.txt +0 -0
modal/_utils/async_utils.py
CHANGED
@@ -3,25 +3,34 @@ import asyncio
|
|
3
3
|
import concurrent.futures
|
4
4
|
import functools
|
5
5
|
import inspect
|
6
|
+
import itertools
|
6
7
|
import time
|
7
8
|
import typing
|
9
|
+
from collections.abc import AsyncGenerator, AsyncIterable, Awaitable, Iterable, Iterator
|
8
10
|
from contextlib import asynccontextmanager
|
9
|
-
from
|
11
|
+
from dataclasses import dataclass
|
12
|
+
from typing import (
|
13
|
+
Any,
|
14
|
+
Callable,
|
15
|
+
Optional,
|
16
|
+
TypeVar,
|
17
|
+
Union,
|
18
|
+
cast,
|
19
|
+
)
|
10
20
|
|
11
21
|
import synchronicity
|
12
|
-
from
|
22
|
+
from synchronicity.async_utils import Runner
|
23
|
+
from synchronicity.exceptions import NestedEventLoops
|
24
|
+
from typing_extensions import ParamSpec, assert_type
|
13
25
|
|
14
26
|
from ..exception import InvalidError
|
15
27
|
from .logger import logger
|
16
28
|
|
17
29
|
synchronizer = synchronicity.Synchronizer()
|
18
|
-
# atexit.register(synchronizer.close)
|
19
30
|
|
20
31
|
|
21
32
|
def synchronize_api(obj, target_module=None):
|
22
|
-
if inspect.isclass(obj):
|
23
|
-
blocking_name = obj.__name__.lstrip("_")
|
24
|
-
elif inspect.isfunction(object):
|
33
|
+
if inspect.isclass(obj) or inspect.isfunction(obj):
|
25
34
|
blocking_name = obj.__name__.lstrip("_")
|
26
35
|
elif isinstance(obj, TypeVar):
|
27
36
|
blocking_name = "_BLOCKING_" + obj.__name__
|
@@ -103,7 +112,7 @@ class TaskContext:
|
|
103
112
|
```
|
104
113
|
"""
|
105
114
|
|
106
|
-
_loops:
|
115
|
+
_loops: set[asyncio.Task]
|
107
116
|
|
108
117
|
def __init__(self, grace: Optional[float] = None):
|
109
118
|
self._grace = grace
|
@@ -140,7 +149,6 @@ class TaskContext:
|
|
140
149
|
if gather_future:
|
141
150
|
try:
|
142
151
|
await gather_future
|
143
|
-
# pre Python3.8, CancelledErrors were a subclass of exception
|
144
152
|
except asyncio.CancelledError:
|
145
153
|
pass
|
146
154
|
|
@@ -150,11 +158,12 @@ class TaskContext:
|
|
150
158
|
# Only tasks without a done_callback will still be present in self._tasks
|
151
159
|
task.result()
|
152
160
|
|
153
|
-
if task.done() or task in self._loops:
|
161
|
+
if task.done() or task in self._loops: # Note: Legacy code, we can probably cancel loops.
|
154
162
|
continue
|
155
163
|
|
156
164
|
# Cancel any remaining unfinished tasks.
|
157
165
|
task.cancel()
|
166
|
+
await asyncio.sleep(0) # wake up coroutines waiting for cancellations
|
158
167
|
|
159
168
|
async def __aexit__(self, exc_type, value, tb):
|
160
169
|
await self.stop()
|
@@ -171,28 +180,32 @@ class TaskContext:
|
|
171
180
|
task.add_done_callback(self._tasks.discard)
|
172
181
|
return task
|
173
182
|
|
174
|
-
def infinite_loop(
|
175
|
-
|
183
|
+
def infinite_loop(
|
184
|
+
self, async_f, timeout: Optional[float] = 90, sleep: float = 10, log_exception: bool = True
|
185
|
+
) -> asyncio.Task:
|
186
|
+
if isinstance(async_f, functools.partial):
|
187
|
+
function_name = async_f.func.__qualname__
|
188
|
+
else:
|
189
|
+
function_name = async_f.__qualname__
|
176
190
|
|
177
191
|
async def loop_coro() -> None:
|
178
192
|
logger.debug(f"Starting infinite loop {function_name}")
|
179
|
-
while
|
180
|
-
t0 = time.time()
|
193
|
+
while not self.exited:
|
181
194
|
try:
|
182
195
|
await asyncio.wait_for(async_f(), timeout=timeout)
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
196
|
+
except Exception as exc:
|
197
|
+
if log_exception and isinstance(exc, asyncio.TimeoutError):
|
198
|
+
# Asyncio sends an empty message in this case, so let's use logger.error
|
199
|
+
logger.error(f"Loop attempt for {function_name} timed out")
|
200
|
+
elif log_exception:
|
201
|
+
# Propagate the exception to the logger
|
202
|
+
logger.exception(f"Loop attempt for {function_name} failed")
|
189
203
|
try:
|
190
204
|
await asyncio.wait_for(self._exited.wait(), timeout=sleep)
|
191
205
|
except asyncio.TimeoutError:
|
192
206
|
continue
|
193
|
-
|
194
|
-
|
195
|
-
break
|
207
|
+
|
208
|
+
logger.debug(f"Exiting infinite loop for {function_name}")
|
196
209
|
|
197
210
|
t = self.create_task(loop_coro())
|
198
211
|
t.set_name(f"{function_name} loop")
|
@@ -200,29 +213,39 @@ class TaskContext:
|
|
200
213
|
t.add_done_callback(self._loops.discard)
|
201
214
|
return t
|
202
215
|
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
216
|
+
@staticmethod
|
217
|
+
async def gather(*coros: Awaitable) -> Any:
|
218
|
+
"""Wait for a sequence of coroutines to finish, concurrently.
|
219
|
+
|
220
|
+
This is similar to `asyncio.gather()`, but it uses TaskContext to cancel all remaining tasks
|
221
|
+
if one fails with an exception other than `asyncio.CancelledError`. The native `asyncio`
|
222
|
+
function does not cancel remaining tasks in this case, which can lead to surprises.
|
223
|
+
|
224
|
+
For example, if you use `asyncio.gather(t1, t2, t3)` and t2 raises an exception, then t1 and
|
225
|
+
t3 would continue running. With `TaskContext.gather(t1, t2, t3)`, they are cancelled.
|
226
|
+
|
227
|
+
(It's still acceptable to use `asyncio.gather()` if you don't need cancellation — for
|
228
|
+
example, if you're just gathering quick coroutines with no side-effects. Or if you're
|
229
|
+
gathering the tasks with `return_exceptions=True`.)
|
230
|
+
|
231
|
+
Usage:
|
232
|
+
|
233
|
+
```python notest
|
234
|
+
# Example 1: Await three coroutines
|
235
|
+
created_object, other_work, new_plumbing = await TaskContext.gather(
|
236
|
+
create_my_object(),
|
237
|
+
do_some_other_work(),
|
238
|
+
fix_plumbing(),
|
239
|
+
)
|
240
|
+
|
241
|
+
# Example 2: Gather a list of coroutines
|
242
|
+
coros = [a.load() for a in objects]
|
243
|
+
results = await TaskContext.gather(*coros)
|
244
|
+
```
|
245
|
+
"""
|
246
|
+
async with TaskContext() as tc:
|
247
|
+
results = await asyncio.gather(*(tc.create_task(coro) for coro in coros))
|
248
|
+
return results
|
226
249
|
|
227
250
|
|
228
251
|
def run_coro_blocking(coro):
|
@@ -243,7 +266,7 @@ async def queue_batch_iterator(q: asyncio.Queue, max_batch_size=100, debounce_ti
|
|
243
266
|
|
244
267
|
Treats a None value as end of queue items
|
245
268
|
"""
|
246
|
-
item_list:
|
269
|
+
item_list: list[Any] = []
|
247
270
|
|
248
271
|
while True:
|
249
272
|
if q.empty() and len(item_list) > 0:
|
@@ -290,10 +313,18 @@ class _WarnIfGeneratorIsNotConsumed:
|
|
290
313
|
if not self.iterated and not self.warned:
|
291
314
|
self.warned = True
|
292
315
|
logger.warning(
|
293
|
-
f"Warning: the results of a call to {self.function_name} was not consumed,
|
294
|
-
|
316
|
+
f"Warning: the results of a call to {self.function_name} was not consumed, "
|
317
|
+
"so the call will never be executed."
|
318
|
+
f" Consider a for-loop like `for x in {self.function_name}(...)` or "
|
319
|
+
"unpacking the generator using `list(...)`"
|
295
320
|
)
|
296
321
|
|
322
|
+
async def athrow(self, exc):
|
323
|
+
return await self.gen.athrow(exc)
|
324
|
+
|
325
|
+
async def aclose(self):
|
326
|
+
return await self.gen.aclose()
|
327
|
+
|
297
328
|
|
298
329
|
synchronize_api(_WarnIfGeneratorIsNotConsumed)
|
299
330
|
|
@@ -331,7 +362,7 @@ def warn_if_generator_is_not_consumed(function_name: Optional[str] = None):
|
|
331
362
|
return decorator
|
332
363
|
|
333
364
|
|
334
|
-
class
|
365
|
+
class AsyncOrSyncIterable:
|
335
366
|
"""Compatibility class for non-synchronicity wrapped async iterables to get
|
336
367
|
both async and sync interfaces in the same way that synchronicity does (but on the main thread)
|
337
368
|
so they can be "lazily" iterated using either `for _ in x` or `async for _ in x`
|
@@ -340,7 +371,7 @@ class AsyncOrSyncIteratable:
|
|
340
371
|
from an already async context, since that would otherwise deadlock the event loop
|
341
372
|
"""
|
342
373
|
|
343
|
-
def __init__(self, async_iterable: typing.
|
374
|
+
def __init__(self, async_iterable: typing.AsyncGenerator[Any, None], nested_async_message):
|
344
375
|
self._async_iterable = async_iterable
|
345
376
|
self.nested_async_message = nested_async_message
|
346
377
|
|
@@ -349,9 +380,9 @@ class AsyncOrSyncIteratable:
|
|
349
380
|
|
350
381
|
def __iter__(self):
|
351
382
|
try:
|
352
|
-
|
353
|
-
yield
|
354
|
-
except
|
383
|
+
with Runner() as runner:
|
384
|
+
yield from run_async_gen(runner, self._async_iterable)
|
385
|
+
except NestedEventLoops:
|
355
386
|
raise InvalidError(self.nested_async_message)
|
356
387
|
|
357
388
|
|
@@ -372,6 +403,7 @@ def on_shutdown(coro):
|
|
372
403
|
|
373
404
|
T = TypeVar("T")
|
374
405
|
P = ParamSpec("P")
|
406
|
+
V = TypeVar("V")
|
375
407
|
|
376
408
|
|
377
409
|
def asyncify(f: Callable[P, T]) -> Callable[P, typing.Coroutine[None, None, T]]:
|
@@ -397,40 +429,6 @@ async def iterate_blocking(iterator: Iterator[T]) -> AsyncGenerator[T, None]:
|
|
397
429
|
yield cast(T, obj)
|
398
430
|
|
399
431
|
|
400
|
-
class ConcurrencyPool:
|
401
|
-
def __init__(self, concurrency_limit: int):
|
402
|
-
self.semaphore = asyncio.Semaphore(concurrency_limit)
|
403
|
-
|
404
|
-
async def run_coros(self, coros: typing.Iterable[typing.Coroutine], return_exceptions=False):
|
405
|
-
async def blocking_wrapper(coro):
|
406
|
-
# Not using async with on the semaphore is intentional here - if return_exceptions=False
|
407
|
-
# manual release prevents starting extraneous tasks after exceptions.
|
408
|
-
try:
|
409
|
-
await self.semaphore.acquire()
|
410
|
-
except asyncio.CancelledError:
|
411
|
-
coro.close() # avoid "coroutine was never awaited" warnings
|
412
|
-
|
413
|
-
try:
|
414
|
-
res = await coro
|
415
|
-
self.semaphore.release()
|
416
|
-
return res
|
417
|
-
except BaseException as e:
|
418
|
-
if return_exceptions:
|
419
|
-
self.semaphore.release()
|
420
|
-
raise e
|
421
|
-
|
422
|
-
# asyncio.gather() is weird - it doesn't cancel outstanding awaitables on exceptions when
|
423
|
-
# return_exceptions=False --> wrap the coros in tasks are cancel them explicitly on exception.
|
424
|
-
tasks = [asyncio.create_task(blocking_wrapper(coro)) for coro in coros]
|
425
|
-
g = asyncio.gather(*tasks, return_exceptions=return_exceptions)
|
426
|
-
try:
|
427
|
-
return await g
|
428
|
-
except BaseException as e:
|
429
|
-
for t in tasks:
|
430
|
-
t.cancel()
|
431
|
-
raise e
|
432
|
-
|
433
|
-
|
434
432
|
@asynccontextmanager
|
435
433
|
async def asyncnullcontext(*args, **kwargs):
|
436
434
|
"""Async noop context manager.
|
@@ -448,21 +446,11 @@ YIELD_TYPE = typing.TypeVar("YIELD_TYPE")
|
|
448
446
|
SEND_TYPE = typing.TypeVar("SEND_TYPE")
|
449
447
|
|
450
448
|
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
def run_generator_sync(
|
449
|
+
def run_async_gen(
|
450
|
+
runner: Runner,
|
456
451
|
gen: typing.AsyncGenerator[YIELD_TYPE, SEND_TYPE],
|
457
452
|
) -> typing.Generator[YIELD_TYPE, SEND_TYPE, None]:
|
458
|
-
|
459
|
-
asyncio.get_running_loop()
|
460
|
-
except RuntimeError:
|
461
|
-
pass # no event loop - this is what we expect!
|
462
|
-
else:
|
463
|
-
raise NestedAsyncCalls()
|
464
|
-
loop = asyncio.new_event_loop() # set up new event loop for the map so we can use async logic
|
465
|
-
|
453
|
+
"""Convert an async generator into a sync one"""
|
466
454
|
# more or less copied from synchronicity's implementation:
|
467
455
|
next_send: typing.Union[SEND_TYPE, None] = None
|
468
456
|
next_yield: YIELD_TYPE
|
@@ -470,14 +458,308 @@ def run_generator_sync(
|
|
470
458
|
while True:
|
471
459
|
try:
|
472
460
|
if exc:
|
473
|
-
next_yield =
|
461
|
+
next_yield = runner.run(gen.athrow(exc))
|
474
462
|
else:
|
475
|
-
next_yield =
|
463
|
+
next_yield = runner.run(gen.asend(next_send)) # type: ignore[arg-type]
|
464
|
+
except KeyboardInterrupt as e:
|
465
|
+
raise e from None
|
476
466
|
except StopAsyncIteration:
|
477
|
-
break
|
467
|
+
break # typically a graceful exit of the async generator
|
478
468
|
try:
|
479
469
|
next_send = yield next_yield
|
480
470
|
exc = None
|
481
471
|
except BaseException as err:
|
482
472
|
exc = err
|
483
|
-
|
473
|
+
|
474
|
+
|
475
|
+
class aclosing(typing.Generic[T]): # noqa
|
476
|
+
# backport of Python contextlib.aclosing from Python 3.10
|
477
|
+
def __init__(self, agen: AsyncGenerator[T, None]):
|
478
|
+
self.agen = agen
|
479
|
+
|
480
|
+
async def __aenter__(self) -> AsyncGenerator[T, None]:
|
481
|
+
return self.agen
|
482
|
+
|
483
|
+
async def __aexit__(self, exc, exc_type, tb):
|
484
|
+
await self.agen.aclose()
|
485
|
+
|
486
|
+
|
487
|
+
async def sync_or_async_iter(iter: Union[Iterable[T], AsyncIterable[T]]) -> AsyncGenerator[T, None]:
|
488
|
+
if hasattr(iter, "__aiter__"):
|
489
|
+
agen = typing.cast(AsyncGenerator[T, None], iter)
|
490
|
+
try:
|
491
|
+
async for item in agen:
|
492
|
+
yield item
|
493
|
+
finally:
|
494
|
+
if hasattr(agen, "aclose"):
|
495
|
+
# All AsyncGenerator's have an aclose method
|
496
|
+
# but some AsyncIterable's don't necessarily
|
497
|
+
await agen.aclose()
|
498
|
+
else:
|
499
|
+
assert hasattr(iter, "__iter__"), "sync_or_async_iter requires an Iterable or AsyncGenerator"
|
500
|
+
# This intentionally could block the event loop for the duration of calling __iter__ and __next__,
|
501
|
+
# so in non-trivial cases (like passing lists and ranges) this could be quite a foot gun for users #
|
502
|
+
# w/ async code (but they can work around it by always using async iterators)
|
503
|
+
for item in typing.cast(Iterable[T], iter):
|
504
|
+
yield item
|
505
|
+
|
506
|
+
|
507
|
+
@typing.overload
|
508
|
+
def async_zip(g1: AsyncGenerator[T, None], g2: AsyncGenerator[V, None], /) -> AsyncGenerator[tuple[T, V], None]:
|
509
|
+
...
|
510
|
+
|
511
|
+
|
512
|
+
@typing.overload
|
513
|
+
def async_zip(*generators: AsyncGenerator[T, None]) -> AsyncGenerator[tuple[T, ...], None]:
|
514
|
+
...
|
515
|
+
|
516
|
+
|
517
|
+
async def async_zip(*generators):
|
518
|
+
tasks = []
|
519
|
+
try:
|
520
|
+
while True:
|
521
|
+
try:
|
522
|
+
|
523
|
+
async def next_item(gen):
|
524
|
+
return await gen.__anext__()
|
525
|
+
|
526
|
+
tasks = [asyncio.create_task(next_item(gen)) for gen in generators]
|
527
|
+
items = await asyncio.gather(*tasks)
|
528
|
+
yield tuple(items)
|
529
|
+
except StopAsyncIteration:
|
530
|
+
break
|
531
|
+
finally:
|
532
|
+
cancelled_tasks = []
|
533
|
+
for task in tasks:
|
534
|
+
if not task.done():
|
535
|
+
task.cancel()
|
536
|
+
cancelled_tasks.append(task)
|
537
|
+
try:
|
538
|
+
await asyncio.gather(*cancelled_tasks)
|
539
|
+
except asyncio.CancelledError:
|
540
|
+
pass
|
541
|
+
|
542
|
+
first_exception = None
|
543
|
+
for gen in generators:
|
544
|
+
try:
|
545
|
+
await gen.aclose()
|
546
|
+
except BaseException as e:
|
547
|
+
if first_exception is None:
|
548
|
+
first_exception = e
|
549
|
+
logger.exception(f"Error closing async generator: {e}")
|
550
|
+
if first_exception is not None:
|
551
|
+
raise first_exception
|
552
|
+
|
553
|
+
|
554
|
+
@dataclass
|
555
|
+
class ValueWrapper(typing.Generic[T]):
|
556
|
+
value: T
|
557
|
+
|
558
|
+
|
559
|
+
@dataclass
|
560
|
+
class ExceptionWrapper:
|
561
|
+
value: Exception
|
562
|
+
|
563
|
+
|
564
|
+
class StopSentinelType:
|
565
|
+
...
|
566
|
+
|
567
|
+
|
568
|
+
STOP_SENTINEL = StopSentinelType()
|
569
|
+
|
570
|
+
|
571
|
+
async def async_merge(*generators: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
|
572
|
+
"""
|
573
|
+
Asynchronously merges multiple async generators into a single async generator.
|
574
|
+
|
575
|
+
This function takes multiple async generators and yields their values in the order
|
576
|
+
they are produced. If any generator raises an exception, the exception is propagated.
|
577
|
+
|
578
|
+
Args:
|
579
|
+
*generators: One or more async generators to be merged.
|
580
|
+
|
581
|
+
Yields:
|
582
|
+
The values produced by the input async generators.
|
583
|
+
|
584
|
+
Raises:
|
585
|
+
Exception: If any of the input generators raises an exception, it is propagated.
|
586
|
+
|
587
|
+
Usage:
|
588
|
+
```python
|
589
|
+
import asyncio
|
590
|
+
from modal._utils.async_utils import async_merge
|
591
|
+
|
592
|
+
async def gen1():
|
593
|
+
yield 1
|
594
|
+
yield 2
|
595
|
+
|
596
|
+
async def gen2():
|
597
|
+
yield "a"
|
598
|
+
yield "b"
|
599
|
+
|
600
|
+
async def example():
|
601
|
+
values = set()
|
602
|
+
async for value in async_merge(gen1(), gen2()):
|
603
|
+
values.add(value)
|
604
|
+
|
605
|
+
return values
|
606
|
+
|
607
|
+
# Output could be: {1, "a", 2, "b"} (order may vary)
|
608
|
+
values = asyncio.run(example())
|
609
|
+
assert values == {1, "a", 2, "b"}
|
610
|
+
```
|
611
|
+
"""
|
612
|
+
queue: asyncio.Queue[Union[ValueWrapper[T], ExceptionWrapper]] = asyncio.Queue(maxsize=len(generators) * 10)
|
613
|
+
|
614
|
+
async def producer(generator: AsyncGenerator[T, None]):
|
615
|
+
try:
|
616
|
+
async for item in generator:
|
617
|
+
await queue.put(ValueWrapper(item))
|
618
|
+
except Exception as e:
|
619
|
+
await queue.put(ExceptionWrapper(e))
|
620
|
+
|
621
|
+
tasks = {asyncio.create_task(producer(gen)) for gen in generators}
|
622
|
+
new_output_task = asyncio.create_task(queue.get())
|
623
|
+
|
624
|
+
try:
|
625
|
+
while tasks:
|
626
|
+
done, _ = await asyncio.wait(
|
627
|
+
[*tasks, new_output_task],
|
628
|
+
return_when=asyncio.FIRST_COMPLETED,
|
629
|
+
)
|
630
|
+
|
631
|
+
if new_output_task in done:
|
632
|
+
item = new_output_task.result()
|
633
|
+
if isinstance(item, ValueWrapper):
|
634
|
+
yield item.value
|
635
|
+
else:
|
636
|
+
assert_type(item, ExceptionWrapper)
|
637
|
+
raise item.value
|
638
|
+
|
639
|
+
new_output_task = asyncio.create_task(queue.get())
|
640
|
+
|
641
|
+
finished_producers = done & tasks
|
642
|
+
tasks -= finished_producers
|
643
|
+
for finished_producer in finished_producers:
|
644
|
+
# this is done in order to catch potential raised errors/cancellations
|
645
|
+
# from within worker tasks as soon as they happen.
|
646
|
+
await finished_producer
|
647
|
+
|
648
|
+
while not queue.empty():
|
649
|
+
item = await new_output_task
|
650
|
+
if isinstance(item, ValueWrapper):
|
651
|
+
yield item.value
|
652
|
+
else:
|
653
|
+
assert_type(item, ExceptionWrapper)
|
654
|
+
raise item.value
|
655
|
+
|
656
|
+
new_output_task = asyncio.create_task(queue.get())
|
657
|
+
|
658
|
+
finally:
|
659
|
+
if not new_output_task.done():
|
660
|
+
new_output_task.cancel()
|
661
|
+
for task in tasks:
|
662
|
+
if not task.done():
|
663
|
+
try:
|
664
|
+
task.cancel()
|
665
|
+
await task
|
666
|
+
except asyncio.CancelledError:
|
667
|
+
pass
|
668
|
+
|
669
|
+
|
670
|
+
async def callable_to_agen(awaitable: Callable[[], Awaitable[T]]) -> AsyncGenerator[T, None]:
|
671
|
+
yield await awaitable()
|
672
|
+
|
673
|
+
|
674
|
+
async def gather_cancel_on_exc(*coros_or_futures):
|
675
|
+
input_tasks = [asyncio.ensure_future(t) for t in coros_or_futures]
|
676
|
+
try:
|
677
|
+
return await asyncio.gather(*input_tasks)
|
678
|
+
except BaseException:
|
679
|
+
for t in input_tasks:
|
680
|
+
t.cancel()
|
681
|
+
await asyncio.gather(*input_tasks, return_exceptions=False) # handle cancellations
|
682
|
+
raise
|
683
|
+
|
684
|
+
|
685
|
+
async def async_map(
|
686
|
+
input_generator: AsyncGenerator[T, None],
|
687
|
+
async_mapper_func: Callable[[T], Awaitable[V]],
|
688
|
+
concurrency: int,
|
689
|
+
) -> AsyncGenerator[V, None]:
|
690
|
+
queue: asyncio.Queue[Union[ValueWrapper[T], StopSentinelType]] = asyncio.Queue(maxsize=concurrency * 2)
|
691
|
+
|
692
|
+
async def producer() -> AsyncGenerator[V, None]:
|
693
|
+
async for item in input_generator:
|
694
|
+
await queue.put(ValueWrapper(item))
|
695
|
+
|
696
|
+
for _ in range(concurrency):
|
697
|
+
await queue.put(STOP_SENTINEL)
|
698
|
+
|
699
|
+
if False:
|
700
|
+
# Need it to be an async generator for async_merge
|
701
|
+
# but we don't want to yield anything
|
702
|
+
yield
|
703
|
+
|
704
|
+
async def worker() -> AsyncGenerator[V, None]:
|
705
|
+
while True:
|
706
|
+
item = await queue.get()
|
707
|
+
if isinstance(item, ValueWrapper):
|
708
|
+
yield await async_mapper_func(item.value)
|
709
|
+
elif isinstance(item, ExceptionWrapper):
|
710
|
+
raise item.value
|
711
|
+
else:
|
712
|
+
assert_type(item, StopSentinelType)
|
713
|
+
break
|
714
|
+
|
715
|
+
async with aclosing(async_merge(*[worker() for _ in range(concurrency)], producer())) as stream:
|
716
|
+
async for item in stream:
|
717
|
+
yield item
|
718
|
+
|
719
|
+
|
720
|
+
async def async_map_ordered(
|
721
|
+
input_generator: AsyncGenerator[T, None],
|
722
|
+
async_mapper_func: Callable[[T], Awaitable[V]],
|
723
|
+
concurrency: int,
|
724
|
+
buffer_size: Optional[int] = None,
|
725
|
+
) -> AsyncGenerator[V, None]:
|
726
|
+
semaphore = asyncio.Semaphore(buffer_size or concurrency)
|
727
|
+
|
728
|
+
async def mapper_func_wrapper(tup: tuple[int, T]) -> tuple[int, V]:
|
729
|
+
return (tup[0], await async_mapper_func(tup[1]))
|
730
|
+
|
731
|
+
async def counter() -> AsyncGenerator[int, None]:
|
732
|
+
for i in itertools.count():
|
733
|
+
await semaphore.acquire()
|
734
|
+
yield i
|
735
|
+
|
736
|
+
next_idx = 0
|
737
|
+
buffer = {}
|
738
|
+
|
739
|
+
async with aclosing(async_map(async_zip(counter(), input_generator), mapper_func_wrapper, concurrency)) as stream:
|
740
|
+
async for output_idx, output_item in stream:
|
741
|
+
buffer[output_idx] = output_item
|
742
|
+
|
743
|
+
while next_idx in buffer:
|
744
|
+
yield buffer[next_idx]
|
745
|
+
semaphore.release()
|
746
|
+
del buffer[next_idx]
|
747
|
+
next_idx += 1
|
748
|
+
|
749
|
+
|
750
|
+
async def async_chain(*generators: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
|
751
|
+
try:
|
752
|
+
for gen in generators:
|
753
|
+
async for item in gen:
|
754
|
+
yield item
|
755
|
+
finally:
|
756
|
+
first_exception = None
|
757
|
+
for gen in generators:
|
758
|
+
try:
|
759
|
+
await gen.aclose()
|
760
|
+
except BaseException as e:
|
761
|
+
if first_exception is None:
|
762
|
+
first_exception = e
|
763
|
+
logger.exception(f"Error closing async generator: {e}")
|
764
|
+
if first_exception is not None:
|
765
|
+
raise first_exception
|