modal 0.62.16__py3-none-any.whl → 0.72.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. modal/__init__.py +17 -13
  2. modal/__main__.py +41 -3
  3. modal/_clustered_functions.py +80 -0
  4. modal/_clustered_functions.pyi +22 -0
  5. modal/_container_entrypoint.py +420 -937
  6. modal/_ipython.py +3 -13
  7. modal/_location.py +17 -10
  8. modal/_output.py +243 -99
  9. modal/_pty.py +2 -2
  10. modal/_resolver.py +55 -59
  11. modal/_resources.py +51 -0
  12. modal/_runtime/__init__.py +1 -0
  13. modal/_runtime/asgi.py +519 -0
  14. modal/_runtime/container_io_manager.py +1036 -0
  15. modal/_runtime/execution_context.py +89 -0
  16. modal/_runtime/telemetry.py +169 -0
  17. modal/_runtime/user_code_imports.py +356 -0
  18. modal/_serialization.py +134 -9
  19. modal/_traceback.py +47 -187
  20. modal/_tunnel.py +52 -16
  21. modal/_tunnel.pyi +19 -36
  22. modal/_utils/app_utils.py +3 -17
  23. modal/_utils/async_utils.py +479 -100
  24. modal/_utils/blob_utils.py +157 -186
  25. modal/_utils/bytes_io_segment_payload.py +97 -0
  26. modal/_utils/deprecation.py +89 -0
  27. modal/_utils/docker_utils.py +98 -0
  28. modal/_utils/function_utils.py +460 -171
  29. modal/_utils/grpc_testing.py +47 -31
  30. modal/_utils/grpc_utils.py +62 -109
  31. modal/_utils/hash_utils.py +61 -19
  32. modal/_utils/http_utils.py +39 -9
  33. modal/_utils/logger.py +2 -1
  34. modal/_utils/mount_utils.py +34 -16
  35. modal/_utils/name_utils.py +58 -0
  36. modal/_utils/package_utils.py +14 -1
  37. modal/_utils/pattern_utils.py +205 -0
  38. modal/_utils/rand_pb_testing.py +5 -7
  39. modal/_utils/shell_utils.py +15 -49
  40. modal/_vendor/a2wsgi_wsgi.py +62 -72
  41. modal/_vendor/cloudpickle.py +1 -1
  42. modal/_watcher.py +14 -12
  43. modal/app.py +1003 -314
  44. modal/app.pyi +540 -264
  45. modal/call_graph.py +7 -6
  46. modal/cli/_download.py +63 -53
  47. modal/cli/_traceback.py +200 -0
  48. modal/cli/app.py +205 -45
  49. modal/cli/config.py +12 -5
  50. modal/cli/container.py +62 -14
  51. modal/cli/dict.py +128 -0
  52. modal/cli/entry_point.py +26 -13
  53. modal/cli/environment.py +40 -9
  54. modal/cli/import_refs.py +64 -58
  55. modal/cli/launch.py +32 -18
  56. modal/cli/network_file_system.py +64 -83
  57. modal/cli/profile.py +1 -1
  58. modal/cli/programs/run_jupyter.py +35 -10
  59. modal/cli/programs/vscode.py +60 -10
  60. modal/cli/queues.py +131 -0
  61. modal/cli/run.py +234 -131
  62. modal/cli/secret.py +8 -7
  63. modal/cli/token.py +7 -2
  64. modal/cli/utils.py +79 -10
  65. modal/cli/volume.py +110 -109
  66. modal/client.py +250 -144
  67. modal/client.pyi +157 -118
  68. modal/cloud_bucket_mount.py +108 -34
  69. modal/cloud_bucket_mount.pyi +32 -38
  70. modal/cls.py +535 -148
  71. modal/cls.pyi +190 -146
  72. modal/config.py +41 -19
  73. modal/container_process.py +177 -0
  74. modal/container_process.pyi +82 -0
  75. modal/dict.py +111 -65
  76. modal/dict.pyi +136 -131
  77. modal/environments.py +106 -5
  78. modal/environments.pyi +77 -25
  79. modal/exception.py +34 -43
  80. modal/experimental.py +61 -2
  81. modal/extensions/ipython.py +5 -5
  82. modal/file_io.py +537 -0
  83. modal/file_io.pyi +235 -0
  84. modal/file_pattern_matcher.py +197 -0
  85. modal/functions.py +906 -911
  86. modal/functions.pyi +466 -430
  87. modal/gpu.py +57 -44
  88. modal/image.py +1089 -479
  89. modal/image.pyi +584 -228
  90. modal/io_streams.py +434 -0
  91. modal/io_streams.pyi +122 -0
  92. modal/mount.py +314 -101
  93. modal/mount.pyi +241 -235
  94. modal/network_file_system.py +92 -92
  95. modal/network_file_system.pyi +152 -110
  96. modal/object.py +67 -36
  97. modal/object.pyi +166 -143
  98. modal/output.py +63 -0
  99. modal/parallel_map.py +434 -0
  100. modal/parallel_map.pyi +75 -0
  101. modal/partial_function.py +282 -117
  102. modal/partial_function.pyi +222 -129
  103. modal/proxy.py +15 -12
  104. modal/proxy.pyi +3 -8
  105. modal/queue.py +182 -65
  106. modal/queue.pyi +218 -118
  107. modal/requirements/2024.04.txt +29 -0
  108. modal/requirements/2024.10.txt +16 -0
  109. modal/requirements/README.md +21 -0
  110. modal/requirements/base-images.json +22 -0
  111. modal/retries.py +48 -7
  112. modal/runner.py +459 -156
  113. modal/runner.pyi +135 -71
  114. modal/running_app.py +38 -0
  115. modal/sandbox.py +514 -236
  116. modal/sandbox.pyi +397 -169
  117. modal/schedule.py +4 -4
  118. modal/scheduler_placement.py +20 -3
  119. modal/secret.py +56 -31
  120. modal/secret.pyi +62 -42
  121. modal/serving.py +51 -56
  122. modal/serving.pyi +44 -36
  123. modal/stream_type.py +15 -0
  124. modal/token_flow.py +5 -3
  125. modal/token_flow.pyi +37 -32
  126. modal/volume.py +285 -157
  127. modal/volume.pyi +249 -184
  128. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/METADATA +7 -7
  129. modal-0.72.11.dist-info/RECORD +174 -0
  130. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/top_level.txt +0 -1
  131. modal_docs/gen_reference_docs.py +3 -1
  132. modal_docs/mdmd/mdmd.py +0 -1
  133. modal_docs/mdmd/signatures.py +5 -2
  134. modal_global_objects/images/base_images.py +28 -0
  135. modal_global_objects/mounts/python_standalone.py +2 -2
  136. modal_proto/__init__.py +1 -1
  137. modal_proto/api.proto +1288 -533
  138. modal_proto/api_grpc.py +856 -456
  139. modal_proto/api_pb2.py +2165 -1157
  140. modal_proto/api_pb2.pyi +8859 -0
  141. modal_proto/api_pb2_grpc.py +1674 -855
  142. modal_proto/api_pb2_grpc.pyi +1416 -0
  143. modal_proto/modal_api_grpc.py +149 -0
  144. modal_proto/modal_options_grpc.py +3 -0
  145. modal_proto/options_pb2.pyi +20 -0
  146. modal_proto/options_pb2_grpc.pyi +7 -0
  147. modal_proto/py.typed +0 -0
  148. modal_version/__init__.py +1 -1
  149. modal_version/_version_generated.py +2 -2
  150. modal/_asgi.py +0 -370
  151. modal/_container_entrypoint.pyi +0 -378
  152. modal/_container_exec.py +0 -128
  153. modal/_sandbox_shell.py +0 -49
  154. modal/shared_volume.py +0 -23
  155. modal/shared_volume.pyi +0 -24
  156. modal/stub.py +0 -783
  157. modal/stub.pyi +0 -332
  158. modal-0.62.16.dist-info/RECORD +0 -198
  159. modal_global_objects/images/conda.py +0 -15
  160. modal_global_objects/images/debian_slim.py +0 -15
  161. modal_global_objects/images/micromamba.py +0 -15
  162. test/__init__.py +0 -1
  163. test/aio_test.py +0 -12
  164. test/async_utils_test.py +0 -262
  165. test/blob_test.py +0 -67
  166. test/cli_imports_test.py +0 -149
  167. test/cli_test.py +0 -659
  168. test/client_test.py +0 -194
  169. test/cls_test.py +0 -630
  170. test/config_test.py +0 -137
  171. test/conftest.py +0 -1420
  172. test/container_app_test.py +0 -32
  173. test/container_test.py +0 -1389
  174. test/cpu_test.py +0 -23
  175. test/decorator_test.py +0 -85
  176. test/deprecation_test.py +0 -34
  177. test/dict_test.py +0 -33
  178. test/e2e_test.py +0 -68
  179. test/error_test.py +0 -7
  180. test/function_serialization_test.py +0 -32
  181. test/function_test.py +0 -653
  182. test/function_utils_test.py +0 -101
  183. test/gpu_test.py +0 -159
  184. test/grpc_utils_test.py +0 -141
  185. test/helpers.py +0 -42
  186. test/image_test.py +0 -669
  187. test/live_reload_test.py +0 -80
  188. test/lookup_test.py +0 -70
  189. test/mdmd_test.py +0 -329
  190. test/mount_test.py +0 -162
  191. test/mounted_files_test.py +0 -329
  192. test/network_file_system_test.py +0 -181
  193. test/notebook_test.py +0 -66
  194. test/object_test.py +0 -41
  195. test/package_utils_test.py +0 -25
  196. test/queue_test.py +0 -97
  197. test/resolver_test.py +0 -58
  198. test/retries_test.py +0 -67
  199. test/runner_test.py +0 -85
  200. test/sandbox_test.py +0 -191
  201. test/schedule_test.py +0 -15
  202. test/scheduler_placement_test.py +0 -29
  203. test/secret_test.py +0 -78
  204. test/serialization_test.py +0 -42
  205. test/stub_composition_test.py +0 -10
  206. test/stub_test.py +0 -360
  207. test/test_asgi_wrapper.py +0 -234
  208. test/token_flow_test.py +0 -18
  209. test/traceback_test.py +0 -135
  210. test/tunnel_test.py +0 -29
  211. test/utils_test.py +0 -88
  212. test/version_test.py +0 -14
  213. test/volume_test.py +0 -341
  214. test/watcher_test.py +0 -30
  215. test/webhook_test.py +0 -146
  216. /modal/{requirements.312.txt → requirements/2023.12.312.txt} +0 -0
  217. /modal/{requirements.txt → requirements/2023.12.txt} +0 -0
  218. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/LICENSE +0 -0
  219. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/WHEEL +0 -0
  220. {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/entry_points.txt +0 -0
modal/parallel_map.py ADDED
@@ -0,0 +1,434 @@
1
+ # Copyright Modal Labs 2024
2
+ import asyncio
3
+ import time
4
+ import typing
5
+ from dataclasses import dataclass
6
+ from typing import Any, Callable, Optional
7
+
8
+ from grpclib import GRPCError, Status
9
+
10
+ from modal._runtime.execution_context import current_input_id
11
+ from modal._utils.async_utils import (
12
+ AsyncOrSyncIterable,
13
+ aclosing,
14
+ async_map_ordered,
15
+ async_merge,
16
+ async_zip,
17
+ queue_batch_iterator,
18
+ sync_or_async_iter,
19
+ synchronize_api,
20
+ synchronizer,
21
+ warn_if_generator_is_not_consumed,
22
+ )
23
+ from modal._utils.blob_utils import BLOB_MAX_PARALLELISM
24
+ from modal._utils.function_utils import (
25
+ ATTEMPT_TIMEOUT_GRACE_PERIOD,
26
+ OUTPUTS_TIMEOUT,
27
+ _create_input,
28
+ _process_result,
29
+ )
30
+ from modal._utils.grpc_utils import retry_transient_errors
31
+ from modal.config import logger
32
+ from modal_proto import api_pb2
33
+
34
+ if typing.TYPE_CHECKING:
35
+ import modal.client
36
+
37
+
38
+ class _SynchronizedQueue:
39
+ """mdmd:hidden"""
40
+
41
+ # small wrapper around asyncio.Queue to make it cross-thread compatible through synchronicity
42
+ async def init(self):
43
+ # in Python 3.8 the asyncio.Queue is bound to the event loop on creation
44
+ # so it needs to be created in a synchronicity-wrapped init method
45
+ self.q = asyncio.Queue()
46
+
47
+ @synchronizer.no_io_translation
48
+ async def put(self, item):
49
+ await self.q.put(item)
50
+
51
+ @synchronizer.no_io_translation
52
+ async def get(self):
53
+ return await self.q.get()
54
+
55
+
56
+ SynchronizedQueue = synchronize_api(_SynchronizedQueue)
57
+
58
+
59
+ @dataclass
60
+ class _OutputValue:
61
+ # box class for distinguishing None results from non-existing/None markers
62
+ value: Any
63
+
64
+
65
+ MAP_INVOCATION_CHUNK_SIZE = 49
66
+
67
+ if typing.TYPE_CHECKING:
68
+ import modal.functions
69
+
70
+
71
+ async def _map_invocation(
72
+ function: "modal.functions._Function",
73
+ raw_input_queue: _SynchronizedQueue,
74
+ client: "modal.client._Client",
75
+ order_outputs: bool,
76
+ return_exceptions: bool,
77
+ count_update_callback: Optional[Callable[[int, int], None]],
78
+ ):
79
+ assert client.stub
80
+ request = api_pb2.FunctionMapRequest(
81
+ function_id=function.object_id,
82
+ parent_input_id=current_input_id() or "",
83
+ function_call_type=api_pb2.FUNCTION_CALL_TYPE_MAP,
84
+ return_exceptions=return_exceptions,
85
+ )
86
+ response = await retry_transient_errors(client.stub.FunctionMap, request)
87
+
88
+ function_call_id = response.function_call_id
89
+
90
+ have_all_inputs = False
91
+ num_inputs = 0
92
+ num_outputs = 0
93
+
94
+ def count_update():
95
+ if count_update_callback is not None:
96
+ count_update_callback(num_outputs, num_inputs)
97
+
98
+ pending_outputs: dict[str, int] = {} # Map input_id -> next expected gen_index value
99
+ completed_outputs: set[str] = set() # Set of input_ids whose outputs are complete (expecting no more values)
100
+
101
+ input_queue: asyncio.Queue = asyncio.Queue()
102
+
103
+ async def create_input(argskwargs):
104
+ nonlocal num_inputs
105
+ idx = num_inputs
106
+ num_inputs += 1
107
+ (args, kwargs) = argskwargs
108
+ return await _create_input(args, kwargs, client, idx=idx, method_name=function._use_method_name)
109
+
110
+ async def input_iter():
111
+ while 1:
112
+ raw_input = await raw_input_queue.get()
113
+ if raw_input is None: # end of input sentinel
114
+ break
115
+ yield raw_input # args, kwargs
116
+
117
+ async def drain_input_generator():
118
+ # Parallelize uploading blobs
119
+ async with aclosing(
120
+ async_map_ordered(input_iter(), create_input, concurrency=BLOB_MAX_PARALLELISM)
121
+ ) as streamer:
122
+ async for item in streamer:
123
+ await input_queue.put(item)
124
+
125
+ # close queue iterator
126
+ await input_queue.put(None)
127
+ yield
128
+
129
+ async def pump_inputs():
130
+ assert client.stub
131
+ nonlocal have_all_inputs, num_inputs
132
+ async for items in queue_batch_iterator(input_queue, MAP_INVOCATION_CHUNK_SIZE):
133
+ request = api_pb2.FunctionPutInputsRequest(
134
+ function_id=function.object_id, inputs=items, function_call_id=function_call_id
135
+ )
136
+ logger.debug(
137
+ f"Pushing {len(items)} inputs to server. Num queued inputs awaiting push is {input_queue.qsize()}."
138
+ )
139
+ while True:
140
+ try:
141
+ resp = await retry_transient_errors(
142
+ client.stub.FunctionPutInputs,
143
+ request,
144
+ # with 8 retries we log the warning below about every 30 secondswhich isn't too spammy.
145
+ max_retries=8,
146
+ max_delay=15,
147
+ additional_status_codes=[Status.RESOURCE_EXHAUSTED],
148
+ )
149
+ break
150
+ except GRPCError as err:
151
+ if err.status != Status.RESOURCE_EXHAUSTED:
152
+ raise err
153
+ logger.warning(
154
+ "Warning: map progress is limited. Common bottlenecks "
155
+ "include slow iteration over results, or function backlogs."
156
+ )
157
+
158
+ count_update()
159
+ for item in resp.inputs:
160
+ pending_outputs.setdefault(item.input_id, 0)
161
+ logger.debug(
162
+ f"Successfully pushed {len(items)} inputs to server. "
163
+ f"Num queued inputs awaiting push is {input_queue.qsize()}."
164
+ )
165
+
166
+ have_all_inputs = True
167
+ yield
168
+
169
+ async def get_all_outputs():
170
+ assert client.stub
171
+ nonlocal num_inputs, num_outputs, have_all_inputs
172
+ last_entry_id = "0-0"
173
+ while not have_all_inputs or len(pending_outputs) > len(completed_outputs):
174
+ request = api_pb2.FunctionGetOutputsRequest(
175
+ function_call_id=function_call_id,
176
+ timeout=OUTPUTS_TIMEOUT,
177
+ last_entry_id=last_entry_id,
178
+ clear_on_success=False,
179
+ requested_at=time.time(),
180
+ )
181
+ response = await retry_transient_errors(
182
+ client.stub.FunctionGetOutputs,
183
+ request,
184
+ max_retries=20,
185
+ attempt_timeout=OUTPUTS_TIMEOUT + ATTEMPT_TIMEOUT_GRACE_PERIOD,
186
+ )
187
+
188
+ if len(response.outputs) == 0:
189
+ continue
190
+
191
+ last_entry_id = response.last_entry_id
192
+ for item in response.outputs:
193
+ pending_outputs.setdefault(item.input_id, 0)
194
+ if item.input_id in completed_outputs:
195
+ # If this input is already completed, it means the output has already been
196
+ # processed and was received again due to a duplicate.
197
+ continue
198
+ completed_outputs.add(item.input_id)
199
+ num_outputs += 1
200
+ yield item
201
+
202
+ async def get_all_outputs_and_clean_up():
203
+ assert client.stub
204
+ try:
205
+ async with aclosing(get_all_outputs()) as output_items:
206
+ async for item in output_items:
207
+ yield item
208
+ finally:
209
+ # "ack" that we have all outputs we are interested in and let backend clear results
210
+ request = api_pb2.FunctionGetOutputsRequest(
211
+ function_call_id=function_call_id,
212
+ timeout=0,
213
+ last_entry_id="0-0",
214
+ clear_on_success=True,
215
+ requested_at=time.time(),
216
+ )
217
+ await retry_transient_errors(client.stub.FunctionGetOutputs, request)
218
+
219
+ async def fetch_output(item: api_pb2.FunctionGetOutputsItem) -> tuple[int, Any]:
220
+ try:
221
+ output = await _process_result(item.result, item.data_format, client.stub, client)
222
+ except Exception as e:
223
+ if return_exceptions:
224
+ output = e
225
+ else:
226
+ raise e
227
+ return (item.idx, output)
228
+
229
+ async def poll_outputs():
230
+ # map to store out-of-order outputs received
231
+ received_outputs = {}
232
+ output_idx = 0
233
+
234
+ async with aclosing(
235
+ async_map_ordered(get_all_outputs_and_clean_up(), fetch_output, concurrency=BLOB_MAX_PARALLELISM)
236
+ ) as streamer:
237
+ async for idx, output in streamer:
238
+ count_update()
239
+ if not order_outputs:
240
+ yield _OutputValue(output)
241
+ else:
242
+ # hold on to outputs for function maps, so we can reorder them correctly.
243
+ received_outputs[idx] = output
244
+ while output_idx in received_outputs:
245
+ output = received_outputs.pop(output_idx)
246
+ yield _OutputValue(output)
247
+ output_idx += 1
248
+
249
+ assert len(received_outputs) == 0
250
+
251
+ async with aclosing(async_merge(drain_input_generator(), pump_inputs(), poll_outputs())) as streamer:
252
+ async for response in streamer:
253
+ if response is not None:
254
+ yield response.value
255
+
256
+
257
+ @warn_if_generator_is_not_consumed(function_name="Function.map")
258
+ def _map_sync(
259
+ self,
260
+ *input_iterators: typing.Iterable[Any], # one input iterator per argument in the mapped-over function/generator
261
+ kwargs={}, # any extra keyword arguments for the function
262
+ order_outputs: bool = True, # return outputs in order
263
+ return_exceptions: bool = False, # propagate exceptions (False) or aggregate them in the results list (True)
264
+ ) -> AsyncOrSyncIterable:
265
+ """Parallel map over a set of inputs.
266
+
267
+ Takes one iterator argument per argument in the function being mapped over.
268
+
269
+ Example:
270
+ ```python
271
+ @app.function()
272
+ def my_func(a):
273
+ return a ** 2
274
+
275
+
276
+ @app.local_entrypoint()
277
+ def main():
278
+ assert list(my_func.map([1, 2, 3, 4])) == [1, 4, 9, 16]
279
+ ```
280
+
281
+ If applied to a `stub.function`, `map()` returns one result per input and the output order
282
+ is guaranteed to be the same as the input order. Set `order_outputs=False` to return results
283
+ in the order that they are completed instead.
284
+
285
+ `return_exceptions` can be used to treat exceptions as successful results:
286
+
287
+ ```python
288
+ @app.function()
289
+ def my_func(a):
290
+ if a == 2:
291
+ raise Exception("ohno")
292
+ return a ** 2
293
+
294
+
295
+ @app.local_entrypoint()
296
+ def main():
297
+ # [0, 1, UserCodeException(Exception('ohno'))]
298
+ print(list(my_func.map(range(3), return_exceptions=True)))
299
+ ```
300
+ """
301
+
302
+ return AsyncOrSyncIterable(
303
+ _map_async(
304
+ self, *input_iterators, kwargs=kwargs, order_outputs=order_outputs, return_exceptions=return_exceptions
305
+ ),
306
+ nested_async_message=(
307
+ "You can't iter(Function.map()) or Function.for_each() from an async function. "
308
+ "Use async for ... Function.map.aio() or Function.for_each.aio() instead."
309
+ ),
310
+ )
311
+
312
+
313
+ @warn_if_generator_is_not_consumed(function_name="Function.map.aio")
314
+ async def _map_async(
315
+ self,
316
+ *input_iterators: typing.Union[
317
+ typing.Iterable[Any], typing.AsyncIterable[Any]
318
+ ], # one input iterator per argument in the mapped-over function/generator
319
+ kwargs={}, # any extra keyword arguments for the function
320
+ order_outputs: bool = True, # return outputs in order
321
+ return_exceptions: bool = False, # propagate exceptions (False) or aggregate them in the results list (True)
322
+ ) -> typing.AsyncGenerator[Any, None]:
323
+ """mdmd:hidden
324
+ This runs in an event loop on the main thread
325
+
326
+ It concurrently feeds new input to the input queue and yields available outputs
327
+ to the caller.
328
+ Note that since the iterator(s) can block, it's a bit opaque how often the event
329
+ loop decides to get a new input vs how often it will emit a new output.
330
+ We could make this explicit as an improvement or even let users decide what they
331
+ prefer: throughput (prioritize queueing inputs) or latency (prioritize yielding results)
332
+ """
333
+ raw_input_queue: Any = SynchronizedQueue() # type: ignore
334
+ raw_input_queue.init()
335
+
336
+ async def feed_queue():
337
+ # This runs in a main thread event loop, so it doesn't block the synchronizer loop
338
+ async with aclosing(async_zip(*[sync_or_async_iter(it) for it in input_iterators])) as streamer:
339
+ async for args in streamer:
340
+ await raw_input_queue.put.aio((args, kwargs))
341
+ await raw_input_queue.put.aio(None) # end-of-input sentinel
342
+
343
+ feed_input_task = asyncio.create_task(feed_queue())
344
+
345
+ try:
346
+ # note that `map()` and `map.aio()` are not synchronicity-wrapped, since
347
+ # they accept executable code in the form of
348
+ # iterators that we don't want to run inside the synchronicity thread.
349
+ # Instead, we delegate to `._map()` with a safer Queue as input
350
+ async with aclosing(self._map.aio(raw_input_queue, order_outputs, return_exceptions)) as map_output_stream:
351
+ async for output in map_output_stream:
352
+ yield output
353
+ finally:
354
+ feed_input_task.cancel() # should only be needed in case of exceptions
355
+
356
+
357
+ def _for_each_sync(self, *input_iterators, kwargs={}, ignore_exceptions: bool = False):
358
+ """Execute function for all inputs, ignoring outputs.
359
+
360
+ Convenient alias for `.map()` in cases where the function just needs to be called.
361
+ as the caller doesn't have to consume the generator to process the inputs.
362
+ """
363
+ # TODO(erikbern): it would be better if this is more like a map_spawn that immediately exits
364
+ # rather than iterating over the result
365
+ for _ in self.map(*input_iterators, kwargs=kwargs, order_outputs=False, return_exceptions=ignore_exceptions):
366
+ pass
367
+
368
+
369
+ async def _for_each_async(self, *input_iterators, kwargs={}, ignore_exceptions: bool = False):
370
+ async for _ in self.map.aio( # type: ignore
371
+ *input_iterators, kwargs=kwargs, order_outputs=False, return_exceptions=ignore_exceptions
372
+ ):
373
+ pass
374
+
375
+
376
+ @warn_if_generator_is_not_consumed(function_name="Function.starmap")
377
+ async def _starmap_async(
378
+ self,
379
+ input_iterator: typing.Union[typing.Iterable[typing.Sequence[Any]], typing.AsyncIterable[typing.Sequence[Any]]],
380
+ kwargs={},
381
+ order_outputs: bool = True,
382
+ return_exceptions: bool = False,
383
+ ) -> typing.AsyncIterable[Any]:
384
+ raw_input_queue: Any = SynchronizedQueue() # type: ignore
385
+ raw_input_queue.init()
386
+
387
+ async def feed_queue():
388
+ # This runs in a main thread event loop, so it doesn't block the synchronizer loop
389
+ async with aclosing(sync_or_async_iter(input_iterator)) as streamer:
390
+ async for args in streamer:
391
+ await raw_input_queue.put.aio((args, kwargs))
392
+ await raw_input_queue.put.aio(None) # end-of-input sentinel
393
+
394
+ feed_input_task = asyncio.create_task(feed_queue())
395
+ try:
396
+ async for output in self._map.aio(raw_input_queue, order_outputs, return_exceptions): # type: ignore[reportFunctionMemberAccess]
397
+ yield output
398
+ finally:
399
+ feed_input_task.cancel() # should only be needed in case of exceptions
400
+
401
+
402
+ @warn_if_generator_is_not_consumed(function_name="Function.starmap.aio")
403
+ def _starmap_sync(
404
+ self,
405
+ input_iterator: typing.Iterable[typing.Sequence[Any]],
406
+ kwargs={},
407
+ order_outputs: bool = True,
408
+ return_exceptions: bool = False,
409
+ ) -> AsyncOrSyncIterable:
410
+ """Like `map`, but spreads arguments over multiple function arguments.
411
+
412
+ Assumes every input is a sequence (e.g. a tuple).
413
+
414
+ Example:
415
+ ```python
416
+ @app.function()
417
+ def my_func(a, b):
418
+ return a + b
419
+
420
+
421
+ @app.local_entrypoint()
422
+ def main():
423
+ assert list(my_func.starmap([(1, 2), (3, 4)])) == [3, 7]
424
+ ```
425
+ """
426
+ return AsyncOrSyncIterable(
427
+ _starmap_async(
428
+ self, input_iterator, kwargs=kwargs, order_outputs=order_outputs, return_exceptions=return_exceptions
429
+ ),
430
+ nested_async_message=(
431
+ "You can't run Function.map() or Function.for_each() from an async function. "
432
+ "Use Function.map.aio()/Function.for_each.aio() instead."
433
+ ),
434
+ )
modal/parallel_map.pyi ADDED
@@ -0,0 +1,75 @@
1
+ import modal._utils.async_utils
2
+ import modal.client
3
+ import modal.functions
4
+ import typing
5
+ import typing_extensions
6
+
7
+ class _SynchronizedQueue:
8
+ async def init(self): ...
9
+ async def put(self, item): ...
10
+ async def get(self): ...
11
+
12
+ class SynchronizedQueue:
13
+ def __init__(self, /, *args, **kwargs): ...
14
+
15
+ class __init_spec(typing_extensions.Protocol):
16
+ def __call__(self): ...
17
+ async def aio(self): ...
18
+
19
+ init: __init_spec
20
+
21
+ class __put_spec(typing_extensions.Protocol):
22
+ def __call__(self, item): ...
23
+ async def aio(self, item): ...
24
+
25
+ put: __put_spec
26
+
27
+ class __get_spec(typing_extensions.Protocol):
28
+ def __call__(self): ...
29
+ async def aio(self): ...
30
+
31
+ get: __get_spec
32
+
33
+ class _OutputValue:
34
+ value: typing.Any
35
+
36
+ def __init__(self, value: typing.Any) -> None: ...
37
+ def __repr__(self): ...
38
+ def __eq__(self, other): ...
39
+
40
+ def _map_invocation(
41
+ function: modal.functions._Function,
42
+ raw_input_queue: _SynchronizedQueue,
43
+ client: modal.client._Client,
44
+ order_outputs: bool,
45
+ return_exceptions: bool,
46
+ count_update_callback: typing.Optional[typing.Callable[[int, int], None]],
47
+ ): ...
48
+ def _map_sync(
49
+ self, *input_iterators, kwargs={}, order_outputs: bool = True, return_exceptions: bool = False
50
+ ) -> modal._utils.async_utils.AsyncOrSyncIterable: ...
51
+ def _map_async(
52
+ self,
53
+ *input_iterators: typing.Union[typing.Iterable[typing.Any], typing.AsyncIterable[typing.Any]],
54
+ kwargs={},
55
+ order_outputs: bool = True,
56
+ return_exceptions: bool = False,
57
+ ) -> typing.AsyncGenerator[typing.Any, None]: ...
58
+ def _for_each_sync(self, *input_iterators, kwargs={}, ignore_exceptions: bool = False): ...
59
+ async def _for_each_async(self, *input_iterators, kwargs={}, ignore_exceptions: bool = False): ...
60
+ def _starmap_async(
61
+ self,
62
+ input_iterator: typing.Union[
63
+ typing.Iterable[typing.Sequence[typing.Any]], typing.AsyncIterable[typing.Sequence[typing.Any]]
64
+ ],
65
+ kwargs={},
66
+ order_outputs: bool = True,
67
+ return_exceptions: bool = False,
68
+ ) -> typing.AsyncIterable[typing.Any]: ...
69
+ def _starmap_sync(
70
+ self,
71
+ input_iterator: typing.Iterable[typing.Sequence[typing.Any]],
72
+ kwargs={},
73
+ order_outputs: bool = True,
74
+ return_exceptions: bool = False,
75
+ ) -> modal._utils.async_utils.AsyncOrSyncIterable: ...