modal 0.62.16__py3-none-any.whl → 0.72.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- modal/__init__.py +17 -13
- modal/__main__.py +41 -3
- modal/_clustered_functions.py +80 -0
- modal/_clustered_functions.pyi +22 -0
- modal/_container_entrypoint.py +420 -937
- modal/_ipython.py +3 -13
- modal/_location.py +17 -10
- modal/_output.py +243 -99
- modal/_pty.py +2 -2
- modal/_resolver.py +55 -59
- modal/_resources.py +51 -0
- modal/_runtime/__init__.py +1 -0
- modal/_runtime/asgi.py +519 -0
- modal/_runtime/container_io_manager.py +1036 -0
- modal/_runtime/execution_context.py +89 -0
- modal/_runtime/telemetry.py +169 -0
- modal/_runtime/user_code_imports.py +356 -0
- modal/_serialization.py +134 -9
- modal/_traceback.py +47 -187
- modal/_tunnel.py +52 -16
- modal/_tunnel.pyi +19 -36
- modal/_utils/app_utils.py +3 -17
- modal/_utils/async_utils.py +479 -100
- modal/_utils/blob_utils.py +157 -186
- modal/_utils/bytes_io_segment_payload.py +97 -0
- modal/_utils/deprecation.py +89 -0
- modal/_utils/docker_utils.py +98 -0
- modal/_utils/function_utils.py +460 -171
- modal/_utils/grpc_testing.py +47 -31
- modal/_utils/grpc_utils.py +62 -109
- modal/_utils/hash_utils.py +61 -19
- modal/_utils/http_utils.py +39 -9
- modal/_utils/logger.py +2 -1
- modal/_utils/mount_utils.py +34 -16
- modal/_utils/name_utils.py +58 -0
- modal/_utils/package_utils.py +14 -1
- modal/_utils/pattern_utils.py +205 -0
- modal/_utils/rand_pb_testing.py +5 -7
- modal/_utils/shell_utils.py +15 -49
- modal/_vendor/a2wsgi_wsgi.py +62 -72
- modal/_vendor/cloudpickle.py +1 -1
- modal/_watcher.py +14 -12
- modal/app.py +1003 -314
- modal/app.pyi +540 -264
- modal/call_graph.py +7 -6
- modal/cli/_download.py +63 -53
- modal/cli/_traceback.py +200 -0
- modal/cli/app.py +205 -45
- modal/cli/config.py +12 -5
- modal/cli/container.py +62 -14
- modal/cli/dict.py +128 -0
- modal/cli/entry_point.py +26 -13
- modal/cli/environment.py +40 -9
- modal/cli/import_refs.py +64 -58
- modal/cli/launch.py +32 -18
- modal/cli/network_file_system.py +64 -83
- modal/cli/profile.py +1 -1
- modal/cli/programs/run_jupyter.py +35 -10
- modal/cli/programs/vscode.py +60 -10
- modal/cli/queues.py +131 -0
- modal/cli/run.py +234 -131
- modal/cli/secret.py +8 -7
- modal/cli/token.py +7 -2
- modal/cli/utils.py +79 -10
- modal/cli/volume.py +110 -109
- modal/client.py +250 -144
- modal/client.pyi +157 -118
- modal/cloud_bucket_mount.py +108 -34
- modal/cloud_bucket_mount.pyi +32 -38
- modal/cls.py +535 -148
- modal/cls.pyi +190 -146
- modal/config.py +41 -19
- modal/container_process.py +177 -0
- modal/container_process.pyi +82 -0
- modal/dict.py +111 -65
- modal/dict.pyi +136 -131
- modal/environments.py +106 -5
- modal/environments.pyi +77 -25
- modal/exception.py +34 -43
- modal/experimental.py +61 -2
- modal/extensions/ipython.py +5 -5
- modal/file_io.py +537 -0
- modal/file_io.pyi +235 -0
- modal/file_pattern_matcher.py +197 -0
- modal/functions.py +906 -911
- modal/functions.pyi +466 -430
- modal/gpu.py +57 -44
- modal/image.py +1089 -479
- modal/image.pyi +584 -228
- modal/io_streams.py +434 -0
- modal/io_streams.pyi +122 -0
- modal/mount.py +314 -101
- modal/mount.pyi +241 -235
- modal/network_file_system.py +92 -92
- modal/network_file_system.pyi +152 -110
- modal/object.py +67 -36
- modal/object.pyi +166 -143
- modal/output.py +63 -0
- modal/parallel_map.py +434 -0
- modal/parallel_map.pyi +75 -0
- modal/partial_function.py +282 -117
- modal/partial_function.pyi +222 -129
- modal/proxy.py +15 -12
- modal/proxy.pyi +3 -8
- modal/queue.py +182 -65
- modal/queue.pyi +218 -118
- modal/requirements/2024.04.txt +29 -0
- modal/requirements/2024.10.txt +16 -0
- modal/requirements/README.md +21 -0
- modal/requirements/base-images.json +22 -0
- modal/retries.py +48 -7
- modal/runner.py +459 -156
- modal/runner.pyi +135 -71
- modal/running_app.py +38 -0
- modal/sandbox.py +514 -236
- modal/sandbox.pyi +397 -169
- modal/schedule.py +4 -4
- modal/scheduler_placement.py +20 -3
- modal/secret.py +56 -31
- modal/secret.pyi +62 -42
- modal/serving.py +51 -56
- modal/serving.pyi +44 -36
- modal/stream_type.py +15 -0
- modal/token_flow.py +5 -3
- modal/token_flow.pyi +37 -32
- modal/volume.py +285 -157
- modal/volume.pyi +249 -184
- {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/METADATA +7 -7
- modal-0.72.11.dist-info/RECORD +174 -0
- {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/top_level.txt +0 -1
- modal_docs/gen_reference_docs.py +3 -1
- modal_docs/mdmd/mdmd.py +0 -1
- modal_docs/mdmd/signatures.py +5 -2
- modal_global_objects/images/base_images.py +28 -0
- modal_global_objects/mounts/python_standalone.py +2 -2
- modal_proto/__init__.py +1 -1
- modal_proto/api.proto +1288 -533
- modal_proto/api_grpc.py +856 -456
- modal_proto/api_pb2.py +2165 -1157
- modal_proto/api_pb2.pyi +8859 -0
- modal_proto/api_pb2_grpc.py +1674 -855
- modal_proto/api_pb2_grpc.pyi +1416 -0
- modal_proto/modal_api_grpc.py +149 -0
- modal_proto/modal_options_grpc.py +3 -0
- modal_proto/options_pb2.pyi +20 -0
- modal_proto/options_pb2_grpc.pyi +7 -0
- modal_proto/py.typed +0 -0
- modal_version/__init__.py +1 -1
- modal_version/_version_generated.py +2 -2
- modal/_asgi.py +0 -370
- modal/_container_entrypoint.pyi +0 -378
- modal/_container_exec.py +0 -128
- modal/_sandbox_shell.py +0 -49
- modal/shared_volume.py +0 -23
- modal/shared_volume.pyi +0 -24
- modal/stub.py +0 -783
- modal/stub.pyi +0 -332
- modal-0.62.16.dist-info/RECORD +0 -198
- modal_global_objects/images/conda.py +0 -15
- modal_global_objects/images/debian_slim.py +0 -15
- modal_global_objects/images/micromamba.py +0 -15
- test/__init__.py +0 -1
- test/aio_test.py +0 -12
- test/async_utils_test.py +0 -262
- test/blob_test.py +0 -67
- test/cli_imports_test.py +0 -149
- test/cli_test.py +0 -659
- test/client_test.py +0 -194
- test/cls_test.py +0 -630
- test/config_test.py +0 -137
- test/conftest.py +0 -1420
- test/container_app_test.py +0 -32
- test/container_test.py +0 -1389
- test/cpu_test.py +0 -23
- test/decorator_test.py +0 -85
- test/deprecation_test.py +0 -34
- test/dict_test.py +0 -33
- test/e2e_test.py +0 -68
- test/error_test.py +0 -7
- test/function_serialization_test.py +0 -32
- test/function_test.py +0 -653
- test/function_utils_test.py +0 -101
- test/gpu_test.py +0 -159
- test/grpc_utils_test.py +0 -141
- test/helpers.py +0 -42
- test/image_test.py +0 -669
- test/live_reload_test.py +0 -80
- test/lookup_test.py +0 -70
- test/mdmd_test.py +0 -329
- test/mount_test.py +0 -162
- test/mounted_files_test.py +0 -329
- test/network_file_system_test.py +0 -181
- test/notebook_test.py +0 -66
- test/object_test.py +0 -41
- test/package_utils_test.py +0 -25
- test/queue_test.py +0 -97
- test/resolver_test.py +0 -58
- test/retries_test.py +0 -67
- test/runner_test.py +0 -85
- test/sandbox_test.py +0 -191
- test/schedule_test.py +0 -15
- test/scheduler_placement_test.py +0 -29
- test/secret_test.py +0 -78
- test/serialization_test.py +0 -42
- test/stub_composition_test.py +0 -10
- test/stub_test.py +0 -360
- test/test_asgi_wrapper.py +0 -234
- test/token_flow_test.py +0 -18
- test/traceback_test.py +0 -135
- test/tunnel_test.py +0 -29
- test/utils_test.py +0 -88
- test/version_test.py +0 -14
- test/volume_test.py +0 -341
- test/watcher_test.py +0 -30
- test/webhook_test.py +0 -146
- /modal/{requirements.312.txt → requirements/2023.12.312.txt} +0 -0
- /modal/{requirements.txt → requirements/2023.12.txt} +0 -0
- {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/LICENSE +0 -0
- {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/WHEEL +0 -0
- {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/entry_points.txt +0 -0
modal/_utils/async_utils.py
CHANGED
@@ -3,25 +3,34 @@ import asyncio
|
|
3
3
|
import concurrent.futures
|
4
4
|
import functools
|
5
5
|
import inspect
|
6
|
-
import
|
6
|
+
import itertools
|
7
7
|
import time
|
8
8
|
import typing
|
9
|
+
from collections.abc import AsyncGenerator, AsyncIterable, Awaitable, Iterable, Iterator
|
9
10
|
from contextlib import asynccontextmanager
|
10
|
-
from
|
11
|
+
from dataclasses import dataclass
|
12
|
+
from typing import (
|
13
|
+
Any,
|
14
|
+
Callable,
|
15
|
+
Optional,
|
16
|
+
TypeVar,
|
17
|
+
Union,
|
18
|
+
cast,
|
19
|
+
)
|
11
20
|
|
12
21
|
import synchronicity
|
13
|
-
from
|
22
|
+
from synchronicity.async_utils import Runner
|
23
|
+
from synchronicity.exceptions import NestedEventLoops
|
24
|
+
from typing_extensions import ParamSpec, assert_type
|
14
25
|
|
26
|
+
from ..exception import InvalidError
|
15
27
|
from .logger import logger
|
16
28
|
|
17
29
|
synchronizer = synchronicity.Synchronizer()
|
18
|
-
# atexit.register(synchronizer.close)
|
19
30
|
|
20
31
|
|
21
32
|
def synchronize_api(obj, target_module=None):
|
22
|
-
if inspect.isclass(obj):
|
23
|
-
blocking_name = obj.__name__.lstrip("_")
|
24
|
-
elif inspect.isfunction(object):
|
33
|
+
if inspect.isclass(obj) or inspect.isfunction(obj):
|
25
34
|
blocking_name = obj.__name__.lstrip("_")
|
26
35
|
elif isinstance(obj, TypeVar):
|
27
36
|
blocking_name = "_BLOCKING_" + obj.__name__
|
@@ -86,14 +95,24 @@ def retry(direct_fn=None, *, n_attempts=3, base_delay=0, delay_factor=2, timeout
|
|
86
95
|
|
87
96
|
|
88
97
|
class TaskContext:
|
89
|
-
"""
|
98
|
+
"""A structured group that helps manage stray tasks.
|
99
|
+
|
100
|
+
This differs from the standard library `asyncio.TaskGroup` in that it cancels all tasks still
|
101
|
+
running after exiting the context manager, rather than waiting for them to finish.
|
102
|
+
|
103
|
+
A `TaskContext` can have an optional `grace` period in seconds, which will wait for a certain
|
104
|
+
amount of time before cancelling all remaining tasks. This is useful for allowing tasks to
|
105
|
+
gracefully exit when they determine that the context is shutting down.
|
90
106
|
|
91
107
|
Usage:
|
108
|
+
|
109
|
+
```python notest
|
92
110
|
async with TaskContext() as task_context:
|
93
|
-
task = task_context.
|
111
|
+
task = task_context.create_task(coro())
|
112
|
+
```
|
94
113
|
"""
|
95
114
|
|
96
|
-
_loops:
|
115
|
+
_loops: set[asyncio.Task]
|
97
116
|
|
98
117
|
def __init__(self, grace: Optional[float] = None):
|
99
118
|
self._grace = grace
|
@@ -130,7 +149,6 @@ class TaskContext:
|
|
130
149
|
if gather_future:
|
131
150
|
try:
|
132
151
|
await gather_future
|
133
|
-
# pre Python3.8, CancelledErrors were a subclass of exception
|
134
152
|
except asyncio.CancelledError:
|
135
153
|
pass
|
136
154
|
|
@@ -140,15 +158,12 @@ class TaskContext:
|
|
140
158
|
# Only tasks without a done_callback will still be present in self._tasks
|
141
159
|
task.result()
|
142
160
|
|
143
|
-
if task.done() or task in self._loops:
|
161
|
+
if task.done() or task in self._loops: # Note: Legacy code, we can probably cancel loops.
|
144
162
|
continue
|
145
163
|
|
146
|
-
|
147
|
-
already_cancelling = task.cancelling() > 0
|
148
|
-
if not already_cancelling:
|
149
|
-
logger.warning(f"Canceling remaining unfinished task: {task}")
|
150
|
-
|
164
|
+
# Cancel any remaining unfinished tasks.
|
151
165
|
task.cancel()
|
166
|
+
await asyncio.sleep(0) # wake up coroutines waiting for cancellations
|
152
167
|
|
153
168
|
async def __aexit__(self, exc_type, value, tb):
|
154
169
|
await self.stop()
|
@@ -165,28 +180,32 @@ class TaskContext:
|
|
165
180
|
task.add_done_callback(self._tasks.discard)
|
166
181
|
return task
|
167
182
|
|
168
|
-
def infinite_loop(
|
169
|
-
|
183
|
+
def infinite_loop(
|
184
|
+
self, async_f, timeout: Optional[float] = 90, sleep: float = 10, log_exception: bool = True
|
185
|
+
) -> asyncio.Task:
|
186
|
+
if isinstance(async_f, functools.partial):
|
187
|
+
function_name = async_f.func.__qualname__
|
188
|
+
else:
|
189
|
+
function_name = async_f.__qualname__
|
170
190
|
|
171
191
|
async def loop_coro() -> None:
|
172
192
|
logger.debug(f"Starting infinite loop {function_name}")
|
173
|
-
while
|
174
|
-
t0 = time.time()
|
193
|
+
while not self.exited:
|
175
194
|
try:
|
176
195
|
await asyncio.wait_for(async_f(), timeout=timeout)
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
196
|
+
except Exception as exc:
|
197
|
+
if log_exception and isinstance(exc, asyncio.TimeoutError):
|
198
|
+
# Asyncio sends an empty message in this case, so let's use logger.error
|
199
|
+
logger.error(f"Loop attempt for {function_name} timed out")
|
200
|
+
elif log_exception:
|
201
|
+
# Propagate the exception to the logger
|
202
|
+
logger.exception(f"Loop attempt for {function_name} failed")
|
183
203
|
try:
|
184
204
|
await asyncio.wait_for(self._exited.wait(), timeout=sleep)
|
185
205
|
except asyncio.TimeoutError:
|
186
206
|
continue
|
187
|
-
|
188
|
-
|
189
|
-
break
|
207
|
+
|
208
|
+
logger.debug(f"Exiting infinite loop for {function_name}")
|
190
209
|
|
191
210
|
t = self.create_task(loop_coro())
|
192
211
|
t.set_name(f"{function_name} loop")
|
@@ -194,29 +213,39 @@ class TaskContext:
|
|
194
213
|
t.add_done_callback(self._loops.discard)
|
195
214
|
return t
|
196
215
|
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
216
|
+
@staticmethod
|
217
|
+
async def gather(*coros: Awaitable) -> Any:
|
218
|
+
"""Wait for a sequence of coroutines to finish, concurrently.
|
219
|
+
|
220
|
+
This is similar to `asyncio.gather()`, but it uses TaskContext to cancel all remaining tasks
|
221
|
+
if one fails with an exception other than `asyncio.CancelledError`. The native `asyncio`
|
222
|
+
function does not cancel remaining tasks in this case, which can lead to surprises.
|
223
|
+
|
224
|
+
For example, if you use `asyncio.gather(t1, t2, t3)` and t2 raises an exception, then t1 and
|
225
|
+
t3 would continue running. With `TaskContext.gather(t1, t2, t3)`, they are cancelled.
|
226
|
+
|
227
|
+
(It's still acceptable to use `asyncio.gather()` if you don't need cancellation — for
|
228
|
+
example, if you're just gathering quick coroutines with no side-effects. Or if you're
|
229
|
+
gathering the tasks with `return_exceptions=True`.)
|
230
|
+
|
231
|
+
Usage:
|
232
|
+
|
233
|
+
```python notest
|
234
|
+
# Example 1: Await three coroutines
|
235
|
+
created_object, other_work, new_plumbing = await TaskContext.gather(
|
236
|
+
create_my_object(),
|
237
|
+
do_some_other_work(),
|
238
|
+
fix_plumbing(),
|
239
|
+
)
|
240
|
+
|
241
|
+
# Example 2: Gather a list of coroutines
|
242
|
+
coros = [a.load() for a in objects]
|
243
|
+
results = await TaskContext.gather(*coros)
|
244
|
+
```
|
245
|
+
"""
|
246
|
+
async with TaskContext() as tc:
|
247
|
+
results = await asyncio.gather(*(tc.create_task(coro) for coro in coros))
|
248
|
+
return results
|
220
249
|
|
221
250
|
|
222
251
|
def run_coro_blocking(coro):
|
@@ -234,8 +263,10 @@ def run_coro_blocking(coro):
|
|
234
263
|
async def queue_batch_iterator(q: asyncio.Queue, max_batch_size=100, debounce_time=0.015):
|
235
264
|
"""
|
236
265
|
Read from a queue but return lists of items when queue is large
|
266
|
+
|
267
|
+
Treats a None value as end of queue items
|
237
268
|
"""
|
238
|
-
item_list:
|
269
|
+
item_list: list[Any] = []
|
239
270
|
|
240
271
|
while True:
|
241
272
|
if q.empty() and len(item_list) > 0:
|
@@ -257,44 +288,102 @@ async def queue_batch_iterator(q: asyncio.Queue, max_batch_size=100, debounce_ti
|
|
257
288
|
|
258
289
|
|
259
290
|
class _WarnIfGeneratorIsNotConsumed:
|
260
|
-
def __init__(self, gen,
|
291
|
+
def __init__(self, gen, function_name: str):
|
261
292
|
self.gen = gen
|
262
|
-
self.
|
293
|
+
self.function_name = function_name
|
263
294
|
self.iterated = False
|
264
295
|
self.warned = False
|
265
296
|
|
266
297
|
def __aiter__(self):
|
267
298
|
self.iterated = True
|
268
|
-
return self.gen
|
299
|
+
return self.gen.__aiter__()
|
269
300
|
|
270
301
|
async def __anext__(self):
|
271
302
|
self.iterated = True
|
272
303
|
return await self.gen.__anext__()
|
273
304
|
|
305
|
+
async def asend(self, value):
|
306
|
+
self.iterated = True
|
307
|
+
return await self.gen.asend(value)
|
308
|
+
|
274
309
|
def __repr__(self):
|
275
310
|
return repr(self.gen)
|
276
311
|
|
277
312
|
def __del__(self):
|
278
313
|
if not self.iterated and not self.warned:
|
279
314
|
self.warned = True
|
280
|
-
name = self.gen_f.__name__
|
281
315
|
logger.warning(
|
282
|
-
f"Warning: the results of a call to {
|
283
|
-
|
316
|
+
f"Warning: the results of a call to {self.function_name} was not consumed, "
|
317
|
+
"so the call will never be executed."
|
318
|
+
f" Consider a for-loop like `for x in {self.function_name}(...)` or "
|
319
|
+
"unpacking the generator using `list(...)`"
|
284
320
|
)
|
285
321
|
|
322
|
+
async def athrow(self, exc):
|
323
|
+
return await self.gen.athrow(exc)
|
324
|
+
|
325
|
+
async def aclose(self):
|
326
|
+
return await self.gen.aclose()
|
327
|
+
|
286
328
|
|
287
329
|
synchronize_api(_WarnIfGeneratorIsNotConsumed)
|
288
330
|
|
289
331
|
|
290
|
-
|
332
|
+
class _WarnIfNonWrappedGeneratorIsNotConsumed(_WarnIfGeneratorIsNotConsumed):
|
333
|
+
# used for non-synchronicity-wrapped generators and iterators
|
334
|
+
def __iter__(self):
|
335
|
+
self.iterated = True
|
336
|
+
return iter(self.gen)
|
337
|
+
|
338
|
+
def __next__(self):
|
339
|
+
self.iterated = True
|
340
|
+
return self.gen.__next__()
|
341
|
+
|
342
|
+
def send(self, value):
|
343
|
+
self.iterated = True
|
344
|
+
return self.gen.send(value)
|
345
|
+
|
346
|
+
|
347
|
+
def warn_if_generator_is_not_consumed(function_name: Optional[str] = None):
|
291
348
|
# https://gist.github.com/erikbern/01ae78d15f89edfa7f77e5c0a827a94d
|
292
|
-
|
293
|
-
|
294
|
-
gen = gen_f(*args, **kwargs)
|
295
|
-
return _WarnIfGeneratorIsNotConsumed(gen, gen_f)
|
349
|
+
def decorator(gen_f):
|
350
|
+
presented_func_name = function_name if function_name is not None else gen_f.__name__
|
296
351
|
|
297
|
-
|
352
|
+
@functools.wraps(gen_f)
|
353
|
+
def f_wrapped(*args, **kwargs):
|
354
|
+
gen = gen_f(*args, **kwargs)
|
355
|
+
if inspect.isasyncgen(gen):
|
356
|
+
return _WarnIfGeneratorIsNotConsumed(gen, presented_func_name)
|
357
|
+
else:
|
358
|
+
return _WarnIfNonWrappedGeneratorIsNotConsumed(gen, presented_func_name)
|
359
|
+
|
360
|
+
return f_wrapped
|
361
|
+
|
362
|
+
return decorator
|
363
|
+
|
364
|
+
|
365
|
+
class AsyncOrSyncIterable:
|
366
|
+
"""Compatibility class for non-synchronicity wrapped async iterables to get
|
367
|
+
both async and sync interfaces in the same way that synchronicity does (but on the main thread)
|
368
|
+
so they can be "lazily" iterated using either `for _ in x` or `async for _ in x`
|
369
|
+
|
370
|
+
nested_async_message is raised as an InvalidError if the async variant is called
|
371
|
+
from an already async context, since that would otherwise deadlock the event loop
|
372
|
+
"""
|
373
|
+
|
374
|
+
def __init__(self, async_iterable: typing.AsyncGenerator[Any, None], nested_async_message):
|
375
|
+
self._async_iterable = async_iterable
|
376
|
+
self.nested_async_message = nested_async_message
|
377
|
+
|
378
|
+
def __aiter__(self):
|
379
|
+
return self._async_iterable
|
380
|
+
|
381
|
+
def __iter__(self):
|
382
|
+
try:
|
383
|
+
with Runner() as runner:
|
384
|
+
yield from run_async_gen(runner, self._async_iterable)
|
385
|
+
except NestedEventLoops:
|
386
|
+
raise InvalidError(self.nested_async_message)
|
298
387
|
|
299
388
|
|
300
389
|
_shutdown_tasks = []
|
@@ -314,6 +403,7 @@ def on_shutdown(coro):
|
|
314
403
|
|
315
404
|
T = TypeVar("T")
|
316
405
|
P = ParamSpec("P")
|
406
|
+
V = TypeVar("V")
|
317
407
|
|
318
408
|
|
319
409
|
def asyncify(f: Callable[P, T]) -> Callable[P, typing.Coroutine[None, None, T]]:
|
@@ -339,40 +429,6 @@ async def iterate_blocking(iterator: Iterator[T]) -> AsyncGenerator[T, None]:
|
|
339
429
|
yield cast(T, obj)
|
340
430
|
|
341
431
|
|
342
|
-
class ConcurrencyPool:
|
343
|
-
def __init__(self, concurrency_limit: int):
|
344
|
-
self.semaphore = asyncio.Semaphore(concurrency_limit)
|
345
|
-
|
346
|
-
async def run_coros(self, coros: typing.Iterable[typing.Coroutine], return_exceptions=False):
|
347
|
-
async def blocking_wrapper(coro):
|
348
|
-
# Not using async with on the semaphore is intentional here - if return_exceptions=False
|
349
|
-
# manual release prevents starting extraneous tasks after exceptions.
|
350
|
-
try:
|
351
|
-
await self.semaphore.acquire()
|
352
|
-
except asyncio.CancelledError:
|
353
|
-
coro.close() # avoid "coroutine was never awaited" warnings
|
354
|
-
|
355
|
-
try:
|
356
|
-
res = await coro
|
357
|
-
self.semaphore.release()
|
358
|
-
return res
|
359
|
-
except BaseException as e:
|
360
|
-
if return_exceptions:
|
361
|
-
self.semaphore.release()
|
362
|
-
raise e
|
363
|
-
|
364
|
-
# asyncio.gather() is weird - it doesn't cancel outstanding awaitables on exceptions when
|
365
|
-
# return_exceptions=False --> wrap the coros in tasks are cancel them explicitly on exception.
|
366
|
-
tasks = [asyncio.create_task(blocking_wrapper(coro)) for coro in coros]
|
367
|
-
g = asyncio.gather(*tasks, return_exceptions=return_exceptions)
|
368
|
-
try:
|
369
|
-
return await g
|
370
|
-
except BaseException as e:
|
371
|
-
for t in tasks:
|
372
|
-
t.cancel()
|
373
|
-
raise e
|
374
|
-
|
375
|
-
|
376
432
|
@asynccontextmanager
|
377
433
|
async def asyncnullcontext(*args, **kwargs):
|
378
434
|
"""Async noop context manager.
|
@@ -384,3 +440,326 @@ async def asyncnullcontext(*args, **kwargs):
|
|
384
440
|
pass
|
385
441
|
"""
|
386
442
|
yield
|
443
|
+
|
444
|
+
|
445
|
+
YIELD_TYPE = typing.TypeVar("YIELD_TYPE")
|
446
|
+
SEND_TYPE = typing.TypeVar("SEND_TYPE")
|
447
|
+
|
448
|
+
|
449
|
+
def run_async_gen(
|
450
|
+
runner: Runner,
|
451
|
+
gen: typing.AsyncGenerator[YIELD_TYPE, SEND_TYPE],
|
452
|
+
) -> typing.Generator[YIELD_TYPE, SEND_TYPE, None]:
|
453
|
+
"""Convert an async generator into a sync one"""
|
454
|
+
# more or less copied from synchronicity's implementation:
|
455
|
+
next_send: typing.Union[SEND_TYPE, None] = None
|
456
|
+
next_yield: YIELD_TYPE
|
457
|
+
exc: Optional[BaseException] = None
|
458
|
+
while True:
|
459
|
+
try:
|
460
|
+
if exc:
|
461
|
+
next_yield = runner.run(gen.athrow(exc))
|
462
|
+
else:
|
463
|
+
next_yield = runner.run(gen.asend(next_send)) # type: ignore[arg-type]
|
464
|
+
except KeyboardInterrupt as e:
|
465
|
+
raise e from None
|
466
|
+
except StopAsyncIteration:
|
467
|
+
break # typically a graceful exit of the async generator
|
468
|
+
try:
|
469
|
+
next_send = yield next_yield
|
470
|
+
exc = None
|
471
|
+
except BaseException as err:
|
472
|
+
exc = err
|
473
|
+
|
474
|
+
|
475
|
+
class aclosing(typing.Generic[T]): # noqa
|
476
|
+
# backport of Python contextlib.aclosing from Python 3.10
|
477
|
+
def __init__(self, agen: AsyncGenerator[T, None]):
|
478
|
+
self.agen = agen
|
479
|
+
|
480
|
+
async def __aenter__(self) -> AsyncGenerator[T, None]:
|
481
|
+
return self.agen
|
482
|
+
|
483
|
+
async def __aexit__(self, exc, exc_type, tb):
|
484
|
+
await self.agen.aclose()
|
485
|
+
|
486
|
+
|
487
|
+
async def sync_or_async_iter(iter: Union[Iterable[T], AsyncIterable[T]]) -> AsyncGenerator[T, None]:
|
488
|
+
if hasattr(iter, "__aiter__"):
|
489
|
+
agen = typing.cast(AsyncGenerator[T, None], iter)
|
490
|
+
try:
|
491
|
+
async for item in agen:
|
492
|
+
yield item
|
493
|
+
finally:
|
494
|
+
if hasattr(agen, "aclose"):
|
495
|
+
# All AsyncGenerator's have an aclose method
|
496
|
+
# but some AsyncIterable's don't necessarily
|
497
|
+
await agen.aclose()
|
498
|
+
else:
|
499
|
+
assert hasattr(iter, "__iter__"), "sync_or_async_iter requires an Iterable or AsyncGenerator"
|
500
|
+
# This intentionally could block the event loop for the duration of calling __iter__ and __next__,
|
501
|
+
# so in non-trivial cases (like passing lists and ranges) this could be quite a foot gun for users #
|
502
|
+
# w/ async code (but they can work around it by always using async iterators)
|
503
|
+
for item in typing.cast(Iterable[T], iter):
|
504
|
+
yield item
|
505
|
+
|
506
|
+
|
507
|
+
@typing.overload
|
508
|
+
def async_zip(g1: AsyncGenerator[T, None], g2: AsyncGenerator[V, None], /) -> AsyncGenerator[tuple[T, V], None]:
|
509
|
+
...
|
510
|
+
|
511
|
+
|
512
|
+
@typing.overload
|
513
|
+
def async_zip(*generators: AsyncGenerator[T, None]) -> AsyncGenerator[tuple[T, ...], None]:
|
514
|
+
...
|
515
|
+
|
516
|
+
|
517
|
+
async def async_zip(*generators):
|
518
|
+
tasks = []
|
519
|
+
try:
|
520
|
+
while True:
|
521
|
+
try:
|
522
|
+
|
523
|
+
async def next_item(gen):
|
524
|
+
return await gen.__anext__()
|
525
|
+
|
526
|
+
tasks = [asyncio.create_task(next_item(gen)) for gen in generators]
|
527
|
+
items = await asyncio.gather(*tasks)
|
528
|
+
yield tuple(items)
|
529
|
+
except StopAsyncIteration:
|
530
|
+
break
|
531
|
+
finally:
|
532
|
+
cancelled_tasks = []
|
533
|
+
for task in tasks:
|
534
|
+
if not task.done():
|
535
|
+
task.cancel()
|
536
|
+
cancelled_tasks.append(task)
|
537
|
+
try:
|
538
|
+
await asyncio.gather(*cancelled_tasks)
|
539
|
+
except asyncio.CancelledError:
|
540
|
+
pass
|
541
|
+
|
542
|
+
first_exception = None
|
543
|
+
for gen in generators:
|
544
|
+
try:
|
545
|
+
await gen.aclose()
|
546
|
+
except BaseException as e:
|
547
|
+
if first_exception is None:
|
548
|
+
first_exception = e
|
549
|
+
logger.exception(f"Error closing async generator: {e}")
|
550
|
+
if first_exception is not None:
|
551
|
+
raise first_exception
|
552
|
+
|
553
|
+
|
554
|
+
@dataclass
|
555
|
+
class ValueWrapper(typing.Generic[T]):
|
556
|
+
value: T
|
557
|
+
|
558
|
+
|
559
|
+
@dataclass
|
560
|
+
class ExceptionWrapper:
|
561
|
+
value: Exception
|
562
|
+
|
563
|
+
|
564
|
+
class StopSentinelType:
|
565
|
+
...
|
566
|
+
|
567
|
+
|
568
|
+
STOP_SENTINEL = StopSentinelType()
|
569
|
+
|
570
|
+
|
571
|
+
async def async_merge(*generators: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
|
572
|
+
"""
|
573
|
+
Asynchronously merges multiple async generators into a single async generator.
|
574
|
+
|
575
|
+
This function takes multiple async generators and yields their values in the order
|
576
|
+
they are produced. If any generator raises an exception, the exception is propagated.
|
577
|
+
|
578
|
+
Args:
|
579
|
+
*generators: One or more async generators to be merged.
|
580
|
+
|
581
|
+
Yields:
|
582
|
+
The values produced by the input async generators.
|
583
|
+
|
584
|
+
Raises:
|
585
|
+
Exception: If any of the input generators raises an exception, it is propagated.
|
586
|
+
|
587
|
+
Usage:
|
588
|
+
```python
|
589
|
+
import asyncio
|
590
|
+
from modal._utils.async_utils import async_merge
|
591
|
+
|
592
|
+
async def gen1():
|
593
|
+
yield 1
|
594
|
+
yield 2
|
595
|
+
|
596
|
+
async def gen2():
|
597
|
+
yield "a"
|
598
|
+
yield "b"
|
599
|
+
|
600
|
+
async def example():
|
601
|
+
values = set()
|
602
|
+
async for value in async_merge(gen1(), gen2()):
|
603
|
+
values.add(value)
|
604
|
+
|
605
|
+
return values
|
606
|
+
|
607
|
+
# Output could be: {1, "a", 2, "b"} (order may vary)
|
608
|
+
values = asyncio.run(example())
|
609
|
+
assert values == {1, "a", 2, "b"}
|
610
|
+
```
|
611
|
+
"""
|
612
|
+
queue: asyncio.Queue[Union[ValueWrapper[T], ExceptionWrapper]] = asyncio.Queue(maxsize=len(generators) * 10)
|
613
|
+
|
614
|
+
async def producer(generator: AsyncGenerator[T, None]):
|
615
|
+
try:
|
616
|
+
async for item in generator:
|
617
|
+
await queue.put(ValueWrapper(item))
|
618
|
+
except Exception as e:
|
619
|
+
await queue.put(ExceptionWrapper(e))
|
620
|
+
|
621
|
+
tasks = {asyncio.create_task(producer(gen)) for gen in generators}
|
622
|
+
new_output_task = asyncio.create_task(queue.get())
|
623
|
+
|
624
|
+
try:
|
625
|
+
while tasks:
|
626
|
+
done, _ = await asyncio.wait(
|
627
|
+
[*tasks, new_output_task],
|
628
|
+
return_when=asyncio.FIRST_COMPLETED,
|
629
|
+
)
|
630
|
+
|
631
|
+
if new_output_task in done:
|
632
|
+
item = new_output_task.result()
|
633
|
+
if isinstance(item, ValueWrapper):
|
634
|
+
yield item.value
|
635
|
+
else:
|
636
|
+
assert_type(item, ExceptionWrapper)
|
637
|
+
raise item.value
|
638
|
+
|
639
|
+
new_output_task = asyncio.create_task(queue.get())
|
640
|
+
|
641
|
+
finished_producers = done & tasks
|
642
|
+
tasks -= finished_producers
|
643
|
+
for finished_producer in finished_producers:
|
644
|
+
# this is done in order to catch potential raised errors/cancellations
|
645
|
+
# from within worker tasks as soon as they happen.
|
646
|
+
await finished_producer
|
647
|
+
|
648
|
+
while not queue.empty():
|
649
|
+
item = await new_output_task
|
650
|
+
if isinstance(item, ValueWrapper):
|
651
|
+
yield item.value
|
652
|
+
else:
|
653
|
+
assert_type(item, ExceptionWrapper)
|
654
|
+
raise item.value
|
655
|
+
|
656
|
+
new_output_task = asyncio.create_task(queue.get())
|
657
|
+
|
658
|
+
finally:
|
659
|
+
if not new_output_task.done():
|
660
|
+
new_output_task.cancel()
|
661
|
+
for task in tasks:
|
662
|
+
if not task.done():
|
663
|
+
try:
|
664
|
+
task.cancel()
|
665
|
+
await task
|
666
|
+
except asyncio.CancelledError:
|
667
|
+
pass
|
668
|
+
|
669
|
+
|
670
|
+
async def callable_to_agen(awaitable: Callable[[], Awaitable[T]]) -> AsyncGenerator[T, None]:
|
671
|
+
yield await awaitable()
|
672
|
+
|
673
|
+
|
674
|
+
async def gather_cancel_on_exc(*coros_or_futures):
|
675
|
+
input_tasks = [asyncio.ensure_future(t) for t in coros_or_futures]
|
676
|
+
try:
|
677
|
+
return await asyncio.gather(*input_tasks)
|
678
|
+
except BaseException:
|
679
|
+
for t in input_tasks:
|
680
|
+
t.cancel()
|
681
|
+
await asyncio.gather(*input_tasks, return_exceptions=False) # handle cancellations
|
682
|
+
raise
|
683
|
+
|
684
|
+
|
685
|
+
async def async_map(
|
686
|
+
input_generator: AsyncGenerator[T, None],
|
687
|
+
async_mapper_func: Callable[[T], Awaitable[V]],
|
688
|
+
concurrency: int,
|
689
|
+
) -> AsyncGenerator[V, None]:
|
690
|
+
queue: asyncio.Queue[Union[ValueWrapper[T], StopSentinelType]] = asyncio.Queue(maxsize=concurrency * 2)
|
691
|
+
|
692
|
+
async def producer() -> AsyncGenerator[V, None]:
|
693
|
+
async for item in input_generator:
|
694
|
+
await queue.put(ValueWrapper(item))
|
695
|
+
|
696
|
+
for _ in range(concurrency):
|
697
|
+
await queue.put(STOP_SENTINEL)
|
698
|
+
|
699
|
+
if False:
|
700
|
+
# Need it to be an async generator for async_merge
|
701
|
+
# but we don't want to yield anything
|
702
|
+
yield
|
703
|
+
|
704
|
+
async def worker() -> AsyncGenerator[V, None]:
|
705
|
+
while True:
|
706
|
+
item = await queue.get()
|
707
|
+
if isinstance(item, ValueWrapper):
|
708
|
+
yield await async_mapper_func(item.value)
|
709
|
+
elif isinstance(item, ExceptionWrapper):
|
710
|
+
raise item.value
|
711
|
+
else:
|
712
|
+
assert_type(item, StopSentinelType)
|
713
|
+
break
|
714
|
+
|
715
|
+
async with aclosing(async_merge(*[worker() for _ in range(concurrency)], producer())) as stream:
|
716
|
+
async for item in stream:
|
717
|
+
yield item
|
718
|
+
|
719
|
+
|
720
|
+
async def async_map_ordered(
|
721
|
+
input_generator: AsyncGenerator[T, None],
|
722
|
+
async_mapper_func: Callable[[T], Awaitable[V]],
|
723
|
+
concurrency: int,
|
724
|
+
buffer_size: Optional[int] = None,
|
725
|
+
) -> AsyncGenerator[V, None]:
|
726
|
+
semaphore = asyncio.Semaphore(buffer_size or concurrency)
|
727
|
+
|
728
|
+
async def mapper_func_wrapper(tup: tuple[int, T]) -> tuple[int, V]:
|
729
|
+
return (tup[0], await async_mapper_func(tup[1]))
|
730
|
+
|
731
|
+
async def counter() -> AsyncGenerator[int, None]:
|
732
|
+
for i in itertools.count():
|
733
|
+
await semaphore.acquire()
|
734
|
+
yield i
|
735
|
+
|
736
|
+
next_idx = 0
|
737
|
+
buffer = {}
|
738
|
+
|
739
|
+
async with aclosing(async_map(async_zip(counter(), input_generator), mapper_func_wrapper, concurrency)) as stream:
|
740
|
+
async for output_idx, output_item in stream:
|
741
|
+
buffer[output_idx] = output_item
|
742
|
+
|
743
|
+
while next_idx in buffer:
|
744
|
+
yield buffer[next_idx]
|
745
|
+
semaphore.release()
|
746
|
+
del buffer[next_idx]
|
747
|
+
next_idx += 1
|
748
|
+
|
749
|
+
|
750
|
+
async def async_chain(*generators: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
|
751
|
+
try:
|
752
|
+
for gen in generators:
|
753
|
+
async for item in gen:
|
754
|
+
yield item
|
755
|
+
finally:
|
756
|
+
first_exception = None
|
757
|
+
for gen in generators:
|
758
|
+
try:
|
759
|
+
await gen.aclose()
|
760
|
+
except BaseException as e:
|
761
|
+
if first_exception is None:
|
762
|
+
first_exception = e
|
763
|
+
logger.exception(f"Error closing async generator: {e}")
|
764
|
+
if first_exception is not None:
|
765
|
+
raise first_exception
|