modal 0.62.16__py3-none-any.whl → 0.72.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. modal/__init__.py +17 -13
  2. modal/__main__.py +41 -3
  3. modal/_clustered_functions.py +80 -0
  4. modal/_clustered_functions.pyi +22 -0
  5. modal/_container_entrypoint.py +420 -937
  6. modal/_ipython.py +3 -13
  7. modal/_location.py +17 -10
  8. modal/_output.py +243 -99
  9. modal/_pty.py +2 -2
  10. modal/_resolver.py +55 -59
  11. modal/_resources.py +51 -0
  12. modal/_runtime/__init__.py +1 -0
  13. modal/_runtime/asgi.py +519 -0
  14. modal/_runtime/container_io_manager.py +1036 -0
  15. modal/_runtime/execution_context.py +89 -0
  16. modal/_runtime/telemetry.py +169 -0
  17. modal/_runtime/user_code_imports.py +356 -0
  18. modal/_serialization.py +134 -9
  19. modal/_traceback.py +47 -187
  20. modal/_tunnel.py +52 -16
  21. modal/_tunnel.pyi +19 -36
  22. modal/_utils/app_utils.py +3 -17
  23. modal/_utils/async_utils.py +479 -100
  24. modal/_utils/blob_utils.py +157 -186
  25. modal/_utils/bytes_io_segment_payload.py +97 -0
  26. modal/_utils/deprecation.py +89 -0
  27. modal/_utils/docker_utils.py +98 -0
  28. modal/_utils/function_utils.py +460 -171
  29. modal/_utils/grpc_testing.py +47 -31
  30. modal/_utils/grpc_utils.py +62 -109
  31. modal/_utils/hash_utils.py +61 -19
  32. modal/_utils/http_utils.py +39 -9
  33. modal/_utils/logger.py +2 -1
  34. modal/_utils/mount_utils.py +34 -16
  35. modal/_utils/name_utils.py +58 -0
  36. modal/_utils/package_utils.py +14 -1
  37. modal/_utils/pattern_utils.py +205 -0
  38. modal/_utils/rand_pb_testing.py +5 -7
  39. modal/_utils/shell_utils.py +15 -49
  40. modal/_vendor/a2wsgi_wsgi.py +62 -72
  41. modal/_vendor/cloudpickle.py +1 -1
  42. modal/_watcher.py +14 -12
  43. modal/app.py +1003 -314
  44. modal/app.pyi +540 -264
  45. modal/call_graph.py +7 -6
  46. modal/cli/_download.py +63 -53
  47. modal/cli/_traceback.py +200 -0
  48. modal/cli/app.py +205 -45
  49. modal/cli/config.py +12 -5
  50. modal/cli/container.py +62 -14
  51. modal/cli/dict.py +128 -0
  52. modal/cli/entry_point.py +26 -13
  53. modal/cli/environment.py +40 -9
  54. modal/cli/import_refs.py +64 -58
  55. modal/cli/launch.py +32 -18
  56. modal/cli/network_file_system.py +64 -83
  57. modal/cli/profile.py +1 -1
  58. modal/cli/programs/run_jupyter.py +35 -10
  59. modal/cli/programs/vscode.py +60 -10
  60. modal/cli/queues.py +131 -0
  61. modal/cli/run.py +234 -131
  62. modal/cli/secret.py +8 -7
  63. modal/cli/token.py +7 -2
  64. modal/cli/utils.py +79 -10
  65. modal/cli/volume.py +110 -109
  66. modal/client.py +250 -144
  67. modal/client.pyi +157 -118
  68. modal/cloud_bucket_mount.py +108 -34
  69. modal/cloud_bucket_mount.pyi +32 -38
  70. modal/cls.py +535 -148
  71. modal/cls.pyi +190 -146
  72. modal/config.py +41 -19
  73. modal/container_process.py +177 -0
  74. modal/container_process.pyi +82 -0
  75. modal/dict.py +111 -65
  76. modal/dict.pyi +136 -131
  77. modal/environments.py +106 -5
  78. modal/environments.pyi +77 -25
  79. modal/exception.py +34 -43
  80. modal/experimental.py +61 -2
  81. modal/extensions/ipython.py +5 -5
  82. modal/file_io.py +537 -0
  83. modal/file_io.pyi +235 -0
  84. modal/file_pattern_matcher.py +197 -0
  85. modal/functions.py +906 -911
  86. modal/functions.pyi +466 -430
  87. modal/gpu.py +57 -44
  88. modal/image.py +1089 -479
  89. modal/image.pyi +584 -228
  90. modal/io_streams.py +434 -0
  91. modal/io_streams.pyi +122 -0
  92. modal/mount.py +314 -101
  93. modal/mount.pyi +241 -235
  94. modal/network_file_system.py +92 -92
  95. modal/network_file_system.pyi +152 -110
  96. modal/object.py +67 -36
  97. modal/object.pyi +166 -143
  98. modal/output.py +63 -0
  99. modal/parallel_map.py +434 -0
  100. modal/parallel_map.pyi +75 -0
  101. modal/partial_function.py +282 -117
  102. modal/partial_function.pyi +222 -129
  103. modal/proxy.py +15 -12
  104. modal/proxy.pyi +3 -8
  105. modal/queue.py +182 -65
  106. modal/queue.pyi +218 -118
  107. modal/requirements/2024.04.txt +29 -0
  108. modal/requirements/2024.10.txt +16 -0
  109. modal/requirements/README.md +21 -0
  110. modal/requirements/base-images.json +22 -0
  111. modal/retries.py +48 -7
  112. modal/runner.py +459 -156
  113. modal/runner.pyi +135 -71
  114. modal/running_app.py +38 -0
  115. modal/sandbox.py +514 -236
  116. modal/sandbox.pyi +397 -169
  117. modal/schedule.py +4 -4
  118. modal/scheduler_placement.py +20 -3
  119. modal/secret.py +56 -31
  120. modal/secret.pyi +62 -42
  121. modal/serving.py +51 -56
  122. modal/serving.pyi +44 -36
  123. modal/stream_type.py +15 -0
  124. modal/token_flow.py +5 -3
  125. modal/token_flow.pyi +37 -32
  126. modal/volume.py +285 -157
  127. modal/volume.pyi +249 -184
  128. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/METADATA +7 -7
  129. modal-0.72.11.dist-info/RECORD +174 -0
  130. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/top_level.txt +0 -1
  131. modal_docs/gen_reference_docs.py +3 -1
  132. modal_docs/mdmd/mdmd.py +0 -1
  133. modal_docs/mdmd/signatures.py +5 -2
  134. modal_global_objects/images/base_images.py +28 -0
  135. modal_global_objects/mounts/python_standalone.py +2 -2
  136. modal_proto/__init__.py +1 -1
  137. modal_proto/api.proto +1288 -533
  138. modal_proto/api_grpc.py +856 -456
  139. modal_proto/api_pb2.py +2165 -1157
  140. modal_proto/api_pb2.pyi +8859 -0
  141. modal_proto/api_pb2_grpc.py +1674 -855
  142. modal_proto/api_pb2_grpc.pyi +1416 -0
  143. modal_proto/modal_api_grpc.py +149 -0
  144. modal_proto/modal_options_grpc.py +3 -0
  145. modal_proto/options_pb2.pyi +20 -0
  146. modal_proto/options_pb2_grpc.pyi +7 -0
  147. modal_proto/py.typed +0 -0
  148. modal_version/__init__.py +1 -1
  149. modal_version/_version_generated.py +2 -2
  150. modal/_asgi.py +0 -370
  151. modal/_container_entrypoint.pyi +0 -378
  152. modal/_container_exec.py +0 -128
  153. modal/_sandbox_shell.py +0 -49
  154. modal/shared_volume.py +0 -23
  155. modal/shared_volume.pyi +0 -24
  156. modal/stub.py +0 -783
  157. modal/stub.pyi +0 -332
  158. modal-0.62.16.dist-info/RECORD +0 -198
  159. modal_global_objects/images/conda.py +0 -15
  160. modal_global_objects/images/debian_slim.py +0 -15
  161. modal_global_objects/images/micromamba.py +0 -15
  162. test/__init__.py +0 -1
  163. test/aio_test.py +0 -12
  164. test/async_utils_test.py +0 -262
  165. test/blob_test.py +0 -67
  166. test/cli_imports_test.py +0 -149
  167. test/cli_test.py +0 -659
  168. test/client_test.py +0 -194
  169. test/cls_test.py +0 -630
  170. test/config_test.py +0 -137
  171. test/conftest.py +0 -1420
  172. test/container_app_test.py +0 -32
  173. test/container_test.py +0 -1389
  174. test/cpu_test.py +0 -23
  175. test/decorator_test.py +0 -85
  176. test/deprecation_test.py +0 -34
  177. test/dict_test.py +0 -33
  178. test/e2e_test.py +0 -68
  179. test/error_test.py +0 -7
  180. test/function_serialization_test.py +0 -32
  181. test/function_test.py +0 -653
  182. test/function_utils_test.py +0 -101
  183. test/gpu_test.py +0 -159
  184. test/grpc_utils_test.py +0 -141
  185. test/helpers.py +0 -42
  186. test/image_test.py +0 -669
  187. test/live_reload_test.py +0 -80
  188. test/lookup_test.py +0 -70
  189. test/mdmd_test.py +0 -329
  190. test/mount_test.py +0 -162
  191. test/mounted_files_test.py +0 -329
  192. test/network_file_system_test.py +0 -181
  193. test/notebook_test.py +0 -66
  194. test/object_test.py +0 -41
  195. test/package_utils_test.py +0 -25
  196. test/queue_test.py +0 -97
  197. test/resolver_test.py +0 -58
  198. test/retries_test.py +0 -67
  199. test/runner_test.py +0 -85
  200. test/sandbox_test.py +0 -191
  201. test/schedule_test.py +0 -15
  202. test/scheduler_placement_test.py +0 -29
  203. test/secret_test.py +0 -78
  204. test/serialization_test.py +0 -42
  205. test/stub_composition_test.py +0 -10
  206. test/stub_test.py +0 -360
  207. test/test_asgi_wrapper.py +0 -234
  208. test/token_flow_test.py +0 -18
  209. test/traceback_test.py +0 -135
  210. test/tunnel_test.py +0 -29
  211. test/utils_test.py +0 -88
  212. test/version_test.py +0 -14
  213. test/volume_test.py +0 -341
  214. test/watcher_test.py +0 -30
  215. test/webhook_test.py +0 -146
  216. /modal/{requirements.312.txt → requirements/2023.12.312.txt} +0 -0
  217. /modal/{requirements.txt → requirements/2023.12.txt} +0 -0
  218. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/LICENSE +0 -0
  219. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/WHEEL +0 -0
  220. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/entry_points.txt +0 -0
