modal 0.73.116__py3-none-any.whl → 0.73.126__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
modal/cls.py CHANGED
@@ -21,7 +21,7 @@ from ._partial_function import (
21
21
  )
22
22
  from ._resolver import Resolver
23
23
  from ._resources import convert_fn_config_to_resources_config
24
- from ._serialization import PYTHON_TO_PROTO_TYPE, check_valid_cls_constructor_arg
24
+ from ._serialization import check_valid_cls_constructor_arg, get_proto_parameter_type
25
25
  from ._traceback import print_server_warnings
26
26
  from ._utils.async_utils import synchronize_api, synchronizer
27
27
  from ._utils.deprecation import deprecation_warning, renamed_parameter, warn_on_renamed_autoscaler_settings
@@ -362,15 +362,6 @@ class _Obj:
362
362
  Obj = synchronize_api(_Obj)
363
363
 
364
364
 
365
- def _validate_parameter_type(cls_name: str, parameter_name: str, parameter_type: type):
366
- if parameter_type not in PYTHON_TO_PROTO_TYPE:
367
- type_name = getattr(parameter_type, "__name__", repr(parameter_type))
368
- supported = ", ".join(parameter_type.__name__ for parameter_type in PYTHON_TO_PROTO_TYPE.keys())
369
- raise InvalidError(
370
- f"{cls_name}.{parameter_name}: {type_name} is not a supported parameter type. Use one of: {supported}"
371
- )
372
-
373
-
374
365
  class _Cls(_Object, type_prefix="cs"):
