modal 1.0.6.dev58__py3-none-any.whl → 1.2.3.dev7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of modal might be problematic. Click here for more details.

Files changed (147) hide show
  1. modal/__main__.py +3 -4
  2. modal/_billing.py +80 -0
  3. modal/_clustered_functions.py +7 -3
  4. modal/_clustered_functions.pyi +4 -2
  5. modal/_container_entrypoint.py +41 -49
  6. modal/_functions.py +424 -195
  7. modal/_grpc_client.py +171 -0
  8. modal/_load_context.py +105 -0
  9. modal/_object.py +68 -20
  10. modal/_output.py +58 -45
  11. modal/_partial_function.py +36 -11
  12. modal/_pty.py +7 -3
  13. modal/_resolver.py +21 -35
  14. modal/_runtime/asgi.py +4 -3
  15. modal/_runtime/container_io_manager.py +301 -186
  16. modal/_runtime/container_io_manager.pyi +70 -61
  17. modal/_runtime/execution_context.py +18 -2
  18. modal/_runtime/execution_context.pyi +4 -1
  19. modal/_runtime/gpu_memory_snapshot.py +170 -63
  20. modal/_runtime/user_code_imports.py +28 -58
  21. modal/_serialization.py +57 -1
  22. modal/_utils/async_utils.py +33 -12
  23. modal/_utils/auth_token_manager.py +2 -5
  24. modal/_utils/blob_utils.py +110 -53
  25. modal/_utils/function_utils.py +49 -42
  26. modal/_utils/grpc_utils.py +80 -50
  27. modal/_utils/mount_utils.py +26 -1
  28. modal/_utils/name_utils.py +17 -3
  29. modal/_utils/task_command_router_client.py +536 -0
  30. modal/_utils/time_utils.py +34 -6
  31. modal/app.py +219 -83
  32. modal/app.pyi +229 -56
  33. modal/billing.py +5 -0
  34. modal/{requirements → builder}/2025.06.txt +1 -0
  35. modal/{requirements → builder}/PREVIEW.txt +1 -0
  36. modal/cli/_download.py +19 -3
  37. modal/cli/_traceback.py +3 -2
  38. modal/cli/app.py +4 -4
  39. modal/cli/cluster.py +15 -7
  40. modal/cli/config.py +5 -3
  41. modal/cli/container.py +7 -6
  42. modal/cli/dict.py +22 -16
  43. modal/cli/entry_point.py +12 -5
  44. modal/cli/environment.py +5 -4
  45. modal/cli/import_refs.py +3 -3
  46. modal/cli/launch.py +102 -5
  47. modal/cli/network_file_system.py +9 -13
  48. modal/cli/profile.py +3 -2
  49. modal/cli/programs/launch_instance_ssh.py +94 -0
  50. modal/cli/programs/run_jupyter.py +1 -1
  51. modal/cli/programs/run_marimo.py +95 -0
  52. modal/cli/programs/vscode.py +1 -1
  53. modal/cli/queues.py +57 -26
  54. modal/cli/run.py +58 -16
  55. modal/cli/secret.py +48 -22
  56. modal/cli/utils.py +3 -4
  57. modal/cli/volume.py +28 -25
  58. modal/client.py +13 -116
  59. modal/client.pyi +9 -91
  60. modal/cloud_bucket_mount.py +5 -3
  61. modal/cloud_bucket_mount.pyi +5 -1
  62. modal/cls.py +130 -102
  63. modal/cls.pyi +45 -85
  64. modal/config.py +29 -10
  65. modal/container_process.py +291 -13
  66. modal/container_process.pyi +95 -32
  67. modal/dict.py +282 -63
  68. modal/dict.pyi +423 -73
  69. modal/environments.py +15 -27
  70. modal/environments.pyi +5 -15
  71. modal/exception.py +8 -0
  72. modal/experimental/__init__.py +143 -38
  73. modal/experimental/flash.py +247 -78
  74. modal/experimental/flash.pyi +137 -9
  75. modal/file_io.py +14 -28
  76. modal/file_io.pyi +2 -2
  77. modal/file_pattern_matcher.py +25 -16
  78. modal/functions.pyi +134 -61
  79. modal/image.py +255 -86
  80. modal/image.pyi +300 -62
  81. modal/io_streams.py +436 -126
  82. modal/io_streams.pyi +236 -171
  83. modal/mount.py +62 -157
  84. modal/mount.pyi +45 -172
  85. modal/network_file_system.py +30 -53
  86. modal/network_file_system.pyi +16 -76
  87. modal/object.pyi +42 -8
  88. modal/parallel_map.py +821 -113
  89. modal/parallel_map.pyi +134 -0
  90. modal/partial_function.pyi +4 -1
  91. modal/proxy.py +16 -7
  92. modal/proxy.pyi +10 -2
  93. modal/queue.py +263 -61
  94. modal/queue.pyi +409 -66
  95. modal/runner.py +112 -92
  96. modal/runner.pyi +45 -27
  97. modal/sandbox.py +451 -124
  98. modal/sandbox.pyi +513 -67
  99. modal/secret.py +291 -67
  100. modal/secret.pyi +425 -19
  101. modal/serving.py +7 -11
  102. modal/serving.pyi +7 -8
  103. modal/snapshot.py +11 -8
  104. modal/token_flow.py +4 -4
  105. modal/volume.py +344 -98
  106. modal/volume.pyi +464 -68
  107. {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/METADATA +9 -8
  108. modal-1.2.3.dev7.dist-info/RECORD +195 -0
  109. modal_docs/mdmd/mdmd.py +11 -1
  110. modal_proto/api.proto +399 -67
  111. modal_proto/api_grpc.py +241 -1
  112. modal_proto/api_pb2.py +1395 -1000
  113. modal_proto/api_pb2.pyi +1239 -79
  114. modal_proto/api_pb2_grpc.py +499 -4
  115. modal_proto/api_pb2_grpc.pyi +162 -14
  116. modal_proto/modal_api_grpc.py +175 -160
  117. modal_proto/sandbox_router.proto +145 -0
  118. modal_proto/sandbox_router_grpc.py +105 -0
  119. modal_proto/sandbox_router_pb2.py +149 -0
  120. modal_proto/sandbox_router_pb2.pyi +333 -0
  121. modal_proto/sandbox_router_pb2_grpc.py +203 -0
  122. modal_proto/sandbox_router_pb2_grpc.pyi +75 -0
  123. modal_proto/task_command_router.proto +144 -0
  124. modal_proto/task_command_router_grpc.py +105 -0
  125. modal_proto/task_command_router_pb2.py +149 -0
  126. modal_proto/task_command_router_pb2.pyi +333 -0
  127. modal_proto/task_command_router_pb2_grpc.py +203 -0
  128. modal_proto/task_command_router_pb2_grpc.pyi +75 -0
  129. modal_version/__init__.py +1 -1
  130. modal-1.0.6.dev58.dist-info/RECORD +0 -183
  131. modal_proto/modal_options_grpc.py +0 -3
  132. modal_proto/options.proto +0 -19
  133. modal_proto/options_grpc.py +0 -3
  134. modal_proto/options_pb2.py +0 -35
  135. modal_proto/options_pb2.pyi +0 -20
  136. modal_proto/options_pb2_grpc.py +0 -4
  137. modal_proto/options_pb2_grpc.pyi +0 -7
  138. /modal/{requirements → builder}/2023.12.312.txt +0 -0
  139. /modal/{requirements → builder}/2023.12.txt +0 -0
  140. /modal/{requirements → builder}/2024.04.txt +0 -0
  141. /modal/{requirements → builder}/2024.10.txt +0 -0
  142. /modal/{requirements → builder}/README.md +0 -0
  143. /modal/{requirements → builder}/base-images.json +0 -0
  144. {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/WHEEL +0 -0
  145. {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/entry_points.txt +0 -0
  146. {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/licenses/LICENSE +0 -0
  147. {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/top_level.txt +0 -0
@@ -29,7 +29,7 @@ class FinalizedFunction:
29
29
  callable: Callable[..., Any]
30
30
  is_async: bool
31
31
  is_generator: bool
32
- data_format: int # api_pb2.DataFormat
32
+ supported_output_formats: Sequence["api_pb2.DataFormat.ValueType"]
33
33
  lifespan_manager: Optional["LifespanManager"] = None
34
34
 
35
35
 
@@ -93,9 +93,9 @@ def construct_webhook_callable(
93
93
 
94
94
  @dataclass
95
95
  class ImportedFunction(Service):
96
- user_cls_instance: Any
97
96
  app: modal.app._App
98
97
  service_deps: Optional[Sequence["modal._object._Object"]]
98
+ user_cls_instance = None
99
99
 
100
100
  _user_defined_callable: Callable[..., Any]
101
101
 
@@ -108,6 +108,7 @@ class ImportedFunction(Service):
108
108
  is_generator = fun_def.function_type == api_pb2.Function.FUNCTION_TYPE_GENERATOR
109
109
 
110
110
  webhook_config = fun_def.webhook_config
111
+
111
112
  if not webhook_config.type:
112
113
  # for non-webhooks, the runnable is straight forward:
113
114
  return {
@@ -115,7 +116,10 @@ class ImportedFunction(Service):
115
116
  callable=self._user_defined_callable,
116
117
  is_async=is_async,
117
118
  is_generator=is_generator,
118
- data_format=api_pb2.DATA_FORMAT_PICKLE,
119
+ supported_output_formats=fun_def.supported_output_formats
120
+ # FIXME (elias): the following `or [api_pb2.DATA_FORMAT_PICKLE, api_pb2.DATA_FORMAT_CBOR]` is only
121
+ # needed for tests
122
+ or [api_pb2.DATA_FORMAT_PICKLE, api_pb2.DATA_FORMAT_CBOR],
119
123
  )
120
124
  }
121
125
 
@@ -129,7 +133,8 @@ class ImportedFunction(Service):
129
133
  lifespan_manager=lifespan_manager,
130
134
  is_async=True,
131
135
  is_generator=True,
132
- data_format=api_pb2.DATA_FORMAT_ASGI,
136
+ # FIXME (elias): the following `or [api_pb2.DATA_FORMAT_ASGI]` is only needed for tests
137
+ supported_output_formats=fun_def.supported_output_formats or [api_pb2.DATA_FORMAT_ASGI],
133
138
  )
134
139
  }
135
140
 
@@ -154,6 +159,7 @@ class ImportedClass(Service):
154
159
  # Use the function definition for whether this is a generator (overriden by webhooks)
155
160
  is_generator = _partial.params.is_generator
156
161
  webhook_config = _partial.params.webhook_config
162
+ method_def = fun_def.method_definitions[method_name]
157
163
 
158
164
  bound_func = user_func.__get__(self.user_cls_instance)
159
165
 
@@ -163,7 +169,10 @@ class ImportedClass(Service):
163
169
  callable=bound_func,
164
170
  is_async=is_async,
165
171
  is_generator=bool(is_generator),
166
- data_format=api_pb2.DATA_FORMAT_PICKLE,
172
+ # FIXME (elias): the following `or [api_pb2.DATA_FORMAT_PICKLE, api_pb2.DATA_FORMAT_CBOR]` is only
173
+ # needed for tests
174
+ supported_output_formats=method_def.supported_output_formats
175
+ or [api_pb2.DATA_FORMAT_PICKLE, api_pb2.DATA_FORMAT_CBOR],
167
176
  )
168
177
  else:
169
178
  web_callable, lifespan_manager = construct_webhook_callable(
@@ -174,7 +183,8 @@ class ImportedClass(Service):
174
183
  lifespan_manager=lifespan_manager,
175
184
  is_async=True,
176
185
  is_generator=True,
177
- data_format=api_pb2.DATA_FORMAT_ASGI,
186
+ # FIXME (elias): the following `or [api_pb2.DATA_FORMAT_ASGI]` is only needed for tests
187
+ supported_output_formats=method_def.supported_output_formats or [api_pb2.DATA_FORMAT_ASGI],
178
188
  )
179
189
  finalized_functions[method_name] = finalized_function
180
190
  return finalized_functions
@@ -199,7 +209,6 @@ def get_user_class_instance(_cls: modal.cls._Cls, args: tuple[Any, ...], kwargs:
199
209
 
200
210
  def import_single_function_service(
201
211
  function_def: api_pb2.Function,
202
- ser_cls: Optional[type], # used only for @build functions
203
212
  ser_fun: Optional[Callable[..., Any]],
204
213
  ) -> Service:
205
214
  """Imports a function dynamically, and locates the app.
@@ -228,12 +237,9 @@ def import_single_function_service(
228
237
  service_deps: Optional[Sequence["modal._object._Object"]] = None
229
238
  active_app: modal.app._App
230
239
 
231
- user_cls_or_cls: typing.Union[None, type, modal.cls.Cls]
232
- user_cls_instance = None
233
-
234
240
  if ser_fun is not None:
235
241
  # This is a serialized function we already fetched from the server
236
- user_cls_or_cls, user_defined_callable = ser_cls, ser_fun
242
+ user_defined_callable = ser_fun
237
243
  active_app = get_active_app_fallback(function_def)
238
244
  else:
239
245
  # Load the module dynamically
@@ -244,58 +250,22 @@ def import_single_function_service(
244
250
  raise LocalFunctionError("Attempted to load a function defined in a function scope")
245
251
 
246
252
  parts = qual_name.split(".")
247
- if len(parts) == 1:
248
- # This is a function
249
- user_cls_or_cls = None
250
- f = getattr(module, qual_name)
251
- if isinstance(f, Function):
252
- _function: modal._functions._Function[Any, Any, Any] = synchronizer._translate_in(f) # type: ignore
253
- service_deps = _function.deps(only_explicit_mounts=True)
254
- user_defined_callable = _function.get_raw_f()
255
- assert _function._app # app should always be set on a decorated function
256
- active_app = _function._app
257
- else:
258
- user_defined_callable = f
259
- active_app = get_active_app_fallback(function_def)
260
-
261
- elif len(parts) == 2:
262
- # This path should only be triggered by @build class builder methods and can be removed
263
- # once @build is deprecated.
264
- assert not function_def.use_method_name # new "placeholder methods" should not be invoked directly!
265
- assert function_def.is_builder_function
266
- cls_name, fun_name = parts
267
- user_cls_or_cls = getattr(module, cls_name)
268
- if isinstance(user_cls_or_cls, modal.cls.Cls):
269
- # The cls decorator is in global scope
270
- _cls = typing.cast(modal.cls._Cls, synchronizer._translate_in(user_cls_or_cls))
271
- user_defined_callable = _cls._callables[fun_name]
272
- # Intentionally not including these, since @build functions don't actually
273
- # forward the information from their parent class.
274
- # service_deps = _cls._get_class_service_function().deps(only_explicit_mounts=True)
275
- assert _cls._app
276
- active_app = _cls._app
277
- else:
278
- # This is non-decorated class
279
- user_defined_callable = getattr(user_cls_or_cls, fun_name) # unbound method
280
- active_app = get_active_app_fallback(function_def)
281
- else:
253
+ if len(parts) != 1:
282
254
  raise InvalidError(f"Invalid function qualname {qual_name}")
283
255
 
284
- # Instantiate the class if it's defined
285
- if user_cls_or_cls:
286
- if isinstance(user_cls_or_cls, modal.cls.Cls):
287
- # This code is only used for @build methods on classes
288
- _cls = typing.cast(modal.cls._Cls, user_cls_or_cls)
289
- user_cls_instance = get_user_class_instance(_cls, (), {})
290
- # Bind the unbound method to the instance as self (using the descriptor protocol!)
256
+ f = getattr(module, qual_name)
257
+ if isinstance(f, Function):
258
+ _function: modal._functions._Function[Any, Any, Any] = synchronizer._translate_in(f) # type: ignore
259
+ service_deps = _function.deps(only_explicit_mounts=True)
260
+ user_defined_callable = _function.get_raw_f()
261
+ assert _function._app # app should always be set on a decorated function
262
+ active_app = _function._app
291
263
  else:
292
- # serialized=True or "undecorated"
293
- user_cls_instance = user_cls_or_cls()
294
-
295
- user_defined_callable = user_defined_callable.__get__(user_cls_instance)
264
+ # function isn't decorated in global scope
265
+ user_defined_callable = f
266
+ active_app = get_active_app_fallback(function_def)
296
267
 
297
268
  return ImportedFunction(
298
- user_cls_instance,
299
269
  active_app,
300
270
  service_deps,
301
271
  user_defined_callable,
modal/_serialization.py CHANGED
@@ -6,6 +6,14 @@ import typing
6
6
  from inspect import Parameter
7
7
  from typing import Any
8
8
 
9
+ from modal._traceback import extract_traceback
10
+ from modal.config import config
11
+
12
+ try:
13
+ import cbor2 # type: ignore
14
+ except ImportError: # pragma: no cover - optional dependency
15
+ cbor2 = None
16
+
9
17
  import google.protobuf.message
10
18
 
11
19
  from modal._utils.async_utils import synchronizer
@@ -15,7 +23,7 @@ from ._object import _Object
15
23
  from ._type_manager import parameter_serde_registry, schema_registry
16
24
  from ._vendor import cloudpickle
17
25
  from .config import logger
18
- from .exception import DeserializationError, ExecutionError, InvalidError
26
+ from .exception import DeserializationError, ExecutionError, InvalidError, SerializationError
19
27
  from .object import Object
20
28
 
21
29
  if typing.TYPE_CHECKING:
@@ -346,6 +354,12 @@ def _deserialize_asgi(asgi: api_pb2.Asgi) -> Any:
346
354
  return None
347
355
 
348
356
 
357
+ def get_preferred_payload_format() -> "api_pb2.DataFormat.ValueType":
358
+ payload_format = (config.get("payload_format") or "pickle").lower()
359
+ data_format = api_pb2.DATA_FORMAT_CBOR if payload_format == "cbor" else api_pb2.DATA_FORMAT_PICKLE
360
+ return data_format
361
+
362
+
349
363
  def serialize_data_format(obj: Any, data_format: int) -> bytes:
350
364
  """Similar to serialize(), but supports other data formats."""
351
365
  if data_format == api_pb2.DATA_FORMAT_PICKLE:
@@ -355,6 +369,21 @@ def serialize_data_format(obj: Any, data_format: int) -> bytes:
355
369
  elif data_format == api_pb2.DATA_FORMAT_GENERATOR_DONE:
356
370
  assert isinstance(obj, api_pb2.GeneratorDone)
357
371
  return obj.SerializeToString(deterministic=True)
372
+ elif data_format == api_pb2.DATA_FORMAT_CBOR:
373
+ if cbor2 is None:
374
+ raise InvalidError("CBOR support requires the 'cbor2' package to be installed.")
375
+ try:
376
+ return cbor2.dumps(obj)
377
+ except cbor2.CBOREncodeTypeError:
378
+ try:
379
+ typename = f"{type(obj).__module__}.{type(obj).__name__}"
380
+ except Exception:
381
+ typename = str(type(obj))
382
+ raise SerializationError(
383
+ # TODO (elias): add documentation link for more information on this
384
+ f"Can not serialize type {typename} as cbor. If you need to use a custom data type, "
385
+ "try to serialize it yourself e.g. by using pickle.dumps(my_data)"
386
+ )
358
387
  else:
359
388
  raise InvalidError(f"Unknown data format {data_format!r}")
360
389
 
@@ -366,6 +395,10 @@ def deserialize_data_format(s: bytes, data_format: int, client) -> Any:
366
395
  return _deserialize_asgi(api_pb2.Asgi.FromString(s))
367
396
  elif data_format == api_pb2.DATA_FORMAT_GENERATOR_DONE:
368
397
  return api_pb2.GeneratorDone.FromString(s)
398
+ elif data_format == api_pb2.DATA_FORMAT_CBOR:
399
+ if cbor2 is None:
400
+ raise InvalidError("CBOR support requires the 'cbor2' package to be installed.")
401
+ return cbor2.loads(s)
369
402
  else:
370
403
  raise InvalidError(f"Unknown data format {data_format!r}")
371
404
 
@@ -579,3 +612,26 @@ def get_callable_schema(
579
612
  arguments=arguments,
580
613
  return_type=return_type_proto,
581
614
  )
615
+
616
+
617
+ def pickle_exception(exc: BaseException) -> bytes:
618
+ try:
619
+ return serialize(exc)
620
+ except Exception as serialization_exc:
621
+ # We can't always serialize exceptions.
622
+ err = f"Failed to serialize exception {exc} of type {type(exc)}: {serialization_exc}"
623
+ logger.info(err)
624
+ return serialize(SerializationError(err))
625
+
626
+
627
+ def pickle_traceback(exc: BaseException, task_id: str) -> tuple[bytes, bytes]:
628
+ serialized_tb, tb_line_cache = b"", b""
629
+
630
+ try:
631
+ tb_dict, line_cache = extract_traceback(exc, task_id)
632
+ serialized_tb = serialize(tb_dict)
633
+ tb_line_cache = serialize(line_cache)
634
+ except Exception:
635
+ logger.info("Failed to serialize exception traceback.")
636
+
637
+ return serialized_tb, tb_line_cache
@@ -1,12 +1,14 @@
1
1
  # Copyright Modal Labs 2022
2
2
  import asyncio
3
3
  import concurrent.futures
4
+ import contextlib
4
5
  import functools
5
6
  import inspect
6
7
  import itertools
7
8
  import sys
8
9
  import time
9
10
  import typing
11
+ import warnings
10
12
  from collections.abc import AsyncGenerator, AsyncIterable, Awaitable, Iterable, Iterator
11
13
  from contextlib import asynccontextmanager
12
14
  from dataclasses import dataclass
@@ -51,6 +53,10 @@ def synchronize_api(obj, target_module=None):
51
53
  return synchronizer.create_blocking(obj, blocking_name, target_module=target_module)
52
54
 
53
55
 
56
+ # Used for testing to configure the `n_attempts` that `retry` will use.
57
+ RETRY_N_ATTEMPTS_OVERRIDE: Optional[int] = None
58
+
59
+
54
60
  def retry(direct_fn=None, *, n_attempts=3, base_delay=0, delay_factor=2, timeout=90):
55
61
  """Decorator that calls an async function multiple times, with a given timeout.
56
62
 
@@ -75,8 +81,13 @@ def retry(direct_fn=None, *, n_attempts=3, base_delay=0, delay_factor=2, timeout
75
81
  def decorator(fn):
76
82
  @functools.wraps(fn)
77
83
  async def f_wrapped(*args, **kwargs):
84
+ if RETRY_N_ATTEMPTS_OVERRIDE is not None:
85
+ local_n_attempts = RETRY_N_ATTEMPTS_OVERRIDE
86
+ else:
87
+ local_n_attempts = n_attempts
88
+
78
89
  delay = base_delay
79
- for i in range(n_attempts):
90
+ for i in range(local_n_attempts):
80
91
  t0 = time.time()
81
92
  try:
82
93
  return await asyncio.wait_for(fn(*args, **kwargs), timeout=timeout)
@@ -84,12 +95,12 @@ def retry(direct_fn=None, *, n_attempts=3, base_delay=0, delay_factor=2, timeout
84
95
  logger.debug(f"Function {fn} was cancelled")
85
96
  raise
86
97
  except Exception as e:
87
- if i >= n_attempts - 1:
98
+ if i >= local_n_attempts - 1:
88
99
  raise
89
100
  logger.debug(
90
101
  f"Failed invoking function {fn}: {e}"
91
102
  f" (took {time.time() - t0}s, sleeping {delay}s"
92
- f" and trying {n_attempts - i - 1} more times)"
103
+ f" and trying {local_n_attempts - i - 1} more times)"
93
104
  )
94
105
  await asyncio.sleep(delay)
95
106
  delay *= delay_factor
@@ -125,7 +136,8 @@ class TaskContext:
125
136
  _loops: set[asyncio.Task]
126
137
 
127
138
  def __init__(self, grace: Optional[float] = None):
128
- self._grace = grace
139
+ self._grace = grace # grace is the time we want for tasks to finish before cancelling them
140
+ self._cancellation_grace: float = 1.0 # extra graceperiod for the cancellation itself to "bubble up"
129
141
  self._loops = set()
130
142
 
131
143
  async def start(self):
@@ -157,22 +169,29 @@ class TaskContext:
157
169
  # still needs to be handled
158
170
  # (https://stackoverflow.com/a/63356323/2475114)
159
171
  if gather_future:
160
- try:
172
+ with contextlib.suppress(asyncio.CancelledError):
161
173
  await gather_future
162
- except asyncio.CancelledError:
163
- pass
164
174
 
175
+ cancelled_tasks: list[asyncio.Task] = []
165
176
  for task in self._tasks:
166
177
  if task.done() and not task.cancelled():
167
178
  # Raise any exceptions if they happened.
168
179
  # Only tasks without a done_callback will still be present in self._tasks
169
180
  task.result()
170
181
 
171
- if task.done() or task in self._loops: # Note: Legacy code, we can probably cancel loops.
182
+ if task.done():
172
183
  continue
173
184
 
174
185
  # Cancel any remaining unfinished tasks.
175
186
  task.cancel()
187
+ cancelled_tasks.append(task)
188
+
189
+ cancellation_gather = asyncio.gather(*cancelled_tasks, return_exceptions=True)
190
+ try:
191
+ await asyncio.wait_for(cancellation_gather, timeout=self._cancellation_grace)
192
+ except asyncio.TimeoutError:
193
+ warnings.warn(f"Internal warning: Tasks did not cancel in a timely manner: {cancelled_tasks}")
194
+
176
195
  await asyncio.sleep(0) # wake up coroutines waiting for cancellations
177
196
 
178
197
  async def __aexit__(self, exc_type, value, tb):
@@ -279,7 +298,9 @@ class TimestampPriorityQueue(Generic[T]):
279
298
 
280
299
  def __init__(self, maxsize: int = 0):
281
300
  self.condition = asyncio.Condition()
282
- self._queue: asyncio.PriorityQueue[tuple[float, Union[T, None]]] = asyncio.PriorityQueue(maxsize=maxsize)
301
+ self._queue: asyncio.PriorityQueue[tuple[float, int, Union[T, None]]] = asyncio.PriorityQueue(maxsize=maxsize)
302
+ # Used to tiebreak items with the same timestamp that are not comparable. (eg. protos)
303
+ self._counter = itertools.count()
283
304
 
284
305
  async def close(self):
285
306
  await self.put(self._MAX_PRIORITY, None)
@@ -288,7 +309,7 @@ class TimestampPriorityQueue(Generic[T]):
288
309
  """
289
310
  Add an item to the queue to be processed at a specific timestamp.
290
311
  """
291
- await self._queue.put((timestamp, item))
312
+ await self._queue.put((timestamp, next(self._counter), item))
292
313
  async with self.condition:
293
314
  self.condition.notify_all() # notify any waiting coroutines
294
315
 
@@ -301,7 +322,7 @@ class TimestampPriorityQueue(Generic[T]):
301
322
  while self.empty():
302
323
  await self.condition.wait()
303
324
  # peek at the next item
304
- timestamp, item = await self._queue.get()
325
+ timestamp, counter, item = await self._queue.get()
305
326
  now = time.time()
306
327
  if timestamp < now:
307
328
  return item
@@ -309,7 +330,7 @@ class TimestampPriorityQueue(Generic[T]):
309
330
  return None
310
331
  # not ready yet, calculate sleep time
311
332
  sleep_time = timestamp - now
312
- self._queue.put_nowait((timestamp, item)) # put it back
333
+ self._queue.put_nowait((timestamp, counter, item)) # put it back
313
334
  # wait until either the timeout or a new item is added
314
335
  try:
315
336
  await asyncio.wait_for(self.condition.wait(), timeout=sleep_time)
@@ -9,7 +9,6 @@ from typing import Any
9
9
  from modal.exception import ExecutionError
10
10
  from modal_proto import api_pb2, modal_api_grpc
11
11
 
12
- from .grpc_utils import retry_transient_errors
13
12
  from .logger import logger
14
13
 
15
14
 
@@ -27,7 +26,7 @@ class _AuthTokenManager:
27
26
  self._expiry = 0.0
28
27
  self._lock: typing.Union[asyncio.Lock, None] = None
29
28
 
30
- async def get_token(self):
29
+ async def get_token(self) -> str:
31
30
  """
32
31
  When called, the AuthTokenManager can be in one of three states:
33
32
  1. Has a valid cached token. It is returned to the caller.
@@ -66,9 +65,7 @@ class _AuthTokenManager:
66
65
  # new token. Once we have a new token, the other coroutines will unblock and return from here.
67
66
  if self._token and not self._needs_refresh():
68
67
  return
69
- resp: api_pb2.AuthTokenGetResponse = await retry_transient_errors(
70
- self._stub.AuthTokenGet, api_pb2.AuthTokenGetRequest()
71
- )
68
+ resp: api_pb2.AuthTokenGetResponse = await self._stub.AuthTokenGet(api_pb2.AuthTokenGetRequest())
72
69
  if not resp.token:
73
70
  # Not expected
74
71
  raise ExecutionError(