modal 1.0.6.dev58__py3-none-any.whl → 1.2.3.dev7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of modal might be problematic. Click here for more details.
- modal/__main__.py +3 -4
- modal/_billing.py +80 -0
- modal/_clustered_functions.py +7 -3
- modal/_clustered_functions.pyi +4 -2
- modal/_container_entrypoint.py +41 -49
- modal/_functions.py +424 -195
- modal/_grpc_client.py +171 -0
- modal/_load_context.py +105 -0
- modal/_object.py +68 -20
- modal/_output.py +58 -45
- modal/_partial_function.py +36 -11
- modal/_pty.py +7 -3
- modal/_resolver.py +21 -35
- modal/_runtime/asgi.py +4 -3
- modal/_runtime/container_io_manager.py +301 -186
- modal/_runtime/container_io_manager.pyi +70 -61
- modal/_runtime/execution_context.py +18 -2
- modal/_runtime/execution_context.pyi +4 -1
- modal/_runtime/gpu_memory_snapshot.py +170 -63
- modal/_runtime/user_code_imports.py +28 -58
- modal/_serialization.py +57 -1
- modal/_utils/async_utils.py +33 -12
- modal/_utils/auth_token_manager.py +2 -5
- modal/_utils/blob_utils.py +110 -53
- modal/_utils/function_utils.py +49 -42
- modal/_utils/grpc_utils.py +80 -50
- modal/_utils/mount_utils.py +26 -1
- modal/_utils/name_utils.py +17 -3
- modal/_utils/task_command_router_client.py +536 -0
- modal/_utils/time_utils.py +34 -6
- modal/app.py +219 -83
- modal/app.pyi +229 -56
- modal/billing.py +5 -0
- modal/{requirements → builder}/2025.06.txt +1 -0
- modal/{requirements → builder}/PREVIEW.txt +1 -0
- modal/cli/_download.py +19 -3
- modal/cli/_traceback.py +3 -2
- modal/cli/app.py +4 -4
- modal/cli/cluster.py +15 -7
- modal/cli/config.py +5 -3
- modal/cli/container.py +7 -6
- modal/cli/dict.py +22 -16
- modal/cli/entry_point.py +12 -5
- modal/cli/environment.py +5 -4
- modal/cli/import_refs.py +3 -3
- modal/cli/launch.py +102 -5
- modal/cli/network_file_system.py +9 -13
- modal/cli/profile.py +3 -2
- modal/cli/programs/launch_instance_ssh.py +94 -0
- modal/cli/programs/run_jupyter.py +1 -1
- modal/cli/programs/run_marimo.py +95 -0
- modal/cli/programs/vscode.py +1 -1
- modal/cli/queues.py +57 -26
- modal/cli/run.py +58 -16
- modal/cli/secret.py +48 -22
- modal/cli/utils.py +3 -4
- modal/cli/volume.py +28 -25
- modal/client.py +13 -116
- modal/client.pyi +9 -91
- modal/cloud_bucket_mount.py +5 -3
- modal/cloud_bucket_mount.pyi +5 -1
- modal/cls.py +130 -102
- modal/cls.pyi +45 -85
- modal/config.py +29 -10
- modal/container_process.py +291 -13
- modal/container_process.pyi +95 -32
- modal/dict.py +282 -63
- modal/dict.pyi +423 -73
- modal/environments.py +15 -27
- modal/environments.pyi +5 -15
- modal/exception.py +8 -0
- modal/experimental/__init__.py +143 -38
- modal/experimental/flash.py +247 -78
- modal/experimental/flash.pyi +137 -9
- modal/file_io.py +14 -28
- modal/file_io.pyi +2 -2
- modal/file_pattern_matcher.py +25 -16
- modal/functions.pyi +134 -61
- modal/image.py +255 -86
- modal/image.pyi +300 -62
- modal/io_streams.py +436 -126
- modal/io_streams.pyi +236 -171
- modal/mount.py +62 -157
- modal/mount.pyi +45 -172
- modal/network_file_system.py +30 -53
- modal/network_file_system.pyi +16 -76
- modal/object.pyi +42 -8
- modal/parallel_map.py +821 -113
- modal/parallel_map.pyi +134 -0
- modal/partial_function.pyi +4 -1
- modal/proxy.py +16 -7
- modal/proxy.pyi +10 -2
- modal/queue.py +263 -61
- modal/queue.pyi +409 -66
- modal/runner.py +112 -92
- modal/runner.pyi +45 -27
- modal/sandbox.py +451 -124
- modal/sandbox.pyi +513 -67
- modal/secret.py +291 -67
- modal/secret.pyi +425 -19
- modal/serving.py +7 -11
- modal/serving.pyi +7 -8
- modal/snapshot.py +11 -8
- modal/token_flow.py +4 -4
- modal/volume.py +344 -98
- modal/volume.pyi +464 -68
- {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/METADATA +9 -8
- modal-1.2.3.dev7.dist-info/RECORD +195 -0
- modal_docs/mdmd/mdmd.py +11 -1
- modal_proto/api.proto +399 -67
- modal_proto/api_grpc.py +241 -1
- modal_proto/api_pb2.py +1395 -1000
- modal_proto/api_pb2.pyi +1239 -79
- modal_proto/api_pb2_grpc.py +499 -4
- modal_proto/api_pb2_grpc.pyi +162 -14
- modal_proto/modal_api_grpc.py +175 -160
- modal_proto/sandbox_router.proto +145 -0
- modal_proto/sandbox_router_grpc.py +105 -0
- modal_proto/sandbox_router_pb2.py +149 -0
- modal_proto/sandbox_router_pb2.pyi +333 -0
- modal_proto/sandbox_router_pb2_grpc.py +203 -0
- modal_proto/sandbox_router_pb2_grpc.pyi +75 -0
- modal_proto/task_command_router.proto +144 -0
- modal_proto/task_command_router_grpc.py +105 -0
- modal_proto/task_command_router_pb2.py +149 -0
- modal_proto/task_command_router_pb2.pyi +333 -0
- modal_proto/task_command_router_pb2_grpc.py +203 -0
- modal_proto/task_command_router_pb2_grpc.pyi +75 -0
- modal_version/__init__.py +1 -1
- modal-1.0.6.dev58.dist-info/RECORD +0 -183
- modal_proto/modal_options_grpc.py +0 -3
- modal_proto/options.proto +0 -19
- modal_proto/options_grpc.py +0 -3
- modal_proto/options_pb2.py +0 -35
- modal_proto/options_pb2.pyi +0 -20
- modal_proto/options_pb2_grpc.py +0 -4
- modal_proto/options_pb2_grpc.pyi +0 -7
- /modal/{requirements → builder}/2023.12.312.txt +0 -0
- /modal/{requirements → builder}/2023.12.txt +0 -0
- /modal/{requirements → builder}/2024.04.txt +0 -0
- /modal/{requirements → builder}/2024.10.txt +0 -0
- /modal/{requirements → builder}/README.md +0 -0
- /modal/{requirements → builder}/base-images.json +0 -0
- {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/WHEEL +0 -0
- {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/entry_points.txt +0 -0
- {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/licenses/LICENSE +0 -0
- {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/top_level.txt +0 -0
modal/parallel_map.py
CHANGED
|
@@ -6,7 +6,7 @@ import time
|
|
|
6
6
|
import typing
|
|
7
7
|
from asyncio import FIRST_COMPLETED
|
|
8
8
|
from dataclasses import dataclass
|
|
9
|
-
from typing import Any, Callable, Optional
|
|
9
|
+
from typing import Any, Callable, Optional, Union
|
|
10
10
|
|
|
11
11
|
from grpclib import Status
|
|
12
12
|
|
|
@@ -35,7 +35,7 @@ from modal._utils.function_utils import (
|
|
|
35
35
|
_create_input,
|
|
36
36
|
_process_result,
|
|
37
37
|
)
|
|
38
|
-
from modal._utils.grpc_utils import RETRYABLE_GRPC_STATUS_CODES,
|
|
38
|
+
from modal._utils.grpc_utils import RETRYABLE_GRPC_STATUS_CODES, Retry, RetryWarningMessage
|
|
39
39
|
from modal._utils.jwt_utils import DecodedJwt
|
|
40
40
|
from modal.config import logger
|
|
41
41
|
from modal.retries import RetryManager
|
|
@@ -79,13 +79,286 @@ class _OutputValue:
|
|
|
79
79
|
|
|
80
80
|
MAX_INPUTS_OUTSTANDING_DEFAULT = 1000
|
|
81
81
|
|
|
82
|
-
#
|
|
82
|
+
# Maximum number of inputs to send to the server per FunctionPutInputs request
|
|
83
83
|
MAP_INVOCATION_CHUNK_SIZE = 49
|
|
84
|
+
SPAWN_MAP_INVOCATION_CHUNK_SIZE = 512
|
|
85
|
+
|
|
84
86
|
|
|
85
87
|
if typing.TYPE_CHECKING:
|
|
86
88
|
import modal.functions
|
|
87
89
|
|
|
88
90
|
|
|
91
|
+
class InputPreprocessor:
|
|
92
|
+
"""
|
|
93
|
+
Constructs FunctionPutInputsItem objects from the raw-input queue, and puts them in the processed-input queue.
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
def __init__(
|
|
97
|
+
self,
|
|
98
|
+
client: "modal.client._Client",
|
|
99
|
+
*,
|
|
100
|
+
raw_input_queue: _SynchronizedQueue,
|
|
101
|
+
processed_input_queue: asyncio.Queue,
|
|
102
|
+
function: "modal.functions._Function",
|
|
103
|
+
created_callback: Callable[[int], None],
|
|
104
|
+
done_callback: Callable[[], None],
|
|
105
|
+
):
|
|
106
|
+
self.client = client
|
|
107
|
+
self.function = function
|
|
108
|
+
self.inputs_created = 0
|
|
109
|
+
self.raw_input_queue = raw_input_queue
|
|
110
|
+
self.processed_input_queue = processed_input_queue
|
|
111
|
+
self.created_callback = created_callback
|
|
112
|
+
self.done_callback = done_callback
|
|
113
|
+
|
|
114
|
+
async def input_iter(self):
|
|
115
|
+
while 1:
|
|
116
|
+
raw_input = await self.raw_input_queue.get()
|
|
117
|
+
if raw_input is None: # end of input sentinel
|
|
118
|
+
break
|
|
119
|
+
yield raw_input # args, kwargs
|
|
120
|
+
|
|
121
|
+
def create_input_factory(self):
|
|
122
|
+
async def create_input(argskwargs):
|
|
123
|
+
idx = self.inputs_created
|
|
124
|
+
self.inputs_created += 1
|
|
125
|
+
self.created_callback(self.inputs_created)
|
|
126
|
+
(args, kwargs) = argskwargs
|
|
127
|
+
return await _create_input(
|
|
128
|
+
args,
|
|
129
|
+
kwargs,
|
|
130
|
+
self.client.stub,
|
|
131
|
+
idx=idx,
|
|
132
|
+
function=self.function,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
return create_input
|
|
136
|
+
|
|
137
|
+
async def drain_input_generator(self):
|
|
138
|
+
# Parallelize uploading blobs
|
|
139
|
+
async with aclosing(
|
|
140
|
+
async_map_ordered(self.input_iter(), self.create_input_factory(), concurrency=BLOB_MAX_PARALLELISM)
|
|
141
|
+
) as streamer:
|
|
142
|
+
async for item in streamer:
|
|
143
|
+
await self.processed_input_queue.put(item)
|
|
144
|
+
|
|
145
|
+
# close queue iterator
|
|
146
|
+
await self.processed_input_queue.put(None)
|
|
147
|
+
self.done_callback()
|
|
148
|
+
yield
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class InputPumper:
|
|
152
|
+
"""
|
|
153
|
+
Reads inputs from a queue of FunctionPutInputsItems, and sends them to the server.
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
def __init__(
|
|
157
|
+
self,
|
|
158
|
+
client: "modal.client._Client",
|
|
159
|
+
*,
|
|
160
|
+
input_queue: asyncio.Queue,
|
|
161
|
+
function: "modal.functions._Function",
|
|
162
|
+
function_call_id: str,
|
|
163
|
+
max_batch_size: int,
|
|
164
|
+
map_items_manager: Optional["_MapItemsManager"] = None,
|
|
165
|
+
):
|
|
166
|
+
self.client = client
|
|
167
|
+
self.function = function
|
|
168
|
+
self.map_items_manager = map_items_manager
|
|
169
|
+
self.input_queue = input_queue
|
|
170
|
+
self.inputs_sent = 0
|
|
171
|
+
self.function_call_id = function_call_id
|
|
172
|
+
self.max_batch_size = max_batch_size
|
|
173
|
+
|
|
174
|
+
async def pump_inputs(self):
|
|
175
|
+
assert self.client.stub
|
|
176
|
+
async for items in queue_batch_iterator(self.input_queue, max_batch_size=self.max_batch_size):
|
|
177
|
+
# Add items to the manager. Their state will be SENDING.
|
|
178
|
+
if self.map_items_manager is not None:
|
|
179
|
+
await self.map_items_manager.add_items(items)
|
|
180
|
+
request = api_pb2.FunctionPutInputsRequest(
|
|
181
|
+
function_id=self.function.object_id,
|
|
182
|
+
inputs=items,
|
|
183
|
+
function_call_id=self.function_call_id,
|
|
184
|
+
)
|
|
185
|
+
logger.debug(
|
|
186
|
+
f"Pushing {len(items)} inputs to server. Num queued inputs awaiting"
|
|
187
|
+
f" push is {self.input_queue.qsize()}. "
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
resp = await self.client.stub.FunctionPutInputs(request, retry=self._function_inputs_retry)
|
|
191
|
+
self.inputs_sent += len(items)
|
|
192
|
+
# Change item state to WAITING_FOR_OUTPUT, and set the input_id and input_jwt which are in the response.
|
|
193
|
+
if self.map_items_manager is not None:
|
|
194
|
+
self.map_items_manager.handle_put_inputs_response(resp.inputs)
|
|
195
|
+
logger.debug(
|
|
196
|
+
f"Successfully pushed {len(items)} inputs to server. "
|
|
197
|
+
f"Num queued inputs awaiting push is {self.input_queue.qsize()}."
|
|
198
|
+
)
|
|
199
|
+
yield
|
|
200
|
+
|
|
201
|
+
@property
|
|
202
|
+
def _function_inputs_retry(self) -> Retry:
|
|
203
|
+
# with 8 retries we log the warning below about every 30 seconds which isn't too spammy.
|
|
204
|
+
retry_warning_message = RetryWarningMessage(
|
|
205
|
+
message=f"Warning: map progress for function {self.function._function_name} is limited."
|
|
206
|
+
" Common bottlenecks include slow iteration over results, or function backlogs.",
|
|
207
|
+
warning_interval=8,
|
|
208
|
+
errors_to_warn_for=[Status.RESOURCE_EXHAUSTED],
|
|
209
|
+
)
|
|
210
|
+
return Retry(
|
|
211
|
+
max_retries=None,
|
|
212
|
+
max_delay=PUMP_INPUTS_MAX_RETRY_DELAY,
|
|
213
|
+
additional_status_codes=[Status.RESOURCE_EXHAUSTED],
|
|
214
|
+
warning_message=retry_warning_message,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class SyncInputPumper(InputPumper):
|
|
219
|
+
def __init__(
|
|
220
|
+
self,
|
|
221
|
+
client: "modal.client._Client",
|
|
222
|
+
*,
|
|
223
|
+
input_queue: asyncio.Queue,
|
|
224
|
+
retry_queue: TimestampPriorityQueue,
|
|
225
|
+
function: "modal.functions._Function",
|
|
226
|
+
function_call_jwt: str,
|
|
227
|
+
function_call_id: str,
|
|
228
|
+
map_items_manager: "_MapItemsManager",
|
|
229
|
+
):
|
|
230
|
+
super().__init__(
|
|
231
|
+
client,
|
|
232
|
+
input_queue=input_queue,
|
|
233
|
+
function=function,
|
|
234
|
+
function_call_id=function_call_id,
|
|
235
|
+
max_batch_size=MAP_INVOCATION_CHUNK_SIZE,
|
|
236
|
+
map_items_manager=map_items_manager,
|
|
237
|
+
)
|
|
238
|
+
self.retry_queue = retry_queue
|
|
239
|
+
self.inputs_retried = 0
|
|
240
|
+
self.function_call_jwt = function_call_jwt
|
|
241
|
+
|
|
242
|
+
async def retry_inputs(self):
|
|
243
|
+
async for retriable_idxs in queue_batch_iterator(self.retry_queue, max_batch_size=self.max_batch_size):
|
|
244
|
+
# For each index, use the context in the manager to create a FunctionRetryInputsItem.
|
|
245
|
+
# This will also update the context state to RETRYING.
|
|
246
|
+
inputs: list[api_pb2.FunctionRetryInputsItem] = await self.map_items_manager.prepare_items_for_retry(
|
|
247
|
+
retriable_idxs
|
|
248
|
+
)
|
|
249
|
+
request = api_pb2.FunctionRetryInputsRequest(
|
|
250
|
+
function_call_jwt=self.function_call_jwt,
|
|
251
|
+
inputs=inputs,
|
|
252
|
+
)
|
|
253
|
+
resp = await self.client.stub.FunctionRetryInputs(request, retry=self._function_inputs_retry)
|
|
254
|
+
# Update the state to WAITING_FOR_OUTPUT, and update the input_jwt in the context
|
|
255
|
+
# to the new value in the response.
|
|
256
|
+
self.map_items_manager.handle_retry_response(resp.input_jwts)
|
|
257
|
+
logger.debug(f"Successfully pushed retry for {len(inputs)} to server.")
|
|
258
|
+
self.inputs_retried += len(inputs)
|
|
259
|
+
yield
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
class AsyncInputPumper(InputPumper):
|
|
263
|
+
def __init__(
|
|
264
|
+
self,
|
|
265
|
+
client: "modal.client._Client",
|
|
266
|
+
*,
|
|
267
|
+
input_queue: asyncio.Queue,
|
|
268
|
+
function: "modal.functions._Function",
|
|
269
|
+
function_call_id: str,
|
|
270
|
+
):
|
|
271
|
+
super().__init__(
|
|
272
|
+
client,
|
|
273
|
+
input_queue=input_queue,
|
|
274
|
+
function=function,
|
|
275
|
+
function_call_id=function_call_id,
|
|
276
|
+
max_batch_size=SPAWN_MAP_INVOCATION_CHUNK_SIZE,
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
async def pump_inputs(self):
|
|
280
|
+
async for _ in super().pump_inputs():
|
|
281
|
+
pass
|
|
282
|
+
request = api_pb2.FunctionFinishInputsRequest(
|
|
283
|
+
function_id=self.function.object_id,
|
|
284
|
+
function_call_id=self.function_call_id,
|
|
285
|
+
num_inputs=self.inputs_sent,
|
|
286
|
+
)
|
|
287
|
+
await self.client.stub.FunctionFinishInputs(request, retry=Retry(max_retries=None))
|
|
288
|
+
yield
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
async def _spawn_map_invocation(
|
|
292
|
+
function: "modal.functions._Function", raw_input_queue: _SynchronizedQueue, client: "modal.client._Client"
|
|
293
|
+
) -> tuple[str, int]:
|
|
294
|
+
assert client.stub
|
|
295
|
+
request = api_pb2.FunctionMapRequest(
|
|
296
|
+
function_id=function.object_id,
|
|
297
|
+
parent_input_id=current_input_id() or "",
|
|
298
|
+
function_call_type=api_pb2.FUNCTION_CALL_TYPE_MAP,
|
|
299
|
+
function_call_invocation_type=api_pb2.FUNCTION_CALL_INVOCATION_TYPE_ASYNC,
|
|
300
|
+
)
|
|
301
|
+
response: api_pb2.FunctionMapResponse = await client.stub.FunctionMap(request)
|
|
302
|
+
function_call_id = response.function_call_id
|
|
303
|
+
|
|
304
|
+
have_all_inputs = False
|
|
305
|
+
inputs_created = 0
|
|
306
|
+
|
|
307
|
+
def set_inputs_created(set_inputs_created):
|
|
308
|
+
nonlocal inputs_created
|
|
309
|
+
assert set_inputs_created is None or set_inputs_created > inputs_created
|
|
310
|
+
inputs_created = set_inputs_created
|
|
311
|
+
|
|
312
|
+
def set_have_all_inputs():
|
|
313
|
+
nonlocal have_all_inputs
|
|
314
|
+
have_all_inputs = True
|
|
315
|
+
|
|
316
|
+
input_queue: asyncio.Queue[api_pb2.FunctionPutInputsItem | None] = asyncio.Queue()
|
|
317
|
+
input_preprocessor = InputPreprocessor(
|
|
318
|
+
client=client,
|
|
319
|
+
raw_input_queue=raw_input_queue,
|
|
320
|
+
processed_input_queue=input_queue,
|
|
321
|
+
function=function,
|
|
322
|
+
created_callback=set_inputs_created,
|
|
323
|
+
done_callback=set_have_all_inputs,
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
input_pumper = AsyncInputPumper(
|
|
327
|
+
client=client,
|
|
328
|
+
input_queue=input_queue,
|
|
329
|
+
function=function,
|
|
330
|
+
function_call_id=function_call_id,
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
def log_stats():
|
|
334
|
+
logger.debug(
|
|
335
|
+
f"have_all_inputs={have_all_inputs} inputs_created={inputs_created} inputs_sent={input_pumper.inputs_sent} "
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
async def log_task():
|
|
339
|
+
while True:
|
|
340
|
+
log_stats()
|
|
341
|
+
try:
|
|
342
|
+
await asyncio.sleep(10)
|
|
343
|
+
except asyncio.CancelledError:
|
|
344
|
+
# Log final stats before exiting
|
|
345
|
+
log_stats()
|
|
346
|
+
break
|
|
347
|
+
|
|
348
|
+
async def consume_generator(gen):
|
|
349
|
+
async for _ in gen:
|
|
350
|
+
pass
|
|
351
|
+
|
|
352
|
+
log_debug_stats_task = asyncio.create_task(log_task())
|
|
353
|
+
await asyncio.gather(
|
|
354
|
+
consume_generator(input_preprocessor.drain_input_generator()),
|
|
355
|
+
consume_generator(input_pumper.pump_inputs()),
|
|
356
|
+
)
|
|
357
|
+
log_debug_stats_task.cancel()
|
|
358
|
+
await log_debug_stats_task
|
|
359
|
+
return function_call_id, inputs_created
|
|
360
|
+
|
|
361
|
+
|
|
89
362
|
async def _map_invocation(
|
|
90
363
|
function: "modal.functions._Function",
|
|
91
364
|
raw_input_queue: _SynchronizedQueue,
|
|
@@ -104,7 +377,7 @@ async def _map_invocation(
|
|
|
104
377
|
return_exceptions=return_exceptions,
|
|
105
378
|
function_call_invocation_type=function_call_invocation_type,
|
|
106
379
|
)
|
|
107
|
-
response: api_pb2.FunctionMapResponse = await
|
|
380
|
+
response: api_pb2.FunctionMapResponse = await client.stub.FunctionMap(request)
|
|
108
381
|
|
|
109
382
|
function_call_id = response.function_call_id
|
|
110
383
|
function_call_jwt = response.function_call_jwt
|
|
@@ -117,8 +390,6 @@ async def _map_invocation(
|
|
|
117
390
|
have_all_inputs = False
|
|
118
391
|
map_done_event = asyncio.Event()
|
|
119
392
|
inputs_created = 0
|
|
120
|
-
inputs_sent = 0
|
|
121
|
-
inputs_retried = 0
|
|
122
393
|
outputs_completed = 0
|
|
123
394
|
outputs_received = 0
|
|
124
395
|
retried_outputs = 0
|
|
@@ -135,25 +406,24 @@ async def _map_invocation(
|
|
|
135
406
|
retry_policy, function_call_invocation_type, retry_queue, sync_client_retries_enabled, max_inputs_outstanding
|
|
136
407
|
)
|
|
137
408
|
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
max_object_size_bytes=function._max_object_size_bytes,
|
|
147
|
-
idx=idx,
|
|
148
|
-
method_name=function._use_method_name,
|
|
149
|
-
)
|
|
409
|
+
input_preprocessor = InputPreprocessor(
|
|
410
|
+
client=client,
|
|
411
|
+
raw_input_queue=raw_input_queue,
|
|
412
|
+
processed_input_queue=input_queue,
|
|
413
|
+
function=function,
|
|
414
|
+
created_callback=lambda x: update_state(set_inputs_created=x),
|
|
415
|
+
done_callback=lambda: update_state(set_have_all_inputs=True),
|
|
416
|
+
)
|
|
150
417
|
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
418
|
+
input_pumper = SyncInputPumper(
|
|
419
|
+
client=client,
|
|
420
|
+
input_queue=input_queue,
|
|
421
|
+
retry_queue=retry_queue,
|
|
422
|
+
function=function,
|
|
423
|
+
map_items_manager=map_items_manager,
|
|
424
|
+
function_call_jwt=function_call_jwt,
|
|
425
|
+
function_call_id=function_call_id,
|
|
426
|
+
)
|
|
157
427
|
|
|
158
428
|
def update_state(set_have_all_inputs=None, set_inputs_created=None, set_outputs_completed=None):
|
|
159
429
|
# This should be the only method that needs nonlocal of the following vars
|
|
@@ -175,84 +445,6 @@ async def _map_invocation(
|
|
|
175
445
|
# map is done
|
|
176
446
|
map_done_event.set()
|
|
177
447
|
|
|
178
|
-
async def drain_input_generator():
|
|
179
|
-
# Parallelize uploading blobs
|
|
180
|
-
async with aclosing(
|
|
181
|
-
async_map_ordered(input_iter(), create_input, concurrency=BLOB_MAX_PARALLELISM)
|
|
182
|
-
) as streamer:
|
|
183
|
-
async for item in streamer:
|
|
184
|
-
await input_queue.put(item)
|
|
185
|
-
|
|
186
|
-
# close queue iterator
|
|
187
|
-
await input_queue.put(None)
|
|
188
|
-
update_state(set_have_all_inputs=True)
|
|
189
|
-
yield
|
|
190
|
-
|
|
191
|
-
async def pump_inputs():
|
|
192
|
-
assert client.stub
|
|
193
|
-
nonlocal inputs_sent
|
|
194
|
-
async for items in queue_batch_iterator(input_queue, max_batch_size=MAP_INVOCATION_CHUNK_SIZE):
|
|
195
|
-
# Add items to the manager. Their state will be SENDING.
|
|
196
|
-
await map_items_manager.add_items(items)
|
|
197
|
-
request = api_pb2.FunctionPutInputsRequest(
|
|
198
|
-
function_id=function.object_id,
|
|
199
|
-
inputs=items,
|
|
200
|
-
function_call_id=function_call_id,
|
|
201
|
-
)
|
|
202
|
-
logger.debug(
|
|
203
|
-
f"Pushing {len(items)} inputs to server. Num queued inputs awaiting push is {input_queue.qsize()}."
|
|
204
|
-
)
|
|
205
|
-
|
|
206
|
-
resp = await send_inputs(client.stub.FunctionPutInputs, request)
|
|
207
|
-
inputs_sent += len(items)
|
|
208
|
-
# Change item state to WAITING_FOR_OUTPUT, and set the input_id and input_jwt which are in the response.
|
|
209
|
-
map_items_manager.handle_put_inputs_response(resp.inputs)
|
|
210
|
-
logger.debug(
|
|
211
|
-
f"Successfully pushed {len(items)} inputs to server. "
|
|
212
|
-
f"Num queued inputs awaiting push is {input_queue.qsize()}."
|
|
213
|
-
)
|
|
214
|
-
yield
|
|
215
|
-
|
|
216
|
-
async def retry_inputs():
|
|
217
|
-
nonlocal inputs_retried
|
|
218
|
-
async for retriable_idxs in queue_batch_iterator(retry_queue, max_batch_size=MAP_INVOCATION_CHUNK_SIZE):
|
|
219
|
-
# For each index, use the context in the manager to create a FunctionRetryInputsItem.
|
|
220
|
-
# This will also update the context state to RETRYING.
|
|
221
|
-
inputs: list[api_pb2.FunctionRetryInputsItem] = await map_items_manager.prepare_items_for_retry(
|
|
222
|
-
retriable_idxs
|
|
223
|
-
)
|
|
224
|
-
request = api_pb2.FunctionRetryInputsRequest(
|
|
225
|
-
function_call_jwt=function_call_jwt,
|
|
226
|
-
inputs=inputs,
|
|
227
|
-
)
|
|
228
|
-
resp = await send_inputs(client.stub.FunctionRetryInputs, request)
|
|
229
|
-
# Update the state to WAITING_FOR_OUTPUT, and update the input_jwt in the context
|
|
230
|
-
# to the new value in the response.
|
|
231
|
-
map_items_manager.handle_retry_response(resp.input_jwts)
|
|
232
|
-
logger.debug(f"Successfully pushed retry for {len(inputs)} to server.")
|
|
233
|
-
inputs_retried += len(inputs)
|
|
234
|
-
yield
|
|
235
|
-
|
|
236
|
-
async def send_inputs(
|
|
237
|
-
fn: "modal.client.UnaryUnaryWrapper",
|
|
238
|
-
request: typing.Union[api_pb2.FunctionPutInputsRequest, api_pb2.FunctionRetryInputsRequest],
|
|
239
|
-
) -> typing.Union[api_pb2.FunctionPutInputsResponse, api_pb2.FunctionRetryInputsResponse]:
|
|
240
|
-
# with 8 retries we log the warning below about every 30 seconds which isn't too spammy.
|
|
241
|
-
retry_warning_message = RetryWarningMessage(
|
|
242
|
-
message=f"Warning: map progress for function {function._function_name} is limited."
|
|
243
|
-
" Common bottlenecks include slow iteration over results, or function backlogs.",
|
|
244
|
-
warning_interval=8,
|
|
245
|
-
errors_to_warn_for=[Status.RESOURCE_EXHAUSTED],
|
|
246
|
-
)
|
|
247
|
-
return await retry_transient_errors(
|
|
248
|
-
fn,
|
|
249
|
-
request,
|
|
250
|
-
max_retries=None,
|
|
251
|
-
max_delay=PUMP_INPUTS_MAX_RETRY_DELAY,
|
|
252
|
-
additional_status_codes=[Status.RESOURCE_EXHAUSTED],
|
|
253
|
-
retry_warning_message=retry_warning_message,
|
|
254
|
-
)
|
|
255
|
-
|
|
256
448
|
async def get_all_outputs():
|
|
257
449
|
assert client.stub
|
|
258
450
|
nonlocal \
|
|
@@ -281,11 +473,12 @@ async def _map_invocation(
|
|
|
281
473
|
input_jwts=input_jwts,
|
|
282
474
|
)
|
|
283
475
|
get_response_task = asyncio.create_task(
|
|
284
|
-
|
|
285
|
-
client.stub.FunctionGetOutputs,
|
|
476
|
+
client.stub.FunctionGetOutputs(
|
|
286
477
|
request,
|
|
287
|
-
|
|
288
|
-
|
|
478
|
+
retry=Retry(
|
|
479
|
+
max_retries=20,
|
|
480
|
+
attempt_timeout=OUTPUTS_TIMEOUT + ATTEMPT_TIMEOUT_GRACE_PERIOD,
|
|
481
|
+
),
|
|
289
482
|
)
|
|
290
483
|
)
|
|
291
484
|
map_done_task = asyncio.create_task(map_done_event.wait())
|
|
@@ -344,7 +537,7 @@ async def _map_invocation(
|
|
|
344
537
|
clear_on_success=True,
|
|
345
538
|
requested_at=time.time(),
|
|
346
539
|
)
|
|
347
|
-
await
|
|
540
|
+
await client.stub.FunctionGetOutputs(request)
|
|
348
541
|
await retry_queue.close()
|
|
349
542
|
|
|
350
543
|
async def fetch_output(item: api_pb2.FunctionGetOutputsItem) -> tuple[int, Any]:
|
|
@@ -395,8 +588,11 @@ async def _map_invocation(
|
|
|
395
588
|
def log_stats():
|
|
396
589
|
logger.debug(
|
|
397
590
|
f"Map stats: sync_client_retries_enabled={sync_client_retries_enabled} "
|
|
398
|
-
f"have_all_inputs={have_all_inputs}
|
|
399
|
-
f"
|
|
591
|
+
f"have_all_inputs={have_all_inputs} "
|
|
592
|
+
f"inputs_created={inputs_created} "
|
|
593
|
+
f"input_sent={input_pumper.inputs_sent} "
|
|
594
|
+
f"inputs_retried={input_pumper.inputs_retried} "
|
|
595
|
+
f"outputs_received={outputs_received} "
|
|
400
596
|
f"successful_completions={successful_completions} failed_completions={failed_completions} "
|
|
401
597
|
f"no_context_duplicates={no_context_duplicates} old_retry_duplicates={stale_retry_duplicates} "
|
|
402
598
|
f"already_complete_duplicates={already_complete_duplicates} "
|
|
@@ -415,7 +611,12 @@ async def _map_invocation(
|
|
|
415
611
|
|
|
416
612
|
log_debug_stats_task = asyncio.create_task(log_debug_stats())
|
|
417
613
|
async with aclosing(
|
|
418
|
-
async_merge(
|
|
614
|
+
async_merge(
|
|
615
|
+
input_preprocessor.drain_input_generator(),
|
|
616
|
+
input_pumper.pump_inputs(),
|
|
617
|
+
input_pumper.retry_inputs(),
|
|
618
|
+
poll_outputs(),
|
|
619
|
+
)
|
|
419
620
|
) as streamer:
|
|
420
621
|
async for response in streamer:
|
|
421
622
|
if response is not None: # type: ignore[unreachable]
|
|
@@ -424,6 +625,367 @@ async def _map_invocation(
|
|
|
424
625
|
await log_debug_stats_task
|
|
425
626
|
|
|
426
627
|
|
|
628
|
+
async def _map_invocation_inputplane(
|
|
629
|
+
function: "modal.functions._Function",
|
|
630
|
+
raw_input_queue: _SynchronizedQueue,
|
|
631
|
+
client: "modal.client._Client",
|
|
632
|
+
order_outputs: bool,
|
|
633
|
+
return_exceptions: bool,
|
|
634
|
+
wrap_returned_exceptions: bool,
|
|
635
|
+
count_update_callback: Optional[Callable[[int, int], None]],
|
|
636
|
+
) -> typing.AsyncGenerator[Any, None]:
|
|
637
|
+
"""Input-plane implementation of a function map invocation.
|
|
638
|
+
|
|
639
|
+
This is analogous to `_map_invocation`, but instead of the control-plane
|
|
640
|
+
`FunctionMap` / `FunctionPutInputs` / `FunctionGetOutputs` RPCs it speaks
|
|
641
|
+
the input-plane protocol consisting of `MapStartOrContinue`, `MapAwait`, and `MapCheckInputs`.
|
|
642
|
+
"""
|
|
643
|
+
|
|
644
|
+
assert function._input_plane_url, "_map_invocation_inputplane should only be used for input-plane backed functions"
|
|
645
|
+
|
|
646
|
+
input_plane_stub = await client.get_stub(function._input_plane_url)
|
|
647
|
+
|
|
648
|
+
# Required for _create_input.
|
|
649
|
+
assert client.stub, "Client must be hydrated with a stub for _map_invocation_inputplane"
|
|
650
|
+
|
|
651
|
+
# ------------------------------------------------------------
|
|
652
|
+
# Invocation-wide state
|
|
653
|
+
# ------------------------------------------------------------
|
|
654
|
+
|
|
655
|
+
have_all_inputs = False
|
|
656
|
+
map_done_event = asyncio.Event()
|
|
657
|
+
|
|
658
|
+
inputs_created = 0
|
|
659
|
+
outputs_completed = 0
|
|
660
|
+
successful_completions = 0
|
|
661
|
+
failed_completions = 0
|
|
662
|
+
no_context_duplicates = 0
|
|
663
|
+
stale_retry_duplicates = 0
|
|
664
|
+
already_complete_duplicates = 0
|
|
665
|
+
retried_outputs = 0
|
|
666
|
+
input_queue_size = 0
|
|
667
|
+
last_entry_id = ""
|
|
668
|
+
|
|
669
|
+
# The input-plane server returns this after the first request.
|
|
670
|
+
map_token = None
|
|
671
|
+
map_token_received = asyncio.Event()
|
|
672
|
+
|
|
673
|
+
# Single priority queue that holds *both* fresh inputs (timestamp == now)
|
|
674
|
+
# and future retries (timestamp > now).
|
|
675
|
+
queue: TimestampPriorityQueue[api_pb2.MapStartOrContinueItem] = TimestampPriorityQueue()
|
|
676
|
+
|
|
677
|
+
# Maximum number of inputs that may be in-flight (the server sends this in
|
|
678
|
+
# the first response – fall back to the default if we never receive it for
|
|
679
|
+
# any reason).
|
|
680
|
+
max_inputs_outstanding = MAX_INPUTS_OUTSTANDING_DEFAULT
|
|
681
|
+
|
|
682
|
+
# Set a default retry policy to construct an instance of _MapItemsManager.
|
|
683
|
+
# We'll update the retry policy with the actual user-specified retry policy
|
|
684
|
+
# from the server in the first MapStartOrContinue response.
|
|
685
|
+
retry_policy = api_pb2.FunctionRetryPolicy(
|
|
686
|
+
retries=0,
|
|
687
|
+
initial_delay_ms=1000,
|
|
688
|
+
max_delay_ms=1000,
|
|
689
|
+
backoff_coefficient=1.0,
|
|
690
|
+
)
|
|
691
|
+
map_items_manager = _MapItemsManager(
|
|
692
|
+
retry_policy=retry_policy,
|
|
693
|
+
function_call_invocation_type=api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC,
|
|
694
|
+
retry_queue=queue,
|
|
695
|
+
sync_client_retries_enabled=True,
|
|
696
|
+
max_inputs_outstanding=MAX_INPUTS_OUTSTANDING_DEFAULT,
|
|
697
|
+
is_input_plane_instance=True,
|
|
698
|
+
)
|
|
699
|
+
|
|
700
|
+
def update_counters(
|
|
701
|
+
created_delta: int = 0, completed_delta: int = 0, set_have_all_inputs: Union[bool, None] = None
|
|
702
|
+
):
|
|
703
|
+
nonlocal inputs_created, outputs_completed, have_all_inputs
|
|
704
|
+
|
|
705
|
+
if created_delta:
|
|
706
|
+
inputs_created += created_delta
|
|
707
|
+
if completed_delta:
|
|
708
|
+
outputs_completed += completed_delta
|
|
709
|
+
if set_have_all_inputs is not None:
|
|
710
|
+
have_all_inputs = set_have_all_inputs
|
|
711
|
+
|
|
712
|
+
if count_update_callback is not None:
|
|
713
|
+
count_update_callback(outputs_completed, inputs_created)
|
|
714
|
+
|
|
715
|
+
if have_all_inputs and outputs_completed >= inputs_created:
|
|
716
|
+
map_done_event.set()
|
|
717
|
+
|
|
718
|
+
async def create_input(argskwargs):
|
|
719
|
+
idx = inputs_created + 1 # 1-indexed map call idx
|
|
720
|
+
update_counters(created_delta=1)
|
|
721
|
+
(args, kwargs) = argskwargs
|
|
722
|
+
put_item: api_pb2.FunctionPutInputsItem = await _create_input(
|
|
723
|
+
args,
|
|
724
|
+
kwargs,
|
|
725
|
+
client.stub,
|
|
726
|
+
idx=idx,
|
|
727
|
+
function=function,
|
|
728
|
+
)
|
|
729
|
+
return api_pb2.MapStartOrContinueItem(input=put_item)
|
|
730
|
+
|
|
731
|
+
async def input_iter():
|
|
732
|
+
while True:
|
|
733
|
+
raw_input = await raw_input_queue.get()
|
|
734
|
+
if raw_input is None: # end of input sentinel
|
|
735
|
+
break
|
|
736
|
+
yield raw_input # args, kwargs
|
|
737
|
+
|
|
738
|
+
async def drain_input_generator():
|
|
739
|
+
async with aclosing(
|
|
740
|
+
async_map_ordered(input_iter(), create_input, concurrency=BLOB_MAX_PARALLELISM)
|
|
741
|
+
) as streamer:
|
|
742
|
+
async for q_item in streamer:
|
|
743
|
+
await queue.put(time.time(), q_item)
|
|
744
|
+
|
|
745
|
+
# All inputs have been read.
|
|
746
|
+
update_counters(set_have_all_inputs=True)
|
|
747
|
+
yield
|
|
748
|
+
|
|
749
|
+
async def pump_inputs():
|
|
750
|
+
nonlocal map_token, max_inputs_outstanding
|
|
751
|
+
async for batch in queue_batch_iterator(queue, max_batch_size=MAP_INVOCATION_CHUNK_SIZE):
|
|
752
|
+
# Convert the queued items into the proto format expected by the RPC.
|
|
753
|
+
request_items: list[api_pb2.MapStartOrContinueItem] = [
|
|
754
|
+
api_pb2.MapStartOrContinueItem(input=qi.input, attempt_token=qi.attempt_token) for qi in batch
|
|
755
|
+
]
|
|
756
|
+
|
|
757
|
+
await map_items_manager.add_items_inputplane(request_items)
|
|
758
|
+
|
|
759
|
+
# Build request
|
|
760
|
+
request = api_pb2.MapStartOrContinueRequest(
|
|
761
|
+
function_id=function.object_id,
|
|
762
|
+
map_token=map_token,
|
|
763
|
+
parent_input_id=current_input_id() or "",
|
|
764
|
+
items=request_items,
|
|
765
|
+
)
|
|
766
|
+
|
|
767
|
+
metadata = await client.get_input_plane_metadata(function._input_plane_region)
|
|
768
|
+
|
|
769
|
+
response: api_pb2.MapStartOrContinueResponse = await input_plane_stub.MapStartOrContinue(
|
|
770
|
+
request,
|
|
771
|
+
retry=Retry(
|
|
772
|
+
additional_status_codes=[Status.RESOURCE_EXHAUSTED],
|
|
773
|
+
max_delay=PUMP_INPUTS_MAX_RETRY_DELAY,
|
|
774
|
+
max_retries=None,
|
|
775
|
+
),
|
|
776
|
+
metadata=metadata,
|
|
777
|
+
)
|
|
778
|
+
|
|
779
|
+
# match response items to the corresponding request item index
|
|
780
|
+
response_items_idx_tuple = [
|
|
781
|
+
(request_items[idx].input.idx, attempt_token)
|
|
782
|
+
for idx, attempt_token in enumerate(response.attempt_tokens)
|
|
783
|
+
]
|
|
784
|
+
|
|
785
|
+
map_items_manager.handle_put_continue_response(response_items_idx_tuple)
|
|
786
|
+
|
|
787
|
+
# Set the function call id and actual retry policy with the data from the first response.
|
|
788
|
+
# This conditional is skipped for subsequent iterations of this for-loop.
|
|
789
|
+
if map_token is None:
|
|
790
|
+
map_token = response.map_token
|
|
791
|
+
map_token_received.set()
|
|
792
|
+
max_inputs_outstanding = response.max_inputs_outstanding or MAX_INPUTS_OUTSTANDING_DEFAULT
|
|
793
|
+
map_items_manager.set_retry_policy(response.retry_policy)
|
|
794
|
+
# Update the retry policy for the first batch of inputs.
|
|
795
|
+
# Subsequent batches will have the correct user-specified retry policy
|
|
796
|
+
# set by the updated _MapItemsManager.
|
|
797
|
+
map_items_manager.update_items_retry_policy(response.retry_policy)
|
|
798
|
+
yield
|
|
799
|
+
|
|
800
|
+
async def check_lost_inputs():
|
|
801
|
+
nonlocal last_entry_id # shared with get_all_outputs
|
|
802
|
+
try:
|
|
803
|
+
while not map_done_event.is_set():
|
|
804
|
+
if map_token is None:
|
|
805
|
+
await map_token_received.wait()
|
|
806
|
+
continue
|
|
807
|
+
|
|
808
|
+
sleep_task = asyncio.create_task(asyncio.sleep(1))
|
|
809
|
+
map_done_task = asyncio.create_task(map_done_event.wait())
|
|
810
|
+
done, _ = await asyncio.wait([sleep_task, map_done_task], return_when=FIRST_COMPLETED)
|
|
811
|
+
if map_done_task in done:
|
|
812
|
+
break
|
|
813
|
+
|
|
814
|
+
# check_inputs = [(idx, attempt_token), ...]
|
|
815
|
+
check_inputs = map_items_manager.get_input_idxs_waiting_for_output()
|
|
816
|
+
attempt_tokens = [attempt_token for _, attempt_token in check_inputs]
|
|
817
|
+
request = api_pb2.MapCheckInputsRequest(
|
|
818
|
+
last_entry_id=last_entry_id,
|
|
819
|
+
timeout=0, # Non-blocking read
|
|
820
|
+
attempt_tokens=attempt_tokens,
|
|
821
|
+
)
|
|
822
|
+
|
|
823
|
+
metadata = await client.get_input_plane_metadata(function._input_plane_region)
|
|
824
|
+
response: api_pb2.MapCheckInputsResponse = await input_plane_stub.MapCheckInputs(
|
|
825
|
+
request, metadata=metadata
|
|
826
|
+
)
|
|
827
|
+
check_inputs_response = [
|
|
828
|
+
(check_inputs[resp_idx][0], response.lost[resp_idx]) for resp_idx, _ in enumerate(response.lost)
|
|
829
|
+
]
|
|
830
|
+
# check_inputs_response = [(idx, lost: bool), ...]
|
|
831
|
+
await map_items_manager.handle_check_inputs_response(check_inputs_response)
|
|
832
|
+
yield
|
|
833
|
+
except asyncio.CancelledError:
|
|
834
|
+
pass
|
|
835
|
+
|
|
836
|
+
async def get_all_outputs():
|
|
837
|
+
nonlocal \
|
|
838
|
+
successful_completions, \
|
|
839
|
+
failed_completions, \
|
|
840
|
+
no_context_duplicates, \
|
|
841
|
+
stale_retry_duplicates, \
|
|
842
|
+
already_complete_duplicates, \
|
|
843
|
+
retried_outputs, \
|
|
844
|
+
last_entry_id
|
|
845
|
+
|
|
846
|
+
while not map_done_event.is_set():
|
|
847
|
+
if map_token is None:
|
|
848
|
+
await map_token_received.wait()
|
|
849
|
+
continue
|
|
850
|
+
|
|
851
|
+
request = api_pb2.MapAwaitRequest(
|
|
852
|
+
map_token=map_token,
|
|
853
|
+
last_entry_id=last_entry_id,
|
|
854
|
+
requested_at=time.time(),
|
|
855
|
+
timeout=OUTPUTS_TIMEOUT,
|
|
856
|
+
)
|
|
857
|
+
metadata = await client.get_input_plane_metadata(function._input_plane_region)
|
|
858
|
+
get_response_task = asyncio.create_task(
|
|
859
|
+
input_plane_stub.MapAwait(
|
|
860
|
+
request,
|
|
861
|
+
retry=Retry(
|
|
862
|
+
max_retries=20,
|
|
863
|
+
attempt_timeout=OUTPUTS_TIMEOUT + ATTEMPT_TIMEOUT_GRACE_PERIOD,
|
|
864
|
+
),
|
|
865
|
+
metadata=metadata,
|
|
866
|
+
)
|
|
867
|
+
)
|
|
868
|
+
map_done_task = asyncio.create_task(map_done_event.wait())
|
|
869
|
+
try:
|
|
870
|
+
done, pending = await asyncio.wait([get_response_task, map_done_task], return_when=FIRST_COMPLETED)
|
|
871
|
+
if get_response_task in done:
|
|
872
|
+
map_done_task.cancel()
|
|
873
|
+
response = get_response_task.result()
|
|
874
|
+
else:
|
|
875
|
+
assert map_done_event.is_set()
|
|
876
|
+
# map is done - no more outputs, so return early
|
|
877
|
+
return
|
|
878
|
+
finally:
|
|
879
|
+
# clean up tasks, in case of cancellations etc.
|
|
880
|
+
get_response_task.cancel()
|
|
881
|
+
map_done_task.cancel()
|
|
882
|
+
last_entry_id = response.last_entry_id
|
|
883
|
+
|
|
884
|
+
for output_item in response.outputs:
|
|
885
|
+
output_type = await map_items_manager.handle_get_outputs_response(output_item, int(time.time()))
|
|
886
|
+
if output_type == _OutputType.SUCCESSFUL_COMPLETION:
|
|
887
|
+
successful_completions += 1
|
|
888
|
+
elif output_type == _OutputType.FAILED_COMPLETION:
|
|
889
|
+
failed_completions += 1
|
|
890
|
+
elif output_type == _OutputType.RETRYING:
|
|
891
|
+
retried_outputs += 1
|
|
892
|
+
elif output_type == _OutputType.NO_CONTEXT_DUPLICATE:
|
|
893
|
+
no_context_duplicates += 1
|
|
894
|
+
elif output_type == _OutputType.STALE_RETRY_DUPLICATE:
|
|
895
|
+
stale_retry_duplicates += 1
|
|
896
|
+
elif output_type == _OutputType.ALREADY_COMPLETE_DUPLICATE:
|
|
897
|
+
already_complete_duplicates += 1
|
|
898
|
+
else:
|
|
899
|
+
raise Exception(f"Unknown output type: {output_type}")
|
|
900
|
+
|
|
901
|
+
if output_type == _OutputType.SUCCESSFUL_COMPLETION or output_type == _OutputType.FAILED_COMPLETION:
|
|
902
|
+
update_counters(completed_delta=1)
|
|
903
|
+
yield output_item
|
|
904
|
+
|
|
905
|
+
async def get_all_outputs_and_clean_up():
|
|
906
|
+
try:
|
|
907
|
+
async with aclosing(get_all_outputs()) as stream:
|
|
908
|
+
async for item in stream:
|
|
909
|
+
yield item
|
|
910
|
+
finally:
|
|
911
|
+
await queue.close()
|
|
912
|
+
pass
|
|
913
|
+
|
|
914
|
+
async def fetch_output(item: api_pb2.FunctionGetOutputsItem) -> tuple[int, Any]:
|
|
915
|
+
try:
|
|
916
|
+
output = await _process_result(item.result, item.data_format, input_plane_stub, client)
|
|
917
|
+
except Exception as e:
|
|
918
|
+
if return_exceptions:
|
|
919
|
+
if wrap_returned_exceptions:
|
|
920
|
+
# Prior to client 1.0.4 there was a bug where return_exceptions would wrap
|
|
921
|
+
# any returned exceptions in a synchronicity.UserCodeException. This adds
|
|
922
|
+
# deprecated non-breaking compatibility bandaid for migrating away from that:
|
|
923
|
+
output = modal.exception.UserCodeException(e)
|
|
924
|
+
else:
|
|
925
|
+
output = e
|
|
926
|
+
else:
|
|
927
|
+
raise e
|
|
928
|
+
return (item.idx, output)
|
|
929
|
+
|
|
930
|
+
async def poll_outputs():
|
|
931
|
+
# map to store out-of-order outputs received
|
|
932
|
+
received_outputs = {}
|
|
933
|
+
output_idx = 1 # 1-indexed map call idx
|
|
934
|
+
|
|
935
|
+
async with aclosing(
|
|
936
|
+
async_map_ordered(get_all_outputs_and_clean_up(), fetch_output, concurrency=BLOB_MAX_PARALLELISM)
|
|
937
|
+
) as streamer:
|
|
938
|
+
async for idx, output in streamer:
|
|
939
|
+
if not order_outputs:
|
|
940
|
+
yield _OutputValue(output)
|
|
941
|
+
else:
|
|
942
|
+
# hold on to outputs for function maps, so we can reorder them correctly.
|
|
943
|
+
received_outputs[idx] = output
|
|
944
|
+
|
|
945
|
+
while True:
|
|
946
|
+
if output_idx not in received_outputs:
|
|
947
|
+
# we haven't received the output for the current index yet.
|
|
948
|
+
# stop returning outputs to the caller and instead wait for
|
|
949
|
+
# the next output to arrive from the server.
|
|
950
|
+
break
|
|
951
|
+
|
|
952
|
+
output = received_outputs.pop(output_idx)
|
|
953
|
+
yield _OutputValue(output)
|
|
954
|
+
output_idx += 1
|
|
955
|
+
|
|
956
|
+
assert len(received_outputs) == 0
|
|
957
|
+
|
|
958
|
+
async def log_debug_stats():
|
|
959
|
+
def log_stats():
|
|
960
|
+
logger.debug(
|
|
961
|
+
f"Map stats:\nsuccessful_completions={successful_completions} failed_completions={failed_completions} "
|
|
962
|
+
f"no_context_duplicates={no_context_duplicates} stale_retry_duplicates={stale_retry_duplicates} "
|
|
963
|
+
f"already_complete_duplicates={already_complete_duplicates} retried_outputs={retried_outputs} "
|
|
964
|
+
f"map_token={map_token} max_inputs_outstanding={max_inputs_outstanding} "
|
|
965
|
+
f"map_items_manager_size={len(map_items_manager)} input_queue_size={input_queue_size}"
|
|
966
|
+
)
|
|
967
|
+
|
|
968
|
+
while True:
|
|
969
|
+
log_stats()
|
|
970
|
+
try:
|
|
971
|
+
await asyncio.sleep(10)
|
|
972
|
+
except asyncio.CancelledError:
|
|
973
|
+
# Log final stats before exiting
|
|
974
|
+
log_stats()
|
|
975
|
+
break
|
|
976
|
+
|
|
977
|
+
log_task = asyncio.create_task(log_debug_stats())
|
|
978
|
+
|
|
979
|
+
async with aclosing(
|
|
980
|
+
async_merge(drain_input_generator(), pump_inputs(), poll_outputs(), check_lost_inputs())
|
|
981
|
+
) as merged:
|
|
982
|
+
async for maybe_output in merged:
|
|
983
|
+
if maybe_output is not None: # ignore None sentinels
|
|
984
|
+
yield maybe_output.value
|
|
985
|
+
|
|
986
|
+
log_task.cancel()
|
|
987
|
+
|
|
988
|
+
|
|
427
989
|
async def _map_helper(
|
|
428
990
|
self: "modal.functions.Function",
|
|
429
991
|
async_input_gen: typing.AsyncGenerator[Any, None],
|
|
@@ -620,6 +1182,56 @@ def _map_sync(
|
|
|
620
1182
|
)
|
|
621
1183
|
|
|
622
1184
|
|
|
1185
|
+
async def _experimental_spawn_map_async(self, *input_iterators, kwargs={}) -> "modal.functions._FunctionCall":
|
|
1186
|
+
async_input_gen = async_zip(*[sync_or_async_iter(it) for it in input_iterators])
|
|
1187
|
+
return await _spawn_map_helper(self, async_input_gen, kwargs)
|
|
1188
|
+
|
|
1189
|
+
|
|
1190
|
+
async def _spawn_map_helper(
|
|
1191
|
+
self: "modal.functions.Function", async_input_gen, kwargs={}
|
|
1192
|
+
) -> "modal.functions._FunctionCall":
|
|
1193
|
+
raw_input_queue: Any = SynchronizedQueue() # type: ignore
|
|
1194
|
+
await raw_input_queue.init.aio()
|
|
1195
|
+
|
|
1196
|
+
async def feed_queue():
|
|
1197
|
+
async with aclosing(async_input_gen) as streamer:
|
|
1198
|
+
async for args in streamer:
|
|
1199
|
+
await raw_input_queue.put.aio((args, kwargs))
|
|
1200
|
+
await raw_input_queue.put.aio(None) # end-of-input sentinel
|
|
1201
|
+
|
|
1202
|
+
fc, _ = await asyncio.gather(self._spawn_map.aio(raw_input_queue), feed_queue())
|
|
1203
|
+
return fc
|
|
1204
|
+
|
|
1205
|
+
|
|
1206
|
+
def _experimental_spawn_map_sync(self, *input_iterators, kwargs={}) -> "modal.functions._FunctionCall":
|
|
1207
|
+
"""mdmd:hidden
|
|
1208
|
+
Spawn parallel execution over a set of inputs, returning as soon as the inputs are created.
|
|
1209
|
+
|
|
1210
|
+
Unlike `modal.Function.map`, this method does not block on completion of the remote execution but
|
|
1211
|
+
returns a `modal.FunctionCall` object that can be used to poll status and retrieve results later.
|
|
1212
|
+
|
|
1213
|
+
Takes one iterator argument per argument in the function being mapped over.
|
|
1214
|
+
|
|
1215
|
+
Example:
|
|
1216
|
+
```python
|
|
1217
|
+
@app.function()
|
|
1218
|
+
def my_func(a, b):
|
|
1219
|
+
return a ** b
|
|
1220
|
+
|
|
1221
|
+
|
|
1222
|
+
@app.local_entrypoint()
|
|
1223
|
+
def main():
|
|
1224
|
+
fc = my_func.spawn_map([1, 2], [3, 4])
|
|
1225
|
+
```
|
|
1226
|
+
|
|
1227
|
+
"""
|
|
1228
|
+
|
|
1229
|
+
return run_coroutine_in_temporary_event_loop(
|
|
1230
|
+
_experimental_spawn_map_async(self, *input_iterators, kwargs=kwargs),
|
|
1231
|
+
"You can't run Function.spawn_map() from an async function. Use Function.spawn_map.aio() instead.",
|
|
1232
|
+
)
|
|
1233
|
+
|
|
1234
|
+
|
|
623
1235
|
async def _spawn_map_async(self, *input_iterators, kwargs={}) -> None:
|
|
624
1236
|
"""This runs in an event loop on the main thread. It consumes inputs from the input iterators and creates async
|
|
625
1237
|
function calls for each.
|
|
@@ -756,12 +1368,19 @@ class _MapItemContext:
|
|
|
756
1368
|
sync_client_retries_enabled: bool
|
|
757
1369
|
# Both these futures are strings. Omitting generic type because
|
|
758
1370
|
# it causes an error when running `inv protoc type-stubs`.
|
|
1371
|
+
# Unused. But important, input_id is not set for inputplane invocations.
|
|
759
1372
|
input_id: asyncio.Future
|
|
760
1373
|
input_jwt: asyncio.Future
|
|
761
1374
|
previous_input_jwt: Optional[str]
|
|
762
1375
|
_event_loop: asyncio.AbstractEventLoop
|
|
763
1376
|
|
|
764
|
-
def __init__(
|
|
1377
|
+
def __init__(
|
|
1378
|
+
self,
|
|
1379
|
+
input: api_pb2.FunctionInput,
|
|
1380
|
+
retry_manager: RetryManager,
|
|
1381
|
+
sync_client_retries_enabled: bool,
|
|
1382
|
+
is_input_plane_instance: bool = False,
|
|
1383
|
+
):
|
|
765
1384
|
self.state = _MapItemState.SENDING
|
|
766
1385
|
self.input = input
|
|
767
1386
|
self.retry_manager = retry_manager
|
|
@@ -772,7 +1391,22 @@ class _MapItemContext:
|
|
|
772
1391
|
# a race condition where we could receive outputs before we have
|
|
773
1392
|
# recorded the input ID and JWT in `pending_outputs`.
|
|
774
1393
|
self.input_jwt = self._event_loop.create_future()
|
|
1394
|
+
# Unused. But important, this is not set for inputplane invocations.
|
|
775
1395
|
self.input_id = self._event_loop.create_future()
|
|
1396
|
+
self._is_input_plane_instance = is_input_plane_instance
|
|
1397
|
+
|
|
1398
|
+
def handle_map_start_or_continue_response(self, attempt_token: str):
|
|
1399
|
+
if not self.input_jwt.done():
|
|
1400
|
+
self.input_jwt.set_result(attempt_token)
|
|
1401
|
+
else:
|
|
1402
|
+
# Create a new future for the next value
|
|
1403
|
+
self.input_jwt = asyncio.Future()
|
|
1404
|
+
self.input_jwt.set_result(attempt_token)
|
|
1405
|
+
|
|
1406
|
+
# Set state to WAITING_FOR_OUTPUT only if current state is SENDING. If state is
|
|
1407
|
+
# RETRYING, WAITING_TO_RETRY, or COMPLETE, then we already got the output.
|
|
1408
|
+
if self.state == _MapItemState.SENDING:
|
|
1409
|
+
self.state = _MapItemState.WAITING_FOR_OUTPUT
|
|
776
1410
|
|
|
777
1411
|
def handle_put_inputs_response(self, item: api_pb2.FunctionPutInputsResponseItem):
|
|
778
1412
|
self.input_jwt.set_result(item.input_jwt)
|
|
@@ -799,7 +1433,7 @@ class _MapItemContext:
|
|
|
799
1433
|
if self.state == _MapItemState.COMPLETE:
|
|
800
1434
|
logger.debug(
|
|
801
1435
|
f"Received output for input marked as complete. Must be duplicate, so ignoring. "
|
|
802
|
-
f"idx={item.idx} input_id={item.input_id}
|
|
1436
|
+
f"idx={item.idx} input_id={item.input_id} retry_count={item.retry_count}"
|
|
803
1437
|
)
|
|
804
1438
|
return _OutputType.ALREADY_COMPLETE_DUPLICATE
|
|
805
1439
|
# If the item's retry count doesn't match our retry count, this is probably a duplicate of an old output.
|
|
@@ -847,7 +1481,11 @@ class _MapItemContext:
|
|
|
847
1481
|
|
|
848
1482
|
self.state = _MapItemState.WAITING_TO_RETRY
|
|
849
1483
|
|
|
850
|
-
|
|
1484
|
+
if self._is_input_plane_instance:
|
|
1485
|
+
retry_item = await self.create_map_start_or_continue_item(item.idx)
|
|
1486
|
+
await retry_queue.put(now_seconds + delay_ms / 1_000, retry_item)
|
|
1487
|
+
else:
|
|
1488
|
+
await retry_queue.put(now_seconds + delay_ms / 1_000, item.idx)
|
|
851
1489
|
|
|
852
1490
|
return _OutputType.RETRYING
|
|
853
1491
|
|
|
@@ -862,10 +1500,23 @@ class _MapItemContext:
|
|
|
862
1500
|
retry_count=self.retry_manager.retry_count,
|
|
863
1501
|
)
|
|
864
1502
|
|
|
1503
|
+
def set_retry_policy(self, retry_policy: api_pb2.FunctionRetryPolicy):
|
|
1504
|
+
self.retry_manager = RetryManager(retry_policy)
|
|
1505
|
+
|
|
865
1506
|
def handle_retry_response(self, input_jwt: str):
|
|
866
1507
|
self.input_jwt.set_result(input_jwt)
|
|
867
1508
|
self.state = _MapItemState.WAITING_FOR_OUTPUT
|
|
868
1509
|
|
|
1510
|
+
async def create_map_start_or_continue_item(self, idx: int) -> api_pb2.MapStartOrContinueItem:
|
|
1511
|
+
attempt_token = await self.input_jwt
|
|
1512
|
+
return api_pb2.MapStartOrContinueItem(
|
|
1513
|
+
input=api_pb2.FunctionPutInputsItem(
|
|
1514
|
+
input=self.input,
|
|
1515
|
+
idx=idx,
|
|
1516
|
+
),
|
|
1517
|
+
attempt_token=attempt_token,
|
|
1518
|
+
)
|
|
1519
|
+
|
|
869
1520
|
|
|
870
1521
|
class _MapItemsManager:
|
|
871
1522
|
def __init__(
|
|
@@ -875,6 +1526,7 @@ class _MapItemsManager:
|
|
|
875
1526
|
retry_queue: TimestampPriorityQueue,
|
|
876
1527
|
sync_client_retries_enabled: bool,
|
|
877
1528
|
max_inputs_outstanding: int,
|
|
1529
|
+
is_input_plane_instance: bool = False,
|
|
878
1530
|
):
|
|
879
1531
|
self._retry_policy = retry_policy
|
|
880
1532
|
self.function_call_invocation_type = function_call_invocation_type
|
|
@@ -885,6 +1537,10 @@ class _MapItemsManager:
|
|
|
885
1537
|
self._inputs_outstanding = asyncio.BoundedSemaphore(max_inputs_outstanding)
|
|
886
1538
|
self._item_context: dict[int, _MapItemContext] = {}
|
|
887
1539
|
self._sync_client_retries_enabled = sync_client_retries_enabled
|
|
1540
|
+
self._is_input_plane_instance = is_input_plane_instance
|
|
1541
|
+
|
|
1542
|
+
def set_retry_policy(self, retry_policy: api_pb2.FunctionRetryPolicy):
|
|
1543
|
+
self._retry_policy = retry_policy
|
|
888
1544
|
|
|
889
1545
|
async def add_items(self, items: list[api_pb2.FunctionPutInputsItem]):
|
|
890
1546
|
for item in items:
|
|
@@ -897,9 +1553,28 @@ class _MapItemsManager:
|
|
|
897
1553
|
sync_client_retries_enabled=self._sync_client_retries_enabled,
|
|
898
1554
|
)
|
|
899
1555
|
|
|
1556
|
+
async def add_items_inputplane(self, items: list[api_pb2.MapStartOrContinueItem]):
|
|
1557
|
+
for item in items:
|
|
1558
|
+
# acquire semaphore to limit the number of inputs in progress
|
|
1559
|
+
# (either queued to be sent, waiting for completion, or retrying)
|
|
1560
|
+
if item.attempt_token != "": # if it is a retry item
|
|
1561
|
+
self._item_context[item.input.idx].state = _MapItemState.SENDING
|
|
1562
|
+
continue
|
|
1563
|
+
await self._inputs_outstanding.acquire()
|
|
1564
|
+
self._item_context[item.input.idx] = _MapItemContext(
|
|
1565
|
+
input=item.input.input,
|
|
1566
|
+
retry_manager=RetryManager(self._retry_policy),
|
|
1567
|
+
sync_client_retries_enabled=self._sync_client_retries_enabled,
|
|
1568
|
+
is_input_plane_instance=self._is_input_plane_instance,
|
|
1569
|
+
)
|
|
1570
|
+
|
|
900
1571
|
async def prepare_items_for_retry(self, retriable_idxs: list[int]) -> list[api_pb2.FunctionRetryInputsItem]:
|
|
901
1572
|
return [await self._item_context[idx].prepare_item_for_retry() for idx in retriable_idxs]
|
|
902
1573
|
|
|
1574
|
+
def update_items_retry_policy(self, retry_policy: api_pb2.FunctionRetryPolicy):
|
|
1575
|
+
for ctx in self._item_context.values():
|
|
1576
|
+
ctx.set_retry_policy(retry_policy)
|
|
1577
|
+
|
|
903
1578
|
def get_input_jwts_waiting_for_output(self) -> list[str]:
|
|
904
1579
|
"""
|
|
905
1580
|
Returns a list of input_jwts for inputs that are waiting for output.
|
|
@@ -911,6 +1586,17 @@ class _MapItemsManager:
|
|
|
911
1586
|
if ctx.state == _MapItemState.WAITING_FOR_OUTPUT and ctx.input_jwt.done()
|
|
912
1587
|
]
|
|
913
1588
|
|
|
1589
|
+
def get_input_idxs_waiting_for_output(self) -> list[tuple[int, str]]:
|
|
1590
|
+
"""
|
|
1591
|
+
Returns a list of input_idxs for inputs that are waiting for output.
|
|
1592
|
+
"""
|
|
1593
|
+
# Idx doesn't need a future because it is set by client and not server.
|
|
1594
|
+
return [
|
|
1595
|
+
(idx, ctx.input_jwt.result())
|
|
1596
|
+
for idx, ctx in self._item_context.items()
|
|
1597
|
+
if ctx.state == _MapItemState.WAITING_FOR_OUTPUT and ctx.input_jwt.done()
|
|
1598
|
+
]
|
|
1599
|
+
|
|
914
1600
|
def _remove_item(self, item_idx: int):
|
|
915
1601
|
del self._item_context[item_idx]
|
|
916
1602
|
self._inputs_outstanding.release()
|
|
@@ -918,6 +1604,18 @@ class _MapItemsManager:
|
|
|
918
1604
|
def get_item_context(self, item_idx: int) -> _MapItemContext:
|
|
919
1605
|
return self._item_context.get(item_idx)
|
|
920
1606
|
|
|
1607
|
+
def handle_put_continue_response(
|
|
1608
|
+
self,
|
|
1609
|
+
items: list[tuple[int, str]], # idx, input_jwt
|
|
1610
|
+
):
|
|
1611
|
+
for index, item in items:
|
|
1612
|
+
ctx = self._item_context.get(index, None)
|
|
1613
|
+
# If the context is None, then get_all_outputs() has already received a successful
|
|
1614
|
+
# output, and deleted the context. This happens if FunctionGetOutputs completes
|
|
1615
|
+
# before MapStartOrContinueResponse is received.
|
|
1616
|
+
if ctx is not None:
|
|
1617
|
+
ctx.handle_map_start_or_continue_response(item)
|
|
1618
|
+
|
|
921
1619
|
def handle_put_inputs_response(self, items: list[api_pb2.FunctionPutInputsResponseItem]):
|
|
922
1620
|
for item in items:
|
|
923
1621
|
ctx = self._item_context.get(item.idx, None)
|
|
@@ -937,6 +1635,16 @@ class _MapItemsManager:
|
|
|
937
1635
|
if ctx is not None:
|
|
938
1636
|
ctx.handle_retry_response(input_jwt)
|
|
939
1637
|
|
|
1638
|
+
async def handle_check_inputs_response(self, response: list[tuple[int, bool]]):
|
|
1639
|
+
for idx, lost in response:
|
|
1640
|
+
ctx = self._item_context.get(idx, None)
|
|
1641
|
+
if ctx is not None:
|
|
1642
|
+
if lost:
|
|
1643
|
+
ctx.state = _MapItemState.WAITING_TO_RETRY
|
|
1644
|
+
retry_item = await ctx.create_map_start_or_continue_item(idx)
|
|
1645
|
+
_ = ctx.retry_manager.get_delay_ms() # increment retry count but instant retry for lost inputs
|
|
1646
|
+
await self._retry_queue.put(time.time(), retry_item)
|
|
1647
|
+
|
|
940
1648
|
async def handle_get_outputs_response(self, item: api_pb2.FunctionGetOutputsItem, now_seconds: int) -> _OutputType:
|
|
941
1649
|
ctx = self._item_context.get(item.idx, None)
|
|
942
1650
|
if ctx is None:
|