modal 0.73.116__py3-none-any.whl → 0.73.128__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
modal/__init__.py CHANGED
@@ -27,6 +27,7 @@ try:
27
27
  asgi_app,
28
28
  batched,
29
29
  build,
30
+ concurrent,
30
31
  enter,
31
32
  exit,
32
33
  fastapi_endpoint,
@@ -82,6 +83,7 @@ __all__ = [
82
83
  "asgi_app",
83
84
  "batched",
84
85
  "build",
86
+ "concurrent",
85
87
  "current_function_call_id",
86
88
  "current_input_id",
87
89
  "enable_output",
modal/_functions.py CHANGED
@@ -1,4 +1,5 @@
1
1
  # Copyright Modal Labs 2023
2
+ import asyncio
2
3
  import dataclasses
3
4
  import inspect
4
5
  import textwrap
@@ -24,7 +25,7 @@ from ._pty import get_pty_info
24
25
  from ._resolver import Resolver
25
26
  from ._resources import convert_fn_config_to_resources_config
26
27
  from ._runtime.execution_context import current_input_id, is_local
27
- from ._serialization import serialize, serialize_proto_params
28
+ from ._serialization import apply_defaults, serialize, serialize_proto_params, validate_params
28
29
  from ._traceback import print_server_warnings
29
30
  from ._utils.async_utils import (
30
31
  TaskContext,
@@ -174,7 +175,7 @@ class _Invocation:
174
175
  input_jwt=input.input_jwt,
175
176
  input_id=input.input_id,
176
177
  item=item,
177
- sync_client_retries_enabled=response.sync_client_retries_enabled
178
+ sync_client_retries_enabled=response.sync_client_retries_enabled,
178
179
  )
179
180
  return _Invocation(client.stub, function_call_id, client, retry_context)
180
181
 
@@ -256,9 +257,13 @@ class _Invocation:
256
257
  try:
257
258
  return await self._get_single_output(ctx.input_jwt)
258
259
  except (UserCodeException, FunctionTimeoutError) as exc:
259
- await user_retry_manager.raise_or_sleep(exc)
260
+ delay_ms = user_retry_manager.get_delay_ms()
261
+ if delay_ms is None:
262
+ raise exc
263
+ await asyncio.sleep(delay_ms / 1000)
260
264
  except InternalFailure:
261
- # For system failures on the server, we retry immediately.
265
+ # For system failures on the server, we retry immediately,
266
+ # and the failure does not count towards the retry policy.
262
267
  pass
263
268
  await self._retry_input()
264
269
 
@@ -430,7 +435,8 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
430
435
  max_containers: Optional[int] = None,
431
436
  buffer_containers: Optional[int] = None,
432
437
  scaledown_window: Optional[int] = None,
433
- allow_concurrent_inputs: Optional[int] = None,
438
+ max_concurrent_inputs: Optional[int] = None,
439
+ target_concurrent_inputs: Optional[int] = None,
434
440
  batch_max_size: Optional[int] = None,
435
441
  batch_wait_ms: Optional[int] = None,
436
442
  cloud: Optional[str] = None,
@@ -781,7 +787,8 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
781
787
  runtime_perf_record=config.get("runtime_perf_record"),
782
788
  app_name=app_name,
783
789
  is_builder_function=is_builder_function,
784
- target_concurrent_inputs=allow_concurrent_inputs or 0,
790
+ max_concurrent_inputs=max_concurrent_inputs or 0,
791
+ target_concurrent_inputs=target_concurrent_inputs or 0,
785
792
  batch_max_size=batch_max_size or 0,
786
793
  batch_linger_ms=batch_wait_ms or 0,
787
794
  worker_id=config.get("worker_id"),
@@ -968,8 +975,11 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
968
975
  "Can't use positional arguments with modal.parameter-based synthetic constructors.\n"
969
976
  "Use (<parameter_name>=value) keyword arguments when constructing classes instead."
970
977
  )
971
- serialized_params = serialize_proto_params(kwargs, parent._class_parameter_info.schema)
972
- can_use_parent = len(parent._class_parameter_info.schema) == 0
978
+ schema = parent._class_parameter_info.schema
979
+ kwargs_with_defaults = apply_defaults(kwargs, schema)
980
+ validate_params(kwargs_with_defaults, schema)
981
+ serialized_params = serialize_proto_params(kwargs_with_defaults)
982
+ can_use_parent = len(parent._class_parameter_info.schema) == 0 # no parameters
973
983
  else:
974
984
  can_use_parent = len(args) + len(kwargs) == 0 and options is None
975
985
  serialized_params = serialize((args, kwargs))
@@ -1304,6 +1314,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
1304
1314
  order_outputs,
1305
1315
  return_exceptions,
1306
1316
  count_update_callback,
1317
+ api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC
1307
1318
  )
1308
1319
  ) as stream:
1309
1320
  async for item in stream:
@@ -59,6 +59,8 @@ class _PartialFunction(typing.Generic[P, ReturnType, OriginalReturnType]):
59
59
  force_build: bool
60
60
  cluster_size: Optional[int] # Experimental: Clustered functions
61
61
  build_timeout: Optional[int]
62
+ max_concurrent_inputs: Optional[int]
63
+ target_concurrent_inputs: Optional[int]
62
64
 
63
65
  def __init__(
64
66
  self,
@@ -72,6 +74,8 @@ class _PartialFunction(typing.Generic[P, ReturnType, OriginalReturnType]):
72
74
  cluster_size: Optional[int] = None, # Experimental: Clustered functions
73
75
  force_build: bool = False,
74
76
  build_timeout: Optional[int] = None,
77
+ max_concurrent_inputs: Optional[int] = None,
78
+ target_concurrent_inputs: Optional[int] = None,
75
79
  ):
76
80
  self.raw_f = raw_f
77
81
  self.flags = flags
@@ -89,6 +93,8 @@ class _PartialFunction(typing.Generic[P, ReturnType, OriginalReturnType]):
89
93
  self.cluster_size = cluster_size # Experimental: Clustered functions
90
94
  self.force_build = force_build
91
95
  self.build_timeout = build_timeout
96
+ self.max_concurrent_inputs = max_concurrent_inputs
97
+ self.target_concurrent_inputs = target_concurrent_inputs
92
98
 
93
99
  def _get_raw_f(self) -> Callable[P, ReturnType]:
94
100
  return self.raw_f
@@ -143,6 +149,8 @@ class _PartialFunction(typing.Generic[P, ReturnType, OriginalReturnType]):
143
149
  batch_wait_ms=self.batch_wait_ms,
144
150
  force_build=self.force_build,
145
151
  build_timeout=self.build_timeout,
152
+ max_concurrent_inputs=self.max_concurrent_inputs,
153
+ target_concurrent_inputs=self.target_concurrent_inputs,
146
154
  )
147
155
 
148
156
 
@@ -722,3 +730,49 @@ def _batched(
722
730
  )
723
731
 
724
732
  return wrapper
