modal 0.73.116__py3-none-any.whl → 0.73.126__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
modal/_functions.py CHANGED
@@ -1,4 +1,5 @@
1
1
  # Copyright Modal Labs 2023
2
+ import asyncio
2
3
  import dataclasses
3
4
  import inspect
4
5
  import textwrap
@@ -24,7 +25,7 @@ from ._pty import get_pty_info
24
25
  from ._resolver import Resolver
25
26
  from ._resources import convert_fn_config_to_resources_config
26
27
  from ._runtime.execution_context import current_input_id, is_local
27
- from ._serialization import serialize, serialize_proto_params
28
+ from ._serialization import apply_defaults, serialize, serialize_proto_params, validate_params
28
29
  from ._traceback import print_server_warnings
29
30
  from ._utils.async_utils import (
30
31
  TaskContext,
@@ -174,7 +175,7 @@ class _Invocation:
174
175
  input_jwt=input.input_jwt,
175
176
  input_id=input.input_id,
176
177
  item=item,
177
- sync_client_retries_enabled=response.sync_client_retries_enabled
178
+ sync_client_retries_enabled=response.sync_client_retries_enabled,
178
179
  )
179
180
  return _Invocation(client.stub, function_call_id, client, retry_context)
180
181
 
@@ -256,9 +257,13 @@ class _Invocation:
256
257
  try:
257
258
  return await self._get_single_output(ctx.input_jwt)
258
259
  except (UserCodeException, FunctionTimeoutError) as exc:
259
- await user_retry_manager.raise_or_sleep(exc)
260
+ delay_ms = user_retry_manager.get_delay_ms()
261
+ if delay_ms is None:
262
+ raise exc
263
+ await asyncio.sleep(delay_ms / 1000)
260
264
  except InternalFailure:
261
- # For system failures on the server, we retry immediately.
265
+ # For system failures on the server, we retry immediately,
266
+ # and the failure does not count towards the retry policy.
262
267
  pass
263
268
  await self._retry_input()
264
269
 
@@ -968,8 +973,11 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
968
973
  "Can't use positional arguments with modal.parameter-based synthetic constructors.\n"
969
974
  "Use (<parameter_name>=value) keyword arguments when constructing classes instead."
970
975
  )
971
- serialized_params = serialize_proto_params(kwargs, parent._class_parameter_info.schema)
972
- can_use_parent = len(parent._class_parameter_info.schema) == 0
976
+ schema = parent._class_parameter_info.schema
977
+ kwargs_with_defaults = apply_defaults(kwargs, schema)
978
+ validate_params(kwargs_with_defaults, schema)
979
+ serialized_params = serialize_proto_params(kwargs_with_defaults)
980
+ can_use_parent = len(parent._class_parameter_info.schema) == 0 # no parameters
973
981
  else:
974
982
  can_use_parent = len(args) + len(kwargs) == 0 and options is None
975
983
  serialized_params = serialize((args, kwargs))
@@ -1304,6 +1312,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
1304
1312
  order_outputs,
1305
1313
  return_exceptions,
1306
1314
  count_update_callback,
1315
+ api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC
1307
1316
  )
1308
1317
  ) as stream:
1309
1318
  async for item in stream:
