modal 0.62.115__py3-none-any.whl → 0.72.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. modal/__init__.py +13 -9
  2. modal/__main__.py +41 -3
  3. modal/_clustered_functions.py +80 -0
  4. modal/_clustered_functions.pyi +22 -0
  5. modal/_container_entrypoint.py +402 -398
  6. modal/_ipython.py +3 -13
  7. modal/_location.py +17 -10
  8. modal/_output.py +243 -99
  9. modal/_pty.py +2 -2
  10. modal/_resolver.py +55 -60
  11. modal/_resources.py +26 -7
  12. modal/_runtime/__init__.py +1 -0
  13. modal/_runtime/asgi.py +519 -0
  14. modal/_runtime/container_io_manager.py +1025 -0
  15. modal/{execution_context.py → _runtime/execution_context.py} +11 -2
  16. modal/_runtime/telemetry.py +169 -0
  17. modal/_runtime/user_code_imports.py +356 -0
  18. modal/_serialization.py +123 -6
  19. modal/_traceback.py +47 -187
  20. modal/_tunnel.py +50 -14
  21. modal/_tunnel.pyi +19 -36
  22. modal/_utils/app_utils.py +3 -17
  23. modal/_utils/async_utils.py +386 -104
  24. modal/_utils/blob_utils.py +157 -186
  25. modal/_utils/bytes_io_segment_payload.py +97 -0
  26. modal/_utils/deprecation.py +89 -0
  27. modal/_utils/docker_utils.py +98 -0
  28. modal/_utils/function_utils.py +299 -98
  29. modal/_utils/grpc_testing.py +47 -34
  30. modal/_utils/grpc_utils.py +54 -21
  31. modal/_utils/hash_utils.py +51 -10
  32. modal/_utils/http_utils.py +39 -9
  33. modal/_utils/logger.py +2 -1
  34. modal/_utils/mount_utils.py +34 -16
  35. modal/_utils/name_utils.py +58 -0
  36. modal/_utils/package_utils.py +14 -1
  37. modal/_utils/pattern_utils.py +205 -0
  38. modal/_utils/rand_pb_testing.py +3 -3
  39. modal/_utils/shell_utils.py +15 -49
  40. modal/_vendor/a2wsgi_wsgi.py +62 -72
  41. modal/_vendor/cloudpickle.py +1 -1
  42. modal/_watcher.py +12 -10
  43. modal/app.py +561 -323
  44. modal/app.pyi +474 -262
  45. modal/call_graph.py +7 -6
  46. modal/cli/_download.py +22 -6
  47. modal/cli/_traceback.py +200 -0
  48. modal/cli/app.py +203 -42
  49. modal/cli/config.py +12 -5
  50. modal/cli/container.py +61 -13
  51. modal/cli/dict.py +128 -0
  52. modal/cli/entry_point.py +26 -13
  53. modal/cli/environment.py +40 -9
  54. modal/cli/import_refs.py +21 -48
  55. modal/cli/launch.py +28 -14
  56. modal/cli/network_file_system.py +57 -21
  57. modal/cli/profile.py +1 -1
  58. modal/cli/programs/run_jupyter.py +34 -9
  59. modal/cli/programs/vscode.py +58 -8
  60. modal/cli/queues.py +131 -0
  61. modal/cli/run.py +199 -96
  62. modal/cli/secret.py +5 -4
  63. modal/cli/token.py +7 -2
  64. modal/cli/utils.py +74 -8
  65. modal/cli/volume.py +97 -56
  66. modal/client.py +248 -144
  67. modal/client.pyi +156 -124
  68. modal/cloud_bucket_mount.py +43 -30
  69. modal/cloud_bucket_mount.pyi +32 -25
  70. modal/cls.py +528 -141
  71. modal/cls.pyi +189 -145
  72. modal/config.py +32 -15
  73. modal/container_process.py +177 -0
  74. modal/container_process.pyi +82 -0
  75. modal/dict.py +50 -54
  76. modal/dict.pyi +120 -164
  77. modal/environments.py +106 -5
  78. modal/environments.pyi +77 -25
  79. modal/exception.py +30 -43
  80. modal/experimental.py +62 -2
  81. modal/file_io.py +537 -0
  82. modal/file_io.pyi +235 -0
  83. modal/file_pattern_matcher.py +196 -0
  84. modal/functions.py +846 -428
  85. modal/functions.pyi +446 -387
  86. modal/gpu.py +57 -44
  87. modal/image.py +943 -417
  88. modal/image.pyi +584 -245
  89. modal/io_streams.py +434 -0
  90. modal/io_streams.pyi +122 -0
  91. modal/mount.py +223 -90
  92. modal/mount.pyi +241 -243
  93. modal/network_file_system.py +85 -86
  94. modal/network_file_system.pyi +151 -110
  95. modal/object.py +66 -36
  96. modal/object.pyi +166 -143
  97. modal/output.py +63 -0
  98. modal/parallel_map.py +73 -47
  99. modal/parallel_map.pyi +51 -63
  100. modal/partial_function.py +272 -107
  101. modal/partial_function.pyi +219 -120
  102. modal/proxy.py +15 -12
  103. modal/proxy.pyi +3 -8
  104. modal/queue.py +96 -72
  105. modal/queue.pyi +210 -135
  106. modal/requirements/2024.04.txt +2 -1
  107. modal/requirements/2024.10.txt +16 -0
  108. modal/requirements/README.md +21 -0
  109. modal/requirements/base-images.json +22 -0
  110. modal/retries.py +45 -4
  111. modal/runner.py +325 -203
  112. modal/runner.pyi +124 -110
  113. modal/running_app.py +27 -4
  114. modal/sandbox.py +509 -231
  115. modal/sandbox.pyi +396 -169
  116. modal/schedule.py +2 -2
  117. modal/scheduler_placement.py +20 -3
  118. modal/secret.py +41 -25
  119. modal/secret.pyi +62 -42
  120. modal/serving.py +39 -49
  121. modal/serving.pyi +37 -43
  122. modal/stream_type.py +15 -0
  123. modal/token_flow.py +5 -3
  124. modal/token_flow.pyi +37 -32
  125. modal/volume.py +123 -137
  126. modal/volume.pyi +228 -221
  127. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/METADATA +5 -5
  128. modal-0.72.13.dist-info/RECORD +174 -0
  129. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/top_level.txt +0 -1
  130. modal_docs/gen_reference_docs.py +3 -1
  131. modal_docs/mdmd/mdmd.py +0 -1
  132. modal_docs/mdmd/signatures.py +1 -2
  133. modal_global_objects/images/base_images.py +28 -0
  134. modal_global_objects/mounts/python_standalone.py +2 -2
  135. modal_proto/__init__.py +1 -1
  136. modal_proto/api.proto +1231 -531
  137. modal_proto/api_grpc.py +750 -430
  138. modal_proto/api_pb2.py +2102 -1176
  139. modal_proto/api_pb2.pyi +8859 -0
  140. modal_proto/api_pb2_grpc.py +1329 -675
  141. modal_proto/api_pb2_grpc.pyi +1416 -0
  142. modal_proto/modal_api_grpc.py +149 -0
  143. modal_proto/modal_options_grpc.py +3 -0
  144. modal_proto/options_pb2.pyi +20 -0
  145. modal_proto/options_pb2_grpc.pyi +7 -0
  146. modal_proto/py.typed +0 -0
  147. modal_version/__init__.py +1 -1
  148. modal_version/_version_generated.py +2 -2
  149. modal/_asgi.py +0 -370
  150. modal/_container_exec.py +0 -128
  151. modal/_container_io_manager.py +0 -646
  152. modal/_container_io_manager.pyi +0 -412
  153. modal/_sandbox_shell.py +0 -49
  154. modal/app_utils.py +0 -20
  155. modal/app_utils.pyi +0 -17
  156. modal/execution_context.pyi +0 -37
  157. modal/shared_volume.py +0 -23
  158. modal/shared_volume.pyi +0 -24
  159. modal-0.62.115.dist-info/RECORD +0 -207
  160. modal_global_objects/images/conda.py +0 -15
  161. modal_global_objects/images/debian_slim.py +0 -15
  162. modal_global_objects/images/micromamba.py +0 -15
  163. test/__init__.py +0 -1
  164. test/aio_test.py +0 -12
  165. test/async_utils_test.py +0 -279
  166. test/blob_test.py +0 -67
  167. test/cli_imports_test.py +0 -149
  168. test/cli_test.py +0 -674
  169. test/client_test.py +0 -203
  170. test/cloud_bucket_mount_test.py +0 -22
  171. test/cls_test.py +0 -636
  172. test/config_test.py +0 -149
  173. test/conftest.py +0 -1485
  174. test/container_app_test.py +0 -50
  175. test/container_test.py +0 -1405
  176. test/cpu_test.py +0 -23
  177. test/decorator_test.py +0 -85
  178. test/deprecation_test.py +0 -34
  179. test/dict_test.py +0 -51
  180. test/e2e_test.py +0 -68
  181. test/error_test.py +0 -7
  182. test/function_serialization_test.py +0 -32
  183. test/function_test.py +0 -791
  184. test/function_utils_test.py +0 -101
  185. test/gpu_test.py +0 -159
  186. test/grpc_utils_test.py +0 -82
  187. test/helpers.py +0 -47
  188. test/image_test.py +0 -814
  189. test/live_reload_test.py +0 -80
  190. test/lookup_test.py +0 -70
  191. test/mdmd_test.py +0 -329
  192. test/mount_test.py +0 -162
  193. test/mounted_files_test.py +0 -327
  194. test/network_file_system_test.py +0 -188
  195. test/notebook_test.py +0 -66
  196. test/object_test.py +0 -41
  197. test/package_utils_test.py +0 -25
  198. test/queue_test.py +0 -115
  199. test/resolver_test.py +0 -59
  200. test/retries_test.py +0 -67
  201. test/runner_test.py +0 -85
  202. test/sandbox_test.py +0 -191
  203. test/schedule_test.py +0 -15
  204. test/scheduler_placement_test.py +0 -57
  205. test/secret_test.py +0 -89
  206. test/serialization_test.py +0 -50
  207. test/stub_composition_test.py +0 -10
  208. test/stub_test.py +0 -361
  209. test/test_asgi_wrapper.py +0 -234
  210. test/token_flow_test.py +0 -18
  211. test/traceback_test.py +0 -135
  212. test/tunnel_test.py +0 -29
  213. test/utils_test.py +0 -88
  214. test/version_test.py +0 -14
  215. test/volume_test.py +0 -397
  216. test/watcher_test.py +0 -58
  217. test/webhook_test.py +0 -145
  218. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/LICENSE +0 -0
  219. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/WHEEL +0 -0
  220. {modal-0.62.115.dist-info → modal-0.72.13.dist-info}/entry_points.txt +0 -0
