modal 1.0.3.dev10__py3-none-any.whl → 1.2.3.dev7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of modal might be problematic. Click here for more details.

Files changed (160) hide show
  1. modal/__init__.py +0 -2
  2. modal/__main__.py +3 -4
  3. modal/_billing.py +80 -0
  4. modal/_clustered_functions.py +7 -3
  5. modal/_clustered_functions.pyi +15 -3
  6. modal/_container_entrypoint.py +51 -69
  7. modal/_functions.py +508 -240
  8. modal/_grpc_client.py +171 -0
  9. modal/_load_context.py +105 -0
  10. modal/_object.py +81 -21
  11. modal/_output.py +58 -45
  12. modal/_partial_function.py +48 -73
  13. modal/_pty.py +7 -3
  14. modal/_resolver.py +26 -46
  15. modal/_runtime/asgi.py +4 -3
  16. modal/_runtime/container_io_manager.py +358 -220
  17. modal/_runtime/container_io_manager.pyi +296 -101
  18. modal/_runtime/execution_context.py +18 -2
  19. modal/_runtime/execution_context.pyi +64 -7
  20. modal/_runtime/gpu_memory_snapshot.py +262 -57
  21. modal/_runtime/user_code_imports.py +28 -58
  22. modal/_serialization.py +90 -6
  23. modal/_traceback.py +42 -1
  24. modal/_tunnel.pyi +380 -12
  25. modal/_utils/async_utils.py +84 -29
  26. modal/_utils/auth_token_manager.py +111 -0
  27. modal/_utils/blob_utils.py +181 -58
  28. modal/_utils/deprecation.py +19 -0
  29. modal/_utils/function_utils.py +91 -47
  30. modal/_utils/grpc_utils.py +89 -66
  31. modal/_utils/mount_utils.py +26 -1
  32. modal/_utils/name_utils.py +17 -3
  33. modal/_utils/task_command_router_client.py +536 -0
  34. modal/_utils/time_utils.py +34 -6
  35. modal/app.py +256 -88
  36. modal/app.pyi +909 -92
  37. modal/billing.py +5 -0
  38. modal/builder/2025.06.txt +18 -0
  39. modal/builder/PREVIEW.txt +18 -0
  40. modal/builder/base-images.json +58 -0
  41. modal/cli/_download.py +19 -3
  42. modal/cli/_traceback.py +3 -2
  43. modal/cli/app.py +4 -4
  44. modal/cli/cluster.py +15 -7
  45. modal/cli/config.py +5 -3
  46. modal/cli/container.py +7 -6
  47. modal/cli/dict.py +22 -16
  48. modal/cli/entry_point.py +12 -5
  49. modal/cli/environment.py +5 -4
  50. modal/cli/import_refs.py +3 -3
  51. modal/cli/launch.py +102 -5
  52. modal/cli/network_file_system.py +11 -12
  53. modal/cli/profile.py +3 -2
  54. modal/cli/programs/launch_instance_ssh.py +94 -0
  55. modal/cli/programs/run_jupyter.py +1 -1
  56. modal/cli/programs/run_marimo.py +95 -0
  57. modal/cli/programs/vscode.py +1 -1
  58. modal/cli/queues.py +57 -26
  59. modal/cli/run.py +91 -23
  60. modal/cli/secret.py +48 -22
  61. modal/cli/token.py +7 -8
  62. modal/cli/utils.py +4 -7
  63. modal/cli/volume.py +31 -25
  64. modal/client.py +15 -85
  65. modal/client.pyi +183 -62
  66. modal/cloud_bucket_mount.py +5 -3
  67. modal/cloud_bucket_mount.pyi +197 -5
  68. modal/cls.py +200 -126
  69. modal/cls.pyi +446 -68
  70. modal/config.py +29 -11
  71. modal/container_process.py +319 -19
  72. modal/container_process.pyi +190 -20
  73. modal/dict.py +290 -71
  74. modal/dict.pyi +835 -83
  75. modal/environments.py +15 -27
  76. modal/environments.pyi +46 -24
  77. modal/exception.py +14 -2
  78. modal/experimental/__init__.py +194 -40
  79. modal/experimental/flash.py +618 -0
  80. modal/experimental/flash.pyi +380 -0
  81. modal/experimental/ipython.py +11 -7
  82. modal/file_io.py +29 -36
  83. modal/file_io.pyi +251 -53
  84. modal/file_pattern_matcher.py +56 -16
  85. modal/functions.pyi +673 -92
  86. modal/gpu.py +1 -1
  87. modal/image.py +528 -176
  88. modal/image.pyi +1572 -145
  89. modal/io_streams.py +458 -128
  90. modal/io_streams.pyi +433 -52
  91. modal/mount.py +216 -151
  92. modal/mount.pyi +225 -78
  93. modal/network_file_system.py +45 -62
  94. modal/network_file_system.pyi +277 -56
  95. modal/object.pyi +93 -17
  96. modal/parallel_map.py +942 -129
  97. modal/parallel_map.pyi +294 -15
  98. modal/partial_function.py +0 -2
  99. modal/partial_function.pyi +234 -19
  100. modal/proxy.py +17 -8
  101. modal/proxy.pyi +36 -3
  102. modal/queue.py +270 -65
  103. modal/queue.pyi +817 -57
  104. modal/runner.py +115 -101
  105. modal/runner.pyi +205 -49
  106. modal/sandbox.py +512 -136
  107. modal/sandbox.pyi +845 -111
  108. modal/schedule.py +1 -1
  109. modal/secret.py +300 -70
  110. modal/secret.pyi +589 -34
  111. modal/serving.py +7 -11
  112. modal/serving.pyi +7 -8
  113. modal/snapshot.py +11 -8
  114. modal/snapshot.pyi +25 -4
  115. modal/token_flow.py +4 -4
  116. modal/token_flow.pyi +28 -8
  117. modal/volume.py +416 -158
  118. modal/volume.pyi +1117 -121
  119. {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/METADATA +10 -9
  120. modal-1.2.3.dev7.dist-info/RECORD +195 -0
  121. modal_docs/mdmd/mdmd.py +17 -4
  122. modal_proto/api.proto +534 -79
  123. modal_proto/api_grpc.py +337 -1
  124. modal_proto/api_pb2.py +1522 -968
  125. modal_proto/api_pb2.pyi +1619 -134
  126. modal_proto/api_pb2_grpc.py +699 -4
  127. modal_proto/api_pb2_grpc.pyi +226 -14
  128. modal_proto/modal_api_grpc.py +175 -154
  129. modal_proto/sandbox_router.proto +145 -0
  130. modal_proto/sandbox_router_grpc.py +105 -0
  131. modal_proto/sandbox_router_pb2.py +149 -0
  132. modal_proto/sandbox_router_pb2.pyi +333 -0
  133. modal_proto/sandbox_router_pb2_grpc.py +203 -0
  134. modal_proto/sandbox_router_pb2_grpc.pyi +75 -0
  135. modal_proto/task_command_router.proto +144 -0
  136. modal_proto/task_command_router_grpc.py +105 -0
  137. modal_proto/task_command_router_pb2.py +149 -0
  138. modal_proto/task_command_router_pb2.pyi +333 -0
  139. modal_proto/task_command_router_pb2_grpc.py +203 -0
  140. modal_proto/task_command_router_pb2_grpc.pyi +75 -0
  141. modal_version/__init__.py +1 -1
  142. modal/requirements/PREVIEW.txt +0 -16
  143. modal/requirements/base-images.json +0 -26
  144. modal-1.0.3.dev10.dist-info/RECORD +0 -179
  145. modal_proto/modal_options_grpc.py +0 -3
  146. modal_proto/options.proto +0 -19
  147. modal_proto/options_grpc.py +0 -3
  148. modal_proto/options_pb2.py +0 -35
  149. modal_proto/options_pb2.pyi +0 -20
  150. modal_proto/options_pb2_grpc.py +0 -4
  151. modal_proto/options_pb2_grpc.pyi +0 -7
  152. /modal/{requirements → builder}/2023.12.312.txt +0 -0
  153. /modal/{requirements → builder}/2023.12.txt +0 -0
  154. /modal/{requirements → builder}/2024.04.txt +0 -0
  155. /modal/{requirements → builder}/2024.10.txt +0 -0
  156. /modal/{requirements → builder}/README.md +0 -0
  157. {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/WHEEL +0 -0
  158. {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/entry_points.txt +0 -0
  159. {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/licenses/LICENSE +0 -0
  160. {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/top_level.txt +0 -0
modal/parallel_map.py CHANGED
@@ -1,13 +1,16 @@
1
1
  # Copyright Modal Labs 2024
2
2
  import asyncio
3
3
  import enum
4
+ import inspect
4
5
  import time
5
6
  import typing
7
+ from asyncio import FIRST_COMPLETED
6
8
  from dataclasses import dataclass
7
- from typing import Any, Callable, Optional
9
+ from typing import Any, Callable, Optional, Union
8
10
 
9
11
  from grpclib import Status
10
12
 
13
+ import modal.exception
11
14
  from modal._runtime.execution_context import current_input_id
12
15
  from modal._utils.async_utils import (
13
16
  AsyncOrSyncIterable,
@@ -25,13 +28,14 @@ from modal._utils.async_utils import (
25
28
  warn_if_generator_is_not_consumed,
26
29
  )
27
30
  from modal._utils.blob_utils import BLOB_MAX_PARALLELISM
31
+ from modal._utils.deprecation import deprecation_warning
28
32
  from modal._utils.function_utils import (
29
33
  ATTEMPT_TIMEOUT_GRACE_PERIOD,
30
34
  OUTPUTS_TIMEOUT,
31
35
  _create_input,
32
36
  _process_result,
33
37
  )
34
- from modal._utils.grpc_utils import RETRYABLE_GRPC_STATUS_CODES, RetryWarningMessage, retry_transient_errors
38
+ from modal._utils.grpc_utils import RETRYABLE_GRPC_STATUS_CODES, Retry, RetryWarningMessage
35
39
  from modal._utils.jwt_utils import DecodedJwt
36
40
  from modal.config import logger
37
41
  from modal.retries import RetryManager
@@ -75,19 +79,293 @@ class _OutputValue:
75
79
 
76
80
  MAX_INPUTS_OUTSTANDING_DEFAULT = 1000
77
81
 
78
- # maximum number of inputs to send to the server in a single request
82
+ # Maximum number of inputs to send to the server per FunctionPutInputs request
79
83
  MAP_INVOCATION_CHUNK_SIZE = 49
84
+ SPAWN_MAP_INVOCATION_CHUNK_SIZE = 512
85
+
80
86
 
81
87
  if typing.TYPE_CHECKING:
82
88
  import modal.functions
83
89
 
84
90
 
91
+ class InputPreprocessor:
92
+ """
93
+ Constructs FunctionPutInputsItem objects from the raw-input queue, and puts them in the processed-input queue.
94
+ """
95
+
96
+ def __init__(
97
+ self,
98
+ client: "modal.client._Client",
99
+ *,
100
+ raw_input_queue: _SynchronizedQueue,
101
+ processed_input_queue: asyncio.Queue,
102
+ function: "modal.functions._Function",
103
+ created_callback: Callable[[int], None],
104
+ done_callback: Callable[[], None],
105
+ ):
106
+ self.client = client
107
+ self.function = function
108
+ self.inputs_created = 0
109
+ self.raw_input_queue = raw_input_queue
110
+ self.processed_input_queue = processed_input_queue
111
+ self.created_callback = created_callback
112
+ self.done_callback = done_callback
113
+
114
+ async def input_iter(self):
115
+ while 1:
116
+ raw_input = await self.raw_input_queue.get()
117
+ if raw_input is None: # end of input sentinel
118
+ break
119
+ yield raw_input # args, kwargs
120
+
121
+ def create_input_factory(self):
122
+ async def create_input(argskwargs):
123
+ idx = self.inputs_created
124
+ self.inputs_created += 1
125
+ self.created_callback(self.inputs_created)
126
+ (args, kwargs) = argskwargs
127
+ return await _create_input(
128
+ args,
129
+ kwargs,
130
+ self.client.stub,
131
+ idx=idx,
132
+ function=self.function,
133
+ )
134
+
135
+ return create_input
136
+
137
+ async def drain_input_generator(self):
138
+ # Parallelize uploading blobs
139
+ async with aclosing(
140
+ async_map_ordered(self.input_iter(), self.create_input_factory(), concurrency=BLOB_MAX_PARALLELISM)
141
+ ) as streamer:
142
+ async for item in streamer:
143
+ await self.processed_input_queue.put(item)
144
+
145
+ # close queue iterator
146
+ await self.processed_input_queue.put(None)
147
+ self.done_callback()
148
+ yield
149
+
150
+
151
+ class InputPumper:
152
+ """
153
+ Reads inputs from a queue of FunctionPutInputsItems, and sends them to the server.
154
+ """
155
+
156
+ def __init__(
157
+ self,
158
+ client: "modal.client._Client",
159
+ *,
160
+ input_queue: asyncio.Queue,
161
+ function: "modal.functions._Function",
162
+ function_call_id: str,
163
+ max_batch_size: int,
164
+ map_items_manager: Optional["_MapItemsManager"] = None,
165
+ ):
166
+ self.client = client
167
+ self.function = function
168
+ self.map_items_manager = map_items_manager
169
+ self.input_queue = input_queue
170
+ self.inputs_sent = 0
171
+ self.function_call_id = function_call_id
172
+ self.max_batch_size = max_batch_size
173
+
174
+ async def pump_inputs(self):
175
+ assert self.client.stub
176
+ async for items in queue_batch_iterator(self.input_queue, max_batch_size=self.max_batch_size):
177
+ # Add items to the manager. Their state will be SENDING.
178
+ if self.map_items_manager is not None:
179
+ await self.map_items_manager.add_items(items)
180
+ request = api_pb2.FunctionPutInputsRequest(
181
+ function_id=self.function.object_id,
182
+ inputs=items,
183
+ function_call_id=self.function_call_id,
184
+ )
185
+ logger.debug(
186
+ f"Pushing {len(items)} inputs to server. Num queued inputs awaiting"
187
+ f" push is {self.input_queue.qsize()}. "
188
+ )
189
+
190
+ resp = await self.client.stub.FunctionPutInputs(request, retry=self._function_inputs_retry)
191
+ self.inputs_sent += len(items)
192
+ # Change item state to WAITING_FOR_OUTPUT, and set the input_id and input_jwt which are in the response.
193
+ if self.map_items_manager is not None:
194
+ self.map_items_manager.handle_put_inputs_response(resp.inputs)
195
+ logger.debug(
196
+ f"Successfully pushed {len(items)} inputs to server. "
197
+ f"Num queued inputs awaiting push is {self.input_queue.qsize()}."
198
+ )
199
+ yield
200
+
201
+ @property
202
+ def _function_inputs_retry(self) -> Retry:
203
+ # with 8 retries we log the warning below about every 30 seconds which isn't too spammy.
204
+ retry_warning_message = RetryWarningMessage(
205
+ message=f"Warning: map progress for function {self.function._function_name} is limited."
206
+ " Common bottlenecks include slow iteration over results, or function backlogs.",
207
+ warning_interval=8,
208
+ errors_to_warn_for=[Status.RESOURCE_EXHAUSTED],
209
+ )
210
+ return Retry(
211
+ max_retries=None,
212
+ max_delay=PUMP_INPUTS_MAX_RETRY_DELAY,
213
+ additional_status_codes=[Status.RESOURCE_EXHAUSTED],
214
+ warning_message=retry_warning_message,
215
+ )
216
+
217
+
218
+ class SyncInputPumper(InputPumper):
219
+ def __init__(
220
+ self,
221
+ client: "modal.client._Client",
222
+ *,
223
+ input_queue: asyncio.Queue,
224
+ retry_queue: TimestampPriorityQueue,
225
+ function: "modal.functions._Function",
226
+ function_call_jwt: str,
227
+ function_call_id: str,
228
+ map_items_manager: "_MapItemsManager",
229
+ ):
230
+ super().__init__(
231
+ client,
232
+ input_queue=input_queue,
233
+ function=function,
234
+ function_call_id=function_call_id,
235
+ max_batch_size=MAP_INVOCATION_CHUNK_SIZE,
236
+ map_items_manager=map_items_manager,
237
+ )
238
+ self.retry_queue = retry_queue
239
+ self.inputs_retried = 0
240
+ self.function_call_jwt = function_call_jwt
241
+
242
+ async def retry_inputs(self):
243
+ async for retriable_idxs in queue_batch_iterator(self.retry_queue, max_batch_size=self.max_batch_size):
244
+ # For each index, use the context in the manager to create a FunctionRetryInputsItem.
245
+ # This will also update the context state to RETRYING.
246
+ inputs: list[api_pb2.FunctionRetryInputsItem] = await self.map_items_manager.prepare_items_for_retry(
247
+ retriable_idxs
248
+ )
249
+ request = api_pb2.FunctionRetryInputsRequest(
250
+ function_call_jwt=self.function_call_jwt,
251
+ inputs=inputs,
252
+ )
253
+ resp = await self.client.stub.FunctionRetryInputs(request, retry=self._function_inputs_retry)
254
+ # Update the state to WAITING_FOR_OUTPUT, and update the input_jwt in the context
255
+ # to the new value in the response.
256
+ self.map_items_manager.handle_retry_response(resp.input_jwts)
257
+ logger.debug(f"Successfully pushed retry for {len(inputs)} to server.")
258
+ self.inputs_retried += len(inputs)
259
+ yield
260
+
261
+
262
+ class AsyncInputPumper(InputPumper):
263
+ def __init__(
264
+ self,
265
+ client: "modal.client._Client",
266
+ *,
267
+ input_queue: asyncio.Queue,
268
+ function: "modal.functions._Function",
269
+ function_call_id: str,
270
+ ):
271
+ super().__init__(
272
+ client,
273
+ input_queue=input_queue,
274
+ function=function,
275
+ function_call_id=function_call_id,
276
+ max_batch_size=SPAWN_MAP_INVOCATION_CHUNK_SIZE,
277
+ )
278
+
279
+ async def pump_inputs(self):
280
+ async for _ in super().pump_inputs():
281
+ pass
282
+ request = api_pb2.FunctionFinishInputsRequest(
283
+ function_id=self.function.object_id,
284
+ function_call_id=self.function_call_id,
285
+ num_inputs=self.inputs_sent,
286
+ )
287
+ await self.client.stub.FunctionFinishInputs(request, retry=Retry(max_retries=None))
288
+ yield
289
+
290
+
291
+ async def _spawn_map_invocation(
292
+ function: "modal.functions._Function", raw_input_queue: _SynchronizedQueue, client: "modal.client._Client"
293
+ ) -> tuple[str, int]:
294
+ assert client.stub
295
+ request = api_pb2.FunctionMapRequest(
296
+ function_id=function.object_id,
297
+ parent_input_id=current_input_id() or "",
298
+ function_call_type=api_pb2.FUNCTION_CALL_TYPE_MAP,
299
+ function_call_invocation_type=api_pb2.FUNCTION_CALL_INVOCATION_TYPE_ASYNC,
300
+ )
301
+ response: api_pb2.FunctionMapResponse = await client.stub.FunctionMap(request)
302
+ function_call_id = response.function_call_id
303
+
304
+ have_all_inputs = False
305
+ inputs_created = 0
306
+
307
+ def set_inputs_created(set_inputs_created):
308
+ nonlocal inputs_created
309
+ assert set_inputs_created is None or set_inputs_created > inputs_created
310
+ inputs_created = set_inputs_created
311
+
312
+ def set_have_all_inputs():
313
+ nonlocal have_all_inputs
314
+ have_all_inputs = True
315
+
316
+ input_queue: asyncio.Queue[api_pb2.FunctionPutInputsItem | None] = asyncio.Queue()
317
+ input_preprocessor = InputPreprocessor(
318
+ client=client,
319
+ raw_input_queue=raw_input_queue,
320
+ processed_input_queue=input_queue,
321
+ function=function,
322
+ created_callback=set_inputs_created,
323
+ done_callback=set_have_all_inputs,
324
+ )
325
+
326
+ input_pumper = AsyncInputPumper(
327
+ client=client,
328
+ input_queue=input_queue,
329
+ function=function,
330
+ function_call_id=function_call_id,
331
+ )
332
+
333
+ def log_stats():
334
+ logger.debug(
335
+ f"have_all_inputs={have_all_inputs} inputs_created={inputs_created} inputs_sent={input_pumper.inputs_sent} "
336
+ )
337
+
338
+ async def log_task():
339
+ while True:
340
+ log_stats()
341
+ try:
342
+ await asyncio.sleep(10)
343
+ except asyncio.CancelledError:
344
+ # Log final stats before exiting
345
+ log_stats()
346
+ break
347
+
348
+ async def consume_generator(gen):
349
+ async for _ in gen:
350
+ pass
351
+
352
+ log_debug_stats_task = asyncio.create_task(log_task())
353
+ await asyncio.gather(
354
+ consume_generator(input_preprocessor.drain_input_generator()),
355
+ consume_generator(input_pumper.pump_inputs()),
356
+ )
357
+ log_debug_stats_task.cancel()
358
+ await log_debug_stats_task
359
+ return function_call_id, inputs_created
360
+
361
+
85
362
  async def _map_invocation(
86
363
  function: "modal.functions._Function",
87
364
  raw_input_queue: _SynchronizedQueue,
88
365
  client: "modal.client._Client",
89
366
  order_outputs: bool,
90
367
  return_exceptions: bool,
368
+ wrap_returned_exceptions: bool,
91
369
  count_update_callback: Optional[Callable[[int, int], None]],
92
370
  function_call_invocation_type: "api_pb2.FunctionCallInvocationType.ValueType",
93
371
  ):
@@ -99,7 +377,7 @@ async def _map_invocation(
99
377
  return_exceptions=return_exceptions,
100
378
  function_call_invocation_type=function_call_invocation_type,
101
379
  )
102
- response: api_pb2.FunctionMapResponse = await retry_transient_errors(client.stub.FunctionMap, request)
380
+ response: api_pb2.FunctionMapResponse = await client.stub.FunctionMap(request)
103
381
 
104
382
  function_call_id = response.function_call_id
105
383
  function_call_jwt = response.function_call_jwt
@@ -110,9 +388,8 @@ async def _map_invocation(
110
388
  max_inputs_outstanding = response.max_inputs_outstanding or MAX_INPUTS_OUTSTANDING_DEFAULT
111
389
 
112
390
  have_all_inputs = False
391
+ map_done_event = asyncio.Event()
113
392
  inputs_created = 0
114
- inputs_sent = 0
115
- inputs_retried = 0
116
393
  outputs_completed = 0
117
394
  outputs_received = 0
118
395
  retried_outputs = 0
@@ -122,10 +399,6 @@ async def _map_invocation(
122
399
  stale_retry_duplicates = 0
123
400
  no_context_duplicates = 0
124
401
 
125
- def count_update():
126
- if count_update_callback is not None:
127
- count_update_callback(outputs_completed, inputs_created)
128
-
129
402
  retry_queue = TimestampPriorityQueue()
130
403
  completed_outputs: set[str] = set() # Set of input_ids whose outputs are complete (expecting no more values)
131
404
  input_queue: asyncio.Queue[api_pb2.FunctionPutInputsItem | None] = asyncio.Queue()
@@ -133,109 +406,50 @@ async def _map_invocation(
133
406
  retry_policy, function_call_invocation_type, retry_queue, sync_client_retries_enabled, max_inputs_outstanding
134
407
  )
135
408
 
136
- async def create_input(argskwargs):
137
- nonlocal inputs_created
138
- idx = inputs_created
139
- inputs_created += 1
140
- (args, kwargs) = argskwargs
141
- return await _create_input(args, kwargs, client.stub, idx=idx, method_name=function._use_method_name)
142
-
143
- async def input_iter():
144
- while 1:
145
- raw_input = await raw_input_queue.get()
146
- if raw_input is None: # end of input sentinel
147
- break
148
- yield raw_input # args, kwargs
149
-
150
- async def drain_input_generator():
151
- nonlocal have_all_inputs
152
-
153
- # Parallelize uploading blobs
154
- async with aclosing(
155
- async_map_ordered(input_iter(), create_input, concurrency=BLOB_MAX_PARALLELISM)
156
- ) as streamer:
157
- async for item in streamer:
158
- await input_queue.put(item)
159
-
160
- # close queue iterator
161
- await input_queue.put(None)
162
- have_all_inputs = True
163
- yield
409
+ input_preprocessor = InputPreprocessor(
410
+ client=client,
411
+ raw_input_queue=raw_input_queue,
412
+ processed_input_queue=input_queue,
413
+ function=function,
414
+ created_callback=lambda x: update_state(set_inputs_created=x),
415
+ done_callback=lambda: update_state(set_have_all_inputs=True),
416
+ )
164
417
 
165
- async def pump_inputs():
166
- assert client.stub
167
- nonlocal inputs_created, inputs_sent
168
- async for items in queue_batch_iterator(input_queue, max_batch_size=MAP_INVOCATION_CHUNK_SIZE):
169
- # Add items to the manager. Their state will be SENDING.
170
- await map_items_manager.add_items(items)
171
- request = api_pb2.FunctionPutInputsRequest(
172
- function_id=function.object_id,
173
- inputs=items,
174
- function_call_id=function_call_id,
175
- )
176
- logger.debug(
177
- f"Pushing {len(items)} inputs to server. Num queued inputs awaiting push is {input_queue.qsize()}."
178
- )
418
+ input_pumper = SyncInputPumper(
419
+ client=client,
420
+ input_queue=input_queue,
421
+ retry_queue=retry_queue,
422
+ function=function,
423
+ map_items_manager=map_items_manager,
424
+ function_call_jwt=function_call_jwt,
425
+ function_call_id=function_call_id,
426
+ )
179
427
 
180
- resp = await send_inputs(client.stub.FunctionPutInputs, request)
181
- count_update()
182
- inputs_sent += len(items)
183
- # Change item state to WAITING_FOR_OUTPUT, and set the input_id and input_jwt which are in the response.
184
- map_items_manager.handle_put_inputs_response(resp.inputs)
185
- logger.debug(
186
- f"Successfully pushed {len(items)} inputs to server. "
187
- f"Num queued inputs awaiting push is {input_queue.qsize()}."
188
- )
189
- yield
428
+ def update_state(set_have_all_inputs=None, set_inputs_created=None, set_outputs_completed=None):
429
+ # This should be the only method that needs nonlocal of the following vars
430
+ nonlocal have_all_inputs, inputs_created, outputs_completed
431
+ assert set_have_all_inputs is not False # not allowed
432
+ assert set_inputs_created is None or set_inputs_created > inputs_created
433
+ assert set_outputs_completed is None or set_outputs_completed > outputs_completed
434
+ if set_have_all_inputs is not None:
435
+ have_all_inputs = set_have_all_inputs
436
+ if set_inputs_created is not None:
437
+ inputs_created = set_inputs_created
438
+ if set_outputs_completed is not None:
439
+ outputs_completed = set_outputs_completed
190
440
 
191
- async def retry_inputs():
192
- nonlocal inputs_retried
193
- async for retriable_idxs in queue_batch_iterator(retry_queue, max_batch_size=MAP_INVOCATION_CHUNK_SIZE):
194
- # For each index, use the context in the manager to create a FunctionRetryInputsItem.
195
- # This will also update the context state to RETRYING.
196
- inputs: list[api_pb2.FunctionRetryInputsItem] = await map_items_manager.prepare_items_for_retry(
197
- retriable_idxs
198
- )
199
- request = api_pb2.FunctionRetryInputsRequest(
200
- function_call_jwt=function_call_jwt,
201
- inputs=inputs,
202
- )
203
- resp = await send_inputs(client.stub.FunctionRetryInputs, request)
204
- # Update the state to WAITING_FOR_OUTPUT, and update the input_jwt in the context
205
- # to the new value in the response.
206
- map_items_manager.handle_retry_response(resp.input_jwts)
207
- logger.debug(f"Successfully pushed retry for {len(inputs)} to server.")
208
- inputs_retried += len(inputs)
209
- yield
441
+ if count_update_callback is not None:
442
+ count_update_callback(outputs_completed, inputs_created)
210
443
 
211
- async def send_inputs(
212
- fn: "modal.client.UnaryUnaryWrapper",
213
- request: typing.Union[api_pb2.FunctionPutInputsRequest, api_pb2.FunctionRetryInputsRequest],
214
- ) -> typing.Union[api_pb2.FunctionPutInputsResponse, api_pb2.FunctionRetryInputsResponse]:
215
- # with 8 retries we log the warning below about every 30 seconds which isn't too spammy.
216
- retry_warning_message = RetryWarningMessage(
217
- message=f"Warning: map progress for function {function._function_name} is limited."
218
- " Common bottlenecks include slow iteration over results, or function backlogs.",
219
- warning_interval=8,
220
- errors_to_warn_for=[Status.RESOURCE_EXHAUSTED],
221
- )
222
- return await retry_transient_errors(
223
- fn,
224
- request,
225
- max_retries=None,
226
- max_delay=PUMP_INPUTS_MAX_RETRY_DELAY,
227
- additional_status_codes=[Status.RESOURCE_EXHAUSTED],
228
- retry_warning_message=retry_warning_message,
229
- )
444
+ if have_all_inputs and outputs_completed >= inputs_created:
445
+ # map is done
446
+ map_done_event.set()
230
447
 
231
448
  async def get_all_outputs():
232
449
  assert client.stub
233
450
  nonlocal \
234
- inputs_created, \
235
451
  successful_completions, \
236
452
  failed_completions, \
237
- outputs_completed, \
238
- have_all_inputs, \
239
453
  outputs_received, \
240
454
  already_complete_duplicates, \
241
455
  no_context_duplicates, \
@@ -244,7 +458,7 @@ async def _map_invocation(
244
458
 
245
459
  last_entry_id = "0-0"
246
460
 
247
- while not have_all_inputs or outputs_completed < inputs_created:
461
+ while not map_done_event.is_set():
248
462
  logger.debug(f"Requesting outputs. Have {outputs_completed} outputs, {inputs_created} inputs.")
249
463
  # Get input_jwts of all items in the WAITING_FOR_OUTPUT state.
250
464
  # The server uses these to track for lost inputs.
@@ -258,12 +472,29 @@ async def _map_invocation(
258
472
  requested_at=time.time(),
259
473
  input_jwts=input_jwts,
260
474
  )
261
- response = await retry_transient_errors(
262
- client.stub.FunctionGetOutputs,
263
- request,
264
- max_retries=20,
265
- attempt_timeout=OUTPUTS_TIMEOUT + ATTEMPT_TIMEOUT_GRACE_PERIOD,
475
+ get_response_task = asyncio.create_task(
476
+ client.stub.FunctionGetOutputs(
477
+ request,
478
+ retry=Retry(
479
+ max_retries=20,
480
+ attempt_timeout=OUTPUTS_TIMEOUT + ATTEMPT_TIMEOUT_GRACE_PERIOD,
481
+ ),
482
+ )
266
483
  )
484
+ map_done_task = asyncio.create_task(map_done_event.wait())
485
+ try:
486
+ done, pending = await asyncio.wait([get_response_task, map_done_task], return_when=FIRST_COMPLETED)
487
+ if get_response_task in done:
488
+ map_done_task.cancel()
489
+ response = get_response_task.result()
490
+ else:
491
+ assert map_done_event.is_set()
492
+ # map is done - no more outputs, so return early
493
+ return
494
+ finally:
495
+ # clean up tasks, in case of cancellations etc.
496
+ get_response_task.cancel()
497
+ map_done_task.cancel()
267
498
 
268
499
  last_entry_id = response.last_entry_id
269
500
  now_seconds = int(time.time())
@@ -288,7 +519,7 @@ async def _map_invocation(
288
519
 
289
520
  if output_type == _OutputType.SUCCESSFUL_COMPLETION or output_type == _OutputType.FAILED_COMPLETION:
290
521
  completed_outputs.add(item.input_id)
291
- outputs_completed += 1
522
+ update_state(set_outputs_completed=outputs_completed + 1)
292
523
  yield item
293
524
 
294
525
  async def get_all_outputs_and_clean_up():
@@ -306,7 +537,7 @@ async def _map_invocation(
306
537
  clear_on_success=True,
307
538
  requested_at=time.time(),
308
539
  )
309
- await retry_transient_errors(client.stub.FunctionGetOutputs, request)
540
+ await client.stub.FunctionGetOutputs(request)
310
541
  await retry_queue.close()
311
542
 
312
543
  async def fetch_output(item: api_pb2.FunctionGetOutputsItem) -> tuple[int, Any]:
@@ -314,7 +545,13 @@ async def _map_invocation(
314
545
  output = await _process_result(item.result, item.data_format, client.stub, client)
315
546
  except Exception as e:
316
547
  if return_exceptions:
317
- output = e
548
+ if wrap_returned_exceptions:
549
+ # Prior to client 1.0.4 there was a bug where return_exceptions would wrap
550
+ # any returned exceptions in a synchronicity.UserCodeException. This adds
551
+ # deprecated non-breaking compatibility bandaid for migrating away from that:
552
+ output = modal.exception.UserCodeException(e)
553
+ else:
554
+ output = e
318
555
  else:
319
556
  raise e
320
557
  return (item.idx, output)
@@ -328,7 +565,6 @@ async def _map_invocation(
328
565
  async_map_ordered(get_all_outputs_and_clean_up(), fetch_output, concurrency=BLOB_MAX_PARALLELISM)
329
566
  ) as streamer:
330
567
  async for idx, output in streamer:
331
- count_update()
332
568
  if not order_outputs:
333
569
  yield _OutputValue(output)
334
570
  else:
@@ -352,8 +588,11 @@ async def _map_invocation(
352
588
  def log_stats():
353
589
  logger.debug(
354
590
  f"Map stats: sync_client_retries_enabled={sync_client_retries_enabled} "
355
- f"have_all_inputs={have_all_inputs} inputs_created={inputs_created} input_sent={inputs_sent} "
356
- f"inputs_retried={inputs_retried} outputs_received={outputs_received} "
591
+ f"have_all_inputs={have_all_inputs} "
592
+ f"inputs_created={inputs_created} "
593
+ f"input_sent={input_pumper.inputs_sent} "
594
+ f"inputs_retried={input_pumper.inputs_retried} "
595
+ f"outputs_received={outputs_received} "
357
596
  f"successful_completions={successful_completions} failed_completions={failed_completions} "
358
597
  f"no_context_duplicates={no_context_duplicates} old_retry_duplicates={stale_retry_duplicates} "
359
598
  f"already_complete_duplicates={already_complete_duplicates} "
@@ -372,21 +611,388 @@ async def _map_invocation(
372
611
 
373
612
  log_debug_stats_task = asyncio.create_task(log_debug_stats())
374
613
  async with aclosing(
375
- async_merge(drain_input_generator(), pump_inputs(), poll_outputs(), retry_inputs())
614
+ async_merge(
615
+ input_preprocessor.drain_input_generator(),
616
+ input_pumper.pump_inputs(),
617
+ input_pumper.retry_inputs(),
618
+ poll_outputs(),
619
+ )
376
620
  ) as streamer:
377
621
  async for response in streamer:
378
- if response is not None:
622
+ if response is not None: # type: ignore[unreachable]
379
623
  yield response.value
380
624
  log_debug_stats_task.cancel()
381
625
  await log_debug_stats_task
382
626
 
383
627
 
628
+ async def _map_invocation_inputplane(
629
+ function: "modal.functions._Function",
630
+ raw_input_queue: _SynchronizedQueue,
631
+ client: "modal.client._Client",
632
+ order_outputs: bool,
633
+ return_exceptions: bool,
634
+ wrap_returned_exceptions: bool,
635
+ count_update_callback: Optional[Callable[[int, int], None]],
636
+ ) -> typing.AsyncGenerator[Any, None]:
637
+ """Input-plane implementation of a function map invocation.
638
+
639
+ This is analogous to `_map_invocation`, but instead of the control-plane
640
+ `FunctionMap` / `FunctionPutInputs` / `FunctionGetOutputs` RPCs it speaks
641
+ the input-plane protocol consisting of `MapStartOrContinue`, `MapAwait`, and `MapCheckInputs`.
642
+ """
643
+
644
+ assert function._input_plane_url, "_map_invocation_inputplane should only be used for input-plane backed functions"
645
+
646
+ input_plane_stub = await client.get_stub(function._input_plane_url)
647
+
648
+ # Required for _create_input.
649
+ assert client.stub, "Client must be hydrated with a stub for _map_invocation_inputplane"
650
+
651
+ # ------------------------------------------------------------
652
+ # Invocation-wide state
653
+ # ------------------------------------------------------------
654
+
655
+ have_all_inputs = False
656
+ map_done_event = asyncio.Event()
657
+
658
+ inputs_created = 0
659
+ outputs_completed = 0
660
+ successful_completions = 0
661
+ failed_completions = 0
662
+ no_context_duplicates = 0
663
+ stale_retry_duplicates = 0
664
+ already_complete_duplicates = 0
665
+ retried_outputs = 0
666
+ input_queue_size = 0
667
+ last_entry_id = ""
668
+
669
+ # The input-plane server returns this after the first request.
670
+ map_token = None
671
+ map_token_received = asyncio.Event()
672
+
673
+ # Single priority queue that holds *both* fresh inputs (timestamp == now)
674
+ # and future retries (timestamp > now).
675
+ queue: TimestampPriorityQueue[api_pb2.MapStartOrContinueItem] = TimestampPriorityQueue()
676
+
677
+ # Maximum number of inputs that may be in-flight (the server sends this in
678
+ # the first response – fall back to the default if we never receive it for
679
+ # any reason).
680
+ max_inputs_outstanding = MAX_INPUTS_OUTSTANDING_DEFAULT
681
+
682
+ # Set a default retry policy to construct an instance of _MapItemsManager.
683
+ # We'll update the retry policy with the actual user-specified retry policy
684
+ # from the server in the first MapStartOrContinue response.
685
+ retry_policy = api_pb2.FunctionRetryPolicy(
686
+ retries=0,
687
+ initial_delay_ms=1000,
688
+ max_delay_ms=1000,
689
+ backoff_coefficient=1.0,
690
+ )
691
+ map_items_manager = _MapItemsManager(
692
+ retry_policy=retry_policy,
693
+ function_call_invocation_type=api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC,
694
+ retry_queue=queue,
695
+ sync_client_retries_enabled=True,
696
+ max_inputs_outstanding=MAX_INPUTS_OUTSTANDING_DEFAULT,
697
+ is_input_plane_instance=True,
698
+ )
699
+
700
+ def update_counters(
701
+ created_delta: int = 0, completed_delta: int = 0, set_have_all_inputs: Union[bool, None] = None
702
+ ):
703
+ nonlocal inputs_created, outputs_completed, have_all_inputs
704
+
705
+ if created_delta:
706
+ inputs_created += created_delta
707
+ if completed_delta:
708
+ outputs_completed += completed_delta
709
+ if set_have_all_inputs is not None:
710
+ have_all_inputs = set_have_all_inputs
711
+
712
+ if count_update_callback is not None:
713
+ count_update_callback(outputs_completed, inputs_created)
714
+
715
+ if have_all_inputs and outputs_completed >= inputs_created:
716
+ map_done_event.set()
717
+
718
+ async def create_input(argskwargs):
719
+ idx = inputs_created + 1 # 1-indexed map call idx
720
+ update_counters(created_delta=1)
721
+ (args, kwargs) = argskwargs
722
+ put_item: api_pb2.FunctionPutInputsItem = await _create_input(
723
+ args,
724
+ kwargs,
725
+ client.stub,
726
+ idx=idx,
727
+ function=function,
728
+ )
729
+ return api_pb2.MapStartOrContinueItem(input=put_item)
730
+
731
+ async def input_iter():
732
+ while True:
733
+ raw_input = await raw_input_queue.get()
734
+ if raw_input is None: # end of input sentinel
735
+ break
736
+ yield raw_input # args, kwargs
737
+
738
+ async def drain_input_generator():
739
+ async with aclosing(
740
+ async_map_ordered(input_iter(), create_input, concurrency=BLOB_MAX_PARALLELISM)
741
+ ) as streamer:
742
+ async for q_item in streamer:
743
+ await queue.put(time.time(), q_item)
744
+
745
+ # All inputs have been read.
746
+ update_counters(set_have_all_inputs=True)
747
+ yield
748
+
749
+ async def pump_inputs():
750
+ nonlocal map_token, max_inputs_outstanding
751
+ async for batch in queue_batch_iterator(queue, max_batch_size=MAP_INVOCATION_CHUNK_SIZE):
752
+ # Convert the queued items into the proto format expected by the RPC.
753
+ request_items: list[api_pb2.MapStartOrContinueItem] = [
754
+ api_pb2.MapStartOrContinueItem(input=qi.input, attempt_token=qi.attempt_token) for qi in batch
755
+ ]
756
+
757
+ await map_items_manager.add_items_inputplane(request_items)
758
+
759
+ # Build request
760
+ request = api_pb2.MapStartOrContinueRequest(
761
+ function_id=function.object_id,
762
+ map_token=map_token,
763
+ parent_input_id=current_input_id() or "",
764
+ items=request_items,
765
+ )
766
+
767
+ metadata = await client.get_input_plane_metadata(function._input_plane_region)
768
+
769
+ response: api_pb2.MapStartOrContinueResponse = await input_plane_stub.MapStartOrContinue(
770
+ request,
771
+ retry=Retry(
772
+ additional_status_codes=[Status.RESOURCE_EXHAUSTED],
773
+ max_delay=PUMP_INPUTS_MAX_RETRY_DELAY,
774
+ max_retries=None,
775
+ ),
776
+ metadata=metadata,
777
+ )
778
+
779
+ # match response items to the corresponding request item index
780
+ response_items_idx_tuple = [
781
+ (request_items[idx].input.idx, attempt_token)
782
+ for idx, attempt_token in enumerate(response.attempt_tokens)
783
+ ]
784
+
785
+ map_items_manager.handle_put_continue_response(response_items_idx_tuple)
786
+
787
+ # Set the function call id and actual retry policy with the data from the first response.
788
+ # This conditional is skipped for subsequent iterations of this for-loop.
789
+ if map_token is None:
790
+ map_token = response.map_token
791
+ map_token_received.set()
792
+ max_inputs_outstanding = response.max_inputs_outstanding or MAX_INPUTS_OUTSTANDING_DEFAULT
793
+ map_items_manager.set_retry_policy(response.retry_policy)
794
+ # Update the retry policy for the first batch of inputs.
795
+ # Subsequent batches will have the correct user-specified retry policy
796
+ # set by the updated _MapItemsManager.
797
+ map_items_manager.update_items_retry_policy(response.retry_policy)
798
+ yield
799
+
800
+ async def check_lost_inputs():
801
+ nonlocal last_entry_id # shared with get_all_outputs
802
+ try:
803
+ while not map_done_event.is_set():
804
+ if map_token is None:
805
+ await map_token_received.wait()
806
+ continue
807
+
808
+ sleep_task = asyncio.create_task(asyncio.sleep(1))
809
+ map_done_task = asyncio.create_task(map_done_event.wait())
810
+ done, _ = await asyncio.wait([sleep_task, map_done_task], return_when=FIRST_COMPLETED)
811
+ if map_done_task in done:
812
+ break
813
+
814
+ # check_inputs = [(idx, attempt_token), ...]
815
+ check_inputs = map_items_manager.get_input_idxs_waiting_for_output()
816
+ attempt_tokens = [attempt_token for _, attempt_token in check_inputs]
817
+ request = api_pb2.MapCheckInputsRequest(
818
+ last_entry_id=last_entry_id,
819
+ timeout=0, # Non-blocking read
820
+ attempt_tokens=attempt_tokens,
821
+ )
822
+
823
+ metadata = await client.get_input_plane_metadata(function._input_plane_region)
824
+ response: api_pb2.MapCheckInputsResponse = await input_plane_stub.MapCheckInputs(
825
+ request, metadata=metadata
826
+ )
827
+ check_inputs_response = [
828
+ (check_inputs[resp_idx][0], response.lost[resp_idx]) for resp_idx, _ in enumerate(response.lost)
829
+ ]
830
+ # check_inputs_response = [(idx, lost: bool), ...]
831
+ await map_items_manager.handle_check_inputs_response(check_inputs_response)
832
+ yield
833
+ except asyncio.CancelledError:
834
+ pass
835
+
836
+ async def get_all_outputs():
837
+ nonlocal \
838
+ successful_completions, \
839
+ failed_completions, \
840
+ no_context_duplicates, \
841
+ stale_retry_duplicates, \
842
+ already_complete_duplicates, \
843
+ retried_outputs, \
844
+ last_entry_id
845
+
846
+ while not map_done_event.is_set():
847
+ if map_token is None:
848
+ await map_token_received.wait()
849
+ continue
850
+
851
+ request = api_pb2.MapAwaitRequest(
852
+ map_token=map_token,
853
+ last_entry_id=last_entry_id,
854
+ requested_at=time.time(),
855
+ timeout=OUTPUTS_TIMEOUT,
856
+ )
857
+ metadata = await client.get_input_plane_metadata(function._input_plane_region)
858
+ get_response_task = asyncio.create_task(
859
+ input_plane_stub.MapAwait(
860
+ request,
861
+ retry=Retry(
862
+ max_retries=20,
863
+ attempt_timeout=OUTPUTS_TIMEOUT + ATTEMPT_TIMEOUT_GRACE_PERIOD,
864
+ ),
865
+ metadata=metadata,
866
+ )
867
+ )
868
+ map_done_task = asyncio.create_task(map_done_event.wait())
869
+ try:
870
+ done, pending = await asyncio.wait([get_response_task, map_done_task], return_when=FIRST_COMPLETED)
871
+ if get_response_task in done:
872
+ map_done_task.cancel()
873
+ response = get_response_task.result()
874
+ else:
875
+ assert map_done_event.is_set()
876
+ # map is done - no more outputs, so return early
877
+ return
878
+ finally:
879
+ # clean up tasks, in case of cancellations etc.
880
+ get_response_task.cancel()
881
+ map_done_task.cancel()
882
+ last_entry_id = response.last_entry_id
883
+
884
+ for output_item in response.outputs:
885
+ output_type = await map_items_manager.handle_get_outputs_response(output_item, int(time.time()))
886
+ if output_type == _OutputType.SUCCESSFUL_COMPLETION:
887
+ successful_completions += 1
888
+ elif output_type == _OutputType.FAILED_COMPLETION:
889
+ failed_completions += 1
890
+ elif output_type == _OutputType.RETRYING:
891
+ retried_outputs += 1
892
+ elif output_type == _OutputType.NO_CONTEXT_DUPLICATE:
893
+ no_context_duplicates += 1
894
+ elif output_type == _OutputType.STALE_RETRY_DUPLICATE:
895
+ stale_retry_duplicates += 1
896
+ elif output_type == _OutputType.ALREADY_COMPLETE_DUPLICATE:
897
+ already_complete_duplicates += 1
898
+ else:
899
+ raise Exception(f"Unknown output type: {output_type}")
900
+
901
+ if output_type == _OutputType.SUCCESSFUL_COMPLETION or output_type == _OutputType.FAILED_COMPLETION:
902
+ update_counters(completed_delta=1)
903
+ yield output_item
904
+
905
+ async def get_all_outputs_and_clean_up():
906
+ try:
907
+ async with aclosing(get_all_outputs()) as stream:
908
+ async for item in stream:
909
+ yield item
910
+ finally:
911
+ await queue.close()
912
+ pass
913
+
914
+ async def fetch_output(item: api_pb2.FunctionGetOutputsItem) -> tuple[int, Any]:
915
+ try:
916
+ output = await _process_result(item.result, item.data_format, input_plane_stub, client)
917
+ except Exception as e:
918
+ if return_exceptions:
919
+ if wrap_returned_exceptions:
920
+ # Prior to client 1.0.4 there was a bug where return_exceptions would wrap
921
+ # any returned exceptions in a synchronicity.UserCodeException. This adds
922
+ # deprecated non-breaking compatibility bandaid for migrating away from that:
923
+ output = modal.exception.UserCodeException(e)
924
+ else:
925
+ output = e
926
+ else:
927
+ raise e
928
+ return (item.idx, output)
929
+
930
+ async def poll_outputs():
931
+ # map to store out-of-order outputs received
932
+ received_outputs = {}
933
+ output_idx = 1 # 1-indexed map call idx
934
+
935
+ async with aclosing(
936
+ async_map_ordered(get_all_outputs_and_clean_up(), fetch_output, concurrency=BLOB_MAX_PARALLELISM)
937
+ ) as streamer:
938
+ async for idx, output in streamer:
939
+ if not order_outputs:
940
+ yield _OutputValue(output)
941
+ else:
942
+ # hold on to outputs for function maps, so we can reorder them correctly.
943
+ received_outputs[idx] = output
944
+
945
+ while True:
946
+ if output_idx not in received_outputs:
947
+ # we haven't received the output for the current index yet.
948
+ # stop returning outputs to the caller and instead wait for
949
+ # the next output to arrive from the server.
950
+ break
951
+
952
+ output = received_outputs.pop(output_idx)
953
+ yield _OutputValue(output)
954
+ output_idx += 1
955
+
956
+ assert len(received_outputs) == 0
957
+
958
+ async def log_debug_stats():
959
+ def log_stats():
960
+ logger.debug(
961
+ f"Map stats:\nsuccessful_completions={successful_completions} failed_completions={failed_completions} "
962
+ f"no_context_duplicates={no_context_duplicates} stale_retry_duplicates={stale_retry_duplicates} "
963
+ f"already_complete_duplicates={already_complete_duplicates} retried_outputs={retried_outputs} "
964
+ f"map_token={map_token} max_inputs_outstanding={max_inputs_outstanding} "
965
+ f"map_items_manager_size={len(map_items_manager)} input_queue_size={input_queue_size}"
966
+ )
967
+
968
+ while True:
969
+ log_stats()
970
+ try:
971
+ await asyncio.sleep(10)
972
+ except asyncio.CancelledError:
973
+ # Log final stats before exiting
974
+ log_stats()
975
+ break
976
+
977
+ log_task = asyncio.create_task(log_debug_stats())
978
+
979
+ async with aclosing(
980
+ async_merge(drain_input_generator(), pump_inputs(), poll_outputs(), check_lost_inputs())
981
+ ) as merged:
982
+ async for maybe_output in merged:
983
+ if maybe_output is not None: # ignore None sentinels
984
+ yield maybe_output.value
985
+
986
+ log_task.cancel()
987
+
988
+
384
989
  async def _map_helper(
385
990
  self: "modal.functions.Function",
386
991
  async_input_gen: typing.AsyncGenerator[Any, None],
387
992
  kwargs={}, # any extra keyword arguments for the function
388
993
  order_outputs: bool = True, # return outputs in order
389
994
  return_exceptions: bool = False, # propagate exceptions (False) or aggregate them in the results list (True)
995
+ wrap_returned_exceptions: bool = True,
390
996
  ) -> typing.AsyncGenerator[Any, None]:
391
997
  """Core implementation that supports `_map_async()`, `_starmap_async()` and `_for_each_async()`.
392
998
 
@@ -399,9 +1005,8 @@ async def _map_helper(
399
1005
  We could make this explicit as an improvement or even let users decide what they
400
1006
  prefer: throughput (prioritize queueing inputs) or latency (prioritize yielding results)
401
1007
  """
402
-
403
1008
  raw_input_queue: Any = SynchronizedQueue() # type: ignore
404
- raw_input_queue.init()
1009
+ await raw_input_queue.init.aio()
405
1010
 
406
1011
  async def feed_queue():
407
1012
  async with aclosing(async_input_gen) as streamer:
@@ -417,12 +1022,41 @@ async def _map_helper(
417
1022
  # synchronicity-wrapped, since they accept executable code in the form of iterators that we don't want to run inside
418
1023
  # the synchronicity thread. Instead, we delegate to `._map()` with a safer Queue as input.
419
1024
  async with aclosing(
420
- async_merge(self._map.aio(raw_input_queue, order_outputs, return_exceptions), feed_queue())
1025
+ async_merge(
1026
+ self._map.aio(raw_input_queue, order_outputs, return_exceptions, wrap_returned_exceptions), feed_queue()
1027
+ )
421
1028
  ) as map_output_stream:
422
1029
  async for output in map_output_stream:
423
1030
  yield output
424
1031
 
425
1032
 
1033
+ def _maybe_warn_about_exceptions(func_name: str, return_exceptions: bool, wrap_returned_exceptions: bool):
1034
+ if return_exceptions and wrap_returned_exceptions:
1035
+ deprecation_warning(
1036
+ (2025, 6, 27),
1037
+ (
1038
+ f"Function.{func_name} currently leaks an internal exception wrapping type "
1039
+ "(modal.exceptions.UserCodeException) when `return_exceptions=True` is set. "
1040
+ "In the future, this will change, and the underlying exception will be returned directly.\n"
1041
+ "To opt into the future behavior and silence this warning, add `wrap_returned_exceptions=False`:\n\n"
1042
+ f" f.{func_name}(..., return_exceptions=True, wrap_returned_exceptions=False)"
1043
+ ),
1044
+ )
1045
+
1046
+
1047
+ def _invoked_from_sync_wrapper() -> bool:
1048
+ """Check whether the calling function was called from a sync wrapper."""
1049
+ # This is temporary: we only need it to avoind double-firing the wrap_returned_exceptions warning.
1050
+ # (We don't want to push the warning lower in the stack beacuse then we can't attribute to the user's code.)
1051
+ try:
1052
+ frame = inspect.currentframe()
1053
+ caller_function_name = frame.f_back.f_back.f_code.co_name
1054
+ # Embeds some assumptions about how the current calling stack works, but this is just temporary.
1055
+ return caller_function_name == "asend"
1056
+ except Exception:
1057
+ return False
1058
+
1059
+
426
1060
  @warn_if_generator_is_not_consumed(function_name="Function.map.aio")
427
1061
  async def _map_async(
428
1062
  self: "modal.functions.Function",
@@ -432,10 +1066,18 @@ async def _map_async(
432
1066
  kwargs={}, # any extra keyword arguments for the function
433
1067
  order_outputs: bool = True, # return outputs in order
434
1068
  return_exceptions: bool = False, # propagate exceptions (False) or aggregate them in the results list (True)
1069
+ wrap_returned_exceptions: bool = True, # wrap returned exceptions in modal.exception.UserCodeException
435
1070
  ) -> typing.AsyncGenerator[Any, None]:
1071
+ if not _invoked_from_sync_wrapper():
1072
+ _maybe_warn_about_exceptions("map.aio", return_exceptions, wrap_returned_exceptions)
436
1073
  async_input_gen = async_zip(*[sync_or_async_iter(it) for it in input_iterators])
437
1074
  async for output in _map_helper(
438
- self, async_input_gen, kwargs=kwargs, order_outputs=order_outputs, return_exceptions=return_exceptions
1075
+ self,
1076
+ async_input_gen,
1077
+ kwargs=kwargs,
1078
+ order_outputs=order_outputs,
1079
+ return_exceptions=return_exceptions,
1080
+ wrap_returned_exceptions=wrap_returned_exceptions,
439
1081
  ):
440
1082
  yield output
441
1083
 
@@ -448,13 +1090,17 @@ async def _starmap_async(
448
1090
  kwargs={},
449
1091
  order_outputs: bool = True,
450
1092
  return_exceptions: bool = False,
1093
+ wrap_returned_exceptions: bool = True,
451
1094
  ) -> typing.AsyncIterable[Any]:
1095
+ if not _invoked_from_sync_wrapper():
1096
+ _maybe_warn_about_exceptions("starmap.aio", return_exceptions, wrap_returned_exceptions)
452
1097
  async for output in _map_helper(
453
1098
  self,
454
1099
  sync_or_async_iter(input_iterator),
455
1100
  kwargs=kwargs,
456
1101
  order_outputs=order_outputs,
457
1102
  return_exceptions=return_exceptions,
1103
+ wrap_returned_exceptions=wrap_returned_exceptions,
458
1104
  ):
459
1105
  yield output
460
1106
 
@@ -464,7 +1110,12 @@ async def _for_each_async(self, *input_iterators, kwargs={}, ignore_exceptions:
464
1110
  # rather than iterating over the result
465
1111
  async_input_gen = async_zip(*[sync_or_async_iter(it) for it in input_iterators])
466
1112
  async for _ in _map_helper(
467
- self, async_input_gen, kwargs=kwargs, order_outputs=False, return_exceptions=ignore_exceptions
1113
+ self,
1114
+ async_input_gen,
1115
+ kwargs=kwargs,
1116
+ order_outputs=False,
1117
+ return_exceptions=ignore_exceptions,
1118
+ wrap_returned_exceptions=False,
468
1119
  ):
469
1120
  pass
470
1121
 
@@ -476,6 +1127,7 @@ def _map_sync(
476
1127
  kwargs={}, # any extra keyword arguments for the function
477
1128
  order_outputs: bool = True, # return outputs in order
478
1129
  return_exceptions: bool = False, # propagate exceptions (False) or aggregate them in the results list (True)
1130
+ wrap_returned_exceptions: bool = True,
479
1131
  ) -> AsyncOrSyncIterable:
480
1132
  """Parallel map over a set of inputs.
481
1133
 
@@ -513,10 +1165,16 @@ def _map_sync(
513
1165
  print(list(my_func.map(range(3), return_exceptions=True)))
514
1166
  ```
515
1167
  """
1168
+ _maybe_warn_about_exceptions("map", return_exceptions, wrap_returned_exceptions)
516
1169
 
517
1170
  return AsyncOrSyncIterable(
518
1171
  _map_async(
519
- self, *input_iterators, kwargs=kwargs, order_outputs=order_outputs, return_exceptions=return_exceptions
1172
+ self,
1173
+ *input_iterators,
1174
+ kwargs=kwargs,
1175
+ order_outputs=order_outputs,
1176
+ return_exceptions=return_exceptions,
1177
+ wrap_returned_exceptions=wrap_returned_exceptions,
520
1178
  ),
521
1179
  nested_async_message=(
522
1180
  "You can't iter(Function.map()) from an async function. Use async for ... in Function.map.aio() instead."
@@ -524,6 +1182,56 @@ def _map_sync(
524
1182
  )
525
1183
 
526
1184
 
1185
+ async def _experimental_spawn_map_async(self, *input_iterators, kwargs={}) -> "modal.functions._FunctionCall":
1186
+ async_input_gen = async_zip(*[sync_or_async_iter(it) for it in input_iterators])
1187
+ return await _spawn_map_helper(self, async_input_gen, kwargs)
1188
+
1189
+
1190
+ async def _spawn_map_helper(
1191
+ self: "modal.functions.Function", async_input_gen, kwargs={}
1192
+ ) -> "modal.functions._FunctionCall":
1193
+ raw_input_queue: Any = SynchronizedQueue() # type: ignore
1194
+ await raw_input_queue.init.aio()
1195
+
1196
+ async def feed_queue():
1197
+ async with aclosing(async_input_gen) as streamer:
1198
+ async for args in streamer:
1199
+ await raw_input_queue.put.aio((args, kwargs))
1200
+ await raw_input_queue.put.aio(None) # end-of-input sentinel
1201
+
1202
+ fc, _ = await asyncio.gather(self._spawn_map.aio(raw_input_queue), feed_queue())
1203
+ return fc
1204
+
1205
+
1206
+ def _experimental_spawn_map_sync(self, *input_iterators, kwargs={}) -> "modal.functions._FunctionCall":
1207
+ """mdmd:hidden
1208
+ Spawn parallel execution over a set of inputs, returning as soon as the inputs are created.
1209
+
1210
+ Unlike `modal.Function.map`, this method does not block on completion of the remote execution but
1211
+ returns a `modal.FunctionCall` object that can be used to poll status and retrieve results later.
1212
+
1213
+ Takes one iterator argument per argument in the function being mapped over.
1214
+
1215
+ Example:
1216
+ ```python
1217
+ @app.function()
1218
+ def my_func(a, b):
1219
+ return a ** b
1220
+
1221
+
1222
+ @app.local_entrypoint()
1223
+ def main():
1224
+ fc = my_func.spawn_map([1, 2], [3, 4])
1225
+ ```
1226
+
1227
+ """
1228
+
1229
+ return run_coroutine_in_temporary_event_loop(
1230
+ _experimental_spawn_map_async(self, *input_iterators, kwargs=kwargs),
1231
+ "You can't run Function.spawn_map() from an async function. Use Function.spawn_map.aio() instead.",
1232
+ )
1233
+
1234
+
527
1235
  async def _spawn_map_async(self, *input_iterators, kwargs={}) -> None:
528
1236
  """This runs in an event loop on the main thread. It consumes inputs from the input iterators and creates async
529
1237
  function calls for each.
@@ -569,7 +1277,7 @@ def _spawn_map_sync(self, *input_iterators, kwargs={}) -> None:
569
1277
 
570
1278
  return run_coroutine_in_temporary_event_loop(
571
1279
  _spawn_map_async(self, *input_iterators, kwargs=kwargs),
572
- "You can't run Function.spawn_map() from an async function. Use Function.map.aio() instead.",
1280
+ "You can't run Function.spawn_map() from an async function. Use Function.spawn_map.aio() instead.",
573
1281
  )
574
1282
 
575
1283
 
@@ -596,6 +1304,7 @@ def _starmap_sync(
596
1304
  kwargs={},
597
1305
  order_outputs: bool = True,
598
1306
  return_exceptions: bool = False,
1307
+ wrap_returned_exceptions: bool = True,
599
1308
  ) -> AsyncOrSyncIterable:
600
1309
  """Like `map`, but spreads arguments over multiple function arguments.
601
1310
 
@@ -613,9 +1322,15 @@ def _starmap_sync(
613
1322
  assert list(my_func.starmap([(1, 2), (3, 4)])) == [3, 7]
614
1323
  ```
615
1324
  """
1325
+ _maybe_warn_about_exceptions("starmap", return_exceptions, wrap_returned_exceptions)
616
1326
  return AsyncOrSyncIterable(
617
1327
  _starmap_async(
618
- self, input_iterator, kwargs=kwargs, order_outputs=order_outputs, return_exceptions=return_exceptions
1328
+ self,
1329
+ input_iterator,
1330
+ kwargs=kwargs,
1331
+ order_outputs=order_outputs,
1332
+ return_exceptions=return_exceptions,
1333
+ wrap_returned_exceptions=wrap_returned_exceptions,
619
1334
  ),
620
1335
  nested_async_message=(
621
1336
  "You can't `iter(Function.starmap())` from an async function. "
@@ -653,12 +1368,19 @@ class _MapItemContext:
653
1368
  sync_client_retries_enabled: bool
654
1369
  # Both these futures are strings. Omitting generic type because
655
1370
  # it causes an error when running `inv protoc type-stubs`.
1371
+ # Unused. But important, input_id is not set for inputplane invocations.
656
1372
  input_id: asyncio.Future
657
1373
  input_jwt: asyncio.Future
658
1374
  previous_input_jwt: Optional[str]
659
1375
  _event_loop: asyncio.AbstractEventLoop
660
1376
 
661
- def __init__(self, input: api_pb2.FunctionInput, retry_manager: RetryManager, sync_client_retries_enabled: bool):
1377
+ def __init__(
1378
+ self,
1379
+ input: api_pb2.FunctionInput,
1380
+ retry_manager: RetryManager,
1381
+ sync_client_retries_enabled: bool,
1382
+ is_input_plane_instance: bool = False,
1383
+ ):
662
1384
  self.state = _MapItemState.SENDING
663
1385
  self.input = input
664
1386
  self.retry_manager = retry_manager
@@ -669,7 +1391,22 @@ class _MapItemContext:
669
1391
  # a race condition where we could receive outputs before we have
670
1392
  # recorded the input ID and JWT in `pending_outputs`.
671
1393
  self.input_jwt = self._event_loop.create_future()
1394
+ # Unused. But important, this is not set for inputplane invocations.
672
1395
  self.input_id = self._event_loop.create_future()
1396
+ self._is_input_plane_instance = is_input_plane_instance
1397
+
1398
+ def handle_map_start_or_continue_response(self, attempt_token: str):
1399
+ if not self.input_jwt.done():
1400
+ self.input_jwt.set_result(attempt_token)
1401
+ else:
1402
+ # Create a new future for the next value
1403
+ self.input_jwt = asyncio.Future()
1404
+ self.input_jwt.set_result(attempt_token)
1405
+
1406
+ # Set state to WAITING_FOR_OUTPUT only if current state is SENDING. If state is
1407
+ # RETRYING, WAITING_TO_RETRY, or COMPLETE, then we already got the output.
1408
+ if self.state == _MapItemState.SENDING:
1409
+ self.state = _MapItemState.WAITING_FOR_OUTPUT
673
1410
 
674
1411
  def handle_put_inputs_response(self, item: api_pb2.FunctionPutInputsResponseItem):
675
1412
  self.input_jwt.set_result(item.input_jwt)
@@ -692,10 +1429,11 @@ class _MapItemContext:
692
1429
  Return True if input state was changed to COMPLETE, otherwise False.
693
1430
  """
694
1431
  # If the item is already complete, this is a duplicate output and can be ignored.
1432
+
695
1433
  if self.state == _MapItemState.COMPLETE:
696
1434
  logger.debug(
697
1435
  f"Received output for input marked as complete. Must be duplicate, so ignoring. "
698
- f"idx={item.idx} input_id={item.input_id}, retry_count={item.retry_count}"
1436
+ f"idx={item.idx} input_id={item.input_id} retry_count={item.retry_count}"
699
1437
  )
700
1438
  return _OutputType.ALREADY_COMPLETE_DUPLICATE
701
1439
  # If the item's retry count doesn't match our retry count, this is probably a duplicate of an old output.
@@ -737,12 +1475,17 @@ class _MapItemContext:
737
1475
  delay_ms = 0
738
1476
 
739
1477
  # None means the maximum number of retries has been reached, so output the error
740
- if delay_ms is None:
1478
+ if delay_ms is None or item.result.status == api_pb2.GenericResult.GENERIC_STATUS_TERMINATED:
741
1479
  self.state = _MapItemState.COMPLETE
742
1480
  return _OutputType.FAILED_COMPLETION
743
1481
 
744
1482
  self.state = _MapItemState.WAITING_TO_RETRY
745
- await retry_queue.put(now_seconds + (delay_ms / 1000), item.idx)
1483
+
1484
+ if self._is_input_plane_instance:
1485
+ retry_item = await self.create_map_start_or_continue_item(item.idx)
1486
+ await retry_queue.put(now_seconds + delay_ms / 1_000, retry_item)
1487
+ else:
1488
+ await retry_queue.put(now_seconds + delay_ms / 1_000, item.idx)
746
1489
 
747
1490
  return _OutputType.RETRYING
748
1491
 
@@ -757,10 +1500,23 @@ class _MapItemContext:
757
1500
  retry_count=self.retry_manager.retry_count,
758
1501
  )
759
1502
 
1503
+ def set_retry_policy(self, retry_policy: api_pb2.FunctionRetryPolicy):
1504
+ self.retry_manager = RetryManager(retry_policy)
1505
+
760
1506
  def handle_retry_response(self, input_jwt: str):
761
1507
  self.input_jwt.set_result(input_jwt)
762
1508
  self.state = _MapItemState.WAITING_FOR_OUTPUT
763
1509
 
1510
+ async def create_map_start_or_continue_item(self, idx: int) -> api_pb2.MapStartOrContinueItem:
1511
+ attempt_token = await self.input_jwt
1512
+ return api_pb2.MapStartOrContinueItem(
1513
+ input=api_pb2.FunctionPutInputsItem(
1514
+ input=self.input,
1515
+ idx=idx,
1516
+ ),
1517
+ attempt_token=attempt_token,
1518
+ )
1519
+
764
1520
 
765
1521
  class _MapItemsManager:
766
1522
  def __init__(
@@ -770,6 +1526,7 @@ class _MapItemsManager:
770
1526
  retry_queue: TimestampPriorityQueue,
771
1527
  sync_client_retries_enabled: bool,
772
1528
  max_inputs_outstanding: int,
1529
+ is_input_plane_instance: bool = False,
773
1530
  ):
774
1531
  self._retry_policy = retry_policy
775
1532
  self.function_call_invocation_type = function_call_invocation_type
@@ -780,6 +1537,10 @@ class _MapItemsManager:
780
1537
  self._inputs_outstanding = asyncio.BoundedSemaphore(max_inputs_outstanding)
781
1538
  self._item_context: dict[int, _MapItemContext] = {}
782
1539
  self._sync_client_retries_enabled = sync_client_retries_enabled
1540
+ self._is_input_plane_instance = is_input_plane_instance
1541
+
1542
+ def set_retry_policy(self, retry_policy: api_pb2.FunctionRetryPolicy):
1543
+ self._retry_policy = retry_policy
783
1544
 
784
1545
  async def add_items(self, items: list[api_pb2.FunctionPutInputsItem]):
785
1546
  for item in items:
@@ -792,9 +1553,28 @@ class _MapItemsManager:
792
1553
  sync_client_retries_enabled=self._sync_client_retries_enabled,
793
1554
  )
794
1555
 
1556
+ async def add_items_inputplane(self, items: list[api_pb2.MapStartOrContinueItem]):
1557
+ for item in items:
1558
+ # acquire semaphore to limit the number of inputs in progress
1559
+ # (either queued to be sent, waiting for completion, or retrying)
1560
+ if item.attempt_token != "": # if it is a retry item
1561
+ self._item_context[item.input.idx].state = _MapItemState.SENDING
1562
+ continue
1563
+ await self._inputs_outstanding.acquire()
1564
+ self._item_context[item.input.idx] = _MapItemContext(
1565
+ input=item.input.input,
1566
+ retry_manager=RetryManager(self._retry_policy),
1567
+ sync_client_retries_enabled=self._sync_client_retries_enabled,
1568
+ is_input_plane_instance=self._is_input_plane_instance,
1569
+ )
1570
+
795
1571
  async def prepare_items_for_retry(self, retriable_idxs: list[int]) -> list[api_pb2.FunctionRetryInputsItem]:
796
1572
  return [await self._item_context[idx].prepare_item_for_retry() for idx in retriable_idxs]
797
1573
 
1574
+ def update_items_retry_policy(self, retry_policy: api_pb2.FunctionRetryPolicy):
1575
+ for ctx in self._item_context.values():
1576
+ ctx.set_retry_policy(retry_policy)
1577
+
798
1578
  def get_input_jwts_waiting_for_output(self) -> list[str]:
799
1579
  """
800
1580
  Returns a list of input_jwts for inputs that are waiting for output.
@@ -806,6 +1586,17 @@ class _MapItemsManager:
806
1586
  if ctx.state == _MapItemState.WAITING_FOR_OUTPUT and ctx.input_jwt.done()
807
1587
  ]
808
1588
 
1589
+ def get_input_idxs_waiting_for_output(self) -> list[tuple[int, str]]:
1590
+ """
1591
+ Returns a list of input_idxs for inputs that are waiting for output.
1592
+ """
1593
+ # Idx doesn't need a future because it is set by client and not server.
1594
+ return [
1595
+ (idx, ctx.input_jwt.result())
1596
+ for idx, ctx in self._item_context.items()
1597
+ if ctx.state == _MapItemState.WAITING_FOR_OUTPUT and ctx.input_jwt.done()
1598
+ ]
1599
+
809
1600
  def _remove_item(self, item_idx: int):
810
1601
  del self._item_context[item_idx]
811
1602
  self._inputs_outstanding.release()
@@ -813,6 +1604,18 @@ class _MapItemsManager:
813
1604
  def get_item_context(self, item_idx: int) -> _MapItemContext:
814
1605
  return self._item_context.get(item_idx)
815
1606
 
1607
+ def handle_put_continue_response(
1608
+ self,
1609
+ items: list[tuple[int, str]], # idx, input_jwt
1610
+ ):
1611
+ for index, item in items:
1612
+ ctx = self._item_context.get(index, None)
1613
+ # If the context is None, then get_all_outputs() has already received a successful
1614
+ # output, and deleted the context. This happens if FunctionGetOutputs completes
1615
+ # before MapStartOrContinueResponse is received.
1616
+ if ctx is not None:
1617
+ ctx.handle_map_start_or_continue_response(item)
1618
+
816
1619
  def handle_put_inputs_response(self, items: list[api_pb2.FunctionPutInputsResponseItem]):
817
1620
  for item in items:
818
1621
  ctx = self._item_context.get(item.idx, None)
@@ -832,6 +1635,16 @@ class _MapItemsManager:
832
1635
  if ctx is not None:
833
1636
  ctx.handle_retry_response(input_jwt)
834
1637
 
1638
+ async def handle_check_inputs_response(self, response: list[tuple[int, bool]]):
1639
+ for idx, lost in response:
1640
+ ctx = self._item_context.get(idx, None)
1641
+ if ctx is not None:
1642
+ if lost:
1643
+ ctx.state = _MapItemState.WAITING_TO_RETRY
1644
+ retry_item = await ctx.create_map_start_or_continue_item(idx)
1645
+ _ = ctx.retry_manager.get_delay_ms() # increment retry count but instant retry for lost inputs
1646
+ await self._retry_queue.put(time.time(), retry_item)
1647
+
835
1648
  async def handle_get_outputs_response(self, item: api_pb2.FunctionGetOutputsItem, now_seconds: int) -> _OutputType:
836
1649
  ctx = self._item_context.get(item.idx, None)
837
1650
  if ctx is None: