modal 0.73.116__py3-none-any.whl → 0.73.126__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- modal/_functions.py +15 -6
- modal/_runtime/container_io_manager.py +13 -9
- modal/_runtime/container_io_manager.pyi +7 -4
- modal/_serialization.py +92 -44
- modal/_utils/async_utils.py +71 -6
- modal/_utils/function_utils.py +33 -13
- modal/_utils/jwt_utils.py +38 -0
- modal/cli/app.py +15 -0
- modal/client.pyi +2 -2
- modal/cls.py +3 -13
- modal/cls.pyi +0 -2
- modal/functions.pyi +6 -6
- modal/parallel_map.py +393 -44
- modal/parallel_map.pyi +75 -0
- modal/retries.py +11 -9
- {modal-0.73.116.dist-info → modal-0.73.126.dist-info}/METADATA +1 -1
- {modal-0.73.116.dist-info → modal-0.73.126.dist-info}/RECORD +29 -28
- {modal-0.73.116.dist-info → modal-0.73.126.dist-info}/WHEEL +1 -1
- modal_proto/api.proto +13 -0
- modal_proto/api_grpc.py +16 -0
- modal_proto/api_pb2.py +284 -263
- modal_proto/api_pb2.pyi +43 -0
- modal_proto/api_pb2_grpc.py +33 -0
- modal_proto/api_pb2_grpc.pyi +10 -0
- modal_proto/modal_api_grpc.py +1 -0
- modal_version/_version_generated.py +1 -1
- {modal-0.73.116.dist-info → modal-0.73.126.dist-info}/LICENSE +0 -0
- {modal-0.73.116.dist-info → modal-0.73.126.dist-info}/entry_points.txt +0 -0
- {modal-0.73.116.dist-info → modal-0.73.126.dist-info}/top_level.txt +0 -0
modal/_functions.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
# Copyright Modal Labs 2023
|
2
|
+
import asyncio
|
2
3
|
import dataclasses
|
3
4
|
import inspect
|
4
5
|
import textwrap
|
@@ -24,7 +25,7 @@ from ._pty import get_pty_info
|
|
24
25
|
from ._resolver import Resolver
|
25
26
|
from ._resources import convert_fn_config_to_resources_config
|
26
27
|
from ._runtime.execution_context import current_input_id, is_local
|
27
|
-
from ._serialization import serialize, serialize_proto_params
|
28
|
+
from ._serialization import apply_defaults, serialize, serialize_proto_params, validate_params
|
28
29
|
from ._traceback import print_server_warnings
|
29
30
|
from ._utils.async_utils import (
|
30
31
|
TaskContext,
|
@@ -174,7 +175,7 @@ class _Invocation:
|
|
174
175
|
input_jwt=input.input_jwt,
|
175
176
|
input_id=input.input_id,
|
176
177
|
item=item,
|
177
|
-
sync_client_retries_enabled=response.sync_client_retries_enabled
|
178
|
+
sync_client_retries_enabled=response.sync_client_retries_enabled,
|
178
179
|
)
|
179
180
|
return _Invocation(client.stub, function_call_id, client, retry_context)
|
180
181
|
|
@@ -256,9 +257,13 @@ class _Invocation:
|
|
256
257
|
try:
|
257
258
|
return await self._get_single_output(ctx.input_jwt)
|
258
259
|
except (UserCodeException, FunctionTimeoutError) as exc:
|
259
|
-
|
260
|
+
delay_ms = user_retry_manager.get_delay_ms()
|
261
|
+
if delay_ms is None:
|
262
|
+
raise exc
|
263
|
+
await asyncio.sleep(delay_ms / 1000)
|
260
264
|
except InternalFailure:
|
261
|
-
# For system failures on the server, we retry immediately
|
265
|
+
# For system failures on the server, we retry immediately,
|
266
|
+
# and the failure does not count towards the retry policy.
|
262
267
|
pass
|
263
268
|
await self._retry_input()
|
264
269
|
|
@@ -968,8 +973,11 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
|
|
968
973
|
"Can't use positional arguments with modal.parameter-based synthetic constructors.\n"
|
969
974
|
"Use (<parameter_name>=value) keyword arguments when constructing classes instead."
|
970
975
|
)
|
971
|
-
|
972
|
-
|
976
|
+
schema = parent._class_parameter_info.schema
|
977
|
+
kwargs_with_defaults = apply_defaults(kwargs, schema)
|
978
|
+
validate_params(kwargs_with_defaults, schema)
|
979
|
+
serialized_params = serialize_proto_params(kwargs_with_defaults)
|
980
|
+
can_use_parent = len(parent._class_parameter_info.schema) == 0 # no parameters
|
973
981
|
else:
|
974
982
|
can_use_parent = len(args) + len(kwargs) == 0 and options is None
|
975
983
|
serialized_params = serialize((args, kwargs))
|
@@ -1304,6 +1312,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
|
|
1304
1312
|
order_outputs,
|
1305
1313
|
return_exceptions,
|
1306
1314
|
count_update_callback,
|
1315
|
+
api_pb2.FUNCTION_CALL_INVOCATION_TYPE_SYNC
|
1307
1316
|
)
|
1308
1317
|
) as stream:
|
1309
1318
|
async for item in stream:
|
@@ -63,7 +63,9 @@ class IOContext:
|
|
63
63
|
"""
|
64
64
|
|
65
65
|
input_ids: list[str]
|
66
|
+
retry_counts: list[int]
|
66
67
|
function_call_ids: list[str]
|
68
|
+
function_inputs: list[api_pb2.FunctionInput]
|
67
69
|
finalized_function: "modal._runtime.user_code_imports.FinalizedFunction"
|
68
70
|
|
69
71
|
_cancel_issued: bool = False
|
@@ -72,6 +74,7 @@ class IOContext:
|
|
72
74
|
def __init__(
|
73
75
|
self,
|
74
76
|
input_ids: list[str],
|
77
|
+
retry_counts: list[int],
|
75
78
|
function_call_ids: list[str],
|
76
79
|
finalized_function: "modal._runtime.user_code_imports.FinalizedFunction",
|
77
80
|
function_inputs: list[api_pb2.FunctionInput],
|
@@ -79,9 +82,10 @@ class IOContext:
|
|
79
82
|
client: _Client,
|
80
83
|
):
|
81
84
|
self.input_ids = input_ids
|
85
|
+
self.retry_counts = retry_counts
|
82
86
|
self.function_call_ids = function_call_ids
|
83
87
|
self.finalized_function = finalized_function
|
84
|
-
self.
|
88
|
+
self.function_inputs = function_inputs
|
85
89
|
self._is_batched = is_batched
|
86
90
|
self._client = client
|
87
91
|
|
@@ -90,11 +94,11 @@ class IOContext:
|
|
90
94
|
cls,
|
91
95
|
client: _Client,
|
92
96
|
finalized_functions: dict[str, "modal._runtime.user_code_imports.FinalizedFunction"],
|
93
|
-
inputs: list[tuple[str, str, api_pb2.FunctionInput]],
|
97
|
+
inputs: list[tuple[str, int, str, api_pb2.FunctionInput]],
|
94
98
|
is_batched: bool,
|
95
99
|
) -> "IOContext":
|
96
100
|
assert len(inputs) >= 1 if is_batched else len(inputs) == 1
|
97
|
-
input_ids, function_call_ids, function_inputs = zip(*inputs)
|
101
|
+
input_ids, retry_counts, function_call_ids, function_inputs = zip(*inputs)
|
98
102
|
|
99
103
|
async def _populate_input_blobs(client: _Client, input: api_pb2.FunctionInput) -> api_pb2.FunctionInput:
|
100
104
|
# If we got a pointer to a blob, download it from S3.
|
@@ -111,7 +115,7 @@ class IOContext:
|
|
111
115
|
method_name = function_inputs[0].method_name
|
112
116
|
assert all(method_name == input.method_name for input in function_inputs)
|
113
117
|
finalized_function = finalized_functions[method_name]
|
114
|
-
return cls(input_ids, function_call_ids, finalized_function, function_inputs, is_batched, client)
|
118
|
+
return cls(input_ids, retry_counts, function_call_ids, finalized_function, function_inputs, is_batched, client)
|
115
119
|
|
116
120
|
def set_cancel_callback(self, cb: Callable[[], None]):
|
117
121
|
self._cancel_callback = cb
|
@@ -135,7 +139,7 @@ class IOContext:
|
|
135
139
|
# to make sure we handle user exceptions properly
|
136
140
|
# and don't retry
|
137
141
|
deserialized_args = [
|
138
|
-
deserialize(input.args, self._client) if input.args else ((), {}) for input in self.
|
142
|
+
deserialize(input.args, self._client) if input.args else ((), {}) for input in self.function_inputs
|
139
143
|
]
|
140
144
|
if not self._is_batched:
|
141
145
|
return deserialized_args[0]
|
@@ -551,7 +555,7 @@ class _ContainerIOManager:
|
|
551
555
|
self,
|
552
556
|
batch_max_size: int,
|
553
557
|
batch_wait_ms: int,
|
554
|
-
) -> AsyncIterator[list[tuple[str, str, api_pb2.FunctionInput]]]:
|
558
|
+
) -> AsyncIterator[list[tuple[str, int, str, api_pb2.FunctionInput]]]:
|
555
559
|
request = api_pb2.FunctionGetInputsRequest(function_id=self.function_id)
|
556
560
|
iteration = 0
|
557
561
|
while self._fetching_inputs:
|
@@ -586,8 +590,7 @@ class _ContainerIOManager:
|
|
586
590
|
if item.kill_switch:
|
587
591
|
logger.debug(f"Task {self.task_id} input kill signal input.")
|
588
592
|
return
|
589
|
-
|
590
|
-
inputs.append((item.input_id, item.function_call_id, item.input))
|
593
|
+
inputs.append((item.input_id, item.retry_count, item.function_call_id, item.input))
|
591
594
|
if item.input.final_input:
|
592
595
|
if request.batch_max_size > 0:
|
593
596
|
logger.debug(f"Task {self.task_id} Final input not expected in batch input stream")
|
@@ -648,8 +651,9 @@ class _ContainerIOManager:
|
|
648
651
|
output_created_at=output_created_at,
|
649
652
|
result=result,
|
650
653
|
data_format=data_format,
|
654
|
+
retry_count=retry_count,
|
651
655
|
)
|
652
|
-
for input_id, result in zip(io_context.input_ids, results)
|
656
|
+
for input_id, retry_count, result in zip(io_context.input_ids, io_context.retry_counts, results)
|
653
657
|
]
|
654
658
|
await retry_transient_errors(
|
655
659
|
self._client.stub.FunctionPutOutputs,
|
@@ -14,7 +14,9 @@ class Sentinel: ...
|
|
14
14
|
|
15
15
|
class IOContext:
|
16
16
|
input_ids: list[str]
|
17
|
+
retry_counts: list[int]
|
17
18
|
function_call_ids: list[str]
|
19
|
+
function_inputs: list[modal_proto.api_pb2.FunctionInput]
|
18
20
|
finalized_function: modal._runtime.user_code_imports.FinalizedFunction
|
19
21
|
_cancel_issued: bool
|
20
22
|
_cancel_callback: typing.Optional[collections.abc.Callable[[], None]]
|
@@ -22,6 +24,7 @@ class IOContext:
|
|
22
24
|
def __init__(
|
23
25
|
self,
|
24
26
|
input_ids: list[str],
|
27
|
+
retry_counts: list[int],
|
25
28
|
function_call_ids: list[str],
|
26
29
|
finalized_function: modal._runtime.user_code_imports.FinalizedFunction,
|
27
30
|
function_inputs: list[modal_proto.api_pb2.FunctionInput],
|
@@ -33,7 +36,7 @@ class IOContext:
|
|
33
36
|
cls,
|
34
37
|
client: modal.client._Client,
|
35
38
|
finalized_functions: dict[str, modal._runtime.user_code_imports.FinalizedFunction],
|
36
|
-
inputs: list[tuple[str, str, modal_proto.api_pb2.FunctionInput]],
|
39
|
+
inputs: list[tuple[str, int, str, modal_proto.api_pb2.FunctionInput]],
|
37
40
|
is_batched: bool,
|
38
41
|
) -> IOContext: ...
|
39
42
|
def set_cancel_callback(self, cb: collections.abc.Callable[[], None]): ...
|
@@ -116,7 +119,7 @@ class _ContainerIOManager:
|
|
116
119
|
def get_max_inputs_to_fetch(self): ...
|
117
120
|
def _generate_inputs(
|
118
121
|
self, batch_max_size: int, batch_wait_ms: int
|
119
|
-
) -> collections.abc.AsyncIterator[list[tuple[str, str, modal_proto.api_pb2.FunctionInput]]]: ...
|
122
|
+
) -> collections.abc.AsyncIterator[list[tuple[str, int, str, modal_proto.api_pb2.FunctionInput]]]: ...
|
120
123
|
def run_inputs_outputs(
|
121
124
|
self,
|
122
125
|
finalized_functions: dict[str, modal._runtime.user_code_imports.FinalizedFunction],
|
@@ -287,10 +290,10 @@ class ContainerIOManager:
|
|
287
290
|
class ___generate_inputs_spec(typing_extensions.Protocol[SUPERSELF]):
|
288
291
|
def __call__(
|
289
292
|
self, batch_max_size: int, batch_wait_ms: int
|
290
|
-
) -> typing.Iterator[list[tuple[str, str, modal_proto.api_pb2.FunctionInput]]]: ...
|
293
|
+
) -> typing.Iterator[list[tuple[str, int, str, modal_proto.api_pb2.FunctionInput]]]: ...
|
291
294
|
def aio(
|
292
295
|
self, batch_max_size: int, batch_wait_ms: int
|
293
|
-
) -> collections.abc.AsyncIterator[list[tuple[str, str, modal_proto.api_pb2.FunctionInput]]]: ...
|
296
|
+
) -> collections.abc.AsyncIterator[list[tuple[str, int, str, modal_proto.api_pb2.FunctionInput]]]: ...
|
294
297
|
|
295
298
|
_generate_inputs: ___generate_inputs_spec[typing_extensions.Self]
|
296
299
|
|
modal/_serialization.py
CHANGED
@@ -400,6 +400,7 @@ class ParamTypeInfo:
|
|
400
400
|
default_field: str
|
401
401
|
proto_field: str
|
402
402
|
converter: typing.Callable[[str], typing.Any]
|
403
|
+
type: type
|
403
404
|
|
404
405
|
|
405
406
|
PYTHON_TO_PROTO_TYPE: dict[type, "api_pb2.ParameterType.ValueType"] = {
|
@@ -411,75 +412,112 @@ PYTHON_TO_PROTO_TYPE: dict[type, "api_pb2.ParameterType.ValueType"] = {
|
|
411
412
|
|
412
413
|
PROTO_TYPE_INFO = {
|
413
414
|
# Protobuf type enum -> encode/decode helper metadata
|
414
|
-
api_pb2.PARAM_TYPE_STRING: ParamTypeInfo(
|
415
|
-
|
415
|
+
api_pb2.PARAM_TYPE_STRING: ParamTypeInfo(
|
416
|
+
default_field="string_default",
|
417
|
+
proto_field="string_value",
|
418
|
+
converter=str,
|
419
|
+
type=str,
|
420
|
+
),
|
421
|
+
api_pb2.PARAM_TYPE_INT: ParamTypeInfo(
|
422
|
+
default_field="int_default",
|
423
|
+
proto_field="int_value",
|
424
|
+
converter=int,
|
425
|
+
type=int,
|
426
|
+
),
|
416
427
|
api_pb2.PARAM_TYPE_BYTES: ParamTypeInfo(
|
417
|
-
default_field="bytes_default",
|
428
|
+
default_field="bytes_default",
|
429
|
+
proto_field="bytes_value",
|
430
|
+
converter=assert_bytes,
|
431
|
+
type=bytes,
|
418
432
|
),
|
419
433
|
}
|
420
434
|
|
421
435
|
|
422
|
-
def
|
423
|
-
|
436
|
+
def apply_defaults(
|
437
|
+
python_params: typing.Mapping[str, Any], schema: typing.Sequence[api_pb2.ClassParameterSpec]
|
438
|
+
) -> dict[str, Any]:
|
439
|
+
"""Apply any declared defaults from the provided schema, if values aren't provided in python_params
|
440
|
+
|
441
|
+
Conceptually similar to inspect.BoundArguments.apply_defaults.
|
442
|
+
|
443
|
+
Note: Apply this before serializing parameters in order to get consistent parameter
|
444
|
+
pools regardless if a value is explicitly provided or not.
|
445
|
+
"""
|
446
|
+
result = {**python_params}
|
424
447
|
for schema_param in schema:
|
425
|
-
|
426
|
-
|
427
|
-
|
448
|
+
if schema_param.has_default and schema_param.name not in python_params:
|
449
|
+
default_field_name = schema_param.WhichOneof("default_oneof")
|
450
|
+
if default_field_name is None:
|
451
|
+
raise InvalidError(f"{schema_param.name} declared as having a default, but has no default value")
|
452
|
+
result[schema_param.name] = getattr(schema_param, default_field_name)
|
453
|
+
return result
|
454
|
+
|
455
|
+
|
456
|
+
def serialize_proto_params(python_params: dict[str, Any]) -> bytes:
|
457
|
+
proto_params: list[api_pb2.ClassParameterValue] = []
|
458
|
+
for param_name, python_value in python_params.items():
|
459
|
+
python_type = type(python_value)
|
460
|
+
protobuf_type = get_proto_parameter_type(python_type)
|
461
|
+
type_info = PROTO_TYPE_INFO.get(protobuf_type)
|
428
462
|
proto_param = api_pb2.ClassParameterValue(
|
429
|
-
name=
|
430
|
-
type=
|
463
|
+
name=param_name,
|
464
|
+
type=protobuf_type,
|
431
465
|
)
|
432
|
-
python_value = python_params.get(schema_param.name)
|
433
|
-
if python_value is None:
|
434
|
-
if schema_param.has_default:
|
435
|
-
python_value = getattr(schema_param, type_info.default_field)
|
436
|
-
else:
|
437
|
-
raise ValueError(f"Missing required parameter: {schema_param.name}")
|
438
466
|
try:
|
439
467
|
converted_value = type_info.converter(python_value)
|
440
468
|
except ValueError as exc:
|
441
|
-
raise ValueError(f"Invalid type for parameter {
|
469
|
+
raise ValueError(f"Invalid type for parameter {param_name}: {exc}")
|
442
470
|
setattr(proto_param, type_info.proto_field, converted_value)
|
443
471
|
proto_params.append(proto_param)
|
444
472
|
proto_bytes = api_pb2.ClassParameterSet(parameters=proto_params).SerializeToString(deterministic=True)
|
445
473
|
return proto_bytes
|
446
474
|
|
447
475
|
|
448
|
-
def deserialize_proto_params(serialized_params: bytes
|
449
|
-
# TODO: this currently requires the schema to decode a payload, but we should make the validation
|
450
|
-
# distinct from the deserialization
|
476
|
+
def deserialize_proto_params(serialized_params: bytes) -> dict[str, Any]:
|
451
477
|
proto_struct = api_pb2.ClassParameterSet()
|
452
478
|
proto_struct.ParseFromString(serialized_params)
|
453
|
-
value_by_name = {p.name: p for p in proto_struct.parameters}
|
454
479
|
python_params = {}
|
455
|
-
for
|
456
|
-
if schema_param.name not in value_by_name:
|
457
|
-
# TODO: handle default values? Could just be a flag on the FunctionParameter schema spec,
|
458
|
-
# allowing it to not be supplied in the FunctionParameterSet?
|
459
|
-
raise AttributeError(f"Constructor arguments don't match declared parameters (missing {schema_param.name})")
|
460
|
-
param_value = value_by_name[schema_param.name]
|
461
|
-
if schema_param.type != param_value.type:
|
462
|
-
raise ValueError(
|
463
|
-
"Constructor arguments types don't match declared parameters "
|
464
|
-
f"({schema_param.name}: type {schema_param.type} != type {param_value.type})"
|
465
|
-
)
|
480
|
+
for param in proto_struct.parameters:
|
466
481
|
python_value: Any
|
467
|
-
if
|
468
|
-
python_value =
|
469
|
-
elif
|
470
|
-
python_value =
|
471
|
-
elif
|
472
|
-
python_value =
|
482
|
+
if param.type == api_pb2.PARAM_TYPE_STRING:
|
483
|
+
python_value = param.string_value
|
484
|
+
elif param.type == api_pb2.PARAM_TYPE_INT:
|
485
|
+
python_value = param.int_value
|
486
|
+
elif param.type == api_pb2.PARAM_TYPE_BYTES:
|
487
|
+
python_value = param.bytes_value
|
473
488
|
else:
|
474
|
-
|
475
|
-
# custom non proto types encoded as bytes in the proto, e.g. PARAM_TYPE_PYTHON_PICKLE
|
476
|
-
raise NotImplementedError("Only strings and ints are supported parameter value types at the moment")
|
489
|
+
raise NotImplementedError(f"Unimplemented parameter type: {param.type}.")
|
477
490
|
|
478
|
-
python_params[
|
491
|
+
python_params[param.name] = python_value
|
479
492
|
|
480
493
|
return python_params
|
481
494
|
|
482
495
|
|
496
|
+
def validate_params(params: dict[str, Any], schema: typing.Sequence[api_pb2.ClassParameterSpec]):
|
497
|
+
# first check that all declared values are provided
|
498
|
+
for schema_param in schema:
|
499
|
+
if schema_param.name not in params:
|
500
|
+
# we expect all values to be present - even defaulted ones (defaults are applied on payload construction)
|
501
|
+
raise InvalidError(f"Missing required parameter: {schema_param.name}")
|
502
|
+
python_value = params[schema_param.name]
|
503
|
+
python_type = type(python_value)
|
504
|
+
param_protobuf_type = get_proto_parameter_type(python_type)
|
505
|
+
if schema_param.type != param_protobuf_type:
|
506
|
+
expected_python_type = PROTO_TYPE_INFO[schema_param.type].type
|
507
|
+
raise TypeError(
|
508
|
+
f"Parameter '{schema_param.name}' type error: expected {expected_python_type.__name__}, "
|
509
|
+
f"got {python_type.__name__}"
|
510
|
+
)
|
511
|
+
|
512
|
+
schema_fields = {p.name for p in schema}
|
513
|
+
# then check that no extra values are provided
|
514
|
+
non_declared_fields = params.keys() - schema_fields
|
515
|
+
if non_declared_fields:
|
516
|
+
raise InvalidError(
|
517
|
+
f"The following parameter names were provided but are not present in the schema: {non_declared_fields}"
|
518
|
+
)
|
519
|
+
|
520
|
+
|
483
521
|
def deserialize_params(serialized_params: bytes, function_def: api_pb2.Function, _client: "modal.client._Client"):
|
484
522
|
if function_def.class_parameter_info.format in (
|
485
523
|
api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_UNSPECIFIED,
|
@@ -488,11 +526,21 @@ def deserialize_params(serialized_params: bytes, function_def: api_pb2.Function,
|
|
488
526
|
# legacy serialization format - pickle of `(args, kwargs)` w/ support for modal object arguments
|
489
527
|
param_args, param_kwargs = deserialize(serialized_params, _client)
|
490
528
|
elif function_def.class_parameter_info.format == api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PROTO:
|
491
|
-
param_args = ()
|
492
|
-
param_kwargs = deserialize_proto_params(serialized_params
|
529
|
+
param_args = () # we use kwargs only for our implicit constructors
|
530
|
+
param_kwargs = deserialize_proto_params(serialized_params)
|
531
|
+
# TODO: We can probably remove the validation below since we do validation in the caller?
|
532
|
+
validate_params(param_kwargs, list(function_def.class_parameter_info.schema))
|
493
533
|
else:
|
494
534
|
raise ExecutionError(
|
495
535
|
f"Unknown class parameter serialization format: {function_def.class_parameter_info.format}"
|
496
536
|
)
|
497
537
|
|
498
538
|
return param_args, param_kwargs
|
539
|
+
|
540
|
+
|
541
|
+
def get_proto_parameter_type(parameter_type: type) -> "api_pb2.ParameterType.ValueType":
|
542
|
+
if parameter_type not in PYTHON_TO_PROTO_TYPE:
|
543
|
+
type_name = getattr(parameter_type, "__name__", repr(parameter_type))
|
544
|
+
supported = ", ".join(parameter_type.__name__ for parameter_type in PYTHON_TO_PROTO_TYPE.keys())
|
545
|
+
raise InvalidError(f"{type_name} is not a supported parameter type. Use one of: {supported}")
|
546
|
+
return PYTHON_TO_PROTO_TYPE[parameter_type]
|
modal/_utils/async_utils.py
CHANGED
@@ -12,6 +12,7 @@ from dataclasses import dataclass
|
|
12
12
|
from typing import (
|
13
13
|
Any,
|
14
14
|
Callable,
|
15
|
+
Generic,
|
15
16
|
Optional,
|
16
17
|
TypeVar,
|
17
18
|
Union,
|
@@ -26,6 +27,10 @@ from typing_extensions import ParamSpec, assert_type
|
|
26
27
|
from ..exception import InvalidError
|
27
28
|
from .logger import logger
|
28
29
|
|
30
|
+
T = TypeVar("T")
|
31
|
+
P = ParamSpec("P")
|
32
|
+
V = TypeVar("V")
|
33
|
+
|
29
34
|
synchronizer = synchronicity.Synchronizer()
|
30
35
|
|
31
36
|
|
@@ -260,7 +265,72 @@ def run_coro_blocking(coro):
|
|
260
265
|
return fut.result()
|
261
266
|
|
262
267
|
|
263
|
-
|
268
|
+
class TimestampPriorityQueue(Generic[T]):
|
269
|
+
"""
|
270
|
+
A priority queue that schedules items to be processed at specific timestamps.
|
271
|
+
"""
|
272
|
+
|
273
|
+
_MAX_PRIORITY = float("inf")
|
274
|
+
|
275
|
+
def __init__(self, maxsize: int = 0):
|
276
|
+
self.condition = asyncio.Condition()
|
277
|
+
self._queue: asyncio.PriorityQueue[tuple[float, Union[T, None]]] = asyncio.PriorityQueue(maxsize=maxsize)
|
278
|
+
|
279
|
+
async def close(self):
|
280
|
+
await self.put(self._MAX_PRIORITY, None)
|
281
|
+
|
282
|
+
async def put(self, timestamp: float, item: Union[T, None]):
|
283
|
+
"""
|
284
|
+
Add an item to the queue to be processed at a specific timestamp.
|
285
|
+
"""
|
286
|
+
await self._queue.put((timestamp, item))
|
287
|
+
async with self.condition:
|
288
|
+
self.condition.notify_all() # notify any waiting coroutines
|
289
|
+
|
290
|
+
async def get(self) -> Union[T, None]:
|
291
|
+
"""
|
292
|
+
Get the next item from the queue that is ready to be processed.
|
293
|
+
"""
|
294
|
+
while True:
|
295
|
+
async with self.condition:
|
296
|
+
while self.empty():
|
297
|
+
await self.condition.wait()
|
298
|
+
# peek at the next item
|
299
|
+
timestamp, item = await self._queue.get()
|
300
|
+
now = time.time()
|
301
|
+
if timestamp < now:
|
302
|
+
return item
|
303
|
+
if timestamp == self._MAX_PRIORITY:
|
304
|
+
return None
|
305
|
+
# not ready yet, calculate sleep time
|
306
|
+
sleep_time = timestamp - now
|
307
|
+
self._queue.put_nowait((timestamp, item)) # put it back
|
308
|
+
# wait until either the timeout or a new item is added
|
309
|
+
try:
|
310
|
+
await asyncio.wait_for(self.condition.wait(), timeout=sleep_time)
|
311
|
+
except asyncio.TimeoutError:
|
312
|
+
continue
|
313
|
+
|
314
|
+
def empty(self) -> bool:
|
315
|
+
return self._queue.empty()
|
316
|
+
|
317
|
+
def qsize(self) -> int:
|
318
|
+
return self._queue.qsize()
|
319
|
+
|
320
|
+
async def clear(self):
|
321
|
+
"""
|
322
|
+
Clear the retry queue. Used for testing to simulate reading all elements from queue using queue_batch_iterator.
|
323
|
+
"""
|
324
|
+
while not self.empty():
|
325
|
+
await self.get()
|
326
|
+
|
327
|
+
def __len__(self):
|
328
|
+
return self._queue.qsize()
|
329
|
+
|
330
|
+
|
331
|
+
async def queue_batch_iterator(
|
332
|
+
q: Union[asyncio.Queue, TimestampPriorityQueue], max_batch_size=100, debounce_time=0.015
|
333
|
+
):
|
264
334
|
"""
|
265
335
|
Read from a queue but return lists of items when queue is large
|
266
336
|
|
@@ -405,11 +475,6 @@ def on_shutdown(coro):
|
|
405
475
|
_shutdown_tasks.append(asyncio.create_task(wrapper()))
|
406
476
|
|
407
477
|
|
408
|
-
T = TypeVar("T")
|
409
|
-
P = ParamSpec("P")
|
410
|
-
V = TypeVar("V")
|
411
|
-
|
412
|
-
|
413
478
|
def asyncify(f: Callable[P, T]) -> Callable[P, typing.Coroutine[None, None, T]]:
|
414
479
|
"""Convert a blocking function into one that runs in the current loop's executor."""
|
415
480
|
|
modal/_utils/function_utils.py
CHANGED
@@ -15,7 +15,14 @@ from synchronicity.exceptions import UserCodeException
|
|
15
15
|
import modal_proto
|
16
16
|
from modal_proto import api_pb2
|
17
17
|
|
18
|
-
from .._serialization import
|
18
|
+
from .._serialization import (
|
19
|
+
PROTO_TYPE_INFO,
|
20
|
+
PYTHON_TO_PROTO_TYPE,
|
21
|
+
deserialize,
|
22
|
+
deserialize_data_format,
|
23
|
+
get_proto_parameter_type,
|
24
|
+
serialize,
|
25
|
+
)
|
19
26
|
from .._traceback import append_modal_tb
|
20
27
|
from ..config import config, logger
|
21
28
|
from ..exception import (
|
@@ -99,6 +106,24 @@ def get_function_type(is_generator: Optional[bool]) -> "api_pb2.Function.Functio
|
|
99
106
|
return api_pb2.Function.FUNCTION_TYPE_GENERATOR if is_generator else api_pb2.Function.FUNCTION_TYPE_FUNCTION
|
100
107
|
|
101
108
|
|
109
|
+
def signature_to_protobuf_schema(signature: inspect.Signature) -> list[api_pb2.ClassParameterSpec]:
|
110
|
+
modal_parameters: list[api_pb2.ClassParameterSpec] = []
|
111
|
+
for param in signature.parameters.values():
|
112
|
+
has_default = param.default is not param.empty
|
113
|
+
class_param_spec = api_pb2.ClassParameterSpec(name=param.name, has_default=has_default)
|
114
|
+
if param.annotation not in PYTHON_TO_PROTO_TYPE:
|
115
|
+
class_param_spec.type = api_pb2.PARAM_TYPE_UNKNOWN
|
116
|
+
else:
|
117
|
+
proto_type = PYTHON_TO_PROTO_TYPE[param.annotation]
|
118
|
+
class_param_spec.type = proto_type
|
119
|
+
proto_type_info = PROTO_TYPE_INFO[proto_type]
|
120
|
+
if has_default and proto_type is not api_pb2.PARAM_TYPE_UNKNOWN:
|
121
|
+
setattr(class_param_spec, proto_type_info.default_field, param.default)
|
122
|
+
|
123
|
+
modal_parameters.append(class_param_spec)
|
124
|
+
return modal_parameters
|
125
|
+
|
126
|
+
|
102
127
|
class FunctionInfo:
|
103
128
|
"""Utility that determines serialization/deserialization mechanisms for functions
|
104
129
|
|
@@ -277,28 +302,23 @@ class FunctionInfo:
|
|
277
302
|
return api_pb2.ClassParameterInfo()
|
278
303
|
|
279
304
|
# TODO(elias): Resolve circular dependencies... maybe we'll need some cls_utils module
|
280
|
-
from modal.cls import _get_class_constructor_signature, _use_annotation_parameters
|
305
|
+
from modal.cls import _get_class_constructor_signature, _use_annotation_parameters
|
281
306
|
|
282
307
|
if not _use_annotation_parameters(self.user_cls):
|
283
308
|
return api_pb2.ClassParameterInfo(format=api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PICKLE)
|
284
309
|
|
285
310
|
# annotation parameters trigger strictly typed parametrization
|
286
311
|
# which enables web endpoint for parametrized classes
|
287
|
-
|
288
|
-
modal_parameters: list[api_pb2.ClassParameterSpec] = []
|
289
312
|
signature = _get_class_constructor_signature(self.user_cls)
|
313
|
+
# validate that the schema has no unspecified fields/unsupported class parameter types
|
290
314
|
for param in signature.parameters.values():
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
proto_type_info = PROTO_TYPE_INFO[proto_type]
|
295
|
-
class_param_spec = api_pb2.ClassParameterSpec(name=param.name, has_default=has_default, type=proto_type)
|
296
|
-
if has_default:
|
297
|
-
setattr(class_param_spec, proto_type_info.default_field, param.default)
|
298
|
-
modal_parameters.append(class_param_spec)
|
315
|
+
get_proto_parameter_type(param.annotation)
|
316
|
+
|
317
|
+
protobuf_schema = signature_to_protobuf_schema(signature)
|
299
318
|
|
300
319
|
return api_pb2.ClassParameterInfo(
|
301
|
-
format=api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PROTO,
|
320
|
+
format=api_pb2.ClassParameterInfo.PARAM_SERIALIZATION_FORMAT_PROTO,
|
321
|
+
schema=protobuf_schema,
|
302
322
|
)
|
303
323
|
|
304
324
|
def get_entrypoint_mount(self) -> dict[str, _Mount]:
|
@@ -0,0 +1,38 @@
|
|
1
|
+
# Copyright Modal Labs 2025
|
2
|
+
import base64
|
3
|
+
import json
|
4
|
+
from dataclasses import dataclass
|
5
|
+
from typing import Any, Dict
|
6
|
+
|
7
|
+
|
8
|
+
@dataclass
|
9
|
+
class DecodedJwt:
|
10
|
+
header: Dict[str, Any]
|
11
|
+
payload: Dict[str, Any]
|
12
|
+
|
13
|
+
@staticmethod
|
14
|
+
def decode_without_verification(token: str) -> "DecodedJwt":
|
15
|
+
# Split the JWT into its three parts
|
16
|
+
header_b64, payload_b64, _ = token.split(".")
|
17
|
+
|
18
|
+
# Decode Base64 (with padding handling)
|
19
|
+
header_json = base64.urlsafe_b64decode(header_b64 + "==").decode("utf-8")
|
20
|
+
payload_json = base64.urlsafe_b64decode(payload_b64 + "==").decode("utf-8")
|
21
|
+
|
22
|
+
# Convert JSON strings to dictionaries
|
23
|
+
header = json.loads(header_json)
|
24
|
+
payload = json.loads(payload_json)
|
25
|
+
|
26
|
+
return DecodedJwt(header, payload)
|
27
|
+
|
28
|
+
@staticmethod
|
29
|
+
def _base64url_encode(data: str) -> str:
|
30
|
+
"""Encodes data to Base64 URL-safe format without padding."""
|
31
|
+
return base64.urlsafe_b64encode(data.encode()).rstrip(b"=").decode()
|
32
|
+
|
33
|
+
@staticmethod
|
34
|
+
def encode_without_signature(fields: Dict[str, Any]) -> str:
|
35
|
+
"""Encodes an Unsecured JWT (without a signature)."""
|
36
|
+
header_b64 = DecodedJwt._base64url_encode(json.dumps({"alg": "none", "typ": "JWT"}))
|
37
|
+
payload_b64 = DecodedJwt._base64url_encode(json.dumps(fields))
|
38
|
+
return f"{header_b64}.{payload_b64}." # No signature
|
modal/cli/app.py
CHANGED
@@ -227,6 +227,8 @@ async def history(
|
|
227
227
|
]
|
228
228
|
rows = []
|
229
229
|
deployments_with_tags = False
|
230
|
+
deployments_with_commit_info = False
|
231
|
+
deployments_with_dirty_commit = False
|
230
232
|
for idx, app_stats in enumerate(resp.app_deployment_histories):
|
231
233
|
style = "bold green" if idx == 0 else ""
|
232
234
|
|
@@ -241,10 +243,23 @@ async def history(
|
|
241
243
|
deployments_with_tags = True
|
242
244
|
row.append(Text(app_stats.tag, style=style))
|
243
245
|
|
246
|
+
if app_stats.commit_info.commit_hash:
|
247
|
+
deployments_with_commit_info = True
|
248
|
+
short_hash = app_stats.commit_info.commit_hash[:7]
|
249
|
+
if app_stats.commit_info.dirty:
|
250
|
+
deployments_with_dirty_commit = True
|
251
|
+
short_hash = f"{short_hash}*"
|
252
|
+
row.append(Text(short_hash, style=style))
|
253
|
+
|
244
254
|
rows.append(row)
|
245
255
|
|
246
256
|
if deployments_with_tags:
|
247
257
|
columns.append("Tag")
|
258
|
+
if deployments_with_commit_info:
|
259
|
+
columns.append("Commit")
|
248
260
|
|
249
261
|
rows = sorted(rows, key=lambda x: int(str(x[0])[1:]), reverse=True)
|
250
262
|
display_table(columns, rows, json)
|
263
|
+
|
264
|
+
if deployments_with_dirty_commit and not json:
|
265
|
+
rich.print("* - repo had uncommitted changes")
|
modal/client.pyi
CHANGED
@@ -31,7 +31,7 @@ class _Client:
|
|
31
31
|
server_url: str,
|
32
32
|
client_type: int,
|
33
33
|
credentials: typing.Optional[tuple[str, str]],
|
34
|
-
version: str = "0.73.
|
34
|
+
version: str = "0.73.126",
|
35
35
|
): ...
|
36
36
|
def is_closed(self) -> bool: ...
|
37
37
|
@property
|
@@ -93,7 +93,7 @@ class Client:
|
|
93
93
|
server_url: str,
|
94
94
|
client_type: int,
|
95
95
|
credentials: typing.Optional[tuple[str, str]],
|
96
|
-
version: str = "0.73.
|
96
|
+
version: str = "0.73.126",
|
97
97
|
): ...
|
98
98
|
def is_closed(self) -> bool: ...
|
99
99
|
@property
|