@@ -3,25 +3,34 @@ import asyncio
3
3
  import concurrent.futures
4
4
  import functools
5
5
  import inspect
6
- import sys
6
+ import itertools
7
7
  import time
8
8
  import typing
9
+ from collections.abc import AsyncGenerator, AsyncIterable, Awaitable, Iterable, Iterator
9
10
  from contextlib import asynccontextmanager
10
- from typing import Any, AsyncGenerator, Callable, Iterator, List, Optional, Set, TypeVar, cast
11
+ from dataclasses import dataclass
12
+ from typing import (
13
+ Any,
14
+ Callable,
15
+ Optional,
16
+ TypeVar,
17
+ Union,
18
+ cast,
19
+ )
11
20
 
12
21
  import synchronicity
13
- from typing_extensions import ParamSpec
22
+ from synchronicity.async_utils import Runner
23
+ from synchronicity.exceptions import NestedEventLoops
24
+ from typing_extensions import ParamSpec, assert_type
14
25
 
26
+ from ..exception import InvalidError
15
27
  from .logger import logger
16
28
 
17
29
  synchronizer = synchronicity.Synchronizer()
18
- # atexit.register(synchronizer.close)
19
30
 
20
31
 
21
32
  def synchronize_api(obj, target_module=None):
22
- if inspect.isclass(obj):
23
- blocking_name = obj.__name__.lstrip("_")
24
- elif inspect.isfunction(object):
33
+ if inspect.isclass(obj) or inspect.isfunction(obj):
25
34
  blocking_name = obj.__name__.lstrip("_")
