modal 1.0.6.dev58__py3-none-any.whl → 1.2.3.dev7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of modal might be problematic. Click here for more details.
- modal/__main__.py +3 -4
- modal/_billing.py +80 -0
- modal/_clustered_functions.py +7 -3
- modal/_clustered_functions.pyi +4 -2
- modal/_container_entrypoint.py +41 -49
- modal/_functions.py +424 -195
- modal/_grpc_client.py +171 -0
- modal/_load_context.py +105 -0
- modal/_object.py +68 -20
- modal/_output.py +58 -45
- modal/_partial_function.py +36 -11
- modal/_pty.py +7 -3
- modal/_resolver.py +21 -35
- modal/_runtime/asgi.py +4 -3
- modal/_runtime/container_io_manager.py +301 -186
- modal/_runtime/container_io_manager.pyi +70 -61
- modal/_runtime/execution_context.py +18 -2
- modal/_runtime/execution_context.pyi +4 -1
- modal/_runtime/gpu_memory_snapshot.py +170 -63
- modal/_runtime/user_code_imports.py +28 -58
- modal/_serialization.py +57 -1
- modal/_utils/async_utils.py +33 -12
- modal/_utils/auth_token_manager.py +2 -5
- modal/_utils/blob_utils.py +110 -53
- modal/_utils/function_utils.py +49 -42
- modal/_utils/grpc_utils.py +80 -50
- modal/_utils/mount_utils.py +26 -1
- modal/_utils/name_utils.py +17 -3
- modal/_utils/task_command_router_client.py +536 -0
- modal/_utils/time_utils.py +34 -6
- modal/app.py +219 -83
- modal/app.pyi +229 -56
- modal/billing.py +5 -0
- modal/{requirements → builder}/2025.06.txt +1 -0
- modal/{requirements → builder}/PREVIEW.txt +1 -0
- modal/cli/_download.py +19 -3
- modal/cli/_traceback.py +3 -2
- modal/cli/app.py +4 -4
- modal/cli/cluster.py +15 -7
- modal/cli/config.py +5 -3
- modal/cli/container.py +7 -6
- modal/cli/dict.py +22 -16
- modal/cli/entry_point.py +12 -5
- modal/cli/environment.py +5 -4
- modal/cli/import_refs.py +3 -3
- modal/cli/launch.py +102 -5
- modal/cli/network_file_system.py +9 -13
- modal/cli/profile.py +3 -2
- modal/cli/programs/launch_instance_ssh.py +94 -0
- modal/cli/programs/run_jupyter.py +1 -1
- modal/cli/programs/run_marimo.py +95 -0
- modal/cli/programs/vscode.py +1 -1
- modal/cli/queues.py +57 -26
- modal/cli/run.py +58 -16
- modal/cli/secret.py +48 -22
- modal/cli/utils.py +3 -4
- modal/cli/volume.py +28 -25
- modal/client.py +13 -116
- modal/client.pyi +9 -91
- modal/cloud_bucket_mount.py +5 -3
- modal/cloud_bucket_mount.pyi +5 -1
- modal/cls.py +130 -102
- modal/cls.pyi +45 -85
- modal/config.py +29 -10
- modal/container_process.py +291 -13
- modal/container_process.pyi +95 -32
- modal/dict.py +282 -63
- modal/dict.pyi +423 -73
- modal/environments.py +15 -27
- modal/environments.pyi +5 -15
- modal/exception.py +8 -0
- modal/experimental/__init__.py +143 -38
- modal/experimental/flash.py +247 -78
- modal/experimental/flash.pyi +137 -9
- modal/file_io.py +14 -28
- modal/file_io.pyi +2 -2
- modal/file_pattern_matcher.py +25 -16
- modal/functions.pyi +134 -61
- modal/image.py +255 -86
- modal/image.pyi +300 -62
- modal/io_streams.py +436 -126
- modal/io_streams.pyi +236 -171
- modal/mount.py +62 -157
- modal/mount.pyi +45 -172
- modal/network_file_system.py +30 -53
- modal/network_file_system.pyi +16 -76
- modal/object.pyi +42 -8
- modal/parallel_map.py +821 -113
- modal/parallel_map.pyi +134 -0
- modal/partial_function.pyi +4 -1
- modal/proxy.py +16 -7
- modal/proxy.pyi +10 -2
- modal/queue.py +263 -61
- modal/queue.pyi +409 -66
- modal/runner.py +112 -92
- modal/runner.pyi +45 -27
- modal/sandbox.py +451 -124
- modal/sandbox.pyi +513 -67
- modal/secret.py +291 -67
- modal/secret.pyi +425 -19
- modal/serving.py +7 -11
- modal/serving.pyi +7 -8
- modal/snapshot.py +11 -8
- modal/token_flow.py +4 -4
- modal/volume.py +344 -98
- modal/volume.pyi +464 -68
- {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/METADATA +9 -8
- modal-1.2.3.dev7.dist-info/RECORD +195 -0
- modal_docs/mdmd/mdmd.py +11 -1
- modal_proto/api.proto +399 -67
- modal_proto/api_grpc.py +241 -1
- modal_proto/api_pb2.py +1395 -1000
- modal_proto/api_pb2.pyi +1239 -79
- modal_proto/api_pb2_grpc.py +499 -4
- modal_proto/api_pb2_grpc.pyi +162 -14
- modal_proto/modal_api_grpc.py +175 -160
- modal_proto/sandbox_router.proto +145 -0
- modal_proto/sandbox_router_grpc.py +105 -0
- modal_proto/sandbox_router_pb2.py +149 -0
- modal_proto/sandbox_router_pb2.pyi +333 -0
- modal_proto/sandbox_router_pb2_grpc.py +203 -0
- modal_proto/sandbox_router_pb2_grpc.pyi +75 -0
- modal_proto/task_command_router.proto +144 -0
- modal_proto/task_command_router_grpc.py +105 -0
- modal_proto/task_command_router_pb2.py +149 -0
- modal_proto/task_command_router_pb2.pyi +333 -0
- modal_proto/task_command_router_pb2_grpc.py +203 -0
- modal_proto/task_command_router_pb2_grpc.pyi +75 -0
- modal_version/__init__.py +1 -1
- modal-1.0.6.dev58.dist-info/RECORD +0 -183
- modal_proto/modal_options_grpc.py +0 -3
- modal_proto/options.proto +0 -19
- modal_proto/options_grpc.py +0 -3
- modal_proto/options_pb2.py +0 -35
- modal_proto/options_pb2.pyi +0 -20
- modal_proto/options_pb2_grpc.py +0 -4
- modal_proto/options_pb2_grpc.pyi +0 -7
- /modal/{requirements → builder}/2023.12.312.txt +0 -0
- /modal/{requirements → builder}/2023.12.txt +0 -0
- /modal/{requirements → builder}/2024.04.txt +0 -0
- /modal/{requirements → builder}/2024.10.txt +0 -0
- /modal/{requirements → builder}/README.md +0 -0
- /modal/{requirements → builder}/base-images.json +0 -0
- {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/WHEEL +0 -0
- {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/entry_points.txt +0 -0
- {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/licenses/LICENSE +0 -0
- {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/top_level.txt +0 -0
|
@@ -29,7 +29,7 @@ class FinalizedFunction:
|
|
|
29
29
|
callable: Callable[..., Any]
|
|
30
30
|
is_async: bool
|
|
31
31
|
is_generator: bool
|
|
32
|
-
|
|
32
|
+
supported_output_formats: Sequence["api_pb2.DataFormat.ValueType"]
|
|
33
33
|
lifespan_manager: Optional["LifespanManager"] = None
|
|
34
34
|
|
|
35
35
|
|
|
@@ -93,9 +93,9 @@ def construct_webhook_callable(
|
|
|
93
93
|
|
|
94
94
|
@dataclass
|
|
95
95
|
class ImportedFunction(Service):
|
|
96
|
-
user_cls_instance: Any
|
|
97
96
|
app: modal.app._App
|
|
98
97
|
service_deps: Optional[Sequence["modal._object._Object"]]
|
|
98
|
+
user_cls_instance = None
|
|
99
99
|
|
|
100
100
|
_user_defined_callable: Callable[..., Any]
|
|
101
101
|
|
|
@@ -108,6 +108,7 @@ class ImportedFunction(Service):
|
|
|
108
108
|
is_generator = fun_def.function_type == api_pb2.Function.FUNCTION_TYPE_GENERATOR
|
|
109
109
|
|
|
110
110
|
webhook_config = fun_def.webhook_config
|
|
111
|
+
|
|
111
112
|
if not webhook_config.type:
|
|
112
113
|
# for non-webhooks, the runnable is straight forward:
|
|
113
114
|
return {
|
|
@@ -115,7 +116,10 @@ class ImportedFunction(Service):
|
|
|
115
116
|
callable=self._user_defined_callable,
|
|
116
117
|
is_async=is_async,
|
|
117
118
|
is_generator=is_generator,
|
|
118
|
-
|
|
119
|
+
supported_output_formats=fun_def.supported_output_formats
|
|
120
|
+
# FIXME (elias): the following `or [api_pb2.DATA_FORMAT_PICKLE, api_pb2.DATA_FORMAT_CBOR]` is only
|
|
121
|
+
# needed for tests
|
|
122
|
+
or [api_pb2.DATA_FORMAT_PICKLE, api_pb2.DATA_FORMAT_CBOR],
|
|
119
123
|
)
|
|
120
124
|
}
|
|
121
125
|
|
|
@@ -129,7 +133,8 @@ class ImportedFunction(Service):
|
|
|
129
133
|
lifespan_manager=lifespan_manager,
|
|
130
134
|
is_async=True,
|
|
131
135
|
is_generator=True,
|
|
132
|
-
|
|
136
|
+
# FIXME (elias): the following `or [api_pb2.DATA_FORMAT_ASGI]` is only needed for tests
|
|
137
|
+
supported_output_formats=fun_def.supported_output_formats or [api_pb2.DATA_FORMAT_ASGI],
|
|
133
138
|
)
|
|
134
139
|
}
|
|
135
140
|
|
|
@@ -154,6 +159,7 @@ class ImportedClass(Service):
|
|
|
154
159
|
# Use the function definition for whether this is a generator (overriden by webhooks)
|
|
155
160
|
is_generator = _partial.params.is_generator
|
|
156
161
|
webhook_config = _partial.params.webhook_config
|
|
162
|
+
method_def = fun_def.method_definitions[method_name]
|
|
157
163
|
|
|
158
164
|
bound_func = user_func.__get__(self.user_cls_instance)
|
|
159
165
|
|
|
@@ -163,7 +169,10 @@ class ImportedClass(Service):
|
|
|
163
169
|
callable=bound_func,
|
|
164
170
|
is_async=is_async,
|
|
165
171
|
is_generator=bool(is_generator),
|
|
166
|
-
|
|
172
|
+
# FIXME (elias): the following `or [api_pb2.DATA_FORMAT_PICKLE, api_pb2.DATA_FORMAT_CBOR]` is only
|
|
173
|
+
# needed for tests
|
|
174
|
+
supported_output_formats=method_def.supported_output_formats
|
|
175
|
+
or [api_pb2.DATA_FORMAT_PICKLE, api_pb2.DATA_FORMAT_CBOR],
|
|
167
176
|
)
|
|
168
177
|
else:
|
|
169
178
|
web_callable, lifespan_manager = construct_webhook_callable(
|
|
@@ -174,7 +183,8 @@ class ImportedClass(Service):
|
|
|
174
183
|
lifespan_manager=lifespan_manager,
|
|
175
184
|
is_async=True,
|
|
176
185
|
is_generator=True,
|
|
177
|
-
|
|
186
|
+
# FIXME (elias): the following `or [api_pb2.DATA_FORMAT_ASGI]` is only needed for tests
|
|
187
|
+
supported_output_formats=method_def.supported_output_formats or [api_pb2.DATA_FORMAT_ASGI],
|
|
178
188
|
)
|
|
179
189
|
finalized_functions[method_name] = finalized_function
|
|
180
190
|
return finalized_functions
|
|
@@ -199,7 +209,6 @@ def get_user_class_instance(_cls: modal.cls._Cls, args: tuple[Any, ...], kwargs:
|
|
|
199
209
|
|
|
200
210
|
def import_single_function_service(
|
|
201
211
|
function_def: api_pb2.Function,
|
|
202
|
-
ser_cls: Optional[type], # used only for @build functions
|
|
203
212
|
ser_fun: Optional[Callable[..., Any]],
|
|
204
213
|
) -> Service:
|
|
205
214
|
"""Imports a function dynamically, and locates the app.
|
|
@@ -228,12 +237,9 @@ def import_single_function_service(
|
|
|
228
237
|
service_deps: Optional[Sequence["modal._object._Object"]] = None
|
|
229
238
|
active_app: modal.app._App
|
|
230
239
|
|
|
231
|
-
user_cls_or_cls: typing.Union[None, type, modal.cls.Cls]
|
|
232
|
-
user_cls_instance = None
|
|
233
|
-
|
|
234
240
|
if ser_fun is not None:
|
|
235
241
|
# This is a serialized function we already fetched from the server
|
|
236
|
-
|
|
242
|
+
user_defined_callable = ser_fun
|
|
237
243
|
active_app = get_active_app_fallback(function_def)
|
|
238
244
|
else:
|
|
239
245
|
# Load the module dynamically
|
|
@@ -244,58 +250,22 @@ def import_single_function_service(
|
|
|
244
250
|
raise LocalFunctionError("Attempted to load a function defined in a function scope")
|
|
245
251
|
|
|
246
252
|
parts = qual_name.split(".")
|
|
247
|
-
if len(parts)
|
|
248
|
-
# This is a function
|
|
249
|
-
user_cls_or_cls = None
|
|
250
|
-
f = getattr(module, qual_name)
|
|
251
|
-
if isinstance(f, Function):
|
|
252
|
-
_function: modal._functions._Function[Any, Any, Any] = synchronizer._translate_in(f) # type: ignore
|
|
253
|
-
service_deps = _function.deps(only_explicit_mounts=True)
|
|
254
|
-
user_defined_callable = _function.get_raw_f()
|
|
255
|
-
assert _function._app # app should always be set on a decorated function
|
|
256
|
-
active_app = _function._app
|
|
257
|
-
else:
|
|
258
|
-
user_defined_callable = f
|
|
259
|
-
active_app = get_active_app_fallback(function_def)
|
|
260
|
-
|
|
261
|
-
elif len(parts) == 2:
|
|
262
|
-
# This path should only be triggered by @build class builder methods and can be removed
|
|
263
|
-
# once @build is deprecated.
|
|
264
|
-
assert not function_def.use_method_name # new "placeholder methods" should not be invoked directly!
|
|
265
|
-
assert function_def.is_builder_function
|
|
266
|
-
cls_name, fun_name = parts
|
|
267
|
-
user_cls_or_cls = getattr(module, cls_name)
|
|
268
|
-
if isinstance(user_cls_or_cls, modal.cls.Cls):
|
|
269
|
-
# The cls decorator is in global scope
|
|
270
|
-
_cls = typing.cast(modal.cls._Cls, synchronizer._translate_in(user_cls_or_cls))
|
|
271
|
-
user_defined_callable = _cls._callables[fun_name]
|
|
272
|
-
# Intentionally not including these, since @build functions don't actually
|
|
273
|
-
# forward the information from their parent class.
|
|
274
|
-
# service_deps = _cls._get_class_service_function().deps(only_explicit_mounts=True)
|
|
275
|
-
assert _cls._app
|
|
276
|
-
active_app = _cls._app
|
|
277
|
-
else:
|
|
278
|
-
# This is non-decorated class
|
|
279
|
-
user_defined_callable = getattr(user_cls_or_cls, fun_name) # unbound method
|
|
280
|
-
active_app = get_active_app_fallback(function_def)
|
|
281
|
-
else:
|
|
253
|
+
if len(parts) != 1:
|
|
282
254
|
raise InvalidError(f"Invalid function qualname {qual_name}")
|
|
283
255
|
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
256
|
+
f = getattr(module, qual_name)
|
|
257
|
+
if isinstance(f, Function):
|
|
258
|
+
_function: modal._functions._Function[Any, Any, Any] = synchronizer._translate_in(f) # type: ignore
|
|
259
|
+
service_deps = _function.deps(only_explicit_mounts=True)
|
|
260
|
+
user_defined_callable = _function.get_raw_f()
|
|
261
|
+
assert _function._app # app should always be set on a decorated function
|
|
262
|
+
active_app = _function._app
|
|
291
263
|
else:
|
|
292
|
-
#
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
user_defined_callable = user_defined_callable.__get__(user_cls_instance)
|
|
264
|
+
# function isn't decorated in global scope
|
|
265
|
+
user_defined_callable = f
|
|
266
|
+
active_app = get_active_app_fallback(function_def)
|
|
296
267
|
|
|
297
268
|
return ImportedFunction(
|
|
298
|
-
user_cls_instance,
|
|
299
269
|
active_app,
|
|
300
270
|
service_deps,
|
|
301
271
|
user_defined_callable,
|
modal/_serialization.py
CHANGED
|
@@ -6,6 +6,14 @@ import typing
|
|
|
6
6
|
from inspect import Parameter
|
|
7
7
|
from typing import Any
|
|
8
8
|
|
|
9
|
+
from modal._traceback import extract_traceback
|
|
10
|
+
from modal.config import config
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
import cbor2 # type: ignore
|
|
14
|
+
except ImportError: # pragma: no cover - optional dependency
|
|
15
|
+
cbor2 = None
|
|
16
|
+
|
|
9
17
|
import google.protobuf.message
|
|
10
18
|
|
|
11
19
|
from modal._utils.async_utils import synchronizer
|
|
@@ -15,7 +23,7 @@ from ._object import _Object
|
|
|
15
23
|
from ._type_manager import parameter_serde_registry, schema_registry
|
|
16
24
|
from ._vendor import cloudpickle
|
|
17
25
|
from .config import logger
|
|
18
|
-
from .exception import DeserializationError, ExecutionError, InvalidError
|
|
26
|
+
from .exception import DeserializationError, ExecutionError, InvalidError, SerializationError
|
|
19
27
|
from .object import Object
|
|
20
28
|
|
|
21
29
|
if typing.TYPE_CHECKING:
|
|
@@ -346,6 +354,12 @@ def _deserialize_asgi(asgi: api_pb2.Asgi) -> Any:
|
|
|
346
354
|
return None
|
|
347
355
|
|
|
348
356
|
|
|
357
|
+
def get_preferred_payload_format() -> "api_pb2.DataFormat.ValueType":
|
|
358
|
+
payload_format = (config.get("payload_format") or "pickle").lower()
|
|
359
|
+
data_format = api_pb2.DATA_FORMAT_CBOR if payload_format == "cbor" else api_pb2.DATA_FORMAT_PICKLE
|
|
360
|
+
return data_format
|
|
361
|
+
|
|
362
|
+
|
|
349
363
|
def serialize_data_format(obj: Any, data_format: int) -> bytes:
|
|
350
364
|
"""Similar to serialize(), but supports other data formats."""
|
|
351
365
|
if data_format == api_pb2.DATA_FORMAT_PICKLE:
|
|
@@ -355,6 +369,21 @@ def serialize_data_format(obj: Any, data_format: int) -> bytes:
|
|
|
355
369
|
elif data_format == api_pb2.DATA_FORMAT_GENERATOR_DONE:
|
|
356
370
|
assert isinstance(obj, api_pb2.GeneratorDone)
|
|
357
371
|
return obj.SerializeToString(deterministic=True)
|
|
372
|
+
elif data_format == api_pb2.DATA_FORMAT_CBOR:
|
|
373
|
+
if cbor2 is None:
|
|
374
|
+
raise InvalidError("CBOR support requires the 'cbor2' package to be installed.")
|
|
375
|
+
try:
|
|
376
|
+
return cbor2.dumps(obj)
|
|
377
|
+
except cbor2.CBOREncodeTypeError:
|
|
378
|
+
try:
|
|
379
|
+
typename = f"{type(obj).__module__}.{type(obj).__name__}"
|
|
380
|
+
except Exception:
|
|
381
|
+
typename = str(type(obj))
|
|
382
|
+
raise SerializationError(
|
|
383
|
+
# TODO (elias): add documentation link for more information on this
|
|
384
|
+
f"Can not serialize type {typename} as cbor. If you need to use a custom data type, "
|
|
385
|
+
"try to serialize it yourself e.g. by using pickle.dumps(my_data)"
|
|
386
|
+
)
|
|
358
387
|
else:
|
|
359
388
|
raise InvalidError(f"Unknown data format {data_format!r}")
|
|
360
389
|
|
|
@@ -366,6 +395,10 @@ def deserialize_data_format(s: bytes, data_format: int, client) -> Any:
|
|
|
366
395
|
return _deserialize_asgi(api_pb2.Asgi.FromString(s))
|
|
367
396
|
elif data_format == api_pb2.DATA_FORMAT_GENERATOR_DONE:
|
|
368
397
|
return api_pb2.GeneratorDone.FromString(s)
|
|
398
|
+
elif data_format == api_pb2.DATA_FORMAT_CBOR:
|
|
399
|
+
if cbor2 is None:
|
|
400
|
+
raise InvalidError("CBOR support requires the 'cbor2' package to be installed.")
|
|
401
|
+
return cbor2.loads(s)
|
|
369
402
|
else:
|
|
370
403
|
raise InvalidError(f"Unknown data format {data_format!r}")
|
|
371
404
|
|
|
@@ -579,3 +612,26 @@ def get_callable_schema(
|
|
|
579
612
|
arguments=arguments,
|
|
580
613
|
return_type=return_type_proto,
|
|
581
614
|
)
|
|
615
|
+
|
|
616
|
+
|
|
617
|
+
def pickle_exception(exc: BaseException) -> bytes:
|
|
618
|
+
try:
|
|
619
|
+
return serialize(exc)
|
|
620
|
+
except Exception as serialization_exc:
|
|
621
|
+
# We can't always serialize exceptions.
|
|
622
|
+
err = f"Failed to serialize exception {exc} of type {type(exc)}: {serialization_exc}"
|
|
623
|
+
logger.info(err)
|
|
624
|
+
return serialize(SerializationError(err))
|
|
625
|
+
|
|
626
|
+
|
|
627
|
+
def pickle_traceback(exc: BaseException, task_id: str) -> tuple[bytes, bytes]:
|
|
628
|
+
serialized_tb, tb_line_cache = b"", b""
|
|
629
|
+
|
|
630
|
+
try:
|
|
631
|
+
tb_dict, line_cache = extract_traceback(exc, task_id)
|
|
632
|
+
serialized_tb = serialize(tb_dict)
|
|
633
|
+
tb_line_cache = serialize(line_cache)
|
|
634
|
+
except Exception:
|
|
635
|
+
logger.info("Failed to serialize exception traceback.")
|
|
636
|
+
|
|
637
|
+
return serialized_tb, tb_line_cache
|
modal/_utils/async_utils.py
CHANGED
|
@@ -1,12 +1,14 @@
|
|
|
1
1
|
# Copyright Modal Labs 2022
|
|
2
2
|
import asyncio
|
|
3
3
|
import concurrent.futures
|
|
4
|
+
import contextlib
|
|
4
5
|
import functools
|
|
5
6
|
import inspect
|
|
6
7
|
import itertools
|
|
7
8
|
import sys
|
|
8
9
|
import time
|
|
9
10
|
import typing
|
|
11
|
+
import warnings
|
|
10
12
|
from collections.abc import AsyncGenerator, AsyncIterable, Awaitable, Iterable, Iterator
|
|
11
13
|
from contextlib import asynccontextmanager
|
|
12
14
|
from dataclasses import dataclass
|
|
@@ -51,6 +53,10 @@ def synchronize_api(obj, target_module=None):
|
|
|
51
53
|
return synchronizer.create_blocking(obj, blocking_name, target_module=target_module)
|
|
52
54
|
|
|
53
55
|
|
|
56
|
+
# Used for testing to configure the `n_attempts` that `retry` will use.
|
|
57
|
+
RETRY_N_ATTEMPTS_OVERRIDE: Optional[int] = None
|
|
58
|
+
|
|
59
|
+
|
|
54
60
|
def retry(direct_fn=None, *, n_attempts=3, base_delay=0, delay_factor=2, timeout=90):
|
|
55
61
|
"""Decorator that calls an async function multiple times, with a given timeout.
|
|
56
62
|
|
|
@@ -75,8 +81,13 @@ def retry(direct_fn=None, *, n_attempts=3, base_delay=0, delay_factor=2, timeout
|
|
|
75
81
|
def decorator(fn):
|
|
76
82
|
@functools.wraps(fn)
|
|
77
83
|
async def f_wrapped(*args, **kwargs):
|
|
84
|
+
if RETRY_N_ATTEMPTS_OVERRIDE is not None:
|
|
85
|
+
local_n_attempts = RETRY_N_ATTEMPTS_OVERRIDE
|
|
86
|
+
else:
|
|
87
|
+
local_n_attempts = n_attempts
|
|
88
|
+
|
|
78
89
|
delay = base_delay
|
|
79
|
-
for i in range(
|
|
90
|
+
for i in range(local_n_attempts):
|
|
80
91
|
t0 = time.time()
|
|
81
92
|
try:
|
|
82
93
|
return await asyncio.wait_for(fn(*args, **kwargs), timeout=timeout)
|
|
@@ -84,12 +95,12 @@ def retry(direct_fn=None, *, n_attempts=3, base_delay=0, delay_factor=2, timeout
|
|
|
84
95
|
logger.debug(f"Function {fn} was cancelled")
|
|
85
96
|
raise
|
|
86
97
|
except Exception as e:
|
|
87
|
-
if i >=
|
|
98
|
+
if i >= local_n_attempts - 1:
|
|
88
99
|
raise
|
|
89
100
|
logger.debug(
|
|
90
101
|
f"Failed invoking function {fn}: {e}"
|
|
91
102
|
f" (took {time.time() - t0}s, sleeping {delay}s"
|
|
92
|
-
f" and trying {
|
|
103
|
+
f" and trying {local_n_attempts - i - 1} more times)"
|
|
93
104
|
)
|
|
94
105
|
await asyncio.sleep(delay)
|
|
95
106
|
delay *= delay_factor
|
|
@@ -125,7 +136,8 @@ class TaskContext:
|
|
|
125
136
|
_loops: set[asyncio.Task]
|
|
126
137
|
|
|
127
138
|
def __init__(self, grace: Optional[float] = None):
|
|
128
|
-
self._grace = grace
|
|
139
|
+
self._grace = grace # grace is the time we want for tasks to finish before cancelling them
|
|
140
|
+
self._cancellation_grace: float = 1.0 # extra graceperiod for the cancellation itself to "bubble up"
|
|
129
141
|
self._loops = set()
|
|
130
142
|
|
|
131
143
|
async def start(self):
|
|
@@ -157,22 +169,29 @@ class TaskContext:
|
|
|
157
169
|
# still needs to be handled
|
|
158
170
|
# (https://stackoverflow.com/a/63356323/2475114)
|
|
159
171
|
if gather_future:
|
|
160
|
-
|
|
172
|
+
with contextlib.suppress(asyncio.CancelledError):
|
|
161
173
|
await gather_future
|
|
162
|
-
except asyncio.CancelledError:
|
|
163
|
-
pass
|
|
164
174
|
|
|
175
|
+
cancelled_tasks: list[asyncio.Task] = []
|
|
165
176
|
for task in self._tasks:
|
|
166
177
|
if task.done() and not task.cancelled():
|
|
167
178
|
# Raise any exceptions if they happened.
|
|
168
179
|
# Only tasks without a done_callback will still be present in self._tasks
|
|
169
180
|
task.result()
|
|
170
181
|
|
|
171
|
-
if task.done()
|
|
182
|
+
if task.done():
|
|
172
183
|
continue
|
|
173
184
|
|
|
174
185
|
# Cancel any remaining unfinished tasks.
|
|
175
186
|
task.cancel()
|
|
187
|
+
cancelled_tasks.append(task)
|
|
188
|
+
|
|
189
|
+
cancellation_gather = asyncio.gather(*cancelled_tasks, return_exceptions=True)
|
|
190
|
+
try:
|
|
191
|
+
await asyncio.wait_for(cancellation_gather, timeout=self._cancellation_grace)
|
|
192
|
+
except asyncio.TimeoutError:
|
|
193
|
+
warnings.warn(f"Internal warning: Tasks did not cancel in a timely manner: {cancelled_tasks}")
|
|
194
|
+
|
|
176
195
|
await asyncio.sleep(0) # wake up coroutines waiting for cancellations
|
|
177
196
|
|
|
178
197
|
async def __aexit__(self, exc_type, value, tb):
|
|
@@ -279,7 +298,9 @@ class TimestampPriorityQueue(Generic[T]):
|
|
|
279
298
|
|
|
280
299
|
def __init__(self, maxsize: int = 0):
|
|
281
300
|
self.condition = asyncio.Condition()
|
|
282
|
-
self._queue: asyncio.PriorityQueue[tuple[float, Union[T, None]]] = asyncio.PriorityQueue(maxsize=maxsize)
|
|
301
|
+
self._queue: asyncio.PriorityQueue[tuple[float, int, Union[T, None]]] = asyncio.PriorityQueue(maxsize=maxsize)
|
|
302
|
+
# Used to tiebreak items with the same timestamp that are not comparable. (eg. protos)
|
|
303
|
+
self._counter = itertools.count()
|
|
283
304
|
|
|
284
305
|
async def close(self):
|
|
285
306
|
await self.put(self._MAX_PRIORITY, None)
|
|
@@ -288,7 +309,7 @@ class TimestampPriorityQueue(Generic[T]):
|
|
|
288
309
|
"""
|
|
289
310
|
Add an item to the queue to be processed at a specific timestamp.
|
|
290
311
|
"""
|
|
291
|
-
await self._queue.put((timestamp, item))
|
|
312
|
+
await self._queue.put((timestamp, next(self._counter), item))
|
|
292
313
|
async with self.condition:
|
|
293
314
|
self.condition.notify_all() # notify any waiting coroutines
|
|
294
315
|
|
|
@@ -301,7 +322,7 @@ class TimestampPriorityQueue(Generic[T]):
|
|
|
301
322
|
while self.empty():
|
|
302
323
|
await self.condition.wait()
|
|
303
324
|
# peek at the next item
|
|
304
|
-
timestamp, item = await self._queue.get()
|
|
325
|
+
timestamp, counter, item = await self._queue.get()
|
|
305
326
|
now = time.time()
|
|
306
327
|
if timestamp < now:
|
|
307
328
|
return item
|
|
@@ -309,7 +330,7 @@ class TimestampPriorityQueue(Generic[T]):
|
|
|
309
330
|
return None
|
|
310
331
|
# not ready yet, calculate sleep time
|
|
311
332
|
sleep_time = timestamp - now
|
|
312
|
-
self._queue.put_nowait((timestamp, item)) # put it back
|
|
333
|
+
self._queue.put_nowait((timestamp, counter, item)) # put it back
|
|
313
334
|
# wait until either the timeout or a new item is added
|
|
314
335
|
try:
|
|
315
336
|
await asyncio.wait_for(self.condition.wait(), timeout=sleep_time)
|
|
@@ -9,7 +9,6 @@ from typing import Any
|
|
|
9
9
|
from modal.exception import ExecutionError
|
|
10
10
|
from modal_proto import api_pb2, modal_api_grpc
|
|
11
11
|
|
|
12
|
-
from .grpc_utils import retry_transient_errors
|
|
13
12
|
from .logger import logger
|
|
14
13
|
|
|
15
14
|
|
|
@@ -27,7 +26,7 @@ class _AuthTokenManager:
|
|
|
27
26
|
self._expiry = 0.0
|
|
28
27
|
self._lock: typing.Union[asyncio.Lock, None] = None
|
|
29
28
|
|
|
30
|
-
async def get_token(self):
|
|
29
|
+
async def get_token(self) -> str:
|
|
31
30
|
"""
|
|
32
31
|
When called, the AuthTokenManager can be in one of three states:
|
|
33
32
|
1. Has a valid cached token. It is returned to the caller.
|
|
@@ -66,9 +65,7 @@ class _AuthTokenManager:
|
|
|
66
65
|
# new token. Once we have a new token, the other coroutines will unblock and return from here.
|
|
67
66
|
if self._token and not self._needs_refresh():
|
|
68
67
|
return
|
|
69
|
-
resp: api_pb2.AuthTokenGetResponse = await
|
|
70
|
-
self._stub.AuthTokenGet, api_pb2.AuthTokenGetRequest()
|
|
71
|
-
)
|
|
68
|
+
resp: api_pb2.AuthTokenGetResponse = await self._stub.AuthTokenGet(api_pb2.AuthTokenGetRequest())
|
|
72
69
|
if not resp.token:
|
|
73
70
|
# Not expected
|
|
74
71
|
raise ExecutionError(
|