modal 1.0.6.dev61__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of modal might be problematic. Click here for more details.

Files changed (75) hide show
  1. modal/__main__.py +2 -2
  2. modal/_clustered_functions.py +3 -0
  3. modal/_clustered_functions.pyi +3 -2
  4. modal/_functions.py +78 -26
  5. modal/_object.py +9 -1
  6. modal/_output.py +14 -25
  7. modal/_runtime/gpu_memory_snapshot.py +158 -54
  8. modal/_utils/async_utils.py +6 -4
  9. modal/_utils/auth_token_manager.py +1 -1
  10. modal/_utils/blob_utils.py +16 -21
  11. modal/_utils/function_utils.py +16 -4
  12. modal/_utils/time_utils.py +8 -4
  13. modal/app.py +0 -4
  14. modal/app.pyi +0 -4
  15. modal/cli/_traceback.py +3 -2
  16. modal/cli/app.py +4 -4
  17. modal/cli/cluster.py +4 -4
  18. modal/cli/config.py +2 -2
  19. modal/cli/container.py +2 -2
  20. modal/cli/dict.py +4 -4
  21. modal/cli/entry_point.py +2 -2
  22. modal/cli/import_refs.py +3 -3
  23. modal/cli/network_file_system.py +8 -9
  24. modal/cli/profile.py +2 -2
  25. modal/cli/queues.py +5 -5
  26. modal/cli/secret.py +5 -5
  27. modal/cli/utils.py +3 -4
  28. modal/cli/volume.py +8 -9
  29. modal/client.py +8 -1
  30. modal/client.pyi +9 -10
  31. modal/container_process.py +2 -2
  32. modal/dict.py +47 -3
  33. modal/dict.pyi +55 -0
  34. modal/exception.py +4 -0
  35. modal/experimental/__init__.py +1 -1
  36. modal/experimental/flash.py +18 -2
  37. modal/experimental/flash.pyi +19 -0
  38. modal/functions.pyi +6 -7
  39. modal/image.py +26 -10
  40. modal/image.pyi +12 -4
  41. modal/mount.py +1 -1
  42. modal/object.pyi +4 -0
  43. modal/parallel_map.py +432 -4
  44. modal/parallel_map.pyi +28 -0
  45. modal/queue.py +46 -3
  46. modal/queue.pyi +53 -0
  47. modal/sandbox.py +105 -25
  48. modal/sandbox.pyi +108 -18
  49. modal/secret.py +48 -5
  50. modal/secret.pyi +55 -0
  51. modal/token_flow.py +3 -3
  52. modal/volume.py +49 -18
  53. modal/volume.pyi +50 -8
  54. {modal-1.0.6.dev61.dist-info → modal-1.1.1.dist-info}/METADATA +2 -2
  55. {modal-1.0.6.dev61.dist-info → modal-1.1.1.dist-info}/RECORD +75 -75
  56. modal_proto/api.proto +140 -14
  57. modal_proto/api_grpc.py +80 -0
  58. modal_proto/api_pb2.py +927 -756
  59. modal_proto/api_pb2.pyi +488 -34
  60. modal_proto/api_pb2_grpc.py +166 -0
  61. modal_proto/api_pb2_grpc.pyi +52 -0
  62. modal_proto/modal_api_grpc.py +5 -0
  63. modal_version/__init__.py +1 -1
  64. /modal/{requirements → builder}/2023.12.312.txt +0 -0
  65. /modal/{requirements → builder}/2023.12.txt +0 -0
  66. /modal/{requirements → builder}/2024.04.txt +0 -0
  67. /modal/{requirements → builder}/2024.10.txt +0 -0
  68. /modal/{requirements → builder}/2025.06.txt +0 -0
  69. /modal/{requirements → builder}/PREVIEW.txt +0 -0
  70. /modal/{requirements → builder}/README.md +0 -0
  71. /modal/{requirements → builder}/base-images.json +0 -0
  72. {modal-1.0.6.dev61.dist-info → modal-1.1.1.dist-info}/WHEEL +0 -0
  73. {modal-1.0.6.dev61.dist-info → modal-1.1.1.dist-info}/entry_points.txt +0 -0
  74. {modal-1.0.6.dev61.dist-info → modal-1.1.1.dist-info}/licenses/LICENSE +0 -0
  75. {modal-1.0.6.dev61.dist-info → modal-1.1.1.dist-info}/top_level.txt +0 -0