733
+
734
+
735
+ def _concurrent(
736
+ _warn_parentheses_missing=None,
737
+ *,
738
+ max_inputs: int, # Hard limit on each container's input concurrency
739
+ target_inputs: Optional[int] = None, # Input concurrency that Modal's autoscaler should target
740
+ ) -> Callable[[Union[Callable[..., Any], _PartialFunction]], _PartialFunction]:
741
+ """Decorator that allows individual containers to handle multiple inputs concurrently.
742
+
743
+ The concurrency mechanism depends on whether the function is async or not:
744
+ - Async functions will run inputs on a single thread as asyncio tasks.
745
+ - Synchronous functions will use multi-threading. The code must be thread-safe.
746
+
747
+ Input concurrency will be most useful for workflows that are IO-bound
748
+ (e.g., making network requests) or when running an inference server that supports
749
+ dynamic batching.
750
+
751
+ When `target_inputs` is set, Modal's autoscaler will try to provision resources such
752
+ that each container is running that many inputs concurrently. Containers may burst up to
753
+ up to `max_inputs` if resources are insufficient to remain at the target concurrency.
754
+ """
755
+ if _warn_parentheses_missing is not None:
756
+ raise InvalidError(
757
+ "Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@modal.concurrent()`."
758
+ )
759
+
760
+ if target_inputs and target_inputs > max_inputs:
761
+ raise InvalidError("`target_inputs` parameter cannot be greater than `max_inputs`.")
762
+
763
+ def wrapper(obj: Union[Callable[..., Any], _PartialFunction]) -> _PartialFunction:
764
+ if isinstance(obj, _PartialFunction):
765
+ # Risky that we need to mutate the parameters here; should make this safer
766
+ obj.max_concurrent_inputs = max_inputs
767
+ obj.target_concurrent_inputs = target_inputs
768
+ obj.add_flags(_PartialFunctionFlags.FUNCTION)
769
+ return obj
770
+
771
+ return _PartialFunction(
772
+ obj,
773
+ _PartialFunctionFlags.FUNCTION,
774
+ max_concurrent_inputs=max_inputs,
775
+ target_concurrent_inputs=target_inputs,
776
+ )
777
+
778
+ return wrapper
@@ -63,7 +63,9 @@ class IOContext:
63
63
  """
64
64
 
65
65
  input_ids: list[str]
66
+ retry_counts: list[int]
66
67
  function_call_ids: list[str]
68
+ function_inputs: list[api_pb2.FunctionInput]
67
69
  finalized_function: "modal._runtime.user_code_imports.FinalizedFunction"
68
70
 
69
71
  _cancel_issued: bool = False
@@ -72,6 +74,7 @@ class IOContext:
72
74
  def __init__(
73
75
  self,
74
76
  input_ids: list[str],
77
+ retry_counts: list[int],
75
78
  function_call_ids: list[str],
76
79
  finalized_function: "modal._runtime.user_code_imports.FinalizedFunction",
77
80
  function_inputs: list[api_pb2.FunctionInput],
@@ -79,9 +82,10 @@ class IOContext:
79
82
  client: _Client,
80
83
  ):
81
84
  self.input_ids = input_ids
85
+ self.retry_counts = retry_counts
82
86
  self.function_call_ids = function_call_ids
83
87
  self.finalized_function = finalized_function
84
- self._function_inputs = function_inputs
88
+ self.function_inputs = function_inputs
85
89
  self._is_batched = is_batched
86
90
  self._client = client
87
91
 
@@ -90,11 +94,11 @@ class IOContext:
90
94
  cls,
91
95
  client: _Client,
92
96
  finalized_functions: dict[str, "modal._runtime.user_code_imports.FinalizedFunction"],
93
- inputs: list[tuple[str, str, api_pb2.FunctionInput]],
97
+ inputs: list[tuple[str, int, str, api_pb2.FunctionInput]],
94
98
  is_batched: bool,
95
99
  ) -> "IOContext":
96
100
  assert len(inputs) >= 1 if is_batched else len(inputs) == 1
97
- input_ids, function_call_ids, function_inputs = zip(*inputs)
101
+ input_ids, retry_counts, function_call_ids, function_inputs = zip(*inputs)
98
102
 
99
103
  async def _populate_input_blobs(client: _Client, input: api_pb2.FunctionInput) -> api_pb2.FunctionInput:
100
104
  # If we got a pointer to a blob, download it from S3.
@@ -111,7 +115,7 @@ class IOContext:
111
115
  method_name = function_inputs[0].method_name
112
116
  assert all(method_name == input.method_name for input in function_inputs)
113
117
  finalized_function = finalized_functions[method_name]
114
- return cls(input_ids, function_call_ids, finalized_function, function_inputs, is_batched, client)
118
+ return cls(input_ids, retry_counts, function_call_ids, finalized_function, function_inputs, is_batched, client)
115
119
 
116
120
  def set_cancel_callback(self, cb: Callable[[], None]):
117
121
  self._cancel_callback = cb
@@ -135,7 +139,7 @@ class IOContext:
135
139
  # to make sure we handle user exceptions properly
136
140
  # and don't retry
137
141
  deserialized_args = [
138
- deserialize(input.args, self._client) if input.args else ((), {}) for input in self._function_inputs
142
+ deserialize(input.args, self._client) if input.args else ((), {}) for input in self.function_inputs
139
143
  ]
140
144
  if not self._is_batched:
141
145
  return deserialized_args[0]
@@ -551,7 +555,7 @@ class _ContainerIOManager:
551
555
  self,
552
556
  batch_max_size: int,
553
557
  batch_wait_ms: int,
554
- ) -> AsyncIterator[list[tuple[str, str, api_pb2.FunctionInput]]]:
558
+ ) -> AsyncIterator[list[tuple[str, int, str, api_pb2.FunctionInput]]]:
555
559
  request = api_pb2.FunctionGetInputsRequest(function_id=self.function_id)
556
560
  iteration = 0
557
561
  while self._fetching_inputs:
@@ -586,8 +590,7 @@ class _ContainerIOManager:
586
590
  if item.kill_switch:
587
591
  logger.debug(f"Task {self.task_id} input kill signal input.")
588
592
  return
589
-
590
- inputs.append((item.input_id, item.function_call_id, item.input))
593
+ inputs.append((item.input_id, item.retry_count, item.function_call_id, item.input))
591
594
  if item.input.final_input:
592
595
  if request.batch_max_size > 0:
593
596
  logger.debug(f"Task {self.task_id} Final input not expected in batch input stream")
@@ -648,8 +651,9 @@ class _ContainerIOManager:
648
651
  output_created_at=output_created_at,
649
652
  result=result,
650
653
  data_format=data_format,
654
+ retry_count=retry_count,
651
655
  )
652
- for input_id, result in zip(io_context.input_ids, results)
656
+ for input_id, retry_count, result in zip(io_context.input_ids, io_context.retry_counts, results)
653
657
  ]
654
658
  await retry_transient_errors(
655
659
  self._client.stub.FunctionPutOutputs,
@@ -14,7 +14,9 @@ class Sentinel: ...
14
14
 
15
15
  class IOContext:
16
16
  input_ids: list[str]
17
+ retry_counts: list[int]
17
18
  function_call_ids: list[str]
19
+ function_inputs: list[modal_proto.api_pb2.FunctionInput]
18
20
  finalized_function: modal._runtime.user_code_imports.FinalizedFunction
19
21
  _cancel_issued: bool
20
22
  _cancel_callback: typing.Optional[collections.abc.Callable[[], None]]
@@ -22,6 +24,7 @@ class IOContext:
22
24
  def __init__(
23
25
  self,
24
26
  input_ids: list[str],
27
+ retry_counts: list[int],
25
28
  function_call_ids: list[str],
26
29
  finalized_function: modal._runtime.user_code_imports.FinalizedFunction,
27
30
  function_inputs: list[modal_proto.api_pb2.FunctionInput],
@@ -33,7 +36,7 @@ class IOContext:
33
36
  cls,
34
37
  client: modal.client._Client,
35
38
  finalized_functions: dict[str, modal._runtime.user_code_imports.FinalizedFunction],
36
- inputs: list[tuple[str, str, modal_proto.api_pb2.FunctionInput]],
39
+ inputs: list[tuple[str, int, str, modal_proto.api_pb2.FunctionInput]],
37
40
  is_batched: bool,
38
41
  ) -> IOContext: ...
39
42
  def set_cancel_callback(self, cb: collections.abc.Callable[[], None]): ...
@@ -116,7 +119,7 @@ class _ContainerIOManager:
116
119
  def get_max_inputs_to_fetch(self): ...
117
120
  def _generate_inputs(
118
121
  self, batch_max_size: int, batch_wait_ms: int
119
- ) -> collections.abc.AsyncIterator[list[tuple[str, str, modal_proto.api_pb2.FunctionInput]]]: ...
122
+ ) -> collections.abc.AsyncIterator[list[tuple[str, int, str, modal_proto.api_pb2.FunctionInput]]]: ...
120
123
  def run_inputs_outputs(
121
124
  self,
122
125
  finalized_functions: dict[str, modal._runtime.user_code_imports.FinalizedFunction],
@@ -287,10 +290,10 @@ class ContainerIOManager:
287
290
  class ___generate_inputs_spec(typing_extensions.Protocol[SUPERSELF]):
288
291
  def __call__(
289
292
  self, batch_max_size: int, batch_wait_ms: int
290
- ) -> typing.Iterator[list[tuple[str, str, modal_proto.api_pb2.FunctionInput]]]: ...
293
+ ) -> typing.Iterator[list[tuple[str, int, str, modal_proto.api_pb2.FunctionInput]]]: ...
291
294
  def aio(
292
295
  self, batch_max_size: int, batch_wait_ms: int
293
- ) -> collections.abc.AsyncIterator[list[tuple[str, str, modal_proto.api_pb2.FunctionInput]]]: ...
296
+ ) -> collections.abc.AsyncIterator[list[tuple[str, int, str, modal_proto.api_pb2.FunctionInput]]]: ...
294
297
 
295
298
  _generate_inputs: ___generate_inputs_spec[typing_extensions.Self]
296
299
 
modal/_serialization.py CHANGED
@@ -400,6 +400,7 @@ class ParamTypeInfo:
400
400
  default_field: str
401
401
  proto_field: str
402
402
  converter: typing.Callable[[str], typing.Any]
403
+ type: type
403
404
 
404
405
 
405
406
  PYTHON_TO_PROTO_TYPE: dict[type, "api_pb2.ParameterType.ValueType"] = {
@@ -411,75 +412,112 @@ PYTHON_TO_PROTO_TYPE: dict[type, "api_pb2.ParameterType.ValueType"] = {
411
412
 
412
413
  PROTO_TYPE_INFO = {
413
414
  # Protobuf type enum -> encode/decode helper metadata
414
- api_pb2.PARAM_TYPE_STRING: ParamTypeInfo(default_field="string_default", proto_field="string_value", converter=str),
415
- api_pb2.PARAM_TYPE_INT: ParamTypeInfo(default_field="int_default", proto_field="int_value", converter=int),
415
+ api_pb2.PARAM_TYPE_STRING: ParamTypeInfo(
416
+ default_field="string_default",
417
+ proto_field="string_value",
418
+ converter=str,
419
+ type=str,
420
+ ),
421
+ api_pb2.PARAM_TYPE_INT: ParamTypeInfo(
422
+ default_field="int_default",
423
+ proto_field="int_value",
424
+ converter=int,
425
+ type=int,
426
+ ),
416
427
  api_pb2.PARAM_TYPE_BYTES: ParamTypeInfo(
417
- default_field="bytes_default", proto_field="bytes_value", converter=assert_bytes
428
+ default_field="bytes_default",
429
+ proto_field="bytes_value",
430
+ converter=assert_bytes,
431
+ type=bytes,
418
432
  ),
419
433
  }
420
434
 
421
435
 
422
- def serialize_proto_params(python_params: dict[str, Any], schema: typing.Sequence[api_pb2.ClassParameterSpec]) -> bytes:
423
- proto_params: list[api_pb2.ClassParameterValue] = []
436
+ def apply_defaults(
437
+ python_params: typing.Mapping[str, Any], schema: typing.Sequence[api_pb2.ClassParameterSpec]
438
+ ) -> dict[str, Any]:
439
+ """Apply any declared defaults from the provided schema, if values aren't provided in python_params
440
+
441
+ Conceptually similar to inspect.BoundArguments.apply_defaults.
442
+
443
+ Note: Apply this before serializing parameters in order to get consistent parameter
444
+ pools regardless if a value is explicitly provided or not.
445
+ """
446
+ result = {**python_params}
424
447
  for schema_param in schema:
425
- type_info = PROTO_TYPE_INFO.get(schema_param.type)
426
- if not type_info:
427
- raise ValueError(f"Unsupported parameter type: {schema_param.type}")
448
+ if schema_param.has_default and schema_param.name not in python_params:
449
+ default_field_name = schema_param.WhichOneof("default_oneof")
450
+ if default_field_name is None:
451
+ raise InvalidError(f"{schema_param.name} declared as having a default, but has no default value")
452
+ result[schema_param.name] = getattr(schema_param, default_field_name)
453
+ return result
454
+
455
+
456
+ def serialize_proto_params(python_params: dict[str, Any]) -> bytes:
457
+ proto_params: list[api_pb2.ClassParameterValue] = []
458
+ for param_name, python_value in python_params.items():
459
+ python_type = type(python_value)
460
+ protobuf_type = get_proto_parameter_type(python_type)
461
+ type_info = PROTO_TYPE_INFO.get(protobuf_type)
428
462
  proto_param = api_pb2.ClassParameterValue(
429
- name=schema_param.name,
430
- type=schema_param.type,
463
+ name=param_name,
464
+ type=protobuf_type,
431
465
  )
432
- python_value = python_params.get(schema_param.name)
433
- if python_value is None:
434
- if schema_param.has_default:
435
- python_value = getattr(schema_param, type_info.default_field)
436
- else:
437
- raise ValueError(f"Missing required parameter: {schema_param.name}")
438
466
  try:
439
467
  converted_value = type_info.converter(python_value)
440
468
  except ValueError as exc:
441
- raise ValueError(f"Invalid type for parameter {schema_param.name}: {exc}")
469
+ raise ValueError(f"Invalid type for parameter {param_name}: {exc}")
442
470
  setattr(proto_param, type_info.proto_field, converted_value)
443
471
  proto_params.append(proto_param)
444
472
  proto_bytes = api_pb2.ClassParameterSet(parameters=proto_params).SerializeToString(deterministic=True)
445
473
  return proto_bytes
446
474
 
447
475
 
448
- def deserialize_proto_params(serialized_params: bytes, schema: list[api_pb2.ClassParameterSpec]) -> dict[str, Any]:
449
- # TODO: this currently requires the schema to decode a payload, but we should make the validation
450
- # distinct from the deserialization
476
+ def deserialize_proto_params(serialized_params: bytes) -> dict[str, Any]:
451
477
  proto_struct = api_pb2.ClassParameterSet()
452
478
  proto_struct.ParseFromString(serialized_params)
453
- value_by_name = {p.name: p for p in proto_struct.parameters}
454
479
  python_params = {}
455
- for schema_param in schema:
456
- if schema_param.name not in value_by_name:
457
- # TODO: handle default values? Could just be a flag on the FunctionParameter schema spec,
458
- # allowing it to not be supplied in the FunctionParameterSet?
459
- raise AttributeError(f"Constructor arguments don't match declared parameters (missing {schema_param.name})")
460
- param_value = value_by_name[schema_param.name]
461
- if schema_param.type != param_value.type:
462
- raise ValueError(
463
- "Constructor arguments types don't match declared parameters "
464
- f"({schema_param.name}: type {schema_param.type} != type {param_value.type})"
465
- )
480
+ for param in proto_struct.parameters:
466
481
  python_value: Any
467
- if schema_param.type == api_pb2.PARAM_TYPE_STRING:
468
- python_value = param_value.string_value
469
- elif schema_param.type == api_pb2.PARAM_TYPE_INT:
470
- python_value = param_value.int_value
471
- elif schema_param.type == api_pb2.PARAM_TYPE_BYTES:
472
- python_value = param_value.bytes_value
482
+ if param.type == api_pb2.PARAM_TYPE_STRING:
483
+ python_value = param.string_value
484
+ elif param.type == api_pb2.PARAM_TYPE_INT:
485
+ python_value = param.int_value
486
+ elif param.type == api_pb2.PARAM_TYPE_BYTES:
487
+ python_value = param.bytes_value
473
488
  else:
474
- # TODO(elias): based on `parameters` declared types, we could add support for
475
- # custom non proto types encoded as bytes in the proto, e.g. PARAM_TYPE_PYTHON_PICKLE
476
- raise NotImplementedError("Only strings and ints are supported parameter value types at the moment")
489
+ raise NotImplementedError(f"Unimplemented parameter type: {param.type}.")
477
490
 
478
- python_params[schema_param.name] = python_value
491
+ python_params[param.name] = python_value
479
492
 
480
493
  return python_params
481
494
 
482
495
 
496
+ def validate_params(params: dict[str, Any], schema: typing.Sequence[api_pb2.ClassParameterSpec]):
497
+ # first check that all declared values are provided
498
+ for schema_param in schema:
499
+ if schema_param.name not in params:
500
+ # we expect all values to be present - even defaulted ones (defaults are applied on payload construction)
501
+ raise InvalidError(f"Missing required parameter: {schema_param.name}")
502
+ python_value = params[schema_param.name]
503
+ python_type = type(python_value)
504
+ param_protobuf_type = get_proto_parameter_type(python_type)
505
+ if schema_param.type != param_protobuf_type:
506
+ expected_python_type = PROTO_TYPE_INFO[schema_param.type].type
507
+ raise TypeError(
508
+ f"Parameter '{schema_param.name}' type error: expected {expected_python_type.__name__}, "
509
+ f"got {python_type.__name__}"
510
+ )
511
+
512
+ schema_fields = {p.name for p in schema}
513
+ # then check that no extra values are provided
514
+ non_declared_fields = params.keys() - schema_fields
515
+ if non_declared_fields:
516
+ raise InvalidError(
517
+ f"The following parameter names were provided but are not present in the schema: {non_declared_fields}"
518
+ )
519
+
520
+
483
521
  def deserialize_params(serialized_params: bytes, function_def: api_pb2.Function, _client: "modal.client._Client"):
484
522
  if function_def.class_parameter_info.format in (
485
523
  api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_UNSPECIFIED,
@@ -488,11 +526,21 @@ def deserialize_params(serialized_params: bytes, function_def: api_pb2.Function,
488
526
  # legacy serialization format - pickle of `(args, kwargs)` w/ support for modal object arguments
489
527
  param_args, param_kwargs = deserialize(serialized_params, _client)
490
528
  elif function_def.class_parameter_info.format == api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PROTO:
491
- param_args = ()
492
- param_kwargs = deserialize_proto_params(serialized_params, list(function_def.class_parameter_info.schema))
529
+ param_args = () # we use kwargs only for our implicit constructors
530
+ param_kwargs = deserialize_proto_params(serialized_params)
531
+ # TODO: We can probably remove the validation below since we do validation in the caller?
532
+ validate_params(param_kwargs, list(function_def.class_parameter_info.schema))
493
533
  else:
494
534
  raise ExecutionError(
495
535
  f"Unknown class parameter serialization format: {function_def.class_parameter_info.format}"
496
536
  )
497
537
 
498
538
  return param_args, param_kwargs
539
+
540
+
541
+ def get_proto_parameter_type(parameter_type: type) -> "api_pb2.ParameterType.ValueType":
542
+ if parameter_type not in PYTHON_TO_PROTO_TYPE:
543
+ type_name = getattr(parameter_type, "__name__", repr(parameter_type))
544
+ supported = ", ".join(parameter_type.__name__ for parameter_type in PYTHON_TO_PROTO_TYPE.keys())
545
+ raise InvalidError(f"{type_name} is not a supported parameter type. Use one of: {supported}")
546
+ return PYTHON_TO_PROTO_TYPE[parameter_type]
@@ -12,6 +12,7 @@ from dataclasses import dataclass
12
12
  from typing import (
13
13
  Any,
14
14
  Callable,
15
+ Generic,
15
16
  Optional,
16
17
  TypeVar,
17
18
  Union,
@@ -26,6 +27,10 @@ from typing_extensions import ParamSpec, assert_type
26
27
  from ..exception import InvalidError
27
28
  from .logger import logger
28
29
 
30
+ T = TypeVar("T")
31
+ P = ParamSpec("P")
32
+ V = TypeVar("V")
33
+
29
34
  synchronizer = synchronicity.Synchronizer()
30
35
 
31
36
 
@@ -260,7 +265,72 @@ def run_coro_blocking(coro):
260
265
  return fut.result()
261
266
 
262
267
 
263
- async def queue_batch_iterator(q: asyncio.Queue, max_batch_size=100, debounce_time=0.015):
268
+ class TimestampPriorityQueue(Generic[T]):
269
+ """
270
+ A priority queue that schedules items to be processed at specific timestamps.
271
+ """
272
+
273
+ _MAX_PRIORITY = float("inf")
274
+
275
+ def __init__(self, maxsize: int = 0):
276
+ self.condition = asyncio.Condition()
277
+ self._queue: asyncio.PriorityQueue[tuple[float, Union[T, None]]] = asyncio.PriorityQueue(maxsize=maxsize)
278
+
279
+ async def close(self):
280
+ await self.put(self._MAX_PRIORITY, None)
281
+
282
+ async def put(self, timestamp: float, item: Union[T, None]):
283
+ """
284
+ Add an item to the queue to be processed at a specific timestamp.
285
+ """
286
+ await self._queue.put((timestamp, item))
287
+ async with self.condition:
288
+ self.condition.notify_all() # notify any waiting coroutines
289
+
290
+ async def get(self) -> Union[T, None]:
291
+ """
292
+ Get the next item from the queue that is ready to be processed.
293
+ """
294
+ while True:
295
+ async with self.condition:
296
+ while self.empty():
297
+ await self.condition.wait()
298
+ # peek at the next item
299
+ timestamp, item = await self._queue.get()
300
+ now = time.time()
301
+ if timestamp < now:
302
+ return item
303
+ if timestamp == self._MAX_PRIORITY:
304
+ return None
305
+ # not ready yet, calculate sleep time
306
+ sleep_time = timestamp - now
307
+ self._queue.put_nowait((timestamp, item)) # put it back
308
+ # wait until either the timeout or a new item is added
309
+ try:
310
+ await asyncio.wait_for(self.condition.wait(), timeout=sleep_time)
311
+ except asyncio.TimeoutError:
312
+ continue
313
+
314
+ def empty(self) -> bool:
315
+ return self._queue.empty()
316
+
317
+ def qsize(self) -> int:
318
+ return self._queue.qsize()
319
+
320
+ async def clear(self):
321
+ """
322
+ Clear the retry queue. Used for testing to simulate reading all elements from queue using queue_batch_iterator.
323
+ """
324
+ while not self.empty():
325
+ await self.get()
326
+
327
+ def __len__(self):
328
+ return self._queue.qsize()
329
+
330
+
331
+ async def queue_batch_iterator(
332
+ q: Union[asyncio.Queue, TimestampPriorityQueue], max_batch_size=100, debounce_time=0.015
333
+ ):
264
334
  """
265
335
  Read from a queue but return lists of items when queue is large
266
336
 
@@ -405,11 +475,6 @@ def on_shutdown(coro):
405
475
  _shutdown_tasks.append(asyncio.create_task(wrapper()))
406
476
 
407
477
 
408
- T = TypeVar("T")
409
- P = ParamSpec("P")
410
- V = TypeVar("V")
411
-
412
-
413
478
  def asyncify(f: Callable[P, T]) -> Callable[P, typing.Coroutine[None, None, T]]:
414
479
  """Convert a blocking function into one that runs in the current loop's executor."""
415
480