modal 1.0.6.dev61__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of modal might be problematic. Click here for more details.
- modal/__main__.py +2 -2
- modal/_clustered_functions.py +3 -0
- modal/_clustered_functions.pyi +3 -2
- modal/_functions.py +78 -26
- modal/_object.py +9 -1
- modal/_output.py +14 -25
- modal/_runtime/gpu_memory_snapshot.py +158 -54
- modal/_utils/async_utils.py +6 -4
- modal/_utils/auth_token_manager.py +1 -1
- modal/_utils/blob_utils.py +16 -21
- modal/_utils/function_utils.py +16 -4
- modal/_utils/time_utils.py +8 -4
- modal/app.py +0 -4
- modal/app.pyi +0 -4
- modal/cli/_traceback.py +3 -2
- modal/cli/app.py +4 -4
- modal/cli/cluster.py +4 -4
- modal/cli/config.py +2 -2
- modal/cli/container.py +2 -2
- modal/cli/dict.py +4 -4
- modal/cli/entry_point.py +2 -2
- modal/cli/import_refs.py +3 -3
- modal/cli/network_file_system.py +8 -9
- modal/cli/profile.py +2 -2
- modal/cli/queues.py +5 -5
- modal/cli/secret.py +5 -5
- modal/cli/utils.py +3 -4
- modal/cli/volume.py +8 -9
- modal/client.py +8 -1
- modal/client.pyi +9 -10
- modal/container_process.py +2 -2
- modal/dict.py +47 -3
- modal/dict.pyi +55 -0
- modal/exception.py +4 -0
- modal/experimental/__init__.py +1 -1
- modal/experimental/flash.py +18 -2
- modal/experimental/flash.pyi +19 -0
- modal/functions.pyi +6 -7
- modal/image.py +26 -10
- modal/image.pyi +12 -4
- modal/mount.py +1 -1
- modal/object.pyi +4 -0
- modal/parallel_map.py +432 -4
- modal/parallel_map.pyi +28 -0
- modal/queue.py +46 -3
- modal/queue.pyi +53 -0
- modal/sandbox.py +105 -25
- modal/sandbox.pyi +108 -18
- modal/secret.py +48 -5
- modal/secret.pyi +55 -0
- modal/token_flow.py +3 -3
- modal/volume.py +49 -18
- modal/volume.pyi +50 -8
- {modal-1.0.6.dev61.dist-info → modal-1.1.1.dist-info}/METADATA +2 -2
- {modal-1.0.6.dev61.dist-info → modal-1.1.1.dist-info}/RECORD +75 -75
- modal_proto/api.proto +140 -14
- modal_proto/api_grpc.py +80 -0
- modal_proto/api_pb2.py +927 -756
- modal_proto/api_pb2.pyi +488 -34
- modal_proto/api_pb2_grpc.py +166 -0
- modal_proto/api_pb2_grpc.pyi +52 -0
- modal_proto/modal_api_grpc.py +5 -0
- modal_version/__init__.py +1 -1
- /modal/{requirements → builder}/2023.12.312.txt +0 -0
- /modal/{requirements → builder}/2023.12.txt +0 -0
- /modal/{requirements → builder}/2024.04.txt +0 -0
- /modal/{requirements → builder}/2024.10.txt +0 -0
- /modal/{requirements → builder}/2025.06.txt +0 -0
- /modal/{requirements → builder}/PREVIEW.txt +0 -0
- /modal/{requirements → builder}/README.md +0 -0
- /modal/{requirements → builder}/base-images.json +0 -0
- {modal-1.0.6.dev61.dist-info → modal-1.1.1.dist-info}/WHEEL +0 -0
- {modal-1.0.6.dev61.dist-info → modal-1.1.1.dist-info}/entry_points.txt +0 -0
- {modal-1.0.6.dev61.dist-info → modal-1.1.1.dist-info}/licenses/LICENSE +0 -0
- {modal-1.0.6.dev61.dist-info → modal-1.1.1.dist-info}/top_level.txt +0 -0
modal/parallel_map.py
CHANGED
|
@@ -6,7 +6,7 @@ import time
|
|
|
6
6
|
import typing
|
|
7
7
|
from asyncio import FIRST_COMPLETED
|
|
8
8
|
from dataclasses import dataclass
|
|
9
|
-
from typing import Any, Callable, Optional
|
|
9
|
+
from typing import Any, Callable, Optional, Union
|
|
10
10
|
|
|
11
11
|
from grpclib import Status
|
|
12
12
|
|
|
@@ -424,6 +424,348 @@ async def _map_invocation(
|
|
|
424
424
|
await log_debug_stats_task
|
|
425
425
|
|
|
426
426
|
|
|
427
|
+
async def _map_invocation_inputplane(
|
|
428
|
+
function: "modal.functions._Function",
|
|
429
|
+
raw_input_queue: _SynchronizedQueue,
|
|
430
|
+
client: "modal.client._Client",
|
|
431
|
+
order_outputs: bool,
|
|
432
|
+
return_exceptions: bool,
|
|
433
|
+
wrap_returned_exceptions: bool,
|
|
434
|
+
count_update_callback: Optional[Callable[[int, int], None]],
|
|
435
|
+
) -> typing.AsyncGenerator[Any, None]:
|
|
436
|
+
"""Input-plane implementation of a function map invocation.
|
|
437
|
+
|
|
438
|
+
This is analogous to `_map_invocation`, but instead of the control-plane
|
|
439
|
+
`FunctionMap` / `FunctionPutInputs` / `FunctionGetOutputs` RPCs it speaks
|
|
440
|
+
the input-plane protocol consisting of `MapStartOrContinue`, `MapAwait`, and `MapCheckInputs`.
|
|
441
|
+
"""
|
|
442
|
+
|
|
443
|
+
assert function._input_plane_url, "_map_invocation_inputplane should only be used for input-plane backed functions"
|
|
444
|
+
|
|
445
|
+
input_plane_stub = await client.get_stub(function._input_plane_url)
|
|
446
|
+
|
|
447
|
+
# Required for _create_input.
|
|
448
|
+
assert client.stub, "Client must be hydrated with a stub for _map_invocation_inputplane"
|
|
449
|
+
|
|
450
|
+
# ------------------------------------------------------------
|
|
451
|
+
# Invocation-wide state
|
|
452
|
+
# ------------------------------------------------------------
|
|
453
|
+
|
|
454
|
+
have_all_inputs = False
|
|
455
|
+
map_done_event = asyncio.Event()
|
|
456
|
+
|
|
457
|
+
inputs_created = 0
|
|
458
|
+
outputs_completed = 0
|
|
459
|
+
successful_completions = 0
|
|
460
|
+
failed_completions = 0
|
|
461
|
+
no_context_duplicates = 0
|
|
462
|
+
stale_retry_duplicates = 0
|
|
463
|
+
already_complete_duplicates = 0
|
|
464
|
+
retried_outputs = 0
|
|
465
|
+
input_queue_size = 0
|
|
466
|
+
last_entry_id = ""
|
|
467
|
+
|
|
468
|
+
# The input-plane server returns this after the first request.
|
|
469
|
+
function_call_id = None
|
|
470
|
+
function_call_id_received = asyncio.Event()
|
|
471
|
+
|
|
472
|
+
# Single priority queue that holds *both* fresh inputs (timestamp == now)
|
|
473
|
+
# and future retries (timestamp > now).
|
|
474
|
+
queue: TimestampPriorityQueue[api_pb2.MapStartOrContinueItem] = TimestampPriorityQueue()
|
|
475
|
+
|
|
476
|
+
# Maximum number of inputs that may be in-flight (the server sends this in
|
|
477
|
+
# the first response – fall back to the default if we never receive it for
|
|
478
|
+
# any reason).
|
|
479
|
+
max_inputs_outstanding = MAX_INPUTS_OUTSTANDING_DEFAULT
|
|
480
|
+
|
|
481
|
+
# Input plane does not yet return a retry policy. So we currently disable retries.
|
|
482
|
+
retry_policy = api_pb2.FunctionRetryPolicy(
|
|
483
|
+
retries=0, # Input plane does not yet return a retry policy. So only retry server failures for now.
|
|
484
|
+
initial_delay_ms=1000,
|
|
485
|
+
max_delay_ms=1000,
|
|
486
|
+
backoff_coefficient=1.0,
|
|
487
|
+
)
|
|
488
|
+
map_items_manager = _MapItemsManager(
|
|
489
|
+
retry_policy=retry_policy,
|
|
490
|
+
function_call_invocation_type=api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC,
|
|
491
|
+
retry_queue=queue,
|
|
492
|
+
sync_client_retries_enabled=True,
|
|
493
|
+
max_inputs_outstanding=MAX_INPUTS_OUTSTANDING_DEFAULT,
|
|
494
|
+
is_input_plane_instance=True,
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
def update_counters(
|
|
498
|
+
created_delta: int = 0, completed_delta: int = 0, set_have_all_inputs: Union[bool, None] = None
|
|
499
|
+
):
|
|
500
|
+
nonlocal inputs_created, outputs_completed, have_all_inputs
|
|
501
|
+
|
|
502
|
+
if created_delta:
|
|
503
|
+
inputs_created += created_delta
|
|
504
|
+
if completed_delta:
|
|
505
|
+
outputs_completed += completed_delta
|
|
506
|
+
if set_have_all_inputs is not None:
|
|
507
|
+
have_all_inputs = set_have_all_inputs
|
|
508
|
+
|
|
509
|
+
if count_update_callback is not None:
|
|
510
|
+
count_update_callback(outputs_completed, inputs_created)
|
|
511
|
+
|
|
512
|
+
if have_all_inputs and outputs_completed >= inputs_created:
|
|
513
|
+
map_done_event.set()
|
|
514
|
+
|
|
515
|
+
async def create_input(argskwargs):
|
|
516
|
+
idx = inputs_created + 1 # 1-indexed map call idx
|
|
517
|
+
update_counters(created_delta=1)
|
|
518
|
+
(args, kwargs) = argskwargs
|
|
519
|
+
put_item: api_pb2.FunctionPutInputsItem = await _create_input(
|
|
520
|
+
args,
|
|
521
|
+
kwargs,
|
|
522
|
+
client.stub,
|
|
523
|
+
max_object_size_bytes=function._max_object_size_bytes,
|
|
524
|
+
idx=idx,
|
|
525
|
+
method_name=function._use_method_name,
|
|
526
|
+
)
|
|
527
|
+
return api_pb2.MapStartOrContinueItem(input=put_item)
|
|
528
|
+
|
|
529
|
+
async def input_iter():
|
|
530
|
+
while True:
|
|
531
|
+
raw_input = await raw_input_queue.get()
|
|
532
|
+
if raw_input is None: # end of input sentinel
|
|
533
|
+
break
|
|
534
|
+
yield raw_input # args, kwargs
|
|
535
|
+
|
|
536
|
+
async def drain_input_generator():
|
|
537
|
+
async with aclosing(
|
|
538
|
+
async_map_ordered(input_iter(), create_input, concurrency=BLOB_MAX_PARALLELISM)
|
|
539
|
+
) as streamer:
|
|
540
|
+
async for q_item in streamer:
|
|
541
|
+
await queue.put(time.time(), q_item)
|
|
542
|
+
|
|
543
|
+
# All inputs have been read.
|
|
544
|
+
update_counters(set_have_all_inputs=True)
|
|
545
|
+
yield
|
|
546
|
+
|
|
547
|
+
async def pump_inputs():
|
|
548
|
+
nonlocal function_call_id, max_inputs_outstanding
|
|
549
|
+
async for batch in queue_batch_iterator(queue, max_batch_size=MAP_INVOCATION_CHUNK_SIZE):
|
|
550
|
+
# Convert the queued items into the proto format expected by the RPC.
|
|
551
|
+
request_items: list[api_pb2.MapStartOrContinueItem] = [
|
|
552
|
+
api_pb2.MapStartOrContinueItem(input=qi.input, attempt_token=qi.attempt_token) for qi in batch
|
|
553
|
+
]
|
|
554
|
+
|
|
555
|
+
await map_items_manager.add_items_inputplane(request_items)
|
|
556
|
+
|
|
557
|
+
# Build request
|
|
558
|
+
request = api_pb2.MapStartOrContinueRequest(
|
|
559
|
+
function_id=function.object_id,
|
|
560
|
+
function_call_id=function_call_id,
|
|
561
|
+
parent_input_id=current_input_id() or "",
|
|
562
|
+
items=request_items,
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
metadata = await client.get_input_plane_metadata(function._input_plane_region)
|
|
566
|
+
|
|
567
|
+
response: api_pb2.MapStartOrContinueResponse = await retry_transient_errors(
|
|
568
|
+
input_plane_stub.MapStartOrContinue, request, metadata=metadata
|
|
569
|
+
)
|
|
570
|
+
|
|
571
|
+
# match response items to the corresponding request item index
|
|
572
|
+
response_items_idx_tuple = [
|
|
573
|
+
(request_items[idx].input.idx, attempt_token)
|
|
574
|
+
for idx, attempt_token in enumerate(response.attempt_tokens)
|
|
575
|
+
]
|
|
576
|
+
|
|
577
|
+
map_items_manager.handle_put_continue_response(response_items_idx_tuple)
|
|
578
|
+
|
|
579
|
+
if function_call_id is None:
|
|
580
|
+
function_call_id = response.function_call_id
|
|
581
|
+
function_call_id_received.set()
|
|
582
|
+
max_inputs_outstanding = response.max_inputs_outstanding or MAX_INPUTS_OUTSTANDING_DEFAULT
|
|
583
|
+
yield
|
|
584
|
+
|
|
585
|
+
async def check_lost_inputs():
|
|
586
|
+
nonlocal last_entry_id # shared with get_all_outputs
|
|
587
|
+
try:
|
|
588
|
+
while not map_done_event.is_set():
|
|
589
|
+
if function_call_id is None:
|
|
590
|
+
await function_call_id_received.wait()
|
|
591
|
+
continue
|
|
592
|
+
|
|
593
|
+
await asyncio.sleep(1)
|
|
594
|
+
|
|
595
|
+
# check_inputs = [(idx, attempt_token), ...]
|
|
596
|
+
check_inputs = map_items_manager.get_input_idxs_waiting_for_output()
|
|
597
|
+
attempt_tokens = [attempt_token for _, attempt_token in check_inputs]
|
|
598
|
+
request = api_pb2.MapCheckInputsRequest(
|
|
599
|
+
last_entry_id=last_entry_id,
|
|
600
|
+
timeout=0, # Non-blocking read
|
|
601
|
+
attempt_tokens=attempt_tokens,
|
|
602
|
+
)
|
|
603
|
+
|
|
604
|
+
metadata = await client.get_input_plane_metadata(function._input_plane_region)
|
|
605
|
+
response: api_pb2.MapCheckInputsResponse = await retry_transient_errors(
|
|
606
|
+
input_plane_stub.MapCheckInputs, request, metadata=metadata
|
|
607
|
+
)
|
|
608
|
+
check_inputs_response = [
|
|
609
|
+
(check_inputs[resp_idx][0], response.lost[resp_idx]) for resp_idx, _ in enumerate(response.lost)
|
|
610
|
+
]
|
|
611
|
+
# check_inputs_response = [(idx, lost: bool), ...]
|
|
612
|
+
await map_items_manager.handle_check_inputs_response(check_inputs_response)
|
|
613
|
+
yield
|
|
614
|
+
except asyncio.CancelledError:
|
|
615
|
+
pass
|
|
616
|
+
|
|
617
|
+
async def get_all_outputs():
|
|
618
|
+
nonlocal \
|
|
619
|
+
successful_completions, \
|
|
620
|
+
failed_completions, \
|
|
621
|
+
no_context_duplicates, \
|
|
622
|
+
stale_retry_duplicates, \
|
|
623
|
+
already_complete_duplicates, \
|
|
624
|
+
retried_outputs, \
|
|
625
|
+
last_entry_id
|
|
626
|
+
|
|
627
|
+
while not map_done_event.is_set():
|
|
628
|
+
if function_call_id is None:
|
|
629
|
+
await function_call_id_received.wait()
|
|
630
|
+
continue
|
|
631
|
+
|
|
632
|
+
request = api_pb2.MapAwaitRequest(
|
|
633
|
+
function_call_id=function_call_id,
|
|
634
|
+
last_entry_id=last_entry_id,
|
|
635
|
+
requested_at=time.time(),
|
|
636
|
+
timeout=OUTPUTS_TIMEOUT,
|
|
637
|
+
)
|
|
638
|
+
metadata = await client.get_input_plane_metadata(function._input_plane_region)
|
|
639
|
+
get_response_task = asyncio.create_task(
|
|
640
|
+
retry_transient_errors(
|
|
641
|
+
input_plane_stub.MapAwait,
|
|
642
|
+
request,
|
|
643
|
+
max_retries=20,
|
|
644
|
+
attempt_timeout=OUTPUTS_TIMEOUT + ATTEMPT_TIMEOUT_GRACE_PERIOD,
|
|
645
|
+
metadata=metadata,
|
|
646
|
+
)
|
|
647
|
+
)
|
|
648
|
+
map_done_task = asyncio.create_task(map_done_event.wait())
|
|
649
|
+
try:
|
|
650
|
+
done, pending = await asyncio.wait([get_response_task, map_done_task], return_when=FIRST_COMPLETED)
|
|
651
|
+
if get_response_task in done:
|
|
652
|
+
map_done_task.cancel()
|
|
653
|
+
response = get_response_task.result()
|
|
654
|
+
else:
|
|
655
|
+
assert map_done_event.is_set()
|
|
656
|
+
# map is done - no more outputs, so return early
|
|
657
|
+
return
|
|
658
|
+
finally:
|
|
659
|
+
# clean up tasks, in case of cancellations etc.
|
|
660
|
+
get_response_task.cancel()
|
|
661
|
+
map_done_task.cancel()
|
|
662
|
+
last_entry_id = response.last_entry_id
|
|
663
|
+
|
|
664
|
+
for output_item in response.outputs:
|
|
665
|
+
output_type = await map_items_manager.handle_get_outputs_response(output_item, int(time.time()))
|
|
666
|
+
if output_type == _OutputType.SUCCESSFUL_COMPLETION:
|
|
667
|
+
successful_completions += 1
|
|
668
|
+
elif output_type == _OutputType.FAILED_COMPLETION:
|
|
669
|
+
failed_completions += 1
|
|
670
|
+
elif output_type == _OutputType.RETRYING:
|
|
671
|
+
retried_outputs += 1
|
|
672
|
+
elif output_type == _OutputType.NO_CONTEXT_DUPLICATE:
|
|
673
|
+
no_context_duplicates += 1
|
|
674
|
+
elif output_type == _OutputType.STALE_RETRY_DUPLICATE:
|
|
675
|
+
stale_retry_duplicates += 1
|
|
676
|
+
elif output_type == _OutputType.ALREADY_COMPLETE_DUPLICATE:
|
|
677
|
+
already_complete_duplicates += 1
|
|
678
|
+
else:
|
|
679
|
+
raise Exception(f"Unknown output type: {output_type}")
|
|
680
|
+
|
|
681
|
+
if output_type == _OutputType.SUCCESSFUL_COMPLETION or output_type == _OutputType.FAILED_COMPLETION:
|
|
682
|
+
update_counters(completed_delta=1)
|
|
683
|
+
yield output_item
|
|
684
|
+
|
|
685
|
+
async def get_all_outputs_and_clean_up():
|
|
686
|
+
try:
|
|
687
|
+
async with aclosing(get_all_outputs()) as stream:
|
|
688
|
+
async for item in stream:
|
|
689
|
+
yield item
|
|
690
|
+
finally:
|
|
691
|
+
await queue.close()
|
|
692
|
+
pass
|
|
693
|
+
|
|
694
|
+
async def fetch_output(item: api_pb2.FunctionGetOutputsItem) -> tuple[int, Any]:
|
|
695
|
+
try:
|
|
696
|
+
output = await _process_result(item.result, item.data_format, input_plane_stub, client)
|
|
697
|
+
except Exception as e:
|
|
698
|
+
if return_exceptions:
|
|
699
|
+
if wrap_returned_exceptions:
|
|
700
|
+
# Prior to client 1.0.4 there was a bug where return_exceptions would wrap
|
|
701
|
+
# any returned exceptions in a synchronicity.UserCodeException. This adds
|
|
702
|
+
# deprecated non-breaking compatibility bandaid for migrating away from that:
|
|
703
|
+
output = modal.exception.UserCodeException(e)
|
|
704
|
+
else:
|
|
705
|
+
output = e
|
|
706
|
+
else:
|
|
707
|
+
raise e
|
|
708
|
+
return (item.idx, output)
|
|
709
|
+
|
|
710
|
+
async def poll_outputs():
|
|
711
|
+
# map to store out-of-order outputs received
|
|
712
|
+
received_outputs = {}
|
|
713
|
+
output_idx = 1 # 1-indexed map call idx
|
|
714
|
+
|
|
715
|
+
async with aclosing(
|
|
716
|
+
async_map_ordered(get_all_outputs_and_clean_up(), fetch_output, concurrency=BLOB_MAX_PARALLELISM)
|
|
717
|
+
) as streamer:
|
|
718
|
+
async for idx, output in streamer:
|
|
719
|
+
if not order_outputs:
|
|
720
|
+
yield _OutputValue(output)
|
|
721
|
+
else:
|
|
722
|
+
# hold on to outputs for function maps, so we can reorder them correctly.
|
|
723
|
+
received_outputs[idx] = output
|
|
724
|
+
|
|
725
|
+
while True:
|
|
726
|
+
if output_idx not in received_outputs:
|
|
727
|
+
# we haven't received the output for the current index yet.
|
|
728
|
+
# stop returning outputs to the caller and instead wait for
|
|
729
|
+
# the next output to arrive from the server.
|
|
730
|
+
break
|
|
731
|
+
|
|
732
|
+
output = received_outputs.pop(output_idx)
|
|
733
|
+
yield _OutputValue(output)
|
|
734
|
+
output_idx += 1
|
|
735
|
+
|
|
736
|
+
assert len(received_outputs) == 0
|
|
737
|
+
|
|
738
|
+
async def log_debug_stats():
|
|
739
|
+
def log_stats():
|
|
740
|
+
logger.debug(
|
|
741
|
+
f"Map stats:\nsuccessful_completions={successful_completions} failed_completions={failed_completions} "
|
|
742
|
+
f"no_context_duplicates={no_context_duplicates} stale_retry_duplicates={stale_retry_duplicates} "
|
|
743
|
+
f"already_complete_duplicates={already_complete_duplicates} retried_outputs={retried_outputs} "
|
|
744
|
+
f"function_call_id={function_call_id} max_inputs_outstanding={max_inputs_outstanding} "
|
|
745
|
+
f"map_items_manager_size={len(map_items_manager)} input_queue_size={input_queue_size}"
|
|
746
|
+
)
|
|
747
|
+
|
|
748
|
+
while True:
|
|
749
|
+
log_stats()
|
|
750
|
+
try:
|
|
751
|
+
await asyncio.sleep(10)
|
|
752
|
+
except asyncio.CancelledError:
|
|
753
|
+
# Log final stats before exiting
|
|
754
|
+
log_stats()
|
|
755
|
+
break
|
|
756
|
+
|
|
757
|
+
log_task = asyncio.create_task(log_debug_stats())
|
|
758
|
+
|
|
759
|
+
async with aclosing(
|
|
760
|
+
async_merge(drain_input_generator(), pump_inputs(), poll_outputs(), check_lost_inputs())
|
|
761
|
+
) as merged:
|
|
762
|
+
async for maybe_output in merged:
|
|
763
|
+
if maybe_output is not None: # ignore None sentinels
|
|
764
|
+
yield maybe_output.value
|
|
765
|
+
|
|
766
|
+
log_task.cancel()
|
|
767
|
+
|
|
768
|
+
|
|
427
769
|
async def _map_helper(
|
|
428
770
|
self: "modal.functions.Function",
|
|
429
771
|
async_input_gen: typing.AsyncGenerator[Any, None],
|
|
@@ -756,12 +1098,19 @@ class _MapItemContext:
|
|
|
756
1098
|
sync_client_retries_enabled: bool
|
|
757
1099
|
# Both these futures are strings. Omitting generic type because
|
|
758
1100
|
# it causes an error when running `inv protoc type-stubs`.
|
|
1101
|
+
# Unused. But important, input_id is not set for inputplane invocations.
|
|
759
1102
|
input_id: asyncio.Future
|
|
760
1103
|
input_jwt: asyncio.Future
|
|
761
1104
|
previous_input_jwt: Optional[str]
|
|
762
1105
|
_event_loop: asyncio.AbstractEventLoop
|
|
763
1106
|
|
|
764
|
-
def __init__(
|
|
1107
|
+
def __init__(
|
|
1108
|
+
self,
|
|
1109
|
+
input: api_pb2.FunctionInput,
|
|
1110
|
+
retry_manager: RetryManager,
|
|
1111
|
+
sync_client_retries_enabled: bool,
|
|
1112
|
+
is_input_plane_instance: bool = False,
|
|
1113
|
+
):
|
|
765
1114
|
self.state = _MapItemState.SENDING
|
|
766
1115
|
self.input = input
|
|
767
1116
|
self.retry_manager = retry_manager
|
|
@@ -772,7 +1121,22 @@ class _MapItemContext:
|
|
|
772
1121
|
# a race condition where we could receive outputs before we have
|
|
773
1122
|
# recorded the input ID and JWT in `pending_outputs`.
|
|
774
1123
|
self.input_jwt = self._event_loop.create_future()
|
|
1124
|
+
# Unused. But important, this is not set for inputplane invocations.
|
|
775
1125
|
self.input_id = self._event_loop.create_future()
|
|
1126
|
+
self._is_input_plane_instance = is_input_plane_instance
|
|
1127
|
+
|
|
1128
|
+
def handle_map_start_or_continue_response(self, attempt_token: str):
|
|
1129
|
+
if not self.input_jwt.done():
|
|
1130
|
+
self.input_jwt.set_result(attempt_token)
|
|
1131
|
+
else:
|
|
1132
|
+
# Create a new future for the next value
|
|
1133
|
+
self.input_jwt = asyncio.Future()
|
|
1134
|
+
self.input_jwt.set_result(attempt_token)
|
|
1135
|
+
|
|
1136
|
+
# Set state to WAITING_FOR_OUTPUT only if current state is SENDING. If state is
|
|
1137
|
+
# RETRYING, WAITING_TO_RETRY, or COMPLETE, then we already got the output.
|
|
1138
|
+
if self.state == _MapItemState.SENDING:
|
|
1139
|
+
self.state = _MapItemState.WAITING_FOR_OUTPUT
|
|
776
1140
|
|
|
777
1141
|
def handle_put_inputs_response(self, item: api_pb2.FunctionPutInputsResponseItem):
|
|
778
1142
|
self.input_jwt.set_result(item.input_jwt)
|
|
@@ -799,7 +1163,7 @@ class _MapItemContext:
|
|
|
799
1163
|
if self.state == _MapItemState.COMPLETE:
|
|
800
1164
|
logger.debug(
|
|
801
1165
|
f"Received output for input marked as complete. Must be duplicate, so ignoring. "
|
|
802
|
-
f"idx={item.idx} input_id={item.input_id}
|
|
1166
|
+
f"idx={item.idx} input_id={item.input_id} retry_count={item.retry_count}"
|
|
803
1167
|
)
|
|
804
1168
|
return _OutputType.ALREADY_COMPLETE_DUPLICATE
|
|
805
1169
|
# If the item's retry count doesn't match our retry count, this is probably a duplicate of an old output.
|
|
@@ -847,7 +1211,11 @@ class _MapItemContext:
|
|
|
847
1211
|
|
|
848
1212
|
self.state = _MapItemState.WAITING_TO_RETRY
|
|
849
1213
|
|
|
850
|
-
|
|
1214
|
+
if self._is_input_plane_instance:
|
|
1215
|
+
retry_item = await self.create_map_start_or_continue_item(item.idx)
|
|
1216
|
+
await retry_queue.put(now_seconds + delay_ms / 1_000, retry_item)
|
|
1217
|
+
else:
|
|
1218
|
+
await retry_queue.put(now_seconds + delay_ms / 1_000, item.idx)
|
|
851
1219
|
|
|
852
1220
|
return _OutputType.RETRYING
|
|
853
1221
|
|
|
@@ -866,6 +1234,16 @@ class _MapItemContext:
|
|
|
866
1234
|
self.input_jwt.set_result(input_jwt)
|
|
867
1235
|
self.state = _MapItemState.WAITING_FOR_OUTPUT
|
|
868
1236
|
|
|
1237
|
+
async def create_map_start_or_continue_item(self, idx: int) -> api_pb2.MapStartOrContinueItem:
|
|
1238
|
+
attempt_token = await self.input_jwt
|
|
1239
|
+
return api_pb2.MapStartOrContinueItem(
|
|
1240
|
+
input=api_pb2.FunctionPutInputsItem(
|
|
1241
|
+
input=self.input,
|
|
1242
|
+
idx=idx,
|
|
1243
|
+
),
|
|
1244
|
+
attempt_token=attempt_token,
|
|
1245
|
+
)
|
|
1246
|
+
|
|
869
1247
|
|
|
870
1248
|
class _MapItemsManager:
|
|
871
1249
|
def __init__(
|
|
@@ -875,6 +1253,7 @@ class _MapItemsManager:
|
|
|
875
1253
|
retry_queue: TimestampPriorityQueue,
|
|
876
1254
|
sync_client_retries_enabled: bool,
|
|
877
1255
|
max_inputs_outstanding: int,
|
|
1256
|
+
is_input_plane_instance: bool = False,
|
|
878
1257
|
):
|
|
879
1258
|
self._retry_policy = retry_policy
|
|
880
1259
|
self.function_call_invocation_type = function_call_invocation_type
|
|
@@ -885,6 +1264,7 @@ class _MapItemsManager:
|
|
|
885
1264
|
self._inputs_outstanding = asyncio.BoundedSemaphore(max_inputs_outstanding)
|
|
886
1265
|
self._item_context: dict[int, _MapItemContext] = {}
|
|
887
1266
|
self._sync_client_retries_enabled = sync_client_retries_enabled
|
|
1267
|
+
self._is_input_plane_instance = is_input_plane_instance
|
|
888
1268
|
|
|
889
1269
|
async def add_items(self, items: list[api_pb2.FunctionPutInputsItem]):
|
|
890
1270
|
for item in items:
|
|
@@ -897,6 +1277,21 @@ class _MapItemsManager:
|
|
|
897
1277
|
sync_client_retries_enabled=self._sync_client_retries_enabled,
|
|
898
1278
|
)
|
|
899
1279
|
|
|
1280
|
+
async def add_items_inputplane(self, items: list[api_pb2.MapStartOrContinueItem]):
|
|
1281
|
+
for item in items:
|
|
1282
|
+
# acquire semaphore to limit the number of inputs in progress
|
|
1283
|
+
# (either queued to be sent, waiting for completion, or retrying)
|
|
1284
|
+
if item.attempt_token != "": # if it is a retry item
|
|
1285
|
+
self._item_context[item.input.idx].state = _MapItemState.SENDING
|
|
1286
|
+
continue
|
|
1287
|
+
await self._inputs_outstanding.acquire()
|
|
1288
|
+
self._item_context[item.input.idx] = _MapItemContext(
|
|
1289
|
+
input=item.input.input,
|
|
1290
|
+
retry_manager=RetryManager(self._retry_policy),
|
|
1291
|
+
sync_client_retries_enabled=self._sync_client_retries_enabled,
|
|
1292
|
+
is_input_plane_instance=self._is_input_plane_instance,
|
|
1293
|
+
)
|
|
1294
|
+
|
|
900
1295
|
async def prepare_items_for_retry(self, retriable_idxs: list[int]) -> list[api_pb2.FunctionRetryInputsItem]:
|
|
901
1296
|
return [await self._item_context[idx].prepare_item_for_retry() for idx in retriable_idxs]
|
|
902
1297
|
|
|
@@ -911,6 +1306,17 @@ class _MapItemsManager:
|
|
|
911
1306
|
if ctx.state == _MapItemState.WAITING_FOR_OUTPUT and ctx.input_jwt.done()
|
|
912
1307
|
]
|
|
913
1308
|
|
|
1309
|
+
def get_input_idxs_waiting_for_output(self) -> list[tuple[int, str]]:
|
|
1310
|
+
"""
|
|
1311
|
+
Returns a list of input_idxs for inputs that are waiting for output.
|
|
1312
|
+
"""
|
|
1313
|
+
# Idx doesn't need a future because it is set by client and not server.
|
|
1314
|
+
return [
|
|
1315
|
+
(idx, ctx.input_jwt.result())
|
|
1316
|
+
for idx, ctx in self._item_context.items()
|
|
1317
|
+
if ctx.state == _MapItemState.WAITING_FOR_OUTPUT and ctx.input_jwt.done()
|
|
1318
|
+
]
|
|
1319
|
+
|
|
914
1320
|
def _remove_item(self, item_idx: int):
|
|
915
1321
|
del self._item_context[item_idx]
|
|
916
1322
|
self._inputs_outstanding.release()
|
|
@@ -918,6 +1324,18 @@ class _MapItemsManager:
|
|
|
918
1324
|
def get_item_context(self, item_idx: int) -> _MapItemContext:
|
|
919
1325
|
return self._item_context.get(item_idx)
|
|
920
1326
|
|
|
1327
|
+
def handle_put_continue_response(
|
|
1328
|
+
self,
|
|
1329
|
+
items: list[tuple[int, str]], # idx, input_jwt
|
|
1330
|
+
):
|
|
1331
|
+
for index, item in items:
|
|
1332
|
+
ctx = self._item_context.get(index, None)
|
|
1333
|
+
# If the context is None, then get_all_outputs() has already received a successful
|
|
1334
|
+
# output, and deleted the context. This happens if FunctionGetOutputs completes
|
|
1335
|
+
# before MapStartOrContinueResponse is received.
|
|
1336
|
+
if ctx is not None:
|
|
1337
|
+
ctx.handle_map_start_or_continue_response(item)
|
|
1338
|
+
|
|
921
1339
|
def handle_put_inputs_response(self, items: list[api_pb2.FunctionPutInputsResponseItem]):
|
|
922
1340
|
for item in items:
|
|
923
1341
|
ctx = self._item_context.get(item.idx, None)
|
|
@@ -937,6 +1355,16 @@ class _MapItemsManager:
|
|
|
937
1355
|
if ctx is not None:
|
|
938
1356
|
ctx.handle_retry_response(input_jwt)
|
|
939
1357
|
|
|
1358
|
+
async def handle_check_inputs_response(self, response: list[tuple[int, bool]]):
|
|
1359
|
+
for idx, lost in response:
|
|
1360
|
+
ctx = self._item_context.get(idx, None)
|
|
1361
|
+
if ctx is not None:
|
|
1362
|
+
if lost:
|
|
1363
|
+
ctx.state = _MapItemState.WAITING_TO_RETRY
|
|
1364
|
+
retry_item = await ctx.create_map_start_or_continue_item(idx)
|
|
1365
|
+
_ = ctx.retry_manager.get_delay_ms() # increment retry count but instant retry for lost inputs
|
|
1366
|
+
await self._retry_queue.put(time.time(), retry_item)
|
|
1367
|
+
|
|
940
1368
|
async def handle_get_outputs_response(self, item: api_pb2.FunctionGetOutputsItem, now_seconds: int) -> _OutputType:
|
|
941
1369
|
ctx = self._item_context.get(item.idx, None)
|
|
942
1370
|
if ctx is None:
|
modal/parallel_map.pyi
CHANGED
|
@@ -70,6 +70,23 @@ def _map_invocation(
|
|
|
70
70
|
count_update_callback: typing.Optional[collections.abc.Callable[[int, int], None]],
|
|
71
71
|
function_call_invocation_type: int,
|
|
72
72
|
): ...
|
|
73
|
+
def _map_invocation_inputplane(
|
|
74
|
+
function: modal._functions._Function,
|
|
75
|
+
raw_input_queue: _SynchronizedQueue,
|
|
76
|
+
client: modal.client._Client,
|
|
77
|
+
order_outputs: bool,
|
|
78
|
+
return_exceptions: bool,
|
|
79
|
+
wrap_returned_exceptions: bool,
|
|
80
|
+
count_update_callback: typing.Optional[collections.abc.Callable[[int, int], None]],
|
|
81
|
+
) -> typing.AsyncGenerator[typing.Any, None]:
|
|
82
|
+
"""Input-plane implementation of a function map invocation.
|
|
83
|
+
|
|
84
|
+
This is analogous to `_map_invocation`, but instead of the control-plane
|
|
85
|
+
`FunctionMap` / `FunctionPutInputs` / `FunctionGetOutputs` RPCs it speaks
|
|
86
|
+
the input-plane protocol consisting of `MapStartOrContinue`, `MapAwait`, and `MapCheckInputs`.
|
|
87
|
+
"""
|
|
88
|
+
...
|
|
89
|
+
|
|
73
90
|
def _map_helper(
|
|
74
91
|
self: modal.functions.Function,
|
|
75
92
|
async_input_gen: typing.AsyncGenerator[typing.Any, None],
|
|
@@ -260,10 +277,12 @@ class _MapItemContext:
|
|
|
260
277
|
input: modal_proto.api_pb2.FunctionInput,
|
|
261
278
|
retry_manager: modal.retries.RetryManager,
|
|
262
279
|
sync_client_retries_enabled: bool,
|
|
280
|
+
is_input_plane_instance: bool = False,
|
|
263
281
|
):
|
|
264
282
|
"""Initialize self. See help(type(self)) for accurate signature."""
|
|
265
283
|
...
|
|
266
284
|
|
|
285
|
+
def handle_map_start_or_continue_response(self, attempt_token: str): ...
|
|
267
286
|
def handle_put_inputs_response(self, item: modal_proto.api_pb2.FunctionPutInputsResponseItem): ...
|
|
268
287
|
async def handle_get_outputs_response(
|
|
269
288
|
self,
|
|
@@ -280,6 +299,7 @@ class _MapItemContext:
|
|
|
280
299
|
|
|
281
300
|
async def prepare_item_for_retry(self) -> modal_proto.api_pb2.FunctionRetryInputsItem: ...
|
|
282
301
|
def handle_retry_response(self, input_jwt: str): ...
|
|
302
|
+
async def create_map_start_or_continue_item(self, idx: int) -> modal_proto.api_pb2.MapStartOrContinueItem: ...
|
|
283
303
|
|
|
284
304
|
class _MapItemsManager:
|
|
285
305
|
def __init__(
|
|
@@ -289,11 +309,13 @@ class _MapItemsManager:
|
|
|
289
309
|
retry_queue: modal._utils.async_utils.TimestampPriorityQueue,
|
|
290
310
|
sync_client_retries_enabled: bool,
|
|
291
311
|
max_inputs_outstanding: int,
|
|
312
|
+
is_input_plane_instance: bool = False,
|
|
292
313
|
):
|
|
293
314
|
"""Initialize self. See help(type(self)) for accurate signature."""
|
|
294
315
|
...
|
|
295
316
|
|
|
296
317
|
async def add_items(self, items: list[modal_proto.api_pb2.FunctionPutInputsItem]): ...
|
|
318
|
+
async def add_items_inputplane(self, items: list[modal_proto.api_pb2.MapStartOrContinueItem]): ...
|
|
297
319
|
async def prepare_items_for_retry(
|
|
298
320
|
self, retriable_idxs: list[int]
|
|
299
321
|
) -> list[modal_proto.api_pb2.FunctionRetryInputsItem]: ...
|
|
@@ -301,10 +323,16 @@ class _MapItemsManager:
|
|
|
301
323
|
"""Returns a list of input_jwts for inputs that are waiting for output."""
|
|
302
324
|
...
|
|
303
325
|
|
|
326
|
+
def get_input_idxs_waiting_for_output(self) -> list[tuple[int, str]]:
|
|
327
|
+
"""Returns a list of input_idxs for inputs that are waiting for output."""
|
|
328
|
+
...
|
|
329
|
+
|
|
304
330
|
def _remove_item(self, item_idx: int): ...
|
|
305
331
|
def get_item_context(self, item_idx: int) -> _MapItemContext: ...
|
|
332
|
+
def handle_put_continue_response(self, items: list[tuple[int, str]]): ...
|
|
306
333
|
def handle_put_inputs_response(self, items: list[modal_proto.api_pb2.FunctionPutInputsResponseItem]): ...
|
|
307
334
|
def handle_retry_response(self, input_jwts: list[str]): ...
|
|
335
|
+
async def handle_check_inputs_response(self, response: list[tuple[int, bool]]): ...
|
|
308
336
|
async def handle_get_outputs_response(
|
|
309
337
|
self, item: modal_proto.api_pb2.FunctionGetOutputsItem, now_seconds: int
|
|
310
338
|
) -> _OutputType: ...
|