modal/parallel_map.py CHANGED
@@ -6,7 +6,7 @@ import time
6
6
  import typing
7
7
  from asyncio import FIRST_COMPLETED
8
8
  from dataclasses import dataclass
9
- from typing import Any, Callable, Optional
9
+ from typing import Any, Callable, Optional, Union
10
10
 
11
11
  from grpclib import Status
12
12
 
@@ -424,6 +424,348 @@ async def _map_invocation(
424
424
  await log_debug_stats_task
425
425
 
426
426
 
427
+ async def _map_invocation_inputplane(
428
+ function: "modal.functions._Function",
429
+ raw_input_queue: _SynchronizedQueue,
430
+ client: "modal.client._Client",
431
+ order_outputs: bool,
432
+ return_exceptions: bool,
433
+ wrap_returned_exceptions: bool,
434
+ count_update_callback: Optional[Callable[[int, int], None]],
435
+ ) -> typing.AsyncGenerator[Any, None]:
436
+ """Input-plane implementation of a function map invocation.
437
+
438
+ This is analogous to `_map_invocation`, but instead of the control-plane
439
+ `FunctionMap` / `FunctionPutInputs` / `FunctionGetOutputs` RPCs it speaks
440
+ the input-plane protocol consisting of `MapStartOrContinue`, `MapAwait`, and `MapCheckInputs`.
441
+ """
442
+
443
+ assert function._input_plane_url, "_map_invocation_inputplane should only be used for input-plane backed functions"
444
+
445
+ input_plane_stub = await client.get_stub(function._input_plane_url)
446
+
447
+ # Required for _create_input.
448
+ assert client.stub, "Client must be hydrated with a stub for _map_invocation_inputplane"
449
+
450
+ # ------------------------------------------------------------
451
+ # Invocation-wide state
452
+ # ------------------------------------------------------------
453
+
454
+ have_all_inputs = False
455
+ map_done_event = asyncio.Event()
456
+
457
+ inputs_created = 0
458
+ outputs_completed = 0
459
+ successful_completions = 0
460
+ failed_completions = 0
461
+ no_context_duplicates = 0
462
+ stale_retry_duplicates = 0
463
+ already_complete_duplicates = 0
464
+ retried_outputs = 0
465
+ input_queue_size = 0
466
+ last_entry_id = ""
467
+
468
+ # The input-plane server returns this after the first request.
469
+ function_call_id = None
470
+ function_call_id_received = asyncio.Event()
471
+
472
+ # Single priority queue that holds *both* fresh inputs (timestamp == now)
473
+ # and future retries (timestamp > now).
474
+ queue: TimestampPriorityQueue[api_pb2.MapStartOrContinueItem] = TimestampPriorityQueue()
475
+
476
+ # Maximum number of inputs that may be in-flight (the server sends this in
477
+ # the first response – fall back to the default if we never receive it for
478
+ # any reason).
479
+ max_inputs_outstanding = MAX_INPUTS_OUTSTANDING_DEFAULT
480
+
481
+ # Input plane does not yet return a retry policy. So we currently disable retries.
482
+ retry_policy = api_pb2.FunctionRetryPolicy(
483
+ retries=0, # Input plane does not yet return a retry policy. So only retry server failures for now.
484
+ initial_delay_ms=1000,
485
+ max_delay_ms=1000,
486
+ backoff_coefficient=1.0,
487
+ )
488
+ map_items_manager = _MapItemsManager(
489
+ retry_policy=retry_policy,
490
+ function_call_invocation_type=api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC,
491
+ retry_queue=queue,
492
+ sync_client_retries_enabled=True,
493
+ max_inputs_outstanding=MAX_INPUTS_OUTSTANDING_DEFAULT,
494
+ is_input_plane_instance=True,
495
+ )
496
+
497
+ def update_counters(
498
+ created_delta: int = 0, completed_delta: int = 0, set_have_all_inputs: Union[bool, None] = None
499
+ ):
500
+ nonlocal inputs_created, outputs_completed, have_all_inputs
501
+
502
+ if created_delta:
503
+ inputs_created += created_delta
504
+ if completed_delta:
505
+ outputs_completed += completed_delta
506
+ if set_have_all_inputs is not None:
507
+ have_all_inputs = set_have_all_inputs
508
+
509
+ if count_update_callback is not None:
510
+ count_update_callback(outputs_completed, inputs_created)
511
+
512
+ if have_all_inputs and outputs_completed >= inputs_created:
513
+ map_done_event.set()
514
+
515
+ async def create_input(argskwargs):
516
+ idx = inputs_created + 1 # 1-indexed map call idx
517
+ update_counters(created_delta=1)
518
+ (args, kwargs) = argskwargs
519
+ put_item: api_pb2.FunctionPutInputsItem = await _create_input(
520
+ args,
521
+ kwargs,
522
+ client.stub,
523
+ max_object_size_bytes=function._max_object_size_bytes,
524
+ idx=idx,
525
+ method_name=function._use_method_name,
526
+ )
527
+ return api_pb2.MapStartOrContinueItem(input=put_item)
528
+
529
+ async def input_iter():
530
+ while True:
531
+ raw_input = await raw_input_queue.get()
532
+ if raw_input is None: # end of input sentinel
533
+ break
534
+ yield raw_input # args, kwargs
535
+
536
+ async def drain_input_generator():
537
+ async with aclosing(
538
+ async_map_ordered(input_iter(), create_input, concurrency=BLOB_MAX_PARALLELISM)
539
+ ) as streamer:
540
+ async for q_item in streamer:
541
+ await queue.put(time.time(), q_item)
542
+
543
+ # All inputs have been read.
544
+ update_counters(set_have_all_inputs=True)
545
+ yield
546
+
547
+ async def pump_inputs():
548
+ nonlocal function_call_id, max_inputs_outstanding
549
+ async for batch in queue_batch_iterator(queue, max_batch_size=MAP_INVOCATION_CHUNK_SIZE):
550
+ # Convert the queued items into the proto format expected by the RPC.
551
+ request_items: list[api_pb2.MapStartOrContinueItem] = [
552
+ api_pb2.MapStartOrContinueItem(input=qi.input, attempt_token=qi.attempt_token) for qi in batch
553
+ ]
554
+
555
+ await map_items_manager.add_items_inputplane(request_items)
556
+
557
+ # Build request
558
+ request = api_pb2.MapStartOrContinueRequest(
559
+ function_id=function.object_id,
560
+ function_call_id=function_call_id,
561
+ parent_input_id=current_input_id() or "",
562
+ items=request_items,
563
+ )
564
+
565
+ metadata = await client.get_input_plane_metadata(function._input_plane_region)
566
+
567
+ response: api_pb2.MapStartOrContinueResponse = await retry_transient_errors(
568
+ input_plane_stub.MapStartOrContinue, request, metadata=metadata
569
+ )
570
+
571
+ # match response items to the corresponding request item index
572
+ response_items_idx_tuple = [
573
+ (request_items[idx].input.idx, attempt_token)
574
+ for idx, attempt_token in enumerate(response.attempt_tokens)
575
+ ]
576
+
577
+ map_items_manager.handle_put_continue_response(response_items_idx_tuple)
578
+
579
+ if function_call_id is None:
580
+ function_call_id = response.function_call_id
581
+ function_call_id_received.set()
582
+ max_inputs_outstanding = response.max_inputs_outstanding or MAX_INPUTS_OUTSTANDING_DEFAULT
583
+ yield
584
+
585
+ async def check_lost_inputs():
586
+ nonlocal last_entry_id # shared with get_all_outputs
587
+ try:
588
+ while not map_done_event.is_set():
589
+ if function_call_id is None:
590
+ await function_call_id_received.wait()
591
+ continue
592
+
593
+ await asyncio.sleep(1)
594
+
595
+ # check_inputs = [(idx, attempt_token), ...]
596
+ check_inputs = map_items_manager.get_input_idxs_waiting_for_output()
597
+ attempt_tokens = [attempt_token for _, attempt_token in check_inputs]
598
+ request = api_pb2.MapCheckInputsRequest(
599
+ last_entry_id=last_entry_id,
600
+ timeout=0, # Non-blocking read
601
+ attempt_tokens=attempt_tokens,
602
+ )
603
+
604
+ metadata = await client.get_input_plane_metadata(function._input_plane_region)
605
+ response: api_pb2.MapCheckInputsResponse = await retry_transient_errors(
606
+ input_plane_stub.MapCheckInputs, request, metadata=metadata
607
+ )
608
+ check_inputs_response = [
609
+ (check_inputs[resp_idx][0], response.lost[resp_idx]) for resp_idx, _ in enumerate(response.lost)
610
+ ]
611
+ # check_inputs_response = [(idx, lost: bool), ...]
612
+ await map_items_manager.handle_check_inputs_response(check_inputs_response)
613
+ yield
614
+ except asyncio.CancelledError:
615
+ pass
616
+
617
+ async def get_all_outputs():
618
+ nonlocal \
619
+ successful_completions, \
620
+ failed_completions, \
621
+ no_context_duplicates, \
622
+ stale_retry_duplicates, \
623
+ already_complete_duplicates, \
624
+ retried_outputs, \
625
+ last_entry_id
626
+
627
+ while not map_done_event.is_set():
628
+ if function_call_id is None:
629
+ await function_call_id_received.wait()
630
+ continue
631
+
632
+ request = api_pb2.MapAwaitRequest(
633
+ function_call_id=function_call_id,
634
+ last_entry_id=last_entry_id,
635
+ requested_at=time.time(),
636
+ timeout=OUTPUTS_TIMEOUT,
637
+ )
638
+ metadata = await client.get_input_plane_metadata(function._input_plane_region)
639
+ get_response_task = asyncio.create_task(
640
+ retry_transient_errors(
641
+ input_plane_stub.MapAwait,
642
+ request,
643
+ max_retries=20,
644
+ attempt_timeout=OUTPUTS_TIMEOUT + ATTEMPT_TIMEOUT_GRACE_PERIOD,
645
+ metadata=metadata,
646
+ )
647
+ )
648
+ map_done_task = asyncio.create_task(map_done_event.wait())
649
+ try:
650
+ done, pending = await asyncio.wait([get_response_task, map_done_task], return_when=FIRST_COMPLETED)
651
+ if get_response_task in done:
652
+ map_done_task.cancel()
653
+ response = get_response_task.result()
654
+ else:
655
+ assert map_done_event.is_set()
656
+ # map is done - no more outputs, so return early
657
+ return
658
+ finally:
659
+ # clean up tasks, in case of cancellations etc.
660
+ get_response_task.cancel()
661
+ map_done_task.cancel()
662
+ last_entry_id = response.last_entry_id
663
+
664
+ for output_item in response.outputs:
665
+ output_type = await map_items_manager.handle_get_outputs_response(output_item, int(time.time()))
666
+ if output_type == _OutputType.SUCCESSFUL_COMPLETION:
667
+ successful_completions += 1
668
+ elif output_type == _OutputType.FAILED_COMPLETION:
669
+ failed_completions += 1
670
+ elif output_type == _OutputType.RETRYING:
671
+ retried_outputs += 1
672
+ elif output_type == _OutputType.NO_CONTEXT_DUPLICATE:
673
+ no_context_duplicates += 1
674
+ elif output_type == _OutputType.STALE_RETRY_DUPLICATE:
675
+ stale_retry_duplicates += 1
676
+ elif output_type == _OutputType.ALREADY_COMPLETE_DUPLICATE:
677
+ already_complete_duplicates += 1
678
+ else:
679
+ raise Exception(f"Unknown output type: {output_type}")
680
+
681
+ if output_type == _OutputType.SUCCESSFUL_COMPLETION or output_type == _OutputType.FAILED_COMPLETION:
682
+ update_counters(completed_delta=1)
683
+ yield output_item
684
+
685
+ async def get_all_outputs_and_clean_up():
686
+ try:
687
+ async with aclosing(get_all_outputs()) as stream:
688
+ async for item in stream:
689
+ yield item
690
+ finally:
691
+ await queue.close()
692
+ pass
693
+
694
+ async def fetch_output(item: api_pb2.FunctionGetOutputsItem) -> tuple[int, Any]:
695
+ try:
696
+ output = await _process_result(item.result, item.data_format, input_plane_stub, client)
697
+ except Exception as e:
698
+ if return_exceptions:
699
+ if wrap_returned_exceptions:
700
+ # Prior to client 1.0.4 there was a bug where return_exceptions would wrap
701
+ # any returned exceptions in a synchronicity.UserCodeException. This adds
702
+ # deprecated non-breaking compatibility bandaid for migrating away from that:
703
+ output = modal.exception.UserCodeException(e)
704
+ else:
705
+ output = e
706
+ else:
707
+ raise e
708
+ return (item.idx, output)
709
+
710
+ async def poll_outputs():
711
+ # map to store out-of-order outputs received
712
+ received_outputs = {}
713
+ output_idx = 1 # 1-indexed map call idx
714
+
715
+ async with aclosing(
716
+ async_map_ordered(get_all_outputs_and_clean_up(), fetch_output, concurrency=BLOB_MAX_PARALLELISM)
717
+ ) as streamer:
718
+ async for idx, output in streamer:
719
+ if not order_outputs:
720
+ yield _OutputValue(output)
721
+ else:
722
+ # hold on to outputs for function maps, so we can reorder them correctly.
723
+ received_outputs[idx] = output
724
+
725
+ while True:
726
+ if output_idx not in received_outputs:
727
+ # we haven't received the output for the current index yet.
728
+ # stop returning outputs to the caller and instead wait for
729
+ # the next output to arrive from the server.
730
+ break
731
+
732
+ output = received_outputs.pop(output_idx)
733
+ yield _OutputValue(output)
734
+ output_idx += 1
735
+
736
+ assert len(received_outputs) == 0
737
+
738
+ async def log_debug_stats():
739
+ def log_stats():
740
+ logger.debug(
741
+ f"Map stats:\nsuccessful_completions={successful_completions} failed_completions={failed_completions} "
742
+ f"no_context_duplicates={no_context_duplicates} stale_retry_duplicates={stale_retry_duplicates} "
743
+ f"already_complete_duplicates={already_complete_duplicates} retried_outputs={retried_outputs} "
744
+ f"function_call_id={function_call_id} max_inputs_outstanding={max_inputs_outstanding} "
745
+ f"map_items_manager_size={len(map_items_manager)} input_queue_size={input_queue_size}"
746
+ )
747
+
748
+ while True:
749
+ log_stats()
750
+ try:
751
+ await asyncio.sleep(10)
752
+ except asyncio.CancelledError:
753
+ # Log final stats before exiting
754
+ log_stats()
755
+ break
756
+
757
+ log_task = asyncio.create_task(log_debug_stats())
758
+
759
+ async with aclosing(
760
+ async_merge(drain_input_generator(), pump_inputs(), poll_outputs(), check_lost_inputs())
761
+ ) as merged:
762
+ async for maybe_output in merged:
763
+ if maybe_output is not None: # ignore None sentinels
764
+ yield maybe_output.value
765
+
766
+ log_task.cancel()
767
+
768
+
427
769
  async def _map_helper(
428
770
  self: "modal.functions.Function",
429
771
  async_input_gen: typing.AsyncGenerator[Any, None],
@@ -756,12 +1098,19 @@ class _MapItemContext:
756
1098
  sync_client_retries_enabled: bool
757
1099
  # Both these futures are strings. Omitting generic type because
758
1100
  # it causes an error when running `inv protoc type-stubs`.
1101
+ # Unused. But important, input_id is not set for inputplane invocations.
759
1102
  input_id: asyncio.Future
760
1103
  input_jwt: asyncio.Future
761
1104
  previous_input_jwt: Optional[str]
762
1105
  _event_loop: asyncio.AbstractEventLoop
763
1106
 
764
- def __init__(self, input: api_pb2.FunctionInput, retry_manager: RetryManager, sync_client_retries_enabled: bool):
1107
+ def __init__(
1108
+ self,
1109
+ input: api_pb2.FunctionInput,
1110
+ retry_manager: RetryManager,
1111
+ sync_client_retries_enabled: bool,
1112
+ is_input_plane_instance: bool = False,
1113
+ ):
765
1114
  self.state = _MapItemState.SENDING
766
1115
  self.input = input
767
1116
  self.retry_manager = retry_manager
@@ -772,7 +1121,22 @@ class _MapItemContext:
772
1121
  # a race condition where we could receive outputs before we have
773
1122
  # recorded the input ID and JWT in `pending_outputs`.
774
1123
  self.input_jwt = self._event_loop.create_future()
1124
+ # Unused. But important, this is not set for inputplane invocations.
775
1125
  self.input_id = self._event_loop.create_future()
1126
+ self._is_input_plane_instance = is_input_plane_instance
1127
+
1128
+ def handle_map_start_or_continue_response(self, attempt_token: str):
1129
+ if not self.input_jwt.done():
1130
+ self.input_jwt.set_result(attempt_token)
1131
+ else:
1132
+ # Create a new future for the next value
1133
+ self.input_jwt = asyncio.Future()
1134
+ self.input_jwt.set_result(attempt_token)
1135
+
1136
+ # Set state to WAITING_FOR_OUTPUT only if current state is SENDING. If state is
1137
+ # RETRYING, WAITING_TO_RETRY, or COMPLETE, then we already got the output.
1138
+ if self.state == _MapItemState.SENDING:
1139
+ self.state = _MapItemState.WAITING_FOR_OUTPUT
776
1140
 
777
1141
  def handle_put_inputs_response(self, item: api_pb2.FunctionPutInputsResponseItem):
778
1142
  self.input_jwt.set_result(item.input_jwt)
@@ -799,7 +1163,7 @@ class _MapItemContext:
799
1163
  if self.state == _MapItemState.COMPLETE:
800
1164
  logger.debug(
801
1165
  f"Received output for input marked as complete. Must be duplicate, so ignoring. "
802
- f"idx={item.idx} input_id={item.input_id}, retry_count={item.retry_count}"
1166
+ f"idx={item.idx} input_id={item.input_id} retry_count={item.retry_count}"
803
1167
  )
804
1168
  return _OutputType.ALREADY_COMPLETE_DUPLICATE
805
1169
  # If the item's retry count doesn't match our retry count, this is probably a duplicate of an old output.
@@ -847,7 +1211,11 @@ class _MapItemContext:
847
1211
 
848
1212
  self.state = _MapItemState.WAITING_TO_RETRY
849
1213
 
850
- await retry_queue.put(now_seconds + (delay_ms / 1000), item.idx)
1214
+ if self._is_input_plane_instance:
1215
+ retry_item = await self.create_map_start_or_continue_item(item.idx)
1216
+ await retry_queue.put(now_seconds + delay_ms / 1_000, retry_item)
1217
+ else:
1218
+ await retry_queue.put(now_seconds + delay_ms / 1_000, item.idx)
851
1219
 
852
1220
  return _OutputType.RETRYING
853
1221
 
@@ -866,6 +1234,16 @@ class _MapItemContext:
866
1234
  self.input_jwt.set_result(input_jwt)
867
1235
  self.state = _MapItemState.WAITING_FOR_OUTPUT
868
1236
 
1237
+ async def create_map_start_or_continue_item(self, idx: int) -> api_pb2.MapStartOrContinueItem:
1238
+ attempt_token = await self.input_jwt
1239
+ return api_pb2.MapStartOrContinueItem(
1240
+ input=api_pb2.FunctionPutInputsItem(
1241
+ input=self.input,
1242
+ idx=idx,
1243
+ ),
1244
+ attempt_token=attempt_token,
1245
+ )
1246
+
869
1247
 
870
1248
  class _MapItemsManager:
871
1249
  def __init__(
@@ -875,6 +1253,7 @@ class _MapItemsManager:
875
1253
  retry_queue: TimestampPriorityQueue,
876
1254
  sync_client_retries_enabled: bool,
877
1255
  max_inputs_outstanding: int,
1256
+ is_input_plane_instance: bool = False,
878
1257
  ):
879
1258
  self._retry_policy = retry_policy
880
1259
  self.function_call_invocation_type = function_call_invocation_type
@@ -885,6 +1264,7 @@ class _MapItemsManager:
885
1264
  self._inputs_outstanding = asyncio.BoundedSemaphore(max_inputs_outstanding)
886
1265
  self._item_context: dict[int, _MapItemContext] = {}
887
1266
  self._sync_client_retries_enabled = sync_client_retries_enabled
1267
+ self._is_input_plane_instance = is_input_plane_instance
888
1268
 
889
1269
  async def add_items(self, items: list[api_pb2.FunctionPutInputsItem]):
890
1270
  for item in items:
@@ -897,6 +1277,21 @@ class _MapItemsManager:
897
1277
  sync_client_retries_enabled=self._sync_client_retries_enabled,
898
1278
  )
899
1279
 
1280
+ async def add_items_inputplane(self, items: list[api_pb2.MapStartOrContinueItem]):
1281
+ for item in items:
1282
+ # acquire semaphore to limit the number of inputs in progress
1283
+ # (either queued to be sent, waiting for completion, or retrying)
1284
+ if item.attempt_token != "": # if it is a retry item
1285
+ self._item_context[item.input.idx].state = _MapItemState.SENDING
1286
+ continue
1287
+ await self._inputs_outstanding.acquire()
1288
+ self._item_context[item.input.idx] = _MapItemContext(
1289
+ input=item.input.input,
1290
+ retry_manager=RetryManager(self._retry_policy),
1291
+ sync_client_retries_enabled=self._sync_client_retries_enabled,
1292
+ is_input_plane_instance=self._is_input_plane_instance,
1293
+ )
1294
+
900
1295
  async def prepare_items_for_retry(self, retriable_idxs: list[int]) -> list[api_pb2.FunctionRetryInputsItem]:
901
1296
  return [await self._item_context[idx].prepare_item_for_retry() for idx in retriable_idxs]
902
1297
 
@@ -911,6 +1306,17 @@ class _MapItemsManager:
911
1306
  if ctx.state == _MapItemState.WAITING_FOR_OUTPUT and ctx.input_jwt.done()
912
1307
  ]