@@ -63,7 +63,9 @@ class IOContext:
63
63
  """
64
64
 
65
65
  input_ids: list[str]
66
+ retry_counts: list[int]
66
67
  function_call_ids: list[str]
68
+ function_inputs: list[api_pb2.FunctionInput]
67
69
  finalized_function: "modal._runtime.user_code_imports.FinalizedFunction"
68
70
 
69
71
  _cancel_issued: bool = False
@@ -72,6 +74,7 @@ class IOContext:
72
74
  def __init__(
73
75
  self,
74
76
  input_ids: list[str],
77
+ retry_counts: list[int],
75
78
  function_call_ids: list[str],
76
79
  finalized_function: "modal._runtime.user_code_imports.FinalizedFunction",
77
80
  function_inputs: list[api_pb2.FunctionInput],
@@ -79,9 +82,10 @@ class IOContext:
79
82
  client: _Client,
80
83
  ):
81
84
  self.input_ids = input_ids
85
+ self.retry_counts = retry_counts
82
86
  self.function_call_ids = function_call_ids
83
87
  self.finalized_function = finalized_function
84
- self._function_inputs = function_inputs
88
+ self.function_inputs = function_inputs
85
89
  self._is_batched = is_batched
86
90
  self._client = client
87
91
 
@@ -90,11 +94,11 @@ class IOContext:
90
94
  cls,
91
95
  client: _Client,
92
96
  finalized_functions: dict[str, "modal._runtime.user_code_imports.FinalizedFunction"],
93
- inputs: list[tuple[str, str, api_pb2.FunctionInput]],
97
+ inputs: list[tuple[str, int, str, api_pb2.FunctionInput]],
94
98
  is_batched: bool,
95
99
  ) -> "IOContext":
96
100
  assert len(inputs) >= 1 if is_batched else len(inputs) == 1
97
- input_ids, function_call_ids, function_inputs = zip(*inputs)
101
+ input_ids, retry_counts, function_call_ids, function_inputs = zip(*inputs)
98
102
 
99
103
  async def _populate_input_blobs(client: _Client, input: api_pb2.FunctionInput) -> api_pb2.FunctionInput:
100
104
  # If we got a pointer to a blob, download it from S3.
@@ -111,7 +115,7 @@ class IOContext:
111
115
  method_name = function_inputs[0].method_name
112
116
  assert all(method_name == input.method_name for input in function_inputs)
113
117
  finalized_function = finalized_functions[method_name]
114
- return cls(input_ids, function_call_ids, finalized_function, function_inputs, is_batched, client)
118
+ return cls(input_ids, retry_counts, function_call_ids, finalized_function, function_inputs, is_batched, client)
115
119
 
116
120
  def set_cancel_callback(self, cb: Callable[[], None]):
117
121
  self._cancel_callback = cb
@@ -135,7 +139,7 @@ class IOContext:
135
139
  # to make sure we handle user exceptions properly
136
140
  # and don't retry
137
141
  deserialized_args = [
138
- deserialize(input.args, self._client) if input.args else ((), {}) for input in self._function_inputs
142
+ deserialize(input.args, self._client) if input.args else ((), {}) for input in self.function_inputs
139
143
  ]
140
144
  if not self._is_batched:
141
145
  return deserialized_args[0]
@@ -551,7 +555,7 @@ class _ContainerIOManager:
551
555
  self,
552
556
  batch_max_size: int,
553
557
  batch_wait_ms: int,
554
- ) -> AsyncIterator[list[tuple[str, str, api_pb2.FunctionInput]]]:
558
+ ) -> AsyncIterator[list[tuple[str, int, str, api_pb2.FunctionInput]]]:
555
559
  request = api_pb2.FunctionGetInputsRequest(function_id=self.function_id)
556
560
  iteration = 0
557
561
  while self._fetching_inputs:
@@ -586,8 +590,7 @@ class _ContainerIOManager:
586
590
  if item.kill_switch:
587
591
  logger.debug(f"Task {self.task_id} input kill signal input.")
588
592
  return
589
-
590
- inputs.append((item.input_id, item.function_call_id, item.input))
593
+ inputs.append((item.input_id, item.retry_count, item.function_call_id, item.input))
591
594
  if item.input.final_input:
592
595
  if request.batch_max_size > 0:
593
596
  logger.debug(f"Task {self.task_id} Final input not expected in batch input stream")
@@ -648,8 +651,9 @@ class _ContainerIOManager:
648
651
  output_created_at=output_created_at,
649
652
  result=result,
650
653
  data_format=data_format,
654
+ retry_count=retry_count,
651
655
  )
652
- for input_id, result in zip(io_context.input_ids, results)
656
+ for input_id, retry_count, result in zip(io_context.input_ids, io_context.retry_counts, results)
653
657
  ]
654
658
  await retry_transient_errors(
655
659
  self._client.stub.FunctionPutOutputs,
@@ -14,7 +14,9 @@ class Sentinel: ...
14
14
 
15
15
  class IOContext:
16
16
  input_ids: list[str]
17
+ retry_counts: list[int]
17
18
  function_call_ids: list[str]
19
+ function_inputs: list[modal_proto.api_pb2.FunctionInput]
18
20
  finalized_function: modal._runtime.user_code_imports.FinalizedFunction
19
21
  _cancel_issued: bool
20
22
  _cancel_callback: typing.Optional[collections.abc.Callable[[], None]]
@@ -22,6 +24,7 @@ class IOContext:
22
24
  def __init__(
23
25
  self,
24
26
  input_ids: list[str],
27
+ retry_counts: list[int],
25
28
  function_call_ids: list[str],
26
29
  finalized_function: modal._runtime.user_code_imports.FinalizedFunction,
27
30
  function_inputs: list[modal_proto.api_pb2.FunctionInput],
@@ -33,7 +36,7 @@ class IOContext:
33
36
  cls,
34
37
  client: modal.client._Client,
35
38
  finalized_functions: dict[str, modal._runtime.user_code_imports.FinalizedFunction],
36
- inputs: list[tuple[str, str, modal_proto.api_pb2.FunctionInput]],
39
+ inputs: list[tuple[str, int, str, modal_proto.api_pb2.FunctionInput]],
37
40
  is_batched: bool,
38
41
  ) -> IOContext: ...
39
42
  def set_cancel_callback(self, cb: collections.abc.Callable[[], None]): ...
@@ -116,7 +119,7 @@ class _ContainerIOManager:
116
119
  def get_max_inputs_to_fetch(self): ...
117
120
  def _generate_inputs(
118
121
  self, batch_max_size: int, batch_wait_ms: int
119
- ) -> collections.abc.AsyncIterator[list[tuple[str, str, modal_proto.api_pb2.FunctionInput]]]: ...
122
+ ) -> collections.abc.AsyncIterator[list[tuple[str, int, str, modal_proto.api_pb2.FunctionInput]]]: ...
120
123
  def run_inputs_outputs(
121
124
  self,
122
125
  finalized_functions: dict[str, modal._runtime.user_code_imports.FinalizedFunction],
@@ -287,10 +290,10 @@ class ContainerIOManager:
287
290
  class ___generate_inputs_spec(typing_extensions.Protocol[SUPERSELF]):
288
291
  def __call__(
289
292
  self, batch_max_size: int, batch_wait_ms: int
290
- ) -> typing.Iterator[list[tuple[str, str, modal_proto.api_pb2.FunctionInput]]]: ...
293
+ ) -> typing.Iterator[list[tuple[str, int, str, modal_proto.api_pb2.FunctionInput]]]: ...
291
294
  def aio(
292
295
  self, batch_max_size: int, batch_wait_ms: int
293
- ) -> collections.abc.AsyncIterator[list[tuple[str, str, modal_proto.api_pb2.FunctionInput]]]: ...
296
+ ) -> collections.abc.AsyncIterator[list[tuple[str, int, str, modal_proto.api_pb2.FunctionInput]]]: ...
294
297
 
295
298
  _generate_inputs: ___generate_inputs_spec[typing_extensions.Self]
296
299
 
modal/_serialization.py CHANGED
@@ -400,6 +400,7 @@ class ParamTypeInfo:
400
400
  default_field: str
401
401
  proto_field: str
402
402
  converter: typing.Callable[[str], typing.Any]
403
+ type: type
403
404
 
404
405
 
405
406
  PYTHON_TO_PROTO_TYPE: dict[type, "api_pb2.ParameterType.ValueType"] = {
@@ -411,75 +412,112 @@ PYTHON_TO_PROTO_TYPE: dict[type, "api_pb2.ParameterType.ValueType"] = {
411
412
 
412
413
  PROTO_TYPE_INFO = {
413
414
  # Protobuf type enum -> encode/decode helper metadata
414
- api_pb2.PARAM_TYPE_STRING: ParamTypeInfo(default_field="string_default", proto_field="string_value", converter=str),
415
- api_pb2.PARAM_TYPE_INT: ParamTypeInfo(default_field="int_default", proto_field="int_value", converter=int),
415
+ api_pb2.PARAM_TYPE_STRING: ParamTypeInfo(
416
+ default_field="string_default",
417
+ proto_field="string_value",
418
+ converter=str,
419
+ type=str,
420
+ ),
421
+ api_pb2.PARAM_TYPE_INT: ParamTypeInfo(
422
+ default_field="int_default",
423
+ proto_field="int_value",
424
+ converter=int,
425
+ type=int,
426
+ ),
416
427
  api_pb2.PARAM_TYPE_BYTES: ParamTypeInfo(
417
- default_field="bytes_default", proto_field="bytes_value", converter=assert_bytes
428
+ default_field="bytes_default",
429
+ proto_field="bytes_value",
430
+ converter=assert_bytes,
431
+ type=bytes,
418
432
  ),
419
433
  }
420
434
 
421
435
 
422
- def serialize_proto_params(python_params: dict[str, Any], schema: typing.Sequence[api_pb2.ClassParameterSpec]) -> bytes:
423
- proto_params: list[api_pb2.ClassParameterValue] = []
436
+ def apply_defaults(
437
+ python_params: typing.Mapping[str, Any], schema: typing.Sequence[api_pb2.ClassParameterSpec]
438
+ ) -> dict[str, Any]:
439
+ """Apply any declared defaults from the provided schema, if values aren't provided in python_params
440
+
441
+ Conceptually similar to inspect.BoundArguments.apply_defaults.
442
+
443
+ Note: Apply this before serializing parameters in order to get consistent parameter
444
+ pools regardless if a value is explicitly provided or not.
445
+ """
446
+ result = {**python_params}
424
447
  for schema_param in schema:
425
- type_info = PROTO_TYPE_INFO.get(schema_param.type)
426
- if not type_info:
427
- raise ValueError(f"Unsupported parameter type: {schema_param.type}")
448
+ if schema_param.has_default and schema_param.name not in python_params:
449
+ default_field_name = schema_param.WhichOneof("default_oneof")
450
+ if default_field_name is None:
451
+ raise InvalidError(f"{schema_param.name} declared as having a default, but has no default value")
452
+ result[schema_param.name] = getattr(schema_param, default_field_name)
453
+ return result
454
+
455
+
456
+ def serialize_proto_params(python_params: dict[str, Any]) -> bytes:
457
+ proto_params: list[api_pb2.ClassParameterValue] = []
458
+ for param_name, python_value in python_params.items():
459
+ python_type = type(python_value)
460
+ protobuf_type = get_proto_parameter_type(python_type)
461
+ type_info = PROTO_TYPE_INFO.get(protobuf_type)
428
462
  proto_param = api_pb2.ClassParameterValue(
429
- name=schema_param.name,
430
- type=schema_param.type,
463
+ name=param_name,
464
+ type=protobuf_type,
431
465
  )
432
- python_value = python_params.get(schema_param.name)
433
- if python_value is None:
434
- if schema_param.has_default:
435
- python_value = getattr(schema_param, type_info.default_field)
436
- else:
437
- raise ValueError(f"Missing required parameter: {schema_param.name}")
438
466
  try:
439
467
  converted_value = type_info.converter(python_value)
440
468
  except ValueError as exc:
441
- raise ValueError(f"Invalid type for parameter {schema_param.name}: {exc}")
469
+ raise ValueError(f"Invalid type for parameter {param_name}: {exc}")
442
470
  setattr(proto_param, type_info.proto_field, converted_value)
443
471
  proto_params.append(proto_param)
444
472
  proto_bytes = api_pb2.ClassParameterSet(parameters=proto_params).SerializeToString(deterministic=True)
445
473
  return proto_bytes
446
474
 
447
475
 
448
- def deserialize_proto_params(serialized_params: bytes, schema: list[api_pb2.ClassParameterSpec]) -> dict[str, Any]:
449
- # TODO: this currently requires the schema to decode a payload, but we should make the validation
450
- # distinct from the deserialization
476
+ def deserialize_proto_params(serialized_params: bytes) -> dict[str, Any]:
451
477
  proto_struct = api_pb2.ClassParameterSet()
452
478
  proto_struct.ParseFromString(serialized_params)
453
- value_by_name = {p.name: p for p in proto_struct.parameters}
454
479
  python_params = {}
455
- for schema_param in schema:
456
- if schema_param.name not in value_by_name:
457
- # TODO: handle default values? Could just be a flag on the FunctionParameter schema spec,
458
- # allowing it to not be supplied in the FunctionParameterSet?
459
- raise AttributeError(f"Constructor arguments don't match declared parameters (missing {schema_param.name})")
460
- param_value = value_by_name[schema_param.name]
461
- if schema_param.type != param_value.type:
462
- raise ValueError(
463
- "Constructor arguments types don't match declared parameters "
464
- f"({schema_param.name}: type {schema_param.type} != type {param_value.type})"
465
- )
480
+ for param in proto_struct.parameters:
466
481
  python_value: Any
467
- if schema_param.type == api_pb2.PARAM_TYPE_STRING:
468
- python_value = param_value.string_value
469
- elif schema_param.type == api_pb2.PARAM_TYPE_INT:
470
- python_value = param_value.int_value
471
- elif schema_param.type == api_pb2.PARAM_TYPE_BYTES:
472
- python_value = param_value.bytes_value
482
+ if param.type == api_pb2.PARAM_TYPE_STRING:
483
+ python_value = param.string_value
484
+ elif param.type == api_pb2.PARAM_TYPE_INT:
485
+ python_value = param.int_value
486
+ elif param.type == api_pb2.PARAM_TYPE_BYTES:
487
+ python_value = param.bytes_value
473
488
  else:
474
- # TODO(elias): based on `parameters` declared types, we could add support for
475
- # custom non proto types encoded as bytes in the proto, e.g. PARAM_TYPE_PYTHON_PICKLE
476
- raise NotImplementedError("Only strings and ints are supported parameter value types at the moment")
489
+ raise NotImplementedError(f"Unimplemented parameter type: {param.type}.")
477
490
 
478
- python_params[schema_param.name] = python_value
491
+ python_params[param.name] = python_value
479
492
 
480
493
  return python_params
481
494
 
482
495
 
496
+ def validate_params(params: dict[str, Any], schema: typing.Sequence[api_pb2.ClassParameterSpec]):
497
+ # first check that all declared values are provided
498
+ for schema_param in schema:
499
+ if schema_param.name not in params:
500
+ # we expect all values to be present - even defaulted ones (defaults are applied on payload construction)
501
+ raise InvalidError(f"Missing required parameter: {schema_param.name}")
502
+ python_value = params[schema_param.name]
503
+ python_type = type(python_value)
504
+ param_protobuf_type = get_proto_parameter_type(python_type)
505
+ if schema_param.type != param_protobuf_type:
506
+ expected_python_type = PROTO_TYPE_INFO[schema_param.type].type
507
+ raise TypeError(
508
+ f"Parameter '{schema_param.name}' type error: expected {expected_python_type.__name__}, "
509
+ f"got {python_type.__name__}"
510
+ )
511
+
512
+ schema_fields = {p.name for p in schema}
513
+ # then check that no extra values are provided
514
+ non_declared_fields = params.keys() - schema_fields
515
+ if non_declared_fields:
516
+ raise InvalidError(
517
+ f"The following parameter names were provided but are not present in the schema: {non_declared_fields}"
518
+ )
519
+
520
+
483
521
  def deserialize_params(serialized_params: bytes, function_def: api_pb2.Function, _client: "modal.client._Client"):
484
522
  if function_def.class_parameter_info.format in (
485
523
  api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_UNSPECIFIED,
@@ -488,11 +526,21 @@ def deserialize_params(serialized_params: bytes, function_def: api_pb2.Function,
488
526
  # legacy serialization format - pickle of `(args, kwargs)` w/ support for modal object arguments
489
527
  param_args, param_kwargs = deserialize(serialized_params, _client)
490
528
  elif function_def.class_parameter_info.format == api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PROTO:
491
- param_args = ()
492
- param_kwargs = deserialize_proto_params(serialized_params, list(function_def.class_parameter_info.schema))
529
+ param_args = () # we use kwargs only for our implicit constructors
530
+ param_kwargs = deserialize_proto_params(serialized_params)
531
+ # TODO: We can probably remove the validation below since we do validation in the caller?
532
+ validate_params(param_kwargs, list(function_def.class_parameter_info.schema))
493
533
  else:
494
534
  raise ExecutionError(
495
535
  f"Unknown class parameter serialization format: {function_def.class_parameter_info.format}"
496
536
  )
497
537
 
498
538
  return param_args, param_kwargs
539
+
540
+
541
+ def get_proto_parameter_type(parameter_type: type) -> "api_pb2.ParameterType.ValueType":
542
+ if parameter_type not in PYTHON_TO_PROTO_TYPE:
543
+ type_name = getattr(parameter_type, "__name__", repr(parameter_type))
544
+ supported = ", ".join(parameter_type.__name__ for parameter_type in PYTHON_TO_PROTO_TYPE.keys())
545
+ raise InvalidError(f"{type_name} is not a supported parameter type. Use one of: {supported}")
546
+ return PYTHON_TO_PROTO_TYPE[parameter_type]
@@ -12,6 +12,7 @@ from dataclasses import dataclass
12
12
  from typing import (
13
13
  Any,
14
14
  Callable,
15
+ Generic,
15
16
  Optional,
16
17
  TypeVar,
17
18
  Union,
@@ -26,6 +27,10 @@ from typing_extensions import ParamSpec, assert_type
26
27
  from ..exception import InvalidError
27
28
  from .logger import logger
28
29
 
30
+ T = TypeVar("T")
31
+ P = ParamSpec("P")
32
+ V = TypeVar("V")
33
+
29
34
  synchronizer = synchronicity.Synchronizer()
30
35
 
31
36
 
@@ -260,7 +265,72 @@ def run_coro_blocking(coro):
260
265
  return fut.result()
261
266
 
262
267
 
263
- async def queue_batch_iterator(q: asyncio.Queue, max_batch_size=100, debounce_time=0.015):
268
+ class TimestampPriorityQueue(Generic[T]):
269
+ """
270
+ A priority queue that schedules items to be processed at specific timestamps.
271
+ """
272
+
273
+ _MAX_PRIORITY = float("inf")
274
+
275
+ def __init__(self, maxsize: int = 0):
276
+ self.condition = asyncio.Condition()
277
+ self._queue: asyncio.PriorityQueue[tuple[float, Union[T, None]]] = asyncio.PriorityQueue(maxsize=maxsize)
278
+
279
+ async def close(self):
280
+ await self.put(self._MAX_PRIORITY, None)
281
+
282
+ async def put(self, timestamp: float, item: Union[T, None]):
283
+ """
284
+ Add an item to the queue to be processed at a specific timestamp.
285
+ """
286
+ await self._queue.put((timestamp, item))
287
+ async with self.condition:
288
+ self.condition.notify_all() # notify any waiting coroutines
289
+
290
+ async def get(self) -> Union[T, None]:
291
+ """
292
+ Get the next item from the queue that is ready to be processed.
293
+ """
294
+ while True:
295
+ async with self.condition:
296
+ while self.empty():
297
+ await self.condition.wait()
298
+ # peek at the next item
299
+ timestamp, item = await self._queue.get()
300
+ now = time.time()
301
+ if timestamp < now:
302
+ return item
303
+ if timestamp == self._MAX_PRIORITY:
304
+ return None
305
+ # not ready yet, calculate sleep time
306
+ sleep_time = timestamp - now
307
+ self._queue.put_nowait((timestamp, item)) # put it back
308
+ # wait until either the timeout or a new item is added
309
+ try:
310
+ await asyncio.wait_for(self.condition.wait(), timeout=sleep_time)
311
+ except asyncio.TimeoutError:
312
+ continue
313
+
314
+ def empty(self) -> bool:
315
+ return self._queue.empty()
316
+
317
+ def qsize(self) -> int:
318
+ return self._queue.qsize()
319
+
320
+ async def clear(self):
321
+ """
322
+ Clear the retry queue. Used for testing to simulate reading all elements from queue using queue_batch_iterator.
323
+ """
324
+ while not self.empty():
325
+ await self.get()
326
+
327
+ def __len__(self):
328
+ return self._queue.qsize()
329
+
330
+
331
+ async def queue_batch_iterator(
332
+ q: Union[asyncio.Queue, TimestampPriorityQueue], max_batch_size=100, debounce_time=0.015
333
+ ):
264
334
  """