26
35
  elif isinstance(obj, TypeVar):
27
36
  blocking_name = "_BLOCKING_" + obj.__name__
@@ -86,14 +95,24 @@ def retry(direct_fn=None, *, n_attempts=3, base_delay=0, delay_factor=2, timeout
86
95
 
87
96
 
88
97
  class TaskContext:
89
- """Simple thing to make sure we don't have stray tasks.
98
+ """A structured group that helps manage stray tasks.
99
+
100
+ This differs from the standard library `asyncio.TaskGroup` in that it cancels all tasks still
101
+ running after exiting the context manager, rather than waiting for them to finish.
102
+
103
+ A `TaskContext` can have an optional `grace` period in seconds, which will wait for a certain
104
+ amount of time before cancelling all remaining tasks. This is useful for allowing tasks to
105
+ gracefully exit when they determine that the context is shutting down.
90
106
 
91
107
  Usage:
108
+
109
+ ```python notest
92
110
  async with TaskContext() as task_context:
93
- task = task_context.create(coro())
111
+ task = task_context.create_task(coro())
112
+ ```
94
113
  """
95
114
 
96
- _loops: Set[asyncio.Task]
115
+ _loops: set[asyncio.Task]
97
116
 
98
117
  def __init__(self, grace: Optional[float] = None):
99
118
  self._grace = grace
@@ -130,7 +149,6 @@ class TaskContext:
130
149
  if gather_future:
131
150
  try:
132
151
  await gather_future
133
- # pre Python3.8, CancelledErrors were a subclass of exception
134
152
  except asyncio.CancelledError:
135
153
  pass
136
154
 
@@ -140,15 +158,12 @@ class TaskContext:
140
158
  # Only tasks without a done_callback will still be present in self._tasks
141
159
  task.result()
142
160
 
143
- if task.done() or task in self._loops:
161
+ if task.done() or task in self._loops: # Note: Legacy code, we can probably cancel loops.
144
162
  continue
145
163
 
146
- if sys.version_info >= (3, 11):
147
- already_cancelling = task.cancelling() > 0
148
- if not already_cancelling:
149
- logger.warning(f"Canceling remaining unfinished task: {task}")
150
-
164
+ # Cancel any remaining unfinished tasks.
151
165
  task.cancel()
166
+ await asyncio.sleep(0) # wake up coroutines waiting for cancellations
152
167
 
153
168
  async def __aexit__(self, exc_type, value, tb):
154
169
  await self.stop()
@@ -165,28 +180,32 @@ class TaskContext:
165
180
  task.add_done_callback(self._tasks.discard)
166
181
  return task
167
182
 
168
- def infinite_loop(self, async_f, timeout: Optional[float] = 90, sleep: float = 10) -> asyncio.Task:
169
- function_name = async_f.__qualname__
183
+ def infinite_loop(
184
+ self, async_f, timeout: Optional[float] = 90, sleep: float = 10, log_exception: bool = True
185
+ ) -> asyncio.Task:
186
+ if isinstance(async_f, functools.partial):
187
+ function_name = async_f.func.__qualname__
188
+ else:
189
+ function_name = async_f.__qualname__
170
190
 
171
191
  async def loop_coro() -> None:
172
192
  logger.debug(f"Starting infinite loop {function_name}")
173
- while True:
174
- t0 = time.time()
193
+ while not self.exited:
175
194
  try:
176
195
  await asyncio.wait_for(async_f(), timeout=timeout)
177
- # pre Python3.8, CancelledErrors were a subclass of exception
178
- except asyncio.CancelledError:
179
- raise
180
- except Exception:
181
- time_elapsed = time.time() - t0
182
- logger.exception(f"Loop attempt failed for {function_name} (time_elapsed={time_elapsed})")
196
+ except Exception as exc:
197
+ if log_exception and isinstance(exc, asyncio.TimeoutError):
198
+ # Asyncio sends an empty message in this case, so let's use logger.error
199
+ logger.error(f"Loop attempt for {function_name} timed out")
200
+ elif log_exception:
201
+ # Propagate the exception to the logger
202
+ logger.exception(f"Loop attempt for {function_name} failed")
183
203
  try:
184
204
  await asyncio.wait_for(self._exited.wait(), timeout=sleep)
185
205
  except asyncio.TimeoutError:
186
206
  continue
187
- # Only reached if self._exited got set.
188
- logger.debug(f"Exiting infinite loop for {function_name}")
189
- break
207
+
208
+ logger.debug(f"Exiting infinite loop for {function_name}")
190
209
 
191
210
  t = self.create_task(loop_coro())
192
211
  t.set_name(f"{function_name} loop")
@@ -194,29 +213,39 @@ class TaskContext:
194
213
  t.add_done_callback(self._loops.discard)
195
214
  return t
196
215
 
197
- async def wait(self, *tasks):
198
- # Waits until all of tasks have finished
199
- # This is slightly different than asyncio.wait since the `tasks` argument
200
- # may be a subset of all the tasks.
201
- # If any of the task context's task raises, throw that exception
202
- # This is probably O(n^2) sadly but I guess it's fine
203
- unfinished_tasks = set(tasks)
204
- while True:
205
- unfinished_tasks &= self._tasks
206
- if not unfinished_tasks:
207
- break
208
- try:
209
- done, pending = await asyncio.wait_for(
210
- asyncio.wait(self._tasks, return_when=asyncio.FIRST_COMPLETED), timeout=30.0
211
- )
212
- except asyncio.TimeoutError:
213
- continue
214
- for task in done:
215
- task.result() # Raise exception if needed
216
- if task in unfinished_tasks:
217
- unfinished_tasks.remove(task)
218
- if task in self._tasks:
219
- self._tasks.remove(task)
216
+ @staticmethod
217
+ async def gather(*coros: Awaitable) -> Any:
218
+ """Wait for a sequence of coroutines to finish, concurrently.
219
+
220
+ This is similar to `asyncio.gather()`, but it uses TaskContext to cancel all remaining tasks
221
+ if one fails with an exception other than `asyncio.CancelledError`. The native `asyncio`
222
+ function does not cancel remaining tasks in this case, which can lead to surprises.
223
+
224
+ For example, if you use `asyncio.gather(t1, t2, t3)` and t2 raises an exception, then t1 and
225
+ t3 would continue running. With `TaskContext.gather(t1, t2, t3)`, they are cancelled.
226
+
227
+ (It's still acceptable to use `asyncio.gather()` if you don't need cancellation — for
228
+ example, if you're just gathering quick coroutines with no side-effects. Or if you're
229
+ gathering the tasks with `return_exceptions=True`.)
230
+
231
+ Usage:
232
+
233
+ ```python notest
234
+ # Example 1: Await three coroutines
235
+ created_object, other_work, new_plumbing = await TaskContext.gather(
236
+ create_my_object(),
237
+ do_some_other_work(),
238
+ fix_plumbing(),
239
+ )
240
+
241
+ # Example 2: Gather a list of coroutines
242
+ coros = [a.load() for a in objects]
243
+ results = await TaskContext.gather(*coros)
244
+ ```
245
+ """
246
+ async with TaskContext() as tc:
247
+ results = await asyncio.gather(*(tc.create_task(coro) for coro in coros))
248
+ return results
220
249
 
221
250
 
222
251
  def run_coro_blocking(coro):
@@ -234,8 +263,10 @@ def run_coro_blocking(coro):
234
263
  async def queue_batch_iterator(q: asyncio.Queue, max_batch_size=100, debounce_time=0.015):
235
264
  """
236
265
  Read from a queue but return lists of items when queue is large
266
+
267
+ Treats a None value as end of queue items
237
268
  """
238
- item_list: List[Any] = []
269
+ item_list: list[Any] = []
239
270
 
240
271
  while True:
241
272
  if q.empty() and len(item_list) > 0:
@@ -257,44 +288,102 @@ async def queue_batch_iterator(q: asyncio.Queue, max_batch_size=100, debounce_ti
257
288
 
258
289
 
259
290
  class _WarnIfGeneratorIsNotConsumed:
260
- def __init__(self, gen, gen_f):
291
+ def __init__(self, gen, function_name: str):
261
292
  self.gen = gen
262
- self.gen_f = gen_f
293
+ self.function_name = function_name
263
294
  self.iterated = False
264
295
  self.warned = False
265
296
 
266
297
  def __aiter__(self):
267
298
  self.iterated = True
268
- return self.gen
299
+ return self.gen.__aiter__()
269
300
 
270
301
  async def __anext__(self):
271
302
  self.iterated = True
272
303
  return await self.gen.__anext__()
273
304
 
305
+ async def asend(self, value):
306
+ self.iterated = True
307
+ return await self.gen.asend(value)
308
+
274
309
  def __repr__(self):
275
310
  return repr(self.gen)
276
311
 
277
312
  def __del__(self):
278
313
  if not self.iterated and not self.warned:
279
314
  self.warned = True
280
- name = self.gen_f.__name__
281
315
  logger.warning(
282
- f"Warning: the results of a call to {name} was not consumed, so the call will never be executed."
283
- f" Consider a for-loop like `for x in {name}(...)` or unpacking the generator using `list(...)`"
316
+ f"Warning: the results of a call to {self.function_name} was not consumed, "
317
+ "so the call will never be executed."
318
+ f" Consider a for-loop like `for x in {self.function_name}(...)` or "
319
+ "unpacking the generator using `list(...)`"
284
320
  )
285
321
 
322
+ async def athrow(self, exc):
323
+ return await self.gen.athrow(exc)
324
+
325
+ async def aclose(self):
326
+ return await self.gen.aclose()
327
+
286
328
 
287
329
  synchronize_api(_WarnIfGeneratorIsNotConsumed)
288
330
 
289
331
 
290
- def warn_if_generator_is_not_consumed(gen_f):
332
+ class _WarnIfNonWrappedGeneratorIsNotConsumed(_WarnIfGeneratorIsNotConsumed):
333
+ # used for non-synchronicity-wrapped generators and iterators
334
+ def __iter__(self):
335
+ self.iterated = True
336
+ return iter(self.gen)
337
+
338
+ def __next__(self):
339
+ self.iterated = True
340
+ return self.gen.__next__()
341
+
342
+ def send(self, value):
343
+ self.iterated = True
344
+ return self.gen.send(value)
345
+
346
+
347
+ def warn_if_generator_is_not_consumed(function_name: Optional[str] = None):
291
348
  # https://gist.github.com/erikbern/01ae78d15f89edfa7f77e5c0a827a94d
292
- @functools.wraps(gen_f)
293
- def f_wrapped(*args, **kwargs):
294
- gen = gen_f(*args, **kwargs)
295
- return _WarnIfGeneratorIsNotConsumed(gen, gen_f)
349
+ def decorator(gen_f):
350
+ presented_func_name = function_name if function_name is not None else gen_f.__name__
296
351
 
297
- return f_wrapped
352
+ @functools.wraps(gen_f)
353
+ def f_wrapped(*args, **kwargs):
354
+ gen = gen_f(*args, **kwargs)
355
+ if inspect.isasyncgen(gen):
356
+ return _WarnIfGeneratorIsNotConsumed(gen, presented_func_name)
357
+ else:
358
+ return _WarnIfNonWrappedGeneratorIsNotConsumed(gen, presented_func_name)
359
+
360
+ return f_wrapped
361
+
362
+ return decorator
363
+
364
+
365
+ class AsyncOrSyncIterable:
366
+ """Compatibility class for non-synchronicity wrapped async iterables to get
367
+ both async and sync interfaces in the same way that synchronicity does (but on the main thread)
368
+ so they can be "lazily" iterated using either `for _ in x` or `async for _ in x`
369
+
370
+ nested_async_message is raised as an InvalidError if the async variant is called
371
+ from an already async context, since that would otherwise deadlock the event loop
372
+ """
373
+
374
+ def __init__(self, async_iterable: typing.AsyncGenerator[Any, None], nested_async_message):
375
+ self._async_iterable = async_iterable
376
+ self.nested_async_message = nested_async_message
377
+
378
+ def __aiter__(self):
379
+ return self._async_iterable
380
+
381
+ def __iter__(self):
382
+ try:
383
+ with Runner() as runner:
384
+ yield from run_async_gen(runner, self._async_iterable)
385
+ except NestedEventLoops:
386
+ raise InvalidError(self.nested_async_message)
298
387
 
299
388
 
300
389
  _shutdown_tasks = []
@@ -314,6 +403,7 @@ def on_shutdown(coro):
314
403
 
315
404
  T = TypeVar("T")
316
405
  P = ParamSpec("P")
406
+ V = TypeVar("V")
317
407
 
318
408
 
319
409
  def asyncify(f: Callable[P, T]) -> Callable[P, typing.Coroutine[None, None, T]]:
@@ -339,40 +429,6 @@ async def iterate_blocking(iterator: Iterator[T]) -> AsyncGenerator[T, None]:
339
429
  yield cast(T, obj)
340
430
 
341
431
 
342
- class ConcurrencyPool:
343
- def __init__(self, concurrency_limit: int):
344
- self.semaphore = asyncio.Semaphore(concurrency_limit)
345
-
346
- async def run_coros(self, coros: typing.Iterable[typing.Coroutine], return_exceptions=False):
347
- async def blocking_wrapper(coro):
348
- # Not using async with on the semaphore is intentional here - if return_exceptions=False
349
- # manual release prevents starting extraneous tasks after exceptions.
350
- try:
351
- await self.semaphore.acquire()
352
- except asyncio.CancelledError:
353
- coro.close() # avoid "coroutine was never awaited" warnings
354
-
355
- try:
356
- res = await coro
357
- self.semaphore.release()
358
- return res
359
- except BaseException as e:
360
- if return_exceptions:
361
- self.semaphore.release()
362
- raise e
363
-
364
- # asyncio.gather() is weird - it doesn't cancel outstanding awaitables on exceptions when
365
- # return_exceptions=False --> wrap the coros in tasks are cancel them explicitly on exception.
366
- tasks = [asyncio.create_task(blocking_wrapper(coro)) for coro in coros]
367
- g = asyncio.gather(*tasks, return_exceptions=return_exceptions)
368
- try:
369
- return await g
370
- except BaseException as e:
371
- for t in tasks:
372
- t.cancel()
373
- raise e
374
-
375
-
376
432
  @asynccontextmanager
377
433
  async def asyncnullcontext(*args, **kwargs):
378
434
  """Async noop context manager.
@@ -384,3 +440,326 @@ async def asyncnullcontext(*args, **kwargs):
384
440
  pass
385
441
  """
386
442
  yield
443
+
444
+
445
+ YIELD_TYPE = typing.TypeVar("YIELD_TYPE")
446
+ SEND_TYPE = typing.TypeVar("SEND_TYPE")
447
+
448
+
449
+ def run_async_gen(
450
+ runner: Runner,
451
+ gen: typing.AsyncGenerator[YIELD_TYPE, SEND_TYPE],
452
+ ) -> typing.Generator[YIELD_TYPE, SEND_TYPE, None]:
453
+ """Convert an async generator into a sync one"""
454
+ # more or less copied from synchronicity's implementation:
455
+ next_send: typing.Union[SEND_TYPE, None] = None
456
+ next_yield: YIELD_TYPE
457
+ exc: Optional[BaseException] = None
458
+ while True:
459
+ try:
460
+ if exc:
461
+ next_yield = runner.run(gen.athrow(exc))
462
+ else:
463
+ next_yield = runner.run(gen.asend(next_send)) # type: ignore[arg-type]
464
+ except KeyboardInterrupt as e:
465
+ raise e from None
466
+ except StopAsyncIteration:
467
+ break # typically a graceful exit of the async generator
468
+ try:
469
+ next_send = yield next_yield
470
+ exc = None
471
+ except BaseException as err:
472
+ exc = err
473
+
474
+
475
+ class aclosing(typing.Generic[T]): # noqa
476
+ # backport of Python contextlib.aclosing from Python 3.10
477
+ def __init__(self, agen: AsyncGenerator[T, None]):
478
+ self.agen = agen
479
+
480
+ async def __aenter__(self) -> AsyncGenerator[T, None]:
481
+ return self.agen
482
+
483
+ async def __aexit__(self, exc, exc_type, tb):
484
+ await self.agen.aclose()
485
+
486
+
487
+ async def sync_or_async_iter(iter: Union[Iterable[T], AsyncIterable[T]]) -> AsyncGenerator[T, None]:
488
+ if hasattr(iter, "__aiter__"):
489
+ agen = typing.cast(AsyncGenerator[T, None], iter)
490
+ try:
491
+ async for item in agen:
492
+ yield item
493
+ finally:
494
+ if hasattr(agen, "aclose"):
495
+ # All AsyncGenerator's have an aclose method
496
+ # but some AsyncIterable's don't necessarily
497
+ await agen.aclose()
498
+ else:
499
+ assert hasattr(iter, "__iter__"), "sync_or_async_iter requires an Iterable or AsyncGenerator"
500
+ # This intentionally could block the event loop for the duration of calling __iter__ and __next__,
501
+ # so in non-trivial cases (like passing lists and ranges) this could be quite a foot gun for users #
502
+ # w/ async code (but they can work around it by always using async iterators)
503
+ for item in typing.cast(Iterable[T], iter):
504
+ yield item
505
+
506
+
507
+ @typing.overload
508
+ def async_zip(g1: AsyncGenerator[T, None], g2: AsyncGenerator[V, None], /) -> AsyncGenerator[tuple[T, V], None]:
509
+ ...
510
+
511
+
512
+ @typing.overload
513
+ def async_zip(*generators: AsyncGenerator[T, None]) -> AsyncGenerator[tuple[T, ...], None]:
514
+ ...
515
+
516
+
517
+ async def async_zip(*generators):
518
+ tasks = []
519
+ try:
520
+ while True:
521
+ try:
522
+
523
+ async def next_item(gen):
524
+ return await gen.__anext__()
525
+
526
+ tasks = [asyncio.create_task(next_item(gen)) for gen in generators]
527
+ items = await asyncio.gather(*tasks)
528
+ yield tuple(items)
529
+ except StopAsyncIteration:
530
+ break
531
+ finally:
532
+ cancelled_tasks = []
533
+ for task in tasks:
534
+ if not task.done():
535
+ task.cancel()
536
+ cancelled_tasks.append(task)
537
+ try:
538
+ await asyncio.gather(*cancelled_tasks)
539
+ except asyncio.CancelledError:
540
+ pass
541
+
542
+ first_exception = None
543
+ for gen in generators:
544
+ try:
545
+ await gen.aclose()
546
+ except BaseException as e:
547
+ if first_exception is None:
548
+ first_exception = e
549
+ logger.exception(f"Error closing async generator: {e}")
550
+ if first_exception is not None:
551
+ raise first_exception
552
+
553
+
554
+ @dataclass
555
+ class ValueWrapper(typing.Generic[T]):
556
+ value: T
557
+
558
+
559
+ @dataclass
560
+ class ExceptionWrapper:
561
+ value: Exception
562
+
563
+
564
+ class StopSentinelType:
565
+ ...
566
+
567
+
568
+ STOP_SENTINEL = StopSentinelType()
569
+
570
+
571
+ async def async_merge(*generators: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
572
+ """
573
+ Asynchronously merges multiple async generators into a single async generator.
574
+
575
+ This function takes multiple async generators and yields their values in the order
576
+ they are produced. If any generator raises an exception, the exception is propagated.
577
+
578
+ Args:
579
+ *generators: One or more async generators to be merged.
580
+
581
+ Yields:
582
+ The values produced by the input async generators.
583
+
584
+ Raises:
585
+ Exception: If any of the input generators raises an exception, it is propagated.
586
+
587
+ Usage:
588
+ ```python
589
+ import asyncio
590
+ from modal._utils.async_utils import async_merge
591
+
592
+ async def gen1():
593
+ yield 1
594
+ yield 2
595
+
596
+ async def gen2():
597
+ yield "a"
598
+ yield "b"
599
+
600
+ async def example():
601
+ values = set()
602
+ async for value in async_merge(gen1(), gen2()):
603
+ values.add(value)
604
+
605
+ return values
606
+
607
+ # Output could be: {1, "a", 2, "b"} (order may vary)
608
+ values = asyncio.run(example())
609
+ assert values == {1, "a", 2, "b"}
610
+ ```
611
+ """
612
+ queue: asyncio.Queue[Union[ValueWrapper[T], ExceptionWrapper]] = asyncio.Queue(maxsize=len(generators) * 10)
613
+
614
+ async def producer(generator: AsyncGenerator[T, None]):
615
+ try:
616
+ async for item in generator:
617
+ await queue.put(ValueWrapper(item))
618
+ except Exception as e:
619
+ await queue.put(ExceptionWrapper(e))
620
+
621
+ tasks = {asyncio.create_task(producer(gen)) for gen in generators}
622
+ new_output_task = asyncio.create_task(queue.get())
623
+
624
+ try:
625
+ while tasks:
626
+ done, _ = await asyncio.wait(
627
+ [*tasks, new_output_task],
628
+ return_when=asyncio.FIRST_COMPLETED,
629
+ )
630
+
631
+ if new_output_task in done:
632
+ item = new_output_task.result()
633
+ if isinstance(item, ValueWrapper):
634
+ yield item.value
635
+ else:
636
+ assert_type(item, ExceptionWrapper)
637
+ raise item.value
638
+
639
+ new_output_task = asyncio.create_task(queue.get())
640
+
641
+ finished_producers = done & tasks
642
+ tasks -= finished_producers
643
+ for finished_producer in finished_producers:
644
+ # this is done in order to catch potential raised errors/cancellations
645
+ # from within worker tasks as soon as they happen.
646
+ await finished_producer
647
+
648
+ while not queue.empty():
649
+ item = await new_output_task
650
+ if isinstance(item, ValueWrapper):
651
+ yield item.value
652
+ else:
653
+ assert_type(item, ExceptionWrapper)
654
+ raise item.value
655
+
656
+ new_output_task = asyncio.create_task(queue.get())
657
+
658
+ finally:
659
+ if not new_output_task.done():
660
+ new_output_task.cancel()
661
+ for task in tasks:
662
+ if not task.done():
663
+ try:
664
+ task.cancel()
665
+ await task
666
+ except asyncio.CancelledError:
667
+ pass
668
+
669
+
670
+ async def callable_to_agen(awaitable: Callable[[], Awaitable[T]]) -> AsyncGenerator[T, None]:
671
+ yield await awaitable()
672
+
673
+
674
+ async def gather_cancel_on_exc(*coros_or_futures):
675
+ input_tasks = [asyncio.ensure_future(t) for t in coros_or_futures]
676
+ try:
677
+ return await asyncio.gather(*input_tasks)
678
+ except BaseException:
679
+ for t in input_tasks:
680
+ t.cancel()
681
+ await asyncio.gather(*input_tasks, return_exceptions=False) # handle cancellations
682
+ raise
683
+
684
+
685
+ async def async_map(
686
+ input_generator: AsyncGenerator[T, None],
687
+ async_mapper_func: Callable[[T], Awaitable[V]],
688
+ concurrency: int,
689
+ ) -> AsyncGenerator[V, None]:
690
+ queue: asyncio.Queue[Union[ValueWrapper[T], StopSentinelType]] = asyncio.Queue(maxsize=concurrency * 2)
691
+
692
+ async def producer() -> AsyncGenerator[V, None]:
693
+ async for item in input_generator:
694
+ await queue.put(ValueWrapper(item))
695
+
696
+ for _ in range(concurrency):
697
+ await queue.put(STOP_SENTINEL)
698
+
699
+ if False:
700
+ # Need it to be an async generator for async_merge
701
+ # but we don't want to yield anything
702
+ yield
703
+
704
+ async def worker() -> AsyncGenerator[V, None]:
705
+ while True:
706
+ item = await queue.get()
707
+ if isinstance(item, ValueWrapper):
708
+ yield await async_mapper_func(item.value)
709
+ elif isinstance(item, ExceptionWrapper):
710
+ raise item.value
711
+ else:
712
+ assert_type(item, StopSentinelType)
713
+ break
714
+
715
+ async with aclosing(async_merge(*[worker() for _ in range(concurrency)], producer())) as stream:
716
+ async for item in stream:
717
+ yield item
718
+
719
+
720
+ async def async_map_ordered(
721
+ input_generator: AsyncGenerator[T, None],
722
+ async_mapper_func: Callable[[T], Awaitable[V]],
723
+ concurrency: int,
724
+ buffer_size: Optional[int] = None,
725
+ ) -> AsyncGenerator[V, None]:
726
+ semaphore = asyncio.Semaphore(buffer_size or concurrency)
727
+
728
+ async def mapper_func_wrapper(tup: tuple[int, T]) -> tuple[int, V]:
729
+ return (tup[0], await async_mapper_func(tup[1]))
730
+
731
+ async def counter() -> AsyncGenerator[int, None]:
732
+ for i in itertools.count():
733
+ await semaphore.acquire()
734
+ yield i
735
+
736
+ next_idx = 0
737
+ buffer = {}
738
+
739
+ async with aclosing(async_map(async_zip(counter(), input_generator), mapper_func_wrapper, concurrency)) as stream:
740
+ async for output_idx, output_item in stream:
741
+ buffer[output_idx] = output_item
742
+
743
+ while next_idx in buffer:
744
+ yield buffer[next_idx]
745
+ semaphore.release()
746
+ del buffer[next_idx]
747
+ next_idx += 1
748
+
749
+
750
+ async def async_chain(*generators: AsyncGenerator[T, None]) -> AsyncGenerator[T, None]:
751
+ try:
752
+ for gen in generators:
753
+ async for item in gen:
754
+ yield item
755
+ finally:
756
+ first_exception = None
757
+ for gen in generators:
758
+ try:
759
+ await gen.aclose()
760
+ except BaseException as e:
761
+ if first_exception is None:
762
+ first_exception = e
763
+ logger.exception(f"Error closing async generator: {e}")
764
+ if first_exception is not None:
765
+ raise first_exception