modal 0.73.116__py3-none-any.whl → 0.73.128__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- modal/__init__.py +2 -0
- modal/_functions.py +19 -8
- modal/_partial_function.py +54 -0
- modal/_runtime/container_io_manager.py +13 -9
- modal/_runtime/container_io_manager.pyi +7 -4
- modal/_serialization.py +92 -44
- modal/_utils/async_utils.py +71 -6
- modal/_utils/function_utils.py +33 -13
- modal/_utils/jwt_utils.py +38 -0
- modal/app.py +34 -5
- modal/app.pyi +3 -2
- modal/cli/app.py +15 -0
- modal/client.pyi +2 -2
- modal/cls.py +3 -13
- modal/cls.pyi +0 -2
- modal/functions.pyi +8 -7
- modal/parallel_map.py +393 -44
- modal/parallel_map.pyi +75 -0
- modal/partial_function.py +2 -0
- modal/partial_function.pyi +9 -0
- modal/retries.py +11 -9
- modal/sandbox.py +5 -1
- {modal-0.73.116.dist-info → modal-0.73.128.dist-info}/METADATA +1 -1
- {modal-0.73.116.dist-info → modal-0.73.128.dist-info}/RECORD +36 -35
- {modal-0.73.116.dist-info → modal-0.73.128.dist-info}/WHEEL +1 -1
- modal_proto/api.proto +15 -2
- modal_proto/api_grpc.py +16 -0
- modal_proto/api_pb2.py +284 -263
- modal_proto/api_pb2.pyi +49 -6
- modal_proto/api_pb2_grpc.py +33 -0
- modal_proto/api_pb2_grpc.pyi +10 -0
- modal_proto/modal_api_grpc.py +1 -0
- modal_version/_version_generated.py +1 -1
- {modal-0.73.116.dist-info → modal-0.73.128.dist-info}/LICENSE +0 -0
- {modal-0.73.116.dist-info → modal-0.73.128.dist-info}/entry_points.txt +0 -0
- {modal-0.73.116.dist-info → modal-0.73.128.dist-info}/top_level.txt +0 -0
modal/__init__.py
CHANGED
@@ -27,6 +27,7 @@ try:
|
|
27
27
|
asgi_app,
|
28
28
|
batched,
|
29
29
|
build,
|
30
|
+
concurrent,
|
30
31
|
enter,
|
31
32
|
exit,
|
32
33
|
fastapi_endpoint,
|
@@ -82,6 +83,7 @@ __all__ = [
|
|
82
83
|
"asgi_app",
|
83
84
|
"batched",
|
84
85
|
"build",
|
86
|
+
"concurrent",
|
85
87
|
"current_function_call_id",
|
86
88
|
"current_input_id",
|
87
89
|
"enable_output",
|
modal/_functions.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
# Copyright Modal Labs 2023
|
2
|
+
import asyncio
|
2
3
|
import dataclasses
|
3
4
|
import inspect
|
4
5
|
import textwrap
|
@@ -24,7 +25,7 @@ from ._pty import get_pty_info
|
|
24
25
|
from ._resolver import Resolver
|
25
26
|
from ._resources import convert_fn_config_to_resources_config
|
26
27
|
from ._runtime.execution_context import current_input_id, is_local
|
27
|
-
from ._serialization import serialize, serialize_proto_params
|
28
|
+
from ._serialization import apply_defaults, serialize, serialize_proto_params, validate_params
|
28
29
|
from ._traceback import print_server_warnings
|
29
30
|
from ._utils.async_utils import (
|
30
31
|
TaskContext,
|
@@ -174,7 +175,7 @@ class _Invocation:
|
|
174
175
|
input_jwt=input.input_jwt,
|
175
176
|
input_id=input.input_id,
|
176
177
|
item=item,
|
177
|
-
sync_client_retries_enabled=response.sync_client_retries_enabled
|
178
|
+
sync_client_retries_enabled=response.sync_client_retries_enabled,
|
178
179
|
)
|
179
180
|
return _Invocation(client.stub, function_call_id, client, retry_context)
|
180
181
|
|
@@ -256,9 +257,13 @@ class _Invocation:
|
|
256
257
|
try:
|
257
258
|
return await self._get_single_output(ctx.input_jwt)
|
258
259
|
except (UserCodeException, FunctionTimeoutError) as exc:
|
259
|
-
|
260
|
+
delay_ms = user_retry_manager.get_delay_ms()
|
261
|
+
if delay_ms is None:
|
262
|
+
raise exc
|
263
|
+
await asyncio.sleep(delay_ms / 1000)
|
260
264
|
except InternalFailure:
|
261
|
-
# For system failures on the server, we retry immediately
|
265
|
+
# For system failures on the server, we retry immediately,
|
266
|
+
# and the failure does not count towards the retry policy.
|
262
267
|
pass
|
263
268
|
await self._retry_input()
|
264
269
|
|
@@ -430,7 +435,8 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
|
|
430
435
|
max_containers: Optional[int] = None,
|
431
436
|
buffer_containers: Optional[int] = None,
|
432
437
|
scaledown_window: Optional[int] = None,
|
433
|
-
|
438
|
+
max_concurrent_inputs: Optional[int] = None,
|
439
|
+
target_concurrent_inputs: Optional[int] = None,
|
434
440
|
batch_max_size: Optional[int] = None,
|
435
441
|
batch_wait_ms: Optional[int] = None,
|
436
442
|
cloud: Optional[str] = None,
|
@@ -781,7 +787,8 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
|
|
781
787
|
runtime_perf_record=config.get("runtime_perf_record"),
|
782
788
|
app_name=app_name,
|
783
789
|
is_builder_function=is_builder_function,
|
784
|
-
|
790
|
+
max_concurrent_inputs=max_concurrent_inputs or 0,
|
791
|
+
target_concurrent_inputs=target_concurrent_inputs or 0,
|
785
792
|
batch_max_size=batch_max_size or 0,
|
786
793
|
batch_linger_ms=batch_wait_ms or 0,
|
787
794
|
worker_id=config.get("worker_id"),
|
@@ -968,8 +975,11 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
|
|
968
975
|
"Can't use positional arguments with modal.parameter-based synthetic constructors.\n"
|
969
976
|
"Use (<parameter_name>=value) keyword arguments when constructing classes instead."
|
970
977
|
)
|
971
|
-
|
972
|
-
|
978
|
+
schema = parent._class_parameter_info.schema
|
979
|
+
kwargs_with_defaults = apply_defaults(kwargs, schema)
|
980
|
+
validate_params(kwargs_with_defaults, schema)
|
981
|
+
serialized_params = serialize_proto_params(kwargs_with_defaults)
|
982
|
+
can_use_parent = len(parent._class_parameter_info.schema) == 0 # no parameters
|
973
983
|
else:
|
974
984
|
can_use_parent = len(args) + len(kwargs) == 0 and options is None
|
975
985
|
serialized_params = serialize((args, kwargs))
|
@@ -1304,6 +1314,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
|
|
1304
1314
|
order_outputs,
|
1305
1315
|
return_exceptions,
|
1306
1316
|
count_update_callback,
|
1317
|
+
api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC
|
1307
1318
|
)
|
1308
1319
|
) as stream:
|
1309
1320
|
async for item in stream:
|
modal/_partial_function.py
CHANGED
@@ -59,6 +59,8 @@ class _PartialFunction(typing.Generic[P, ReturnType, OriginalReturnType]):
|
|
59
59
|
force_build: bool
|
60
60
|
cluster_size: Optional[int] # Experimental: Clustered functions
|
61
61
|
build_timeout: Optional[int]
|
62
|
+
max_concurrent_inputs: Optional[int]
|
63
|
+
target_concurrent_inputs: Optional[int]
|
62
64
|
|
63
65
|
def __init__(
|
64
66
|
self,
|
@@ -72,6 +74,8 @@ class _PartialFunction(typing.Generic[P, ReturnType, OriginalReturnType]):
|
|
72
74
|
cluster_size: Optional[int] = None, # Experimental: Clustered functions
|
73
75
|
force_build: bool = False,
|
74
76
|
build_timeout: Optional[int] = None,
|
77
|
+
max_concurrent_inputs: Optional[int] = None,
|
78
|
+
target_concurrent_inputs: Optional[int] = None,
|
75
79
|
):
|
76
80
|
self.raw_f = raw_f
|
77
81
|
self.flags = flags
|
@@ -89,6 +93,8 @@ class _PartialFunction(typing.Generic[P, ReturnType, OriginalReturnType]):
|
|
89
93
|
self.cluster_size = cluster_size # Experimental: Clustered functions
|
90
94
|
self.force_build = force_build
|
91
95
|
self.build_timeout = build_timeout
|
96
|
+
self.max_concurrent_inputs = max_concurrent_inputs
|
97
|
+
self.target_concurrent_inputs = target_concurrent_inputs
|
92
98
|
|
93
99
|
def _get_raw_f(self) -> Callable[P, ReturnType]:
|
94
100
|
return self.raw_f
|
@@ -143,6 +149,8 @@ class _PartialFunction(typing.Generic[P, ReturnType, OriginalReturnType]):
|
|
143
149
|
batch_wait_ms=self.batch_wait_ms,
|
144
150
|
force_build=self.force_build,
|
145
151
|
build_timeout=self.build_timeout,
|
152
|
+
max_concurrent_inputs=self.max_concurrent_inputs,
|
153
|
+
target_concurrent_inputs=self.target_concurrent_inputs,
|
146
154
|
)
|
147
155
|
|
148
156
|
|
@@ -722,3 +730,49 @@ def _batched(
|
|
722
730
|
)
|
723
731
|
|
724
732
|
return wrapper
|
733
|
+
|
734
|
+
|
735
|
+
def _concurrent(
|
736
|
+
_warn_parentheses_missing=None,
|
737
|
+
*,
|
738
|
+
max_inputs: int, # Hard limit on each container's input concurrency
|
739
|
+
target_inputs: Optional[int] = None, # Input concurrency that Modal's autoscaler should target
|
740
|
+
) -> Callable[[Union[Callable[..., Any], _PartialFunction]], _PartialFunction]:
|
741
|
+
"""Decorator that allows individual containers to handle multiple inputs concurrently.
|
742
|
+
|
743
|
+
The concurrency mechanism depends on whether the function is async or not:
|
744
|
+
- Async functions will run inputs on a single thread as asyncio tasks.
|
745
|
+
- Synchronous functions will use multi-threading. The code must be thread-safe.
|
746
|
+
|
747
|
+
Input concurrency will be most useful for workflows that are IO-bound
|
748
|
+
(e.g., making network requests) or when running an inference server that supports
|
749
|
+
dynamic batching.
|
750
|
+
|
751
|
+
When `target_inputs` is set, Modal's autoscaler will try to provision resources such
|
752
|
+
that each container is running that many inputs concurrently. Containers may burst up to
|
753
|
+
up to `max_inputs` if resources are insufficient to remain at the target concurrency.
|
754
|
+
"""
|
755
|
+
if _warn_parentheses_missing is not None:
|
756
|
+
raise InvalidError(
|
757
|
+
"Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@modal.concurrent()`."
|
758
|
+
)
|
759
|
+
|
760
|
+
if target_inputs and target_inputs > max_inputs:
|
761
|
+
raise InvalidError("`target_inputs` parameter cannot be greater than `max_inputs`.")
|
762
|
+
|
763
|
+
def wrapper(obj: Union[Callable[..., Any], _PartialFunction]) -> _PartialFunction:
|
764
|
+
if isinstance(obj, _PartialFunction):
|
765
|
+
# Risky that we need to mutate the parameters here; should make this safer
|
766
|
+
obj.max_concurrent_inputs = max_inputs
|
767
|
+
obj.target_concurrent_inputs = target_inputs
|
768
|
+
obj.add_flags(_PartialFunctionFlags.FUNCTION)
|
769
|
+
return obj
|
770
|
+
|
771
|
+
return _PartialFunction(
|
772
|
+
obj,
|
773
|
+
_PartialFunctionFlags.FUNCTION,
|
774
|
+
max_concurrent_inputs=max_inputs,
|
775
|
+
target_concurrent_inputs=target_inputs,
|
776
|
+
)
|
777
|
+
|
778
|
+
return wrapper
|
@@ -63,7 +63,9 @@ class IOContext:
|
|
63
63
|
"""
|
64
64
|
|
65
65
|
input_ids: list[str]
|
66
|
+
retry_counts: list[int]
|
66
67
|
function_call_ids: list[str]
|
68
|
+
function_inputs: list[api_pb2.FunctionInput]
|
67
69
|
finalized_function: "modal._runtime.user_code_imports.FinalizedFunction"
|
68
70
|
|
69
71
|
_cancel_issued: bool = False
|
@@ -72,6 +74,7 @@ class IOContext:
|
|
72
74
|
def __init__(
|
73
75
|
self,
|
74
76
|
input_ids: list[str],
|
77
|
+
retry_counts: list[int],
|
75
78
|
function_call_ids: list[str],
|
76
79
|
finalized_function: "modal._runtime.user_code_imports.FinalizedFunction",
|
77
80
|
function_inputs: list[api_pb2.FunctionInput],
|
@@ -79,9 +82,10 @@ class IOContext:
|
|
79
82
|
client: _Client,
|
80
83
|
):
|
81
84
|
self.input_ids = input_ids
|
85
|
+
self.retry_counts = retry_counts
|
82
86
|
self.function_call_ids = function_call_ids
|
83
87
|
self.finalized_function = finalized_function
|
84
|
-
self.
|
88
|
+
self.function_inputs = function_inputs
|
85
89
|
self._is_batched = is_batched
|
86
90
|
self._client = client
|
87
91
|
|
@@ -90,11 +94,11 @@ class IOContext:
|
|
90
94
|
cls,
|
91
95
|
client: _Client,
|
92
96
|
finalized_functions: dict[str, "modal._runtime.user_code_imports.FinalizedFunction"],
|
93
|
-
inputs: list[tuple[str, str, api_pb2.FunctionInput]],
|
97
|
+
inputs: list[tuple[str, int, str, api_pb2.FunctionInput]],
|
94
98
|
is_batched: bool,
|
95
99
|
) -> "IOContext":
|
96
100
|
assert len(inputs) >= 1 if is_batched else len(inputs) == 1
|
97
|
-
input_ids, function_call_ids, function_inputs = zip(*inputs)
|
101
|
+
input_ids, retry_counts, function_call_ids, function_inputs = zip(*inputs)
|
98
102
|
|
99
103
|
async def _populate_input_blobs(client: _Client, input: api_pb2.FunctionInput) -> api_pb2.FunctionInput:
|
100
104
|
# If we got a pointer to a blob, download it from S3.
|
@@ -111,7 +115,7 @@ class IOContext:
|
|
111
115
|
method_name = function_inputs[0].method_name
|
112
116
|
assert all(method_name == input.method_name for input in function_inputs)
|
113
117
|
finalized_function = finalized_functions[method_name]
|
114
|
-
return cls(input_ids, function_call_ids, finalized_function, function_inputs, is_batched, client)
|
118
|
+
return cls(input_ids, retry_counts, function_call_ids, finalized_function, function_inputs, is_batched, client)
|
115
119
|
|
116
120
|
def set_cancel_callback(self, cb: Callable[[], None]):
|
117
121
|
self._cancel_callback = cb
|
@@ -135,7 +139,7 @@ class IOContext:
|
|
135
139
|
# to make sure we handle user exceptions properly
|
136
140
|
# and don't retry
|
137
141
|
deserialized_args = [
|
138
|
-
deserialize(input.args, self._client) if input.args else ((), {}) for input in self.
|
142
|
+
deserialize(input.args, self._client) if input.args else ((), {}) for input in self.function_inputs
|
139
143
|
]
|
140
144
|
if not self._is_batched:
|
141
145
|
return deserialized_args[0]
|
@@ -551,7 +555,7 @@ class _ContainerIOManager:
|
|
551
555
|
self,
|
552
556
|
batch_max_size: int,
|
553
557
|
batch_wait_ms: int,
|
554
|
-
) -> AsyncIterator[list[tuple[str, str, api_pb2.FunctionInput]]]:
|
558
|
+
) -> AsyncIterator[list[tuple[str, int, str, api_pb2.FunctionInput]]]:
|
555
559
|
request = api_pb2.FunctionGetInputsRequest(function_id=self.function_id)
|
556
560
|
iteration = 0
|
557
561
|
while self._fetching_inputs:
|
@@ -586,8 +590,7 @@ class _ContainerIOManager:
|
|
586
590
|
if item.kill_switch:
|
587
591
|
logger.debug(f"Task {self.task_id} input kill signal input.")
|
588
592
|
return
|
589
|
-
|
590
|
-
inputs.append((item.input_id, item.function_call_id, item.input))
|
593
|
+
inputs.append((item.input_id, item.retry_count, item.function_call_id, item.input))
|
591
594
|
if item.input.final_input:
|
592
595
|
if request.batch_max_size > 0:
|
593
596
|
logger.debug(f"Task {self.task_id} Final input not expected in batch input stream")
|
@@ -648,8 +651,9 @@ class _ContainerIOManager:
|
|
648
651
|
output_created_at=output_created_at,
|
649
652
|
result=result,
|
650
653
|
data_format=data_format,
|
654
|
+
retry_count=retry_count,
|
651
655
|
)
|
652
|
-
for input_id, result in zip(io_context.input_ids, results)
|
656
|
+
for input_id, retry_count, result in zip(io_context.input_ids, io_context.retry_counts, results)
|
653
657
|
]
|
654
658
|
await retry_transient_errors(
|
655
659
|
self._client.stub.FunctionPutOutputs,
|
@@ -14,7 +14,9 @@ class Sentinel: ...
|
|
14
14
|
|
15
15
|
class IOContext:
|
16
16
|
input_ids: list[str]
|
17
|
+
retry_counts: list[int]
|
17
18
|
function_call_ids: list[str]
|
19
|
+
function_inputs: list[modal_proto.api_pb2.FunctionInput]
|
18
20
|
finalized_function: modal._runtime.user_code_imports.FinalizedFunction
|
19
21
|
_cancel_issued: bool
|
20
22
|
_cancel_callback: typing.Optional[collections.abc.Callable[[], None]]
|
@@ -22,6 +24,7 @@ class IOContext:
|
|
22
24
|
def __init__(
|
23
25
|
self,
|
24
26
|
input_ids: list[str],
|
27
|
+
retry_counts: list[int],
|
25
28
|
function_call_ids: list[str],
|
26
29
|
finalized_function: modal._runtime.user_code_imports.FinalizedFunction,
|
27
30
|
function_inputs: list[modal_proto.api_pb2.FunctionInput],
|
@@ -33,7 +36,7 @@ class IOContext:
|
|
33
36
|
cls,
|
34
37
|
client: modal.client._Client,
|
35
38
|
finalized_functions: dict[str, modal._runtime.user_code_imports.FinalizedFunction],
|
36
|
-
inputs: list[tuple[str, str, modal_proto.api_pb2.FunctionInput]],
|
39
|
+
inputs: list[tuple[str, int, str, modal_proto.api_pb2.FunctionInput]],
|
37
40
|
is_batched: bool,
|
38
41
|
) -> IOContext: ...
|
39
42
|
def set_cancel_callback(self, cb: collections.abc.Callable[[], None]): ...
|
@@ -116,7 +119,7 @@ class _ContainerIOManager:
|
|
116
119
|
def get_max_inputs_to_fetch(self): ...
|
117
120
|
def _generate_inputs(
|
118
121
|
self, batch_max_size: int, batch_wait_ms: int
|
119
|
-
) -> collections.abc.AsyncIterator[list[tuple[str, str, modal_proto.api_pb2.FunctionInput]]]: ...
|
122
|
+
) -> collections.abc.AsyncIterator[list[tuple[str, int, str, modal_proto.api_pb2.FunctionInput]]]: ...
|
120
123
|
def run_inputs_outputs(
|
121
124
|
self,
|
122
125
|
finalized_functions: dict[str, modal._runtime.user_code_imports.FinalizedFunction],
|
@@ -287,10 +290,10 @@ class ContainerIOManager:
|
|
287
290
|
class ___generate_inputs_spec(typing_extensions.Protocol[SUPERSELF]):
|
288
291
|
def __call__(
|
289
292
|
self, batch_max_size: int, batch_wait_ms: int
|
290
|
-
) -> typing.Iterator[list[tuple[str, str, modal_proto.api_pb2.FunctionInput]]]: ...
|
293
|
+
) -> typing.Iterator[list[tuple[str, int, str, modal_proto.api_pb2.FunctionInput]]]: ...
|
291
294
|
def aio(
|
292
295
|
self, batch_max_size: int, batch_wait_ms: int
|
293
|
-
) -> collections.abc.AsyncIterator[list[tuple[str, str, modal_proto.api_pb2.FunctionInput]]]: ...
|
296
|
+
) -> collections.abc.AsyncIterator[list[tuple[str, int, str, modal_proto.api_pb2.FunctionInput]]]: ...
|
294
297
|
|
295
298
|
_generate_inputs: ___generate_inputs_spec[typing_extensions.Self]
|
296
299
|
|
modal/_serialization.py
CHANGED
@@ -400,6 +400,7 @@ class ParamTypeInfo:
|
|
400
400
|
default_field: str
|
401
401
|
proto_field: str
|
402
402
|
converter: typing.Callable[[str], typing.Any]
|
403
|
+
type: type
|
403
404
|
|
404
405
|
|
405
406
|
PYTHON_TO_PROTO_TYPE: dict[type, "api_pb2.ParameterType.ValueType"] = {
|
@@ -411,75 +412,112 @@ PYTHON_TO_PROTO_TYPE: dict[type, "api_pb2.ParameterType.ValueType"] = {
|
|
411
412
|
|
412
413
|
PROTO_TYPE_INFO = {
|
413
414
|
# Protobuf type enum -> encode/decode helper metadata
|
414
|
-
api_pb2.PARAM_TYPE_STRING: ParamTypeInfo(
|
415
|
-
|
415
|
+
api_pb2.PARAM_TYPE_STRING: ParamTypeInfo(
|
416
|
+
default_field="string_default",
|
417
|
+
proto_field="string_value",
|
418
|
+
converter=str,
|
419
|
+
type=str,
|
420
|
+
),
|
421
|
+
api_pb2.PARAM_TYPE_INT: ParamTypeInfo(
|
422
|
+
default_field="int_default",
|
423
|
+
proto_field="int_value",
|
424
|
+
converter=int,
|
425
|
+
type=int,
|
426
|
+
),
|
416
427
|
api_pb2.PARAM_TYPE_BYTES: ParamTypeInfo(
|
417
|
-
default_field="bytes_default",
|
428
|
+
default_field="bytes_default",
|
429
|
+
proto_field="bytes_value",
|
430
|
+
converter=assert_bytes,
|
431
|
+
type=bytes,
|
418
432
|
),
|
419
433
|
}
|
420
434
|
|
421
435
|
|
422
|
-
def
|
423
|
-
|
436
|
+
def apply_defaults(
|
437
|
+
python_params: typing.Mapping[str, Any], schema: typing.Sequence[api_pb2.ClassParameterSpec]
|
438
|
+
) -> dict[str, Any]:
|
439
|
+
"""Apply any declared defaults from the provided schema, if values aren't provided in python_params
|
440
|
+
|
441
|
+
Conceptually similar to inspect.BoundArguments.apply_defaults.
|
442
|
+
|
443
|
+
Note: Apply this before serializing parameters in order to get consistent parameter
|
444
|
+
pools regardless if a value is explicitly provided or not.
|
445
|
+
"""
|
446
|
+
result = {**python_params}
|
424
447
|
for schema_param in schema:
|
425
|
-
|
426
|
-
|
427
|
-
|
448
|
+
if schema_param.has_default and schema_param.name not in python_params:
|
449
|
+
default_field_name = schema_param.WhichOneof("default_oneof")
|
450
|
+
if default_field_name is None:
|
451
|
+
raise InvalidError(f"{schema_param.name} declared as having a default, but has no default value")
|
452
|
+
result[schema_param.name] = getattr(schema_param, default_field_name)
|
453
|
+
return result
|
454
|
+
|
455
|
+
|
456
|
+
def serialize_proto_params(python_params: dict[str, Any]) -> bytes:
|
457
|
+
proto_params: list[api_pb2.ClassParameterValue] = []
|
458
|
+
for param_name, python_value in python_params.items():
|
459
|
+
python_type = type(python_value)
|
460
|
+
protobuf_type = get_proto_parameter_type(python_type)
|
461
|
+
type_info = PROTO_TYPE_INFO.get(protobuf_type)
|
428
462
|
proto_param = api_pb2.ClassParameterValue(
|
429
|
-
name=
|
430
|
-
type=
|
463
|
+
name=param_name,
|
464
|
+
type=protobuf_type,
|
431
465
|
)
|
432
|
-
python_value = python_params.get(schema_param.name)
|
433
|
-
if python_value is None:
|
434
|
-
if schema_param.has_default:
|
435
|
-
python_value = getattr(schema_param, type_info.default_field)
|
436
|
-
else:
|
437
|
-
raise ValueError(f"Missing required parameter: {schema_param.name}")
|
438
466
|
try:
|
439
467
|
converted_value = type_info.converter(python_value)
|
440
468
|
except ValueError as exc:
|
441
|
-
raise ValueError(f"Invalid type for parameter {
|
469
|
+
raise ValueError(f"Invalid type for parameter {param_name}: {exc}")
|
442
470
|
setattr(proto_param, type_info.proto_field, converted_value)
|
443
471
|
proto_params.append(proto_param)
|
444
472
|
proto_bytes = api_pb2.ClassParameterSet(parameters=proto_params).SerializeToString(deterministic=True)
|
445
473
|
return proto_bytes
|
446
474
|
|
447
475
|
|
448
|
-
def deserialize_proto_params(serialized_params: bytes
|
449
|
-
# TODO: this currently requires the schema to decode a payload, but we should make the validation
|
450
|
-
# distinct from the deserialization
|
476
|
+
def deserialize_proto_params(serialized_params: bytes) -> dict[str, Any]:
|
451
477
|
proto_struct = api_pb2.ClassParameterSet()
|
452
478
|
proto_struct.ParseFromString(serialized_params)
|
453
|
-
value_by_name = {p.name: p for p in proto_struct.parameters}
|
454
479
|
python_params = {}
|
455
|
-
for
|
456
|
-
if schema_param.name not in value_by_name:
|
457
|
-
# TODO: handle default values? Could just be a flag on the FunctionParameter schema spec,
|
458
|
-
# allowing it to not be supplied in the FunctionParameterSet?
|
459
|
-
raise AttributeError(f"Constructor arguments don't match declared parameters (missing {schema_param.name})")
|
460
|
-
param_value = value_by_name[schema_param.name]
|
461
|
-
if schema_param.type != param_value.type:
|
462
|
-
raise ValueError(
|
463
|
-
"Constructor arguments types don't match declared parameters "
|
464
|
-
f"({schema_param.name}: type {schema_param.type} != type {param_value.type})"
|
465
|
-
)
|
480
|
+
for param in proto_struct.parameters:
|
466
481
|
python_value: Any
|
467
|
-
if
|
468
|
-
python_value =
|
469
|
-
elif
|
470
|
-
python_value =
|
471
|
-
elif
|
472
|
-
python_value =
|
482
|
+
if param.type == api_pb2.PARAM_TYPE_STRING:
|
483
|
+
python_value = param.string_value
|
484
|
+
elif param.type == api_pb2.PARAM_TYPE_INT:
|
485
|
+
python_value = param.int_value
|
486
|
+
elif param.type == api_pb2.PARAM_TYPE_BYTES:
|
487
|
+
python_value = param.bytes_value
|
473
488
|
else:
|
474
|
-
|
475
|
-
# custom non proto types encoded as bytes in the proto, e.g. PARAM_TYPE_PYTHON_PICKLE
|
476
|
-
raise NotImplementedError("Only strings and ints are supported parameter value types at the moment")
|
489
|
+
raise NotImplementedError(f"Unimplemented parameter type: {param.type}.")
|
477
490
|
|
478
|
-
python_params[
|
491
|
+
python_params[param.name] = python_value
|
479
492
|
|
480
493
|
return python_params
|
481
494
|
|
482
495
|
|
496
|
+
def validate_params(params: dict[str, Any], schema: typing.Sequence[api_pb2.ClassParameterSpec]):
|
497
|
+
# first check that all declared values are provided
|
498
|
+
for schema_param in schema:
|
499
|
+
if schema_param.name not in params:
|
500
|
+
# we expect all values to be present - even defaulted ones (defaults are applied on payload construction)
|
501
|
+
raise InvalidError(f"Missing required parameter: {schema_param.name}")
|
502
|
+
python_value = params[schema_param.name]
|
503
|
+
python_type = type(python_value)
|
504
|
+
param_protobuf_type = get_proto_parameter_type(python_type)
|
505
|
+
if schema_param.type != param_protobuf_type:
|
506
|
+
expected_python_type = PROTO_TYPE_INFO[schema_param.type].type
|
507
|
+
raise TypeError(
|
508
|
+
f"Parameter '{schema_param.name}' type error: expected {expected_python_type.__name__}, "
|
509
|
+
f"got {python_type.__name__}"
|
510
|
+
)
|
511
|
+
|
512
|
+
schema_fields = {p.name for p in schema}
|
513
|
+
# then check that no extra values are provided
|
514
|
+
non_declared_fields = params.keys() - schema_fields
|
515
|
+
if non_declared_fields:
|
516
|
+
raise InvalidError(
|
517
|
+
f"The following parameter names were provided but are not present in the schema: {non_declared_fields}"
|
518
|
+
)
|
519
|
+
|
520
|
+
|
483
521
|
def deserialize_params(serialized_params: bytes, function_def: api_pb2.Function, _client: "modal.client._Client"):
|
484
522
|
if function_def.class_parameter_info.format in (
|
485
523
|
api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_UNSPECIFIED,
|
@@ -488,11 +526,21 @@ def deserialize_params(serialized_params: bytes, function_def: api_pb2.Function,
|
|
488
526
|
# legacy serialization format - pickle of `(args, kwargs)` w/ support for modal object arguments
|
489
527
|
param_args, param_kwargs = deserialize(serialized_params, _client)
|
490
528
|
elif function_def.class_parameter_info.format == api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PROTO:
|
491
|
-
param_args = ()
|
492
|
-
param_kwargs = deserialize_proto_params(serialized_params
|
529
|
+
param_args = () # we use kwargs only for our implicit constructors
|
530
|
+
param_kwargs = deserialize_proto_params(serialized_params)
|
531
|
+
# TODO: We can probably remove the validation below since we do validation in the caller?
|
532
|
+
validate_params(param_kwargs, list(function_def.class_parameter_info.schema))
|
493
533
|
else:
|
494
534
|
raise ExecutionError(
|
495
535
|
f"Unknown class parameter serialization format: {function_def.class_parameter_info.format}"
|
496
536
|
)
|
497
537
|
|
498
538
|
return param_args, param_kwargs
|
539
|
+
|
540
|
+
|
541
|
+
def get_proto_parameter_type(parameter_type: type) -> "api_pb2.ParameterType.ValueType":
|
542
|
+
if parameter_type not in PYTHON_TO_PROTO_TYPE:
|
543
|
+
type_name = getattr(parameter_type, "__name__", repr(parameter_type))
|
544
|
+
supported = ", ".join(parameter_type.__name__ for parameter_type in PYTHON_TO_PROTO_TYPE.keys())
|
545
|
+
raise InvalidError(f"{type_name} is not a supported parameter type. Use one of: {supported}")
|
546
|
+
return PYTHON_TO_PROTO_TYPE[parameter_type]
|
modal/_utils/async_utils.py
CHANGED
@@ -12,6 +12,7 @@ from dataclasses import dataclass
|
|
12
12
|
from typing import (
|
13
13
|
Any,
|
14
14
|
Callable,
|
15
|
+
Generic,
|
15
16
|
Optional,
|
16
17
|
TypeVar,
|
17
18
|
Union,
|
@@ -26,6 +27,10 @@ from typing_extensions import ParamSpec, assert_type
|
|
26
27
|
from ..exception import InvalidError
|
27
28
|
from .logger import logger
|
28
29
|
|
30
|
+
T = TypeVar("T")
|
31
|
+
P = ParamSpec("P")
|
32
|
+
V = TypeVar("V")
|
33
|
+
|
29
34
|
synchronizer = synchronicity.Synchronizer()
|
30
35
|
|
31
36
|
|
@@ -260,7 +265,72 @@ def run_coro_blocking(coro):
|
|
260
265
|
return fut.result()
|
261
266
|
|
262
267
|
|
263
|
-
|
268
|
+
class TimestampPriorityQueue(Generic[T]):
|
269
|
+
"""
|
270
|
+
A priority queue that schedules items to be processed at specific timestamps.
|
271
|
+
"""
|
272
|
+
|
273
|
+
_MAX_PRIORITY = float("inf")
|
274
|
+
|
275
|
+
def __init__(self, maxsize: int = 0):
|
276
|
+
self.condition = asyncio.Condition()
|
277
|
+
self._queue: asyncio.PriorityQueue[tuple[float, Union[T, None]]] = asyncio.PriorityQueue(maxsize=maxsize)
|
278
|
+
|
279
|
+
async def close(self):
|
280
|
+
await self.put(self._MAX_PRIORITY, None)
|
281
|
+
|
282
|
+
async def put(self, timestamp: float, item: Union[T, None]):
|
283
|
+
"""
|
284
|
+
Add an item to the queue to be processed at a specific timestamp.
|
285
|
+
"""
|
286
|
+
await self._queue.put((timestamp, item))
|
287
|
+
async with self.condition:
|
288
|
+
self.condition.notify_all() # notify any waiting coroutines
|
289
|
+
|
290
|
+
async def get(self) -> Union[T, None]:
|
291
|
+
"""
|
292
|
+
Get the next item from the queue that is ready to be processed.
|
293
|
+
"""
|
294
|
+
while True:
|
295
|
+
async with self.condition:
|
296
|
+
while self.empty():
|
297
|
+
await self.condition.wait()
|
298
|
+
# peek at the next item
|
299
|
+
timestamp, item = await self._queue.get()
|
300
|
+
now = time.time()
|
301
|
+
if timestamp < now:
|
302
|
+
return item
|
303
|
+
if timestamp == self._MAX_PRIORITY:
|
304
|
+
return None
|
305
|
+
# not ready yet, calculate sleep time
|
306
|
+
sleep_time = timestamp - now
|
307
|
+
self._queue.put_nowait((timestamp, item)) # put it back
|
308
|
+
# wait until either the timeout or a new item is added
|
309
|
+
try:
|
310
|
+
await asyncio.wait_for(self.condition.wait(), timeout=sleep_time)
|
311
|
+
except asyncio.TimeoutError:
|
312
|
+
continue
|
313
|
+
|
314
|
+
def empty(self) -> bool:
|
315
|
+
return self._queue.empty()
|
316
|
+
|
317
|
+
def qsize(self) -> int:
|
318
|
+
return self._queue.qsize()
|
319
|
+
|
320
|
+
async def clear(self):
|
321
|
+
"""
|
322
|
+
Clear the retry queue. Used for testing to simulate reading all elements from queue using queue_batch_iterator.
|
323
|
+
"""
|
324
|
+
while not self.empty():
|
325
|
+
await self.get()
|
326
|
+
|
327
|
+
def __len__(self):
|
328
|
+
return self._queue.qsize()
|
329
|
+
|
330
|
+
|
331
|
+
async def queue_batch_iterator(
|
332
|
+
q: Union[asyncio.Queue, TimestampPriorityQueue], max_batch_size=100, debounce_time=0.015
|
333
|
+
):
|
264
334
|
"""
|
265
335
|
Read from a queue but return lists of items when queue is large
|
266
336
|
|
@@ -405,11 +475,6 @@ def on_shutdown(coro):
|
|
405
475
|
_shutdown_tasks.append(asyncio.create_task(wrapper()))
|
406
476
|
|
407
477
|
|
408
|
-
T = TypeVar("T")
|
409
|
-
P = ParamSpec("P")
|
410
|
-
V = TypeVar("V")
|
411
|
-
|
412
|
-
|
413
478
|
def asyncify(f: Callable[P, T]) -> Callable[P, typing.Coroutine[None, None, T]]:
|
414
479
|
"""Convert a blocking function into one that runs in the current loop's executor."""
|
415
480
|
|