913
1308
 
1309
+ def get_input_idxs_waiting_for_output(self) -> list[tuple[int, str]]:
1310
+ """
1311
+ Returns a list of input_idxs for inputs that are waiting for output.
1312
+ """
1313
+ # Idx doesn't need a future because it is set by client and not server.
1314
+ return [
1315
+ (idx, ctx.input_jwt.result())
1316
+ for idx, ctx in self._item_context.items()
1317
+ if ctx.state == _MapItemState.WAITING_FOR_OUTPUT and ctx.input_jwt.done()
1318
+ ]
1319
+
914
1320
  def _remove_item(self, item_idx: int):
915
1321
  del self._item_context[item_idx]
916
1322
  self._inputs_outstanding.release()
@@ -918,6 +1324,18 @@ class _MapItemsManager:
918
1324
  def get_item_context(self, item_idx: int) -> _MapItemContext:
919
1325
  return self._item_context.get(item_idx)
920
1326
 
1327
+ def handle_put_continue_response(
1328
+ self,
1329
+ items: list[tuple[int, str]], # idx, input_jwt
1330
+ ):
1331
+ for index, item in items:
1332
+ ctx = self._item_context.get(index, None)
1333
+ # If the context is None, then get_all_outputs() has already received a successful
1334
+ # output, and deleted the context. This happens if FunctionGetOutputs completes
1335
+ # before MapStartOrContinueResponse is received.
1336
+ if ctx is not None:
1337
+ ctx.handle_map_start_or_continue_response(item)
1338
+
921
1339
  def handle_put_inputs_response(self, items: list[api_pb2.FunctionPutInputsResponseItem]):