@@ -3,25 +3,34 @@ import asyncio
3
3
  import concurrent.futures
4
4
  import functools
5
5
  import inspect
6
+ import itertools
6
7
  import time
7
8
  import typing
9
+ from collections.abc import AsyncGenerator, AsyncIterable, Awaitable, Iterable, Iterator
8
10
  from contextlib import asynccontextmanager
9
- from typing import Any, AsyncGenerator, Callable, Iterator, List, Optional, Set, TypeVar, cast
11
+ from dataclasses import dataclass
12
+ from typing import (
13
+ Any,
14
+ Callable,
15
+ Optional,
16
+ TypeVar,
17
+ Union,
18
+ cast,
19
+ )
10
20
 
11
21
  import synchronicity
12
- from typing_extensions import ParamSpec
22
+ from synchronicity.async_utils import Runner
23
+ from synchronicity.exceptions import NestedEventLoops
24
+ from typing_extensions import ParamSpec, assert_type
13
25
 
14
26
  from ..exception import InvalidError
15
27
  from .logger import logger
16
28
 
17
29
  synchronizer = synchronicity.Synchronizer()
18
- # atexit.register(synchronizer.close)
19
30
 
20
31
 
21
32
  def synchronize_api(obj, target_module=None):
22
- if inspect.isclass(obj):
23
- blocking_name = obj.__name__.lstrip("_")
24
- elif inspect.isfunction(object):
33
+ if inspect.isclass(obj) or inspect.isfunction(obj):
25
34
  blocking_name = obj.__name__.lstrip("_")
