modal 0.62.16__py3-none-any.whl → 0.72.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- modal/__init__.py +17 -13
- modal/__main__.py +41 -3
- modal/_clustered_functions.py +80 -0
- modal/_clustered_functions.pyi +22 -0
- modal/_container_entrypoint.py +420 -937
- modal/_ipython.py +3 -13
- modal/_location.py +17 -10
- modal/_output.py +243 -99
- modal/_pty.py +2 -2
- modal/_resolver.py +55 -59
- modal/_resources.py +51 -0
- modal/_runtime/__init__.py +1 -0
- modal/_runtime/asgi.py +519 -0
- modal/_runtime/container_io_manager.py +1036 -0
- modal/_runtime/execution_context.py +89 -0
- modal/_runtime/telemetry.py +169 -0
- modal/_runtime/user_code_imports.py +356 -0
- modal/_serialization.py +134 -9
- modal/_traceback.py +47 -187
- modal/_tunnel.py +52 -16
- modal/_tunnel.pyi +19 -36
- modal/_utils/app_utils.py +3 -17
- modal/_utils/async_utils.py +479 -100
- modal/_utils/blob_utils.py +157 -186
- modal/_utils/bytes_io_segment_payload.py +97 -0
- modal/_utils/deprecation.py +89 -0
- modal/_utils/docker_utils.py +98 -0
- modal/_utils/function_utils.py +460 -171
- modal/_utils/grpc_testing.py +47 -31
- modal/_utils/grpc_utils.py +62 -109
- modal/_utils/hash_utils.py +61 -19
- modal/_utils/http_utils.py +39 -9
- modal/_utils/logger.py +2 -1
- modal/_utils/mount_utils.py +34 -16
- modal/_utils/name_utils.py +58 -0
- modal/_utils/package_utils.py +14 -1
- modal/_utils/pattern_utils.py +205 -0
- modal/_utils/rand_pb_testing.py +5 -7
- modal/_utils/shell_utils.py +15 -49
- modal/_vendor/a2wsgi_wsgi.py +62 -72
- modal/_vendor/cloudpickle.py +1 -1
- modal/_watcher.py +14 -12
- modal/app.py +1003 -314
- modal/app.pyi +540 -264
- modal/call_graph.py +7 -6
- modal/cli/_download.py +63 -53
- modal/cli/_traceback.py +200 -0
- modal/cli/app.py +205 -45
- modal/cli/config.py +12 -5
- modal/cli/container.py +62 -14
- modal/cli/dict.py +128 -0
- modal/cli/entry_point.py +26 -13
- modal/cli/environment.py +40 -9
- modal/cli/import_refs.py +64 -58
- modal/cli/launch.py +32 -18
- modal/cli/network_file_system.py +64 -83
- modal/cli/profile.py +1 -1
- modal/cli/programs/run_jupyter.py +35 -10
- modal/cli/programs/vscode.py +60 -10
- modal/cli/queues.py +131 -0
- modal/cli/run.py +234 -131
- modal/cli/secret.py +8 -7
- modal/cli/token.py +7 -2
- modal/cli/utils.py +79 -10
- modal/cli/volume.py +110 -109
- modal/client.py +250 -144
- modal/client.pyi +157 -118
- modal/cloud_bucket_mount.py +108 -34
- modal/cloud_bucket_mount.pyi +32 -38
- modal/cls.py +535 -148
- modal/cls.pyi +190 -146
- modal/config.py +41 -19
- modal/container_process.py +177 -0
- modal/container_process.pyi +82 -0
- modal/dict.py +111 -65
- modal/dict.pyi +136 -131
- modal/environments.py +106 -5
- modal/environments.pyi +77 -25
- modal/exception.py +34 -43
- modal/experimental.py +61 -2
- modal/extensions/ipython.py +5 -5
- modal/file_io.py +537 -0
- modal/file_io.pyi +235 -0
- modal/file_pattern_matcher.py +197 -0
- modal/functions.py +906 -911
- modal/functions.pyi +466 -430
- modal/gpu.py +57 -44
- modal/image.py +1089 -479
- modal/image.pyi +584 -228
- modal/io_streams.py +434 -0
- modal/io_streams.pyi +122 -0
- modal/mount.py +314 -101
- modal/mount.pyi +241 -235
- modal/network_file_system.py +92 -92
- modal/network_file_system.pyi +152 -110
- modal/object.py +67 -36
- modal/object.pyi +166 -143
- modal/output.py +63 -0
- modal/parallel_map.py +434 -0
- modal/parallel_map.pyi +75 -0
- modal/partial_function.py +282 -117
- modal/partial_function.pyi +222 -129
- modal/proxy.py +15 -12
- modal/proxy.pyi +3 -8
- modal/queue.py +182 -65
- modal/queue.pyi +218 -118
- modal/requirements/2024.04.txt +29 -0
- modal/requirements/2024.10.txt +16 -0
- modal/requirements/README.md +21 -0
- modal/requirements/base-images.json +22 -0
- modal/retries.py +48 -7
- modal/runner.py +459 -156
- modal/runner.pyi +135 -71
- modal/running_app.py +38 -0
- modal/sandbox.py +514 -236
- modal/sandbox.pyi +397 -169
- modal/schedule.py +4 -4
- modal/scheduler_placement.py +20 -3
- modal/secret.py +56 -31
- modal/secret.pyi +62 -42
- modal/serving.py +51 -56
- modal/serving.pyi +44 -36
- modal/stream_type.py +15 -0
- modal/token_flow.py +5 -3
- modal/token_flow.pyi +37 -32
- modal/volume.py +285 -157
- modal/volume.pyi +249 -184
- {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/METADATA +7 -7
- modal-0.72.11.dist-info/RECORD +174 -0
- {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/top_level.txt +0 -1
- modal_docs/gen_reference_docs.py +3 -1
- modal_docs/mdmd/mdmd.py +0 -1
- modal_docs/mdmd/signatures.py +5 -2
- modal_global_objects/images/base_images.py +28 -0
- modal_global_objects/mounts/python_standalone.py +2 -2
- modal_proto/__init__.py +1 -1
- modal_proto/api.proto +1288 -533
- modal_proto/api_grpc.py +856 -456
- modal_proto/api_pb2.py +2165 -1157
- modal_proto/api_pb2.pyi +8859 -0
- modal_proto/api_pb2_grpc.py +1674 -855
- modal_proto/api_pb2_grpc.pyi +1416 -0
- modal_proto/modal_api_grpc.py +149 -0
- modal_proto/modal_options_grpc.py +3 -0
- modal_proto/options_pb2.pyi +20 -0
- modal_proto/options_pb2_grpc.pyi +7 -0
- modal_proto/py.typed +0 -0
- modal_version/__init__.py +1 -1
- modal_version/_version_generated.py +2 -2
- modal/_asgi.py +0 -370
- modal/_container_entrypoint.pyi +0 -378
- modal/_container_exec.py +0 -128
- modal/_sandbox_shell.py +0 -49
- modal/shared_volume.py +0 -23
- modal/shared_volume.pyi +0 -24
- modal/stub.py +0 -783
- modal/stub.pyi +0 -332
- modal-0.62.16.dist-info/RECORD +0 -198
- modal_global_objects/images/conda.py +0 -15
- modal_global_objects/images/debian_slim.py +0 -15
- modal_global_objects/images/micromamba.py +0 -15
- test/__init__.py +0 -1
- test/aio_test.py +0 -12
- test/async_utils_test.py +0 -262
- test/blob_test.py +0 -67
- test/cli_imports_test.py +0 -149
- test/cli_test.py +0 -659
- test/client_test.py +0 -194
- test/cls_test.py +0 -630
- test/config_test.py +0 -137
- test/conftest.py +0 -1420
- test/container_app_test.py +0 -32
- test/container_test.py +0 -1389
- test/cpu_test.py +0 -23
- test/decorator_test.py +0 -85
- test/deprecation_test.py +0 -34
- test/dict_test.py +0 -33
- test/e2e_test.py +0 -68
- test/error_test.py +0 -7
- test/function_serialization_test.py +0 -32
- test/function_test.py +0 -653
- test/function_utils_test.py +0 -101
- test/gpu_test.py +0 -159
- test/grpc_utils_test.py +0 -141
- test/helpers.py +0 -42
- test/image_test.py +0 -669
- test/live_reload_test.py +0 -80
- test/lookup_test.py +0 -70
- test/mdmd_test.py +0 -329
- test/mount_test.py +0 -162
- test/mounted_files_test.py +0 -329
- test/network_file_system_test.py +0 -181
- test/notebook_test.py +0 -66
- test/object_test.py +0 -41
- test/package_utils_test.py +0 -25
- test/queue_test.py +0 -97
- test/resolver_test.py +0 -58
- test/retries_test.py +0 -67
- test/runner_test.py +0 -85
- test/sandbox_test.py +0 -191
- test/schedule_test.py +0 -15
- test/scheduler_placement_test.py +0 -29
- test/secret_test.py +0 -78
- test/serialization_test.py +0 -42
- test/stub_composition_test.py +0 -10
- test/stub_test.py +0 -360
- test/test_asgi_wrapper.py +0 -234
- test/token_flow_test.py +0 -18
- test/traceback_test.py +0 -135
- test/tunnel_test.py +0 -29
- test/utils_test.py +0 -88
- test/version_test.py +0 -14
- test/volume_test.py +0 -341
- test/watcher_test.py +0 -30
- test/webhook_test.py +0 -146
- /modal/{requirements.312.txt → requirements/2023.12.312.txt} +0 -0
- /modal/{requirements.txt → requirements/2023.12.txt} +0 -0
- {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/LICENSE +0 -0
- {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/WHEEL +0 -0
- {modal-0.62.16.dist-info → modal-0.72.11.dist-info}/entry_points.txt +0 -0
modal/parallel_map.py
ADDED
@@ -0,0 +1,434 @@
|
|
1
|
+
# Copyright Modal Labs 2024
|
2
|
+
import asyncio
|
3
|
+
import time
|
4
|
+
import typing
|
5
|
+
from dataclasses import dataclass
|
6
|
+
from typing import Any, Callable, Optional
|
7
|
+
|
8
|
+
from grpclib import GRPCError, Status
|
9
|
+
|
10
|
+
from modal._runtime.execution_context import current_input_id
|
11
|
+
from modal._utils.async_utils import (
|
12
|
+
AsyncOrSyncIterable,
|
13
|
+
aclosing,
|
14
|
+
async_map_ordered,
|
15
|
+
async_merge,
|
16
|
+
async_zip,
|
17
|
+
queue_batch_iterator,
|
18
|
+
sync_or_async_iter,
|
19
|
+
synchronize_api,
|
20
|
+
synchronizer,
|
21
|
+
warn_if_generator_is_not_consumed,
|
22
|
+
)
|
23
|
+
from modal._utils.blob_utils import BLOB_MAX_PARALLELISM
|
24
|
+
from modal._utils.function_utils import (
|
25
|
+
ATTEMPT_TIMEOUT_GRACE_PERIOD,
|
26
|
+
OUTPUTS_TIMEOUT,
|
27
|
+
_create_input,
|
28
|
+
_process_result,
|
29
|
+
)
|
30
|
+
from modal._utils.grpc_utils import retry_transient_errors
|
31
|
+
from modal.config import logger
|
32
|
+
from modal_proto import api_pb2
|
33
|
+
|
34
|
+
if typing.TYPE_CHECKING:
|
35
|
+
import modal.client
|
36
|
+
|
37
|
+
|
38
|
+
class _SynchronizedQueue:
|
39
|
+
"""mdmd:hidden"""
|
40
|
+
|
41
|
+
# small wrapper around asyncio.Queue to make it cross-thread compatible through synchronicity
|
42
|
+
async def init(self):
|
43
|
+
# in Python 3.8 the asyncio.Queue is bound to the event loop on creation
|
44
|
+
# so it needs to be created in a synchronicity-wrapped init method
|
45
|
+
self.q = asyncio.Queue()
|
46
|
+
|
47
|
+
@synchronizer.no_io_translation
|
48
|
+
async def put(self, item):
|
49
|
+
await self.q.put(item)
|
50
|
+
|
51
|
+
@synchronizer.no_io_translation
|
52
|
+
async def get(self):
|
53
|
+
return await self.q.get()
|
54
|
+
|
55
|
+
|
56
|
+
SynchronizedQueue = synchronize_api(_SynchronizedQueue)
|
57
|
+
|
58
|
+
|
59
|
+
@dataclass
|
60
|
+
class _OutputValue:
|
61
|
+
# box class for distinguishing None results from non-existing/None markers
|
62
|
+
value: Any
|
63
|
+
|
64
|
+
|
65
|
+
MAP_INVOCATION_CHUNK_SIZE = 49
|
66
|
+
|
67
|
+
if typing.TYPE_CHECKING:
|
68
|
+
import modal.functions
|
69
|
+
|
70
|
+
|
71
|
+
async def _map_invocation(
|
72
|
+
function: "modal.functions._Function",
|
73
|
+
raw_input_queue: _SynchronizedQueue,
|
74
|
+
client: "modal.client._Client",
|
75
|
+
order_outputs: bool,
|
76
|
+
return_exceptions: bool,
|
77
|
+
count_update_callback: Optional[Callable[[int, int], None]],
|
78
|
+
):
|
79
|
+
assert client.stub
|
80
|
+
request = api_pb2.FunctionMapRequest(
|
81
|
+
function_id=function.object_id,
|
82
|
+
parent_input_id=current_input_id() or "",
|
83
|
+
function_call_type=api_pb2.FUNCTION_CALL_TYPE_MAP,
|
84
|
+
return_exceptions=return_exceptions,
|
85
|
+
)
|
86
|
+
response = await retry_transient_errors(client.stub.FunctionMap, request)
|
87
|
+
|
88
|
+
function_call_id = response.function_call_id
|
89
|
+
|
90
|
+
have_all_inputs = False
|
91
|
+
num_inputs = 0
|
92
|
+
num_outputs = 0
|
93
|
+
|
94
|
+
def count_update():
|
95
|
+
if count_update_callback is not None:
|
96
|
+
count_update_callback(num_outputs, num_inputs)
|
97
|
+
|
98
|
+
pending_outputs: dict[str, int] = {} # Map input_id -> next expected gen_index value
|
99
|
+
completed_outputs: set[str] = set() # Set of input_ids whose outputs are complete (expecting no more values)
|
100
|
+
|
101
|
+
input_queue: asyncio.Queue = asyncio.Queue()
|
102
|
+
|
103
|
+
async def create_input(argskwargs):
|
104
|
+
nonlocal num_inputs
|
105
|
+
idx = num_inputs
|
106
|
+
num_inputs += 1
|
107
|
+
(args, kwargs) = argskwargs
|
108
|
+
return await _create_input(args, kwargs, client, idx=idx, method_name=function._use_method_name)
|
109
|
+
|
110
|
+
async def input_iter():
|
111
|
+
while 1:
|
112
|
+
raw_input = await raw_input_queue.get()
|
113
|
+
if raw_input is None: # end of input sentinel
|
114
|
+
break
|
115
|
+
yield raw_input # args, kwargs
|
116
|
+
|
117
|
+
async def drain_input_generator():
|
118
|
+
# Parallelize uploading blobs
|
119
|
+
async with aclosing(
|
120
|
+
async_map_ordered(input_iter(), create_input, concurrency=BLOB_MAX_PARALLELISM)
|
121
|
+
) as streamer:
|
122
|
+
async for item in streamer:
|
123
|
+
await input_queue.put(item)
|
124
|
+
|
125
|
+
# close queue iterator
|
126
|
+
await input_queue.put(None)
|
127
|
+
yield
|
128
|
+
|
129
|
+
async def pump_inputs():
|
130
|
+
assert client.stub
|
131
|
+
nonlocal have_all_inputs, num_inputs
|
132
|
+
async for items in queue_batch_iterator(input_queue, MAP_INVOCATION_CHUNK_SIZE):
|
133
|
+
request = api_pb2.FunctionPutInputsRequest(
|
134
|
+
function_id=function.object_id, inputs=items, function_call_id=function_call_id
|
135
|
+
)
|
136
|
+
logger.debug(
|
137
|
+
f"Pushing {len(items)} inputs to server. Num queued inputs awaiting push is {input_queue.qsize()}."
|
138
|
+
)
|
139
|
+
while True:
|
140
|
+
try:
|
141
|
+
resp = await retry_transient_errors(
|
142
|
+
client.stub.FunctionPutInputs,
|
143
|
+
request,
|
144
|
+
# with 8 retries we log the warning below about every 30 secondswhich isn't too spammy.
|
145
|
+
max_retries=8,
|
146
|
+
max_delay=15,
|
147
|
+
additional_status_codes=[Status.RESOURCE_EXHAUSTED],
|
148
|
+
)
|
149
|
+
break
|
150
|
+
except GRPCError as err:
|
151
|
+
if err.status != Status.RESOURCE_EXHAUSTED:
|
152
|
+
raise err
|
153
|
+
logger.warning(
|
154
|
+
"Warning: map progress is limited. Common bottlenecks "
|
155
|
+
"include slow iteration over results, or function backlogs."
|
156
|
+
)
|
157
|
+
|
158
|
+
count_update()
|
159
|
+
for item in resp.inputs:
|
160
|
+
pending_outputs.setdefault(item.input_id, 0)
|
161
|
+
logger.debug(
|
162
|
+
f"Successfully pushed {len(items)} inputs to server. "
|
163
|
+
f"Num queued inputs awaiting push is {input_queue.qsize()}."
|
164
|
+
)
|
165
|
+
|
166
|
+
have_all_inputs = True
|
167
|
+
yield
|
168
|
+
|
169
|
+
async def get_all_outputs():
|
170
|
+
assert client.stub
|
171
|
+
nonlocal num_inputs, num_outputs, have_all_inputs
|
172
|
+
last_entry_id = "0-0"
|
173
|
+
while not have_all_inputs or len(pending_outputs) > len(completed_outputs):
|
174
|
+
request = api_pb2.FunctionGetOutputsRequest(
|
175
|
+
function_call_id=function_call_id,
|
176
|
+
timeout=OUTPUTS_TIMEOUT,
|
177
|
+
last_entry_id=last_entry_id,
|
178
|
+
clear_on_success=False,
|
179
|
+
requested_at=time.time(),
|
180
|
+
)
|
181
|
+
response = await retry_transient_errors(
|
182
|
+
client.stub.FunctionGetOutputs,
|
183
|
+
request,
|
184
|
+
max_retries=20,
|
185
|
+
attempt_timeout=OUTPUTS_TIMEOUT + ATTEMPT_TIMEOUT_GRACE_PERIOD,
|
186
|
+
)
|
187
|
+
|
188
|
+
if len(response.outputs) == 0:
|
189
|
+
continue
|
190
|
+
|
191
|
+
last_entry_id = response.last_entry_id
|
192
|
+
for item in response.outputs:
|
193
|
+
pending_outputs.setdefault(item.input_id, 0)
|
194
|
+
if item.input_id in completed_outputs:
|
195
|
+
# If this input is already completed, it means the output has already been
|
196
|
+
# processed and was received again due to a duplicate.
|
197
|
+
continue
|
198
|
+
completed_outputs.add(item.input_id)
|
199
|
+
num_outputs += 1
|
200
|
+
yield item
|
201
|
+
|
202
|
+
async def get_all_outputs_and_clean_up():
|
203
|
+
assert client.stub
|
204
|
+
try:
|
205
|
+
async with aclosing(get_all_outputs()) as output_items:
|
206
|
+
async for item in output_items:
|
207
|
+
yield item
|
208
|
+
finally:
|
209
|
+
# "ack" that we have all outputs we are interested in and let backend clear results
|
210
|
+
request = api_pb2.FunctionGetOutputsRequest(
|
211
|
+
function_call_id=function_call_id,
|
212
|
+
timeout=0,
|
213
|
+
last_entry_id="0-0",
|
214
|
+
clear_on_success=True,
|
215
|
+
requested_at=time.time(),
|
216
|
+
)
|
217
|
+
await retry_transient_errors(client.stub.FunctionGetOutputs, request)
|
218
|
+
|
219
|
+
async def fetch_output(item: api_pb2.FunctionGetOutputsItem) -> tuple[int, Any]:
|
220
|
+
try:
|
221
|
+
output = await _process_result(item.result, item.data_format, client.stub, client)
|
222
|
+
except Exception as e:
|
223
|
+
if return_exceptions:
|
224
|
+
output = e
|
225
|
+
else:
|
226
|
+
raise e
|
227
|
+
return (item.idx, output)
|
228
|
+
|
229
|
+
async def poll_outputs():
|
230
|
+
# map to store out-of-order outputs received
|
231
|
+
received_outputs = {}
|
232
|
+
output_idx = 0
|
233
|
+
|
234
|
+
async with aclosing(
|
235
|
+
async_map_ordered(get_all_outputs_and_clean_up(), fetch_output, concurrency=BLOB_MAX_PARALLELISM)
|
236
|
+
) as streamer:
|
237
|
+
async for idx, output in streamer:
|
238
|
+
count_update()
|
239
|
+
if not order_outputs:
|
240
|
+
yield _OutputValue(output)
|
241
|
+
else:
|
242
|
+
# hold on to outputs for function maps, so we can reorder them correctly.
|
243
|
+
received_outputs[idx] = output
|
244
|
+
while output_idx in received_outputs:
|
245
|
+
output = received_outputs.pop(output_idx)
|
246
|
+
yield _OutputValue(output)
|
247
|
+
output_idx += 1
|
248
|
+
|
249
|
+
assert len(received_outputs) == 0
|
250
|
+
|
251
|
+
async with aclosing(async_merge(drain_input_generator(), pump_inputs(), poll_outputs())) as streamer:
|
252
|
+
async for response in streamer:
|
253
|
+
if response is not None:
|
254
|
+
yield response.value
|
255
|
+
|
256
|
+
|
257
|
+
@warn_if_generator_is_not_consumed(function_name="Function.map")
|
258
|
+
def _map_sync(
|
259
|
+
self,
|
260
|
+
*input_iterators: typing.Iterable[Any], # one input iterator per argument in the mapped-over function/generator
|
261
|
+
kwargs={}, # any extra keyword arguments for the function
|
262
|
+
order_outputs: bool = True, # return outputs in order
|
263
|
+
return_exceptions: bool = False, # propagate exceptions (False) or aggregate them in the results list (True)
|
264
|
+
) -> AsyncOrSyncIterable:
|
265
|
+
"""Parallel map over a set of inputs.
|
266
|
+
|
267
|
+
Takes one iterator argument per argument in the function being mapped over.
|
268
|
+
|
269
|
+
Example:
|
270
|
+
```python
|
271
|
+
@app.function()
|
272
|
+
def my_func(a):
|
273
|
+
return a ** 2
|
274
|
+
|
275
|
+
|
276
|
+
@app.local_entrypoint()
|
277
|
+
def main():
|
278
|
+
assert list(my_func.map([1, 2, 3, 4])) == [1, 4, 9, 16]
|
279
|
+
```
|
280
|
+
|
281
|
+
If applied to a `stub.function`, `map()` returns one result per input and the output order
|
282
|
+
is guaranteed to be the same as the input order. Set `order_outputs=False` to return results
|
283
|
+
in the order that they are completed instead.
|
284
|
+
|
285
|
+
`return_exceptions` can be used to treat exceptions as successful results:
|
286
|
+
|
287
|
+
```python
|
288
|
+
@app.function()
|
289
|
+
def my_func(a):
|
290
|
+
if a == 2:
|
291
|
+
raise Exception("ohno")
|
292
|
+
return a ** 2
|
293
|
+
|
294
|
+
|
295
|
+
@app.local_entrypoint()
|
296
|
+
def main():
|
297
|
+
# [0, 1, UserCodeException(Exception('ohno'))]
|
298
|
+
print(list(my_func.map(range(3), return_exceptions=True)))
|
299
|
+
```
|
300
|
+
"""
|
301
|
+
|
302
|
+
return AsyncOrSyncIterable(
|
303
|
+
_map_async(
|
304
|
+
self, *input_iterators, kwargs=kwargs, order_outputs=order_outputs, return_exceptions=return_exceptions
|
305
|
+
),
|
306
|
+
nested_async_message=(
|
307
|
+
"You can't iter(Function.map()) or Function.for_each() from an async function. "
|
308
|
+
"Use async for ... Function.map.aio() or Function.for_each.aio() instead."
|
309
|
+
),
|
310
|
+
)
|
311
|
+
|
312
|
+
|
313
|
+
@warn_if_generator_is_not_consumed(function_name="Function.map.aio")
|
314
|
+
async def _map_async(
|
315
|
+
self,
|
316
|
+
*input_iterators: typing.Union[
|
317
|
+
typing.Iterable[Any], typing.AsyncIterable[Any]
|
318
|
+
], # one input iterator per argument in the mapped-over function/generator
|
319
|
+
kwargs={}, # any extra keyword arguments for the function
|
320
|
+
order_outputs: bool = True, # return outputs in order
|
321
|
+
return_exceptions: bool = False, # propagate exceptions (False) or aggregate them in the results list (True)
|
322
|
+
) -> typing.AsyncGenerator[Any, None]:
|
323
|
+
"""mdmd:hidden
|
324
|
+
This runs in an event loop on the main thread
|
325
|
+
|
326
|
+
It concurrently feeds new input to the input queue and yields available outputs
|
327
|
+
to the caller.
|
328
|
+
Note that since the iterator(s) can block, it's a bit opaque how often the event
|
329
|
+
loop decides to get a new input vs how often it will emit a new output.
|
330
|
+
We could make this explicit as an improvement or even let users decide what they
|
331
|
+
prefer: throughput (prioritize queueing inputs) or latency (prioritize yielding results)
|
332
|
+
"""
|
333
|
+
raw_input_queue: Any = SynchronizedQueue() # type: ignore
|
334
|
+
raw_input_queue.init()
|
335
|
+
|
336
|
+
async def feed_queue():
|
337
|
+
# This runs in a main thread event loop, so it doesn't block the synchronizer loop
|
338
|
+
async with aclosing(async_zip(*[sync_or_async_iter(it) for it in input_iterators])) as streamer:
|
339
|
+
async for args in streamer:
|
340
|
+
await raw_input_queue.put.aio((args, kwargs))
|
341
|
+
await raw_input_queue.put.aio(None) # end-of-input sentinel
|
342
|
+
|
343
|
+
feed_input_task = asyncio.create_task(feed_queue())
|
344
|
+
|
345
|
+
try:
|
346
|
+
# note that `map()` and `map.aio()` are not synchronicity-wrapped, since
|
347
|
+
# they accept executable code in the form of
|
348
|
+
# iterators that we don't want to run inside the synchronicity thread.
|
349
|
+
# Instead, we delegate to `._map()` with a safer Queue as input
|
350
|
+
async with aclosing(self._map.aio(raw_input_queue, order_outputs, return_exceptions)) as map_output_stream:
|
351
|
+
async for output in map_output_stream:
|
352
|
+
yield output
|
353
|
+
finally:
|
354
|
+
feed_input_task.cancel() # should only be needed in case of exceptions
|
355
|
+
|
356
|
+
|
357
|
+
def _for_each_sync(self, *input_iterators, kwargs={}, ignore_exceptions: bool = False):
|
358
|
+
"""Execute function for all inputs, ignoring outputs.
|
359
|
+
|
360
|
+
Convenient alias for `.map()` in cases where the function just needs to be called.
|
361
|
+
as the caller doesn't have to consume the generator to process the inputs.
|
362
|
+
"""
|
363
|
+
# TODO(erikbern): it would be better if this is more like a map_spawn that immediately exits
|
364
|
+
# rather than iterating over the result
|
365
|
+
for _ in self.map(*input_iterators, kwargs=kwargs, order_outputs=False, return_exceptions=ignore_exceptions):
|
366
|
+
pass
|
367
|
+
|
368
|
+
|
369
|
+
async def _for_each_async(self, *input_iterators, kwargs={}, ignore_exceptions: bool = False):
|
370
|
+
async for _ in self.map.aio( # type: ignore
|
371
|
+
*input_iterators, kwargs=kwargs, order_outputs=False, return_exceptions=ignore_exceptions
|
372
|
+
):
|
373
|
+
pass
|
374
|
+
|
375
|
+
|
376
|
+
@warn_if_generator_is_not_consumed(function_name="Function.starmap")
|
377
|
+
async def _starmap_async(
|
378
|
+
self,
|
379
|
+
input_iterator: typing.Union[typing.Iterable[typing.Sequence[Any]], typing.AsyncIterable[typing.Sequence[Any]]],
|
380
|
+
kwargs={},
|
381
|
+
order_outputs: bool = True,
|
382
|
+
return_exceptions: bool = False,
|
383
|
+
) -> typing.AsyncIterable[Any]:
|
384
|
+
raw_input_queue: Any = SynchronizedQueue() # type: ignore
|
385
|
+
raw_input_queue.init()
|
386
|
+
|
387
|
+
async def feed_queue():
|
388
|
+
# This runs in a main thread event loop, so it doesn't block the synchronizer loop
|
389
|
+
async with aclosing(sync_or_async_iter(input_iterator)) as streamer:
|
390
|
+
async for args in streamer:
|
391
|
+
await raw_input_queue.put.aio((args, kwargs))
|
392
|
+
await raw_input_queue.put.aio(None) # end-of-input sentinel
|
393
|
+
|
394
|
+
feed_input_task = asyncio.create_task(feed_queue())
|
395
|
+
try:
|
396
|
+
async for output in self._map.aio(raw_input_queue, order_outputs, return_exceptions): # type: ignore[reportFunctionMemberAccess]
|
397
|
+
yield output
|
398
|
+
finally:
|
399
|
+
feed_input_task.cancel() # should only be needed in case of exceptions
|
400
|
+
|
401
|
+
|
402
|
+
@warn_if_generator_is_not_consumed(function_name="Function.starmap.aio")
|
403
|
+
def _starmap_sync(
|
404
|
+
self,
|
405
|
+
input_iterator: typing.Iterable[typing.Sequence[Any]],
|
406
|
+
kwargs={},
|
407
|
+
order_outputs: bool = True,
|
408
|
+
return_exceptions: bool = False,
|
409
|
+
) -> AsyncOrSyncIterable:
|
410
|
+
"""Like `map`, but spreads arguments over multiple function arguments.
|
411
|
+
|
412
|
+
Assumes every input is a sequence (e.g. a tuple).
|
413
|
+
|
414
|
+
Example:
|
415
|
+
```python
|
416
|
+
@app.function()
|
417
|
+
def my_func(a, b):
|
418
|
+
return a + b
|
419
|
+
|
420
|
+
|
421
|
+
@app.local_entrypoint()
|
422
|
+
def main():
|
423
|
+
assert list(my_func.starmap([(1, 2), (3, 4)])) == [3, 7]
|
424
|
+
```
|
425
|
+
"""
|
426
|
+
return AsyncOrSyncIterable(
|
427
|
+
_starmap_async(
|
428
|
+
self, input_iterator, kwargs=kwargs, order_outputs=order_outputs, return_exceptions=return_exceptions
|
429
|
+
),
|
430
|
+
nested_async_message=(
|
431
|
+
"You can't run Function.map() or Function.for_each() from an async function. "
|
432
|
+
"Use Function.map.aio()/Function.for_each.aio() instead."
|
433
|
+
),
|
434
|
+
)
|
modal/parallel_map.pyi
ADDED
@@ -0,0 +1,75 @@
|
|
1
|
+
import modal._utils.async_utils
|
2
|
+
import modal.client
|
3
|
+
import modal.functions
|
4
|
+
import typing
|
5
|
+
import typing_extensions
|
6
|
+
|
7
|
+
class _SynchronizedQueue:
|
8
|
+
async def init(self): ...
|
9
|
+
async def put(self, item): ...
|
10
|
+
async def get(self): ...
|
11
|
+
|
12
|
+
class SynchronizedQueue:
|
13
|
+
def __init__(self, /, *args, **kwargs): ...
|
14
|
+
|
15
|
+
class __init_spec(typing_extensions.Protocol):
|
16
|
+
def __call__(self): ...
|
17
|
+
async def aio(self): ...
|
18
|
+
|
19
|
+
init: __init_spec
|
20
|
+
|
21
|
+
class __put_spec(typing_extensions.Protocol):
|
22
|
+
def __call__(self, item): ...
|
23
|
+
async def aio(self, item): ...
|
24
|
+
|
25
|
+
put: __put_spec
|
26
|
+
|
27
|
+
class __get_spec(typing_extensions.Protocol):
|
28
|
+
def __call__(self): ...
|
29
|
+
async def aio(self): ...
|
30
|
+
|
31
|
+
get: __get_spec
|
32
|
+
|
33
|
+
class _OutputValue:
|
34
|
+
value: typing.Any
|
35
|
+
|
36
|
+
def __init__(self, value: typing.Any) -> None: ...
|
37
|
+
def __repr__(self): ...
|
38
|
+
def __eq__(self, other): ...
|
39
|
+
|
40
|
+
def _map_invocation(
|
41
|
+
function: modal.functions._Function,
|
42
|
+
raw_input_queue: _SynchronizedQueue,
|
43
|
+
client: modal.client._Client,
|
44
|
+
order_outputs: bool,
|
45
|
+
return_exceptions: bool,
|
46
|
+
count_update_callback: typing.Optional[typing.Callable[[int, int], None]],
|
47
|
+
): ...
|
48
|
+
def _map_sync(
|
49
|
+
self, *input_iterators, kwargs={}, order_outputs: bool = True, return_exceptions: bool = False
|
50
|
+
) -> modal._utils.async_utils.AsyncOrSyncIterable: ...
|
51
|
+
def _map_async(
|
52
|
+
self,
|
53
|
+
*input_iterators: typing.Union[typing.Iterable[typing.Any], typing.AsyncIterable[typing.Any]],
|
54
|
+
kwargs={},
|
55
|
+
order_outputs: bool = True,
|
56
|
+
return_exceptions: bool = False,
|
57
|
+
) -> typing.AsyncGenerator[typing.Any, None]: ...
|
58
|
+
def _for_each_sync(self, *input_iterators, kwargs={}, ignore_exceptions: bool = False): ...
|
59
|
+
async def _for_each_async(self, *input_iterators, kwargs={}, ignore_exceptions: bool = False): ...
|
60
|
+
def _starmap_async(
|
61
|
+
self,
|
62
|
+
input_iterator: typing.Union[
|
63
|
+
typing.Iterable[typing.Sequence[typing.Any]], typing.AsyncIterable[typing.Sequence[typing.Any]]
|
64
|
+
],
|
65
|
+
kwargs={},
|
66
|
+
order_outputs: bool = True,
|
67
|
+
return_exceptions: bool = False,
|
68
|
+
) -> typing.AsyncIterable[typing.Any]: ...
|
69
|
+
def _starmap_sync(
|
70
|
+
self,
|
71
|
+
input_iterator: typing.Iterable[typing.Sequence[typing.Any]],
|
72
|
+
kwargs={},
|
73
|
+
order_outputs: bool = True,
|
74
|
+
return_exceptions: bool = False,
|
75
|
+
) -> modal._utils.async_utils.AsyncOrSyncIterable: ...
|