modal 0.73.128__py3-none-any.whl → 0.73.131__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
modal/__init__.py CHANGED
@@ -27,7 +27,6 @@ try:
27
27
  asgi_app,
28
28
  batched,
29
29
  build,
30
- concurrent,
31
30
  enter,
32
31
  exit,
33
32
  fastapi_endpoint,
@@ -83,7 +82,6 @@ __all__ = [
83
82
  "asgi_app",
84
83
  "batched",
85
84
  "build",
86
- "concurrent",
87
85
  "current_function_call_id",
88
86
  "current_input_id",
89
87
  "enable_output",
modal/_functions.py CHANGED
@@ -25,7 +25,12 @@ from ._pty import get_pty_info
25
25
  from ._resolver import Resolver
26
26
  from ._resources import convert_fn_config_to_resources_config
27
27
  from ._runtime.execution_context import current_input_id, is_local
28
- from ._serialization import apply_defaults, serialize, serialize_proto_params, validate_params
28
+ from ._serialization import (
29
+ apply_defaults,
30
+ serialize,
31
+ serialize_proto_params,
32
+ validate_parameter_values,
33
+ )
29
34
  from ._traceback import print_server_warnings
30
35
  from ._utils.async_utils import (
31
36
  TaskContext,
@@ -435,8 +440,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
435
440
  max_containers: Optional[int] = None,
436
441
  buffer_containers: Optional[int] = None,
437
442
  scaledown_window: Optional[int] = None,
438
- max_concurrent_inputs: Optional[int] = None,
439
- target_concurrent_inputs: Optional[int] = None,
443
+ allow_concurrent_inputs: Optional[int] = None,
440
444
  batch_max_size: Optional[int] = None,
441
445
  batch_wait_ms: Optional[int] = None,
442
446
  cloud: Optional[str] = None,
@@ -787,8 +791,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
787
791
  runtime_perf_record=config.get("runtime_perf_record"),
788
792
  app_name=app_name,
789
793
  is_builder_function=is_builder_function,
790
- max_concurrent_inputs=max_concurrent_inputs or 0,
791
- target_concurrent_inputs=target_concurrent_inputs or 0,
794
+ target_concurrent_inputs=allow_concurrent_inputs or 0,
792
795
  batch_max_size=batch_max_size or 0,
793
796
  batch_linger_ms=batch_wait_ms or 0,
794
797
  worker_id=config.get("worker_id"),
@@ -977,7 +980,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
977
980
  )
978
981
  schema = parent._class_parameter_info.schema
979
982
  kwargs_with_defaults = apply_defaults(kwargs, schema)
980
- validate_params(kwargs_with_defaults, schema)
983
+ validate_parameter_values(kwargs_with_defaults, schema)
981
984
  serialized_params = serialize_proto_params(kwargs_with_defaults)
982
985
  can_use_parent = len(parent._class_parameter_info.schema) == 0 # no parameters
983
986
  else:
@@ -1314,7 +1317,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
1314
1317
  order_outputs,
1315
1318
  return_exceptions,
1316
1319
  count_update_callback,
1317
- api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC
1320
+ api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC,
1318
1321
  )
1319
1322
  ) as stream:
1320
1323
  async for item in stream:
@@ -59,8 +59,6 @@ class _PartialFunction(typing.Generic[P, ReturnType, OriginalReturnType]):
59
59
  force_build: bool
60
60
  cluster_size: Optional[int] # Experimental: Clustered functions
61
61
  build_timeout: Optional[int]
62
- max_concurrent_inputs: Optional[int]
63
- target_concurrent_inputs: Optional[int]
64
62
 
65
63
  def __init__(
66
64
  self,
@@ -74,8 +72,6 @@ class _PartialFunction(typing.Generic[P, ReturnType, OriginalReturnType]):
74
72
  cluster_size: Optional[int] = None, # Experimental: Clustered functions
75
73
  force_build: bool = False,
76
74
  build_timeout: Optional[int] = None,
77
- max_concurrent_inputs: Optional[int] = None,
78
- target_concurrent_inputs: Optional[int] = None,
79
75
  ):
80
76
  self.raw_f = raw_f
81
77
  self.flags = flags
@@ -93,8 +89,6 @@ class _PartialFunction(typing.Generic[P, ReturnType, OriginalReturnType]):
93
89
  self.cluster_size = cluster_size # Experimental: Clustered functions