265
335
  Read from a queue but return lists of items when queue is large
266
336
 
@@ -405,11 +475,6 @@ def on_shutdown(coro):
405
475
  _shutdown_tasks.append(asyncio.create_task(wrapper()))
406
476
 
407
477
 
408
- T = TypeVar("T")
409
- P = ParamSpec("P")
410
- V = TypeVar("V")
411
-
412
-
413
478
  def asyncify(f: Callable[P, T]) -> Callable[P, typing.Coroutine[None, None, T]]:
414
479
  """Convert a blocking function into one that runs in the current loop's executor."""
415
480
 
@@ -15,7 +15,14 @@ from synchronicity.exceptions import UserCodeException
15
15
  import modal_proto
16
16
  from modal_proto import api_pb2
17
17
 
18
- from .._serialization import PROTO_TYPE_INFO, PYTHON_TO_PROTO_TYPE, deserialize, deserialize_data_format, serialize
18
+ from .._serialization import (
19
+ PROTO_TYPE_INFO,
20
+ PYTHON_TO_PROTO_TYPE,
21
+ deserialize,
22
+ deserialize_data_format,
23
+ get_proto_parameter_type,
24
+ serialize,
25
+ )
19
26
  from .._traceback import append_modal_tb
20
27
  from ..config import config, logger