375
366
  """
376
367
  Cls adds method pooling and [lifecycle hook](/docs/guide/lifecycle-functions) behavior
@@ -467,12 +458,11 @@ class _Cls(_Object, type_prefix="cs"):
467
458
  annotations = user_cls.__dict__.get("__annotations__", {}) # compatible with older pythons
468
459
  missing_annotations = params.keys() - annotations.keys()
469
460
  if missing_annotations:
470
- raise InvalidError("All modal.parameter() specifications need to be type annotated")
461
+ raise InvalidError("All modal.parameter() specifications need to be type-annotated")
471
462
 
472
463
  annotated_params = {k: t for k, t in annotations.items() if k in params}
473
464
  for k, t in annotated_params.items():
474
- if t not in PYTHON_TO_PROTO_TYPE:
475
- _validate_parameter_type(user_cls.__name__, k, t)
465
+ get_proto_parameter_type(t)
476
466
 
477
467
  @staticmethod
478
468
  def from_local(user_cls, app: "modal.app._App", class_service_function: _Function) -> "_Cls":
modal/cls.pyi CHANGED
@@ -109,8 +109,6 @@ class Obj:
109
109
  async def _aenter(self): ...
110
110
  def __getattr__(self, k): ...
111
111
 
112
- def _validate_parameter_type(cls_name: str, parameter_name: str, parameter_type: type): ...
113
-
114
112
  class _Cls(modal._object._Object):
115
113
  _class_service_function: typing.Optional[modal._functions._Function]
116
114
  _options: typing.Optional[_ServiceOptions]
modal/functions.pyi CHANGED
@@ -198,11 +198,11 @@ class Function(
198
198
 
199
199
  _call_generator_nowait: ___call_generator_nowait_spec[typing_extensions.Self]
200
200
 
201
- class __remote_spec(typing_extensions.Protocol[ReturnType_INNER, P_INNER, SUPERSELF]):
201
+ class __remote_spec(typing_extensions.Protocol[P_INNER, ReturnType_INNER, SUPERSELF]):
202
202
  def __call__(self, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> ReturnType_INNER: ...
203
203
  async def aio(self, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> ReturnType_INNER: ...
204
204
 
205
- remote: __remote_spec[modal._functions.ReturnType, modal._functions.P, typing_extensions.Self]
205
+ remote: __remote_spec[modal._functions.P, modal._functions.ReturnType, typing_extensions.Self]
206
206
 
207
207
  class __remote_gen_spec(typing_extensions.Protocol[SUPERSELF]):
208
208
  def __call__(self, *args, **kwargs) -> typing.Generator[typing.Any, None, None]: ...
@@ -217,19 +217,19 @@ class Function(
217
217
  self, *args: modal._functions.P.args, **kwargs: modal._functions.P.kwargs
218
218
  ) -> modal._functions.OriginalReturnType: ...
219
219
 
220
- class ___experimental_spawn_spec(typing_extensions.Protocol[ReturnType_INNER, P_INNER, SUPERSELF]):
220
+ class ___experimental_spawn_spec(typing_extensions.Protocol[P_INNER, ReturnType_INNER, SUPERSELF]):
221
221
  def __call__(self, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> FunctionCall[ReturnType_INNER]: ...
222
222
  async def aio(self, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> FunctionCall[ReturnType_INNER]: ...
223
223
 
224
224
  _experimental_spawn: ___experimental_spawn_spec[
225
- modal._functions.ReturnType, modal._functions.P, typing_extensions.Self
225
+ modal._functions.P, modal._functions.ReturnType, typing_extensions.Self
226
226
  ]
227
227
 
228
- class __spawn_spec(typing_extensions.Protocol[ReturnType_INNER, P_INNER, SUPERSELF]):
228
+ class __spawn_spec(typing_extensions.Protocol[P_INNER, ReturnType_INNER, SUPERSELF]):
229
229
  def __call__(self, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> FunctionCall[ReturnType_INNER]: ...
230
230
  async def aio(self, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> FunctionCall[ReturnType_INNER]: ...
231
231
 
232
- spawn: __spawn_spec[modal._functions.ReturnType, modal._functions.P, typing_extensions.Self]
232
+ spawn: __spawn_spec[modal._functions.P, modal._functions.ReturnType, typing_extensions.Self]
233
233
 
234
234
  def get_raw_f(self) -> collections.abc.Callable[..., typing.Any]: ...
235
235
 
modal/parallel_map.py CHANGED
@@ -1,5 +1,6 @@
1
1
  # Copyright Modal Labs 2024
2
2
  import asyncio
3
+ import enum
3
4
  import time
4
5
  import typing
5
6
  from dataclasses import dataclass
@@ -10,6 +11,7 @@ from grpclib import Status
10
11
  from modal._runtime.execution_context import current_input_id
11
12
  from modal._utils.async_utils import (
12
13
  AsyncOrSyncIterable,
14
+ TimestampPriorityQueue,
13
15
  aclosing,
14
16
  async_map_ordered,
15
17
  async_merge,
@@ -28,7 +30,9 @@ from modal._utils.function_utils import (
28
30
  _process_result,
29
31
  )
30
32
  from modal._utils.grpc_utils import RETRYABLE_GRPC_STATUS_CODES, RetryWarningMessage, retry_transient_errors
33
+ from modal._utils.jwt_utils import DecodedJwt
31
34
  from modal.config import logger
35
+ from modal.retries import RetryManager
32
36
  from modal_proto import api_pb2
33
37
 
34
38
  if typing.TYPE_CHECKING:
@@ -66,6 +70,12 @@ class _OutputValue:
66
70
  value: Any
67
71
 
68
72
 
73
+ # maximum number of inputs that can be in progress (either queued to be sent,
74
+ # or waiting for completion). if this limit is reached, we will block sending
75
+ # more inputs to the server until some of the existing inputs are completed.
76
+ MAP_MAX_INPUTS_OUTSTANDING = 1000
77
+
78
+ # maximum number of inputs to send to the server in a single request
69
79
  MAP_INVOCATION_CHUNK_SIZE = 49
70
80
 
71
81
  if typing.TYPE_CHECKING:
@@ -79,6 +89,7 @@ async def _map_invocation(
79
89
  order_outputs: bool,
80
90
  return_exceptions: bool,
81
91
  count_update_callback: Optional[Callable[[int, int], None]],
92
+ function_call_invocation_type: "api_pb2.FunctionCallInvocationType.ValueType",
82
93
  ):
83
94
  assert client.stub
84
95
  request = api_pb2.FunctionMapRequest(
@@ -86,28 +97,43 @@ async def _map_invocation(
86
97
  parent_input_id=current_input_id() or "",
87
98
  function_call_type=api_pb2.FUNCTION_CALL_TYPE_MAP,
88
99
  return_exceptions=return_exceptions,
100
+ function_call_invocation_type=function_call_invocation_type,
89
101
  )
90
- response = await retry_transient_errors(client.stub.FunctionMap, request)
102
+ response: api_pb2.FunctionMapResponse = await retry_transient_errors(client.stub.FunctionMap, request)
91
103
 
92
104
  function_call_id = response.function_call_id
105
+ function_call_jwt = response.function_call_jwt
106
+ retry_policy = response.retry_policy
107
+ sync_client_retries_enabled = response.sync_client_retries_enabled
93
108
 
94
109
  have_all_inputs = False
95
- num_inputs = 0
96
- num_outputs = 0
110
+ inputs_created = 0
111
+ inputs_sent = 0
112
+ inputs_retried = 0
113
+ outputs_completed = 0
114
+ outputs_received = 0
115
+ retried_outputs = 0
116
+ successful_completions = 0
117
+ failed_completions = 0
118
+ already_complete_duplicates = 0
119
+ stale_retry_duplicates = 0
120
+ no_context_duplicates = 0
97
121
 
98
122
  def count_update():
99
123
  if count_update_callback is not None:
100
- count_update_callback(num_outputs, num_inputs)
124
+ count_update_callback(outputs_completed, inputs_created)
101
125
 
102
- pending_outputs: dict[str, int] = {} # Map input_id -> next expected gen_index value
126
+ retry_queue = TimestampPriorityQueue()
103
127
  completed_outputs: set[str] = set() # Set of input_ids whose outputs are complete (expecting no more values)
104
-
105
- input_queue: asyncio.Queue = asyncio.Queue()
128
+ input_queue: asyncio.Queue[api_pb2.FunctionPutInputsItem | None] = asyncio.Queue()
129
+ map_items_manager = _MapItemsManager(
130
+ retry_policy, function_call_invocation_type, retry_queue, sync_client_retries_enabled
131
+ )
106
132
 
107
133
  async def create_input(argskwargs):
108
- nonlocal num_inputs
109
- idx = num_inputs
110
- num_inputs += 1
134
+ nonlocal inputs_created
135
+ idx = inputs_created
136
+ inputs_created += 1
111
137
  (args, kwargs) = argskwargs
112
138
  return await _create_input(args, kwargs, client, idx=idx, method_name=function._use_method_name)
113
139
 
@@ -119,6 +145,8 @@ async def _map_invocation(
119
145
  yield raw_input # args, kwargs
120
146
 
121
147
  async def drain_input_generator():
148
+ nonlocal have_all_inputs
149
+
122
150
  # Parallelize uploading blobs
123
151
  async with aclosing(
124
152
  async_map_ordered(input_iter(), create_input, concurrency=BLOB_MAX_PARALLELISM)
@@ -132,49 +160,100 @@ async def _map_invocation(
132
160
 
133
161
  async def pump_inputs():
134
162
  assert client.stub
135
- nonlocal have_all_inputs, num_inputs
136
- async for items in queue_batch_iterator(input_queue, MAP_INVOCATION_CHUNK_SIZE):
163
+ nonlocal have_all_inputs, inputs_created, inputs_sent
164
+ async for items in queue_batch_iterator(input_queue, max_batch_size=MAP_INVOCATION_CHUNK_SIZE):
165
+ # Add items to the manager. Their state will be SENDING.
166
+ await map_items_manager.add_items(items)
137
167
  request = api_pb2.FunctionPutInputsRequest(
138
- function_id=function.object_id, inputs=items, function_call_id=function_call_id
168
+ function_id=function.object_id,
169
+ inputs=items,
170
+ function_call_id=function_call_id,
139
171
  )
140
172
  logger.debug(
141
173
  f"Pushing {len(items)} inputs to server. Num queued inputs awaiting push is {input_queue.qsize()}."
142
174
  )
143
- # with 8 retries we log the warning below about every 30 seconds which isn't too spammy.
144
- retry_warning_message = RetryWarningMessage(
145
- message=f"Warning: map progress for function {function._function_name} is limited."
146
- " Common bottlenecks include slow iteration over results, or function backlogs.",
147
- warning_interval=8,
148
- errors_to_warn_for=[Status.RESOURCE_EXHAUSTED])
149
- resp = await retry_transient_errors(
150
- client.stub.FunctionPutInputs,
151
- request,
152
- max_retries=None,
153
- max_delay=PUMP_INPUTS_MAX_RETRY_DELAY,
154
- additional_status_codes=[Status.RESOURCE_EXHAUSTED],
155
- retry_warning_message=retry_warning_message)
175
+
176
+ resp = await send_inputs(client.stub.FunctionPutInputs, request)
156
177
  count_update()
157
- for item in resp.inputs:
158
- pending_outputs.setdefault(item.input_id, 0)
178
+ inputs_sent += len(items)
179
+ # Change item state to WAITING_FOR_OUTPUT, and set the input_id and input_jwt which are in the response.
180
+ map_items_manager.handle_put_inputs_response(resp.inputs)
159
181
  logger.debug(
160
182
  f"Successfully pushed {len(items)} inputs to server. "
161
183
  f"Num queued inputs awaiting push is {input_queue.qsize()}."
162
184
  )
163
-
164
185
  have_all_inputs = True
165
186
  yield
166
187
 
188
+ async def retry_inputs():
189
+ nonlocal inputs_retried
190
+ async for retriable_idxs in queue_batch_iterator(retry_queue, max_batch_size=MAP_INVOCATION_CHUNK_SIZE):
191
+ # For each index, use the context in the manager to create a FunctionRetryInputsItem.
192
+ # This will also update the context state to RETRYING.
193
+ inputs: list[api_pb2.FunctionRetryInputsItem] = await map_items_manager.prepare_items_for_retry(
194
+ retriable_idxs
195
+ )
196
+ request = api_pb2.FunctionRetryInputsRequest(
197
+ function_call_jwt=function_call_jwt,
198
+ inputs=inputs,
199
+ )
200
+ resp = await send_inputs(client.stub.FunctionRetryInputs, request)
201
+ # Update the state to WAITING_FOR_OUTPUT, and update the input_jwt in the context
202
+ # to the new value in the response.
203
+ map_items_manager.handle_retry_response(resp.input_jwts)
204
+ logger.debug(f"Successfully pushed retry for {len(inputs)} to server.")
205
+ inputs_retried += len(inputs)
206
+ yield
207
+
208
+ async def send_inputs(
209
+ fn: "modal.client.UnaryUnaryWrapper",
210
+ request: typing.Union[api_pb2.FunctionPutInputsRequest, api_pb2.FunctionRetryInputsRequest],
211
+ ) -> typing.Union[api_pb2.FunctionPutInputsResponse, api_pb2.FunctionRetryInputsResponse]:
212
+ # with 8 retries we log the warning below about every 30 seconds which isn't too spammy.
213
+ retry_warning_message = RetryWarningMessage(
214
+ message=f"Warning: map progress for function {function._function_name} is limited."
215
+ " Common bottlenecks include slow iteration over results, or function backlogs.",
216
+ warning_interval=8,
217
+ errors_to_warn_for=[Status.RESOURCE_EXHAUSTED],
218
+ )
219
+ return await retry_transient_errors(
220
+ fn,
221
+ request,
222
+ max_retries=None,
223
+ max_delay=PUMP_INPUTS_MAX_RETRY_DELAY,
224
+ additional_status_codes=[Status.RESOURCE_EXHAUSTED],
225
+ retry_warning_message=retry_warning_message,
226
+ )
227
+
167
228
  async def get_all_outputs():
168
229
  assert client.stub
169
- nonlocal num_inputs, num_outputs, have_all_inputs
230
+ nonlocal \
231
+ inputs_created, \
232
+ successful_completions, \
233
+ failed_completions, \
234
+ outputs_completed, \
235
+ have_all_inputs, \
236
+ outputs_received, \
237
+ already_complete_duplicates, \
238
+ no_context_duplicates, \
239
+ stale_retry_duplicates, \
240
+ retried_outputs
241
+
170
242
  last_entry_id = "0-0"
171
- while not have_all_inputs or len(pending_outputs) > len(completed_outputs):
243
+
244
+ while not have_all_inputs or outputs_completed < inputs_created:
245
+ logger.debug(f"Requesting outputs. Have {outputs_completed} outputs, {inputs_created} inputs.")
246
+ # Get input_jwts of all items in the WAITING_FOR_OUTPUT state.
247
+ # The server uses these to track for lost inputs.
248
+ input_jwts = [input_jwt for input_jwt in map_items_manager.get_input_jwts_waiting_for_output()]
249
+
172
250
  request = api_pb2.FunctionGetOutputsRequest(
173
251
  function_call_id=function_call_id,
174
252
  timeout=OUTPUTS_TIMEOUT,
175
253
  last_entry_id=last_entry_id,
176
254
  clear_on_success=False,
177
255
  requested_at=time.time(),
256
+ input_jwts=input_jwts,
178
257
  )
179
258
  response = await retry_transient_errors(
180
259
  client.stub.FunctionGetOutputs,
@@ -183,19 +262,31 @@ async def _map_invocation(
183
262
  attempt_timeout=OUTPUTS_TIMEOUT + ATTEMPT_TIMEOUT_GRACE_PERIOD,
184
263
  )
185
264
 
186
- if len(response.outputs) == 0:
187
- continue
188
-
189
265
  last_entry_id = response.last_entry_id
266
+ now_seconds = int(time.time())
190
267
  for item in response.outputs:
191
- pending_outputs.setdefault(item.input_id, 0)
192
- if item.input_id in completed_outputs:
193
- # If this input is already completed, it means the output has already been
194
- # processed and was received again due to a duplicate.
195
- continue
196
- completed_outputs.add(item.input_id)
197
- num_outputs += 1
198
- yield item
268
+ outputs_received += 1
269
+ # If the output failed, and there are retries remaining, the input will be placed on the
270
+ # retry queue, and state updated to WAITING_FOR_RETRY. Otherwise, the output is considered
271
+ # complete and the item is removed from the manager.
272
+ output_type = await map_items_manager.handle_get_outputs_response(item, now_seconds)
273
+ if output_type == _OutputType.SUCCESSFUL_COMPLETION:
274
+ successful_completions += 1
275
+ elif output_type == _OutputType.FAILED_COMPLETION:
276
+ failed_completions += 1
277
+ elif output_type == _OutputType.NO_CONTEXT_DUPLICATE:
278
+ no_context_duplicates += 1
279
+ elif output_type == _OutputType.STALE_RETRY_DUPLICATE:
280
+ stale_retry_duplicates += 1
281
+ elif output_type == _OutputType.ALREADY_COMPLETE_DUPLICATE:
282
+ already_complete_duplicates += 1
283
+ elif output_type == _OutputType.RETRYING:
284
+ retried_outputs += 1
285
+
286
+ if output_type == _OutputType.SUCCESSFUL_COMPLETION or output_type == _OutputType.FAILED_COMPLETION:
287
+ completed_outputs.add(item.input_id)
288
+ outputs_completed += 1
289
+ yield item
199
290
 
200
291
  async def get_all_outputs_and_clean_up():
201
292
  assert client.stub
@@ -213,6 +304,7 @@ async def _map_invocation(
213
304
  requested_at=time.time(),
214
305
  )
215
306
  await retry_transient_errors(client.stub.FunctionGetOutputs, request)
307
+ await retry_queue.close()
216
308
 
217
309
  async def fetch_output(item: api_pb2.FunctionGetOutputsItem) -> tuple[int, Any]:
218
310
  try:
@@ -239,17 +331,50 @@ async def _map_invocation(
239
331
  else:
240
332
  # hold on to outputs for function maps, so we can reorder them correctly.
241
333
  received_outputs[idx] = output
242
- while output_idx in received_outputs:
334
+
335
+ while True:
336
+ if output_idx not in received_outputs:
337
+ # we haven't received the output for the current index yet.
338
+ # stop returning outputs to the caller and instead wait for
339
+ # the next output to arrive from the server.
340
+ break
341
+
243
342
  output = received_outputs.pop(output_idx)
244
343
  yield _OutputValue(output)
245
344
  output_idx += 1
246
345
 
247
346
  assert len(received_outputs) == 0
248
347
 
249
- async with aclosing(async_merge(drain_input_generator(), pump_inputs(), poll_outputs())) as streamer:
348
+ async def log_debug_stats():
349
+ def log_stats():
350
+ logger.debug(
351
+ f"Map stats: sync_client_retries_enabled={sync_client_retries_enabled} "
352
+ f"have_all_inputs={have_all_inputs} inputs_created={inputs_created} input_sent={inputs_sent} "
353
+ f"inputs_retried={inputs_retried} outputs_received={outputs_received} "
354
+ f"successful_completions={successful_completions} failed_completions={failed_completions} "
355
+ f"no_context_duplicates={no_context_duplicates} old_retry_duplicates={stale_retry_duplicates} "
356
+ f"already_complete_duplicates={already_complete_duplicates} "
357
+ f"retried_outputs={retried_outputs} input_queue_size={input_queue.qsize()} "
358
+ f"retry_queue_size={retry_queue.qsize()} map_items_manager={len(map_items_manager)}"
359
+ )
360
+ while True:
361
+ log_stats()
362
+ try:
363
+ await asyncio.sleep(10)
364
+ except asyncio.CancelledError:
365
+ # Log final stats before exiting
366
+ log_stats()
367
+ break
368
+
369
+ log_debug_stats_task = asyncio.create_task(log_debug_stats())
370
+ async with aclosing(
371
+ async_merge(drain_input_generator(), pump_inputs(), poll_outputs(), retry_inputs())
372
+ ) as streamer:
250
373
  async for response in streamer:
251
374
  if response is not None:
252
375
  yield response.value
376
+ log_debug_stats_task.cancel()
377
+ await log_debug_stats_task
253
378
 
254
379
 
255
380
  @warn_if_generator_is_not_consumed(function_name="Function.map")
@@ -431,3 +556,227 @@ def _starmap_sync(
431
556
  "Use Function.map.aio()/Function.for_each.aio() instead."
432
557
  ),
433
558
  )
559
+
560
+
561
+ class _MapItemState(enum.Enum):
562
+ # The input is being sent the server with a PutInputs request, but the response has not been received yet.
563
+ SENDING = 1
564
+ # A call to either PutInputs or FunctionRetry has completed, and we are waiting to receive the output.
565
+ WAITING_FOR_OUTPUT = 2
566
+ # The input is on the retry queue, and waiting for its delay to expire.
567
+ WAITING_TO_RETRY = 3
568
+ # The input is being sent to the server with a FunctionRetry request, but the response has not been received yet.
569
+ RETRYING = 4
570
+ # The output has been received and was either successful, or failed with no more retries remaining.
571
+ COMPLETE = 5
572
+
573
+ class _OutputType(enum.Enum):
574
+ SUCCESSFUL_COMPLETION = 1
575
+ FAILED_COMPLETION = 2
576
+ RETRYING = 3
577
+ ALREADY_COMPLETE_DUPLICATE = 4
578
+ STALE_RETRY_DUPLICATE = 5
579
+ NO_CONTEXT_DUPLICATE = 6
580
+
581
+ class _MapItemContext:
582
+ state: _MapItemState
583
+ input: api_pb2.FunctionInput
584
+ retry_manager: RetryManager
585
+ sync_client_retries_enabled:bool
586
+ # Both these futures are strings. Omitting generic type because
587
+ # it causes an error when running `inv protoc type-stubs`.
588
+ input_id: asyncio.Future
589
+ input_jwt: asyncio.Future
590
+ previous_input_jwt: Optional[str]
591
+ _event_loop: asyncio.AbstractEventLoop
592
+
593
+ def __init__(self, input: api_pb2.FunctionInput, retry_manager: RetryManager, sync_client_retries_enabled: bool):
594
+ self.state = _MapItemState.SENDING
595
+ self.input = input
596
+ self.retry_manager = retry_manager
597
+ self.sync_client_retries_enabled = sync_client_retries_enabled
598
+ self._event_loop = asyncio.get_event_loop()
599
+ # create a future for each input, to be resolved when we have
600
+ # received the input ID and JWT from the server. this addresses
601
+ # a race condition where we could receive outputs before we have
602
+ # recorded the input ID and JWT in `pending_outputs`.
603
+ self.input_jwt = self._event_loop.create_future()
604
+ self.input_id = self._event_loop.create_future()
605
+
606
+ def handle_put_inputs_response(self, item: api_pb2.FunctionPutInputsResponseItem):
607
+ self.input_jwt.set_result(item.input_jwt)
608
+ self.input_id.set_result(item.input_id)
609
+ # Set state to WAITING_FOR_OUTPUT only if current state is SENDING. If state is
610
+ # RETRYING, WAITING_TO_RETRY, or COMPLETE, then we already got the output.
611
+ if self.state == _MapItemState.SENDING:
612
+ self.state = _MapItemState.WAITING_FOR_OUTPUT
613
+
614
+ async def handle_get_outputs_response(
615
+ self,
616
+ item: api_pb2.FunctionGetOutputsItem,
617
+ now_seconds: int,
618
+ function_call_invocation_type: "api_pb2.FunctionCallInvocationType.ValueType",
619
+ retry_queue: TimestampPriorityQueue,
620
+ ) -> _OutputType:
621
+ """
622
+ Processes the output, and determines if it is complete or needs to be retried.
623
+
624
+ Return True if input state was changed to COMPLETE, otherwise False.
625
+ """
626
+ # If the item is already complete, this is a duplicate output and can be ignored.
627
+ if self.state == _MapItemState.COMPLETE:
628
+ logger.debug(
629
+ f"Received output for input marked as complete. Must be duplicate, so ignoring. "
630
+ f"idx={item.idx} input_id={item.input_id}, retry_count={item.retry_count}"
631
+ )
632
+ return _OutputType.ALREADY_COMPLETE_DUPLICATE
633
+ # If the item's retry count doesn't match our retry count, this is probably a duplicate of an old output.
634
+ if item.retry_count != self.retry_manager.retry_count:
635
+ logger.debug(
636
+ f"Received output with stale retry_count, so ignoring. "
637
+ f"idx={item.idx} input_id={item.input_id} retry_count={item.retry_count} "
638
+ f"expected_retry_count={self.retry_manager.retry_count}"
639
+ )
640
+ return _OutputType.STALE_RETRY_DUPLICATE
641
+
642
+ # retry failed inputs when the function call invocation type is SYNC
643
+ if (
644
+ item.result.status == api_pb2.GenericResult.GENERIC_STATUS_SUCCESS
645
+ or function_call_invocation_type != api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC
646
+ or not self.sync_client_retries_enabled
647
+ ):
648
+ self.state = _MapItemState.COMPLETE
649
+ if item.result.status == api_pb2.GenericResult.GENERIC_STATUS_SUCCESS:
650
+ return _OutputType.SUCCESSFUL_COMPLETION
651
+ else:
652
+ return _OutputType.FAILED_COMPLETION
653
+
654
+ # Get the retry delay and increment the retry count.
655
+ # TODO(ryan): We must call this for lost inputs - even though we will set the retry delay to 0 later -
656
+ # because we must increment the retry count. That's awkward, let's come up with something better.
657
+ # TODO(ryan):To maintain parity with server-side retries, retrying lost inputs should not count towards
658
+ # the retry policy. However we use the retry_count number as a unique identifier on each attempt to:
659
+ # 1) ignore duplicate outputs
660
+ # 2) ignore late outputs received from previous attempts
661
+ # 3) avoid a server race condition between FunctionRetry and GetOutputs that results in deleted input metadata
662
+ # For now, lost inputs will count towards the retry policy. But let's address this in another PR, perhaps by
663
+ # tracking total attempts and attempts which count towards the retry policy separately.
664
+ delay_ms = self.retry_manager.get_delay_ms()
665
+
666
+ # For system failures on the server, we retry immediately.
667
+ # and the failure does not count towards the retry policy.
668
+ if item.result.status == api_pb2.GenericResult.GENERIC_STATUS_INTERNAL_FAILURE:
669
+ delay_ms = 0
670
+
671
+ # None means the maximum number of retries has been reached, so output the error
672
+ if delay_ms is None:
673
+ self.state = _MapItemState.COMPLETE
674
+ return _OutputType.FAILED_COMPLETION
675
+
676
+ self.state = _MapItemState.WAITING_TO_RETRY
677
+ await retry_queue.put(now_seconds + (delay_ms / 1000), item.idx)
678
+
679
+ return _OutputType.RETRYING
680
+
681
+ async def prepare_item_for_retry(self) -> api_pb2.FunctionRetryInputsItem:
682
+ self.state = _MapItemState.RETRYING
683
+ # If the input_jwt is not set, then put_inputs hasn't returned yet. Block until we have it.
684
+ input_jwt = await self.input_jwt
685
+ self.input_jwt = self._event_loop.create_future()
686
+ return api_pb2.FunctionRetryInputsItem(
687
+ input_jwt=input_jwt,
688
+ input=self.input,
689
+ retry_count=self.retry_manager.retry_count,
690
+ )
691
+
692
+ def handle_retry_response(self, input_jwt: str):
693
+ self.input_jwt.set_result(input_jwt)
694
+ self.state = _MapItemState.WAITING_FOR_OUTPUT
695
+
696
+
697
+ class _MapItemsManager:
698
+ def __init__(
699
+ self,
700
+ retry_policy: api_pb2.FunctionRetryPolicy,
701
+ function_call_invocation_type: "api_pb2.FunctionCallInvocationType.ValueType",
702
+ retry_queue: TimestampPriorityQueue,
703
+ sync_client_retries_enabled: bool
704
+ ):
705
+ self._retry_policy = retry_policy
706
+ self.function_call_invocation_type = function_call_invocation_type
707
+ self._retry_queue = retry_queue
708
+ # semaphore to limit the number of inputs that can be in progress at once
709
+ self._inputs_outstanding = asyncio.BoundedSemaphore(MAP_MAX_INPUTS_OUTSTANDING)
710
+ self._item_context: dict[int, _MapItemContext] = {}
711
+ self._sync_client_retries_enabled = sync_client_retries_enabled
712
+
713
+ async def add_items(self, items: list[api_pb2.FunctionPutInputsItem]):
714
+ for item in items:
715
+ # acquire semaphore to limit the number of inputs in progress
716
+ # (either queued to be sent, waiting for completion, or retrying)
717
+ await self._inputs_outstanding.acquire()
718
+ self._item_context[item.idx] = _MapItemContext(
719
+ input=item.input,
720
+ retry_manager=RetryManager(self._retry_policy),
721
+ sync_client_retries_enabled=self._sync_client_retries_enabled,
722
+ )
723
+
724
+ async def prepare_items_for_retry(self, retriable_idxs: list[int]) -> list[api_pb2.FunctionRetryInputsItem]:
725
+ return [await self._item_context[idx].prepare_item_for_retry() for idx in retriable_idxs]
726
+
727
+ def get_input_jwts_waiting_for_output(self) -> list[str]:
728
+ """
729
+ Returns a list of input_jwts for inputs that are waiting for output.
730
+ """
731
+ # If input_jwt is not done, the call to PutInputs has not completed, so omit it from results.
732
+ return [
733
+ ctx.input_jwt.result()
734
+ for ctx in self._item_context.values()
735
+ if ctx.state == _MapItemState.WAITING_FOR_OUTPUT and ctx.input_jwt.done()
736
+ ]
737
+
738
+ def _remove_item(self, item_idx: int):
739
+ del self._item_context[item_idx]
740
+ self._inputs_outstanding.release()
741
+
742
+ def get_item_context(self, item_idx: int) -> _MapItemContext:
743
+ return self._item_context.get(item_idx)
744
+
745
+ def handle_put_inputs_response(self, items: list[api_pb2.FunctionPutInputsResponseItem]):
746
+ for item in items:
747
+ ctx = self._item_context.get(item.idx, None)
748
+ # If the context is None, then get_all_outputs() has already received a successful
749
+ # output, and deleted the context. This happens if FunctionGetOutputs completes
750
+ # before FunctionPutInputsResponse is received.
751
+ if ctx is not None:
752
+ ctx.handle_put_inputs_response(item)
753
+
754
+ def handle_retry_response(self, input_jwts: list[str]):
755
+ for input_jwt in input_jwts:
756
+ decoded_jwt = DecodedJwt.decode_without_verification(input_jwt)
757
+ ctx = self._item_context.get(decoded_jwt.payload["idx"], None)
758
+ # If the context is None, then get_all_outputs() has already received a successful
759
+ # output, and deleted the context. This happens if FunctionGetOutputs completes
760
+ # before FunctionRetryInputsResponse is received.
761
+ if ctx is not None:
762
+ ctx.handle_retry_response(input_jwt)
763
+
764
+ async def handle_get_outputs_response(self, item: api_pb2.FunctionGetOutputsItem, now_seconds: int) -> _OutputType:
765
+ ctx = self._item_context.get(item.idx, None)
766
+ if ctx is None:
767
+ # We've already processed this output, so we can skip it.
768
+ # This can happen because the worker can sometimes send duplicate outputs.
769
+ logger.debug(
770
+ f"Received output that does not have entry in item_context map, so ignoring. "
771
+ f"idx={item.idx} input_id={item.input_id} retry_count={item.retry_count} "
772
+ )
773
+ return _OutputType.NO_CONTEXT_DUPLICATE
774
+ output_type = await ctx.handle_get_outputs_response(
775
+ item, now_seconds, self.function_call_invocation_type, self._retry_queue
776
+ )
777
+ if output_type == _OutputType.SUCCESSFUL_COMPLETION or output_type == _OutputType.FAILED_COMPLETION:
778
+ self._remove_item(item.idx)
779
+ return output_type
780
+
781
+ def __len__(self):
782
+ return len(self._item_context)