modal 0.73.130__py3-none-any.whl → 0.73.131__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- modal/_functions.py +8 -3
- modal/_resolver.py +6 -1
- modal/_serialization.py +78 -96
- modal/_type_manager.py +229 -0
- modal/_utils/function_utils.py +4 -27
- modal/client.pyi +2 -2
- modal/cls.py +6 -2
- modal/functions.pyi +6 -6
- {modal-0.73.130.dist-info → modal-0.73.131.dist-info}/METADATA +1 -1
- {modal-0.73.130.dist-info → modal-0.73.131.dist-info}/RECORD +18 -17
- modal_proto/api.proto +15 -4
- modal_proto/api_pb2.py +717 -704
- modal_proto/api_pb2.pyi +40 -2
- modal_version/_version_generated.py +1 -1
- {modal-0.73.130.dist-info → modal-0.73.131.dist-info}/LICENSE +0 -0
- {modal-0.73.130.dist-info → modal-0.73.131.dist-info}/WHEEL +0 -0
- {modal-0.73.130.dist-info → modal-0.73.131.dist-info}/entry_points.txt +0 -0
- {modal-0.73.130.dist-info → modal-0.73.131.dist-info}/top_level.txt +0 -0
modal/_functions.py
CHANGED
@@ -25,7 +25,12 @@ from ._pty import get_pty_info
|
|
25
25
|
from ._resolver import Resolver
|
26
26
|
from ._resources import convert_fn_config_to_resources_config
|
27
27
|
from ._runtime.execution_context import current_input_id, is_local
|
28
|
-
from ._serialization import
|
28
|
+
from ._serialization import (
|
29
|
+
apply_defaults,
|
30
|
+
serialize,
|
31
|
+
serialize_proto_params,
|
32
|
+
validate_parameter_values,
|
33
|
+
)
|
29
34
|
from ._traceback import print_server_warnings
|
30
35
|
from ._utils.async_utils import (
|
31
36
|
TaskContext,
|
@@ -975,7 +980,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
|
|
975
980
|
)
|
976
981
|
schema = parent._class_parameter_info.schema
|
977
982
|
kwargs_with_defaults = apply_defaults(kwargs, schema)
|
978
|
-
|
983
|
+
validate_parameter_values(kwargs_with_defaults, schema)
|
979
984
|
serialized_params = serialize_proto_params(kwargs_with_defaults)
|
980
985
|
can_use_parent = len(parent._class_parameter_info.schema) == 0 # no parameters
|
981
986
|
else:
|
@@ -1312,7 +1317,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
|
|
1312
1317
|
order_outputs,
|
1313
1318
|
return_exceptions,
|
1314
1319
|
count_update_callback,
|
1315
|
-
api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC
|
1320
|
+
api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC,
|
1316
1321
|
)
|
1317
1322
|
) as stream:
|
1318
1323
|
async for item in stream:
|
modal/_resolver.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
# Copyright Modal Labs 2023
|
2
2
|
import asyncio
|
3
3
|
import contextlib
|
4
|
+
import traceback
|
4
5
|
import typing
|
5
6
|
from asyncio import Future
|
6
7
|
from collections.abc import Hashable
|
@@ -153,7 +154,11 @@ class Resolver:
|
|
153
154
|
self._deduplication_cache[deduplication_key] = cached_future
|
154
155
|
|
155
156
|
# TODO(elias): print original exception/trace rather than the Resolver-internal trace
|
156
|
-
|
157
|
+
try:
|
158
|
+
return await cached_future
|
159
|
+
except Exception:
|
160
|
+
traceback.print_exc()
|
161
|
+
raise
|
157
162
|
|
158
163
|
def objects(self) -> list["modal._object._Object"]:
|
159
164
|
unique_objects: dict[str, "modal._object._Object"] = {}
|
modal/_serialization.py
CHANGED
@@ -1,14 +1,16 @@
|
|
1
1
|
# Copyright Modal Labs 2022
|
2
|
+
import inspect
|
2
3
|
import io
|
3
4
|
import pickle
|
4
5
|
import typing
|
5
|
-
from
|
6
|
+
from inspect import Parameter
|
6
7
|
from typing import Any
|
7
8
|
|
8
9
|
from modal._utils.async_utils import synchronizer
|
9
10
|
from modal_proto import api_pb2
|
10
11
|
|
11
12
|
from ._object import _Object
|
13
|
+
from ._type_manager import parameter_serde_registry, schema_registry
|
12
14
|
from ._vendor import cloudpickle
|
13
15
|
from .config import logger
|
14
16
|
from .exception import DeserializationError, ExecutionError, InvalidError
|
@@ -389,50 +391,6 @@ def check_valid_cls_constructor_arg(key, obj):
|
|
389
391
|
)
|
390
392
|
|
391
393
|
|
392
|
-
def assert_bytes(obj: Any):
|
393
|
-
if not isinstance(obj, bytes):
|
394
|
-
raise TypeError(f"Expected bytes, got {type(obj)}")
|
395
|
-
return obj
|
396
|
-
|
397
|
-
|
398
|
-
@dataclass
|
399
|
-
class ParamTypeInfo:
|
400
|
-
default_field: str
|
401
|
-
proto_field: str
|
402
|
-
converter: typing.Callable[[str], typing.Any]
|
403
|
-
type: type
|
404
|
-
|
405
|
-
|
406
|
-
PYTHON_TO_PROTO_TYPE: dict[type, "api_pb2.ParameterType.ValueType"] = {
|
407
|
-
# python type -> protobuf type enum
|
408
|
-
str: api_pb2.PARAM_TYPE_STRING,
|
409
|
-
int: api_pb2.PARAM_TYPE_INT,
|
410
|
-
bytes: api_pb2.PARAM_TYPE_BYTES,
|
411
|
-
}
|
412
|
-
|
413
|
-
PROTO_TYPE_INFO = {
|
414
|
-
# Protobuf type enum -> encode/decode helper metadata
|
415
|
-
api_pb2.PARAM_TYPE_STRING: ParamTypeInfo(
|
416
|
-
default_field="string_default",
|
417
|
-
proto_field="string_value",
|
418
|
-
converter=str,
|
419
|
-
type=str,
|
420
|
-
),
|
421
|
-
api_pb2.PARAM_TYPE_INT: ParamTypeInfo(
|
422
|
-
default_field="int_default",
|
423
|
-
proto_field="int_value",
|
424
|
-
converter=int,
|
425
|
-
type=int,
|
426
|
-
),
|
427
|
-
api_pb2.PARAM_TYPE_BYTES: ParamTypeInfo(
|
428
|
-
default_field="bytes_default",
|
429
|
-
proto_field="bytes_value",
|
430
|
-
converter=assert_bytes,
|
431
|
-
type=bytes,
|
432
|
-
),
|
433
|
-
}
|
434
|
-
|
435
|
-
|
436
394
|
def apply_defaults(
|
437
395
|
python_params: typing.Mapping[str, Any], schema: typing.Sequence[api_pb2.ClassParameterSpec]
|
438
396
|
) -> dict[str, Any]:
|
@@ -453,68 +411,56 @@ def apply_defaults(
|
|
453
411
|
return result
|
454
412
|
|
455
413
|
|
414
|
+
def encode_parameter_value(name: str, python_value: Any) -> api_pb2.ClassParameterValue:
|
415
|
+
"""Map to proto parameter representation using python runtime type information"""
|
416
|
+
struct = parameter_serde_registry.encode(python_value)
|
417
|
+
struct.name = name
|
418
|
+
return struct
|
419
|
+
|
420
|
+
|
456
421
|
def serialize_proto_params(python_params: dict[str, Any]) -> bytes:
|
457
422
|
proto_params: list[api_pb2.ClassParameterValue] = []
|
458
423
|
for param_name, python_value in python_params.items():
|
459
|
-
|
460
|
-
protobuf_type = get_proto_parameter_type(python_type)
|
461
|
-
type_info = PROTO_TYPE_INFO.get(protobuf_type)
|
462
|
-
proto_param = api_pb2.ClassParameterValue(
|
463
|
-
name=param_name,
|
464
|
-
type=protobuf_type,
|
465
|
-
)
|
466
|
-
try:
|
467
|
-
converted_value = type_info.converter(python_value)
|
468
|
-
except ValueError as exc:
|
469
|
-
raise ValueError(f"Invalid type for parameter {param_name}: {exc}")
|
470
|
-
setattr(proto_param, type_info.proto_field, converted_value)
|
471
|
-
proto_params.append(proto_param)
|
424
|
+
proto_params.append(encode_parameter_value(param_name, python_value))
|
472
425
|
proto_bytes = api_pb2.ClassParameterSet(parameters=proto_params).SerializeToString(deterministic=True)
|
473
426
|
return proto_bytes
|
474
427
|
|
475
428
|
|
476
429
|
def deserialize_proto_params(serialized_params: bytes) -> dict[str, Any]:
|
477
|
-
proto_struct = api_pb2.ClassParameterSet()
|
478
|
-
proto_struct.ParseFromString(serialized_params)
|
430
|
+
proto_struct = api_pb2.ClassParameterSet.FromString(serialized_params)
|
479
431
|
python_params = {}
|
480
432
|
for param in proto_struct.parameters:
|
481
|
-
|
482
|
-
if param.type == api_pb2.PARAM_TYPE_STRING:
|
483
|
-
python_value = param.string_value
|
484
|
-
elif param.type == api_pb2.PARAM_TYPE_INT:
|
485
|
-
python_value = param.int_value
|
486
|
-
elif param.type == api_pb2.PARAM_TYPE_BYTES:
|
487
|
-
python_value = param.bytes_value
|
488
|
-
else:
|
489
|
-
raise NotImplementedError(f"Unimplemented parameter type: {param.type}.")
|
490
|
-
|
491
|
-
python_params[param.name] = python_value
|
433
|
+
python_params[param.name] = parameter_serde_registry.decode(param)
|
492
434
|
|
493
435
|
return python_params
|
494
436
|
|
495
437
|
|
496
|
-
def
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
if
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
438
|
+
def validate_parameter_values(payload: dict[str, Any], schema: typing.Sequence[api_pb2.ClassParameterSpec]):
|
439
|
+
"""Ensure parameter payload conforms to the schema of a class
|
440
|
+
|
441
|
+
Checks that:
|
442
|
+
* All fields are specified (defaults are expected to already be applied on the payload)
|
443
|
+
* No extra fields are specified
|
444
|
+
* The type of each field is correct
|
445
|
+
"""
|
446
|
+
for param_spec in schema:
|
447
|
+
if param_spec.name not in payload:
|
448
|
+
raise InvalidError(f"Missing required parameter: {param_spec.name}")
|
449
|
+
python_value = payload[param_spec.name]
|
450
|
+
if param_spec.HasField("full_type") and param_spec.full_type.base_type:
|
451
|
+
type_enum_value = param_spec.full_type.base_type
|
452
|
+
else:
|
453
|
+
type_enum_value = param_spec.type # backwards compatibility pre-full_type
|
454
|
+
|
455
|
+
parameter_serde_registry.validate_value_for_enum_type(type_enum_value, python_value)
|
511
456
|
|
512
457
|
schema_fields = {p.name for p in schema}
|
513
458
|
# then check that no extra values are provided
|
514
|
-
non_declared_fields =
|
459
|
+
non_declared_fields = payload.keys() - schema_fields
|
515
460
|
if non_declared_fields:
|
516
461
|
raise InvalidError(
|
517
|
-
f"The following parameter names were provided but are not
|
462
|
+
f"The following parameter names were provided but are not defined class modal.parameters for the class: "
|
463
|
+
f"{', '.join(non_declared_fields)}"
|
518
464
|
)
|
519
465
|
|
520
466
|
|
@@ -528,8 +474,6 @@ def deserialize_params(serialized_params: bytes, function_def: api_pb2.Function,
|
|
528
474
|
elif function_def.class_parameter_info.format == api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PROTO:
|
529
475
|
param_args = () # we use kwargs only for our implicit constructors
|
530
476
|
param_kwargs = deserialize_proto_params(serialized_params)
|
531
|
-
# TODO: We can probably remove the validation below since we do validation in the caller?
|
532
|
-
validate_params(param_kwargs, list(function_def.class_parameter_info.schema))
|
533
477
|
else:
|
534
478
|
raise ExecutionError(
|
535
479
|
f"Unknown class parameter serialization format: {function_def.class_parameter_info.format}"
|
@@ -538,9 +482,47 @@ def deserialize_params(serialized_params: bytes, function_def: api_pb2.Function,
|
|
538
482
|
return param_args, param_kwargs
|
539
483
|
|
540
484
|
|
541
|
-
def
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
485
|
+
def _signature_parameter_to_spec(
|
486
|
+
python_signature_parameter: inspect.Parameter, include_legacy_parameter_fields: bool = False
|
487
|
+
) -> api_pb2.ClassParameterSpec:
|
488
|
+
"""Returns proto representation of Parameter as returned by inspect.signature()
|
489
|
+
|
490
|
+
Setting include_legacy_parameter_fields makes the output backwards compatible with
|
491
|
+
pre v0.74 clients looking at class parameter specifications, and should not be used
|
492
|
+
when registering *function* schemas.
|
493
|
+
"""
|
494
|
+
declared_type = python_signature_parameter.annotation
|
495
|
+
full_proto_type = schema_registry.get_proto_generic_type(declared_type)
|
496
|
+
has_default = python_signature_parameter.default is not Parameter.empty
|
497
|
+
|
498
|
+
field_spec = api_pb2.ClassParameterSpec(
|
499
|
+
name=python_signature_parameter.name,
|
500
|
+
full_type=full_proto_type,
|
501
|
+
has_default=has_default,
|
502
|
+
)
|
503
|
+
if include_legacy_parameter_fields:
|
504
|
+
# add the .{type}_default and `.type` values as required by legacy clients
|
505
|
+
# looking at class parameter specs
|
506
|
+
if full_proto_type.base_type == api_pb2.PARAM_TYPE_INT:
|
507
|
+
if has_default:
|
508
|
+
field_spec.int_default = python_signature_parameter.default
|
509
|
+
field_spec.type = api_pb2.PARAM_TYPE_INT
|
510
|
+
elif full_proto_type.base_type == api_pb2.PARAM_TYPE_STRING:
|
511
|
+
if has_default:
|
512
|
+
field_spec.string_default = python_signature_parameter.default
|
513
|
+
field_spec.type = api_pb2.PARAM_TYPE_STRING
|
514
|
+
elif full_proto_type.base_type == api_pb2.PARAM_TYPE_BYTES:
|
515
|
+
if has_default:
|
516
|
+
field_spec.bytes_default = python_signature_parameter.default
|
517
|
+
field_spec.type = api_pb2.PARAM_TYPE_BYTES
|
518
|
+
|
519
|
+
return field_spec
|
520
|
+
|
521
|
+
|
522
|
+
def signature_to_parameter_specs(signature: inspect.Signature) -> list[api_pb2.ClassParameterSpec]:
|
523
|
+
# only used for modal.parameter() specs, uses backwards compatible fields and types
|
524
|
+
modal_parameters: list[api_pb2.ClassParameterSpec] = []
|
525
|
+
for param in signature.parameters.values():
|
526
|
+
field_spec = _signature_parameter_to_spec(param, include_legacy_parameter_fields=True)
|
527
|
+
modal_parameters.append(field_spec)
|
528
|
+
return modal_parameters
|
modal/_type_manager.py
ADDED
@@ -0,0 +1,229 @@
|
|
1
|
+
# Copyright Modal Labs 2025
|
2
|
+
import typing
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
import typing_extensions
|
6
|
+
|
7
|
+
from modal.exception import InvalidError
|
8
|
+
from modal_proto import api_pb2
|
9
|
+
|
10
|
+
|
11
|
+
class ParameterProtoSerde(typing.Protocol):
|
12
|
+
def encode(self, value: Any) -> api_pb2.ClassParameterValue: ...
|
13
|
+
|
14
|
+
def decode(self, proto_value: api_pb2.ClassParameterValue) -> Any: ...
|
15
|
+
|
16
|
+
def validate(self, python_value: Any): ...
|
17
|
+
|
18
|
+
|
19
|
+
class ProtoParameterSerdeRegistry:
|
20
|
+
_py_base_type_to_serde: dict[type, ParameterProtoSerde]
|
21
|
+
_proto_type_to_serde: dict["api_pb2.ParameterType.ValueType", ParameterProtoSerde]
|
22
|
+
|
23
|
+
def __init__(self):
|
24
|
+
self._py_base_type_to_serde = {}
|
25
|
+
self._proto_type_to_serde = {}
|
26
|
+
|
27
|
+
def register_encoder(self, python_base_type: type) -> typing.Callable[[ParameterProtoSerde], ParameterProtoSerde]:
|
28
|
+
def deco(ph: ParameterProtoSerde) -> ParameterProtoSerde:
|
29
|
+
if python_base_type in self._py_base_type_to_serde:
|
30
|
+
raise ValueError("Can't register the same encoder type twice")
|
31
|
+
self._py_base_type_to_serde[python_base_type] = ph
|
32
|
+
return ph
|
33
|
+
|
34
|
+
return deco
|
35
|
+
|
36
|
+
def register_decoder(
|
37
|
+
self, enum_type_value: "api_pb2.ParameterType.ValueType"
|
38
|
+
) -> typing.Callable[[ParameterProtoSerde], ParameterProtoSerde]:
|
39
|
+
def deco(ph: ParameterProtoSerde) -> ParameterProtoSerde:
|
40
|
+
if enum_type_value in self._proto_type_to_serde:
|
41
|
+
raise ValueError("Can't register the same decoder type twice")
|
42
|
+
self._proto_type_to_serde[enum_type_value] = ph
|
43
|
+
return ph
|
44
|
+
|
45
|
+
return deco
|
46
|
+
|
47
|
+
def encode(self, python_value: Any) -> api_pb2.ClassParameterValue:
|
48
|
+
return self._get_encoder(type(python_value)).encode(python_value)
|
49
|
+
|
50
|
+
def supports_type(self, declared_type: type) -> bool:
|
51
|
+
try:
|
52
|
+
self._get_encoder(declared_type)
|
53
|
+
return True
|
54
|
+
except InvalidError:
|
55
|
+
return False
|
56
|
+
|
57
|
+
def decode(self, param_value: api_pb2.ClassParameterValue) -> Any:
|
58
|
+
return self._get_decoder(param_value.type).decode(param_value)
|
59
|
+
|
60
|
+
def validate_parameter_type(self, declared_type: type):
|
61
|
+
"""Raises a helpful TypeError if the supplied type isn't supported by class parameters"""
|
62
|
+
if not parameter_serde_registry.supports_type(declared_type):
|
63
|
+
supported_types = self._py_base_type_to_serde.keys()
|
64
|
+
supported_str = ", ".join(t.__name__ for t in supported_types)
|
65
|
+
|
66
|
+
raise TypeError(
|
67
|
+
f"{declared_type.__name__} is not a supported modal.parameter() type. Use one of: {supported_str}"
|
68
|
+
)
|
69
|
+
|
70
|
+
def validate_value_for_enum_type(self, enum_value: "api_pb2.ParameterType.ValueType", python_value: Any):
|
71
|
+
serde = self._get_decoder(enum_value) # use the schema's expected decoder
|
72
|
+
serde.validate(python_value)
|
73
|
+
|
74
|
+
def _get_encoder(self, python_base_type: type) -> ParameterProtoSerde:
|
75
|
+
try:
|
76
|
+
return self._py_base_type_to_serde[python_base_type]
|
77
|
+
except KeyError:
|
78
|
+
raise InvalidError(f"No class parameter encoder implemented for type `{python_base_type.__name__}`")
|
79
|
+
|
80
|
+
def _get_decoder(self, enum_value: "api_pb2.ParameterType.ValueType") -> ParameterProtoSerde:
|
81
|
+
try:
|
82
|
+
return self._proto_type_to_serde[enum_value]
|
83
|
+
except KeyError:
|
84
|
+
try:
|
85
|
+
enum_name = api_pb2.ParameterType.Name(enum_value)
|
86
|
+
except ValueError:
|
87
|
+
enum_name = str(enum_value)
|
88
|
+
|
89
|
+
raise InvalidError(f"No class parameter decoder implemented for type {enum_name}.")
|
90
|
+
|
91
|
+
|
92
|
+
parameter_serde_registry = ProtoParameterSerdeRegistry()
|
93
|
+
|
94
|
+
|
95
|
+
@parameter_serde_registry.register_encoder(int)
|
96
|
+
@parameter_serde_registry.register_decoder(api_pb2.PARAM_TYPE_INT)
|
97
|
+
class IntParameter:
|
98
|
+
@staticmethod
|
99
|
+
def encode(value: Any) -> api_pb2.ClassParameterValue:
|
100
|
+
return api_pb2.ClassParameterValue(type=api_pb2.PARAM_TYPE_INT, int_value=value)
|
101
|
+
|
102
|
+
@staticmethod
|
103
|
+
def decode(proto_value: api_pb2.ClassParameterValue) -> int:
|
104
|
+
return proto_value.int_value
|
105
|
+
|
106
|
+
@staticmethod
|
107
|
+
def validate(python_value: Any):
|
108
|
+
if not isinstance(python_value, int):
|
109
|
+
raise TypeError(f"Expected int, got {type(python_value).__name__}")
|
110
|
+
|
111
|
+
|
112
|
+
@parameter_serde_registry.register_encoder(str)
|
113
|
+
@parameter_serde_registry.register_decoder(api_pb2.PARAM_TYPE_STRING)
|
114
|
+
class StringParameter:
|
115
|
+
@staticmethod
|
116
|
+
def encode(value: Any) -> api_pb2.ClassParameterValue:
|
117
|
+
return api_pb2.ClassParameterValue(type=api_pb2.PARAM_TYPE_STRING, string_value=value)
|
118
|
+
|
119
|
+
@staticmethod
|
120
|
+
def decode(proto_value: api_pb2.ClassParameterValue) -> str:
|
121
|
+
return proto_value.string_value
|
122
|
+
|
123
|
+
@staticmethod
|
124
|
+
def validate(python_value: Any):
|
125
|
+
if not isinstance(python_value, str):
|
126
|
+
raise TypeError(f"Expected str, got {type(python_value).__name__}")
|
127
|
+
|
128
|
+
|
129
|
+
@parameter_serde_registry.register_encoder(bytes)
|
130
|
+
@parameter_serde_registry.register_decoder(api_pb2.PARAM_TYPE_BYTES)
|
131
|
+
class BytesParameter:
|
132
|
+
@staticmethod
|
133
|
+
def encode(value: Any) -> api_pb2.ClassParameterValue:
|
134
|
+
return api_pb2.ClassParameterValue(type=api_pb2.PARAM_TYPE_BYTES, bytes_value=value)
|
135
|
+
|
136
|
+
@staticmethod
|
137
|
+
def decode(proto_value: api_pb2.ClassParameterValue) -> bytes:
|
138
|
+
return proto_value.bytes_value
|
139
|
+
|
140
|
+
@staticmethod
|
141
|
+
def validate(python_value: Any):
|
142
|
+
if not isinstance(python_value, bytes):
|
143
|
+
raise TypeError(f"Expected bytes, got {type(python_value).__name__}")
|
144
|
+
|
145
|
+
|
146
|
+
SCHEMA_FACTORY_TYPE = typing.Callable[[type], api_pb2.GenericPayloadType]
|
147
|
+
|
148
|
+
|
149
|
+
class SchemaRegistry:
|
150
|
+
_schema_factories: dict[type, SCHEMA_FACTORY_TYPE]
|
151
|
+
|
152
|
+
def __init__(self):
|
153
|
+
self._schema_factories = {}
|
154
|
+
|
155
|
+
def add(self, python_base_type: type) -> typing.Callable[[SCHEMA_FACTORY_TYPE], SCHEMA_FACTORY_TYPE]:
|
156
|
+
# decorator for schema factory functions for a base type
|
157
|
+
def deco(factory_func: SCHEMA_FACTORY_TYPE) -> SCHEMA_FACTORY_TYPE:
|
158
|
+
assert python_base_type not in self._schema_factories
|
159
|
+
self._schema_factories[python_base_type] = factory_func
|
160
|
+
return factory_func
|
161
|
+
|
162
|
+
return deco
|
163
|
+
|
164
|
+
def get(self, python_base_type: type) -> SCHEMA_FACTORY_TYPE:
|
165
|
+
try:
|
166
|
+
return self._schema_factories[python_base_type]
|
167
|
+
except KeyError:
|
168
|
+
return unknown_type_schema
|
169
|
+
|
170
|
+
def get_proto_generic_type(self, declared_type: type):
|
171
|
+
if origin := typing_extensions.get_origin(declared_type):
|
172
|
+
base_type = origin
|
173
|
+
else:
|
174
|
+
base_type = declared_type
|
175
|
+
|
176
|
+
return self.get(base_type)(declared_type)
|
177
|
+
|
178
|
+
|
179
|
+
schema_registry = SchemaRegistry()
|
180
|
+
|
181
|
+
|
182
|
+
@schema_registry.add(int)
|
183
|
+
def int_schema(full_python_type: type) -> api_pb2.GenericPayloadType:
|
184
|
+
return api_pb2.GenericPayloadType(
|
185
|
+
base_type=api_pb2.PARAM_TYPE_INT,
|
186
|
+
)
|
187
|
+
|
188
|
+
|
189
|
+
@schema_registry.add(bytes)
|
190
|
+
def proto_type_def(declared_python_type: type) -> api_pb2.GenericPayloadType:
|
191
|
+
return api_pb2.GenericPayloadType(
|
192
|
+
base_type=api_pb2.PARAM_TYPE_BYTES,
|
193
|
+
)
|
194
|
+
|
195
|
+
|
196
|
+
def unknown_type_schema(declared_python_type: type) -> api_pb2.GenericPayloadType:
|
197
|
+
# TODO: add some metadata for unknown types to the type def?
|
198
|
+
return api_pb2.GenericPayloadType(base_type=api_pb2.PARAM_TYPE_UNKNOWN)
|
199
|
+
|
200
|
+
|
201
|
+
@schema_registry.add(str)
|
202
|
+
def str_schema(full_python_type: type) -> api_pb2.GenericPayloadType:
|
203
|
+
return api_pb2.GenericPayloadType(
|
204
|
+
base_type=api_pb2.PARAM_TYPE_STRING,
|
205
|
+
)
|
206
|
+
|
207
|
+
|
208
|
+
@schema_registry.add(type(None))
|
209
|
+
def none_type_schema(declared_python_type: type) -> api_pb2.GenericPayloadType:
|
210
|
+
return api_pb2.GenericPayloadType(base_type=api_pb2.PARAM_TYPE_NONE)
|
211
|
+
|
212
|
+
|
213
|
+
@schema_registry.add(list)
|
214
|
+
def list_schema(full_python_type: type) -> api_pb2.GenericPayloadType:
|
215
|
+
args = typing_extensions.get_args(full_python_type)
|
216
|
+
|
217
|
+
return api_pb2.GenericPayloadType(
|
218
|
+
base_type=api_pb2.PARAM_TYPE_LIST, sub_types=[schema_registry.get_proto_generic_type(arg) for arg in args]
|
219
|
+
)
|
220
|
+
|
221
|
+
|
222
|
+
@schema_registry.add(dict)
|
223
|
+
def dict_schema(full_python_type: type) -> api_pb2.GenericPayloadType:
|
224
|
+
args = typing_extensions.get_args(full_python_type)
|
225
|
+
|
226
|
+
return api_pb2.GenericPayloadType(
|
227
|
+
base_type=api_pb2.PARAM_TYPE_DICT,
|
228
|
+
sub_types=[schema_registry.get_proto_generic_type(arg_type) for arg_type in args],
|
229
|
+
)
|
modal/_utils/function_utils.py
CHANGED
@@ -16,12 +16,10 @@ import modal_proto
|
|
16
16
|
from modal_proto import api_pb2
|
17
17
|
|
18
18
|
from .._serialization import (
|
19
|
-
PROTO_TYPE_INFO,
|
20
|
-
PYTHON_TO_PROTO_TYPE,
|
21
19
|
deserialize,
|
22
20
|
deserialize_data_format,
|
23
|
-
get_proto_parameter_type,
|
24
21
|
serialize,
|
22
|
+
signature_to_parameter_specs,
|
25
23
|
)
|
26
24
|
from .._traceback import append_modal_tb
|
27
25
|
from ..config import config, logger
|
@@ -106,24 +104,6 @@ def get_function_type(is_generator: Optional[bool]) -> "api_pb2.Function.Functio
|
|
106
104
|
return api_pb2.Function.FUNCTION_TYPE_GENERATOR if is_generator else api_pb2.Function.FUNCTION_TYPE_FUNCTION
|
107
105
|
|
108
106
|
|
109
|
-
def signature_to_protobuf_schema(signature: inspect.Signature) -> list[api_pb2.ClassParameterSpec]:
|
110
|
-
modal_parameters: list[api_pb2.ClassParameterSpec] = []
|
111
|
-
for param in signature.parameters.values():
|
112
|
-
has_default = param.default is not param.empty
|
113
|
-
class_param_spec = api_pb2.ClassParameterSpec(name=param.name, has_default=has_default)
|
114
|
-
if param.annotation not in PYTHON_TO_PROTO_TYPE:
|
115
|
-
class_param_spec.type = api_pb2.PARAM_TYPE_UNKNOWN
|
116
|
-
else:
|
117
|
-
proto_type = PYTHON_TO_PROTO_TYPE[param.annotation]
|
118
|
-
class_param_spec.type = proto_type
|
119
|
-
proto_type_info = PROTO_TYPE_INFO[proto_type]
|
120
|
-
if has_default and proto_type is not api_pb2.PARAM_TYPE_UNKNOWN:
|
121
|
-
setattr(class_param_spec, proto_type_info.default_field, param.default)
|
122
|
-
|
123
|
-
modal_parameters.append(class_param_spec)
|
124
|
-
return modal_parameters
|
125
|
-
|
126
|
-
|
127
107
|
class FunctionInfo:
|
128
108
|
"""Utility that determines serialization/deserialization mechanisms for functions
|
129
109
|
|
@@ -310,15 +290,12 @@ class FunctionInfo:
|
|
310
290
|
# annotation parameters trigger strictly typed parametrization
|
311
291
|
# which enables web endpoint for parametrized classes
|
312
292
|
signature = _get_class_constructor_signature(self.user_cls)
|
313
|
-
#
|
314
|
-
|
315
|
-
get_proto_parameter_type(param.annotation)
|
316
|
-
|
317
|
-
protobuf_schema = signature_to_protobuf_schema(signature)
|
293
|
+
# at this point, the types in the signature should already have been validated (see Cls.from_local())
|
294
|
+
parameter_specs = signature_to_parameter_specs(signature)
|
318
295
|
|
319
296
|
return api_pb2.ClassParameterInfo(
|
320
297
|
format=api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PROTO,
|
321
|
-
schema=
|
298
|
+
schema=parameter_specs,
|
322
299
|
)
|
323
300
|
|
324
301
|
def get_entrypoint_mount(self) -> dict[str, _Mount]:
|
modal/client.pyi
CHANGED
@@ -31,7 +31,7 @@ class _Client:
|
|
31
31
|
server_url: str,
|
32
32
|
client_type: int,
|
33
33
|
credentials: typing.Optional[tuple[str, str]],
|
34
|
-
version: str = "0.73.
|
34
|
+
version: str = "0.73.131",
|
35
35
|
): ...
|
36
36
|
def is_closed(self) -> bool: ...
|
37
37
|
@property
|
@@ -93,7 +93,7 @@ class Client:
|
|
93
93
|
server_url: str,
|
94
94
|
client_type: int,
|
95
95
|
credentials: typing.Optional[tuple[str, str]],
|
96
|
-
version: str = "0.73.
|
96
|
+
version: str = "0.73.131",
|
97
97
|
): ...
|
98
98
|
def is_closed(self) -> bool: ...
|
99
99
|
@property
|
modal/cls.py
CHANGED
@@ -21,8 +21,9 @@ from ._partial_function import (
|
|
21
21
|
)
|
22
22
|
from ._resolver import Resolver
|
23
23
|
from ._resources import convert_fn_config_to_resources_config
|
24
|
-
from ._serialization import check_valid_cls_constructor_arg
|
24
|
+
from ._serialization import check_valid_cls_constructor_arg
|
25
25
|
from ._traceback import print_server_warnings
|
26
|
+
from ._type_manager import parameter_serde_registry
|
26
27
|
from ._utils.async_utils import synchronize_api, synchronizer
|
27
28
|
from ._utils.deprecation import deprecation_warning, renamed_parameter, warn_on_renamed_autoscaler_settings
|
28
29
|
from ._utils.grpc_utils import retry_transient_errors
|
@@ -462,7 +463,10 @@ class _Cls(_Object, type_prefix="cs"):
|
|
462
463
|
|
463
464
|
annotated_params = {k: t for k, t in annotations.items() if k in params}
|
464
465
|
for k, t in annotated_params.items():
|
465
|
-
|
466
|
+
try:
|
467
|
+
parameter_serde_registry.validate_parameter_type(t)
|
468
|
+
except TypeError as exc:
|
469
|
+
raise InvalidError(f"Class parameter '{k}': {exc}")
|
466
470
|
|
467
471
|
@staticmethod
|
468
472
|
def from_local(user_cls, app: "modal.app._App", class_service_function: _Function) -> "_Cls":
|
modal/functions.pyi
CHANGED
@@ -198,11 +198,11 @@ class Function(
|
|
198
198
|
|
199
199
|
_call_generator_nowait: ___call_generator_nowait_spec[typing_extensions.Self]
|
200
200
|
|
201
|
-
class __remote_spec(typing_extensions.Protocol[
|
201
|
+
class __remote_spec(typing_extensions.Protocol[ReturnType_INNER, P_INNER, SUPERSELF]):
|
202
202
|
def __call__(self, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> ReturnType_INNER: ...
|
203
203
|
async def aio(self, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> ReturnType_INNER: ...
|
204
204
|
|
205
|
-
remote: __remote_spec[modal._functions.
|
205
|
+
remote: __remote_spec[modal._functions.ReturnType, modal._functions.P, typing_extensions.Self]
|
206
206
|
|
207
207
|
class __remote_gen_spec(typing_extensions.Protocol[SUPERSELF]):
|
208
208
|
def __call__(self, *args, **kwargs) -> typing.Generator[typing.Any, None, None]: ...
|
@@ -217,19 +217,19 @@ class Function(
|
|
217
217
|
self, *args: modal._functions.P.args, **kwargs: modal._functions.P.kwargs
|
218
218
|
) -> modal._functions.OriginalReturnType: ...
|
219
219
|
|
220
|
-
class ___experimental_spawn_spec(typing_extensions.Protocol[
|
220
|
+
class ___experimental_spawn_spec(typing_extensions.Protocol[ReturnType_INNER, P_INNER, SUPERSELF]):
|
221
221
|
def __call__(self, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> FunctionCall[ReturnType_INNER]: ...
|
222
222
|
async def aio(self, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> FunctionCall[ReturnType_INNER]: ...
|
223
223
|
|
224
224
|
_experimental_spawn: ___experimental_spawn_spec[
|
225
|
-
modal._functions.
|
225
|
+
modal._functions.ReturnType, modal._functions.P, typing_extensions.Self
|
226
226
|
]
|
227
227
|
|
228
|
-
class __spawn_spec(typing_extensions.Protocol[
|
228
|
+
class __spawn_spec(typing_extensions.Protocol[ReturnType_INNER, P_INNER, SUPERSELF]):
|
229
229
|
def __call__(self, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> FunctionCall[ReturnType_INNER]: ...
|
230
230
|
async def aio(self, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> FunctionCall[ReturnType_INNER]: ...
|
231
231
|
|
232
|
-
spawn: __spawn_spec[modal._functions.
|
232
|
+
spawn: __spawn_spec[modal._functions.ReturnType, modal._functions.P, typing_extensions.Self]
|
233
233
|
|
234
234
|
def get_raw_f(self) -> collections.abc.Callable[..., typing.Any]: ...
|
235
235
|
|