922
1340
  for item in items:
923
1341
  ctx = self._item_context.get(item.idx, None)
@@ -937,6 +1355,16 @@ class _MapItemsManager:
937
1355
  if ctx is not None:
938
1356
  ctx.handle_retry_response(input_jwt)
939
1357
 
1358
+ async def handle_check_inputs_response(self, response: list[tuple[int, bool]]):
1359
+ for idx, lost in response:
1360
+ ctx = self._item_context.get(idx, None)
1361
+ if ctx is not None:
1362
+ if lost:
1363
+ ctx.state = _MapItemState.WAITING_TO_RETRY
1364
+ retry_item = await ctx.create_map_start_or_continue_item(idx)
1365
+ _ = ctx.retry_manager.get_delay_ms() # increment retry count but instant retry for lost inputs
1366
+ await self._retry_queue.put(time.time(), retry_item)
1367
+
940
1368
  async def handle_get_outputs_response(self, item: api_pb2.FunctionGetOutputsItem, now_seconds: int) -> _OutputType:
941
1369
  ctx = self._item_context.get(item.idx, None)
942
1370
  if ctx is None:
modal/parallel_map.pyi CHANGED
@@ -70,6 +70,23 @@ def _map_invocation(
70
70
  count_update_callback: typing.Optional[collections.abc.Callable[[int, int], None]],
71
71
  function_call_invocation_type: int,
72
72
  ): ...
