modal 1.0.3.dev10__py3-none-any.whl → 1.2.3.dev7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of modal might be problematic. Click here for more details.
- modal/__init__.py +0 -2
- modal/__main__.py +3 -4
- modal/_billing.py +80 -0
- modal/_clustered_functions.py +7 -3
- modal/_clustered_functions.pyi +15 -3
- modal/_container_entrypoint.py +51 -69
- modal/_functions.py +508 -240
- modal/_grpc_client.py +171 -0
- modal/_load_context.py +105 -0
- modal/_object.py +81 -21
- modal/_output.py +58 -45
- modal/_partial_function.py +48 -73
- modal/_pty.py +7 -3
- modal/_resolver.py +26 -46
- modal/_runtime/asgi.py +4 -3
- modal/_runtime/container_io_manager.py +358 -220
- modal/_runtime/container_io_manager.pyi +296 -101
- modal/_runtime/execution_context.py +18 -2
- modal/_runtime/execution_context.pyi +64 -7
- modal/_runtime/gpu_memory_snapshot.py +262 -57
- modal/_runtime/user_code_imports.py +28 -58
- modal/_serialization.py +90 -6
- modal/_traceback.py +42 -1
- modal/_tunnel.pyi +380 -12
- modal/_utils/async_utils.py +84 -29
- modal/_utils/auth_token_manager.py +111 -0
- modal/_utils/blob_utils.py +181 -58
- modal/_utils/deprecation.py +19 -0
- modal/_utils/function_utils.py +91 -47
- modal/_utils/grpc_utils.py +89 -66
- modal/_utils/mount_utils.py +26 -1
- modal/_utils/name_utils.py +17 -3
- modal/_utils/task_command_router_client.py +536 -0
- modal/_utils/time_utils.py +34 -6
- modal/app.py +256 -88
- modal/app.pyi +909 -92
- modal/billing.py +5 -0
- modal/builder/2025.06.txt +18 -0
- modal/builder/PREVIEW.txt +18 -0
- modal/builder/base-images.json +58 -0
- modal/cli/_download.py +19 -3
- modal/cli/_traceback.py +3 -2
- modal/cli/app.py +4 -4
- modal/cli/cluster.py +15 -7
- modal/cli/config.py +5 -3
- modal/cli/container.py +7 -6
- modal/cli/dict.py +22 -16
- modal/cli/entry_point.py +12 -5
- modal/cli/environment.py +5 -4
- modal/cli/import_refs.py +3 -3
- modal/cli/launch.py +102 -5
- modal/cli/network_file_system.py +11 -12
- modal/cli/profile.py +3 -2
- modal/cli/programs/launch_instance_ssh.py +94 -0
- modal/cli/programs/run_jupyter.py +1 -1
- modal/cli/programs/run_marimo.py +95 -0
- modal/cli/programs/vscode.py +1 -1
- modal/cli/queues.py +57 -26
- modal/cli/run.py +91 -23
- modal/cli/secret.py +48 -22
- modal/cli/token.py +7 -8
- modal/cli/utils.py +4 -7
- modal/cli/volume.py +31 -25
- modal/client.py +15 -85
- modal/client.pyi +183 -62
- modal/cloud_bucket_mount.py +5 -3
- modal/cloud_bucket_mount.pyi +197 -5
- modal/cls.py +200 -126
- modal/cls.pyi +446 -68
- modal/config.py +29 -11
- modal/container_process.py +319 -19
- modal/container_process.pyi +190 -20
- modal/dict.py +290 -71
- modal/dict.pyi +835 -83
- modal/environments.py +15 -27
- modal/environments.pyi +46 -24
- modal/exception.py +14 -2
- modal/experimental/__init__.py +194 -40
- modal/experimental/flash.py +618 -0
- modal/experimental/flash.pyi +380 -0
- modal/experimental/ipython.py +11 -7
- modal/file_io.py +29 -36
- modal/file_io.pyi +251 -53
- modal/file_pattern_matcher.py +56 -16
- modal/functions.pyi +673 -92
- modal/gpu.py +1 -1
- modal/image.py +528 -176
- modal/image.pyi +1572 -145
- modal/io_streams.py +458 -128
- modal/io_streams.pyi +433 -52
- modal/mount.py +216 -151
- modal/mount.pyi +225 -78
- modal/network_file_system.py +45 -62
- modal/network_file_system.pyi +277 -56
- modal/object.pyi +93 -17
- modal/parallel_map.py +942 -129
- modal/parallel_map.pyi +294 -15
- modal/partial_function.py +0 -2
- modal/partial_function.pyi +234 -19
- modal/proxy.py +17 -8
- modal/proxy.pyi +36 -3
- modal/queue.py +270 -65
- modal/queue.pyi +817 -57
- modal/runner.py +115 -101
- modal/runner.pyi +205 -49
- modal/sandbox.py +512 -136
- modal/sandbox.pyi +845 -111
- modal/schedule.py +1 -1
- modal/secret.py +300 -70
- modal/secret.pyi +589 -34
- modal/serving.py +7 -11
- modal/serving.pyi +7 -8
- modal/snapshot.py +11 -8
- modal/snapshot.pyi +25 -4
- modal/token_flow.py +4 -4
- modal/token_flow.pyi +28 -8
- modal/volume.py +416 -158
- modal/volume.pyi +1117 -121
- {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/METADATA +10 -9
- modal-1.2.3.dev7.dist-info/RECORD +195 -0
- modal_docs/mdmd/mdmd.py +17 -4
- modal_proto/api.proto +534 -79
- modal_proto/api_grpc.py +337 -1
- modal_proto/api_pb2.py +1522 -968
- modal_proto/api_pb2.pyi +1619 -134
- modal_proto/api_pb2_grpc.py +699 -4
- modal_proto/api_pb2_grpc.pyi +226 -14
- modal_proto/modal_api_grpc.py +175 -154
- modal_proto/sandbox_router.proto +145 -0
- modal_proto/sandbox_router_grpc.py +105 -0
- modal_proto/sandbox_router_pb2.py +149 -0
- modal_proto/sandbox_router_pb2.pyi +333 -0
- modal_proto/sandbox_router_pb2_grpc.py +203 -0
- modal_proto/sandbox_router_pb2_grpc.pyi +75 -0
- modal_proto/task_command_router.proto +144 -0
- modal_proto/task_command_router_grpc.py +105 -0
- modal_proto/task_command_router_pb2.py +149 -0
- modal_proto/task_command_router_pb2.pyi +333 -0
- modal_proto/task_command_router_pb2_grpc.py +203 -0
- modal_proto/task_command_router_pb2_grpc.pyi +75 -0
- modal_version/__init__.py +1 -1
- modal/requirements/PREVIEW.txt +0 -16
- modal/requirements/base-images.json +0 -26
- modal-1.0.3.dev10.dist-info/RECORD +0 -179
- modal_proto/modal_options_grpc.py +0 -3
- modal_proto/options.proto +0 -19
- modal_proto/options_grpc.py +0 -3
- modal_proto/options_pb2.py +0 -35
- modal_proto/options_pb2.pyi +0 -20
- modal_proto/options_pb2_grpc.py +0 -4
- modal_proto/options_pb2_grpc.pyi +0 -7
- /modal/{requirements → builder}/2023.12.312.txt +0 -0
- /modal/{requirements → builder}/2023.12.txt +0 -0
- /modal/{requirements → builder}/2024.04.txt +0 -0
- /modal/{requirements → builder}/2024.10.txt +0 -0
- /modal/{requirements → builder}/README.md +0 -0
- {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/WHEEL +0 -0
- {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/entry_points.txt +0 -0
- {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/licenses/LICENSE +0 -0
- {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/top_level.txt +0 -0
modal/parallel_map.py
CHANGED
|
@@ -1,13 +1,16 @@
|
|
|
1
1
|
# Copyright Modal Labs 2024
|
|
2
2
|
import asyncio
|
|
3
3
|
import enum
|
|
4
|
+
import inspect
|
|
4
5
|
import time
|
|
5
6
|
import typing
|
|
7
|
+
from asyncio import FIRST_COMPLETED
|
|
6
8
|
from dataclasses import dataclass
|
|
7
|
-
from typing import Any, Callable, Optional
|
|
9
|
+
from typing import Any, Callable, Optional, Union
|
|
8
10
|
|
|
9
11
|
from grpclib import Status
|
|
10
12
|
|
|
13
|
+
import modal.exception
|
|
11
14
|
from modal._runtime.execution_context import current_input_id
|
|
12
15
|
from modal._utils.async_utils import (
|
|
13
16
|
AsyncOrSyncIterable,
|
|
@@ -25,13 +28,14 @@ from modal._utils.async_utils import (
|
|
|
25
28
|
warn_if_generator_is_not_consumed,
|
|
26
29
|
)
|
|
27
30
|
from modal._utils.blob_utils import BLOB_MAX_PARALLELISM
|
|
31
|
+
from modal._utils.deprecation import deprecation_warning
|
|
28
32
|
from modal._utils.function_utils import (
|
|
29
33
|
ATTEMPT_TIMEOUT_GRACE_PERIOD,
|
|
30
34
|
OUTPUTS_TIMEOUT,
|
|
31
35
|
_create_input,
|
|
32
36
|
_process_result,
|
|
33
37
|
)
|
|
34
|
-
from modal._utils.grpc_utils import RETRYABLE_GRPC_STATUS_CODES,
|
|
38
|
+
from modal._utils.grpc_utils import RETRYABLE_GRPC_STATUS_CODES, Retry, RetryWarningMessage
|
|
35
39
|
from modal._utils.jwt_utils import DecodedJwt
|
|
36
40
|
from modal.config import logger
|
|
37
41
|
from modal.retries import RetryManager
|
|
@@ -75,19 +79,293 @@ class _OutputValue:
|
|
|
75
79
|
|
|
76
80
|
MAX_INPUTS_OUTSTANDING_DEFAULT = 1000
|
|
77
81
|
|
|
78
|
-
#
|
|
82
|
+
# Maximum number of inputs to send to the server per FunctionPutInputs request
|
|
79
83
|
MAP_INVOCATION_CHUNK_SIZE = 49
|
|
84
|
+
SPAWN_MAP_INVOCATION_CHUNK_SIZE = 512
|
|
85
|
+
|
|
80
86
|
|
|
81
87
|
if typing.TYPE_CHECKING:
|
|
82
88
|
import modal.functions
|
|
83
89
|
|
|
84
90
|
|
|
91
|
+
class InputPreprocessor:
|
|
92
|
+
"""
|
|
93
|
+
Constructs FunctionPutInputsItem objects from the raw-input queue, and puts them in the processed-input queue.
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
def __init__(
|
|
97
|
+
self,
|
|
98
|
+
client: "modal.client._Client",
|
|
99
|
+
*,
|
|
100
|
+
raw_input_queue: _SynchronizedQueue,
|
|
101
|
+
processed_input_queue: asyncio.Queue,
|
|
102
|
+
function: "modal.functions._Function",
|
|
103
|
+
created_callback: Callable[[int], None],
|
|
104
|
+
done_callback: Callable[[], None],
|
|
105
|
+
):
|
|
106
|
+
self.client = client
|
|
107
|
+
self.function = function
|
|
108
|
+
self.inputs_created = 0
|
|
109
|
+
self.raw_input_queue = raw_input_queue
|
|
110
|
+
self.processed_input_queue = processed_input_queue
|
|
111
|
+
self.created_callback = created_callback
|
|
112
|
+
self.done_callback = done_callback
|
|
113
|
+
|
|
114
|
+
async def input_iter(self):
|
|
115
|
+
while 1:
|
|
116
|
+
raw_input = await self.raw_input_queue.get()
|
|
117
|
+
if raw_input is None: # end of input sentinel
|
|
118
|
+
break
|
|
119
|
+
yield raw_input # args, kwargs
|
|
120
|
+
|
|
121
|
+
def create_input_factory(self):
|
|
122
|
+
async def create_input(argskwargs):
|
|
123
|
+
idx = self.inputs_created
|
|
124
|
+
self.inputs_created += 1
|
|
125
|
+
self.created_callback(self.inputs_created)
|
|
126
|
+
(args, kwargs) = argskwargs
|
|
127
|
+
return await _create_input(
|
|
128
|
+
args,
|
|
129
|
+
kwargs,
|
|
130
|
+
self.client.stub,
|
|
131
|
+
idx=idx,
|
|
132
|
+
function=self.function,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
return create_input
|
|
136
|
+
|
|
137
|
+
async def drain_input_generator(self):
|
|
138
|
+
# Parallelize uploading blobs
|
|
139
|
+
async with aclosing(
|
|
140
|
+
async_map_ordered(self.input_iter(), self.create_input_factory(), concurrency=BLOB_MAX_PARALLELISM)
|
|
141
|
+
) as streamer:
|
|
142
|
+
async for item in streamer:
|
|
143
|
+
await self.processed_input_queue.put(item)
|
|
144
|
+
|
|
145
|
+
# close queue iterator
|
|
146
|
+
await self.processed_input_queue.put(None)
|
|
147
|
+
self.done_callback()
|
|
148
|
+
yield
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class InputPumper:
|
|
152
|
+
"""
|
|
153
|
+
Reads inputs from a queue of FunctionPutInputsItems, and sends them to the server.
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
def __init__(
|
|
157
|
+
self,
|
|
158
|
+
client: "modal.client._Client",
|
|
159
|
+
*,
|
|
160
|
+
input_queue: asyncio.Queue,
|
|
161
|
+
function: "modal.functions._Function",
|
|
162
|
+
function_call_id: str,
|
|
163
|
+
max_batch_size: int,
|
|
164
|
+
map_items_manager: Optional["_MapItemsManager"] = None,
|
|
165
|
+
):
|
|
166
|
+
self.client = client
|
|
167
|
+
self.function = function
|
|
168
|
+
self.map_items_manager = map_items_manager
|
|
169
|
+
self.input_queue = input_queue
|
|
170
|
+
self.inputs_sent = 0
|
|
171
|
+
self.function_call_id = function_call_id
|
|
172
|
+
self.max_batch_size = max_batch_size
|
|
173
|
+
|
|
174
|
+
async def pump_inputs(self):
|
|
175
|
+
assert self.client.stub
|
|
176
|
+
async for items in queue_batch_iterator(self.input_queue, max_batch_size=self.max_batch_size):
|
|
177
|
+
# Add items to the manager. Their state will be SENDING.
|
|
178
|
+
if self.map_items_manager is not None:
|
|
179
|
+
await self.map_items_manager.add_items(items)
|
|
180
|
+
request = api_pb2.FunctionPutInputsRequest(
|
|
181
|
+
function_id=self.function.object_id,
|
|
182
|
+
inputs=items,
|
|
183
|
+
function_call_id=self.function_call_id,
|
|
184
|
+
)
|
|
185
|
+
logger.debug(
|
|
186
|
+
f"Pushing {len(items)} inputs to server. Num queued inputs awaiting"
|
|
187
|
+
f" push is {self.input_queue.qsize()}. "
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
resp = await self.client.stub.FunctionPutInputs(request, retry=self._function_inputs_retry)
|
|
191
|
+
self.inputs_sent += len(items)
|
|
192
|
+
# Change item state to WAITING_FOR_OUTPUT, and set the input_id and input_jwt which are in the response.
|
|
193
|
+
if self.map_items_manager is not None:
|
|
194
|
+
self.map_items_manager.handle_put_inputs_response(resp.inputs)
|
|
195
|
+
logger.debug(
|
|
196
|
+
f"Successfully pushed {len(items)} inputs to server. "
|
|
197
|
+
f"Num queued inputs awaiting push is {self.input_queue.qsize()}."
|
|
198
|
+
)
|
|
199
|
+
yield
|
|
200
|
+
|
|
201
|
+
@property
|
|
202
|
+
def _function_inputs_retry(self) -> Retry:
|
|
203
|
+
# with 8 retries we log the warning below about every 30 seconds which isn't too spammy.
|
|
204
|
+
retry_warning_message = RetryWarningMessage(
|
|
205
|
+
message=f"Warning: map progress for function {self.function._function_name} is limited."
|
|
206
|
+
" Common bottlenecks include slow iteration over results, or function backlogs.",
|
|
207
|
+
warning_interval=8,
|
|
208
|
+
errors_to_warn_for=[Status.RESOURCE_EXHAUSTED],
|
|
209
|
+
)
|
|
210
|
+
return Retry(
|
|
211
|
+
max_retries=None,
|
|
212
|
+
max_delay=PUMP_INPUTS_MAX_RETRY_DELAY,
|
|
213
|
+
additional_status_codes=[Status.RESOURCE_EXHAUSTED],
|
|
214
|
+
warning_message=retry_warning_message,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class SyncInputPumper(InputPumper):
|
|
219
|
+
def __init__(
|
|
220
|
+
self,
|
|
221
|
+
client: "modal.client._Client",
|
|
222
|
+
*,
|
|
223
|
+
input_queue: asyncio.Queue,
|
|
224
|
+
retry_queue: TimestampPriorityQueue,
|
|
225
|
+
function: "modal.functions._Function",
|
|
226
|
+
function_call_jwt: str,
|
|
227
|
+
function_call_id: str,
|
|
228
|
+
map_items_manager: "_MapItemsManager",
|
|
229
|
+
):
|
|
230
|
+
super().__init__(
|
|
231
|
+
client,
|
|
232
|
+
input_queue=input_queue,
|
|
233
|
+
function=function,
|
|
234
|
+
function_call_id=function_call_id,
|
|
235
|
+
max_batch_size=MAP_INVOCATION_CHUNK_SIZE,
|
|
236
|
+
map_items_manager=map_items_manager,
|
|
237
|
+
)
|
|
238
|
+
self.retry_queue = retry_queue
|
|
239
|
+
self.inputs_retried = 0
|
|
240
|
+
self.function_call_jwt = function_call_jwt
|
|
241
|
+
|
|
242
|
+
async def retry_inputs(self):
|
|
243
|
+
async for retriable_idxs in queue_batch_iterator(self.retry_queue, max_batch_size=self.max_batch_size):
|
|
244
|
+
# For each index, use the context in the manager to create a FunctionRetryInputsItem.
|
|
245
|
+
# This will also update the context state to RETRYING.
|
|
246
|
+
inputs: list[api_pb2.FunctionRetryInputsItem] = await self.map_items_manager.prepare_items_for_retry(
|
|
247
|
+
retriable_idxs
|
|
248
|
+
)
|
|
249
|
+
request = api_pb2.FunctionRetryInputsRequest(
|
|
250
|
+
function_call_jwt=self.function_call_jwt,
|
|
251
|
+
inputs=inputs,
|
|
252
|
+
)
|
|
253
|
+
resp = await self.client.stub.FunctionRetryInputs(request, retry=self._function_inputs_retry)
|
|
254
|
+
# Update the state to WAITING_FOR_OUTPUT, and update the input_jwt in the context
|
|
255
|
+
# to the new value in the response.
|
|
256
|
+
self.map_items_manager.handle_retry_response(resp.input_jwts)
|
|
257
|
+
logger.debug(f"Successfully pushed retry for {len(inputs)} to server.")
|
|
258
|
+
self.inputs_retried += len(inputs)
|
|
259
|
+
yield
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
class AsyncInputPumper(InputPumper):
|
|
263
|
+
def __init__(
|
|
264
|
+
self,
|
|
265
|
+
client: "modal.client._Client",
|
|
266
|
+
*,
|
|
267
|
+
input_queue: asyncio.Queue,
|
|
268
|
+
function: "modal.functions._Function",
|
|
269
|
+
function_call_id: str,
|
|
270
|
+
):
|
|
271
|
+
super().__init__(
|
|
272
|
+
client,
|
|
273
|
+
input_queue=input_queue,
|
|
274
|
+
function=function,
|
|
275
|
+
function_call_id=function_call_id,
|
|
276
|
+
max_batch_size=SPAWN_MAP_INVOCATION_CHUNK_SIZE,
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
async def pump_inputs(self):
|
|
280
|
+
async for _ in super().pump_inputs():
|
|
281
|
+
pass
|
|
282
|
+
request = api_pb2.FunctionFinishInputsRequest(
|
|
283
|
+
function_id=self.function.object_id,
|
|
284
|
+
function_call_id=self.function_call_id,
|
|
285
|
+
num_inputs=self.inputs_sent,
|
|
286
|
+
)
|
|
287
|
+
await self.client.stub.FunctionFinishInputs(request, retry=Retry(max_retries=None))
|
|
288
|
+
yield
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
async def _spawn_map_invocation(
|
|
292
|
+
function: "modal.functions._Function", raw_input_queue: _SynchronizedQueue, client: "modal.client._Client"
|
|
293
|
+
) -> tuple[str, int]:
|
|
294
|
+
assert client.stub
|
|
295
|
+
request = api_pb2.FunctionMapRequest(
|
|
296
|
+
function_id=function.object_id,
|
|
297
|
+
parent_input_id=current_input_id() or "",
|
|
298
|
+
function_call_type=api_pb2.FUNCTION_CALL_TYPE_MAP,
|
|
299
|
+
function_call_invocation_type=api_pb2.FUNCTION_CALL_INVOCATION_TYPE_ASYNC,
|
|
300
|
+
)
|
|
301
|
+
response: api_pb2.FunctionMapResponse = await client.stub.FunctionMap(request)
|
|
302
|
+
function_call_id = response.function_call_id
|
|
303
|
+
|
|
304
|
+
have_all_inputs = False
|
|
305
|
+
inputs_created = 0
|
|
306
|
+
|
|
307
|
+
def set_inputs_created(set_inputs_created):
|
|
308
|
+
nonlocal inputs_created
|
|
309
|
+
assert set_inputs_created is None or set_inputs_created > inputs_created
|
|
310
|
+
inputs_created = set_inputs_created
|
|
311
|
+
|
|
312
|
+
def set_have_all_inputs():
|
|
313
|
+
nonlocal have_all_inputs
|
|
314
|
+
have_all_inputs = True
|
|
315
|
+
|
|
316
|
+
input_queue: asyncio.Queue[api_pb2.FunctionPutInputsItem | None] = asyncio.Queue()
|
|
317
|
+
input_preprocessor = InputPreprocessor(
|
|
318
|
+
client=client,
|
|
319
|
+
raw_input_queue=raw_input_queue,
|
|
320
|
+
processed_input_queue=input_queue,
|
|
321
|
+
function=function,
|
|
322
|
+
created_callback=set_inputs_created,
|
|
323
|
+
done_callback=set_have_all_inputs,
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
input_pumper = AsyncInputPumper(
|
|
327
|
+
client=client,
|
|
328
|
+
input_queue=input_queue,
|
|
329
|
+
function=function,
|
|
330
|
+
function_call_id=function_call_id,
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
def log_stats():
|
|
334
|
+
logger.debug(
|
|
335
|
+
f"have_all_inputs={have_all_inputs} inputs_created={inputs_created} inputs_sent={input_pumper.inputs_sent} "
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
async def log_task():
|
|
339
|
+
while True:
|
|
340
|
+
log_stats()
|
|
341
|
+
try:
|
|
342
|
+
await asyncio.sleep(10)
|
|
343
|
+
except asyncio.CancelledError:
|
|
344
|
+
# Log final stats before exiting
|
|
345
|
+
log_stats()
|
|
346
|
+
break
|
|
347
|
+
|
|
348
|
+
async def consume_generator(gen):
|
|
349
|
+
async for _ in gen:
|
|
350
|
+
pass
|
|
351
|
+
|
|
352
|
+
log_debug_stats_task = asyncio.create_task(log_task())
|
|
353
|
+
await asyncio.gather(
|
|
354
|
+
consume_generator(input_preprocessor.drain_input_generator()),
|
|
355
|
+
consume_generator(input_pumper.pump_inputs()),
|
|
356
|
+
)
|
|
357
|
+
log_debug_stats_task.cancel()
|
|
358
|
+
await log_debug_stats_task
|
|
359
|
+
return function_call_id, inputs_created
|
|
360
|
+
|
|
361
|
+
|
|
85
362
|
async def _map_invocation(
|
|
86
363
|
function: "modal.functions._Function",
|
|
87
364
|
raw_input_queue: _SynchronizedQueue,
|
|
88
365
|
client: "modal.client._Client",
|
|
89
366
|
order_outputs: bool,
|
|
90
367
|
return_exceptions: bool,
|
|
368
|
+
wrap_returned_exceptions: bool,
|
|
91
369
|
count_update_callback: Optional[Callable[[int, int], None]],
|
|
92
370
|
function_call_invocation_type: "api_pb2.FunctionCallInvocationType.ValueType",
|
|
93
371
|
):
|
|
@@ -99,7 +377,7 @@ async def _map_invocation(
|
|
|
99
377
|
return_exceptions=return_exceptions,
|
|
100
378
|
function_call_invocation_type=function_call_invocation_type,
|
|
101
379
|
)
|
|
102
|
-
response: api_pb2.FunctionMapResponse = await
|
|
380
|
+
response: api_pb2.FunctionMapResponse = await client.stub.FunctionMap(request)
|
|
103
381
|
|
|
104
382
|
function_call_id = response.function_call_id
|
|
105
383
|
function_call_jwt = response.function_call_jwt
|
|
@@ -110,9 +388,8 @@ async def _map_invocation(
|
|
|
110
388
|
max_inputs_outstanding = response.max_inputs_outstanding or MAX_INPUTS_OUTSTANDING_DEFAULT
|
|
111
389
|
|
|
112
390
|
have_all_inputs = False
|
|
391
|
+
map_done_event = asyncio.Event()
|
|
113
392
|
inputs_created = 0
|
|
114
|
-
inputs_sent = 0
|
|
115
|
-
inputs_retried = 0
|
|
116
393
|
outputs_completed = 0
|
|
117
394
|
outputs_received = 0
|
|
118
395
|
retried_outputs = 0
|
|
@@ -122,10 +399,6 @@ async def _map_invocation(
|
|
|
122
399
|
stale_retry_duplicates = 0
|
|
123
400
|
no_context_duplicates = 0
|
|
124
401
|
|
|
125
|
-
def count_update():
|
|
126
|
-
if count_update_callback is not None:
|
|
127
|
-
count_update_callback(outputs_completed, inputs_created)
|
|
128
|
-
|
|
129
402
|
retry_queue = TimestampPriorityQueue()
|
|
130
403
|
completed_outputs: set[str] = set() # Set of input_ids whose outputs are complete (expecting no more values)
|
|
131
404
|
input_queue: asyncio.Queue[api_pb2.FunctionPutInputsItem | None] = asyncio.Queue()
|
|
@@ -133,109 +406,50 @@ async def _map_invocation(
|
|
|
133
406
|
retry_policy, function_call_invocation_type, retry_queue, sync_client_retries_enabled, max_inputs_outstanding
|
|
134
407
|
)
|
|
135
408
|
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
while 1:
|
|
145
|
-
raw_input = await raw_input_queue.get()
|
|
146
|
-
if raw_input is None: # end of input sentinel
|
|
147
|
-
break
|
|
148
|
-
yield raw_input # args, kwargs
|
|
149
|
-
|
|
150
|
-
async def drain_input_generator():
|
|
151
|
-
nonlocal have_all_inputs
|
|
152
|
-
|
|
153
|
-
# Parallelize uploading blobs
|
|
154
|
-
async with aclosing(
|
|
155
|
-
async_map_ordered(input_iter(), create_input, concurrency=BLOB_MAX_PARALLELISM)
|
|
156
|
-
) as streamer:
|
|
157
|
-
async for item in streamer:
|
|
158
|
-
await input_queue.put(item)
|
|
159
|
-
|
|
160
|
-
# close queue iterator
|
|
161
|
-
await input_queue.put(None)
|
|
162
|
-
have_all_inputs = True
|
|
163
|
-
yield
|
|
409
|
+
input_preprocessor = InputPreprocessor(
|
|
410
|
+
client=client,
|
|
411
|
+
raw_input_queue=raw_input_queue,
|
|
412
|
+
processed_input_queue=input_queue,
|
|
413
|
+
function=function,
|
|
414
|
+
created_callback=lambda x: update_state(set_inputs_created=x),
|
|
415
|
+
done_callback=lambda: update_state(set_have_all_inputs=True),
|
|
416
|
+
)
|
|
164
417
|
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
function_call_id=function_call_id,
|
|
175
|
-
)
|
|
176
|
-
logger.debug(
|
|
177
|
-
f"Pushing {len(items)} inputs to server. Num queued inputs awaiting push is {input_queue.qsize()}."
|
|
178
|
-
)
|
|
418
|
+
input_pumper = SyncInputPumper(
|
|
419
|
+
client=client,
|
|
420
|
+
input_queue=input_queue,
|
|
421
|
+
retry_queue=retry_queue,
|
|
422
|
+
function=function,
|
|
423
|
+
map_items_manager=map_items_manager,
|
|
424
|
+
function_call_jwt=function_call_jwt,
|
|
425
|
+
function_call_id=function_call_id,
|
|
426
|
+
)
|
|
179
427
|
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
428
|
+
def update_state(set_have_all_inputs=None, set_inputs_created=None, set_outputs_completed=None):
|
|
429
|
+
# This should be the only method that needs nonlocal of the following vars
|
|
430
|
+
nonlocal have_all_inputs, inputs_created, outputs_completed
|
|
431
|
+
assert set_have_all_inputs is not False # not allowed
|
|
432
|
+
assert set_inputs_created is None or set_inputs_created > inputs_created
|
|
433
|
+
assert set_outputs_completed is None or set_outputs_completed > outputs_completed
|
|
434
|
+
if set_have_all_inputs is not None:
|
|
435
|
+
have_all_inputs = set_have_all_inputs
|
|
436
|
+
if set_inputs_created is not None:
|
|
437
|
+
inputs_created = set_inputs_created
|
|
438
|
+
if set_outputs_completed is not None:
|
|
439
|
+
outputs_completed = set_outputs_completed
|
|
190
440
|
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
async for retriable_idxs in queue_batch_iterator(retry_queue, max_batch_size=MAP_INVOCATION_CHUNK_SIZE):
|
|
194
|
-
# For each index, use the context in the manager to create a FunctionRetryInputsItem.
|
|
195
|
-
# This will also update the context state to RETRYING.
|
|
196
|
-
inputs: list[api_pb2.FunctionRetryInputsItem] = await map_items_manager.prepare_items_for_retry(
|
|
197
|
-
retriable_idxs
|
|
198
|
-
)
|
|
199
|
-
request = api_pb2.FunctionRetryInputsRequest(
|
|
200
|
-
function_call_jwt=function_call_jwt,
|
|
201
|
-
inputs=inputs,
|
|
202
|
-
)
|
|
203
|
-
resp = await send_inputs(client.stub.FunctionRetryInputs, request)
|
|
204
|
-
# Update the state to WAITING_FOR_OUTPUT, and update the input_jwt in the context
|
|
205
|
-
# to the new value in the response.
|
|
206
|
-
map_items_manager.handle_retry_response(resp.input_jwts)
|
|
207
|
-
logger.debug(f"Successfully pushed retry for {len(inputs)} to server.")
|
|
208
|
-
inputs_retried += len(inputs)
|
|
209
|
-
yield
|
|
441
|
+
if count_update_callback is not None:
|
|
442
|
+
count_update_callback(outputs_completed, inputs_created)
|
|
210
443
|
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
) -> typing.Union[api_pb2.FunctionPutInputsResponse, api_pb2.FunctionRetryInputsResponse]:
|
|
215
|
-
# with 8 retries we log the warning below about every 30 seconds which isn't too spammy.
|
|
216
|
-
retry_warning_message = RetryWarningMessage(
|
|
217
|
-
message=f"Warning: map progress for function {function._function_name} is limited."
|
|
218
|
-
" Common bottlenecks include slow iteration over results, or function backlogs.",
|
|
219
|
-
warning_interval=8,
|
|
220
|
-
errors_to_warn_for=[Status.RESOURCE_EXHAUSTED],
|
|
221
|
-
)
|
|
222
|
-
return await retry_transient_errors(
|
|
223
|
-
fn,
|
|
224
|
-
request,
|
|
225
|
-
max_retries=None,
|
|
226
|
-
max_delay=PUMP_INPUTS_MAX_RETRY_DELAY,
|
|
227
|
-
additional_status_codes=[Status.RESOURCE_EXHAUSTED],
|
|
228
|
-
retry_warning_message=retry_warning_message,
|
|
229
|
-
)
|
|
444
|
+
if have_all_inputs and outputs_completed >= inputs_created:
|
|
445
|
+
# map is done
|
|
446
|
+
map_done_event.set()
|
|
230
447
|
|
|
231
448
|
async def get_all_outputs():
|
|
232
449
|
assert client.stub
|
|
233
450
|
nonlocal \
|
|
234
|
-
inputs_created, \
|
|
235
451
|
successful_completions, \
|
|
236
452
|
failed_completions, \
|
|
237
|
-
outputs_completed, \
|
|
238
|
-
have_all_inputs, \
|
|
239
453
|
outputs_received, \
|
|
240
454
|
already_complete_duplicates, \
|
|
241
455
|
no_context_duplicates, \
|
|
@@ -244,7 +458,7 @@ async def _map_invocation(
|
|
|
244
458
|
|
|
245
459
|
last_entry_id = "0-0"
|
|
246
460
|
|
|
247
|
-
while not
|
|
461
|
+
while not map_done_event.is_set():
|
|
248
462
|
logger.debug(f"Requesting outputs. Have {outputs_completed} outputs, {inputs_created} inputs.")
|
|
249
463
|
# Get input_jwts of all items in the WAITING_FOR_OUTPUT state.
|
|
250
464
|
# The server uses these to track for lost inputs.
|
|
@@ -258,12 +472,29 @@ async def _map_invocation(
|
|
|
258
472
|
requested_at=time.time(),
|
|
259
473
|
input_jwts=input_jwts,
|
|
260
474
|
)
|
|
261
|
-
|
|
262
|
-
client.stub.FunctionGetOutputs
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
475
|
+
get_response_task = asyncio.create_task(
|
|
476
|
+
client.stub.FunctionGetOutputs(
|
|
477
|
+
request,
|
|
478
|
+
retry=Retry(
|
|
479
|
+
max_retries=20,
|
|
480
|
+
attempt_timeout=OUTPUTS_TIMEOUT + ATTEMPT_TIMEOUT_GRACE_PERIOD,
|
|
481
|
+
),
|
|
482
|
+
)
|
|
266
483
|
)
|
|
484
|
+
map_done_task = asyncio.create_task(map_done_event.wait())
|
|
485
|
+
try:
|
|
486
|
+
done, pending = await asyncio.wait([get_response_task, map_done_task], return_when=FIRST_COMPLETED)
|
|
487
|
+
if get_response_task in done:
|
|
488
|
+
map_done_task.cancel()
|
|
489
|
+
response = get_response_task.result()
|
|
490
|
+
else:
|
|
491
|
+
assert map_done_event.is_set()
|
|
492
|
+
# map is done - no more outputs, so return early
|
|
493
|
+
return
|
|
494
|
+
finally:
|
|
495
|
+
# clean up tasks, in case of cancellations etc.
|
|
496
|
+
get_response_task.cancel()
|
|
497
|
+
map_done_task.cancel()
|
|
267
498
|
|
|
268
499
|
last_entry_id = response.last_entry_id
|
|
269
500
|
now_seconds = int(time.time())
|
|
@@ -288,7 +519,7 @@ async def _map_invocation(
|
|
|
288
519
|
|
|
289
520
|
if output_type == _OutputType.SUCCESSFUL_COMPLETION or output_type == _OutputType.FAILED_COMPLETION:
|
|
290
521
|
completed_outputs.add(item.input_id)
|
|
291
|
-
outputs_completed
|
|
522
|
+
update_state(set_outputs_completed=outputs_completed + 1)
|
|
292
523
|
yield item
|
|
293
524
|
|
|
294
525
|
async def get_all_outputs_and_clean_up():
|
|
@@ -306,7 +537,7 @@ async def _map_invocation(
|
|
|
306
537
|
clear_on_success=True,
|
|
307
538
|
requested_at=time.time(),
|
|
308
539
|
)
|
|
309
|
-
await
|
|
540
|
+
await client.stub.FunctionGetOutputs(request)
|
|
310
541
|
await retry_queue.close()
|
|
311
542
|
|
|
312
543
|
async def fetch_output(item: api_pb2.FunctionGetOutputsItem) -> tuple[int, Any]:
|
|
@@ -314,7 +545,13 @@ async def _map_invocation(
|
|
|
314
545
|
output = await _process_result(item.result, item.data_format, client.stub, client)
|
|
315
546
|
except Exception as e:
|
|
316
547
|
if return_exceptions:
|
|
317
|
-
|
|
548
|
+
if wrap_returned_exceptions:
|
|
549
|
+
# Prior to client 1.0.4 there was a bug where return_exceptions would wrap
|
|
550
|
+
# any returned exceptions in a synchronicity.UserCodeException. This adds
|
|
551
|
+
# deprecated non-breaking compatibility bandaid for migrating away from that:
|
|
552
|
+
output = modal.exception.UserCodeException(e)
|
|
553
|
+
else:
|
|
554
|
+
output = e
|
|
318
555
|
else:
|
|
319
556
|
raise e
|
|
320
557
|
return (item.idx, output)
|
|
@@ -328,7 +565,6 @@ async def _map_invocation(
|
|
|
328
565
|
async_map_ordered(get_all_outputs_and_clean_up(), fetch_output, concurrency=BLOB_MAX_PARALLELISM)
|
|
329
566
|
) as streamer:
|
|
330
567
|
async for idx, output in streamer:
|
|
331
|
-
count_update()
|
|
332
568
|
if not order_outputs:
|
|
333
569
|
yield _OutputValue(output)
|
|
334
570
|
else:
|
|
@@ -352,8 +588,11 @@ async def _map_invocation(
|
|
|
352
588
|
def log_stats():
|
|
353
589
|
logger.debug(
|
|
354
590
|
f"Map stats: sync_client_retries_enabled={sync_client_retries_enabled} "
|
|
355
|
-
f"have_all_inputs={have_all_inputs}
|
|
356
|
-
f"
|
|
591
|
+
f"have_all_inputs={have_all_inputs} "
|
|
592
|
+
f"inputs_created={inputs_created} "
|
|
593
|
+
f"input_sent={input_pumper.inputs_sent} "
|
|
594
|
+
f"inputs_retried={input_pumper.inputs_retried} "
|
|
595
|
+
f"outputs_received={outputs_received} "
|
|
357
596
|
f"successful_completions={successful_completions} failed_completions={failed_completions} "
|
|
358
597
|
f"no_context_duplicates={no_context_duplicates} old_retry_duplicates={stale_retry_duplicates} "
|
|
359
598
|
f"already_complete_duplicates={already_complete_duplicates} "
|
|
@@ -372,21 +611,388 @@ async def _map_invocation(
|
|
|
372
611
|
|
|
373
612
|
log_debug_stats_task = asyncio.create_task(log_debug_stats())
|
|
374
613
|
async with aclosing(
|
|
375
|
-
async_merge(
|
|
614
|
+
async_merge(
|
|
615
|
+
input_preprocessor.drain_input_generator(),
|
|
616
|
+
input_pumper.pump_inputs(),
|
|
617
|
+
input_pumper.retry_inputs(),
|
|
618
|
+
poll_outputs(),
|
|
619
|
+
)
|
|
376
620
|
) as streamer:
|
|
377
621
|
async for response in streamer:
|
|
378
|
-
if response is not None:
|
|
622
|
+
if response is not None: # type: ignore[unreachable]
|
|
379
623
|
yield response.value
|
|
380
624
|
log_debug_stats_task.cancel()
|
|
381
625
|
await log_debug_stats_task
|
|
382
626
|
|
|
383
627
|
|
|
628
|
+
async def _map_invocation_inputplane(
|
|
629
|
+
function: "modal.functions._Function",
|
|
630
|
+
raw_input_queue: _SynchronizedQueue,
|
|
631
|
+
client: "modal.client._Client",
|
|
632
|
+
order_outputs: bool,
|
|
633
|
+
return_exceptions: bool,
|
|
634
|
+
wrap_returned_exceptions: bool,
|
|
635
|
+
count_update_callback: Optional[Callable[[int, int], None]],
|
|
636
|
+
) -> typing.AsyncGenerator[Any, None]:
|
|
637
|
+
"""Input-plane implementation of a function map invocation.
|
|
638
|
+
|
|
639
|
+
This is analogous to `_map_invocation`, but instead of the control-plane
|
|
640
|
+
`FunctionMap` / `FunctionPutInputs` / `FunctionGetOutputs` RPCs it speaks
|
|
641
|
+
the input-plane protocol consisting of `MapStartOrContinue`, `MapAwait`, and `MapCheckInputs`.
|
|
642
|
+
"""
|
|
643
|
+
|
|
644
|
+
assert function._input_plane_url, "_map_invocation_inputplane should only be used for input-plane backed functions"
|
|
645
|
+
|
|
646
|
+
input_plane_stub = await client.get_stub(function._input_plane_url)
|
|
647
|
+
|
|
648
|
+
# Required for _create_input.
|
|
649
|
+
assert client.stub, "Client must be hydrated with a stub for _map_invocation_inputplane"
|
|
650
|
+
|
|
651
|
+
# ------------------------------------------------------------
|
|
652
|
+
# Invocation-wide state
|
|
653
|
+
# ------------------------------------------------------------
|
|
654
|
+
|
|
655
|
+
have_all_inputs = False
|
|
656
|
+
map_done_event = asyncio.Event()
|
|
657
|
+
|
|
658
|
+
inputs_created = 0
|
|
659
|
+
outputs_completed = 0
|
|
660
|
+
successful_completions = 0
|
|
661
|
+
failed_completions = 0
|
|
662
|
+
no_context_duplicates = 0
|
|
663
|
+
stale_retry_duplicates = 0
|
|
664
|
+
already_complete_duplicates = 0
|
|
665
|
+
retried_outputs = 0
|
|
666
|
+
input_queue_size = 0
|
|
667
|
+
last_entry_id = ""
|
|
668
|
+
|
|
669
|
+
# The input-plane server returns this after the first request.
|
|
670
|
+
map_token = None
|
|
671
|
+
map_token_received = asyncio.Event()
|
|
672
|
+
|
|
673
|
+
# Single priority queue that holds *both* fresh inputs (timestamp == now)
|
|
674
|
+
# and future retries (timestamp > now).
|
|
675
|
+
queue: TimestampPriorityQueue[api_pb2.MapStartOrContinueItem] = TimestampPriorityQueue()
|
|
676
|
+
|
|
677
|
+
# Maximum number of inputs that may be in-flight (the server sends this in
|
|
678
|
+
# the first response – fall back to the default if we never receive it for
|
|
679
|
+
# any reason).
|
|
680
|
+
max_inputs_outstanding = MAX_INPUTS_OUTSTANDING_DEFAULT
|
|
681
|
+
|
|
682
|
+
# Set a default retry policy to construct an instance of _MapItemsManager.
|
|
683
|
+
# We'll update the retry policy with the actual user-specified retry policy
|
|
684
|
+
# from the server in the first MapStartOrContinue response.
|
|
685
|
+
retry_policy = api_pb2.FunctionRetryPolicy(
|
|
686
|
+
retries=0,
|
|
687
|
+
initial_delay_ms=1000,
|
|
688
|
+
max_delay_ms=1000,
|
|
689
|
+
backoff_coefficient=1.0,
|
|
690
|
+
)
|
|
691
|
+
map_items_manager = _MapItemsManager(
|
|
692
|
+
retry_policy=retry_policy,
|
|
693
|
+
function_call_invocation_type=api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC,
|
|
694
|
+
retry_queue=queue,
|
|
695
|
+
sync_client_retries_enabled=True,
|
|
696
|
+
max_inputs_outstanding=MAX_INPUTS_OUTSTANDING_DEFAULT,
|
|
697
|
+
is_input_plane_instance=True,
|
|
698
|
+
)
|
|
699
|
+
|
|
700
|
+
def update_counters(
|
|
701
|
+
created_delta: int = 0, completed_delta: int = 0, set_have_all_inputs: Union[bool, None] = None
|
|
702
|
+
):
|
|
703
|
+
nonlocal inputs_created, outputs_completed, have_all_inputs
|
|
704
|
+
|
|
705
|
+
if created_delta:
|
|
706
|
+
inputs_created += created_delta
|
|
707
|
+
if completed_delta:
|
|
708
|
+
outputs_completed += completed_delta
|
|
709
|
+
if set_have_all_inputs is not None:
|
|
710
|
+
have_all_inputs = set_have_all_inputs
|
|
711
|
+
|
|
712
|
+
if count_update_callback is not None:
|
|
713
|
+
count_update_callback(outputs_completed, inputs_created)
|
|
714
|
+
|
|
715
|
+
if have_all_inputs and outputs_completed >= inputs_created:
|
|
716
|
+
map_done_event.set()
|
|
717
|
+
|
|
718
|
+
async def create_input(argskwargs):
|
|
719
|
+
idx = inputs_created + 1 # 1-indexed map call idx
|
|
720
|
+
update_counters(created_delta=1)
|
|
721
|
+
(args, kwargs) = argskwargs
|
|
722
|
+
put_item: api_pb2.FunctionPutInputsItem = await _create_input(
|
|
723
|
+
args,
|
|
724
|
+
kwargs,
|
|
725
|
+
client.stub,
|
|
726
|
+
idx=idx,
|
|
727
|
+
function=function,
|
|
728
|
+
)
|
|
729
|
+
return api_pb2.MapStartOrContinueItem(input=put_item)
|
|
730
|
+
|
|
731
|
+
async def input_iter():
|
|
732
|
+
while True:
|
|
733
|
+
raw_input = await raw_input_queue.get()
|
|
734
|
+
if raw_input is None: # end of input sentinel
|
|
735
|
+
break
|
|
736
|
+
yield raw_input # args, kwargs
|
|
737
|
+
|
|
738
|
+
async def drain_input_generator():
|
|
739
|
+
async with aclosing(
|
|
740
|
+
async_map_ordered(input_iter(), create_input, concurrency=BLOB_MAX_PARALLELISM)
|
|
741
|
+
) as streamer:
|
|
742
|
+
async for q_item in streamer:
|
|
743
|
+
await queue.put(time.time(), q_item)
|
|
744
|
+
|
|
745
|
+
# All inputs have been read.
|
|
746
|
+
update_counters(set_have_all_inputs=True)
|
|
747
|
+
yield
|
|
748
|
+
|
|
749
|
+
async def pump_inputs():
|
|
750
|
+
nonlocal map_token, max_inputs_outstanding
|
|
751
|
+
async for batch in queue_batch_iterator(queue, max_batch_size=MAP_INVOCATION_CHUNK_SIZE):
|
|
752
|
+
# Convert the queued items into the proto format expected by the RPC.
|
|
753
|
+
request_items: list[api_pb2.MapStartOrContinueItem] = [
|
|
754
|
+
api_pb2.MapStartOrContinueItem(input=qi.input, attempt_token=qi.attempt_token) for qi in batch
|
|
755
|
+
]
|
|
756
|
+
|
|
757
|
+
await map_items_manager.add_items_inputplane(request_items)
|
|
758
|
+
|
|
759
|
+
# Build request
|
|
760
|
+
request = api_pb2.MapStartOrContinueRequest(
|
|
761
|
+
function_id=function.object_id,
|
|
762
|
+
map_token=map_token,
|
|
763
|
+
parent_input_id=current_input_id() or "",
|
|
764
|
+
items=request_items,
|
|
765
|
+
)
|
|
766
|
+
|
|
767
|
+
metadata = await client.get_input_plane_metadata(function._input_plane_region)
|
|
768
|
+
|
|
769
|
+
response: api_pb2.MapStartOrContinueResponse = await input_plane_stub.MapStartOrContinue(
|
|
770
|
+
request,
|
|
771
|
+
retry=Retry(
|
|
772
|
+
additional_status_codes=[Status.RESOURCE_EXHAUSTED],
|
|
773
|
+
max_delay=PUMP_INPUTS_MAX_RETRY_DELAY,
|
|
774
|
+
max_retries=None,
|
|
775
|
+
),
|
|
776
|
+
metadata=metadata,
|
|
777
|
+
)
|
|
778
|
+
|
|
779
|
+
# match response items to the corresponding request item index
|
|
780
|
+
response_items_idx_tuple = [
|
|
781
|
+
(request_items[idx].input.idx, attempt_token)
|
|
782
|
+
for idx, attempt_token in enumerate(response.attempt_tokens)
|
|
783
|
+
]
|
|
784
|
+
|
|
785
|
+
map_items_manager.handle_put_continue_response(response_items_idx_tuple)
|
|
786
|
+
|
|
787
|
+
# Set the function call id and actual retry policy with the data from the first response.
|
|
788
|
+
# This conditional is skipped for subsequent iterations of this for-loop.
|
|
789
|
+
if map_token is None:
|
|
790
|
+
map_token = response.map_token
|
|
791
|
+
map_token_received.set()
|
|
792
|
+
max_inputs_outstanding = response.max_inputs_outstanding or MAX_INPUTS_OUTSTANDING_DEFAULT
|
|
793
|
+
map_items_manager.set_retry_policy(response.retry_policy)
|
|
794
|
+
# Update the retry policy for the first batch of inputs.
|
|
795
|
+
# Subsequent batches will have the correct user-specified retry policy
|
|
796
|
+
# set by the updated _MapItemsManager.
|
|
797
|
+
map_items_manager.update_items_retry_policy(response.retry_policy)
|
|
798
|
+
yield
|
|
799
|
+
|
|
800
|
+
async def check_lost_inputs():
|
|
801
|
+
nonlocal last_entry_id # shared with get_all_outputs
|
|
802
|
+
try:
|
|
803
|
+
while not map_done_event.is_set():
|
|
804
|
+
if map_token is None:
|
|
805
|
+
await map_token_received.wait()
|
|
806
|
+
continue
|
|
807
|
+
|
|
808
|
+
sleep_task = asyncio.create_task(asyncio.sleep(1))
|
|
809
|
+
map_done_task = asyncio.create_task(map_done_event.wait())
|
|
810
|
+
done, _ = await asyncio.wait([sleep_task, map_done_task], return_when=FIRST_COMPLETED)
|
|
811
|
+
if map_done_task in done:
|
|
812
|
+
break
|
|
813
|
+
|
|
814
|
+
# check_inputs = [(idx, attempt_token), ...]
|
|
815
|
+
check_inputs = map_items_manager.get_input_idxs_waiting_for_output()
|
|
816
|
+
attempt_tokens = [attempt_token for _, attempt_token in check_inputs]
|
|
817
|
+
request = api_pb2.MapCheckInputsRequest(
|
|
818
|
+
last_entry_id=last_entry_id,
|
|
819
|
+
timeout=0, # Non-blocking read
|
|
820
|
+
attempt_tokens=attempt_tokens,
|
|
821
|
+
)
|
|
822
|
+
|
|
823
|
+
metadata = await client.get_input_plane_metadata(function._input_plane_region)
|
|
824
|
+
response: api_pb2.MapCheckInputsResponse = await input_plane_stub.MapCheckInputs(
|
|
825
|
+
request, metadata=metadata
|
|
826
|
+
)
|
|
827
|
+
check_inputs_response = [
|
|
828
|
+
(check_inputs[resp_idx][0], response.lost[resp_idx]) for resp_idx, _ in enumerate(response.lost)
|
|
829
|
+
]
|
|
830
|
+
# check_inputs_response = [(idx, lost: bool), ...]
|
|
831
|
+
await map_items_manager.handle_check_inputs_response(check_inputs_response)
|
|
832
|
+
yield
|
|
833
|
+
except asyncio.CancelledError:
|
|
834
|
+
pass
|
|
835
|
+
|
|
836
|
+
async def get_all_outputs():
|
|
837
|
+
nonlocal \
|
|
838
|
+
successful_completions, \
|
|
839
|
+
failed_completions, \
|
|
840
|
+
no_context_duplicates, \
|
|
841
|
+
stale_retry_duplicates, \
|
|
842
|
+
already_complete_duplicates, \
|
|
843
|
+
retried_outputs, \
|
|
844
|
+
last_entry_id
|
|
845
|
+
|
|
846
|
+
while not map_done_event.is_set():
|
|
847
|
+
if map_token is None:
|
|
848
|
+
await map_token_received.wait()
|
|
849
|
+
continue
|
|
850
|
+
|
|
851
|
+
request = api_pb2.MapAwaitRequest(
|
|
852
|
+
map_token=map_token,
|
|
853
|
+
last_entry_id=last_entry_id,
|
|
854
|
+
requested_at=time.time(),
|
|
855
|
+
timeout=OUTPUTS_TIMEOUT,
|
|
856
|
+
)
|
|
857
|
+
metadata = await client.get_input_plane_metadata(function._input_plane_region)
|
|
858
|
+
get_response_task = asyncio.create_task(
|
|
859
|
+
input_plane_stub.MapAwait(
|
|
860
|
+
request,
|
|
861
|
+
retry=Retry(
|
|
862
|
+
max_retries=20,
|
|
863
|
+
attempt_timeout=OUTPUTS_TIMEOUT + ATTEMPT_TIMEOUT_GRACE_PERIOD,
|
|
864
|
+
),
|
|
865
|
+
metadata=metadata,
|
|
866
|
+
)
|
|
867
|
+
)
|
|
868
|
+
map_done_task = asyncio.create_task(map_done_event.wait())
|
|
869
|
+
try:
|
|
870
|
+
done, pending = await asyncio.wait([get_response_task, map_done_task], return_when=FIRST_COMPLETED)
|
|
871
|
+
if get_response_task in done:
|
|
872
|
+
map_done_task.cancel()
|
|
873
|
+
response = get_response_task.result()
|
|
874
|
+
else:
|
|
875
|
+
assert map_done_event.is_set()
|
|
876
|
+
# map is done - no more outputs, so return early
|
|
877
|
+
return
|
|
878
|
+
finally:
|
|
879
|
+
# clean up tasks, in case of cancellations etc.
|
|
880
|
+
get_response_task.cancel()
|
|
881
|
+
map_done_task.cancel()
|
|
882
|
+
last_entry_id = response.last_entry_id
|
|
883
|
+
|
|
884
|
+
for output_item in response.outputs:
|
|
885
|
+
output_type = await map_items_manager.handle_get_outputs_response(output_item, int(time.time()))
|
|
886
|
+
if output_type == _OutputType.SUCCESSFUL_COMPLETION:
|
|
887
|
+
successful_completions += 1
|
|
888
|
+
elif output_type == _OutputType.FAILED_COMPLETION:
|
|
889
|
+
failed_completions += 1
|
|
890
|
+
elif output_type == _OutputType.RETRYING:
|
|
891
|
+
retried_outputs += 1
|
|
892
|
+
elif output_type == _OutputType.NO_CONTEXT_DUPLICATE:
|
|
893
|
+
no_context_duplicates += 1
|
|
894
|
+
elif output_type == _OutputType.STALE_RETRY_DUPLICATE:
|
|
895
|
+
stale_retry_duplicates += 1
|
|
896
|
+
elif output_type == _OutputType.ALREADY_COMPLETE_DUPLICATE:
|
|
897
|
+
already_complete_duplicates += 1
|
|
898
|
+
else:
|
|
899
|
+
raise Exception(f"Unknown output type: {output_type}")
|
|
900
|
+
|
|
901
|
+
if output_type == _OutputType.SUCCESSFUL_COMPLETION or output_type == _OutputType.FAILED_COMPLETION:
|
|
902
|
+
update_counters(completed_delta=1)
|
|
903
|
+
yield output_item
|
|
904
|
+
|
|
905
|
+
async def get_all_outputs_and_clean_up():
|
|
906
|
+
try:
|
|
907
|
+
async with aclosing(get_all_outputs()) as stream:
|
|
908
|
+
async for item in stream:
|
|
909
|
+
yield item
|
|
910
|
+
finally:
|
|
911
|
+
await queue.close()
|
|
912
|
+
pass
|
|
913
|
+
|
|
914
|
+
async def fetch_output(item: api_pb2.FunctionGetOutputsItem) -> tuple[int, Any]:
|
|
915
|
+
try:
|
|
916
|
+
output = await _process_result(item.result, item.data_format, input_plane_stub, client)
|
|
917
|
+
except Exception as e:
|
|
918
|
+
if return_exceptions:
|
|
919
|
+
if wrap_returned_exceptions:
|
|
920
|
+
# Prior to client 1.0.4 there was a bug where return_exceptions would wrap
|
|
921
|
+
# any returned exceptions in a synchronicity.UserCodeException. This adds
|
|
922
|
+
# deprecated non-breaking compatibility bandaid for migrating away from that:
|
|
923
|
+
output = modal.exception.UserCodeException(e)
|
|
924
|
+
else:
|
|
925
|
+
output = e
|
|
926
|
+
else:
|
|
927
|
+
raise e
|
|
928
|
+
return (item.idx, output)
|
|
929
|
+
|
|
930
|
+
async def poll_outputs():
|
|
931
|
+
# map to store out-of-order outputs received
|
|
932
|
+
received_outputs = {}
|
|
933
|
+
output_idx = 1 # 1-indexed map call idx
|
|
934
|
+
|
|
935
|
+
async with aclosing(
|
|
936
|
+
async_map_ordered(get_all_outputs_and_clean_up(), fetch_output, concurrency=BLOB_MAX_PARALLELISM)
|
|
937
|
+
) as streamer:
|
|
938
|
+
async for idx, output in streamer:
|
|
939
|
+
if not order_outputs:
|
|
940
|
+
yield _OutputValue(output)
|
|
941
|
+
else:
|
|
942
|
+
# hold on to outputs for function maps, so we can reorder them correctly.
|
|
943
|
+
received_outputs[idx] = output
|
|
944
|
+
|
|
945
|
+
while True:
|
|
946
|
+
if output_idx not in received_outputs:
|
|
947
|
+
# we haven't received the output for the current index yet.
|
|
948
|
+
# stop returning outputs to the caller and instead wait for
|
|
949
|
+
# the next output to arrive from the server.
|
|
950
|
+
break
|
|
951
|
+
|
|
952
|
+
output = received_outputs.pop(output_idx)
|
|
953
|
+
yield _OutputValue(output)
|
|
954
|
+
output_idx += 1
|
|
955
|
+
|
|
956
|
+
assert len(received_outputs) == 0
|
|
957
|
+
|
|
958
|
+
async def log_debug_stats():
|
|
959
|
+
def log_stats():
|
|
960
|
+
logger.debug(
|
|
961
|
+
f"Map stats:\nsuccessful_completions={successful_completions} failed_completions={failed_completions} "
|
|
962
|
+
f"no_context_duplicates={no_context_duplicates} stale_retry_duplicates={stale_retry_duplicates} "
|
|
963
|
+
f"already_complete_duplicates={already_complete_duplicates} retried_outputs={retried_outputs} "
|
|
964
|
+
f"map_token={map_token} max_inputs_outstanding={max_inputs_outstanding} "
|
|
965
|
+
f"map_items_manager_size={len(map_items_manager)} input_queue_size={input_queue_size}"
|
|
966
|
+
)
|
|
967
|
+
|
|
968
|
+
while True:
|
|
969
|
+
log_stats()
|
|
970
|
+
try:
|
|
971
|
+
await asyncio.sleep(10)
|
|
972
|
+
except asyncio.CancelledError:
|
|
973
|
+
# Log final stats before exiting
|
|
974
|
+
log_stats()
|
|
975
|
+
break
|
|
976
|
+
|
|
977
|
+
log_task = asyncio.create_task(log_debug_stats())
|
|
978
|
+
|
|
979
|
+
async with aclosing(
|
|
980
|
+
async_merge(drain_input_generator(), pump_inputs(), poll_outputs(), check_lost_inputs())
|
|
981
|
+
) as merged:
|
|
982
|
+
async for maybe_output in merged:
|
|
983
|
+
if maybe_output is not None: # ignore None sentinels
|
|
984
|
+
yield maybe_output.value
|
|
985
|
+
|
|
986
|
+
log_task.cancel()
|
|
987
|
+
|
|
988
|
+
|
|
384
989
|
async def _map_helper(
|
|
385
990
|
self: "modal.functions.Function",
|
|
386
991
|
async_input_gen: typing.AsyncGenerator[Any, None],
|
|
387
992
|
kwargs={}, # any extra keyword arguments for the function
|
|
388
993
|
order_outputs: bool = True, # return outputs in order
|
|
389
994
|
return_exceptions: bool = False, # propagate exceptions (False) or aggregate them in the results list (True)
|
|
995
|
+
wrap_returned_exceptions: bool = True,
|
|
390
996
|
) -> typing.AsyncGenerator[Any, None]:
|
|
391
997
|
"""Core implementation that supports `_map_async()`, `_starmap_async()` and `_for_each_async()`.
|
|
392
998
|
|
|
@@ -399,9 +1005,8 @@ async def _map_helper(
|
|
|
399
1005
|
We could make this explicit as an improvement or even let users decide what they
|
|
400
1006
|
prefer: throughput (prioritize queueing inputs) or latency (prioritize yielding results)
|
|
401
1007
|
"""
|
|
402
|
-
|
|
403
1008
|
raw_input_queue: Any = SynchronizedQueue() # type: ignore
|
|
404
|
-
raw_input_queue.init()
|
|
1009
|
+
await raw_input_queue.init.aio()
|
|
405
1010
|
|
|
406
1011
|
async def feed_queue():
|
|
407
1012
|
async with aclosing(async_input_gen) as streamer:
|
|
@@ -417,12 +1022,41 @@ async def _map_helper(
|
|
|
417
1022
|
# synchronicity-wrapped, since they accept executable code in the form of iterators that we don't want to run inside
|
|
418
1023
|
# the synchronicity thread. Instead, we delegate to `._map()` with a safer Queue as input.
|
|
419
1024
|
async with aclosing(
|
|
420
|
-
async_merge(
|
|
1025
|
+
async_merge(
|
|
1026
|
+
self._map.aio(raw_input_queue, order_outputs, return_exceptions, wrap_returned_exceptions), feed_queue()
|
|
1027
|
+
)
|
|
421
1028
|
) as map_output_stream:
|
|
422
1029
|
async for output in map_output_stream:
|
|
423
1030
|
yield output
|
|
424
1031
|
|
|
425
1032
|
|
|
1033
|
+
def _maybe_warn_about_exceptions(func_name: str, return_exceptions: bool, wrap_returned_exceptions: bool):
|
|
1034
|
+
if return_exceptions and wrap_returned_exceptions:
|
|
1035
|
+
deprecation_warning(
|
|
1036
|
+
(2025, 6, 27),
|
|
1037
|
+
(
|
|
1038
|
+
f"Function.{func_name} currently leaks an internal exception wrapping type "
|
|
1039
|
+
"(modal.exceptions.UserCodeException) when `return_exceptions=True` is set. "
|
|
1040
|
+
"In the future, this will change, and the underlying exception will be returned directly.\n"
|
|
1041
|
+
"To opt into the future behavior and silence this warning, add `wrap_returned_exceptions=False`:\n\n"
|
|
1042
|
+
f" f.{func_name}(..., return_exceptions=True, wrap_returned_exceptions=False)"
|
|
1043
|
+
),
|
|
1044
|
+
)
|
|
1045
|
+
|
|
1046
|
+
|
|
1047
|
+
def _invoked_from_sync_wrapper() -> bool:
|
|
1048
|
+
"""Check whether the calling function was called from a sync wrapper."""
|
|
1049
|
+
# This is temporary: we only need it to avoind double-firing the wrap_returned_exceptions warning.
|
|
1050
|
+
# (We don't want to push the warning lower in the stack beacuse then we can't attribute to the user's code.)
|
|
1051
|
+
try:
|
|
1052
|
+
frame = inspect.currentframe()
|
|
1053
|
+
caller_function_name = frame.f_back.f_back.f_code.co_name
|
|
1054
|
+
# Embeds some assumptions about how the current calling stack works, but this is just temporary.
|
|
1055
|
+
return caller_function_name == "asend"
|
|
1056
|
+
except Exception:
|
|
1057
|
+
return False
|
|
1058
|
+
|
|
1059
|
+
|
|
426
1060
|
@warn_if_generator_is_not_consumed(function_name="Function.map.aio")
|
|
427
1061
|
async def _map_async(
|
|
428
1062
|
self: "modal.functions.Function",
|
|
@@ -432,10 +1066,18 @@ async def _map_async(
|
|
|
432
1066
|
kwargs={}, # any extra keyword arguments for the function
|
|
433
1067
|
order_outputs: bool = True, # return outputs in order
|
|
434
1068
|
return_exceptions: bool = False, # propagate exceptions (False) or aggregate them in the results list (True)
|
|
1069
|
+
wrap_returned_exceptions: bool = True, # wrap returned exceptions in modal.exception.UserCodeException
|
|
435
1070
|
) -> typing.AsyncGenerator[Any, None]:
|
|
1071
|
+
if not _invoked_from_sync_wrapper():
|
|
1072
|
+
_maybe_warn_about_exceptions("map.aio", return_exceptions, wrap_returned_exceptions)
|
|
436
1073
|
async_input_gen = async_zip(*[sync_or_async_iter(it) for it in input_iterators])
|
|
437
1074
|
async for output in _map_helper(
|
|
438
|
-
self,
|
|
1075
|
+
self,
|
|
1076
|
+
async_input_gen,
|
|
1077
|
+
kwargs=kwargs,
|
|
1078
|
+
order_outputs=order_outputs,
|
|
1079
|
+
return_exceptions=return_exceptions,
|
|
1080
|
+
wrap_returned_exceptions=wrap_returned_exceptions,
|
|
439
1081
|
):
|
|
440
1082
|
yield output
|
|
441
1083
|
|
|
@@ -448,13 +1090,17 @@ async def _starmap_async(
|
|
|
448
1090
|
kwargs={},
|
|
449
1091
|
order_outputs: bool = True,
|
|
450
1092
|
return_exceptions: bool = False,
|
|
1093
|
+
wrap_returned_exceptions: bool = True,
|
|
451
1094
|
) -> typing.AsyncIterable[Any]:
|
|
1095
|
+
if not _invoked_from_sync_wrapper():
|
|
1096
|
+
_maybe_warn_about_exceptions("starmap.aio", return_exceptions, wrap_returned_exceptions)
|
|
452
1097
|
async for output in _map_helper(
|
|
453
1098
|
self,
|
|
454
1099
|
sync_or_async_iter(input_iterator),
|
|
455
1100
|
kwargs=kwargs,
|
|
456
1101
|
order_outputs=order_outputs,
|
|
457
1102
|
return_exceptions=return_exceptions,
|
|
1103
|
+
wrap_returned_exceptions=wrap_returned_exceptions,
|
|
458
1104
|
):
|
|
459
1105
|
yield output
|
|
460
1106
|
|
|
@@ -464,7 +1110,12 @@ async def _for_each_async(self, *input_iterators, kwargs={}, ignore_exceptions:
|
|
|
464
1110
|
# rather than iterating over the result
|
|
465
1111
|
async_input_gen = async_zip(*[sync_or_async_iter(it) for it in input_iterators])
|
|
466
1112
|
async for _ in _map_helper(
|
|
467
|
-
self,
|
|
1113
|
+
self,
|
|
1114
|
+
async_input_gen,
|
|
1115
|
+
kwargs=kwargs,
|
|
1116
|
+
order_outputs=False,
|
|
1117
|
+
return_exceptions=ignore_exceptions,
|
|
1118
|
+
wrap_returned_exceptions=False,
|
|
468
1119
|
):
|
|
469
1120
|
pass
|
|
470
1121
|
|
|
@@ -476,6 +1127,7 @@ def _map_sync(
|
|
|
476
1127
|
kwargs={}, # any extra keyword arguments for the function
|
|
477
1128
|
order_outputs: bool = True, # return outputs in order
|
|
478
1129
|
return_exceptions: bool = False, # propagate exceptions (False) or aggregate them in the results list (True)
|
|
1130
|
+
wrap_returned_exceptions: bool = True,
|
|
479
1131
|
) -> AsyncOrSyncIterable:
|
|
480
1132
|
"""Parallel map over a set of inputs.
|
|
481
1133
|
|
|
@@ -513,10 +1165,16 @@ def _map_sync(
|
|
|
513
1165
|
print(list(my_func.map(range(3), return_exceptions=True)))
|
|
514
1166
|
```
|
|
515
1167
|
"""
|
|
1168
|
+
_maybe_warn_about_exceptions("map", return_exceptions, wrap_returned_exceptions)
|
|
516
1169
|
|
|
517
1170
|
return AsyncOrSyncIterable(
|
|
518
1171
|
_map_async(
|
|
519
|
-
self,
|
|
1172
|
+
self,
|
|
1173
|
+
*input_iterators,
|
|
1174
|
+
kwargs=kwargs,
|
|
1175
|
+
order_outputs=order_outputs,
|
|
1176
|
+
return_exceptions=return_exceptions,
|
|
1177
|
+
wrap_returned_exceptions=wrap_returned_exceptions,
|
|
520
1178
|
),
|
|
521
1179
|
nested_async_message=(
|
|
522
1180
|
"You can't iter(Function.map()) from an async function. Use async for ... in Function.map.aio() instead."
|
|
@@ -524,6 +1182,56 @@ def _map_sync(
|
|
|
524
1182
|
)
|
|
525
1183
|
|
|
526
1184
|
|
|
1185
|
+
async def _experimental_spawn_map_async(self, *input_iterators, kwargs={}) -> "modal.functions._FunctionCall":
|
|
1186
|
+
async_input_gen = async_zip(*[sync_or_async_iter(it) for it in input_iterators])
|
|
1187
|
+
return await _spawn_map_helper(self, async_input_gen, kwargs)
|
|
1188
|
+
|
|
1189
|
+
|
|
1190
|
+
async def _spawn_map_helper(
|
|
1191
|
+
self: "modal.functions.Function", async_input_gen, kwargs={}
|
|
1192
|
+
) -> "modal.functions._FunctionCall":
|
|
1193
|
+
raw_input_queue: Any = SynchronizedQueue() # type: ignore
|
|
1194
|
+
await raw_input_queue.init.aio()
|
|
1195
|
+
|
|
1196
|
+
async def feed_queue():
|
|
1197
|
+
async with aclosing(async_input_gen) as streamer:
|
|
1198
|
+
async for args in streamer:
|
|
1199
|
+
await raw_input_queue.put.aio((args, kwargs))
|
|
1200
|
+
await raw_input_queue.put.aio(None) # end-of-input sentinel
|
|
1201
|
+
|
|
1202
|
+
fc, _ = await asyncio.gather(self._spawn_map.aio(raw_input_queue), feed_queue())
|
|
1203
|
+
return fc
|
|
1204
|
+
|
|
1205
|
+
|
|
1206
|
+
def _experimental_spawn_map_sync(self, *input_iterators, kwargs={}) -> "modal.functions._FunctionCall":
|
|
1207
|
+
"""mdmd:hidden
|
|
1208
|
+
Spawn parallel execution over a set of inputs, returning as soon as the inputs are created.
|
|
1209
|
+
|
|
1210
|
+
Unlike `modal.Function.map`, this method does not block on completion of the remote execution but
|
|
1211
|
+
returns a `modal.FunctionCall` object that can be used to poll status and retrieve results later.
|
|
1212
|
+
|
|
1213
|
+
Takes one iterator argument per argument in the function being mapped over.
|
|
1214
|
+
|
|
1215
|
+
Example:
|
|
1216
|
+
```python
|
|
1217
|
+
@app.function()
|
|
1218
|
+
def my_func(a, b):
|
|
1219
|
+
return a ** b
|
|
1220
|
+
|
|
1221
|
+
|
|
1222
|
+
@app.local_entrypoint()
|
|
1223
|
+
def main():
|
|
1224
|
+
fc = my_func.spawn_map([1, 2], [3, 4])
|
|
1225
|
+
```
|
|
1226
|
+
|
|
1227
|
+
"""
|
|
1228
|
+
|
|
1229
|
+
return run_coroutine_in_temporary_event_loop(
|
|
1230
|
+
_experimental_spawn_map_async(self, *input_iterators, kwargs=kwargs),
|
|
1231
|
+
"You can't run Function.spawn_map() from an async function. Use Function.spawn_map.aio() instead.",
|
|
1232
|
+
)
|
|
1233
|
+
|
|
1234
|
+
|
|
527
1235
|
async def _spawn_map_async(self, *input_iterators, kwargs={}) -> None:
|
|
528
1236
|
"""This runs in an event loop on the main thread. It consumes inputs from the input iterators and creates async
|
|
529
1237
|
function calls for each.
|
|
@@ -569,7 +1277,7 @@ def _spawn_map_sync(self, *input_iterators, kwargs={}) -> None:
|
|
|
569
1277
|
|
|
570
1278
|
return run_coroutine_in_temporary_event_loop(
|
|
571
1279
|
_spawn_map_async(self, *input_iterators, kwargs=kwargs),
|
|
572
|
-
"You can't run Function.spawn_map() from an async function. Use Function.
|
|
1280
|
+
"You can't run Function.spawn_map() from an async function. Use Function.spawn_map.aio() instead.",
|
|
573
1281
|
)
|
|
574
1282
|
|
|
575
1283
|
|
|
@@ -596,6 +1304,7 @@ def _starmap_sync(
|
|
|
596
1304
|
kwargs={},
|
|
597
1305
|
order_outputs: bool = True,
|
|
598
1306
|
return_exceptions: bool = False,
|
|
1307
|
+
wrap_returned_exceptions: bool = True,
|
|
599
1308
|
) -> AsyncOrSyncIterable:
|
|
600
1309
|
"""Like `map`, but spreads arguments over multiple function arguments.
|
|
601
1310
|
|
|
@@ -613,9 +1322,15 @@ def _starmap_sync(
|
|
|
613
1322
|
assert list(my_func.starmap([(1, 2), (3, 4)])) == [3, 7]
|
|
614
1323
|
```
|
|
615
1324
|
"""
|
|
1325
|
+
_maybe_warn_about_exceptions("starmap", return_exceptions, wrap_returned_exceptions)
|
|
616
1326
|
return AsyncOrSyncIterable(
|
|
617
1327
|
_starmap_async(
|
|
618
|
-
self,
|
|
1328
|
+
self,
|
|
1329
|
+
input_iterator,
|
|
1330
|
+
kwargs=kwargs,
|
|
1331
|
+
order_outputs=order_outputs,
|
|
1332
|
+
return_exceptions=return_exceptions,
|
|
1333
|
+
wrap_returned_exceptions=wrap_returned_exceptions,
|
|
619
1334
|
),
|
|
620
1335
|
nested_async_message=(
|
|
621
1336
|
"You can't `iter(Function.starmap())` from an async function. "
|
|
@@ -653,12 +1368,19 @@ class _MapItemContext:
|
|
|
653
1368
|
sync_client_retries_enabled: bool
|
|
654
1369
|
# Both these futures are strings. Omitting generic type because
|
|
655
1370
|
# it causes an error when running `inv protoc type-stubs`.
|
|
1371
|
+
# Unused. But important, input_id is not set for inputplane invocations.
|
|
656
1372
|
input_id: asyncio.Future
|
|
657
1373
|
input_jwt: asyncio.Future
|
|
658
1374
|
previous_input_jwt: Optional[str]
|
|
659
1375
|
_event_loop: asyncio.AbstractEventLoop
|
|
660
1376
|
|
|
661
|
-
def __init__(
|
|
1377
|
+
def __init__(
|
|
1378
|
+
self,
|
|
1379
|
+
input: api_pb2.FunctionInput,
|
|
1380
|
+
retry_manager: RetryManager,
|
|
1381
|
+
sync_client_retries_enabled: bool,
|
|
1382
|
+
is_input_plane_instance: bool = False,
|
|
1383
|
+
):
|
|
662
1384
|
self.state = _MapItemState.SENDING
|
|
663
1385
|
self.input = input
|
|
664
1386
|
self.retry_manager = retry_manager
|
|
@@ -669,7 +1391,22 @@ class _MapItemContext:
|
|
|
669
1391
|
# a race condition where we could receive outputs before we have
|
|
670
1392
|
# recorded the input ID and JWT in `pending_outputs`.
|
|
671
1393
|
self.input_jwt = self._event_loop.create_future()
|
|
1394
|
+
# Unused. But important, this is not set for inputplane invocations.
|
|
672
1395
|
self.input_id = self._event_loop.create_future()
|
|
1396
|
+
self._is_input_plane_instance = is_input_plane_instance
|
|
1397
|
+
|
|
1398
|
+
def handle_map_start_or_continue_response(self, attempt_token: str):
|
|
1399
|
+
if not self.input_jwt.done():
|
|
1400
|
+
self.input_jwt.set_result(attempt_token)
|
|
1401
|
+
else:
|
|
1402
|
+
# Create a new future for the next value
|
|
1403
|
+
self.input_jwt = asyncio.Future()
|
|
1404
|
+
self.input_jwt.set_result(attempt_token)
|
|
1405
|
+
|
|
1406
|
+
# Set state to WAITING_FOR_OUTPUT only if current state is SENDING. If state is
|
|
1407
|
+
# RETRYING, WAITING_TO_RETRY, or COMPLETE, then we already got the output.
|
|
1408
|
+
if self.state == _MapItemState.SENDING:
|
|
1409
|
+
self.state = _MapItemState.WAITING_FOR_OUTPUT
|
|
673
1410
|
|
|
674
1411
|
def handle_put_inputs_response(self, item: api_pb2.FunctionPutInputsResponseItem):
|
|
675
1412
|
self.input_jwt.set_result(item.input_jwt)
|
|
@@ -692,10 +1429,11 @@ class _MapItemContext:
|
|
|
692
1429
|
Return True if input state was changed to COMPLETE, otherwise False.
|
|
693
1430
|
"""
|
|
694
1431
|
# If the item is already complete, this is a duplicate output and can be ignored.
|
|
1432
|
+
|
|
695
1433
|
if self.state == _MapItemState.COMPLETE:
|
|
696
1434
|
logger.debug(
|
|
697
1435
|
f"Received output for input marked as complete. Must be duplicate, so ignoring. "
|
|
698
|
-
f"idx={item.idx} input_id={item.input_id}
|
|
1436
|
+
f"idx={item.idx} input_id={item.input_id} retry_count={item.retry_count}"
|
|
699
1437
|
)
|
|
700
1438
|
return _OutputType.ALREADY_COMPLETE_DUPLICATE
|
|
701
1439
|
# If the item's retry count doesn't match our retry count, this is probably a duplicate of an old output.
|
|
@@ -737,12 +1475,17 @@ class _MapItemContext:
|
|
|
737
1475
|
delay_ms = 0
|
|
738
1476
|
|
|
739
1477
|
# None means the maximum number of retries has been reached, so output the error
|
|
740
|
-
if delay_ms is None:
|
|
1478
|
+
if delay_ms is None or item.result.status == api_pb2.GenericResult.GENERIC_STATUS_TERMINATED:
|
|
741
1479
|
self.state = _MapItemState.COMPLETE
|
|
742
1480
|
return _OutputType.FAILED_COMPLETION
|
|
743
1481
|
|
|
744
1482
|
self.state = _MapItemState.WAITING_TO_RETRY
|
|
745
|
-
|
|
1483
|
+
|
|
1484
|
+
if self._is_input_plane_instance:
|
|
1485
|
+
retry_item = await self.create_map_start_or_continue_item(item.idx)
|
|
1486
|
+
await retry_queue.put(now_seconds + delay_ms / 1_000, retry_item)
|
|
1487
|
+
else:
|
|
1488
|
+
await retry_queue.put(now_seconds + delay_ms / 1_000, item.idx)
|
|
746
1489
|
|
|
747
1490
|
return _OutputType.RETRYING
|
|
748
1491
|
|
|
@@ -757,10 +1500,23 @@ class _MapItemContext:
|
|
|
757
1500
|
retry_count=self.retry_manager.retry_count,
|
|
758
1501
|
)
|
|
759
1502
|
|
|
1503
|
+
def set_retry_policy(self, retry_policy: api_pb2.FunctionRetryPolicy):
|
|
1504
|
+
self.retry_manager = RetryManager(retry_policy)
|
|
1505
|
+
|
|
760
1506
|
def handle_retry_response(self, input_jwt: str):
|
|
761
1507
|
self.input_jwt.set_result(input_jwt)
|
|
762
1508
|
self.state = _MapItemState.WAITING_FOR_OUTPUT
|
|
763
1509
|
|
|
1510
|
+
async def create_map_start_or_continue_item(self, idx: int) -> api_pb2.MapStartOrContinueItem:
|
|
1511
|
+
attempt_token = await self.input_jwt
|
|
1512
|
+
return api_pb2.MapStartOrContinueItem(
|
|
1513
|
+
input=api_pb2.FunctionPutInputsItem(
|
|
1514
|
+
input=self.input,
|
|
1515
|
+
idx=idx,
|
|
1516
|
+
),
|
|
1517
|
+
attempt_token=attempt_token,
|
|
1518
|
+
)
|
|
1519
|
+
|
|
764
1520
|
|
|
765
1521
|
class _MapItemsManager:
|
|
766
1522
|
def __init__(
|
|
@@ -770,6 +1526,7 @@ class _MapItemsManager:
|
|
|
770
1526
|
retry_queue: TimestampPriorityQueue,
|
|
771
1527
|
sync_client_retries_enabled: bool,
|
|
772
1528
|
max_inputs_outstanding: int,
|
|
1529
|
+
is_input_plane_instance: bool = False,
|
|
773
1530
|
):
|
|
774
1531
|
self._retry_policy = retry_policy
|
|
775
1532
|
self.function_call_invocation_type = function_call_invocation_type
|
|
@@ -780,6 +1537,10 @@ class _MapItemsManager:
|
|
|
780
1537
|
self._inputs_outstanding = asyncio.BoundedSemaphore(max_inputs_outstanding)
|
|
781
1538
|
self._item_context: dict[int, _MapItemContext] = {}
|
|
782
1539
|
self._sync_client_retries_enabled = sync_client_retries_enabled
|
|
1540
|
+
self._is_input_plane_instance = is_input_plane_instance
|
|
1541
|
+
|
|
1542
|
+
def set_retry_policy(self, retry_policy: api_pb2.FunctionRetryPolicy):
|
|
1543
|
+
self._retry_policy = retry_policy
|
|
783
1544
|
|
|
784
1545
|
async def add_items(self, items: list[api_pb2.FunctionPutInputsItem]):
|
|
785
1546
|
for item in items:
|
|
@@ -792,9 +1553,28 @@ class _MapItemsManager:
|
|
|
792
1553
|
sync_client_retries_enabled=self._sync_client_retries_enabled,
|
|
793
1554
|
)
|
|
794
1555
|
|
|
1556
|
+
async def add_items_inputplane(self, items: list[api_pb2.MapStartOrContinueItem]):
|
|
1557
|
+
for item in items:
|
|
1558
|
+
# acquire semaphore to limit the number of inputs in progress
|
|
1559
|
+
# (either queued to be sent, waiting for completion, or retrying)
|
|
1560
|
+
if item.attempt_token != "": # if it is a retry item
|
|
1561
|
+
self._item_context[item.input.idx].state = _MapItemState.SENDING
|
|
1562
|
+
continue
|
|
1563
|
+
await self._inputs_outstanding.acquire()
|
|
1564
|
+
self._item_context[item.input.idx] = _MapItemContext(
|
|
1565
|
+
input=item.input.input,
|
|
1566
|
+
retry_manager=RetryManager(self._retry_policy),
|
|
1567
|
+
sync_client_retries_enabled=self._sync_client_retries_enabled,
|
|
1568
|
+
is_input_plane_instance=self._is_input_plane_instance,
|
|
1569
|
+
)
|
|
1570
|
+
|
|
795
1571
|
async def prepare_items_for_retry(self, retriable_idxs: list[int]) -> list[api_pb2.FunctionRetryInputsItem]:
|
|
796
1572
|
return [await self._item_context[idx].prepare_item_for_retry() for idx in retriable_idxs]
|
|
797
1573
|
|
|
1574
|
+
def update_items_retry_policy(self, retry_policy: api_pb2.FunctionRetryPolicy):
|
|
1575
|
+
for ctx in self._item_context.values():
|
|
1576
|
+
ctx.set_retry_policy(retry_policy)
|
|
1577
|
+
|
|
798
1578
|
def get_input_jwts_waiting_for_output(self) -> list[str]:
|
|
799
1579
|
"""
|
|
800
1580
|
Returns a list of input_jwts for inputs that are waiting for output.
|
|
@@ -806,6 +1586,17 @@ class _MapItemsManager:
|
|
|
806
1586
|
if ctx.state == _MapItemState.WAITING_FOR_OUTPUT and ctx.input_jwt.done()
|
|
807
1587
|
]
|
|
808
1588
|
|
|
1589
|
+
def get_input_idxs_waiting_for_output(self) -> list[tuple[int, str]]:
|
|
1590
|
+
"""
|
|
1591
|
+
Returns a list of input_idxs for inputs that are waiting for output.
|
|
1592
|
+
"""
|
|
1593
|
+
# Idx doesn't need a future because it is set by client and not server.
|
|
1594
|
+
return [
|
|
1595
|
+
(idx, ctx.input_jwt.result())
|
|
1596
|
+
for idx, ctx in self._item_context.items()
|
|
1597
|
+
if ctx.state == _MapItemState.WAITING_FOR_OUTPUT and ctx.input_jwt.done()
|
|
1598
|
+
]
|
|
1599
|
+
|
|
809
1600
|
def _remove_item(self, item_idx: int):
|
|
810
1601
|
del self._item_context[item_idx]
|
|
811
1602
|
self._inputs_outstanding.release()
|
|
@@ -813,6 +1604,18 @@ class _MapItemsManager:
|
|
|
813
1604
|
def get_item_context(self, item_idx: int) -> _MapItemContext:
|
|
814
1605
|
return self._item_context.get(item_idx)
|
|
815
1606
|
|
|
1607
|
+
def handle_put_continue_response(
|
|
1608
|
+
self,
|
|
1609
|
+
items: list[tuple[int, str]], # idx, input_jwt
|
|
1610
|
+
):
|
|
1611
|
+
for index, item in items:
|
|
1612
|
+
ctx = self._item_context.get(index, None)
|
|
1613
|
+
# If the context is None, then get_all_outputs() has already received a successful
|
|
1614
|
+
# output, and deleted the context. This happens if FunctionGetOutputs completes
|
|
1615
|
+
# before MapStartOrContinueResponse is received.
|
|
1616
|
+
if ctx is not None:
|
|
1617
|
+
ctx.handle_map_start_or_continue_response(item)
|
|
1618
|
+
|
|
816
1619
|
def handle_put_inputs_response(self, items: list[api_pb2.FunctionPutInputsResponseItem]):
|
|
817
1620
|
for item in items:
|
|
818
1621
|
ctx = self._item_context.get(item.idx, None)
|
|
@@ -832,6 +1635,16 @@ class _MapItemsManager:
|
|
|
832
1635
|
if ctx is not None:
|
|
833
1636
|
ctx.handle_retry_response(input_jwt)
|
|
834
1637
|
|
|
1638
|
+
async def handle_check_inputs_response(self, response: list[tuple[int, bool]]):
|
|
1639
|
+
for idx, lost in response:
|
|
1640
|
+
ctx = self._item_context.get(idx, None)
|
|
1641
|
+
if ctx is not None:
|
|
1642
|
+
if lost:
|
|
1643
|
+
ctx.state = _MapItemState.WAITING_TO_RETRY
|
|
1644
|
+
retry_item = await ctx.create_map_start_or_continue_item(idx)
|
|
1645
|
+
_ = ctx.retry_manager.get_delay_ms() # increment retry count but instant retry for lost inputs
|
|
1646
|
+
await self._retry_queue.put(time.time(), retry_item)
|
|
1647
|
+
|
|
835
1648
|
async def handle_get_outputs_response(self, item: api_pb2.FunctionGetOutputsItem, now_seconds: int) -> _OutputType:
|
|
836
1649
|
ctx = self._item_context.get(item.idx, None)
|
|
837
1650
|
if ctx is None:
|