modal 0.73.130__py3-none-any.whl → 0.73.131__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
modal/_functions.py CHANGED
@@ -25,7 +25,12 @@ from ._pty import get_pty_info
25
25
  from ._resolver import Resolver
26
26
  from ._resources import convert_fn_config_to_resources_config
27
27
  from ._runtime.execution_context import current_input_id, is_local
28
- from ._serialization import apply_defaults, serialize, serialize_proto_params, validate_params
28
+ from ._serialization import (
29
+ apply_defaults,
30
+ serialize,
31
+ serialize_proto_params,
32
+ validate_parameter_values,
33
+ )
29
34
  from ._traceback import print_server_warnings
30
35
  from ._utils.async_utils import (
31
36
  TaskContext,
@@ -975,7 +980,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
975
980
  )
976
981
  schema = parent._class_parameter_info.schema
977
982
  kwargs_with_defaults = apply_defaults(kwargs, schema)
978
- validate_params(kwargs_with_defaults, schema)
983
+ validate_parameter_values(kwargs_with_defaults, schema)
979
984
  serialized_params = serialize_proto_params(kwargs_with_defaults)
980
985
  can_use_parent = len(parent._class_parameter_info.schema) == 0 # no parameters
981
986
  else:
@@ -1312,7 +1317,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
1312
1317
  order_outputs,
1313
1318
  return_exceptions,
1314
1319
  count_update_callback,
1315
- api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC
1320
+ api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC,
1316
1321
  )
1317
1322
  ) as stream:
1318
1323
  async for item in stream:
modal/_resolver.py CHANGED
@@ -1,6 +1,7 @@
1
1
  # Copyright Modal Labs 2023
2
2
  import asyncio
3
3
  import contextlib
4
+ import traceback
4
5
  import typing
5
6
  from asyncio import Future
6
7
  from collections.abc import Hashable
@@ -153,7 +154,11 @@ class Resolver:
153
154
  self._deduplication_cache[deduplication_key] = cached_future
154
155
 
155
156
  # TODO(elias): print original exception/trace rather than the Resolver-internal trace
156
- return await cached_future
157
+ try:
158
+ return await cached_future
159
+ except Exception:
160
+ traceback.print_exc()
161
+ raise
157
162
 
158
163
  def objects(self) -> list["modal._object._Object"]:
159
164
  unique_objects: dict[str, "modal._object._Object"] = {}
modal/_serialization.py CHANGED
@@ -1,14 +1,16 @@
1
1
  # Copyright Modal Labs 2022
2
+ import inspect
2
3
  import io
3
4
  import pickle
4
5
  import typing
5
- from dataclasses import dataclass
6
+ from inspect import Parameter
6
7
  from typing import Any
7
8
 
8
9
  from modal._utils.async_utils import synchronizer
9
10
  from modal_proto import api_pb2
10
11
 
11
12
  from ._object import _Object
13
+ from ._type_manager import parameter_serde_registry, schema_registry
12
14
  from ._vendor import cloudpickle
13
15
  from .config import logger
14
16
  from .exception import DeserializationError, ExecutionError, InvalidError
@@ -389,50 +391,6 @@ def check_valid_cls_constructor_arg(key, obj):
389
391
  )
390
392
 
391
393
 