73
+ def _map_invocation_inputplane(
74
+ function: modal._functions._Function,
75
+ raw_input_queue: _SynchronizedQueue,
76
+ client: modal.client._Client,
77
+ order_outputs: bool,
78
+ return_exceptions: bool,
79
+ wrap_returned_exceptions: bool,
80
+ count_update_callback: typing.Optional[collections.abc.Callable[[int, int], None]],
81
+ ) -> typing.AsyncGenerator[typing.Any, None]:
82
+ """Input-plane implementation of a function map invocation.
83
+
84
+ This is analogous to `_map_invocation`, but instead of the control-plane
85
+ `FunctionMap` / `FunctionPutInputs` / `FunctionGetOutputs` RPCs it speaks
86
+ the input-plane protocol consisting of `MapStartOrContinue`, `MapAwait`, and `MapCheckInputs`.
87
+ """
88
+ ...
89
+
73
90
  def _map_helper(
74
91
  self: modal.functions.Function,
75
92
  async_input_gen: typing.AsyncGenerator[typing.Any, None],
@@ -260,10 +277,12 @@ class _MapItemContext:
260
277
  input: modal_proto.api_pb2.FunctionInput,
261
278
  retry_manager: modal.retries.RetryManager,
262
279
  sync_client_retries_enabled: bool,
280
+ is_input_plane_instance: bool = False,
263
281
  ):