26
35
  elif isinstance(obj, TypeVar):
27
36
  blocking_name = "_BLOCKING_" + obj.__name__
@@ -103,7 +112,7 @@ class TaskContext:
103
112
  ```
104
113
  """
105
114
 
106
- _loops: Set[asyncio.Task]
115
+ _loops: set[asyncio.Task]
107
116
 
108
117
  def __init__(self, grace: Optional[float] = None):
109
118
  self._grace = grace
@@ -140,7 +149,6 @@ class TaskContext:
140
149
  if gather_future:
141
150
  try:
142
151
  await gather_future
143
- # pre Python3.8, CancelledErrors were a subclass of exception
144
152
  except asyncio.CancelledError:
145
153
  pass
146
154
 
@@ -150,11 +158,12 @@ class TaskContext:
150
158
  # Only tasks without a done_callback will still be present in self._tasks
151
159
  task.result()
152
160
 
153
- if task.done() or task in self._loops:
161
+ if task.done() or task in self._loops: # Note: Legacy code, we can probably cancel loops.
154
162
  continue
155
163
 
156
164
  # Cancel any remaining unfinished tasks.
157
165
  task.cancel()
166
+ await asyncio.sleep(0) # wake up coroutines waiting for cancellations
158
167
 
159
168
  async def __aexit__(self, exc_type, value, tb):
160
169
  await self.stop()
@@ -171,28 +180,32 @@ class TaskContext:
171
180
  task.add_done_callback(self._tasks.discard)
172
181
  return task
173
182
 
174
- def infinite_loop(self, async_f, timeout: Optional[float] = 90, sleep: float = 10) -> asyncio.Task:
175
- function_name = async_f.__qualname__
183
+ def infinite_loop(
184
+ self, async_f, timeout: Optional[float] = 90, sleep: float = 10, log_exception: bool = True
185
+ ) -> asyncio.Task:
186
+ if isinstance(async_f, functools.partial):
187
+ function_name = async_f.func.__qualname__
188
+ else:
189
+ function_name = async_f.__qualname__
176
190
 
177
191
  async def loop_coro() -> None:
178
192
  logger.debug(f"Starting infinite loop {function_name}")
179
- while True:
180
- t0 = time.time()
193
+ while not self.exited:
181
194
  try:
182
195
  await asyncio.wait_for(async_f(), timeout=timeout)
183
- # pre Python3.8, CancelledErrors were a subclass of exception
184
- except asyncio.CancelledError:
185
- raise
186
- except Exception:
187
- time_elapsed = time.time() - t0
188
- logger.exception(f"Loop attempt failed for {function_name} (time_elapsed={time_elapsed})")
196
+ except Exception as exc:
197
+ if log_exception and isinstance(exc, asyncio.TimeoutError):
198
+ # Asyncio sends an empty message in this case, so let's use logger.error
199
+ logger.error(f"Loop attempt for {function_name} timed out")
200
+ elif log_exception:
201
+ # Propagate the exception to the logger
202
+ logger.exception(f"Loop attempt for {function_name} failed")
189
203
  try:
190
204
  await asyncio.wait_for(self._exited.wait(), timeout=sleep)
191
205
  except asyncio.TimeoutError:
192
206
  continue
193
- # Only reached if self._exited got set.
194
- logger.debug(f"Exiting infinite loop for {function_name}")
195
- break
207
+
208
+ logger.debug(f"Exiting infinite loop for {function_name}")
196
209
 
197
210
  t = self.create_task(loop_coro())
198
211
  t.set_name(f"{function_name} loop")
@@ -200,29 +213,39 @@ class TaskContext:
200
213
  t.add_done_callback(self._loops.discard)
201
214
  return t
202
215
 
203
- async def wait(self, *tasks):
204
- # Waits until all of tasks have finished
205
- # This is slightly different than asyncio.wait since the `tasks` argument
206
- # may be a subset of all the tasks.
207
- # If any of the task context's task raises, throw that exception
208
- # This is probably O(n^2) sadly but I guess it's fine
209
- unfinished_tasks = set(tasks)
210
- while True:
211
- unfinished_tasks &= self._tasks
212
- if not unfinished_tasks:
213
- break
214
- try:
215
- done, pending = await asyncio.wait_for(
216
- asyncio.wait(self._tasks, return_when=asyncio.FIRST_COMPLETED), timeout=30.0
217
- )
218
- except asyncio.TimeoutError:
219
- continue
220
- for task in done:
221
- task.result() # Raise exception if needed
222
- if task in unfinished_tasks:
223
- unfinished_tasks.remove(task)
224
- if task in self._tasks:
225
- self._tasks.remove(task)
216
+ @staticmethod
217
+ async def gather(*coros: Awaitable) -> Any:
218
+ """Wait for a sequence of coroutines to finish, concurrently.
219
+
220
+ This is similar to `asyncio.gather()`, but it uses TaskContext to cancel all remaining tasks
221
+ if one fails with an exception other than `asyncio.CancelledError`. The native `asyncio`
222
+ function does not cancel remaining tasks in this case, which can lead to surprises.
223
+
224
+ For example, if you use `asyncio.gather(t1, t2, t3)` and t2 raises an exception, then t1 and
225
+ t3 would continue running. With `TaskContext.gather(t1, t2, t3)`, they are cancelled.
226
+
227
+ (It's still acceptable to use `asyncio.gather()` if you don't need cancellation — for
228
+ example, if you're just gathering quick coroutines with no side-effects. Or if you're
229
+ gathering the tasks with `return_exceptions=True`.)
230
+
231
+ Usage:
232
+
233
+ ```python notest
234
+ # Example 1: Await three coroutines
235
+ created_object, other_work, new_plumbing = await TaskContext.gather(
236
+ create_my_object(),
237
+ do_some_other_work(),
238
+ fix_plumbing(),
239
+ )
240
+
241
+ # Example 2: Gather a list of coroutines
242
+ coros = [a.load() for a in objects]
243
+ results = await TaskContext.gather(*coros)
244
+ ```
245
+ """
246
+ async with TaskContext() as tc:
247
+ results = await asyncio.gather(*(tc.create_task(coro) for coro in coros))
248
+ return results
226
249
 
227
250
 
228
251
  def run_coro_blocking(coro):
@@ -243,7 +266,7 @@ async def queue_batch_iterator(q: asyncio.Queue, max_batch_size=100, debounce_ti
243
266
 
244
267
  Treats a None value as end of queue items
245
268
  """
246
- item_list: List[Any] = []
269
+ item_list: list[Any] = []
247
270
 
248
271
  while True:
249
272
  if q.empty() and len(item_list) > 0:
@@ -290,10 +313,18 @@ class _WarnIfGeneratorIsNotConsumed:
290
313
  if not self.iterated and not self.warned:
291
314
  self.warned = True
292
315
  logger.warning(
293
- f"Warning: the results of a call to {self.function_name} was not consumed, so the call will never be executed."
294
- f" Consider a for-loop like `for x in {self.function_name}(...)` or unpacking the generator using `list(...)`"
316
+ f"Warning: the results of a call to {self.function_name} was not consumed, "
317
+ "so the call will never be executed."
318
+ f" Consider a for-loop like `for x in {self.function_name}(...)` or "
319
+ "unpacking the generator using `list(...)`"
295
320
  )
296
321
 
322
+ async def athrow(self, exc):
323
+ return await self.gen.athrow(exc)
324
+
325
+ async def aclose(self):
326
+ return await self.gen.aclose()
327
+
297
328
 
298
329
  synchronize_api(_WarnIfGeneratorIsNotConsumed)
299
330
 
@@ -331,7 +362,7 @@ def warn_if_generator_is_not_consumed(function_name: Optional[str] = None):
331
362
  return decorator
332
363
 
333
364
 
334
- class AsyncOrSyncIteratable:
365
+ class AsyncOrSyncIterable:
335
366
  """Compatibility class for non-synchronicity wrapped async iterables to get
336
367
  both async and sync interfaces in the same way that synchronicity does (but on the main thread)
337
368
  so they can be "lazily" iterated using either `for _ in x` or `async for _ in x`
@@ -340,7 +371,7 @@ class AsyncOrSyncIteratable:
340
371
  from an already async context, since that would otherwise deadlock the event loop
341
372
  """
342
373
 
343
- def __init__(self, async_iterable: typing.AsyncIterable[Any], nested_async_message):
374
+ def __init__(self, async_iterable: typing.AsyncGenerator[Any, None], nested_async_message):
344
375
  self._async_iterable = async_iterable
345
376
  self.nested_async_message = nested_async_message
346
377
 
@@ -349,9 +380,9 @@ class AsyncOrSyncIteratable:
349
380
 
350
381
  def __iter__(self):
351
382
  try:
352
- for output in run_generator_sync(self._async_iterable): # type: ignore
353
- yield output
354
- except NestedAsyncCalls:
383
+ with Runner() as runner:
384
+ yield from run_async_gen(runner, self._async_iterable)
385
+ except NestedEventLoops:
355
386
  raise InvalidError(self.nested_async_message)
356
387
 
357
388
 
@@ -372,6 +403,7 @@ def on_shutdown(coro):
372
403
 
373
404
  T = TypeVar("T")
374
405
  P = ParamSpec("P")
406
+ V = TypeVar("V")
375
407
 
376
408
 
377
409
  def asyncify(f: Callable[P, T]) -> Callable[P, typing.Coroutine[None, None, T]]:
@@ -397,40 +429,6 @@ async def iterate_blocking(iterator: Iterator[T]) -> AsyncGenerator[T, None]:
397
429
  yield cast(T, obj)
398
430
 
399
431
 
400
- class ConcurrencyPool:
401
- def __init__(self, concurrency_limit: int):
402
- self.semaphore = asyncio.Semaphore(concurrency_limit)
403
-
404
- async def run_coros(self, coros: typing.Iterable[typing.Coroutine], return_exceptions=False):
405
- async def blocking_wrapper(coro):
406
- # Not using async with on the semaphore is intentional here - if return_exceptions=False
407
- # manual release prevents starting extraneous tasks after exceptions.
408
- try:
409
- await self.semaphore.acquire()
410
- except asyncio.CancelledError:
411
- coro.close() # avoid "coroutine was never awaited" warnings
412
-
413
- try:
414
- res = await coro
415
- self.semaphore.release()
416
- return res
417
- except BaseException as e:
418
- if return_exceptions:
419
- self.semaphore.release()
420
- raise e
421
-
422
- # asyncio.gather() is weird - it doesn't cancel outstanding awaitables on exceptions when
423
- # return_exceptions=False --> wrap the coros in tasks are cancel them explicitly on exception.
424
- tasks = [asyncio.create_task(blocking_wrapper(coro)) for coro in coros]
425
- g = asyncio.gather(*tasks, return_exceptions=return_exceptions)
426
- try:
427
- return await g
428
- except BaseException as e:
429
- for t in tasks:
430
- t.cancel()
431
- raise e
432
-
433
-
434
432
  @asynccontextmanager
435
433
  async def asyncnullcontext(*args, **kwargs):
436
434
  """Async noop context manager.
@@ -448,21 +446,11 @@ YIELD_TYPE = typing.TypeVar("YIELD_TYPE")
448
446
  SEND_TYPE = typing.TypeVar("SEND_TYPE")
449
447
 
450
448
 
451
- class NestedAsyncCalls(Exception):
452
- pass
453
-
454
-
455
- def run_generator_sync(
449
+ def run_async_gen(
450
+ runner: Runner,
456
451
  gen: typing.AsyncGenerator[YIELD_TYPE, SEND_TYPE],
457
452
  ) -> typing.Generator[YIELD_TYPE, SEND_TYPE, None]:
458
- try:
459
- asyncio.get_running_loop()
460
- except RuntimeError:
461
- pass # no event loop - this is what we expect!
462
- else:
463
- raise NestedAsyncCalls()
464
- loop = asyncio.new_event_loop() # set up new event loop for the map so we can use async logic
465
-
453
+ """Convert an async generator into a sync one"""
466
454
  # more or less copied from synchronicity's implementation:
467
455
  next_send: typing.Union[SEND_TYPE, None] = None
468
456
  next_yield: YIELD_TYPE
@@ -470,14 +458,308 @@ def run_generator_sync(
470
458
  while True:
471
459
  try:
472
460
  if exc:
473
- next_yield = loop.run_until_complete(gen.athrow(exc))
461
+ next_yield = runner.run(gen.athrow(exc))
474
462
  else:
475
- next_yield = loop.run_until_complete(gen.asend(next_send)) # type: ignore[arg-type]
463
+ next_yield = runner.run(gen.asend(next_send)) # type: ignore[arg-type]
464
+ except KeyboardInterrupt as e:
465
+ raise e from None
476
466
  except StopAsyncIteration:
477
- break
467
+ break # typically a graceful exit of the async generator
478
468
  try:
479
469
  next_send = yield next_yield
480
470
  exc = None
481
471
  except BaseException as err:
482
472
  exc = err
483
- loop.close()
473
+
474
+
475
+ class aclosing(typing.Generic[T]): # noqa
476
+ # backport of Python contextlib.aclosing from Python 3.10
477
+ def __init__(self, agen: AsyncGenerator[T, None]):
478
+ self.agen = agen
479
+
480
+ async def __aenter__(self) -> AsyncGenerator[T, None]:
481
+ return self.agen
482
+
483
+ async def __aexit__(self, exc, exc_type, tb):
484
+ await self.agen.aclose()
485
+
486
+
487
+ async def sync_or_async_iter(iter: Union[Iterable[T], AsyncIterable[T]]) -> AsyncGenerator[T, None]:
488
+ if hasattr(iter, "__aiter__"):
489
+ agen = typing.cast(AsyncGenerator[T, None], iter)
490
+ try:
491
+ async for item in agen:
492
+ yield item
493
+ finally:
494
+ if hasattr(agen, "aclose"):
495
+ # All AsyncGenerator's have an aclose method
496
+ # but some AsyncIterable's don't necessarily
497
+ await agen.aclose()
498
+ else:
499
+ assert hasattr(iter, "__iter__"), "sync_or_async_iter requires an Iterable or AsyncGenerator"
500
+ # This intentionally could block the event loop for the duration of calling __iter__ and __next__,
501
+ # so in non-trivial cases (like passing lists and ranges) this could be quite a foot gun for users #
502
+ # w/ async code (but they can work around it by always using async iterators)
503
+ for item in typing.cast(Iterable[T], iter):
504
+ yield item
505
+
506
+
507
+ @typing.overload
508
+ def async_zip(g1: AsyncGenerator[T, None], g2: AsyncGenerator[V, None], /) -> AsyncGenerator[tuple[T, V], None]:
509
+ ...
510
+
511
+
512
+ @typing.overload
513
+ def async_zip(*generators: AsyncGenerator[T, None]) -> AsyncGenerator[tuple[T, ...], None]:
514
+ ...
515
+
516
+
517
+ async def async_zip(*generators):
518
+ tasks = []
519
+ try:
520
+ while True:
521
+ try:
522
+
523
+ async def next_item(gen):
524
+ return await gen.__anext__()
525
+
526
+ tasks = [asyncio.create_task(next_item(gen)) for gen in generators]
527
+ items = await asyncio.gather(*tasks)
528
+ yield tuple(items)
529
+ except StopAsyncIteration:
530
+ break
531
+ finally:
532
+ cancelled_tasks = []
533
+ for task in tasks:
534
+ if not task.done():
535
+ task.cancel()
536
+ cancelled_tasks.append(task)
537
+ try:
538
+ await asyncio.gather(*cancelled_tasks)
539
+ except asyncio.CancelledError:
540
+ pass
541
+
542
+ first_exception = None
543
+ for gen in generators:
544
+ try:
545
+ await gen.aclose()
546
+ except BaseException as e:
547
+ if first_exception is None:
548
+ first_exception = e
549
+ logger.exception(f"Error closing async generator: {e}")
550
+ if first_exception is not None:
551
+ raise first_exception
552
+
553
+
554
+ @dataclass
555
+ class ValueWrapper(typing.Generic[T]):
556
+ value: T
557
+
558
+
559
+ @dataclass
560
+ class ExceptionWrapper:
561
+ value: Exception
562
+
563
+
564
+ class StopSentinelType:
565
+ ...
566
+
567
+
568
+ STOP_SENTINEL = StopSentinelType()
569
+
570
+
571
+ async def async_merge(*generators: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
572
+ """
573
+ Asynchronously merges multiple async generators into a single async generator.
574
+
575
+ This function takes multiple async generators and yields their values in the order
576
+ they are produced. If any generator raises an exception, the exception is propagated.
577
+
578
+ Args:
579
+ *generators: One or more async generators to be merged.
580
+
581
+ Yields:
582
+ The values produced by the input async generators.
583
+
584
+ Raises:
585
+ Exception: If any of the input generators raises an exception, it is propagated.
586
+
587
+ Usage:
588
+ ```python
589
+ import asyncio
590
+ from modal._utils.async_utils import async_merge
591
+
592
+ async def gen1():
593
+ yield 1
594
+ yield 2
595
+
596
+ async def gen2():
597
+ yield "a"
598
+ yield "b"
599
+
600
+ async def example():
601
+ values = set()
602
+ async for value in async_merge(gen1(), gen2()):
603
+ values.add(value)
604
+
605
+ return values
606
+
607
+ # Output could be: {1, "a", 2, "b"} (order may vary)
608
+ values = asyncio.run(example())
609
+ assert values == {1, "a", 2, "b"}
610
+ ```
611
+ """
612
+ queue: asyncio.Queue[Union[ValueWrapper[T], ExceptionWrapper]] = asyncio.Queue(maxsize=len(generators) * 10)
613
+
614
+ async def producer(generator: AsyncGenerator[T, None]):
615
+ try:
616
+ async for item in generator:
617
+ await queue.put(ValueWrapper(item))
618
+ except Exception as e:
619
+ await queue.put(ExceptionWrapper(e))
620
+
621
+ tasks = {asyncio.create_task(producer(gen)) for gen in generators}
622
+ new_output_task = asyncio.create_task(queue.get())
623
+
624
+ try:
625
+ while tasks:
626
+ done, _ = await asyncio.wait(
627
+ [*tasks, new_output_task],
628
+ return_when=asyncio.FIRST_COMPLETED,
629
+ )
630
+
631
+ if new_output_task in done:
632
+ item = new_output_task.result()
633
+ if isinstance(item, ValueWrapper):
634
+ yield item.value
635
+ else:
636
+ assert_type(item, ExceptionWrapper)
637
+ raise item.value
638
+
639
+ new_output_task = asyncio.create_task(queue.get())
640
+
641
+ finished_producers = done & tasks
642
+ tasks -= finished_producers
643
+ for finished_producer in finished_producers:
644
+ # this is done in order to catch potential raised errors/cancellations
645
+ # from within worker tasks as soon as they happen.
646
+ await finished_producer
647
+
648
+ while not queue.empty():
649
+ item = await new_output_task
650
+ if isinstance(item, ValueWrapper):
651
+ yield item.value
652
+ else:
653
+ assert_type(item, ExceptionWrapper)
654
+ raise item.value
655
+
656
+ new_output_task = asyncio.create_task(queue.get())
657
+
658
+ finally:
659
+ if not new_output_task.done():
660
+ new_output_task.cancel()
661
+ for task in tasks:
662
+ if not task.done():
663
+ try:
664
+ task.cancel()
665
+ await task
666
+ except asyncio.CancelledError:
667
+ pass
668
+
669
+
670
+ async def callable_to_agen(awaitable: Callable[[], Awaitable[T]]) -> AsyncGenerator[T, None]:
671
+ yield await awaitable()
672
+
673
+
674
+ async def gather_cancel_on_exc(*coros_or_futures):
675
+ input_tasks = [asyncio.ensure_future(t) for t in coros_or_futures]
676
+ try:
677
+ return await asyncio.gather(*input_tasks)
678
+ except BaseException:
679
+ for t in input_tasks:
680
+ t.cancel()
681
+ await asyncio.gather(*input_tasks, return_exceptions=False) # handle cancellations
682
+ raise
683
+
684
+
685
+ async def async_map(
686
+ input_generator: AsyncGenerator[T, None],
687
+ async_mapper_func: Callable[[T], Awaitable[V]],
688
+ concurrency: int,
689
+ ) -> AsyncGenerator[V, None]:
690
+ queue: asyncio.Queue[Union[ValueWrapper[T], StopSentinelType]] = asyncio.Queue(maxsize=concurrency * 2)
691
+
692
+ async def producer() -> AsyncGenerator[V, None]:
693
+ async for item in input_generator:
694
+ await queue.put(ValueWrapper(item))
695
+
696
+ for _ in range(concurrency):
697
+ await queue.put(STOP_SENTINEL)
698
+
699
+ if False:
700
+ # Need it to be an async generator for async_merge
701
+ # but we don't want to yield anything
702
+ yield
703
+
704
+ async def worker() -> AsyncGenerator[V, None]:
705
+ while True:
706
+ item = await queue.get()
707
+ if isinstance(item, ValueWrapper):
708
+ yield await async_mapper_func(item.value)
709
+ elif isinstance(item, ExceptionWrapper):
710
+ raise item.value
711
+ else:
712
+ assert_type(item, StopSentinelType)
713
+ break
714
+
715
+ async with aclosing(async_merge(*[worker() for _ in range(concurrency)], producer())) as stream:
716
+ async for item in stream:
717
+ yield item
718
+
719
+
720
+ async def async_map_ordered(
721
+ input_generator: AsyncGenerator[T, None],
722
+ async_mapper_func: Callable[[T], Awaitable[V]],
723
+ concurrency: int,
724
+ buffer_size: Optional[int] = None,
725
+ ) -> AsyncGenerator[V, None]:
726
+ semaphore = asyncio.Semaphore(buffer_size or concurrency)
727
+
728
+ async def mapper_func_wrapper(tup: tuple[int, T]) -> tuple[int, V]:
729
+ return (tup[0], await async_mapper_func(tup[1]))
730
+
731
+ async def counter() -> AsyncGenerator[int, None]:
732
+ for i in itertools.count():
733
+ await semaphore.acquire()
734
+ yield i
735
+
736
+ next_idx = 0
737
+ buffer = {}
738
+
739
+ async with aclosing(async_map(async_zip(counter(), input_generator), mapper_func_wrapper, concurrency)) as stream:
740
+ async for output_idx, output_item in stream:
741
+ buffer[output_idx] = output_item
742
+
743
+ while next_idx in buffer:
744
+ yield buffer[next_idx]
745
+ semaphore.release()
746
+ del buffer[next_idx]
747
+ next_idx += 1
748
+
749
+
750
+ async def async_chain(*generators: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
751
+ try:
752
+ for gen in generators:
753
+ async for item in gen:
754
+ yield item
755
+ finally:
756
+ first_exception = None
757
+ for gen in generators:
758
+ try:
759
+ await gen.aclose()
760
+ except BaseException as e:
761
+ if first_exception is None:
762
+ first_exception = e
763
+ logger.exception(f"Error closing async generator: {e}")
764
+ if first_exception is not None:
765
+ raise first_exception