94
90
  self.force_build = force_build
95
91
  self.build_timeout = build_timeout
96
- self.max_concurrent_inputs = max_concurrent_inputs
97
- self.target_concurrent_inputs = target_concurrent_inputs
98
92
 
99
93
  def _get_raw_f(self) -> Callable[P, ReturnType]:
100
94
  return self.raw_f
@@ -149,8 +143,6 @@ class _PartialFunction(typing.Generic[P, ReturnType, OriginalReturnType]):
149
143
  batch_wait_ms=self.batch_wait_ms,
150
144
  force_build=self.force_build,
151
145
  build_timeout=self.build_timeout,
152
- max_concurrent_inputs=self.max_concurrent_inputs,
153
- target_concurrent_inputs=self.target_concurrent_inputs,
154
146
  )
155
147
 
156
148
 
@@ -730,49 +722,3 @@ def _batched(
730
722
  )
731
723
 
732
724
  return wrapper
733
-
734
-
735
- def _concurrent(
736
- _warn_parentheses_missing=None,
737
- *,
738
- max_inputs: int, # Hard limit on each container's input concurrency
739
- target_inputs: Optional[int] = None, # Input concurrency that Modal's autoscaler should target
740
- ) -> Callable[[Union[Callable[..., Any], _PartialFunction]], _PartialFunction]:
741
- """Decorator that allows individual containers to handle multiple inputs concurrently.
742
-
743
- The concurrency mechanism depends on whether the function is async or not:
744
- - Async functions will run inputs on a single thread as asyncio tasks.
745
- - Synchronous functions will use multi-threading. The code must be thread-safe.
746
-
747
- Input concurrency will be most useful for workflows that are IO-bound
748
- (e.g., making network requests) or when running an inference server that supports
749
- dynamic batching.
750
-
751
- When `target_inputs` is set, Modal's autoscaler will try to provision resources such
752
- that each container is running that many inputs concurrently. Containers may burst up to
753
- up to `max_inputs` if resources are insufficient to remain at the target concurrency.
754
- """
755
- if _warn_parentheses_missing is not None:
756
- raise InvalidError(
757
- "Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@modal.concurrent()`."
758
- )
759
-
760
- if target_inputs and target_inputs > max_inputs:
761
- raise InvalidError("`target_inputs` parameter cannot be greater than `max_inputs`.")
762
-
763
- def wrapper(obj: Union[Callable[..., Any], _PartialFunction]) -> _PartialFunction:
764
- if isinstance(obj, _PartialFunction):
765
- # Risky that we need to mutate the parameters here; should make this safer
766
- obj.max_concurrent_inputs = max_inputs
767
- obj.target_concurrent_inputs = target_inputs
768
- obj.add_flags(_PartialFunctionFlags.FUNCTION)
769
- return obj
770
-
771
- return _PartialFunction(
772
- obj,
773
- _PartialFunctionFlags.FUNCTION,
774
- max_concurrent_inputs=max_inputs,
775
- target_concurrent_inputs=target_inputs,
776
- )
777
-
778
- return wrapper
modal/_resolver.py CHANGED
@@ -1,6 +1,7 @@
1
1
  # Copyright Modal Labs 2023
2
2
  import asyncio
3
3
  import contextlib
4
+ import traceback
4
5
  import typing
5
6
  from asyncio import Future
6
7
  from collections.abc import Hashable
@@ -153,7 +154,11 @@ class Resolver:
153
154
  self._deduplication_cache[deduplication_key] = cached_future
154
155
 
155
156
  # TODO(elias): print original exception/trace rather than the Resolver-internal trace
156
- return await cached_future
157
+ try:
158
+ return await cached_future
159
+ except Exception:
160
+ traceback.print_exc()
161
+ raise
157
162
 
158
163
  def objects(self) -> list["modal._object._Object"]:
159
164
  unique_objects: dict[str, "modal._object._Object"] = {}
modal/_serialization.py CHANGED
@@ -1,14 +1,16 @@
1
1
  # Copyright Modal Labs 2022
2
+ import inspect
2
3
  import io
3
4
  import pickle
4
5
  import typing
5
- from dataclasses import dataclass
6
+ from inspect import Parameter
6
7
  from typing import Any
7
8
 
8
9
  from modal._utils.async_utils import synchronizer