392
- def assert_bytes(obj: Any):
393
- if not isinstance(obj, bytes):
394
- raise TypeError(f"Expected bytes, got {type(obj)}")
395
- return obj
396
-
397
-
398
- @dataclass
399
- class ParamTypeInfo:
400
- default_field: str
401
- proto_field: str
402
- converter: typing.Callable[[str], typing.Any]
403
- type: type
404
-
405
-
406
- PYTHON_TO_PROTO_TYPE: dict[type, "api_pb2.ParameterType.ValueType"] = {
407
- # python type -> protobuf type enum
408
- str: api_pb2.PARAM_TYPE_STRING,
409
- int: api_pb2.PARAM_TYPE_INT,
410
- bytes: api_pb2.PARAM_TYPE_BYTES,
411
- }
412
-
413
- PROTO_TYPE_INFO = {
414
- # Protobuf type enum -> encode/decode helper metadata
415
- api_pb2.PARAM_TYPE_STRING: ParamTypeInfo(
416
- default_field="string_default",
417
- proto_field="string_value",
418
- converter=str,
419
- type=str,
420
- ),
421
- api_pb2.PARAM_TYPE_INT: ParamTypeInfo(
422
- default_field="int_default",
423
- proto_field="int_value",
424
- converter=int,
425
- type=int,
426
- ),
427
- api_pb2.PARAM_TYPE_BYTES: ParamTypeInfo(
428
- default_field="bytes_default",
429
- proto_field="bytes_value",
430
- converter=assert_bytes,
431
- type=bytes,
432
- ),
433
- }
434
-
435
-
436
394
  def apply_defaults(
437
395
  python_params: typing.Mapping[str, Any], schema: typing.Sequence[api_pb2.ClassParameterSpec]
438
396
  ) -> dict[str, Any]:
@@ -453,68 +411,56 @@ def apply_defaults(
453
411
  return result
454
412
 
455
413
 
414
+ def encode_parameter_value(name: str, python_value: Any) -> api_pb2.ClassParameterValue:
415
+ """Map to proto parameter representation using python runtime type information"""
416
+ struct = parameter_serde_registry.encode(python_value)
417
+ struct.name = name
418
+ return struct
419
+
420
+
456
421
  def serialize_proto_params(python_params: dict[str, Any]) -> bytes:
457
422
  proto_params: list[api_pb2.ClassParameterValue] = []
458
423
  for param_name, python_value in python_params.items():
459
- python_type = type(python_value)
460
- protobuf_type = get_proto_parameter_type(python_type)
461
- type_info = PROTO_TYPE_INFO.get(protobuf_type)
462
- proto_param = api_pb2.ClassParameterValue(
463
- name=param_name,
464
- type=protobuf_type,
465
- )
466
- try:
467
- converted_value = type_info.converter(python_value)
468
- except ValueError as exc:
469
- raise ValueError(f"Invalid type for parameter {param_name}: {exc}")
470
- setattr(proto_param, type_info.proto_field, converted_value)
471
- proto_params.append(proto_param)
424
+ proto_params.append(encode_parameter_value(param_name, python_value))
472
425
  proto_bytes = api_pb2.ClassParameterSet(parameters=proto_params).SerializeToString(deterministic=True)
473
426
  return proto_bytes
474
427
 
475
428
 
476
429
  def deserialize_proto_params(serialized_params: bytes) -> dict[str, Any]:
477
- proto_struct = api_pb2.ClassParameterSet()
478
- proto_struct.ParseFromString(serialized_params)
430
+ proto_struct = api_pb2.ClassParameterSet.FromString(serialized_params)
479
431
  python_params = {}
480
432
  for param in proto_struct.parameters:
481
- python_value: Any
482
- if param.type == api_pb2.PARAM_TYPE_STRING:
483
- python_value = param.string_value
484
- elif param.type == api_pb2.PARAM_TYPE_INT:
485
- python_value = param.int_value
486
- elif param.type == api_pb2.PARAM_TYPE_BYTES:
487
- python_value = param.bytes_value
488
- else:
489
- raise NotImplementedError(f"Unimplemented parameter type: {param.type}.")
490
-
491
- python_params[param.name] = python_value
433
+ python_params[param.name] = parameter_serde_registry.decode(param)
492
434
 
493
435
  return python_params
494
436
 
495
437
 
496
- def validate_params(params: dict[str, Any], schema: typing.Sequence[api_pb2.ClassParameterSpec]):
497
- # first check that all declared values are provided
498
- for schema_param in schema:
499
- if schema_param.name not in params:
500
- # we expect all values to be present - even defaulted ones (defaults are applied on payload construction)
501
- raise InvalidError(f"Missing required parameter: {schema_param.name}")
502
- python_value = params[schema_param.name]
503
- python_type = type(python_value)
504
- param_protobuf_type = get_proto_parameter_type(python_type)
505
- if schema_param.type != param_protobuf_type:
506
- expected_python_type = PROTO_TYPE_INFO[schema_param.type].type
507
- raise TypeError(
508
- f"Parameter '{schema_param.name}' type error: expected {expected_python_type.__name__}, "
509
- f"got {python_type.__name__}"
510
- )
438
+ def validate_parameter_values(payload: dict[str, Any], schema: typing.Sequence[api_pb2.ClassParameterSpec]):
439
+ """Ensure parameter payload conforms to the schema of a class
440
+
441
+ Checks that:
442
+ * All fields are specified (defaults are expected to already be applied on the payload)
443
+ * No extra fields are specified
444
+ * The type of each field is correct
445
+ """
446
+ for param_spec in schema:
447
+ if param_spec.name not in payload:
448
+ raise InvalidError(f"Missing required parameter: {param_spec.name}")
449
+ python_value = payload[param_spec.name]
450
+ if param_spec.HasField("full_type") and param_spec.full_type.base_type:
451
+ type_enum_value = param_spec.full_type.base_type
452
+ else:
453
+ type_enum_value = param_spec.type # backwards compatibility pre-full_type
454
+
455
+ parameter_serde_registry.validate_value_for_enum_type(type_enum_value, python_value)
511
456
 
512
457
  schema_fields = {p.name for p in schema}
513
458
  # then check that no extra values are provided
514
- non_declared_fields = params.keys() - schema_fields
459
+ non_declared_fields = payload.keys() - schema_fields
515
460
  if non_declared_fields:
516
461
  raise InvalidError(
517
- f"The following parameter names were provided but are not present in the schema: {non_declared_fields}"
462
+ f"The following parameter names were provided but are not defined class modal.parameters for the class: "
463
+ f"{', '.join(non_declared_fields)}"
518
464
  )
519
465
 
520
466
 
@@ -528,8 +474,6 @@ def deserialize_params(serialized_params: bytes, function_def: api_pb2.Function,
528
474
  elif function_def.class_parameter_info.format == api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PROTO:
529
475
  param_args = () # we use kwargs only for our implicit constructors
530
476
  param_kwargs = deserialize_proto_params(serialized_params)
531
- # TODO: We can probably remove the validation below since we do validation in the caller?
532
- validate_params(param_kwargs, list(function_def.class_parameter_info.schema))
533
477
  else:
534
478
  raise ExecutionError(
535
479
  f"Unknown class parameter serialization format: {function_def.class_parameter_info.format}"
@@ -538,9 +482,47 @@ def deserialize_params(serialized_params: bytes, function_def: api_pb2.Function,
538
482
  return param_args, param_kwargs
539
483
 
540
484
 
541
- def get_proto_parameter_type(parameter_type: type) -> "api_pb2.ParameterType.ValueType":
542
- if parameter_type not in PYTHON_TO_PROTO_TYPE:
543
- type_name = getattr(parameter_type, "__name__", repr(parameter_type))
544
- supported = ", ".join(parameter_type.__name__ for parameter_type in PYTHON_TO_PROTO_TYPE.keys())
545
- raise InvalidError(f"{type_name} is not a supported parameter type. Use one of: {supported}")
546
- return PYTHON_TO_PROTO_TYPE[parameter_type]
485
+ def _signature_parameter_to_spec(
486
+ python_signature_parameter: inspect.Parameter, include_legacy_parameter_fields: bool = False
487
+ ) -> api_pb2.ClassParameterSpec:
488
+ """Returns proto representation of Parameter as returned by inspect.signature()
489
+
490
+ Setting include_legacy_parameter_fields makes the output backwards compatible with
491
+ pre v0.74 clients looking at class parameter specifications, and should not be used
492
+ when registering *function* schemas.
493
+ """
494
+ declared_type = python_signature_parameter.annotation
495
+ full_proto_type = schema_registry.get_proto_generic_type(declared_type)
496
+ has_default = python_signature_parameter.default is not Parameter.empty
497
+
498
+ field_spec = api_pb2.ClassParameterSpec(
499
+ name=python_signature_parameter.name,
500
+ full_type=full_proto_type,
501
+ has_default=has_default,
502
+ )
503
+ if include_legacy_parameter_fields:
504
+ # add the .{type}_default and `.type` values as required by legacy clients
505
+ # looking at class parameter specs
506
+ if full_proto_type.base_type == api_pb2.PARAM_TYPE_INT:
507
+ if has_default:
508
+ field_spec.int_default = python_signature_parameter.default
509
+ field_spec.type = api_pb2.PARAM_TYPE_INT
510
+ elif full_proto_type.base_type == api_pb2.PARAM_TYPE_STRING:
511
+ if has_default:
512
+ field_spec.string_default = python_signature_parameter.default
513
+ field_spec.type = api_pb2.PARAM_TYPE_STRING
514
+ elif full_proto_type.base_type == api_pb2.PARAM_TYPE_BYTES:
515
+ if has_default:
516
+ field_spec.bytes_default = python_signature_parameter.default
517
+ field_spec.type = api_pb2.PARAM_TYPE_BYTES
518
+
519
+ return field_spec
520
+
521
+
522
+ def signature_to_parameter_specs(signature: inspect.Signature) -> list[api_pb2.ClassParameterSpec]:
523
+ # only used for modal.parameter() specs, uses backwards compatible fields and types
524
+ modal_parameters: list[api_pb2.ClassParameterSpec] = []
525
+ for param in signature.parameters.values():
526
+ field_spec = _signature_parameter_to_spec(param, include_legacy_parameter_fields=True)
527
+ modal_parameters.append(field_spec)
528
+ return modal_parameters
modal/_type_manager.py ADDED
@@ -0,0 +1,229 @@
1
+ # Copyright Modal Labs 2025
2
+ import typing
3
+ from typing import Any
4
+
5
+ import typing_extensions
6
+
7
+ from modal.exception import InvalidError
8
+ from modal_proto import api_pb2
9
+
10
+
11
+ class ParameterProtoSerde(typing.Protocol):
12
+ def encode(self, value: Any) -> api_pb2.ClassParameterValue: ...
13
+
14
+ def decode(self, proto_value: api_pb2.ClassParameterValue) -> Any: ...
15
+
16
+ def validate(self, python_value: Any): ...
17
+
18
+
19
+ class ProtoParameterSerdeRegistry:
20
+ _py_base_type_to_serde: dict[type, ParameterProtoSerde]
21
+ _proto_type_to_serde: dict["api_pb2.ParameterType.ValueType", ParameterProtoSerde]
22
+
23
+ def __init__(self):
24
+ self._py_base_type_to_serde = {}
25
+ self._proto_type_to_serde = {}
26
+
27
+ def register_encoder(self, python_base_type: type) -> typing.Callable[[ParameterProtoSerde], ParameterProtoSerde]:
28
+ def deco(ph: ParameterProtoSerde) -> ParameterProtoSerde:
29
+ if python_base_type in self._py_base_type_to_serde:
30
+ raise ValueError("Can't register the same encoder type twice")
31
+ self._py_base_type_to_serde[python_base_type] = ph
32
+ return ph
33
+
34
+ return deco
35
+
36
+ def register_decoder(
37
+ self, enum_type_value: "api_pb2.ParameterType.ValueType"
38
+ ) -> typing.Callable[[ParameterProtoSerde], ParameterProtoSerde]:
39
+ def deco(ph: ParameterProtoSerde) -> ParameterProtoSerde:
40
+ if enum_type_value in self._proto_type_to_serde:
41
+ raise ValueError("Can't register the same decoder type twice")
42
+ self._proto_type_to_serde[enum_type_value] = ph
43
+ return ph
44
+
45
+ return deco
46
+
47
+ def encode(self, python_value: Any) -> api_pb2.ClassParameterValue:
48
+ return self._get_encoder(type(python_value)).encode(python_value)
49
+
50
+ def supports_type(self, declared_type: type) -> bool:
51
+ try:
52
+ self._get_encoder(declared_type)
53
+ return True
54
+ except InvalidError:
55
+ return False
56
+
57
+ def decode(self, param_value: api_pb2.ClassParameterValue) -> Any:
58
+ return self._get_decoder(param_value.type).decode(param_value)
59
+
60
+ def validate_parameter_type(self, declared_type: type):
61
+ """Raises a helpful TypeError if the supplied type isn't supported by class parameters"""
62
+ if not parameter_serde_registry.supports_type(declared_type):
63
+ supported_types = self._py_base_type_to_serde.keys()
64
+ supported_str = ", ".join(t.__name__ for t in supported_types)
65
+
66
+ raise TypeError(
67
+ f"{declared_type.__name__} is not a supported modal.parameter() type. Use one of: {supported_str}"
68
+ )
69
+
70
+ def validate_value_for_enum_type(self, enum_value: "api_pb2.ParameterType.ValueType", python_value: Any):
71
+ serde = self._get_decoder(enum_value) # use the schema's expected decoder
72
+ serde.validate(python_value)
73
+
74
+ def _get_encoder(self, python_base_type: type) -> ParameterProtoSerde:
75
+ try:
76
+ return self._py_base_type_to_serde[python_base_type]
77
+ except KeyError:
78
+ raise InvalidError(f"No class parameter encoder implemented for type `{python_base_type.__name__}`")
79
+
80
+ def _get_decoder(self, enum_value: "api_pb2.ParameterType.ValueType") -> ParameterProtoSerde:
81
+ try:
82
+ return self._proto_type_to_serde[enum_value]
83
+ except KeyError:
84
+ try:
85
+ enum_name = api_pb2.ParameterType.Name(enum_value)
86
+ except ValueError:
87
+ enum_name = str(enum_value)
88
+
89
+ raise InvalidError(f"No class parameter decoder implemented for type {enum_name}.")
90
+
91
+
92
+ parameter_serde_registry = ProtoParameterSerdeRegistry()
93
+
94
+
95
+ @parameter_serde_registry.register_encoder(int)
96
+ @parameter_serde_registry.register_decoder(api_pb2.PARAM_TYPE_INT)
97
+ class IntParameter:
98
+ @staticmethod
99
+ def encode(value: Any) -> api_pb2.ClassParameterValue:
100
+ return api_pb2.ClassParameterValue(type=api_pb2.PARAM_TYPE_INT, int_value=value)
101
+
102
+ @staticmethod
103
+ def decode(proto_value: api_pb2.ClassParameterValue) -> int:
104
+ return proto_value.int_value
105
+
106
+ @staticmethod
107
+ def validate(python_value: Any):
108
+ if not isinstance(python_value, int):
109
+ raise TypeError(f"Expected int, got {type(python_value).__name__}")
110
+
111
+
112
+ @parameter_serde_registry.register_encoder(str)
113
+ @parameter_serde_registry.register_decoder(api_pb2.PARAM_TYPE_STRING)
114
+ class StringParameter:
115
+ @staticmethod
116
+ def encode(value: Any) -> api_pb2.ClassParameterValue:
117
+ return api_pb2.ClassParameterValue(type=api_pb2.PARAM_TYPE_STRING, string_value=value)
118
+
119
+ @staticmethod
120
+ def decode(proto_value: api_pb2.ClassParameterValue) -> str:
121
+ return proto_value.string_value
122
+
123
+ @staticmethod
124
+ def validate(python_value: Any):
125
+ if not isinstance(python_value, str):
126
+ raise TypeError(f"Expected str, got {type(python_value).__name__}")
127
+
128
+
129
+ @parameter_serde_registry.register_encoder(bytes)
130
+ @parameter_serde_registry.register_decoder(api_pb2.PARAM_TYPE_BYTES)
131
+ class BytesParameter:
132
+ @staticmethod
133
+ def encode(value: Any) -> api_pb2.ClassParameterValue:
134
+ return api_pb2.ClassParameterValue(type=api_pb2.PARAM_TYPE_BYTES, bytes_value=value)
135
+
136
+ @staticmethod
137
+ def decode(proto_value: api_pb2.ClassParameterValue) -> bytes:
138
+ return proto_value.bytes_value
139
+
140
+ @staticmethod
141
+ def validate(python_value: Any):
142
+ if not isinstance(python_value, bytes):
143
+ raise TypeError(f"Expected bytes, got {type(python_value).__name__}")
144
+
145
+
146
+ SCHEMA_FACTORY_TYPE = typing.Callable[[type], api_pb2.GenericPayloadType]
147
+
148
+
149
+ class SchemaRegistry:
150
+ _schema_factories: dict[type, SCHEMA_FACTORY_TYPE]
151
+
152
+ def __init__(self):
153
+ self._schema_factories = {}
154
+
155
+ def add(self, python_base_type: type) -> typing.Callable[[SCHEMA_FACTORY_TYPE], SCHEMA_FACTORY_TYPE]:
156
+ # decorator for schema factory functions for a base type
157
+ def deco(factory_func: SCHEMA_FACTORY_TYPE) -> SCHEMA_FACTORY_TYPE:
158
+ assert python_base_type not in self._schema_factories
159
+ self._schema_factories[python_base_type] = factory_func
160
+ return factory_func
161
+
162
+ return deco
163
+
164
+ def get(self, python_base_type: type) -> SCHEMA_FACTORY_TYPE:
165
+ try:
166
+ return self._schema_factories[python_base_type]
167
+ except KeyError:
168
+ return unknown_type_schema
169
+
170
+ def get_proto_generic_type(self, declared_type: type):
171
+ if origin := typing_extensions.get_origin(declared_type):
172
+ base_type = origin
173
+ else:
174
+ base_type = declared_type
175
+
176
+ return self.get(base_type)(declared_type)
177
+
178
+
179
+ schema_registry = SchemaRegistry()
180
+
181
+
182
+ @schema_registry.add(int)
183
+ def int_schema(full_python_type: type) -> api_pb2.GenericPayloadType:
184
+ return api_pb2.GenericPayloadType(
185
+ base_type=api_pb2.PARAM_TYPE_INT,
186
+ )
187
+
188
+
189
+ @schema_registry.add(bytes)
190
+ def proto_type_def(declared_python_type: type) -> api_pb2.GenericPayloadType:
191
+ return api_pb2.GenericPayloadType(
192
+ base_type=api_pb2.PARAM_TYPE_BYTES,
193
+ )
194
+
195
+
196
+ def unknown_type_schema(declared_python_type: type) -> api_pb2.GenericPayloadType:
197
+ # TODO: add some metadata for unknown types to the type def?
198
+ return api_pb2.GenericPayloadType(base_type=api_pb2.PARAM_TYPE_UNKNOWN)
199
+
200
+
201
+ @schema_registry.add(str)
202
+ def str_schema(full_python_type: type) -> api_pb2.GenericPayloadType:
203
+ return api_pb2.GenericPayloadType(
204
+ base_type=api_pb2.PARAM_TYPE_STRING,
205
+ )
206
+
207
+
208
+ @schema_registry.add(type(None))
209
+ def none_type_schema(declared_python_type: type) -> api_pb2.GenericPayloadType:
210
+ return api_pb2.GenericPayloadType(base_type=api_pb2.PARAM_TYPE_NONE)
211
+
212
+
213
+ @schema_registry.add(list)
214
+ def list_schema(full_python_type: type) -> api_pb2.GenericPayloadType:
215
+ args = typing_extensions.get_args(full_python_type)
216
+
217
+ return api_pb2.GenericPayloadType(
218
+ base_type=api_pb2.PARAM_TYPE_LIST, sub_types=[schema_registry.get_proto_generic_type(arg) for arg in args]
219
+ )
220
+
221
+
222
+ @schema_registry.add(dict)
223
+ def dict_schema(full_python_type: type) -> api_pb2.GenericPayloadType:
224
+ args = typing_extensions.get_args(full_python_type)
225
+
226
+ return api_pb2.GenericPayloadType(
227
+ base_type=api_pb2.PARAM_TYPE_DICT,
228
+ sub_types=[schema_registry.get_proto_generic_type(arg_type) for arg_type in args],
229
+ )
@@ -16,12 +16,10 @@ import modal_proto
16
16
  from modal_proto import api_pb2
17
17
 
18
18
  from .._serialization import (
19
- PROTO_TYPE_INFO,
20
- PYTHON_TO_PROTO_TYPE,
21
19
  deserialize,
22
20
  deserialize_data_format,
23
- get_proto_parameter_type,
24
21
  serialize,
22
+ signature_to_parameter_specs,
25
23
  )
26
24
  from .._traceback import append_modal_tb
27
25
  from ..config import config, logger
@@ -106,24 +104,6 @@ def get_function_type(is_generator: Optional[bool]) -> "api_pb2.Function.Functio
106
104
  return api_pb2.Function.FUNCTION_TYPE_GENERATOR if is_generator else api_pb2.Function.FUNCTION_TYPE_FUNCTION
107
105
 
108
106
 
109
- def signature_to_protobuf_schema(signature: inspect.Signature) -> list[api_pb2.ClassParameterSpec]:
110
- modal_parameters: list[api_pb2.ClassParameterSpec] = []
111
- for param in signature.parameters.values():
112
- has_default = param.default is not param.empty
113
- class_param_spec = api_pb2.ClassParameterSpec(name=param.name, has_default=has_default)
114
- if param.annotation not in PYTHON_TO_PROTO_TYPE:
115
- class_param_spec.type = api_pb2.PARAM_TYPE_UNKNOWN
116
- else:
117
- proto_type = PYTHON_TO_PROTO_TYPE[param.annotation]
118
- class_param_spec.type = proto_type
119
- proto_type_info = PROTO_TYPE_INFO[proto_type]
120
- if has_default and proto_type is not api_pb2.PARAM_TYPE_UNKNOWN:
121
- setattr(class_param_spec, proto_type_info.default_field, param.default)
122
-
123
- modal_parameters.append(class_param_spec)
124
- return modal_parameters
125
-
126
-
127
107
  class FunctionInfo:
128
108
  """Utility that determines serialization/deserialization mechanisms for functions
129
109
 
@@ -310,15 +290,12 @@ class FunctionInfo:
310
290
  # annotation parameters trigger strictly typed parametrization
311
291
  # which enables web endpoint for parametrized classes
312
292
  signature = _get_class_constructor_signature(self.user_cls)
313
- # validate that the schema has no unspecified fields/unsupported class parameter types
314
- for param in signature.parameters.values():
315
- get_proto_parameter_type(param.annotation)
316
-
317
- protobuf_schema = signature_to_protobuf_schema(signature)
293
+ # at this point, the types in the signature should already have been validated (see Cls.from_local())
294
+ parameter_specs = signature_to_parameter_specs(signature)
318
295
 
319
296
  return api_pb2.ClassParameterInfo(
320
297
  format=api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PROTO,
321
- schema=protobuf_schema,
298
+ schema=parameter_specs,
322
299
  )
323
300
 
324
301
  def get_entrypoint_mount(self) -> dict[str, _Mount]:
modal/client.pyi CHANGED
@@ -31,7 +31,7 @@ class _Client:
31
31
  server_url: str,
32
32
  client_type: int,
33
33
  credentials: typing.Optional[tuple[str, str]],
34
- version: str = "0.73.130",
34
+ version: str = "0.73.131",
35
35
  ): ...
36
36
  def is_closed(self) -> bool: ...
37
37
  @property
@@ -93,7 +93,7 @@ class Client:
93
93
  server_url: str,
94
94
  client_type: int,
95
95
  credentials: typing.Optional[tuple[str, str]],
96
- version: str = "0.73.130",
96
+ version: str = "0.73.131",
97
97
  ): ...
98
98
  def is_closed(self) -> bool: ...
99
99
  @property
modal/cls.py CHANGED
@@ -21,8 +21,9 @@ from ._partial_function import (
21
21
  )
22
22
  from ._resolver import Resolver
23
23
  from ._resources import convert_fn_config_to_resources_config
24
- from ._serialization import check_valid_cls_constructor_arg, get_proto_parameter_type
24
+ from ._serialization import check_valid_cls_constructor_arg
25
25
  from ._traceback import print_server_warnings
26
+ from ._type_manager import parameter_serde_registry
26
27
  from ._utils.async_utils import synchronize_api, synchronizer
27
28
  from ._utils.deprecation import deprecation_warning, renamed_parameter, warn_on_renamed_autoscaler_settings
28
29
  from ._utils.grpc_utils import retry_transient_errors
@@ -462,7 +463,10 @@ class _Cls(_Object, type_prefix="cs"):
462
463
 
463
464
  annotated_params = {k: t for k, t in annotations.items() if k in params}
464
465
  for k, t in annotated_params.items():
465
- get_proto_parameter_type(t)
466
+ try:
467
+ parameter_serde_registry.validate_parameter_type(t)
468
+ except TypeError as exc:
469
+ raise InvalidError(f"Class parameter '{k}': {exc}")
466
470
 
467
471
  @staticmethod
468
472
  def from_local(user_cls, app: "modal.app._App", class_service_function: _Function) -> "_Cls":
modal/functions.pyi CHANGED
@@ -198,11 +198,11 @@ class Function(
198
198
 
199
199
  _call_generator_nowait: ___call_generator_nowait_spec[typing_extensions.Self]
200
200
 
201
- class __remote_spec(typing_extensions.Protocol[P_INNER, ReturnType_INNER, SUPERSELF]):
201
+ class __remote_spec(typing_extensions.Protocol[ReturnType_INNER, P_INNER, SUPERSELF]):
202
202
  def __call__(self, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> ReturnType_INNER: ...
203
203
  async def aio(self, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> ReturnType_INNER: ...
204
204
 
205
- remote: __remote_spec[modal._functions.P, modal._functions.ReturnType, typing_extensions.Self]
205
+ remote: __remote_spec[modal._functions.ReturnType, modal._functions.P, typing_extensions.Self]
206
206
 
207
207
  class __remote_gen_spec(typing_extensions.Protocol[SUPERSELF]):
208
208
  def __call__(self, *args, **kwargs) -> typing.Generator[typing.Any, None, None]: ...
@@ -217,19 +217,19 @@ class Function(
217
217
  self, *args: modal._functions.P.args, **kwargs: modal._functions.P.kwargs
218
218
  ) -> modal._functions.OriginalReturnType: ...
219
219
 
220
- class ___experimental_spawn_spec(typing_extensions.Protocol[P_INNER, ReturnType_INNER, SUPERSELF]):
220
+ class ___experimental_spawn_spec(typing_extensions.Protocol[ReturnType_INNER, P_INNER, SUPERSELF]):
221
221
  def __call__(self, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> FunctionCall[ReturnType_INNER]: ...
222
222
  async def aio(self, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> FunctionCall[ReturnType_INNER]: ...
223
223
 
224
224
  _experimental_spawn: ___experimental_spawn_spec[
225
- modal._functions.P, modal._functions.ReturnType, typing_extensions.Self
225
+ modal._functions.ReturnType, modal._functions.P, typing_extensions.Self
226
226
  ]
227
227
 
228
- class __spawn_spec(typing_extensions.Protocol[P_INNER, ReturnType_INNER, SUPERSELF]):
228
+ class __spawn_spec(typing_extensions.Protocol[ReturnType_INNER, P_INNER, SUPERSELF]):
229
229
  def __call__(self, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> FunctionCall[ReturnType_INNER]: ...
230
230
  async def aio(self, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> FunctionCall[ReturnType_INNER]: ...
231
231
 
232
- spawn: __spawn_spec[modal._functions.P, modal._functions.ReturnType, typing_extensions.Self]
232
+ spawn: __spawn_spec[modal._functions.ReturnType, modal._functions.P, typing_extensions.Self]
233
233
 
234
234
  def get_raw_f(self) -> collections.abc.Callable[..., typing.Any]: ...
235
235
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: modal
3
- Version: 0.73.130
3
+ Version: 0.73.131
4
4
  Summary: Python client library for Modal
5
5
  Author-email: Modal Labs <support@modal.com>
6
6
  License: Apache-2.0