modal 0.73.115__py3-none-any.whl → 0.73.126__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- modal/_functions.py +15 -6
- modal/_runtime/container_io_manager.py +13 -9
- modal/_runtime/container_io_manager.pyi +7 -4
- modal/_serialization.py +92 -44
- modal/_utils/async_utils.py +71 -6
- modal/_utils/function_utils.py +33 -13
- modal/_utils/jwt_utils.py +38 -0
- modal/cli/app.py +15 -0
- modal/client.pyi +2 -2
- modal/cls.py +3 -13
- modal/cls.pyi +0 -2
- modal/functions.pyi +6 -6
- modal/image.py +2 -0
- modal/parallel_map.py +393 -44
- modal/parallel_map.pyi +75 -0
- modal/retries.py +11 -9
- {modal-0.73.115.dist-info → modal-0.73.126.dist-info}/METADATA +1 -1
- {modal-0.73.115.dist-info → modal-0.73.126.dist-info}/RECORD +30 -29
- {modal-0.73.115.dist-info → modal-0.73.126.dist-info}/WHEEL +1 -1
- modal_proto/api.proto +13 -0
- modal_proto/api_grpc.py +16 -0
- modal_proto/api_pb2.py +284 -263
- modal_proto/api_pb2.pyi +43 -0
- modal_proto/api_pb2_grpc.py +33 -0
- modal_proto/api_pb2_grpc.pyi +10 -0
- modal_proto/modal_api_grpc.py +1 -0
- modal_version/_version_generated.py +1 -1
- {modal-0.73.115.dist-info → modal-0.73.126.dist-info}/LICENSE +0 -0
- {modal-0.73.115.dist-info → modal-0.73.126.dist-info}/entry_points.txt +0 -0
- {modal-0.73.115.dist-info → modal-0.73.126.dist-info}/top_level.txt +0 -0
modal/cls.py
CHANGED
@@ -21,7 +21,7 @@ from ._partial_function import (
|
|
21
21
|
)
|
22
22
|
from ._resolver import Resolver
|
23
23
|
from ._resources import convert_fn_config_to_resources_config
|
24
|
-
from ._serialization import
|
24
|
+
from ._serialization import check_valid_cls_constructor_arg, get_proto_parameter_type
|
25
25
|
from ._traceback import print_server_warnings
|
26
26
|
from ._utils.async_utils import synchronize_api, synchronizer
|
27
27
|
from ._utils.deprecation import deprecation_warning, renamed_parameter, warn_on_renamed_autoscaler_settings
|
@@ -362,15 +362,6 @@ class _Obj:
|
|
362
362
|
Obj = synchronize_api(_Obj)
|
363
363
|
|
364
364
|
|
365
|
-
def _validate_parameter_type(cls_name: str, parameter_name: str, parameter_type: type):
|
366
|
-
if parameter_type not in PYTHON_TO_PROTO_TYPE:
|
367
|
-
type_name = getattr(parameter_type, "__name__", repr(parameter_type))
|
368
|
-
supported = ", ".join(parameter_type.__name__ for parameter_type in PYTHON_TO_PROTO_TYPE.keys())
|
369
|
-
raise InvalidError(
|
370
|
-
f"{cls_name}.{parameter_name}: {type_name} is not a supported parameter type. Use one of: {supported}"
|
371
|
-
)
|
372
|
-
|
373
|
-
|
374
365
|
class _Cls(_Object, type_prefix="cs"):
|
375
366
|
"""
|
376
367
|
Cls adds method pooling and [lifecycle hook](/docs/guide/lifecycle-functions) behavior
|
@@ -467,12 +458,11 @@ class _Cls(_Object, type_prefix="cs"):
|
|
467
458
|
annotations = user_cls.__dict__.get("__annotations__", {}) # compatible with older pythons
|
468
459
|
missing_annotations = params.keys() - annotations.keys()
|
469
460
|
if missing_annotations:
|
470
|
-
raise InvalidError("All modal.parameter() specifications need to be type
|
461
|
+
raise InvalidError("All modal.parameter() specifications need to be type-annotated")
|
471
462
|
|
472
463
|
annotated_params = {k: t for k, t in annotations.items() if k in params}
|
473
464
|
for k, t in annotated_params.items():
|
474
|
-
|
475
|
-
_validate_parameter_type(user_cls.__name__, k, t)
|
465
|
+
get_proto_parameter_type(t)
|
476
466
|
|
477
467
|
@staticmethod
|
478
468
|
def from_local(user_cls, app: "modal.app._App", class_service_function: _Function) -> "_Cls":
|
modal/cls.pyi
CHANGED
@@ -109,8 +109,6 @@ class Obj:
|
|
109
109
|
async def _aenter(self): ...
|
110
110
|
def __getattr__(self, k): ...
|
111
111
|
|
112
|
-
def _validate_parameter_type(cls_name: str, parameter_name: str, parameter_type: type): ...
|
113
|
-
|
114
112
|
class _Cls(modal._object._Object):
|
115
113
|
_class_service_function: typing.Optional[modal._functions._Function]
|
116
114
|
_options: typing.Optional[_ServiceOptions]
|
modal/functions.pyi
CHANGED
@@ -198,11 +198,11 @@ class Function(
|
|
198
198
|
|
199
199
|
_call_generator_nowait: ___call_generator_nowait_spec[typing_extensions.Self]
|
200
200
|
|
201
|
-
class __remote_spec(typing_extensions.Protocol[
|
201
|
+
class __remote_spec(typing_extensions.Protocol[P_INNER, ReturnType_INNER, SUPERSELF]):
|
202
202
|
def __call__(self, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> ReturnType_INNER: ...
|
203
203
|
async def aio(self, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> ReturnType_INNER: ...
|
204
204
|
|
205
|
-
remote: __remote_spec[modal._functions.
|
205
|
+
remote: __remote_spec[modal._functions.P, modal._functions.ReturnType, typing_extensions.Self]
|
206
206
|
|
207
207
|
class __remote_gen_spec(typing_extensions.Protocol[SUPERSELF]):
|
208
208
|
def __call__(self, *args, **kwargs) -> typing.Generator[typing.Any, None, None]: ...
|
@@ -217,19 +217,19 @@ class Function(
|
|
217
217
|
self, *args: modal._functions.P.args, **kwargs: modal._functions.P.kwargs
|
218
218
|
) -> modal._functions.OriginalReturnType: ...
|
219
219
|
|
220
|
-
class ___experimental_spawn_spec(typing_extensions.Protocol[
|
220
|
+
class ___experimental_spawn_spec(typing_extensions.Protocol[P_INNER, ReturnType_INNER, SUPERSELF]):
|
221
221
|
def __call__(self, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> FunctionCall[ReturnType_INNER]: ...
|
222
222
|
async def aio(self, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> FunctionCall[ReturnType_INNER]: ...
|
223
223
|
|
224
224
|
_experimental_spawn: ___experimental_spawn_spec[
|
225
|
-
modal._functions.
|
225
|
+
modal._functions.P, modal._functions.ReturnType, typing_extensions.Self
|
226
226
|
]
|
227
227
|
|
228
|
-
class __spawn_spec(typing_extensions.Protocol[
|
228
|
+
class __spawn_spec(typing_extensions.Protocol[P_INNER, ReturnType_INNER, SUPERSELF]):
|
229
229
|
def __call__(self, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> FunctionCall[ReturnType_INNER]: ...
|
230
230
|
async def aio(self, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> FunctionCall[ReturnType_INNER]: ...
|
231
231
|
|
232
|
-
spawn: __spawn_spec[modal._functions.
|
232
|
+
spawn: __spawn_spec[modal._functions.P, modal._functions.ReturnType, typing_extensions.Self]
|
233
233
|
|
234
234
|
def get_raw_f(self) -> collections.abc.Callable[..., typing.Any]: ...
|
235
235
|
|
modal/image.py
CHANGED
@@ -861,6 +861,8 @@ class _Image(_Object, type_prefix="im"):
|
|
861
861
|
|
862
862
|
*Added in v0.67.28*: This method replaces the deprecated `modal.Mount.from_local_python_packages` pattern.
|
863
863
|
"""
|
864
|
+
if not all(isinstance(module, str) for module in modules):
|
865
|
+
raise InvalidError("Local Python modules must be specified as strings.")
|
864
866
|
mount = _Mount._from_local_python_packages(*modules, ignore=ignore)
|
865
867
|
img = self._add_mount_layer_or_copy(mount, copy=copy)
|
866
868
|
img._added_python_source_set |= set(modules)
|
modal/parallel_map.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
# Copyright Modal Labs 2024
|
2
2
|
import asyncio
|
3
|
+
import enum
|
3
4
|
import time
|
4
5
|
import typing
|
5
6
|
from dataclasses import dataclass
|
@@ -10,6 +11,7 @@ from grpclib import Status
|
|
10
11
|
from modal._runtime.execution_context import current_input_id
|
11
12
|
from modal._utils.async_utils import (
|
12
13
|
AsyncOrSyncIterable,
|
14
|
+
TimestampPriorityQueue,
|
13
15
|
aclosing,
|
14
16
|
async_map_ordered,
|
15
17
|
async_merge,
|
@@ -28,7 +30,9 @@ from modal._utils.function_utils import (
|
|
28
30
|
_process_result,
|
29
31
|
)
|
30
32
|
from modal._utils.grpc_utils import RETRYABLE_GRPC_STATUS_CODES, RetryWarningMessage, retry_transient_errors
|
33
|
+
from modal._utils.jwt_utils import DecodedJwt
|
31
34
|
from modal.config import logger
|
35
|
+
from modal.retries import RetryManager
|
32
36
|
from modal_proto import api_pb2
|
33
37
|
|
34
38
|
if typing.TYPE_CHECKING:
|
@@ -66,6 +70,12 @@ class _OutputValue:
|
|
66
70
|
value: Any
|
67
71
|
|
68
72
|
|
73
|
+
# maximum number of inputs that can be in progress (either queued to be sent,
|
74
|
+
# or waiting for completion). if this limit is reached, we will block sending
|
75
|
+
# more inputs to the server until some of the existing inputs are completed.
|
76
|
+
MAP_MAX_INPUTS_OUTSTANDING = 1000
|
77
|
+
|
78
|
+
# maximum number of inputs to send to the server in a single request
|
69
79
|
MAP_INVOCATION_CHUNK_SIZE = 49
|
70
80
|
|
71
81
|
if typing.TYPE_CHECKING:
|
@@ -79,6 +89,7 @@ async def _map_invocation(
|
|
79
89
|
order_outputs: bool,
|
80
90
|
return_exceptions: bool,
|
81
91
|
count_update_callback: Optional[Callable[[int, int], None]],
|
92
|
+
function_call_invocation_type: "api_pb2.FunctionCallInvocationType.ValueType",
|
82
93
|
):
|
83
94
|
assert client.stub
|
84
95
|
request = api_pb2.FunctionMapRequest(
|
@@ -86,28 +97,43 @@ async def _map_invocation(
|
|
86
97
|
parent_input_id=current_input_id() or "",
|
87
98
|
function_call_type=api_pb2.FUNCTION_CALL_TYPE_MAP,
|
88
99
|
return_exceptions=return_exceptions,
|
100
|
+
function_call_invocation_type=function_call_invocation_type,
|
89
101
|
)
|
90
|
-
response = await retry_transient_errors(client.stub.FunctionMap, request)
|
102
|
+
response: api_pb2.FunctionMapResponse = await retry_transient_errors(client.stub.FunctionMap, request)
|
91
103
|
|
92
104
|
function_call_id = response.function_call_id
|
105
|
+
function_call_jwt = response.function_call_jwt
|
106
|
+
retry_policy = response.retry_policy
|
107
|
+
sync_client_retries_enabled = response.sync_client_retries_enabled
|
93
108
|
|
94
109
|
have_all_inputs = False
|
95
|
-
|
96
|
-
|
110
|
+
inputs_created = 0
|
111
|
+
inputs_sent = 0
|
112
|
+
inputs_retried = 0
|
113
|
+
outputs_completed = 0
|
114
|
+
outputs_received = 0
|
115
|
+
retried_outputs = 0
|
116
|
+
successful_completions = 0
|
117
|
+
failed_completions = 0
|
118
|
+
already_complete_duplicates = 0
|
119
|
+
stale_retry_duplicates = 0
|
120
|
+
no_context_duplicates = 0
|
97
121
|
|
98
122
|
def count_update():
|
99
123
|
if count_update_callback is not None:
|
100
|
-
count_update_callback(
|
124
|
+
count_update_callback(outputs_completed, inputs_created)
|
101
125
|
|
102
|
-
|
126
|
+
retry_queue = TimestampPriorityQueue()
|
103
127
|
completed_outputs: set[str] = set() # Set of input_ids whose outputs are complete (expecting no more values)
|
104
|
-
|
105
|
-
|
128
|
+
input_queue: asyncio.Queue[api_pb2.FunctionPutInputsItem | None] = asyncio.Queue()
|
129
|
+
map_items_manager = _MapItemsManager(
|
130
|
+
retry_policy, function_call_invocation_type, retry_queue, sync_client_retries_enabled
|
131
|
+
)
|
106
132
|
|
107
133
|
async def create_input(argskwargs):
|
108
|
-
nonlocal
|
109
|
-
idx =
|
110
|
-
|
134
|
+
nonlocal inputs_created
|
135
|
+
idx = inputs_created
|
136
|
+
inputs_created += 1
|
111
137
|
(args, kwargs) = argskwargs
|
112
138
|
return await _create_input(args, kwargs, client, idx=idx, method_name=function._use_method_name)
|
113
139
|
|
@@ -119,6 +145,8 @@ async def _map_invocation(
|
|
119
145
|
yield raw_input # args, kwargs
|
120
146
|
|
121
147
|
async def drain_input_generator():
|
148
|
+
nonlocal have_all_inputs
|
149
|
+
|
122
150
|
# Parallelize uploading blobs
|
123
151
|
async with aclosing(
|
124
152
|
async_map_ordered(input_iter(), create_input, concurrency=BLOB_MAX_PARALLELISM)
|
@@ -132,49 +160,100 @@ async def _map_invocation(
|
|
132
160
|
|
133
161
|
async def pump_inputs():
|
134
162
|
assert client.stub
|
135
|
-
nonlocal have_all_inputs,
|
136
|
-
async for items in queue_batch_iterator(input_queue, MAP_INVOCATION_CHUNK_SIZE):
|
163
|
+
nonlocal have_all_inputs, inputs_created, inputs_sent
|
164
|
+
async for items in queue_batch_iterator(input_queue, max_batch_size=MAP_INVOCATION_CHUNK_SIZE):
|
165
|
+
# Add items to the manager. Their state will be SENDING.
|
166
|
+
await map_items_manager.add_items(items)
|
137
167
|
request = api_pb2.FunctionPutInputsRequest(
|
138
|
-
function_id=function.object_id,
|
168
|
+
function_id=function.object_id,
|
169
|
+
inputs=items,
|
170
|
+
function_call_id=function_call_id,
|
139
171
|
)
|
140
172
|
logger.debug(
|
141
173
|
f"Pushing {len(items)} inputs to server. Num queued inputs awaiting push is {input_queue.qsize()}."
|
142
174
|
)
|
143
|
-
|
144
|
-
|
145
|
-
message=f"Warning: map progress for function {function._function_name} is limited."
|
146
|
-
" Common bottlenecks include slow iteration over results, or function backlogs.",
|
147
|
-
warning_interval=8,
|
148
|
-
errors_to_warn_for=[Status.RESOURCE_EXHAUSTED])
|
149
|
-
resp = await retry_transient_errors(
|
150
|
-
client.stub.FunctionPutInputs,
|
151
|
-
request,
|
152
|
-
max_retries=None,
|
153
|
-
max_delay=PUMP_INPUTS_MAX_RETRY_DELAY,
|
154
|
-
additional_status_codes=[Status.RESOURCE_EXHAUSTED],
|
155
|
-
retry_warning_message=retry_warning_message)
|
175
|
+
|
176
|
+
resp = await send_inputs(client.stub.FunctionPutInputs, request)
|
156
177
|
count_update()
|
157
|
-
|
158
|
-
|
178
|
+
inputs_sent += len(items)
|
179
|
+
# Change item state to WAITING_FOR_OUTPUT, and set the input_id and input_jwt which are in the response.
|
180
|
+
map_items_manager.handle_put_inputs_response(resp.inputs)
|
159
181
|
logger.debug(
|
160
182
|
f"Successfully pushed {len(items)} inputs to server. "
|
161
183
|
f"Num queued inputs awaiting push is {input_queue.qsize()}."
|
162
184
|
)
|
163
|
-
|
164
185
|
have_all_inputs = True
|
165
186
|
yield
|
166
187
|
|
188
|
+
async def retry_inputs():
|
189
|
+
nonlocal inputs_retried
|
190
|
+
async for retriable_idxs in queue_batch_iterator(retry_queue, max_batch_size=MAP_INVOCATION_CHUNK_SIZE):
|
191
|
+
# For each index, use the context in the manager to create a FunctionRetryInputsItem.
|
192
|
+
# This will also update the context state to RETRYING.
|
193
|
+
inputs: list[api_pb2.FunctionRetryInputsItem] = await map_items_manager.prepare_items_for_retry(
|
194
|
+
retriable_idxs
|
195
|
+
)
|
196
|
+
request = api_pb2.FunctionRetryInputsRequest(
|
197
|
+
function_call_jwt=function_call_jwt,
|
198
|
+
inputs=inputs,
|
199
|
+
)
|
200
|
+
resp = await send_inputs(client.stub.FunctionRetryInputs, request)
|
201
|
+
# Update the state to WAITING_FOR_OUTPUT, and update the input_jwt in the context
|
202
|
+
# to the new value in the response.
|
203
|
+
map_items_manager.handle_retry_response(resp.input_jwts)
|
204
|
+
logger.debug(f"Successfully pushed retry for {len(inputs)} to server.")
|
205
|
+
inputs_retried += len(inputs)
|
206
|
+
yield
|
207
|
+
|
208
|
+
async def send_inputs(
|
209
|
+
fn: "modal.client.UnaryUnaryWrapper",
|
210
|
+
request: typing.Union[api_pb2.FunctionPutInputsRequest, api_pb2.FunctionRetryInputsRequest],
|
211
|
+
) -> typing.Union[api_pb2.FunctionPutInputsResponse, api_pb2.FunctionRetryInputsResponse]:
|
212
|
+
# with 8 retries we log the warning below about every 30 seconds which isn't too spammy.
|
213
|
+
retry_warning_message = RetryWarningMessage(
|
214
|
+
message=f"Warning: map progress for function {function._function_name} is limited."
|
215
|
+
" Common bottlenecks include slow iteration over results, or function backlogs.",
|
216
|
+
warning_interval=8,
|
217
|
+
errors_to_warn_for=[Status.RESOURCE_EXHAUSTED],
|
218
|
+
)
|
219
|
+
return await retry_transient_errors(
|
220
|
+
fn,
|
221
|
+
request,
|
222
|
+
max_retries=None,
|
223
|
+
max_delay=PUMP_INPUTS_MAX_RETRY_DELAY,
|
224
|
+
additional_status_codes=[Status.RESOURCE_EXHAUSTED],
|
225
|
+
retry_warning_message=retry_warning_message,
|
226
|
+
)
|
227
|
+
|
167
228
|
async def get_all_outputs():
|
168
229
|
assert client.stub
|
169
|
-
nonlocal
|
230
|
+
nonlocal \
|
231
|
+
inputs_created, \
|
232
|
+
successful_completions, \
|
233
|
+
failed_completions, \
|
234
|
+
outputs_completed, \
|
235
|
+
have_all_inputs, \
|
236
|
+
outputs_received, \
|
237
|
+
already_complete_duplicates, \
|
238
|
+
no_context_duplicates, \
|
239
|
+
stale_retry_duplicates, \
|
240
|
+
retried_outputs
|
241
|
+
|
170
242
|
last_entry_id = "0-0"
|
171
|
-
|
243
|
+
|
244
|
+
while not have_all_inputs or outputs_completed < inputs_created:
|
245
|
+
logger.debug(f"Requesting outputs. Have {outputs_completed} outputs, {inputs_created} inputs.")
|
246
|
+
# Get input_jwts of all items in the WAITING_FOR_OUTPUT state.
|
247
|
+
# The server uses these to track for lost inputs.
|
248
|
+
input_jwts = [input_jwt for input_jwt in map_items_manager.get_input_jwts_waiting_for_output()]
|
249
|
+
|
172
250
|
request = api_pb2.FunctionGetOutputsRequest(
|
173
251
|
function_call_id=function_call_id,
|
174
252
|
timeout=OUTPUTS_TIMEOUT,
|
175
253
|
last_entry_id=last_entry_id,
|
176
254
|
clear_on_success=False,
|
177
255
|
requested_at=time.time(),
|
256
|
+
input_jwts=input_jwts,
|
178
257
|
)
|
179
258
|
response = await retry_transient_errors(
|
180
259
|
client.stub.FunctionGetOutputs,
|
@@ -183,19 +262,31 @@ async def _map_invocation(
|
|
183
262
|
attempt_timeout=OUTPUTS_TIMEOUT + ATTEMPT_TIMEOUT_GRACE_PERIOD,
|
184
263
|
)
|
185
264
|
|
186
|
-
if len(response.outputs) == 0:
|
187
|
-
continue
|
188
|
-
|
189
265
|
last_entry_id = response.last_entry_id
|
266
|
+
now_seconds = int(time.time())
|
190
267
|
for item in response.outputs:
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
268
|
+
outputs_received += 1
|
269
|
+
# If the output failed, and there are retries remaining, the input will be placed on the
|
270
|
+
# retry queue, and state updated to WAITING_FOR_RETRY. Otherwise, the output is considered
|
271
|
+
# complete and the item is removed from the manager.
|
272
|
+
output_type = await map_items_manager.handle_get_outputs_response(item, now_seconds)
|
273
|
+
if output_type == _OutputType.SUCCESSFUL_COMPLETION:
|
274
|
+
successful_completions += 1
|
275
|
+
elif output_type == _OutputType.FAILED_COMPLETION:
|
276
|
+
failed_completions += 1
|
277
|
+
elif output_type == _OutputType.NO_CONTEXT_DUPLICATE:
|
278
|
+
no_context_duplicates += 1
|
279
|
+
elif output_type == _OutputType.STALE_RETRY_DUPLICATE:
|
280
|
+
stale_retry_duplicates += 1
|
281
|
+
elif output_type == _OutputType.ALREADY_COMPLETE_DUPLICATE:
|
282
|
+
already_complete_duplicates += 1
|
283
|
+
elif output_type == _OutputType.RETRYING:
|
284
|
+
retried_outputs += 1
|
285
|
+
|
286
|
+
if output_type == _OutputType.SUCCESSFUL_COMPLETION or output_type == _OutputType.FAILED_COMPLETION:
|
287
|
+
completed_outputs.add(item.input_id)
|
288
|
+
outputs_completed += 1
|
289
|
+
yield item
|
199
290
|
|
200
291
|
async def get_all_outputs_and_clean_up():
|
201
292
|
assert client.stub
|
@@ -213,6 +304,7 @@ async def _map_invocation(
|
|
213
304
|
requested_at=time.time(),
|
214
305
|
)
|
215
306
|
await retry_transient_errors(client.stub.FunctionGetOutputs, request)
|
307
|
+
await retry_queue.close()
|
216
308
|
|
217
309
|
async def fetch_output(item: api_pb2.FunctionGetOutputsItem) -> tuple[int, Any]:
|
218
310
|
try:
|
@@ -239,17 +331,50 @@ async def _map_invocation(
|
|
239
331
|
else:
|
240
332
|
# hold on to outputs for function maps, so we can reorder them correctly.
|
241
333
|
received_outputs[idx] = output
|
242
|
-
|
334
|
+
|
335
|
+
while True:
|
336
|
+
if output_idx not in received_outputs:
|
337
|
+
# we haven't received the output for the current index yet.
|
338
|
+
# stop returning outputs to the caller and instead wait for
|
339
|
+
# the next output to arrive from the server.
|
340
|
+
break
|
341
|
+
|
243
342
|
output = received_outputs.pop(output_idx)
|
244
343
|
yield _OutputValue(output)
|
245
344
|
output_idx += 1
|
246
345
|
|
247
346
|
assert len(received_outputs) == 0
|
248
347
|
|
249
|
-
async
|
348
|
+
async def log_debug_stats():
|
349
|
+
def log_stats():
|
350
|
+
logger.debug(
|
351
|
+
f"Map stats: sync_client_retries_enabled={sync_client_retries_enabled} "
|
352
|
+
f"have_all_inputs={have_all_inputs} inputs_created={inputs_created} input_sent={inputs_sent} "
|
353
|
+
f"inputs_retried={inputs_retried} outputs_received={outputs_received} "
|
354
|
+
f"successful_completions={successful_completions} failed_completions={failed_completions} "
|
355
|
+
f"no_context_duplicates={no_context_duplicates} old_retry_duplicates={stale_retry_duplicates} "
|
356
|
+
f"already_complete_duplicates={already_complete_duplicates} "
|
357
|
+
f"retried_outputs={retried_outputs} input_queue_size={input_queue.qsize()} "
|
358
|
+
f"retry_queue_size={retry_queue.qsize()} map_items_manager={len(map_items_manager)}"
|
359
|
+
)
|
360
|
+
while True:
|
361
|
+
log_stats()
|
362
|
+
try:
|
363
|
+
await asyncio.sleep(10)
|
364
|
+
except asyncio.CancelledError:
|
365
|
+
# Log final stats before exiting
|
366
|
+
log_stats()
|
367
|
+
break
|
368
|
+
|
369
|
+
log_debug_stats_task = asyncio.create_task(log_debug_stats())
|
370
|
+
async with aclosing(
|
371
|
+
async_merge(drain_input_generator(), pump_inputs(), poll_outputs(), retry_inputs())
|
372
|
+
) as streamer:
|
250
373
|
async for response in streamer:
|
251
374
|
if response is not None:
|
252
375
|
yield response.value
|
376
|
+
log_debug_stats_task.cancel()
|
377
|
+
await log_debug_stats_task
|
253
378
|
|
254
379
|
|
255
380
|
@warn_if_generator_is_not_consumed(function_name="Function.map")
|
@@ -431,3 +556,227 @@ def _starmap_sync(
|
|
431
556
|
"Use Function.map.aio()/Function.for_each.aio() instead."
|
432
557
|
),
|
433
558
|
)
|
559
|
+
|
560
|
+
|
561
|
+
class _MapItemState(enum.Enum):
|
562
|
+
# The input is being sent the server with a PutInputs request, but the response has not been received yet.
|
563
|
+
SENDING = 1
|
564
|
+
# A call to either PutInputs or FunctionRetry has completed, and we are waiting to receive the output.
|
565
|
+
WAITING_FOR_OUTPUT = 2
|
566
|
+
# The input is on the retry queue, and waiting for its delay to expire.
|
567
|
+
WAITING_TO_RETRY = 3
|
568
|
+
# The input is being sent to the server with a FunctionRetry request, but the response has not been received yet.
|
569
|
+
RETRYING = 4
|
570
|
+
# The output has been received and was either successful, or failed with no more retries remaining.
|
571
|
+
COMPLETE = 5
|
572
|
+
|
573
|
+
class _OutputType(enum.Enum):
|
574
|
+
SUCCESSFUL_COMPLETION = 1
|
575
|
+
FAILED_COMPLETION = 2
|
576
|
+
RETRYING = 3
|
577
|
+
ALREADY_COMPLETE_DUPLICATE = 4
|
578
|
+
STALE_RETRY_DUPLICATE = 5
|
579
|
+
NO_CONTEXT_DUPLICATE = 6
|
580
|
+
|
581
|
+
class _MapItemContext:
|
582
|
+
state: _MapItemState
|
583
|
+
input: api_pb2.FunctionInput
|
584
|
+
retry_manager: RetryManager
|
585
|
+
sync_client_retries_enabled:bool
|
586
|
+
# Both these futures are strings. Omitting generic type because
|
587
|
+
# it causes an error when running `inv protoc type-stubs`.
|
588
|
+
input_id: asyncio.Future
|
589
|
+
input_jwt: asyncio.Future
|
590
|
+
previous_input_jwt: Optional[str]
|
591
|
+
_event_loop: asyncio.AbstractEventLoop
|
592
|
+
|
593
|
+
def __init__(self, input: api_pb2.FunctionInput, retry_manager: RetryManager, sync_client_retries_enabled: bool):
|
594
|
+
self.state = _MapItemState.SENDING
|
595
|
+
self.input = input
|
596
|
+
self.retry_manager = retry_manager
|
597
|
+
self.sync_client_retries_enabled = sync_client_retries_enabled
|
598
|
+
self._event_loop = asyncio.get_event_loop()
|
599
|
+
# create a future for each input, to be resolved when we have
|
600
|
+
# received the input ID and JWT from the server. this addresses
|
601
|
+
# a race condition where we could receive outputs before we have
|
602
|
+
# recorded the input ID and JWT in `pending_outputs`.
|
603
|
+
self.input_jwt = self._event_loop.create_future()
|
604
|
+
self.input_id = self._event_loop.create_future()
|
605
|
+
|
606
|
+
def handle_put_inputs_response(self, item: api_pb2.FunctionPutInputsResponseItem):
|
607
|
+
self.input_jwt.set_result(item.input_jwt)
|
608
|
+
self.input_id.set_result(item.input_id)
|
609
|
+
# Set state to WAITING_FOR_OUTPUT only if current state is SENDING. If state is
|
610
|
+
# RETRYING, WAITING_TO_RETRY, or COMPLETE, then we already got the output.
|
611
|
+
if self.state == _MapItemState.SENDING:
|
612
|
+
self.state = _MapItemState.WAITING_FOR_OUTPUT
|
613
|
+
|
614
|
+
async def handle_get_outputs_response(
|
615
|
+
self,
|
616
|
+
item: api_pb2.FunctionGetOutputsItem,
|
617
|
+
now_seconds: int,
|
618
|
+
function_call_invocation_type: "api_pb2.FunctionCallInvocationType.ValueType",
|
619
|
+
retry_queue: TimestampPriorityQueue,
|
620
|
+
) -> _OutputType:
|
621
|
+
"""
|
622
|
+
Processes the output, and determines if it is complete or needs to be retried.
|
623
|
+
|
624
|
+
Return True if input state was changed to COMPLETE, otherwise False.
|
625
|
+
"""
|
626
|
+
# If the item is already complete, this is a duplicate output and can be ignored.
|
627
|
+
if self.state == _MapItemState.COMPLETE:
|
628
|
+
logger.debug(
|
629
|
+
f"Received output for input marked as complete. Must be duplicate, so ignoring. "
|
630
|
+
f"idx={item.idx} input_id={item.input_id}, retry_count={item.retry_count}"
|
631
|
+
)
|
632
|
+
return _OutputType.ALREADY_COMPLETE_DUPLICATE
|
633
|
+
# If the item's retry count doesn't match our retry count, this is probably a duplicate of an old output.
|
634
|
+
if item.retry_count != self.retry_manager.retry_count:
|
635
|
+
logger.debug(
|
636
|
+
f"Received output with stale retry_count, so ignoring. "
|
637
|
+
f"idx={item.idx} input_id={item.input_id} retry_count={item.retry_count} "
|
638
|
+
f"expected_retry_count={self.retry_manager.retry_count}"
|
639
|
+
)
|
640
|
+
return _OutputType.STALE_RETRY_DUPLICATE
|
641
|
+
|
642
|
+
# retry failed inputs when the function call invocation type is SYNC
|
643
|
+
if (
|
644
|
+
item.result.status == api_pb2.GenericResult.GENERIC_STATUS_SUCCESS
|
645
|
+
or function_call_invocation_type != api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC
|
646
|
+
or not self.sync_client_retries_enabled
|
647
|
+
):
|
648
|
+
self.state = _MapItemState.COMPLETE
|
649
|
+
if item.result.status == api_pb2.GenericResult.GENERIC_STATUS_SUCCESS:
|
650
|
+
return _OutputType.SUCCESSFUL_COMPLETION
|
651
|
+
else:
|
652
|
+
return _OutputType.FAILED_COMPLETION
|
653
|
+
|
654
|
+
# Get the retry delay and increment the retry count.
|
655
|
+
# TODO(ryan): We must call this for lost inputs - even though we will set the retry delay to 0 later -
|
656
|
+
# because we must increment the retry count. That's awkward, let's come up with something better.
|
657
|
+
# TODO(ryan):To maintain parity with server-side retries, retrying lost inputs should not count towards
|
658
|
+
# the retry policy. However we use the retry_count number as a unique identifier on each attempt to:
|
659
|
+
# 1) ignore duplicate outputs
|
660
|
+
# 2) ignore late outputs received from previous attempts
|
661
|
+
# 3) avoid a server race condition between FunctionRetry and GetOutputs that results in deleted input metadata
|
662
|
+
# For now, lost inputs will count towards the retry policy. But let's address this in another PR, perhaps by
|
663
|
+
# tracking total attempts and attempts which count towards the retry policy separately.
|
664
|
+
delay_ms = self.retry_manager.get_delay_ms()
|
665
|
+
|
666
|
+
# For system failures on the server, we retry immediately.
|
667
|
+
# and the failure does not count towards the retry policy.
|
668
|
+
if item.result.status == api_pb2.GenericResult.GENERIC_STATUS_INTERNAL_FAILURE:
|
669
|
+
delay_ms = 0
|
670
|
+
|
671
|
+
# None means the maximum number of retries has been reached, so output the error
|
672
|
+
if delay_ms is None:
|
673
|
+
self.state = _MapItemState.COMPLETE
|
674
|
+
return _OutputType.FAILED_COMPLETION
|
675
|
+
|
676
|
+
self.state = _MapItemState.WAITING_TO_RETRY
|
677
|
+
await retry_queue.put(now_seconds + (delay_ms / 1000), item.idx)
|
678
|
+
|
679
|
+
return _OutputType.RETRYING
|
680
|
+
|
681
|
+
async def prepare_item_for_retry(self) -> api_pb2.FunctionRetryInputsItem:
|
682
|
+
self.state = _MapItemState.RETRYING
|
683
|
+
# If the input_jwt is not set, then put_inputs hasn't returned yet. Block until we have it.
|
684
|
+
input_jwt = await self.input_jwt
|
685
|
+
self.input_jwt = self._event_loop.create_future()
|
686
|
+
return api_pb2.FunctionRetryInputsItem(
|
687
|
+
input_jwt=input_jwt,
|
688
|
+
input=self.input,
|
689
|
+
retry_count=self.retry_manager.retry_count,
|
690
|
+
)
|
691
|
+
|
692
|
+
def handle_retry_response(self, input_jwt: str):
|
693
|
+
self.input_jwt.set_result(input_jwt)
|
694
|
+
self.state = _MapItemState.WAITING_FOR_OUTPUT
|
695
|
+
|
696
|
+
|
697
|
+
class _MapItemsManager:
|
698
|
+
def __init__(
|
699
|
+
self,
|
700
|
+
retry_policy: api_pb2.FunctionRetryPolicy,
|
701
|
+
function_call_invocation_type: "api_pb2.FunctionCallInvocationType.ValueType",
|
702
|
+
retry_queue: TimestampPriorityQueue,
|
703
|
+
sync_client_retries_enabled: bool
|
704
|
+
):
|
705
|
+
self._retry_policy = retry_policy
|
706
|
+
self.function_call_invocation_type = function_call_invocation_type
|
707
|
+
self._retry_queue = retry_queue
|
708
|
+
# semaphore to limit the number of inputs that can be in progress at once
|
709
|
+
self._inputs_outstanding = asyncio.BoundedSemaphore(MAP_MAX_INPUTS_OUTSTANDING)
|
710
|
+
self._item_context: dict[int, _MapItemContext] = {}
|
711
|
+
self._sync_client_retries_enabled = sync_client_retries_enabled
|
712
|
+
|
713
|
+
async def add_items(self, items: list[api_pb2.FunctionPutInputsItem]):
|
714
|
+
for item in items:
|
715
|
+
# acquire semaphore to limit the number of inputs in progress
|
716
|
+
# (either queued to be sent, waiting for completion, or retrying)
|
717
|
+
await self._inputs_outstanding.acquire()
|
718
|
+
self._item_context[item.idx] = _MapItemContext(
|
719
|
+
input=item.input,
|
720
|
+
retry_manager=RetryManager(self._retry_policy),
|
721
|
+
sync_client_retries_enabled=self._sync_client_retries_enabled,
|
722
|
+
)
|
723
|
+
|
724
|
+
async def prepare_items_for_retry(self, retriable_idxs: list[int]) -> list[api_pb2.FunctionRetryInputsItem]:
|
725
|
+
return [await self._item_context[idx].prepare_item_for_retry() for idx in retriable_idxs]
|
726
|
+
|
727
|
+
def get_input_jwts_waiting_for_output(self) -> list[str]:
|
728
|
+
"""
|
729
|
+
Returns a list of input_jwts for inputs that are waiting for output.
|
730
|
+
"""
|
731
|
+
# If input_jwt is not done, the call to PutInputs has not completed, so omit it from results.
|
732
|
+
return [
|
733
|
+
ctx.input_jwt.result()
|
734
|
+
for ctx in self._item_context.values()
|
735
|
+
if ctx.state == _MapItemState.WAITING_FOR_OUTPUT and ctx.input_jwt.done()
|
736
|
+
]
|
737
|
+
|
738
|
+
def _remove_item(self, item_idx: int):
|
739
|
+
del self._item_context[item_idx]
|
740
|
+
self._inputs_outstanding.release()
|
741
|
+
|
742
|
+
def get_item_context(self, item_idx: int) -> _MapItemContext:
|
743
|
+
return self._item_context.get(item_idx)
|
744
|
+
|
745
|
+
def handle_put_inputs_response(self, items: list[api_pb2.FunctionPutInputsResponseItem]):
|
746
|
+
for item in items:
|
747
|
+
ctx = self._item_context.get(item.idx, None)
|
748
|
+
# If the context is None, then get_all_outputs() has already received a successful
|
749
|
+
# output, and deleted the context. This happens if FunctionGetOutputs completes
|
750
|
+
# before FunctionPutInputsResponse is received.
|
751
|
+
if ctx is not None:
|
752
|
+
ctx.handle_put_inputs_response(item)
|
753
|
+
|
754
|
+
def handle_retry_response(self, input_jwts: list[str]):
|
755
|
+
for input_jwt in input_jwts:
|
756
|
+
decoded_jwt = DecodedJwt.decode_without_verification(input_jwt)
|
757
|
+
ctx = self._item_context.get(decoded_jwt.payload["idx"], None)
|
758
|
+
# If the context is None, then get_all_outputs() has already received a successful
|
759
|
+
# output, and deleted the context. This happens if FunctionGetOutputs completes
|
760
|
+
# before FunctionRetryInputsResponse is received.
|
761
|
+
if ctx is not None:
|
762
|
+
ctx.handle_retry_response(input_jwt)
|
763
|
+
|
764
|
+
async def handle_get_outputs_response(self, item: api_pb2.FunctionGetOutputsItem, now_seconds: int) -> _OutputType:
|
765
|
+
ctx = self._item_context.get(item.idx, None)
|
766
|
+
if ctx is None:
|
767
|
+
# We've already processed this output, so we can skip it.
|
768
|
+
# This can happen because the worker can sometimes send duplicate outputs.
|
769
|
+
logger.debug(
|
770
|
+
f"Received output that does not have entry in item_context map, so ignoring. "
|
771
|
+
f"idx={item.idx} input_id={item.input_id} retry_count={item.retry_count} "
|
772
|
+
)
|
773
|
+
return _OutputType.NO_CONTEXT_DUPLICATE
|
774
|
+
output_type = await ctx.handle_get_outputs_response(
|
775
|
+
item, now_seconds, self.function_call_invocation_type, self._retry_queue
|
776
|
+
)
|
777
|
+
if output_type == _OutputType.SUCCESSFUL_COMPLETION or output_type == _OutputType.FAILED_COMPLETION:
|
778
|
+
self._remove_item(item.idx)
|
779
|
+
return output_type
|
780
|
+
|
781
|
+
def __len__(self):
|
782
|
+
return len(self._item_context)
|