9
10
  from modal_proto import api_pb2
10
11
 
11
12
  from ._object import _Object
13
+ from ._type_manager import parameter_serde_registry, schema_registry
12
14
  from ._vendor import cloudpickle
13
15
  from .config import logger
14
16
  from .exception import DeserializationError, ExecutionError, InvalidError
@@ -389,50 +391,6 @@ def check_valid_cls_constructor_arg(key, obj):
389
391
  )
390
392
 
391
393
 
392
- def assert_bytes(obj: Any):
393
- if not isinstance(obj, bytes):
394
- raise TypeError(f"Expected bytes, got {type(obj)}")
395
- return obj
396
-
397
-
398
- @dataclass
399
- class ParamTypeInfo:
400
- default_field: str
401
- proto_field: str
402
- converter: typing.Callable[[str], typing.Any]
403
- type: type
404
-
405
-
406
- PYTHON_TO_PROTO_TYPE: dict[type, "api_pb2.ParameterType.ValueType"] = {
407
- # python type -> protobuf type enum
408
- str: api_pb2.PARAM_TYPE_STRING,
409
- int: api_pb2.PARAM_TYPE_INT,
410
- bytes: api_pb2.PARAM_TYPE_BYTES,
411
- }
412
-
413
- PROTO_TYPE_INFO = {
414
- # Protobuf type enum -> encode/decode helper metadata
415
- api_pb2.PARAM_TYPE_STRING: ParamTypeInfo(
416
- default_field="string_default",
417
- proto_field="string_value",
418
- converter=str,
419
- type=str,
420
- ),
421
- api_pb2.PARAM_TYPE_INT: ParamTypeInfo(
422
- default_field="int_default",
423
- proto_field="int_value",
424
- converter=int,
425
- type=int,
426
- ),
427
- api_pb2.PARAM_TYPE_BYTES: ParamTypeInfo(
428
- default_field="bytes_default",
429
- proto_field="bytes_value",
430
- converter=assert_bytes,
431
- type=bytes,
432
- ),
433
- }
434
-
435
-
436
394
  def apply_defaults(
437
395
  python_params: typing.Mapping[str, Any], schema: typing.Sequence[api_pb2.ClassParameterSpec]
438
396
  ) -> dict[str, Any]:
@@ -453,68 +411,56 @@ def apply_defaults(
453
411
  return result
454
412
 
455
413
 
414
+ def encode_parameter_value(name: str, python_value: Any) -> api_pb2.ClassParameterValue:
415
+ """Map to proto parameter representation using python runtime type information"""
416
+ struct = parameter_serde_registry.encode(python_value)
417
+ struct.name = name
418
+ return struct
419
+
420
+
456
421
  def serialize_proto_params(python_params: dict[str, Any]) -> bytes:
457
422
  proto_params: list[api_pb2.ClassParameterValue] = []
458
423
  for param_name, python_value in python_params.items():
459
- python_type = type(python_value)
460
- protobuf_type = get_proto_parameter_type(python_type)
461
- type_info = PROTO_TYPE_INFO.get(protobuf_type)
462
- proto_param = api_pb2.ClassParameterValue(
463
- name=param_name,
464
- type=protobuf_type,
465
- )
466
- try:
467
- converted_value = type_info.converter(python_value)
468
- except ValueError as exc:
469
- raise ValueError(f"Invalid type for parameter {param_name}: {exc}")
470
- setattr(proto_param, type_info.proto_field, converted_value)
471
- proto_params.append(proto_param)
424
+ proto_params.append(encode_parameter_value(param_name, python_value))
472
425
  proto_bytes = api_pb2.ClassParameterSet(parameters=proto_params).SerializeToString(deterministic=True)
473
426
  return proto_bytes
474
427
 
475
428
 
476
429
  def deserialize_proto_params(serialized_params: bytes) -> dict[str, Any]:
477
- proto_struct = api_pb2.ClassParameterSet()
478
- proto_struct.ParseFromString(serialized_params)
430
+ proto_struct = api_pb2.ClassParameterSet.FromString(serialized_params)
479
431
  python_params = {}
480
432
  for param in proto_struct.parameters:
481
- python_value: Any
482
- if param.type == api_pb2.PARAM_TYPE_STRING:
483
- python_value = param.string_value
484
- elif param.type == api_pb2.PARAM_TYPE_INT:
485
- python_value = param.int_value
486
- elif param.type == api_pb2.PARAM_TYPE_BYTES:
487
- python_value = param.bytes_value
488
- else:
489
- raise NotImplementedError(f"Unimplemented parameter type: {param.type}.")
490
-
491
- python_params[param.name] = python_value
433
+ python_params[param.name] = parameter_serde_registry.decode(param)
492
434
 
493
435
  return python_params
494
436
 
495
437
 
496
- def validate_params(params: dict[str, Any], schema: typing.Sequence[api_pb2.ClassParameterSpec]):
497
- # first check that all declared values are provided
498
- for schema_param in schema:
499
- if schema_param.name not in params:
500
- # we expect all values to be present - even defaulted ones (defaults are applied on payload construction)
501
- raise InvalidError(f"Missing required parameter: {schema_param.name}")
502
- python_value = params[schema_param.name]
503
- python_type = type(python_value)
504
- param_protobuf_type = get_proto_parameter_type(python_type)
505
- if schema_param.type != param_protobuf_type:
506
- expected_python_type = PROTO_TYPE_INFO[schema_param.type].type
507
- raise TypeError(
508
- f"Parameter '{schema_param.name}' type error: expected {expected_python_type.__name__}, "
509
- f"got {python_type.__name__}"
510
- )
438
+ def validate_parameter_values(payload: dict[str, Any], schema: typing.Sequence[api_pb2.ClassParameterSpec]):
439
+ """Ensure parameter payload conforms to the schema of a class
440
+
441
+ Checks that:
442
+ * All fields are specified (defaults are expected to already be applied on the payload)
443
+ * No extra fields are specified
444
+ * The type of each field is correct
445
+ """
446
+ for param_spec in schema:
447
+ if param_spec.name not in payload:
448
+ raise InvalidError(f"Missing required parameter: {param_spec.name}")
449
+ python_value = payload[param_spec.name]
450
+ if param_spec.HasField("full_type") and param_spec.full_type.base_type:
451
+ type_enum_value = param_spec.full_type.base_type
452
+ else:
453
+ type_enum_value = param_spec.type # backwards compatibility pre-full_type
454
+
455
+ parameter_serde_registry.validate_value_for_enum_type(type_enum_value, python_value)
511
456
 
512
457
  schema_fields = {p.name for p in schema}
513
458
  # then check that no extra values are provided
514
- non_declared_fields = params.keys() - schema_fields
459
+ non_declared_fields = payload.keys() - schema_fields
515
460
  if non_declared_fields:
516
461
  raise InvalidError(
517
- f"The following parameter names were provided but are not present in the schema: {non_declared_fields}"
462
+ f"The following parameter names were provided but are not defined class modal.parameters for the class: "
463
+ f"{', '.join(non_declared_fields)}"
518
464
  )
519
465
 
520
466
 
@@ -528,8 +474,6 @@ def deserialize_params(serialized_params: bytes, function_def: api_pb2.Function,
528
474
  elif function_def.class_parameter_info.format == api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PROTO:
529
475
  param_args = () # we use kwargs only for our implicit constructors
530
476
  param_kwargs = deserialize_proto_params(serialized_params)
531
- # TODO: We can probably remove the validation below since we do validation in the caller?
532
- validate_params(param_kwargs, list(function_def.class_parameter_info.schema))
533
477
  else:
534
478
  raise ExecutionError(
535
479
  f"Unknown class parameter serialization format: {function_def.class_parameter_info.format}"
@@ -538,9 +482,47 @@ def deserialize_params(serialized_params: bytes, function_def: api_pb2.Function,
538
482
  return param_args, param_kwargs
539
483
 
540
484
 
541
- def get_proto_parameter_type(parameter_type: type) -> "api_pb2.ParameterType.ValueType":
542
- if parameter_type not in PYTHON_TO_PROTO_TYPE:
543
- type_name = getattr(parameter_type, "__name__", repr(parameter_type))
544
- supported = ", ".join(parameter_type.__name__ for parameter_type in PYTHON_TO_PROTO_TYPE.keys())
545
- raise InvalidError(f"{type_name} is not a supported parameter type. Use one of: {supported}")
546
- return PYTHON_TO_PROTO_TYPE[parameter_type]
485
+ def _signature_parameter_to_spec(
486
+ python_signature_parameter: inspect.Parameter, include_legacy_parameter_fields: bool = False
487
+ ) -> api_pb2.ClassParameterSpec:
488
+ """Returns proto representation of Parameter as returned by inspect.signature()
489
+
490
+ Setting include_legacy_parameter_fields makes the output backwards compatible with
491
+ pre v0.74 clients looking at class parameter specifications, and should not be used
492
+ when registering *function* schemas.
493
+ """
494
+ declared_type = python_signature_parameter.annotation
495
+ full_proto_type = schema_registry.get_proto_generic_type(declared_type)
496
+ has_default = python_signature_parameter.default is not Parameter.empty
497
+
498
+ field_spec = api_pb2.ClassParameterSpec(
499
+ name=python_signature_parameter.name,
500
+ full_type=full_proto_type,
501
+ has_default=has_default,
502
+ )
503
+ if include_legacy_parameter_fields:
504
+ # add the .{type}_default and `.type` values as required by legacy clients
505
+ # looking at class parameter specs
506
+ if full_proto_type.base_type == api_pb2.PARAM_TYPE_INT:
507
+ if has_default:
508
+ field_spec.int_default = python_signature_parameter.default
509
+ field_spec.type = api_pb2.PARAM_TYPE_INT
510
+ elif full_proto_type.base_type == api_pb2.PARAM_TYPE_STRING:
511
+ if has_default:
512
+ field_spec.string_default = python_signature_parameter.default
513
+ field_spec.type = api_pb2.PARAM_TYPE_STRING
514
+ elif full_proto_type.base_type == api_pb2.PARAM_TYPE_BYTES:
515
+ if has_default:
516
+ field_spec.bytes_default = python_signature_parameter.default
517
+ field_spec.type = api_pb2.PARAM_TYPE_BYTES
518
+
519
+ return field_spec
520
+
521
+
522
+ def signature_to_parameter_specs(signature: inspect.Signature) -> list[api_pb2.ClassParameterSpec]:
523
+ # only used for modal.parameter() specs, uses backwards compatible fields and types
524
+ modal_parameters: list[api_pb2.ClassParameterSpec] = []
525
+ for param in signature.parameters.values():
526
+ field_spec = _signature_parameter_to_spec(param, include_legacy_parameter_fields=True)
527
+ modal_parameters.append(field_spec)
528
+ return modal_parameters
modal/_type_manager.py ADDED
@@ -0,0 +1,229 @@
1
+ # Copyright Modal Labs 2025
2
+ import typing
3
+ from typing import Any
4
+
5
+ import typing_extensions
6
+
7
+ from modal.exception import InvalidError
8
+ from modal_proto import api_pb2
9
+
10
+
11
+ class ParameterProtoSerde(typing.Protocol):
12
+ def encode(self, value: Any) -> api_pb2.ClassParameterValue: ...
13
+
14
+ def decode(self, proto_value: api_pb2.ClassParameterValue) -> Any: ...
15
+
16
+ def validate(self, python_value: Any): ...
17
+
18
+
19
+ class ProtoParameterSerdeRegistry:
20
+ _py_base_type_to_serde: dict[type, ParameterProtoSerde]
21
+ _proto_type_to_serde: dict["api_pb2.ParameterType.ValueType", ParameterProtoSerde]
22
+
23
+ def __init__(self):
24
+ self._py_base_type_to_serde = {}
25
+ self._proto_type_to_serde = {}
26
+
27
+ def register_encoder(self, python_base_type: type) -> typing.Callable[[ParameterProtoSerde], ParameterProtoSerde]:
28
+ def deco(ph: ParameterProtoSerde) -> ParameterProtoSerde:
29
+ if python_base_type in self._py_base_type_to_serde:
30
+ raise ValueError("Can't register the same encoder type twice")
31
+ self._py_base_type_to_serde[python_base_type] = ph
32
+ return ph
33
+
34
+ return deco
35
+
36
+ def register_decoder(
37
+ self, enum_type_value: "api_pb2.ParameterType.ValueType"
38
+ ) -> typing.Callable[[ParameterProtoSerde], ParameterProtoSerde]:
39
+ def deco(ph: ParameterProtoSerde) -> ParameterProtoSerde:
40
+ if enum_type_value in self._proto_type_to_serde:
41
+ raise ValueError("Can't register the same decoder type twice")
42
+ self._proto_type_to_serde[enum_type_value] = ph
43
+ return ph
44
+
45
+ return deco
46
+
47
+ def encode(self, python_value: Any) -> api_pb2.ClassParameterValue:
48
+ return self._get_encoder(type(python_value)).encode(python_value)
49
+
50
+ def supports_type(self, declared_type: type) -> bool:
51
+ try:
52
+ self._get_encoder(declared_type)
53
+ return True
54
+ except InvalidError:
55
+ return False
56
+
57
+ def decode(self, param_value: api_pb2.ClassParameterValue) -> Any:
58
+ return self._get_decoder(param_value.type).decode(param_value)
59
+
60
+ def validate_parameter_type(self, declared_type: type):
61
+ """Raises a helpful TypeError if the supplied type isn't supported by class parameters"""
62
+ if not parameter_serde_registry.supports_type(declared_type):
63
+ supported_types = self._py_base_type_to_serde.keys()
64
+ supported_str = ", ".join(t.__name__ for t in supported_types)
65
+
66
+ raise TypeError(
67
+ f"{declared_type.__name__} is not a supported modal.parameter() type. Use one of: {supported_str}"
68
+ )
69
+
70
+ def validate_value_for_enum_type(self, enum_value: "api_pb2.ParameterType.ValueType", python_value: Any):
71
+ serde = self._get_decoder(enum_value) # use the schema's expected decoder
72
+ serde.validate(python_value)
73
+
74
+ def _get_encoder(self, python_base_type: type) -> ParameterProtoSerde:
75
+ try:
76
+ return self._py_base_type_to_serde[python_base_type]
77
+ except KeyError:
78
+ raise InvalidError(f"No class parameter encoder implemented for type `{python_base_type.__name__}`")
79
+
80
+ def _get_decoder(self, enum_value: "api_pb2.ParameterType.ValueType") -> ParameterProtoSerde:
81
+ try:
82
+ return self._proto_type_to_serde[enum_value]
83
+ except KeyError:
84
+ try:
85
+ enum_name = api_pb2.ParameterType.Name(enum_value)
86
+ except ValueError:
87
+ enum_name = str(enum_value)
88
+
89
+ raise InvalidError(f"No class parameter decoder implemented for type {enum_name}.")
90
+
91
+
92
+ parameter_serde_registry = ProtoParameterSerdeRegistry()
93
+
94
+
95
+ @parameter_serde_registry.register_encoder(int)
96
+ @parameter_serde_registry.register_decoder(api_pb2.PARAM_TYPE_INT)
97
+ class IntParameter:
98
+ @staticmethod
99
+ def encode(value: Any) -> api_pb2.ClassParameterValue:
100
+ return api_pb2.ClassParameterValue(type=api_pb2.PARAM_TYPE_INT, int_value=value)
101
+
102
+ @staticmethod
103
+ def decode(proto_value: api_pb2.ClassParameterValue) -> int:
104
+ return proto_value.int_value
105
+
106
+ @staticmethod
107
+ def validate(python_value: Any):
108
+ if not isinstance(python_value, int):
109
+ raise TypeError(f"Expected int, got {type(python_value).__name__}")
110
+
111
+
112
+ @parameter_serde_registry.register_encoder(str)
113
+ @parameter_serde_registry.register_decoder(api_pb2.PARAM_TYPE_STRING)
114
+ class StringParameter:
115
+ @staticmethod
116
+ def encode(value: Any) -> api_pb2.ClassParameterValue:
117
+ return api_pb2.ClassParameterValue(type=api_pb2.PARAM_TYPE_STRING, string_value=value)
118
+
119
+ @staticmethod
120
+ def decode(proto_value: api_pb2.ClassParameterValue) -> str:
121
+ return proto_value.string_value
122
+
123
+ @staticmethod
124
+ def validate(python_value: Any):
125
+ if not isinstance(python_value, str):
126
+ raise TypeError(f"Expected str, got {type(python_value).__name__}")
127
+
128
+
129
+ @parameter_serde_registry.register_encoder(bytes)
130
+ @parameter_serde_registry.register_decoder(api_pb2.PARAM_TYPE_BYTES)
131
+ class BytesParameter:
132
+ @staticmethod
133
+ def encode(value: Any) -> api_pb2.ClassParameterValue:
134
+ return api_pb2.ClassParameterValue(type=api_pb2.PARAM_TYPE_BYTES, bytes_value=value)
135
+
136
+ @staticmethod
137
+ def decode(proto_value: api_pb2.ClassParameterValue) -> bytes:
138
+ return proto_value.bytes_value
139
+
140
+ @staticmethod
141
+ def validate(python_value: Any):
142
+ if not isinstance(python_value, bytes):
143
+ raise TypeError(f"Expected bytes, got {type(python_value).__name__}")
144
+
145
+
146
+ SCHEMA_FACTORY_TYPE = typing.Callable[[type], api_pb2.GenericPayloadType]
147
+
148
+
149
+ class SchemaRegistry:
150
+ _schema_factories: dict[type, SCHEMA_FACTORY_TYPE]
151
+
152
+ def __init__(self):
153
+ self._schema_factories = {}
154
+
155
+ def add(self, python_base_type: type) -> typing.Callable[[SCHEMA_FACTORY_TYPE], SCHEMA_FACTORY_TYPE]:
156
+ # decorator for schema factory functions for a base type
157
+ def deco(factory_func: SCHEMA_FACTORY_TYPE) -> SCHEMA_FACTORY_TYPE:
158
+ assert python_base_type not in self._schema_factories
159
+ self._schema_factories[python_base_type] = factory_func
160
+ return factory_func
161
+
162
+ return deco
163
+
164
+ def get(self, python_base_type: type) -> SCHEMA_FACTORY_TYPE:
165
+ try:
166
+ return self._schema_factories[python_base_type]
167
+ except KeyError:
168
+ return unknown_type_schema
169
+
170
+ def get_proto_generic_type(self, declared_type: type):
171
+ if origin := typing_extensions.get_origin(declared_type):
172
+ base_type = origin
173
+ else:
174
+ base_type = declared_type
175
+
176
+ return self.get(base_type)(declared_type)
177
+
178
+
179
+ schema_registry = SchemaRegistry()
180
+
181
+
182
+ @schema_registry.add(int)
183
+ def int_schema(full_python_type: type) -> api_pb2.GenericPayloadType:
184
+ return api_pb2.GenericPayloadType(
185
+ base_type=api_pb2.PARAM_TYPE_INT,
186
+ )
187
+
188
+
189
+ @schema_registry.add(bytes)
190
+ def proto_type_def(declared_python_type: type) -> api_pb2.GenericPayloadType:
191
+ return api_pb2.GenericPayloadType(
192
+ base_type=api_pb2.PARAM_TYPE_BYTES,
193
+ )
194
+
195
+
196
+ def unknown_type_schema(declared_python_type: type) -> api_pb2.GenericPayloadType:
197
+ # TODO: add some metadata for unknown types to the type def?
198
+ return api_pb2.GenericPayloadType(base_type=api_pb2.PARAM_TYPE_UNKNOWN)
199
+
200
+
201
+ @schema_registry.add(str)
202
+ def str_schema(full_python_type: type) -> api_pb2.GenericPayloadType:
203
+ return api_pb2.GenericPayloadType(
204
+ base_type=api_pb2.PARAM_TYPE_STRING,
205
+ )
206
+
207
+
208
+ @schema_registry.add(type(None))
209
+ def none_type_schema(declared_python_type: type) -> api_pb2.GenericPayloadType:
210
+ return api_pb2.GenericPayloadType(base_type=api_pb2.PARAM_TYPE_NONE)
211
+
212
+
213
+ @schema_registry.add(list)
214
+ def list_schema(full_python_type: type) -> api_pb2.GenericPayloadType:
215
+ args = typing_extensions.get_args(full_python_type)
216
+
217
+ return api_pb2.GenericPayloadType(
218
+ base_type=api_pb2.PARAM_TYPE_LIST, sub_types=[schema_registry.get_proto_generic_type(arg) for arg in args]
219
+ )
220
+
221
+
222
+ @schema_registry.add(dict)
223
+ def dict_schema(full_python_type: type) -> api_pb2.GenericPayloadType:
224
+ args = typing_extensions.get_args(full_python_type)
225
+
226
+ return api_pb2.GenericPayloadType(
227
+ base_type=api_pb2.PARAM_TYPE_DICT,
228
+ sub_types=[schema_registry.get_proto_generic_type(arg_type) for arg_type in args],
229
+ )