21
28
  from ..exception import (
@@ -99,6 +106,24 @@ def get_function_type(is_generator: Optional[bool]) -> "api_pb2.Function.Functio
99
106
  return api_pb2.Function.FUNCTION_TYPE_GENERATOR if is_generator else api_pb2.Function.FUNCTION_TYPE_FUNCTION
100
107
 
101
108
 
109
+ def signature_to_protobuf_schema(signature: inspect.Signature) -> list[api_pb2.ClassParameterSpec]:
110
+ modal_parameters: list[api_pb2.ClassParameterSpec] = []
111
+ for param in signature.parameters.values():
112
+ has_default = param.default is not param.empty
113
+ class_param_spec = api_pb2.ClassParameterSpec(name=param.name, has_default=has_default)
114
+ if param.annotation not in PYTHON_TO_PROTO_TYPE:
115
+ class_param_spec.type = api_pb2.PARAM_TYPE_UNKNOWN
116
+ else:
117
+ proto_type = PYTHON_TO_PROTO_TYPE[param.annotation]
118
+ class_param_spec.type = proto_type
119
+ proto_type_info = PROTO_TYPE_INFO[proto_type]
120
+ if has_default and proto_type is not api_pb2.PARAM_TYPE_UNKNOWN:
121
+ setattr(class_param_spec, proto_type_info.default_field, param.default)
122
+
123
+ modal_parameters.append(class_param_spec)
124
+ return modal_parameters
125
+
126
+
102
127
  class FunctionInfo:
103
128
  """Utility that determines serialization/deserialization mechanisms for functions
104
129
 
@@ -277,28 +302,23 @@ class FunctionInfo:
277
302
  return api_pb2.ClassParameterInfo()
278
303
 
279
304
  # TODO(elias): Resolve circular dependencies... maybe we'll need some cls_utils module
280
- from modal.cls import _get_class_constructor_signature, _use_annotation_parameters, _validate_parameter_type
305
+ from modal.cls import _get_class_constructor_signature, _use_annotation_parameters
281
306
 
282
307
  if not _use_annotation_parameters(self.user_cls):
283
308
  return api_pb2.ClassParameterInfo(format=api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PICKLE)
284
309
 
285
310
  # annotation parameters trigger strictly typed parametrization
286
311
  # which enables web endpoint for parametrized classes
287
-
288
- modal_parameters: list[api_pb2.ClassParameterSpec] = []
289
312
  signature = _get_class_constructor_signature(self.user_cls)
313
+ # validate that the schema has no unspecified fields/unsupported class parameter types
290
314
  for param in signature.parameters.values():
291
- has_default = param.default is not param.empty
292
- _validate_parameter_type(self.user_cls.__name__, param.name, param.annotation)
293
- proto_type = PYTHON_TO_PROTO_TYPE[param.annotation]
294
- proto_type_info = PROTO_TYPE_INFO[proto_type]
295
- class_param_spec = api_pb2.ClassParameterSpec(name=param.name, has_default=has_default, type=proto_type)
296
- if has_default:
297
- setattr(class_param_spec, proto_type_info.default_field, param.default)
298
- modal_parameters.append(class_param_spec)
315
+ get_proto_parameter_type(param.annotation)
316
+
317
+ protobuf_schema = signature_to_protobuf_schema(signature)
299
318
 
300
319
  return api_pb2.ClassParameterInfo(
301
- format=api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PROTO, schema=modal_parameters
320
+ format=api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PROTO,
321
+ schema=protobuf_schema,
302
322
  )
303
323
 
304
324
  def get_entrypoint_mount(self) -> dict[str, _Mount]:
@@ -0,0 +1,38 @@
1
+ # Copyright Modal Labs 2025
2
+ import base64
3
+ import json
4
+ from dataclasses import dataclass
5
+ from typing import Any, Dict
6
+
7
+
8
+ @dataclass
9
+ class DecodedJwt:
10
+ header: Dict[str, Any]
11
+ payload: Dict[str, Any]
12
+
13
+ @staticmethod
14
+ def decode_without_verification(token: str) -> "DecodedJwt":
15
+ # Split the JWT into its three parts
16
+ header_b64, payload_b64, _ = token.split(".")
17
+
18
+ # Decode Base64 (with padding handling)
19
+ header_json = base64.urlsafe_b64decode(header_b64 + "==").decode("utf-8")
20
+ payload_json = base64.urlsafe_b64decode(payload_b64 + "==").decode("utf-8")
21
+
22
+ # Convert JSON strings to dictionaries
23
+ header = json.loads(header_json)
24
+ payload = json.loads(payload_json)
25
+
26
+ return DecodedJwt(header, payload)
27
+
28
+ @staticmethod
29
+ def _base64url_encode(data: str) -> str:
30
+ """Encodes data to Base64 URL-safe format without padding."""
31
+ return base64.urlsafe_b64encode(data.encode()).rstrip(b"=").decode()
32
+
33
+ @staticmethod
34
+ def encode_without_signature(fields: Dict[str, Any]) -> str:
35
+ """Encodes an Unsecured JWT (without a signature)."""
36
+ header_b64 = DecodedJwt._base64url_encode(json.dumps({"alg": "none", "typ": "JWT"}))
37
+ payload_b64 = DecodedJwt._base64url_encode(json.dumps(fields))
38
+ return f"{header_b64}.{payload_b64}." # No signature
modal/cli/app.py CHANGED
@@ -227,6 +227,8 @@ async def history(
227
227
  ]
228
228
  rows = []
229
229
  deployments_with_tags = False
230
+ deployments_with_commit_info = False
231
+ deployments_with_dirty_commit = False
230
232
  for idx, app_stats in enumerate(resp.app_deployment_histories):
231
233
  style = "bold green" if idx == 0 else ""
232
234
 
@@ -241,10 +243,23 @@ async def history(
241
243
  deployments_with_tags = True
242
244
  row.append(Text(app_stats.tag, style=style))
243
245
 
246
+ if app_stats.commit_info.commit_hash:
247
+ deployments_with_commit_info = True
248
+ short_hash = app_stats.commit_info.commit_hash[:7]
249
+ if app_stats.commit_info.dirty:
250
+ deployments_with_dirty_commit = True
251
+ short_hash = f"{short_hash}*"
252
+ row.append(Text(short_hash, style=style))
253
+
244
254
  rows.append(row)
245
255
 
246
256
  if deployments_with_tags:
247
257
  columns.append("Tag")
258
+ if deployments_with_commit_info:
259
+ columns.append("Commit")
248
260
 
249
261
  rows = sorted(rows, key=lambda x: int(str(x[0])[1:]), reverse=True)
250
262
  display_table(columns, rows, json)
263
+
264
+ if deployments_with_dirty_commit and not json:
265
+ rich.print("* - repo had uncommitted changes")
modal/client.pyi CHANGED
@@ -31,7 +31,7 @@ class _Client:
31
31
  server_url: str,
32
32
  client_type: int,
33
33
  credentials: typing.Optional[tuple[str, str]],
34
- version: str = "0.73.116",
34
+ version: str = "0.73.126",
35
35
  ): ...
36
36
  def is_closed(self) -> bool: ...
37
37
  @property
@@ -93,7 +93,7 @@ class Client:
93
93
  server_url: str,
94
94
  client_type: int,
95
95
  credentials: typing.Optional[tuple[str, str]],
96
- version: str = "0.73.116",
96
+ version: str = "0.73.126",
97
97
  ): ...
98
98
  def is_closed(self) -> bool: ...
99
99
  @property