264
282
  """Initialize self. See help(type(self)) for accurate signature."""
265
283
  ...
266
284
 
285
+ def handle_map_start_or_continue_response(self, attempt_token: str): ...
267
286
  def handle_put_inputs_response(self, item: modal_proto.api_pb2.FunctionPutInputsResponseItem): ...
268
287
  async def handle_get_outputs_response(
269
288
  self,
@@ -280,6 +299,7 @@ class _MapItemContext:
280
299
 
281
300
  async def prepare_item_for_retry(self) -> modal_proto.api_pb2.FunctionRetryInputsItem: ...
282
301
  def handle_retry_response(self, input_jwt: str): ...
302
+ async def create_map_start_or_continue_item(self, idx: int) -> modal_proto.api_pb2.MapStartOrContinueItem: ...
283
303
 
284
304
  class _MapItemsManager:
285
305
  def __init__(
@@ -289,11 +309,13 @@ class _MapItemsManager:
289
309
  retry_queue: modal._utils.async_utils.TimestampPriorityQueue,
290
310
  sync_client_retries_enabled: bool,
291
311
  max_inputs_outstanding: int,
312
+ is_input_plane_instance: bool = False,
292
313
  ):
293
314
  """Initialize self. See help(type(self)) for accurate signature."""
294
315
  ...
295
316
 
296
317
  async def add_items(self, items: list[modal_proto.api_pb2.FunctionPutInputsItem]): ...
318
+ async def add_items_inputplane(self, items: list[modal_proto.api_pb2.MapStartOrContinueItem]): ...
297
319
  async def prepare_items_for_retry(
298
320
  self, retriable_idxs: list[int]
299
321
  ) -> list[modal_proto.api_pb2.FunctionRetryInputsItem]: ...
@@ -301,10 +323,16 @@ class _MapItemsManager:
301
323
  """Returns a list of input_jwts for inputs that are waiting for output."""
302
324
  ...
303
325
 
326
+ def get_input_idxs_waiting_for_output(self) -> list[tuple[int, str]]:
327
+ """Returns a list of input_idxs for inputs that are waiting for output."""
328
+ ...
329
+
304
330
  def _remove_item(self, item_idx: int): ...
305
331
  def get_item_context(self, item_idx: int) -> _MapItemContext: ...
332
+ def handle_put_continue_response(self, items: list[tuple[int, str]]): ...
306
333
  def handle_put_inputs_response(self, items: list[modal_proto.api_pb2.FunctionPutInputsResponseItem]): ...
307
334
  def handle_retry_response(self, input_jwts: list[str]): ...
335
+ async def handle_check_inputs_response(self, response: list[tuple[int, bool]]): ...
308
336
  async def handle_get_outputs_response(
309
337
  self, item: modal_proto.api_pb2.FunctionGetOutputsItem, now_seconds: int
310
338
  ) -> _OutputType: ...