modal 0.73.128__py3-none-any.whl → 0.73.131__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- modal/__init__.py +0 -2
- modal/_functions.py +10 -7
- modal/_partial_function.py +0 -54
- modal/_resolver.py +6 -1
- modal/_serialization.py +78 -96
- modal/_type_manager.py +229 -0
- modal/_utils/function_utils.py +4 -27
- modal/app.py +5 -34
- modal/app.pyi +2 -3
- modal/client.pyi +2 -2
- modal/cls.py +6 -2
- modal/functions.pyi +7 -8
- modal/partial_function.py +0 -2
- modal/partial_function.pyi +0 -9
- {modal-0.73.128.dist-info → modal-0.73.131.dist-info}/METADATA +1 -1
- {modal-0.73.128.dist-info → modal-0.73.131.dist-info}/RECORD +24 -23
- modal_proto/api.proto +17 -6
- modal_proto/api_pb2.py +717 -704
- modal_proto/api_pb2.pyi +46 -8
- modal_version/_version_generated.py +1 -1
- {modal-0.73.128.dist-info → modal-0.73.131.dist-info}/LICENSE +0 -0
- {modal-0.73.128.dist-info → modal-0.73.131.dist-info}/WHEEL +0 -0
- {modal-0.73.128.dist-info → modal-0.73.131.dist-info}/entry_points.txt +0 -0
- {modal-0.73.128.dist-info → modal-0.73.131.dist-info}/top_level.txt +0 -0
modal/__init__.py
CHANGED
@@ -27,7 +27,6 @@ try:
|
|
27
27
|
asgi_app,
|
28
28
|
batched,
|
29
29
|
build,
|
30
|
-
concurrent,
|
31
30
|
enter,
|
32
31
|
exit,
|
33
32
|
fastapi_endpoint,
|
@@ -83,7 +82,6 @@ __all__ = [
|
|
83
82
|
"asgi_app",
|
84
83
|
"batched",
|
85
84
|
"build",
|
86
|
-
"concurrent",
|
87
85
|
"current_function_call_id",
|
88
86
|
"current_input_id",
|
89
87
|
"enable_output",
|
modal/_functions.py
CHANGED
@@ -25,7 +25,12 @@ from ._pty import get_pty_info
|
|
25
25
|
from ._resolver import Resolver
|
26
26
|
from ._resources import convert_fn_config_to_resources_config
|
27
27
|
from ._runtime.execution_context import current_input_id, is_local
|
28
|
-
from ._serialization import
|
28
|
+
from ._serialization import (
|
29
|
+
apply_defaults,
|
30
|
+
serialize,
|
31
|
+
serialize_proto_params,
|
32
|
+
validate_parameter_values,
|
33
|
+
)
|
29
34
|
from ._traceback import print_server_warnings
|
30
35
|
from ._utils.async_utils import (
|
31
36
|
TaskContext,
|
@@ -435,8 +440,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
|
|
435
440
|
max_containers: Optional[int] = None,
|
436
441
|
buffer_containers: Optional[int] = None,
|
437
442
|
scaledown_window: Optional[int] = None,
|
438
|
-
|
439
|
-
target_concurrent_inputs: Optional[int] = None,
|
443
|
+
allow_concurrent_inputs: Optional[int] = None,
|
440
444
|
batch_max_size: Optional[int] = None,
|
441
445
|
batch_wait_ms: Optional[int] = None,
|
442
446
|
cloud: Optional[str] = None,
|
@@ -787,8 +791,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
|
|
787
791
|
runtime_perf_record=config.get("runtime_perf_record"),
|
788
792
|
app_name=app_name,
|
789
793
|
is_builder_function=is_builder_function,
|
790
|
-
|
791
|
-
target_concurrent_inputs=target_concurrent_inputs or 0,
|
794
|
+
target_concurrent_inputs=allow_concurrent_inputs or 0,
|
792
795
|
batch_max_size=batch_max_size or 0,
|
793
796
|
batch_linger_ms=batch_wait_ms or 0,
|
794
797
|
worker_id=config.get("worker_id"),
|
@@ -977,7 +980,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
|
|
977
980
|
)
|
978
981
|
schema = parent._class_parameter_info.schema
|
979
982
|
kwargs_with_defaults = apply_defaults(kwargs, schema)
|
980
|
-
|
983
|
+
validate_parameter_values(kwargs_with_defaults, schema)
|
981
984
|
serialized_params = serialize_proto_params(kwargs_with_defaults)
|
982
985
|
can_use_parent = len(parent._class_parameter_info.schema) == 0 # no parameters
|
983
986
|
else:
|
@@ -1314,7 +1317,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
|
|
1314
1317
|
order_outputs,
|
1315
1318
|
return_exceptions,
|
1316
1319
|
count_update_callback,
|
1317
|
-
api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC
|
1320
|
+
api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC,
|
1318
1321
|
)
|
1319
1322
|
) as stream:
|
1320
1323
|
async for item in stream:
|
modal/_partial_function.py
CHANGED
@@ -59,8 +59,6 @@ class _PartialFunction(typing.Generic[P, ReturnType, OriginalReturnType]):
|
|
59
59
|
force_build: bool
|
60
60
|
cluster_size: Optional[int] # Experimental: Clustered functions
|
61
61
|
build_timeout: Optional[int]
|
62
|
-
max_concurrent_inputs: Optional[int]
|
63
|
-
target_concurrent_inputs: Optional[int]
|
64
62
|
|
65
63
|
def __init__(
|
66
64
|
self,
|
@@ -74,8 +72,6 @@ class _PartialFunction(typing.Generic[P, ReturnType, OriginalReturnType]):
|
|
74
72
|
cluster_size: Optional[int] = None, # Experimental: Clustered functions
|
75
73
|
force_build: bool = False,
|
76
74
|
build_timeout: Optional[int] = None,
|
77
|
-
max_concurrent_inputs: Optional[int] = None,
|
78
|
-
target_concurrent_inputs: Optional[int] = None,
|
79
75
|
):
|
80
76
|
self.raw_f = raw_f
|
81
77
|
self.flags = flags
|
@@ -93,8 +89,6 @@ class _PartialFunction(typing.Generic[P, ReturnType, OriginalReturnType]):
|
|
93
89
|
self.cluster_size = cluster_size # Experimental: Clustered functions
|
94
90
|
self.force_build = force_build
|
95
91
|
self.build_timeout = build_timeout
|
96
|
-
self.max_concurrent_inputs = max_concurrent_inputs
|
97
|
-
self.target_concurrent_inputs = target_concurrent_inputs
|
98
92
|
|
99
93
|
def _get_raw_f(self) -> Callable[P, ReturnType]:
|
100
94
|
return self.raw_f
|
@@ -149,8 +143,6 @@ class _PartialFunction(typing.Generic[P, ReturnType, OriginalReturnType]):
|
|
149
143
|
batch_wait_ms=self.batch_wait_ms,
|
150
144
|
force_build=self.force_build,
|
151
145
|
build_timeout=self.build_timeout,
|
152
|
-
max_concurrent_inputs=self.max_concurrent_inputs,
|
153
|
-
target_concurrent_inputs=self.target_concurrent_inputs,
|
154
146
|
)
|
155
147
|
|
156
148
|
|
@@ -730,49 +722,3 @@ def _batched(
|
|
730
722
|
)
|
731
723
|
|
732
724
|
return wrapper
|
733
|
-
|
734
|
-
|
735
|
-
def _concurrent(
|
736
|
-
_warn_parentheses_missing=None,
|
737
|
-
*,
|
738
|
-
max_inputs: int, # Hard limit on each container's input concurrency
|
739
|
-
target_inputs: Optional[int] = None, # Input concurrency that Modal's autoscaler should target
|
740
|
-
) -> Callable[[Union[Callable[..., Any], _PartialFunction]], _PartialFunction]:
|
741
|
-
"""Decorator that allows individual containers to handle multiple inputs concurrently.
|
742
|
-
|
743
|
-
The concurrency mechanism depends on whether the function is async or not:
|
744
|
-
- Async functions will run inputs on a single thread as asyncio tasks.
|
745
|
-
- Synchronous functions will use multi-threading. The code must be thread-safe.
|
746
|
-
|
747
|
-
Input concurrency will be most useful for workflows that are IO-bound
|
748
|
-
(e.g., making network requests) or when running an inference server that supports
|
749
|
-
dynamic batching.
|
750
|
-
|
751
|
-
When `target_inputs` is set, Modal's autoscaler will try to provision resources such
|
752
|
-
that each container is running that many inputs concurrently. Containers may burst up to
|
753
|
-
up to `max_inputs` if resources are insufficient to remain at the target concurrency.
|
754
|
-
"""
|
755
|
-
if _warn_parentheses_missing is not None:
|
756
|
-
raise InvalidError(
|
757
|
-
"Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@modal.concurrent()`."
|
758
|
-
)
|
759
|
-
|
760
|
-
if target_inputs and target_inputs > max_inputs:
|
761
|
-
raise InvalidError("`target_inputs` parameter cannot be greater than `max_inputs`.")
|
762
|
-
|
763
|
-
def wrapper(obj: Union[Callable[..., Any], _PartialFunction]) -> _PartialFunction:
|
764
|
-
if isinstance(obj, _PartialFunction):
|
765
|
-
# Risky that we need to mutate the parameters here; should make this safer
|
766
|
-
obj.max_concurrent_inputs = max_inputs
|
767
|
-
obj.target_concurrent_inputs = target_inputs
|
768
|
-
obj.add_flags(_PartialFunctionFlags.FUNCTION)
|
769
|
-
return obj
|
770
|
-
|
771
|
-
return _PartialFunction(
|
772
|
-
obj,
|
773
|
-
_PartialFunctionFlags.FUNCTION,
|
774
|
-
max_concurrent_inputs=max_inputs,
|
775
|
-
target_concurrent_inputs=target_inputs,
|
776
|
-
)
|
777
|
-
|
778
|
-
return wrapper
|
modal/_resolver.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
# Copyright Modal Labs 2023
|
2
2
|
import asyncio
|
3
3
|
import contextlib
|
4
|
+
import traceback
|
4
5
|
import typing
|
5
6
|
from asyncio import Future
|
6
7
|
from collections.abc import Hashable
|
@@ -153,7 +154,11 @@ class Resolver:
|
|
153
154
|
self._deduplication_cache[deduplication_key] = cached_future
|
154
155
|
|
155
156
|
# TODO(elias): print original exception/trace rather than the Resolver-internal trace
|
156
|
-
|
157
|
+
try:
|
158
|
+
return await cached_future
|
159
|
+
except Exception:
|
160
|
+
traceback.print_exc()
|
161
|
+
raise
|
157
162
|
|
158
163
|
def objects(self) -> list["modal._object._Object"]:
|
159
164
|
unique_objects: dict[str, "modal._object._Object"] = {}
|
modal/_serialization.py
CHANGED
@@ -1,14 +1,16 @@
|
|
1
1
|
# Copyright Modal Labs 2022
|
2
|
+
import inspect
|
2
3
|
import io
|
3
4
|
import pickle
|
4
5
|
import typing
|
5
|
-
from
|
6
|
+
from inspect import Parameter
|
6
7
|
from typing import Any
|
7
8
|
|
8
9
|
from modal._utils.async_utils import synchronizer
|
9
10
|
from modal_proto import api_pb2
|
10
11
|
|
11
12
|
from ._object import _Object
|
13
|
+
from ._type_manager import parameter_serde_registry, schema_registry
|
12
14
|
from ._vendor import cloudpickle
|
13
15
|
from .config import logger
|
14
16
|
from .exception import DeserializationError, ExecutionError, InvalidError
|
@@ -389,50 +391,6 @@ def check_valid_cls_constructor_arg(key, obj):
|
|
389
391
|
)
|
390
392
|
|
391
393
|
|
392
|
-
def assert_bytes(obj: Any):
|
393
|
-
if not isinstance(obj, bytes):
|
394
|
-
raise TypeError(f"Expected bytes, got {type(obj)}")
|
395
|
-
return obj
|
396
|
-
|
397
|
-
|
398
|
-
@dataclass
|
399
|
-
class ParamTypeInfo:
|
400
|
-
default_field: str
|
401
|
-
proto_field: str
|
402
|
-
converter: typing.Callable[[str], typing.Any]
|
403
|
-
type: type
|
404
|
-
|
405
|
-
|
406
|
-
PYTHON_TO_PROTO_TYPE: dict[type, "api_pb2.ParameterType.ValueType"] = {
|
407
|
-
# python type -> protobuf type enum
|
408
|
-
str: api_pb2.PARAM_TYPE_STRING,
|
409
|
-
int: api_pb2.PARAM_TYPE_INT,
|
410
|
-
bytes: api_pb2.PARAM_TYPE_BYTES,
|
411
|
-
}
|
412
|
-
|
413
|
-
PROTO_TYPE_INFO = {
|
414
|
-
# Protobuf type enum -> encode/decode helper metadata
|
415
|
-
api_pb2.PARAM_TYPE_STRING: ParamTypeInfo(
|
416
|
-
default_field="string_default",
|
417
|
-
proto_field="string_value",
|
418
|
-
converter=str,
|
419
|
-
type=str,
|
420
|
-
),
|
421
|
-
api_pb2.PARAM_TYPE_INT: ParamTypeInfo(
|
422
|
-
default_field="int_default",
|
423
|
-
proto_field="int_value",
|
424
|
-
converter=int,
|
425
|
-
type=int,
|
426
|
-
),
|
427
|
-
api_pb2.PARAM_TYPE_BYTES: ParamTypeInfo(
|
428
|
-
default_field="bytes_default",
|
429
|
-
proto_field="bytes_value",
|
430
|
-
converter=assert_bytes,
|
431
|
-
type=bytes,
|
432
|
-
),
|
433
|
-
}
|
434
|
-
|
435
|
-
|
436
394
|
def apply_defaults(
|
437
395
|
python_params: typing.Mapping[str, Any], schema: typing.Sequence[api_pb2.ClassParameterSpec]
|
438
396
|
) -> dict[str, Any]:
|
@@ -453,68 +411,56 @@ def apply_defaults(
|
|
453
411
|
return result
|
454
412
|
|
455
413
|
|
414
|
+
def encode_parameter_value(name: str, python_value: Any) -> api_pb2.ClassParameterValue:
|
415
|
+
"""Map to proto parameter representation using python runtime type information"""
|
416
|
+
struct = parameter_serde_registry.encode(python_value)
|
417
|
+
struct.name = name
|
418
|
+
return struct
|
419
|
+
|
420
|
+
|
456
421
|
def serialize_proto_params(python_params: dict[str, Any]) -> bytes:
|
457
422
|
proto_params: list[api_pb2.ClassParameterValue] = []
|
458
423
|
for param_name, python_value in python_params.items():
|
459
|
-
|
460
|
-
protobuf_type = get_proto_parameter_type(python_type)
|
461
|
-
type_info = PROTO_TYPE_INFO.get(protobuf_type)
|
462
|
-
proto_param = api_pb2.ClassParameterValue(
|
463
|
-
name=param_name,
|
464
|
-
type=protobuf_type,
|
465
|
-
)
|
466
|
-
try:
|
467
|
-
converted_value = type_info.converter(python_value)
|
468
|
-
except ValueError as exc:
|
469
|
-
raise ValueError(f"Invalid type for parameter {param_name}: {exc}")
|
470
|
-
setattr(proto_param, type_info.proto_field, converted_value)
|
471
|
-
proto_params.append(proto_param)
|
424
|
+
proto_params.append(encode_parameter_value(param_name, python_value))
|
472
425
|
proto_bytes = api_pb2.ClassParameterSet(parameters=proto_params).SerializeToString(deterministic=True)
|
473
426
|
return proto_bytes
|
474
427
|
|
475
428
|
|
476
429
|
def deserialize_proto_params(serialized_params: bytes) -> dict[str, Any]:
|
477
|
-
proto_struct = api_pb2.ClassParameterSet()
|
478
|
-
proto_struct.ParseFromString(serialized_params)
|
430
|
+
proto_struct = api_pb2.ClassParameterSet.FromString(serialized_params)
|
479
431
|
python_params = {}
|
480
432
|
for param in proto_struct.parameters:
|
481
|
-
|
482
|
-
if param.type == api_pb2.PARAM_TYPE_STRING:
|
483
|
-
python_value = param.string_value
|
484
|
-
elif param.type == api_pb2.PARAM_TYPE_INT:
|
485
|
-
python_value = param.int_value
|
486
|
-
elif param.type == api_pb2.PARAM_TYPE_BYTES:
|
487
|
-
python_value = param.bytes_value
|
488
|
-
else:
|
489
|
-
raise NotImplementedError(f"Unimplemented parameter type: {param.type}.")
|
490
|
-
|
491
|
-
python_params[param.name] = python_value
|
433
|
+
python_params[param.name] = parameter_serde_registry.decode(param)
|
492
434
|
|
493
435
|
return python_params
|
494
436
|
|
495
437
|
|
496
|
-
def
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
if
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
438
|
+
def validate_parameter_values(payload: dict[str, Any], schema: typing.Sequence[api_pb2.ClassParameterSpec]):
|
439
|
+
"""Ensure parameter payload conforms to the schema of a class
|
440
|
+
|
441
|
+
Checks that:
|
442
|
+
* All fields are specified (defaults are expected to already be applied on the payload)
|
443
|
+
* No extra fields are specified
|
444
|
+
* The type of each field is correct
|
445
|
+
"""
|
446
|
+
for param_spec in schema:
|
447
|
+
if param_spec.name not in payload:
|
448
|
+
raise InvalidError(f"Missing required parameter: {param_spec.name}")
|
449
|
+
python_value = payload[param_spec.name]
|
450
|
+
if param_spec.HasField("full_type") and param_spec.full_type.base_type:
|
451
|
+
type_enum_value = param_spec.full_type.base_type
|
452
|
+
else:
|
453
|
+
type_enum_value = param_spec.type # backwards compatibility pre-full_type
|
454
|
+
|
455
|
+
parameter_serde_registry.validate_value_for_enum_type(type_enum_value, python_value)
|
511
456
|
|
512
457
|
schema_fields = {p.name for p in schema}
|
513
458
|
# then check that no extra values are provided
|
514
|
-
non_declared_fields =
|
459
|
+
non_declared_fields = payload.keys() - schema_fields
|
515
460
|
if non_declared_fields:
|
516
461
|
raise InvalidError(
|
517
|
-
f"The following parameter names were provided but are not
|
462
|
+
f"The following parameter names were provided but are not defined class modal.parameters for the class: "
|
463
|
+
f"{', '.join(non_declared_fields)}"
|
518
464
|
)
|
519
465
|
|
520
466
|
|
@@ -528,8 +474,6 @@ def deserialize_params(serialized_params: bytes, function_def: api_pb2.Function,
|
|
528
474
|
elif function_def.class_parameter_info.format == api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PROTO:
|
529
475
|
param_args = () # we use kwargs only for our implicit constructors
|
530
476
|
param_kwargs = deserialize_proto_params(serialized_params)
|
531
|
-
# TODO: We can probably remove the validation below since we do validation in the caller?
|
532
|
-
validate_params(param_kwargs, list(function_def.class_parameter_info.schema))
|
533
477
|
else:
|
534
478
|
raise ExecutionError(
|
535
479
|
f"Unknown class parameter serialization format: {function_def.class_parameter_info.format}"
|
@@ -538,9 +482,47 @@ def deserialize_params(serialized_params: bytes, function_def: api_pb2.Function,
|
|
538
482
|
return param_args, param_kwargs
|
539
483
|
|
540
484
|
|
541
|
-
def
|
542
|
-
|
543
|
-
|
544
|
-
|
545
|
-
|
546
|
-
|
485
|
+
def _signature_parameter_to_spec(
|
486
|
+
python_signature_parameter: inspect.Parameter, include_legacy_parameter_fields: bool = False
|
487
|
+
) -> api_pb2.ClassParameterSpec:
|
488
|
+
"""Returns proto representation of Parameter as returned by inspect.signature()
|
489
|
+
|
490
|
+
Setting include_legacy_parameter_fields makes the output backwards compatible with
|
491
|
+
pre v0.74 clients looking at class parameter specifications, and should not be used
|
492
|
+
when registering *function* schemas.
|
493
|
+
"""
|
494
|
+
declared_type = python_signature_parameter.annotation
|
495
|
+
full_proto_type = schema_registry.get_proto_generic_type(declared_type)
|
496
|
+
has_default = python_signature_parameter.default is not Parameter.empty
|
497
|
+
|
498
|
+
field_spec = api_pb2.ClassParameterSpec(
|
499
|
+
name=python_signature_parameter.name,
|
500
|
+
full_type=full_proto_type,
|
501
|
+
has_default=has_default,
|
502
|
+
)
|
503
|
+
if include_legacy_parameter_fields:
|
504
|
+
# add the .{type}_default and `.type` values as required by legacy clients
|
505
|
+
# looking at class parameter specs
|
506
|
+
if full_proto_type.base_type == api_pb2.PARAM_TYPE_INT:
|
507
|
+
if has_default:
|
508
|
+
field_spec.int_default = python_signature_parameter.default
|
509
|
+
field_spec.type = api_pb2.PARAM_TYPE_INT
|
510
|
+
elif full_proto_type.base_type == api_pb2.PARAM_TYPE_STRING:
|
511
|
+
if has_default:
|
512
|
+
field_spec.string_default = python_signature_parameter.default
|
513
|
+
field_spec.type = api_pb2.PARAM_TYPE_STRING
|
514
|
+
elif full_proto_type.base_type == api_pb2.PARAM_TYPE_BYTES:
|
515
|
+
if has_default:
|
516
|
+
field_spec.bytes_default = python_signature_parameter.default
|
517
|
+
field_spec.type = api_pb2.PARAM_TYPE_BYTES
|
518
|
+
|
519
|
+
return field_spec
|
520
|
+
|
521
|
+
|
522
|
+
def signature_to_parameter_specs(signature: inspect.Signature) -> list[api_pb2.ClassParameterSpec]:
|
523
|
+
# only used for modal.parameter() specs, uses backwards compatible fields and types
|
524
|
+
modal_parameters: list[api_pb2.ClassParameterSpec] = []
|
525
|
+
for param in signature.parameters.values():
|
526
|
+
field_spec = _signature_parameter_to_spec(param, include_legacy_parameter_fields=True)
|
527
|
+
modal_parameters.append(field_spec)
|
528
|
+
return modal_parameters
|
modal/_type_manager.py
ADDED
@@ -0,0 +1,229 @@
|
|
1
|
+
# Copyright Modal Labs 2025
|
2
|
+
import typing
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
import typing_extensions
|
6
|
+
|
7
|
+
from modal.exception import InvalidError
|
8
|
+
from modal_proto import api_pb2
|
9
|
+
|
10
|
+
|
11
|
+
class ParameterProtoSerde(typing.Protocol):
|
12
|
+
def encode(self, value: Any) -> api_pb2.ClassParameterValue: ...
|
13
|
+
|
14
|
+
def decode(self, proto_value: api_pb2.ClassParameterValue) -> Any: ...
|
15
|
+
|
16
|
+
def validate(self, python_value: Any): ...
|
17
|
+
|
18
|
+
|
19
|
+
class ProtoParameterSerdeRegistry:
|
20
|
+
_py_base_type_to_serde: dict[type, ParameterProtoSerde]
|
21
|
+
_proto_type_to_serde: dict["api_pb2.ParameterType.ValueType", ParameterProtoSerde]
|
22
|
+
|
23
|
+
def __init__(self):
|
24
|
+
self._py_base_type_to_serde = {}
|
25
|
+
self._proto_type_to_serde = {}
|
26
|
+
|
27
|
+
def register_encoder(self, python_base_type: type) -> typing.Callable[[ParameterProtoSerde], ParameterProtoSerde]:
|
28
|
+
def deco(ph: ParameterProtoSerde) -> ParameterProtoSerde:
|
29
|
+
if python_base_type in self._py_base_type_to_serde:
|
30
|
+
raise ValueError("Can't register the same encoder type twice")
|
31
|
+
self._py_base_type_to_serde[python_base_type] = ph
|
32
|
+
return ph
|
33
|
+
|
34
|
+
return deco
|
35
|
+
|
36
|
+
def register_decoder(
|
37
|
+
self, enum_type_value: "api_pb2.ParameterType.ValueType"
|
38
|
+
) -> typing.Callable[[ParameterProtoSerde], ParameterProtoSerde]:
|
39
|
+
def deco(ph: ParameterProtoSerde) -> ParameterProtoSerde:
|
40
|
+
if enum_type_value in self._proto_type_to_serde:
|
41
|
+
raise ValueError("Can't register the same decoder type twice")
|
42
|
+
self._proto_type_to_serde[enum_type_value] = ph
|
43
|
+
return ph
|
44
|
+
|
45
|
+
return deco
|
46
|
+
|
47
|
+
def encode(self, python_value: Any) -> api_pb2.ClassParameterValue:
|
48
|
+
return self._get_encoder(type(python_value)).encode(python_value)
|
49
|
+
|
50
|
+
def supports_type(self, declared_type: type) -> bool:
|
51
|
+
try:
|
52
|
+
self._get_encoder(declared_type)
|
53
|
+
return True
|
54
|
+
except InvalidError:
|
55
|
+
return False
|
56
|
+
|
57
|
+
def decode(self, param_value: api_pb2.ClassParameterValue) -> Any:
|
58
|
+
return self._get_decoder(param_value.type).decode(param_value)
|
59
|
+
|
60
|
+
def validate_parameter_type(self, declared_type: type):
|
61
|
+
"""Raises a helpful TypeError if the supplied type isn't supported by class parameters"""
|
62
|
+
if not parameter_serde_registry.supports_type(declared_type):
|
63
|
+
supported_types = self._py_base_type_to_serde.keys()
|
64
|
+
supported_str = ", ".join(t.__name__ for t in supported_types)
|
65
|
+
|
66
|
+
raise TypeError(
|
67
|
+
f"{declared_type.__name__} is not a supported modal.parameter() type. Use one of: {supported_str}"
|
68
|
+
)
|
69
|
+
|
70
|
+
def validate_value_for_enum_type(self, enum_value: "api_pb2.ParameterType.ValueType", python_value: Any):
|
71
|
+
serde = self._get_decoder(enum_value) # use the schema's expected decoder
|
72
|
+
serde.validate(python_value)
|
73
|
+
|
74
|
+
def _get_encoder(self, python_base_type: type) -> ParameterProtoSerde:
|
75
|
+
try:
|
76
|
+
return self._py_base_type_to_serde[python_base_type]
|
77
|
+
except KeyError:
|
78
|
+
raise InvalidError(f"No class parameter encoder implemented for type `{python_base_type.__name__}`")
|
79
|
+
|
80
|
+
def _get_decoder(self, enum_value: "api_pb2.ParameterType.ValueType") -> ParameterProtoSerde:
|
81
|
+
try:
|
82
|
+
return self._proto_type_to_serde[enum_value]
|
83
|
+
except KeyError:
|
84
|
+
try:
|
85
|
+
enum_name = api_pb2.ParameterType.Name(enum_value)
|
86
|
+
except ValueError:
|
87
|
+
enum_name = str(enum_value)
|
88
|
+
|
89
|
+
raise InvalidError(f"No class parameter decoder implemented for type {enum_name}.")
|
90
|
+
|
91
|
+
|
92
|
+
parameter_serde_registry = ProtoParameterSerdeRegistry()
|
93
|
+
|
94
|
+
|
95
|
+
@parameter_serde_registry.register_encoder(int)
|
96
|
+
@parameter_serde_registry.register_decoder(api_pb2.PARAM_TYPE_INT)
|
97
|
+
class IntParameter:
|
98
|
+
@staticmethod
|
99
|
+
def encode(value: Any) -> api_pb2.ClassParameterValue:
|
100
|
+
return api_pb2.ClassParameterValue(type=api_pb2.PARAM_TYPE_INT, int_value=value)
|
101
|
+
|
102
|
+
@staticmethod
|
103
|
+
def decode(proto_value: api_pb2.ClassParameterValue) -> int:
|
104
|
+
return proto_value.int_value
|
105
|
+
|
106
|
+
@staticmethod
|
107
|
+
def validate(python_value: Any):
|
108
|
+
if not isinstance(python_value, int):
|
109
|
+
raise TypeError(f"Expected int, got {type(python_value).__name__}")
|
110
|
+
|
111
|
+
|
112
|
+
@parameter_serde_registry.register_encoder(str)
|
113
|
+
@parameter_serde_registry.register_decoder(api_pb2.PARAM_TYPE_STRING)
|
114
|
+
class StringParameter:
|
115
|
+
@staticmethod
|
116
|
+
def encode(value: Any) -> api_pb2.ClassParameterValue:
|
117
|
+
return api_pb2.ClassParameterValue(type=api_pb2.PARAM_TYPE_STRING, string_value=value)
|
118
|
+
|
119
|
+
@staticmethod
|
120
|
+
def decode(proto_value: api_pb2.ClassParameterValue) -> str:
|
121
|
+
return proto_value.string_value
|
122
|
+
|
123
|
+
@staticmethod
|
124
|
+
def validate(python_value: Any):
|
125
|
+
if not isinstance(python_value, str):
|
126
|
+
raise TypeError(f"Expected str, got {type(python_value).__name__}")
|
127
|
+
|
128
|
+
|
129
|
+
@parameter_serde_registry.register_encoder(bytes)
|
130
|
+
@parameter_serde_registry.register_decoder(api_pb2.PARAM_TYPE_BYTES)
|
131
|
+
class BytesParameter:
|
132
|
+
@staticmethod
|
133
|
+
def encode(value: Any) -> api_pb2.ClassParameterValue:
|
134
|
+
return api_pb2.ClassParameterValue(type=api_pb2.PARAM_TYPE_BYTES, bytes_value=value)
|
135
|
+
|
136
|
+
@staticmethod
|
137
|
+
def decode(proto_value: api_pb2.ClassParameterValue) -> bytes:
|
138
|
+
return proto_value.bytes_value
|
139
|
+
|
140
|
+
@staticmethod
|
141
|
+
def validate(python_value: Any):
|
142
|
+
if not isinstance(python_value, bytes):
|
143
|
+
raise TypeError(f"Expected bytes, got {type(python_value).__name__}")
|
144
|
+
|
145
|
+
|
146
|
+
SCHEMA_FACTORY_TYPE = typing.Callable[[type], api_pb2.GenericPayloadType]
|
147
|
+
|
148
|
+
|
149
|
+
class SchemaRegistry:
|
150
|
+
_schema_factories: dict[type, SCHEMA_FACTORY_TYPE]
|
151
|
+
|
152
|
+
def __init__(self):
|
153
|
+
self._schema_factories = {}
|
154
|
+
|
155
|
+
def add(self, python_base_type: type) -> typing.Callable[[SCHEMA_FACTORY_TYPE], SCHEMA_FACTORY_TYPE]:
|
156
|
+
# decorator for schema factory functions for a base type
|
157
|
+
def deco(factory_func: SCHEMA_FACTORY_TYPE) -> SCHEMA_FACTORY_TYPE:
|
158
|
+
assert python_base_type not in self._schema_factories
|
159
|
+
self._schema_factories[python_base_type] = factory_func
|
160
|
+
return factory_func
|
161
|
+
|
162
|
+
return deco
|
163
|
+
|
164
|
+
def get(self, python_base_type: type) -> SCHEMA_FACTORY_TYPE:
|
165
|
+
try:
|
166
|
+
return self._schema_factories[python_base_type]
|
167
|
+
except KeyError:
|
168
|
+
return unknown_type_schema
|
169
|
+
|
170
|
+
def get_proto_generic_type(self, declared_type: type):
|
171
|
+
if origin := typing_extensions.get_origin(declared_type):
|
172
|
+
base_type = origin
|
173
|
+
else:
|
174
|
+
base_type = declared_type
|
175
|
+
|
176
|
+
return self.get(base_type)(declared_type)
|
177
|
+
|
178
|
+
|
179
|
+
schema_registry = SchemaRegistry()
|
180
|
+
|
181
|
+
|
182
|
+
@schema_registry.add(int)
|
183
|
+
def int_schema(full_python_type: type) -> api_pb2.GenericPayloadType:
|
184
|
+
return api_pb2.GenericPayloadType(
|
185
|
+
base_type=api_pb2.PARAM_TYPE_INT,
|
186
|
+
)
|
187
|
+
|
188
|
+
|
189
|
+
@schema_registry.add(bytes)
|
190
|
+
def proto_type_def(declared_python_type: type) -> api_pb2.GenericPayloadType:
|
191
|
+
return api_pb2.GenericPayloadType(
|
192
|
+
base_type=api_pb2.PARAM_TYPE_BYTES,
|
193
|
+
)
|
194
|
+
|
195
|
+
|
196
|
+
def unknown_type_schema(declared_python_type: type) -> api_pb2.GenericPayloadType:
|
197
|
+
# TODO: add some metadata for unknown types to the type def?
|
198
|
+
return api_pb2.GenericPayloadType(base_type=api_pb2.PARAM_TYPE_UNKNOWN)
|
199
|
+
|
200
|
+
|
201
|
+
@schema_registry.add(str)
|
202
|
+
def str_schema(full_python_type: type) -> api_pb2.GenericPayloadType:
|
203
|
+
return api_pb2.GenericPayloadType(
|
204
|
+
base_type=api_pb2.PARAM_TYPE_STRING,
|
205
|
+
)
|
206
|
+
|
207
|
+
|
208
|
+
@schema_registry.add(type(None))
|
209
|
+
def none_type_schema(declared_python_type: type) -> api_pb2.GenericPayloadType:
|
210
|
+
return api_pb2.GenericPayloadType(base_type=api_pb2.PARAM_TYPE_NONE)
|
211
|
+
|
212
|
+
|
213
|
+
@schema_registry.add(list)
|
214
|
+
def list_schema(full_python_type: type) -> api_pb2.GenericPayloadType:
|
215
|
+
args = typing_extensions.get_args(full_python_type)
|
216
|
+
|
217
|
+
return api_pb2.GenericPayloadType(
|
218
|
+
base_type=api_pb2.PARAM_TYPE_LIST, sub_types=[schema_registry.get_proto_generic_type(arg) for arg in args]
|
219
|
+
)
|
220
|
+
|
221
|
+
|
222
|
+
@schema_registry.add(dict)
|
223
|
+
def dict_schema(full_python_type: type) -> api_pb2.GenericPayloadType:
|
224
|
+
args = typing_extensions.get_args(full_python_type)
|
225
|
+
|
226
|
+
return api_pb2.GenericPayloadType(
|
227
|
+
base_type=api_pb2.PARAM_TYPE_DICT,
|
228
|
+
sub_types=[schema_registry.get_proto_generic_type(arg_type) for arg_type in args],
|
229
|
+
)
|