modal 1.0.6.dev58__py3-none-any.whl → 1.2.3.dev7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of modal might be problematic. Click here for more details.

Files changed (147) hide show
  1. modal/__main__.py +3 -4
  2. modal/_billing.py +80 -0
  3. modal/_clustered_functions.py +7 -3
  4. modal/_clustered_functions.pyi +4 -2
  5. modal/_container_entrypoint.py +41 -49
  6. modal/_functions.py +424 -195
  7. modal/_grpc_client.py +171 -0
  8. modal/_load_context.py +105 -0
  9. modal/_object.py +68 -20
  10. modal/_output.py +58 -45
  11. modal/_partial_function.py +36 -11
  12. modal/_pty.py +7 -3
  13. modal/_resolver.py +21 -35
  14. modal/_runtime/asgi.py +4 -3
  15. modal/_runtime/container_io_manager.py +301 -186
  16. modal/_runtime/container_io_manager.pyi +70 -61
  17. modal/_runtime/execution_context.py +18 -2
  18. modal/_runtime/execution_context.pyi +4 -1
  19. modal/_runtime/gpu_memory_snapshot.py +170 -63
  20. modal/_runtime/user_code_imports.py +28 -58
  21. modal/_serialization.py +57 -1
  22. modal/_utils/async_utils.py +33 -12
  23. modal/_utils/auth_token_manager.py +2 -5
  24. modal/_utils/blob_utils.py +110 -53
  25. modal/_utils/function_utils.py +49 -42
  26. modal/_utils/grpc_utils.py +80 -50
  27. modal/_utils/mount_utils.py +26 -1
  28. modal/_utils/name_utils.py +17 -3
  29. modal/_utils/task_command_router_client.py +536 -0
  30. modal/_utils/time_utils.py +34 -6
  31. modal/app.py +219 -83
  32. modal/app.pyi +229 -56
  33. modal/billing.py +5 -0
  34. modal/{requirements → builder}/2025.06.txt +1 -0
  35. modal/{requirements → builder}/PREVIEW.txt +1 -0
  36. modal/cli/_download.py +19 -3
  37. modal/cli/_traceback.py +3 -2
  38. modal/cli/app.py +4 -4
  39. modal/cli/cluster.py +15 -7
  40. modal/cli/config.py +5 -3
  41. modal/cli/container.py +7 -6
  42. modal/cli/dict.py +22 -16
  43. modal/cli/entry_point.py +12 -5
  44. modal/cli/environment.py +5 -4
  45. modal/cli/import_refs.py +3 -3
  46. modal/cli/launch.py +102 -5
  47. modal/cli/network_file_system.py +9 -13
  48. modal/cli/profile.py +3 -2
  49. modal/cli/programs/launch_instance_ssh.py +94 -0
  50. modal/cli/programs/run_jupyter.py +1 -1
  51. modal/cli/programs/run_marimo.py +95 -0
  52. modal/cli/programs/vscode.py +1 -1
  53. modal/cli/queues.py +57 -26
  54. modal/cli/run.py +58 -16
  55. modal/cli/secret.py +48 -22
  56. modal/cli/utils.py +3 -4
  57. modal/cli/volume.py +28 -25
  58. modal/client.py +13 -116
  59. modal/client.pyi +9 -91
  60. modal/cloud_bucket_mount.py +5 -3
  61. modal/cloud_bucket_mount.pyi +5 -1
  62. modal/cls.py +130 -102
  63. modal/cls.pyi +45 -85
  64. modal/config.py +29 -10
  65. modal/container_process.py +291 -13
  66. modal/container_process.pyi +95 -32
  67. modal/dict.py +282 -63
  68. modal/dict.pyi +423 -73
  69. modal/environments.py +15 -27
  70. modal/environments.pyi +5 -15
  71. modal/exception.py +8 -0
  72. modal/experimental/__init__.py +143 -38
  73. modal/experimental/flash.py +247 -78
  74. modal/experimental/flash.pyi +137 -9
  75. modal/file_io.py +14 -28
  76. modal/file_io.pyi +2 -2
  77. modal/file_pattern_matcher.py +25 -16
  78. modal/functions.pyi +134 -61
  79. modal/image.py +255 -86
  80. modal/image.pyi +300 -62
  81. modal/io_streams.py +436 -126
  82. modal/io_streams.pyi +236 -171
  83. modal/mount.py +62 -157
  84. modal/mount.pyi +45 -172
  85. modal/network_file_system.py +30 -53
  86. modal/network_file_system.pyi +16 -76
  87. modal/object.pyi +42 -8
  88. modal/parallel_map.py +821 -113
  89. modal/parallel_map.pyi +134 -0
  90. modal/partial_function.pyi +4 -1
  91. modal/proxy.py +16 -7
  92. modal/proxy.pyi +10 -2
  93. modal/queue.py +263 -61
  94. modal/queue.pyi +409 -66
  95. modal/runner.py +112 -92
  96. modal/runner.pyi +45 -27
  97. modal/sandbox.py +451 -124
  98. modal/sandbox.pyi +513 -67
  99. modal/secret.py +291 -67
  100. modal/secret.pyi +425 -19
  101. modal/serving.py +7 -11
  102. modal/serving.pyi +7 -8
  103. modal/snapshot.py +11 -8
  104. modal/token_flow.py +4 -4
  105. modal/volume.py +344 -98
  106. modal/volume.pyi +464 -68
  107. {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/METADATA +9 -8
  108. modal-1.2.3.dev7.dist-info/RECORD +195 -0
  109. modal_docs/mdmd/mdmd.py +11 -1
  110. modal_proto/api.proto +399 -67
  111. modal_proto/api_grpc.py +241 -1
  112. modal_proto/api_pb2.py +1395 -1000
  113. modal_proto/api_pb2.pyi +1239 -79
  114. modal_proto/api_pb2_grpc.py +499 -4
  115. modal_proto/api_pb2_grpc.pyi +162 -14
  116. modal_proto/modal_api_grpc.py +175 -160
  117. modal_proto/sandbox_router.proto +145 -0
  118. modal_proto/sandbox_router_grpc.py +105 -0
  119. modal_proto/sandbox_router_pb2.py +149 -0
  120. modal_proto/sandbox_router_pb2.pyi +333 -0
  121. modal_proto/sandbox_router_pb2_grpc.py +203 -0
  122. modal_proto/sandbox_router_pb2_grpc.pyi +75 -0
  123. modal_proto/task_command_router.proto +144 -0
  124. modal_proto/task_command_router_grpc.py +105 -0
  125. modal_proto/task_command_router_pb2.py +149 -0
  126. modal_proto/task_command_router_pb2.pyi +333 -0
  127. modal_proto/task_command_router_pb2_grpc.py +203 -0
  128. modal_proto/task_command_router_pb2_grpc.pyi +75 -0
  129. modal_version/__init__.py +1 -1
  130. modal-1.0.6.dev58.dist-info/RECORD +0 -183
  131. modal_proto/modal_options_grpc.py +0 -3
  132. modal_proto/options.proto +0 -19
  133. modal_proto/options_grpc.py +0 -3
  134. modal_proto/options_pb2.py +0 -35
  135. modal_proto/options_pb2.pyi +0 -20
  136. modal_proto/options_pb2_grpc.py +0 -4
  137. modal_proto/options_pb2_grpc.pyi +0 -7
  138. /modal/{requirements → builder}/2023.12.312.txt +0 -0
  139. /modal/{requirements → builder}/2023.12.txt +0 -0
  140. /modal/{requirements → builder}/2024.04.txt +0 -0
  141. /modal/{requirements → builder}/2024.10.txt +0 -0
  142. /modal/{requirements → builder}/README.md +0 -0
  143. /modal/{requirements → builder}/base-images.json +0 -0
  144. {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/WHEEL +0 -0
  145. {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/entry_points.txt +0 -0
  146. {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/licenses/LICENSE +0 -0
  147. {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/top_level.txt +0 -0
modal/parallel_map.py CHANGED
@@ -6,7 +6,7 @@ import time
6
6
  import typing
7
7
  from asyncio import FIRST_COMPLETED
8
8
  from dataclasses import dataclass
9
- from typing import Any, Callable, Optional
9
+ from typing import Any, Callable, Optional, Union
10
10
 
11
11
  from grpclib import Status
12
12
 
@@ -35,7 +35,7 @@ from modal._utils.function_utils import (
35
35
  _create_input,
36
36
  _process_result,
37
37
  )
38
- from modal._utils.grpc_utils import RETRYABLE_GRPC_STATUS_CODES, RetryWarningMessage, retry_transient_errors
38
+ from modal._utils.grpc_utils import RETRYABLE_GRPC_STATUS_CODES, Retry, RetryWarningMessage
39
39
  from modal._utils.jwt_utils import DecodedJwt
40
40
  from modal.config import logger
41
41
  from modal.retries import RetryManager
@@ -79,13 +79,286 @@ class _OutputValue:
79
79
 
80
80
  MAX_INPUTS_OUTSTANDING_DEFAULT = 1000
81
81
 
82
- # maximum number of inputs to send to the server in a single request
82
+ # Maximum number of inputs to send to the server per FunctionPutInputs request
83
83
  MAP_INVOCATION_CHUNK_SIZE = 49
84
+ SPAWN_MAP_INVOCATION_CHUNK_SIZE = 512
85
+
84
86
 
85
87
  if typing.TYPE_CHECKING:
86
88
  import modal.functions
87
89
 
88
90
 
91
+ class InputPreprocessor:
92
+ """
93
+ Constructs FunctionPutInputsItem objects from the raw-input queue, and puts them in the processed-input queue.
94
+ """
95
+
96
+ def __init__(
97
+ self,
98
+ client: "modal.client._Client",
99
+ *,
100
+ raw_input_queue: _SynchronizedQueue,
101
+ processed_input_queue: asyncio.Queue,
102
+ function: "modal.functions._Function",
103
+ created_callback: Callable[[int], None],
104
+ done_callback: Callable[[], None],
105
+ ):
106
+ self.client = client
107
+ self.function = function
108
+ self.inputs_created = 0
109
+ self.raw_input_queue = raw_input_queue
110
+ self.processed_input_queue = processed_input_queue
111
+ self.created_callback = created_callback
112
+ self.done_callback = done_callback
113
+
114
+ async def input_iter(self):
115
+ while 1:
116
+ raw_input = await self.raw_input_queue.get()
117
+ if raw_input is None: # end of input sentinel
118
+ break
119
+ yield raw_input # args, kwargs
120
+
121
+ def create_input_factory(self):
122
+ async def create_input(argskwargs):
123
+ idx = self.inputs_created
124
+ self.inputs_created += 1
125
+ self.created_callback(self.inputs_created)
126
+ (args, kwargs) = argskwargs
127
+ return await _create_input(
128
+ args,
129
+ kwargs,
130
+ self.client.stub,
131
+ idx=idx,
132
+ function=self.function,
133
+ )
134
+
135
+ return create_input
136
+
137
+ async def drain_input_generator(self):
138
+ # Parallelize uploading blobs
139
+ async with aclosing(
140
+ async_map_ordered(self.input_iter(), self.create_input_factory(), concurrency=BLOB_MAX_PARALLELISM)
141
+ ) as streamer:
142
+ async for item in streamer:
143
+ await self.processed_input_queue.put(item)
144
+
145
+ # close queue iterator
146
+ await self.processed_input_queue.put(None)
147
+ self.done_callback()
148
+ yield
149
+
150
+
151
+ class InputPumper:
152
+ """
153
+ Reads inputs from a queue of FunctionPutInputsItems, and sends them to the server.
154
+ """
155
+
156
+ def __init__(
157
+ self,
158
+ client: "modal.client._Client",
159
+ *,
160
+ input_queue: asyncio.Queue,
161
+ function: "modal.functions._Function",
162
+ function_call_id: str,
163
+ max_batch_size: int,
164
+ map_items_manager: Optional["_MapItemsManager"] = None,
165
+ ):
166
+ self.client = client
167
+ self.function = function
168
+ self.map_items_manager = map_items_manager
169
+ self.input_queue = input_queue
170
+ self.inputs_sent = 0
171
+ self.function_call_id = function_call_id
172
+ self.max_batch_size = max_batch_size
173
+
174
+ async def pump_inputs(self):
175
+ assert self.client.stub
176
+ async for items in queue_batch_iterator(self.input_queue, max_batch_size=self.max_batch_size):
177
+ # Add items to the manager. Their state will be SENDING.
178
+ if self.map_items_manager is not None:
179
+ await self.map_items_manager.add_items(items)
180
+ request = api_pb2.FunctionPutInputsRequest(
181
+ function_id=self.function.object_id,
182
+ inputs=items,
183
+ function_call_id=self.function_call_id,
184
+ )
185
+ logger.debug(
186
+ f"Pushing {len(items)} inputs to server. Num queued inputs awaiting"
187
+ f" push is {self.input_queue.qsize()}. "
188
+ )
189
+
190
+ resp = await self.client.stub.FunctionPutInputs(request, retry=self._function_inputs_retry)
191
+ self.inputs_sent += len(items)
192
+ # Change item state to WAITING_FOR_OUTPUT, and set the input_id and input_jwt which are in the response.
193
+ if self.map_items_manager is not None:
194
+ self.map_items_manager.handle_put_inputs_response(resp.inputs)
195
+ logger.debug(
196
+ f"Successfully pushed {len(items)} inputs to server. "
197
+ f"Num queued inputs awaiting push is {self.input_queue.qsize()}."
198
+ )
199
+ yield
200
+
201
+ @property
202
+ def _function_inputs_retry(self) -> Retry:
203
+ # with 8 retries we log the warning below about every 30 seconds which isn't too spammy.
204
+ retry_warning_message = RetryWarningMessage(
205
+ message=f"Warning: map progress for function {self.function._function_name} is limited."
206
+ " Common bottlenecks include slow iteration over results, or function backlogs.",
207
+ warning_interval=8,
208
+ errors_to_warn_for=[Status.RESOURCE_EXHAUSTED],
209
+ )
210
+ return Retry(
211
+ max_retries=None,
212
+ max_delay=PUMP_INPUTS_MAX_RETRY_DELAY,
213
+ additional_status_codes=[Status.RESOURCE_EXHAUSTED],
214
+ warning_message=retry_warning_message,
215
+ )
216
+
217
+
218
+ class SyncInputPumper(InputPumper):
219
+ def __init__(
220
+ self,
221
+ client: "modal.client._Client",
222
+ *,
223
+ input_queue: asyncio.Queue,
224
+ retry_queue: TimestampPriorityQueue,
225
+ function: "modal.functions._Function",
226
+ function_call_jwt: str,
227
+ function_call_id: str,
228
+ map_items_manager: "_MapItemsManager",
229
+ ):
230
+ super().__init__(
231
+ client,
232
+ input_queue=input_queue,
233
+ function=function,
234
+ function_call_id=function_call_id,
235
+ max_batch_size=MAP_INVOCATION_CHUNK_SIZE,
236
+ map_items_manager=map_items_manager,
237
+ )
238
+ self.retry_queue = retry_queue
239
+ self.inputs_retried = 0
240
+ self.function_call_jwt = function_call_jwt
241
+
242
+ async def retry_inputs(self):
243
+ async for retriable_idxs in queue_batch_iterator(self.retry_queue, max_batch_size=self.max_batch_size):
244
+ # For each index, use the context in the manager to create a FunctionRetryInputsItem.
245
+ # This will also update the context state to RETRYING.
246
+ inputs: list[api_pb2.FunctionRetryInputsItem] = await self.map_items_manager.prepare_items_for_retry(
247
+ retriable_idxs
248
+ )
249
+ request = api_pb2.FunctionRetryInputsRequest(
250
+ function_call_jwt=self.function_call_jwt,
251
+ inputs=inputs,
252
+ )
253
+ resp = await self.client.stub.FunctionRetryInputs(request, retry=self._function_inputs_retry)
254
+ # Update the state to WAITING_FOR_OUTPUT, and update the input_jwt in the context
255
+ # to the new value in the response.
256
+ self.map_items_manager.handle_retry_response(resp.input_jwts)
257
+ logger.debug(f"Successfully pushed retry for {len(inputs)} to server.")
258
+ self.inputs_retried += len(inputs)
259
+ yield
260
+
261
+
262
+ class AsyncInputPumper(InputPumper):
263
+ def __init__(
264
+ self,
265
+ client: "modal.client._Client",
266
+ *,
267
+ input_queue: asyncio.Queue,
268
+ function: "modal.functions._Function",
269
+ function_call_id: str,
270
+ ):
271
+ super().__init__(
272
+ client,
273
+ input_queue=input_queue,
274
+ function=function,
275
+ function_call_id=function_call_id,
276
+ max_batch_size=SPAWN_MAP_INVOCATION_CHUNK_SIZE,
277
+ )
278
+
279
+ async def pump_inputs(self):
280
+ async for _ in super().pump_inputs():
281
+ pass
282
+ request = api_pb2.FunctionFinishInputsRequest(
283
+ function_id=self.function.object_id,
284
+ function_call_id=self.function_call_id,
285
+ num_inputs=self.inputs_sent,
286
+ )
287
+ await self.client.stub.FunctionFinishInputs(request, retry=Retry(max_retries=None))
288
+ yield
289
+
290
+
291
+ async def _spawn_map_invocation(
292
+ function: "modal.functions._Function", raw_input_queue: _SynchronizedQueue, client: "modal.client._Client"
293
+ ) -> tuple[str, int]:
294
+ assert client.stub
295
+ request = api_pb2.FunctionMapRequest(
296
+ function_id=function.object_id,
297
+ parent_input_id=current_input_id() or "",
298
+ function_call_type=api_pb2.FUNCTION_CALL_TYPE_MAP,
299
+ function_call_invocation_type=api_pb2.FUNCTION_CALL_INVOCATION_TYPE_ASYNC,
300
+ )
301
+ response: api_pb2.FunctionMapResponse = await client.stub.FunctionMap(request)
302
+ function_call_id = response.function_call_id
303
+
304
+ have_all_inputs = False
305
+ inputs_created = 0
306
+
307
+ def set_inputs_created(set_inputs_created):
308
+ nonlocal inputs_created
309
+ assert set_inputs_created is None or set_inputs_created > inputs_created
310
+ inputs_created = set_inputs_created
311
+
312
+ def set_have_all_inputs():
313
+ nonlocal have_all_inputs
314
+ have_all_inputs = True
315
+
316
+ input_queue: asyncio.Queue[api_pb2.FunctionPutInputsItem | None] = asyncio.Queue()
317
+ input_preprocessor = InputPreprocessor(
318
+ client=client,
319
+ raw_input_queue=raw_input_queue,
320
+ processed_input_queue=input_queue,
321
+ function=function,
322
+ created_callback=set_inputs_created,
323
+ done_callback=set_have_all_inputs,
324
+ )
325
+
326
+ input_pumper = AsyncInputPumper(
327
+ client=client,
328
+ input_queue=input_queue,
329
+ function=function,
330
+ function_call_id=function_call_id,
331
+ )
332
+
333
+ def log_stats():
334
+ logger.debug(
335
+ f"have_all_inputs={have_all_inputs} inputs_created={inputs_created} inputs_sent={input_pumper.inputs_sent} "
336
+ )
337
+
338
+ async def log_task():
339
+ while True:
340
+ log_stats()
341
+ try:
342
+ await asyncio.sleep(10)
343
+ except asyncio.CancelledError:
344
+ # Log final stats before exiting
345
+ log_stats()
346
+ break
347
+
348
+ async def consume_generator(gen):
349
+ async for _ in gen:
350
+ pass
351
+
352
+ log_debug_stats_task = asyncio.create_task(log_task())
353
+ await asyncio.gather(
354
+ consume_generator(input_preprocessor.drain_input_generator()),
355
+ consume_generator(input_pumper.pump_inputs()),
356
+ )
357
+ log_debug_stats_task.cancel()
358
+ await log_debug_stats_task
359
+ return function_call_id, inputs_created
360
+
361
+
89
362
  async def _map_invocation(
90
363
  function: "modal.functions._Function",
91
364
  raw_input_queue: _SynchronizedQueue,
@@ -104,7 +377,7 @@ async def _map_invocation(
104
377
  return_exceptions=return_exceptions,
105
378
  function_call_invocation_type=function_call_invocation_type,
106
379
  )
107
- response: api_pb2.FunctionMapResponse = await retry_transient_errors(client.stub.FunctionMap, request)
380
+ response: api_pb2.FunctionMapResponse = await client.stub.FunctionMap(request)
108
381
 
109
382
  function_call_id = response.function_call_id
110
383
  function_call_jwt = response.function_call_jwt
@@ -117,8 +390,6 @@ async def _map_invocation(
117
390
  have_all_inputs = False
118
391
  map_done_event = asyncio.Event()
119
392
  inputs_created = 0
120
- inputs_sent = 0
121
- inputs_retried = 0
122
393
  outputs_completed = 0
123
394
  outputs_received = 0
124
395
  retried_outputs = 0
@@ -135,25 +406,24 @@ async def _map_invocation(
135
406
  retry_policy, function_call_invocation_type, retry_queue, sync_client_retries_enabled, max_inputs_outstanding
136
407
  )
137
408
 
138
- async def create_input(argskwargs):
139
- idx = inputs_created
140
- update_state(set_inputs_created=inputs_created + 1)
141
- (args, kwargs) = argskwargs
142
- return await _create_input(
143
- args,
144
- kwargs,
145
- client.stub,
146
- max_object_size_bytes=function._max_object_size_bytes,
147
- idx=idx,
148
- method_name=function._use_method_name,
149
- )
409
+ input_preprocessor = InputPreprocessor(
410
+ client=client,
411
+ raw_input_queue=raw_input_queue,
412
+ processed_input_queue=input_queue,
413
+ function=function,
414
+ created_callback=lambda x: update_state(set_inputs_created=x),
415
+ done_callback=lambda: update_state(set_have_all_inputs=True),
416
+ )
150
417
 
151
- async def input_iter():
152
- while 1:
153
- raw_input = await raw_input_queue.get()
154
- if raw_input is None: # end of input sentinel
155
- break
156
- yield raw_input # args, kwargs
418
+ input_pumper = SyncInputPumper(
419
+ client=client,
420
+ input_queue=input_queue,
421
+ retry_queue=retry_queue,
422
+ function=function,
423
+ map_items_manager=map_items_manager,
424
+ function_call_jwt=function_call_jwt,
425
+ function_call_id=function_call_id,
426
+ )
157
427
 
158
428
  def update_state(set_have_all_inputs=None, set_inputs_created=None, set_outputs_completed=None):
159
429
  # This should be the only method that needs nonlocal of the following vars
@@ -175,84 +445,6 @@ async def _map_invocation(
175
445
  # map is done
176
446
  map_done_event.set()
177
447
 
178
- async def drain_input_generator():
179
- # Parallelize uploading blobs
180
- async with aclosing(
181
- async_map_ordered(input_iter(), create_input, concurrency=BLOB_MAX_PARALLELISM)
182
- ) as streamer:
183
- async for item in streamer:
184
- await input_queue.put(item)
185
-
186
- # close queue iterator
187
- await input_queue.put(None)
188
- update_state(set_have_all_inputs=True)
189
- yield
190
-
191
- async def pump_inputs():
192
- assert client.stub
193
- nonlocal inputs_sent
194
- async for items in queue_batch_iterator(input_queue, max_batch_size=MAP_INVOCATION_CHUNK_SIZE):
195
- # Add items to the manager. Their state will be SENDING.
196
- await map_items_manager.add_items(items)
197
- request = api_pb2.FunctionPutInputsRequest(
198
- function_id=function.object_id,
199
- inputs=items,
200
- function_call_id=function_call_id,
201
- )
202
- logger.debug(
203
- f"Pushing {len(items)} inputs to server. Num queued inputs awaiting push is {input_queue.qsize()}."
204
- )
205
-
206
- resp = await send_inputs(client.stub.FunctionPutInputs, request)
207
- inputs_sent += len(items)
208
- # Change item state to WAITING_FOR_OUTPUT, and set the input_id and input_jwt which are in the response.
209
- map_items_manager.handle_put_inputs_response(resp.inputs)
210
- logger.debug(
211
- f"Successfully pushed {len(items)} inputs to server. "
212
- f"Num queued inputs awaiting push is {input_queue.qsize()}."
213
- )
214
- yield
215
-
216
- async def retry_inputs():
217
- nonlocal inputs_retried
218
- async for retriable_idxs in queue_batch_iterator(retry_queue, max_batch_size=MAP_INVOCATION_CHUNK_SIZE):
219
- # For each index, use the context in the manager to create a FunctionRetryInputsItem.
220
- # This will also update the context state to RETRYING.
221
- inputs: list[api_pb2.FunctionRetryInputsItem] = await map_items_manager.prepare_items_for_retry(
222
- retriable_idxs
223
- )
224
- request = api_pb2.FunctionRetryInputsRequest(
225
- function_call_jwt=function_call_jwt,
226
- inputs=inputs,
227
- )
228
- resp = await send_inputs(client.stub.FunctionRetryInputs, request)
229
- # Update the state to WAITING_FOR_OUTPUT, and update the input_jwt in the context
230
- # to the new value in the response.
231
- map_items_manager.handle_retry_response(resp.input_jwts)
232
- logger.debug(f"Successfully pushed retry for {len(inputs)} to server.")
233
- inputs_retried += len(inputs)
234
- yield
235
-
236
- async def send_inputs(
237
- fn: "modal.client.UnaryUnaryWrapper",
238
- request: typing.Union[api_pb2.FunctionPutInputsRequest, api_pb2.FunctionRetryInputsRequest],
239
- ) -> typing.Union[api_pb2.FunctionPutInputsResponse, api_pb2.FunctionRetryInputsResponse]:
240
- # with 8 retries we log the warning below about every 30 seconds which isn't too spammy.
241
- retry_warning_message = RetryWarningMessage(
242
- message=f"Warning: map progress for function {function._function_name} is limited."
243
- " Common bottlenecks include slow iteration over results, or function backlogs.",
244
- warning_interval=8,
245
- errors_to_warn_for=[Status.RESOURCE_EXHAUSTED],
246
- )
247
- return await retry_transient_errors(
248
- fn,
249
- request,
250
- max_retries=None,
251
- max_delay=PUMP_INPUTS_MAX_RETRY_DELAY,
252
- additional_status_codes=[Status.RESOURCE_EXHAUSTED],
253
- retry_warning_message=retry_warning_message,
254
- )
255
-
256
448
  async def get_all_outputs():
257
449
  assert client.stub
258
450
  nonlocal \
@@ -281,11 +473,12 @@ async def _map_invocation(
281
473
  input_jwts=input_jwts,
282
474
  )
283
475
  get_response_task = asyncio.create_task(
284
- retry_transient_errors(
285
- client.stub.FunctionGetOutputs,
476
+ client.stub.FunctionGetOutputs(
286
477
  request,
287
- max_retries=20,
288
- attempt_timeout=OUTPUTS_TIMEOUT + ATTEMPT_TIMEOUT_GRACE_PERIOD,
478
+ retry=Retry(
479
+ max_retries=20,
480
+ attempt_timeout=OUTPUTS_TIMEOUT + ATTEMPT_TIMEOUT_GRACE_PERIOD,
481
+ ),
289
482
  )
290
483
  )
291
484
  map_done_task = asyncio.create_task(map_done_event.wait())
@@ -344,7 +537,7 @@ async def _map_invocation(
344
537
  clear_on_success=True,
345
538
  requested_at=time.time(),
346
539
  )
347
- await retry_transient_errors(client.stub.FunctionGetOutputs, request)
540
+ await client.stub.FunctionGetOutputs(request)
348
541
  await retry_queue.close()
349
542
 
350
543
  async def fetch_output(item: api_pb2.FunctionGetOutputsItem) -> tuple[int, Any]:
@@ -395,8 +588,11 @@ async def _map_invocation(
395
588
  def log_stats():
396
589
  logger.debug(
397
590
  f"Map stats: sync_client_retries_enabled={sync_client_retries_enabled} "
398
- f"have_all_inputs={have_all_inputs} inputs_created={inputs_created} input_sent={inputs_sent} "
399
- f"inputs_retried={inputs_retried} outputs_received={outputs_received} "
591
+ f"have_all_inputs={have_all_inputs} "
592
+ f"inputs_created={inputs_created} "
593
+ f"input_sent={input_pumper.inputs_sent} "
594
+ f"inputs_retried={input_pumper.inputs_retried} "
595
+ f"outputs_received={outputs_received} "
400
596
  f"successful_completions={successful_completions} failed_completions={failed_completions} "
401
597
  f"no_context_duplicates={no_context_duplicates} old_retry_duplicates={stale_retry_duplicates} "
402
598
  f"already_complete_duplicates={already_complete_duplicates} "
@@ -415,7 +611,12 @@ async def _map_invocation(
415
611
 
416
612
  log_debug_stats_task = asyncio.create_task(log_debug_stats())
417
613
  async with aclosing(
418
- async_merge(drain_input_generator(), pump_inputs(), poll_outputs(), retry_inputs())
614
+ async_merge(
615
+ input_preprocessor.drain_input_generator(),
616
+ input_pumper.pump_inputs(),
617
+ input_pumper.retry_inputs(),
618
+ poll_outputs(),
619
+ )
419
620
  ) as streamer:
420
621
  async for response in streamer:
421
622
  if response is not None: # type: ignore[unreachable]
@@ -424,6 +625,367 @@ async def _map_invocation(
424
625
  await log_debug_stats_task
425
626
 
426
627
 
628
+ async def _map_invocation_inputplane(
629
+ function: "modal.functions._Function",
630
+ raw_input_queue: _SynchronizedQueue,
631
+ client: "modal.client._Client",
632
+ order_outputs: bool,
633
+ return_exceptions: bool,
634
+ wrap_returned_exceptions: bool,
635
+ count_update_callback: Optional[Callable[[int, int], None]],
636
+ ) -> typing.AsyncGenerator[Any, None]:
637
+ """Input-plane implementation of a function map invocation.
638
+
639
+ This is analogous to `_map_invocation`, but instead of the control-plane
640
+ `FunctionMap` / `FunctionPutInputs` / `FunctionGetOutputs` RPCs it speaks
641
+ the input-plane protocol consisting of `MapStartOrContinue`, `MapAwait`, and `MapCheckInputs`.
642
+ """
643
+
644
+ assert function._input_plane_url, "_map_invocation_inputplane should only be used for input-plane backed functions"
645
+
646
+ input_plane_stub = await client.get_stub(function._input_plane_url)
647
+
648
+ # Required for _create_input.
649
+ assert client.stub, "Client must be hydrated with a stub for _map_invocation_inputplane"
650
+
651
+ # ------------------------------------------------------------
652
+ # Invocation-wide state
653
+ # ------------------------------------------------------------
654
+
655
+ have_all_inputs = False
656
+ map_done_event = asyncio.Event()
657
+
658
+ inputs_created = 0
659
+ outputs_completed = 0
660
+ successful_completions = 0
661
+ failed_completions = 0
662
+ no_context_duplicates = 0
663
+ stale_retry_duplicates = 0
664
+ already_complete_duplicates = 0
665
+ retried_outputs = 0
666
+ input_queue_size = 0
667
+ last_entry_id = ""
668
+
669
+ # The input-plane server returns this after the first request.
670
+ map_token = None
671
+ map_token_received = asyncio.Event()
672
+
673
+ # Single priority queue that holds *both* fresh inputs (timestamp == now)
674
+ # and future retries (timestamp > now).
675
+ queue: TimestampPriorityQueue[api_pb2.MapStartOrContinueItem] = TimestampPriorityQueue()
676
+
677
+ # Maximum number of inputs that may be in-flight (the server sends this in
678
+ # the first response – fall back to the default if we never receive it for
679
+ # any reason).
680
+ max_inputs_outstanding = MAX_INPUTS_OUTSTANDING_DEFAULT
681
+
682
+ # Set a default retry policy to construct an instance of _MapItemsManager.
683
+ # We'll update the retry policy with the actual user-specified retry policy
684
+ # from the server in the first MapStartOrContinue response.
685
+ retry_policy = api_pb2.FunctionRetryPolicy(
686
+ retries=0,
687
+ initial_delay_ms=1000,
688
+ max_delay_ms=1000,
689
+ backoff_coefficient=1.0,
690
+ )
691
+ map_items_manager = _MapItemsManager(
692
+ retry_policy=retry_policy,
693
+ function_call_invocation_type=api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC,
694
+ retry_queue=queue,
695
+ sync_client_retries_enabled=True,
696
+ max_inputs_outstanding=MAX_INPUTS_OUTSTANDING_DEFAULT,
697
+ is_input_plane_instance=True,
698
+ )
699
+
700
+ def update_counters(
701
+ created_delta: int = 0, completed_delta: int = 0, set_have_all_inputs: Union[bool, None] = None
702
+ ):
703
+ nonlocal inputs_created, outputs_completed, have_all_inputs
704
+
705
+ if created_delta:
706
+ inputs_created += created_delta
707
+ if completed_delta:
708
+ outputs_completed += completed_delta
709
+ if set_have_all_inputs is not None:
710
+ have_all_inputs = set_have_all_inputs
711
+
712
+ if count_update_callback is not None:
713
+ count_update_callback(outputs_completed, inputs_created)
714
+
715
+ if have_all_inputs and outputs_completed >= inputs_created:
716
+ map_done_event.set()
717
+
718
+ async def create_input(argskwargs):
719
+ idx = inputs_created + 1 # 1-indexed map call idx
720
+ update_counters(created_delta=1)
721
+ (args, kwargs) = argskwargs
722
+ put_item: api_pb2.FunctionPutInputsItem = await _create_input(
723
+ args,
724
+ kwargs,
725
+ client.stub,
726
+ idx=idx,
727
+ function=function,
728
+ )
729
+ return api_pb2.MapStartOrContinueItem(input=put_item)
730
+
731
+ async def input_iter():
732
+ while True:
733
+ raw_input = await raw_input_queue.get()
734
+ if raw_input is None: # end of input sentinel
735
+ break
736
+ yield raw_input # args, kwargs
737
+
738
+ async def drain_input_generator():
739
+ async with aclosing(
740
+ async_map_ordered(input_iter(), create_input, concurrency=BLOB_MAX_PARALLELISM)
741
+ ) as streamer:
742
+ async for q_item in streamer:
743
+ await queue.put(time.time(), q_item)
744
+
745
+ # All inputs have been read.
746
+ update_counters(set_have_all_inputs=True)
747
+ yield
748
+
749
+ async def pump_inputs():
750
+ nonlocal map_token, max_inputs_outstanding
751
+ async for batch in queue_batch_iterator(queue, max_batch_size=MAP_INVOCATION_CHUNK_SIZE):
752
+ # Convert the queued items into the proto format expected by the RPC.
753
+ request_items: list[api_pb2.MapStartOrContinueItem] = [
754
+ api_pb2.MapStartOrContinueItem(input=qi.input, attempt_token=qi.attempt_token) for qi in batch
755
+ ]
756
+
757
+ await map_items_manager.add_items_inputplane(request_items)
758
+
759
+ # Build request
760
+ request = api_pb2.MapStartOrContinueRequest(
761
+ function_id=function.object_id,
762
+ map_token=map_token,
763
+ parent_input_id=current_input_id() or "",
764
+ items=request_items,
765
+ )
766
+
767
+ metadata = await client.get_input_plane_metadata(function._input_plane_region)
768
+
769
+ response: api_pb2.MapStartOrContinueResponse = await input_plane_stub.MapStartOrContinue(
770
+ request,
771
+ retry=Retry(
772
+ additional_status_codes=[Status.RESOURCE_EXHAUSTED],
773
+ max_delay=PUMP_INPUTS_MAX_RETRY_DELAY,
774
+ max_retries=None,
775
+ ),
776
+ metadata=metadata,
777
+ )
778
+
779
+ # match response items to the corresponding request item index
780
+ response_items_idx_tuple = [
781
+ (request_items[idx].input.idx, attempt_token)
782
+ for idx, attempt_token in enumerate(response.attempt_tokens)
783
+ ]
784
+
785
+ map_items_manager.handle_put_continue_response(response_items_idx_tuple)
786
+
787
+ # Set the function call id and actual retry policy with the data from the first response.
788
+ # This conditional is skipped for subsequent iterations of this for-loop.
789
+ if map_token is None:
790
+ map_token = response.map_token
791
+ map_token_received.set()
792
+ max_inputs_outstanding = response.max_inputs_outstanding or MAX_INPUTS_OUTSTANDING_DEFAULT
793
+ map_items_manager.set_retry_policy(response.retry_policy)
794
+ # Update the retry policy for the first batch of inputs.
795
+ # Subsequent batches will have the correct user-specified retry policy
796
+ # set by the updated _MapItemsManager.
797
+ map_items_manager.update_items_retry_policy(response.retry_policy)
798
+ yield
799
+
800
+ async def check_lost_inputs():
801
+ nonlocal last_entry_id # shared with get_all_outputs
802
+ try:
803
+ while not map_done_event.is_set():
804
+ if map_token is None:
805
+ await map_token_received.wait()
806
+ continue
807
+
808
+ sleep_task = asyncio.create_task(asyncio.sleep(1))
809
+ map_done_task = asyncio.create_task(map_done_event.wait())
810
+ done, _ = await asyncio.wait([sleep_task, map_done_task], return_when=FIRST_COMPLETED)
811
+ if map_done_task in done:
812
+ break
813
+
814
+ # check_inputs = [(idx, attempt_token), ...]
815
+ check_inputs = map_items_manager.get_input_idxs_waiting_for_output()
816
+ attempt_tokens = [attempt_token for _, attempt_token in check_inputs]
817
+ request = api_pb2.MapCheckInputsRequest(
818
+ last_entry_id=last_entry_id,
819
+ timeout=0, # Non-blocking read
820
+ attempt_tokens=attempt_tokens,
821
+ )
822
+
823
+ metadata = await client.get_input_plane_metadata(function._input_plane_region)
824
+ response: api_pb2.MapCheckInputsResponse = await input_plane_stub.MapCheckInputs(
825
+ request, metadata=metadata
826
+ )
827
+ check_inputs_response = [
828
+ (check_inputs[resp_idx][0], response.lost[resp_idx]) for resp_idx, _ in enumerate(response.lost)
829
+ ]
830
+ # check_inputs_response = [(idx, lost: bool), ...]
831
+ await map_items_manager.handle_check_inputs_response(check_inputs_response)
832
+ yield
833
+ except asyncio.CancelledError:
834
+ pass
835
+
836
+ async def get_all_outputs():
837
+ nonlocal \
838
+ successful_completions, \
839
+ failed_completions, \
840
+ no_context_duplicates, \
841
+ stale_retry_duplicates, \
842
+ already_complete_duplicates, \
843
+ retried_outputs, \
844
+ last_entry_id
845
+
846
+ while not map_done_event.is_set():
847
+ if map_token is None:
848
+ await map_token_received.wait()
849
+ continue
850
+
851
+ request = api_pb2.MapAwaitRequest(
852
+ map_token=map_token,
853
+ last_entry_id=last_entry_id,
854
+ requested_at=time.time(),
855
+ timeout=OUTPUTS_TIMEOUT,
856
+ )
857
+ metadata = await client.get_input_plane_metadata(function._input_plane_region)
858
+ get_response_task = asyncio.create_task(
859
+ input_plane_stub.MapAwait(
860
+ request,
861
+ retry=Retry(
862
+ max_retries=20,
863
+ attempt_timeout=OUTPUTS_TIMEOUT + ATTEMPT_TIMEOUT_GRACE_PERIOD,
864
+ ),
865
+ metadata=metadata,
866
+ )
867
+ )
868
+ map_done_task = asyncio.create_task(map_done_event.wait())
869
+ try:
870
+ done, pending = await asyncio.wait([get_response_task, map_done_task], return_when=FIRST_COMPLETED)
871
+ if get_response_task in done:
872
+ map_done_task.cancel()
873
+ response = get_response_task.result()
874
+ else:
875
+ assert map_done_event.is_set()
876
+ # map is done - no more outputs, so return early
877
+ return
878
+ finally:
879
+ # clean up tasks, in case of cancellations etc.
880
+ get_response_task.cancel()
881
+ map_done_task.cancel()
882
+ last_entry_id = response.last_entry_id
883
+
884
+ for output_item in response.outputs:
885
+ output_type = await map_items_manager.handle_get_outputs_response(output_item, int(time.time()))
886
+ if output_type == _OutputType.SUCCESSFUL_COMPLETION:
887
+ successful_completions += 1
888
+ elif output_type == _OutputType.FAILED_COMPLETION:
889
+ failed_completions += 1
890
+ elif output_type == _OutputType.RETRYING:
891
+ retried_outputs += 1
892
+ elif output_type == _OutputType.NO_CONTEXT_DUPLICATE:
893
+ no_context_duplicates += 1
894
+ elif output_type == _OutputType.STALE_RETRY_DUPLICATE:
895
+ stale_retry_duplicates += 1
896
+ elif output_type == _OutputType.ALREADY_COMPLETE_DUPLICATE:
897
+ already_complete_duplicates += 1
898
+ else:
899
+ raise Exception(f"Unknown output type: {output_type}")
900
+
901
+ if output_type == _OutputType.SUCCESSFUL_COMPLETION or output_type == _OutputType.FAILED_COMPLETION:
902
+ update_counters(completed_delta=1)
903
+ yield output_item
904
+
905
+ async def get_all_outputs_and_clean_up():
906
+ try:
907
+ async with aclosing(get_all_outputs()) as stream:
908
+ async for item in stream:
909
+ yield item
910
+ finally:
911
+ await queue.close()
912
+ pass
913
+
914
+ async def fetch_output(item: api_pb2.FunctionGetOutputsItem) -> tuple[int, Any]:
915
+ try:
916
+ output = await _process_result(item.result, item.data_format, input_plane_stub, client)
917
+ except Exception as e:
918
+ if return_exceptions:
919
+ if wrap_returned_exceptions:
920
+ # Prior to client 1.0.4 there was a bug where return_exceptions would wrap
921
+ # any returned exceptions in a synchronicity.UserCodeException. This adds
922
+ # deprecated non-breaking compatibility bandaid for migrating away from that:
923
+ output = modal.exception.UserCodeException(e)
924
+ else:
925
+ output = e
926
+ else:
927
+ raise e
928
+ return (item.idx, output)
929
+
930
+ async def poll_outputs():
931
+ # map to store out-of-order outputs received
932
+ received_outputs = {}
933
+ output_idx = 1 # 1-indexed map call idx
934
+
935
+ async with aclosing(
936
+ async_map_ordered(get_all_outputs_and_clean_up(), fetch_output, concurrency=BLOB_MAX_PARALLELISM)
937
+ ) as streamer:
938
+ async for idx, output in streamer:
939
+ if not order_outputs:
940
+ yield _OutputValue(output)
941
+ else:
942
+ # hold on to outputs for function maps, so we can reorder them correctly.
943
+ received_outputs[idx] = output
944
+
945
+ while True:
946
+ if output_idx not in received_outputs:
947
+ # we haven't received the output for the current index yet.
948
+ # stop returning outputs to the caller and instead wait for
949
+ # the next output to arrive from the server.
950
+ break
951
+
952
+ output = received_outputs.pop(output_idx)
953
+ yield _OutputValue(output)
954
+ output_idx += 1
955
+
956
+ assert len(received_outputs) == 0
957
+
958
+ async def log_debug_stats():
959
+ def log_stats():
960
+ logger.debug(
961
+ f"Map stats:\nsuccessful_completions={successful_completions} failed_completions={failed_completions} "
962
+ f"no_context_duplicates={no_context_duplicates} stale_retry_duplicates={stale_retry_duplicates} "
963
+ f"already_complete_duplicates={already_complete_duplicates} retried_outputs={retried_outputs} "
964
+ f"map_token={map_token} max_inputs_outstanding={max_inputs_outstanding} "
965
+ f"map_items_manager_size={len(map_items_manager)} input_queue_size={input_queue_size}"
966
+ )
967
+
968
+ while True:
969
+ log_stats()
970
+ try:
971
+ await asyncio.sleep(10)
972
+ except asyncio.CancelledError:
973
+ # Log final stats before exiting
974
+ log_stats()
975
+ break
976
+
977
+ log_task = asyncio.create_task(log_debug_stats())
978
+
979
+ async with aclosing(
980
+ async_merge(drain_input_generator(), pump_inputs(), poll_outputs(), check_lost_inputs())
981
+ ) as merged:
982
+ async for maybe_output in merged:
983
+ if maybe_output is not None: # ignore None sentinels
984
+ yield maybe_output.value
985
+
986
+ log_task.cancel()
987
+
988
+
427
989
  async def _map_helper(
428
990
  self: "modal.functions.Function",
429
991
  async_input_gen: typing.AsyncGenerator[Any, None],
@@ -620,6 +1182,56 @@ def _map_sync(
620
1182
  )
621
1183
 
622
1184
 
1185
+ async def _experimental_spawn_map_async(self, *input_iterators, kwargs={}) -> "modal.functions._FunctionCall":
1186
+ async_input_gen = async_zip(*[sync_or_async_iter(it) for it in input_iterators])
1187
+ return await _spawn_map_helper(self, async_input_gen, kwargs)
1188
+
1189
+
1190
+ async def _spawn_map_helper(
1191
+ self: "modal.functions.Function", async_input_gen, kwargs={}
1192
+ ) -> "modal.functions._FunctionCall":
1193
+ raw_input_queue: Any = SynchronizedQueue() # type: ignore
1194
+ await raw_input_queue.init.aio()
1195
+
1196
+ async def feed_queue():
1197
+ async with aclosing(async_input_gen) as streamer:
1198
+ async for args in streamer:
1199
+ await raw_input_queue.put.aio((args, kwargs))
1200
+ await raw_input_queue.put.aio(None) # end-of-input sentinel
1201
+
1202
+ fc, _ = await asyncio.gather(self._spawn_map.aio(raw_input_queue), feed_queue())
1203
+ return fc
1204
+
1205
+
1206
+ def _experimental_spawn_map_sync(self, *input_iterators, kwargs={}) -> "modal.functions._FunctionCall":
1207
+ """mdmd:hidden
1208
+ Spawn parallel execution over a set of inputs, returning as soon as the inputs are created.
1209
+
1210
+ Unlike `modal.Function.map`, this method does not block on completion of the remote execution but
1211
+ returns a `modal.FunctionCall` object that can be used to poll status and retrieve results later.
1212
+
1213
+ Takes one iterator argument per argument in the function being mapped over.
1214
+
1215
+ Example:
1216
+ ```python
1217
+ @app.function()
1218
+ def my_func(a, b):
1219
+ return a ** b
1220
+
1221
+
1222
+ @app.local_entrypoint()
1223
+ def main():
1224
+ fc = my_func.spawn_map([1, 2], [3, 4])
1225
+ ```
1226
+
1227
+ """
1228
+
1229
+ return run_coroutine_in_temporary_event_loop(
1230
+ _experimental_spawn_map_async(self, *input_iterators, kwargs=kwargs),
1231
+ "You can't run Function.spawn_map() from an async function. Use Function.spawn_map.aio() instead.",
1232
+ )
1233
+
1234
+
623
1235
  async def _spawn_map_async(self, *input_iterators, kwargs={}) -> None:
624
1236
  """This runs in an event loop on the main thread. It consumes inputs from the input iterators and creates async
625
1237
  function calls for each.
@@ -756,12 +1368,19 @@ class _MapItemContext:
756
1368
  sync_client_retries_enabled: bool
757
1369
  # Both these futures are strings. Omitting generic type because
758
1370
  # it causes an error when running `inv protoc type-stubs`.
1371
+ # Unused. But important, input_id is not set for inputplane invocations.
759
1372
  input_id: asyncio.Future
760
1373
  input_jwt: asyncio.Future
761
1374
  previous_input_jwt: Optional[str]
762
1375
  _event_loop: asyncio.AbstractEventLoop
763
1376
 
764
- def __init__(self, input: api_pb2.FunctionInput, retry_manager: RetryManager, sync_client_retries_enabled: bool):
1377
+ def __init__(
1378
+ self,
1379
+ input: api_pb2.FunctionInput,
1380
+ retry_manager: RetryManager,
1381
+ sync_client_retries_enabled: bool,
1382
+ is_input_plane_instance: bool = False,
1383
+ ):
765
1384
  self.state = _MapItemState.SENDING
766
1385
  self.input = input
767
1386
  self.retry_manager = retry_manager
@@ -772,7 +1391,22 @@ class _MapItemContext:
772
1391
  # a race condition where we could receive outputs before we have
773
1392
  # recorded the input ID and JWT in `pending_outputs`.
774
1393
  self.input_jwt = self._event_loop.create_future()
1394
+ # Unused. But important, this is not set for inputplane invocations.
775
1395
  self.input_id = self._event_loop.create_future()
1396
+ self._is_input_plane_instance = is_input_plane_instance
1397
+
1398
+ def handle_map_start_or_continue_response(self, attempt_token: str):
1399
+ if not self.input_jwt.done():
1400
+ self.input_jwt.set_result(attempt_token)
1401
+ else:
1402
+ # Create a new future for the next value
1403
+ self.input_jwt = asyncio.Future()
1404
+ self.input_jwt.set_result(attempt_token)
1405
+
1406
+ # Set state to WAITING_FOR_OUTPUT only if current state is SENDING. If state is
1407
+ # RETRYING, WAITING_TO_RETRY, or COMPLETE, then we already got the output.
1408
+ if self.state == _MapItemState.SENDING:
1409
+ self.state = _MapItemState.WAITING_FOR_OUTPUT
776
1410
 
777
1411
  def handle_put_inputs_response(self, item: api_pb2.FunctionPutInputsResponseItem):
778
1412
  self.input_jwt.set_result(item.input_jwt)
@@ -799,7 +1433,7 @@ class _MapItemContext:
799
1433
  if self.state == _MapItemState.COMPLETE:
800
1434
  logger.debug(
801
1435
  f"Received output for input marked as complete. Must be duplicate, so ignoring. "
802
- f"idx={item.idx} input_id={item.input_id}, retry_count={item.retry_count}"
1436
+ f"idx={item.idx} input_id={item.input_id} retry_count={item.retry_count}"
803
1437
  )
804
1438
  return _OutputType.ALREADY_COMPLETE_DUPLICATE
805
1439
  # If the item's retry count doesn't match our retry count, this is probably a duplicate of an old output.
@@ -847,7 +1481,11 @@ class _MapItemContext:
847
1481
 
848
1482
  self.state = _MapItemState.WAITING_TO_RETRY
849
1483
 
850
- await retry_queue.put(now_seconds + (delay_ms / 1000), item.idx)
1484
+ if self._is_input_plane_instance:
1485
+ retry_item = await self.create_map_start_or_continue_item(item.idx)
1486
+ await retry_queue.put(now_seconds + delay_ms / 1_000, retry_item)
1487
+ else:
1488
+ await retry_queue.put(now_seconds + delay_ms / 1_000, item.idx)
851
1489
 
852
1490
  return _OutputType.RETRYING
853
1491
 
@@ -862,10 +1500,23 @@ class _MapItemContext:
862
1500
  retry_count=self.retry_manager.retry_count,
863
1501
  )
864
1502
 
1503
+ def set_retry_policy(self, retry_policy: api_pb2.FunctionRetryPolicy):
1504
+ self.retry_manager = RetryManager(retry_policy)
1505
+
865
1506
  def handle_retry_response(self, input_jwt: str):
866
1507
  self.input_jwt.set_result(input_jwt)
867
1508
  self.state = _MapItemState.WAITING_FOR_OUTPUT
868
1509
 
1510
+ async def create_map_start_or_continue_item(self, idx: int) -> api_pb2.MapStartOrContinueItem:
1511
+ attempt_token = await self.input_jwt
1512
+ return api_pb2.MapStartOrContinueItem(
1513
+ input=api_pb2.FunctionPutInputsItem(
1514
+ input=self.input,
1515
+ idx=idx,
1516
+ ),
1517
+ attempt_token=attempt_token,
1518
+ )
1519
+
869
1520
 
870
1521
  class _MapItemsManager:
871
1522
  def __init__(
@@ -875,6 +1526,7 @@ class _MapItemsManager:
875
1526
  retry_queue: TimestampPriorityQueue,
876
1527
  sync_client_retries_enabled: bool,
877
1528
  max_inputs_outstanding: int,
1529
+ is_input_plane_instance: bool = False,
878
1530
  ):
879
1531
  self._retry_policy = retry_policy
880
1532
  self.function_call_invocation_type = function_call_invocation_type
@@ -885,6 +1537,10 @@ class _MapItemsManager:
885
1537
  self._inputs_outstanding = asyncio.BoundedSemaphore(max_inputs_outstanding)
886
1538
  self._item_context: dict[int, _MapItemContext] = {}
887
1539
  self._sync_client_retries_enabled = sync_client_retries_enabled
1540
+ self._is_input_plane_instance = is_input_plane_instance
1541
+
1542
+ def set_retry_policy(self, retry_policy: api_pb2.FunctionRetryPolicy):
1543
+ self._retry_policy = retry_policy
888
1544
 
889
1545
  async def add_items(self, items: list[api_pb2.FunctionPutInputsItem]):
890
1546
  for item in items:
@@ -897,9 +1553,28 @@ class _MapItemsManager:
897
1553
  sync_client_retries_enabled=self._sync_client_retries_enabled,
898
1554
  )
899
1555
 
1556
+ async def add_items_inputplane(self, items: list[api_pb2.MapStartOrContinueItem]):
1557
+ for item in items:
1558
+ # acquire semaphore to limit the number of inputs in progress
1559
+ # (either queued to be sent, waiting for completion, or retrying)
1560
+ if item.attempt_token != "": # if it is a retry item
1561
+ self._item_context[item.input.idx].state = _MapItemState.SENDING
1562
+ continue
1563
+ await self._inputs_outstanding.acquire()
1564
+ self._item_context[item.input.idx] = _MapItemContext(
1565
+ input=item.input.input,
1566
+ retry_manager=RetryManager(self._retry_policy),
1567
+ sync_client_retries_enabled=self._sync_client_retries_enabled,
1568
+ is_input_plane_instance=self._is_input_plane_instance,
1569
+ )
1570
+
900
1571
  async def prepare_items_for_retry(self, retriable_idxs: list[int]) -> list[api_pb2.FunctionRetryInputsItem]:
901
1572
  return [await self._item_context[idx].prepare_item_for_retry() for idx in retriable_idxs]
902
1573
 
1574
+ def update_items_retry_policy(self, retry_policy: api_pb2.FunctionRetryPolicy):
1575
+ for ctx in self._item_context.values():
1576
+ ctx.set_retry_policy(retry_policy)
1577
+
903
1578
  def get_input_jwts_waiting_for_output(self) -> list[str]:
904
1579
  """
905
1580
  Returns a list of input_jwts for inputs that are waiting for output.
@@ -911,6 +1586,17 @@ class _MapItemsManager:
911
1586
  if ctx.state == _MapItemState.WAITING_FOR_OUTPUT and ctx.input_jwt.done()
912
1587
  ]
913
1588
 
1589
+ def get_input_idxs_waiting_for_output(self) -> list[tuple[int, str]]:
1590
+ """
1591
+ Returns a list of input_idxs for inputs that are waiting for output.
1592
+ """
1593
+ # Idx doesn't need a future because it is set by client and not server.
1594
+ return [
1595
+ (idx, ctx.input_jwt.result())
1596
+ for idx, ctx in self._item_context.items()
1597
+ if ctx.state == _MapItemState.WAITING_FOR_OUTPUT and ctx.input_jwt.done()
1598
+ ]
1599
+
914
1600
  def _remove_item(self, item_idx: int):
915
1601
  del self._item_context[item_idx]
916
1602
  self._inputs_outstanding.release()
@@ -918,6 +1604,18 @@ class _MapItemsManager:
918
1604
  def get_item_context(self, item_idx: int) -> _MapItemContext:
919
1605
  return self._item_context.get(item_idx)
920
1606
 
1607
+ def handle_put_continue_response(
1608
+ self,
1609
+ items: list[tuple[int, str]], # idx, input_jwt
1610
+ ):
1611
+ for index, item in items:
1612
+ ctx = self._item_context.get(index, None)
1613
+ # If the context is None, then get_all_outputs() has already received a successful
1614
+ # output, and deleted the context. This happens if FunctionGetOutputs completes
1615
+ # before MapStartOrContinueResponse is received.
1616
+ if ctx is not None:
1617
+ ctx.handle_map_start_or_continue_response(item)
1618
+
921
1619
  def handle_put_inputs_response(self, items: list[api_pb2.FunctionPutInputsResponseItem]):
922
1620
  for item in items:
923
1621
  ctx = self._item_context.get(item.idx, None)
@@ -937,6 +1635,16 @@ class _MapItemsManager:
937
1635
  if ctx is not None:
938
1636
  ctx.handle_retry_response(input_jwt)
939
1637
 
1638
+ async def handle_check_inputs_response(self, response: list[tuple[int, bool]]):
1639
+ for idx, lost in response:
1640
+ ctx = self._item_context.get(idx, None)
1641
+ if ctx is not None:
1642
+ if lost:
1643
+ ctx.state = _MapItemState.WAITING_TO_RETRY
1644
+ retry_item = await ctx.create_map_start_or_continue_item(idx)
1645
+ _ = ctx.retry_manager.get_delay_ms() # increment retry count but instant retry for lost inputs
1646
+ await self._retry_queue.put(time.time(), retry_item)
1647
+
940
1648
  async def handle_get_outputs_response(self, item: api_pb2.FunctionGetOutputsItem, now_seconds: int) -> _OutputType:
941
1649
  ctx = self._item_context.get(item.idx, None)
942
1650
  if ctx is None: