modal 1.1.5.dev66__py3-none-any.whl → 1.3.1.dev8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of modal might be problematic. Click here for more details.

Files changed (143) hide show
  1. modal/__init__.py +4 -4
  2. modal/__main__.py +4 -29
  3. modal/_billing.py +84 -0
  4. modal/_clustered_functions.py +1 -3
  5. modal/_container_entrypoint.py +33 -208
  6. modal/_functions.py +171 -138
  7. modal/_grpc_client.py +191 -0
  8. modal/_ipython.py +16 -6
  9. modal/_load_context.py +106 -0
  10. modal/_object.py +72 -21
  11. modal/_output.py +12 -14
  12. modal/_partial_function.py +31 -4
  13. modal/_resolver.py +44 -57
  14. modal/_runtime/container_io_manager.py +30 -28
  15. modal/_runtime/container_io_manager.pyi +42 -44
  16. modal/_runtime/gpu_memory_snapshot.py +9 -7
  17. modal/_runtime/user_code_event_loop.py +80 -0
  18. modal/_runtime/user_code_imports.py +236 -10
  19. modal/_serialization.py +2 -1
  20. modal/_traceback.py +4 -13
  21. modal/_tunnel.py +16 -11
  22. modal/_tunnel.pyi +25 -3
  23. modal/_utils/async_utils.py +337 -10
  24. modal/_utils/auth_token_manager.py +1 -4
  25. modal/_utils/blob_utils.py +29 -22
  26. modal/_utils/function_utils.py +20 -21
  27. modal/_utils/grpc_testing.py +6 -3
  28. modal/_utils/grpc_utils.py +223 -64
  29. modal/_utils/mount_utils.py +26 -1
  30. modal/_utils/name_utils.py +2 -3
  31. modal/_utils/package_utils.py +0 -1
  32. modal/_utils/rand_pb_testing.py +8 -1
  33. modal/_utils/task_command_router_client.py +524 -0
  34. modal/_vendor/cloudpickle.py +144 -48
  35. modal/app.py +285 -105
  36. modal/app.pyi +216 -53
  37. modal/billing.py +5 -0
  38. modal/builder/2025.06.txt +6 -3
  39. modal/builder/PREVIEW.txt +2 -1
  40. modal/builder/base-images.json +4 -2
  41. modal/cli/_download.py +19 -3
  42. modal/cli/cluster.py +4 -2
  43. modal/cli/config.py +3 -1
  44. modal/cli/container.py +5 -4
  45. modal/cli/dict.py +5 -2
  46. modal/cli/entry_point.py +26 -2
  47. modal/cli/environment.py +2 -16
  48. modal/cli/launch.py +1 -76
  49. modal/cli/network_file_system.py +5 -20
  50. modal/cli/programs/run_jupyter.py +1 -1
  51. modal/cli/programs/vscode.py +1 -1
  52. modal/cli/queues.py +5 -4
  53. modal/cli/run.py +24 -204
  54. modal/cli/secret.py +1 -2
  55. modal/cli/shell.py +375 -0
  56. modal/cli/utils.py +1 -13
  57. modal/cli/volume.py +11 -17
  58. modal/client.py +16 -125
  59. modal/client.pyi +94 -144
  60. modal/cloud_bucket_mount.py +3 -1
  61. modal/cloud_bucket_mount.pyi +4 -0
  62. modal/cls.py +101 -64
  63. modal/cls.pyi +9 -8
  64. modal/config.py +21 -1
  65. modal/container_process.py +288 -12
  66. modal/container_process.pyi +99 -38
  67. modal/dict.py +72 -33
  68. modal/dict.pyi +88 -57
  69. modal/environments.py +16 -8
  70. modal/environments.pyi +6 -2
  71. modal/exception.py +154 -16
  72. modal/experimental/__init__.py +24 -53
  73. modal/experimental/flash.py +161 -74
  74. modal/experimental/flash.pyi +97 -49
  75. modal/file_io.py +50 -92
  76. modal/file_io.pyi +117 -89
  77. modal/functions.pyi +70 -87
  78. modal/image.py +82 -47
  79. modal/image.pyi +51 -30
  80. modal/io_streams.py +500 -149
  81. modal/io_streams.pyi +279 -189
  82. modal/mount.py +60 -46
  83. modal/mount.pyi +41 -17
  84. modal/network_file_system.py +19 -11
  85. modal/network_file_system.pyi +72 -39
  86. modal/object.pyi +114 -22
  87. modal/parallel_map.py +42 -44
  88. modal/parallel_map.pyi +9 -17
  89. modal/partial_function.pyi +4 -2
  90. modal/proxy.py +14 -6
  91. modal/proxy.pyi +10 -2
  92. modal/queue.py +45 -38
  93. modal/queue.pyi +88 -52
  94. modal/runner.py +96 -96
  95. modal/runner.pyi +44 -27
  96. modal/sandbox.py +225 -107
  97. modal/sandbox.pyi +226 -60
  98. modal/secret.py +58 -56
  99. modal/secret.pyi +28 -13
  100. modal/serving.py +7 -11
  101. modal/serving.pyi +7 -8
  102. modal/snapshot.py +29 -15
  103. modal/snapshot.pyi +18 -10
  104. modal/token_flow.py +1 -1
  105. modal/token_flow.pyi +4 -6
  106. modal/volume.py +102 -55
  107. modal/volume.pyi +125 -66
  108. {modal-1.1.5.dev66.dist-info → modal-1.3.1.dev8.dist-info}/METADATA +10 -9
  109. modal-1.3.1.dev8.dist-info/RECORD +189 -0
  110. modal_proto/api.proto +141 -70
  111. modal_proto/api_grpc.py +42 -26
  112. modal_proto/api_pb2.py +1123 -1103
  113. modal_proto/api_pb2.pyi +331 -83
  114. modal_proto/api_pb2_grpc.py +80 -48
  115. modal_proto/api_pb2_grpc.pyi +26 -18
  116. modal_proto/modal_api_grpc.py +175 -174
  117. modal_proto/task_command_router.proto +164 -0
  118. modal_proto/task_command_router_grpc.py +138 -0
  119. modal_proto/task_command_router_pb2.py +180 -0
  120. modal_proto/{sandbox_router_pb2.pyi → task_command_router_pb2.pyi} +148 -57
  121. modal_proto/task_command_router_pb2_grpc.py +272 -0
  122. modal_proto/task_command_router_pb2_grpc.pyi +100 -0
  123. modal_version/__init__.py +1 -1
  124. modal_version/__main__.py +1 -1
  125. modal/cli/programs/launch_instance_ssh.py +0 -94
  126. modal/cli/programs/run_marimo.py +0 -95
  127. modal-1.1.5.dev66.dist-info/RECORD +0 -191
  128. modal_proto/modal_options_grpc.py +0 -3
  129. modal_proto/options.proto +0 -19
  130. modal_proto/options_grpc.py +0 -3
  131. modal_proto/options_pb2.py +0 -35
  132. modal_proto/options_pb2.pyi +0 -20
  133. modal_proto/options_pb2_grpc.py +0 -4
  134. modal_proto/options_pb2_grpc.pyi +0 -7
  135. modal_proto/sandbox_router.proto +0 -125
  136. modal_proto/sandbox_router_grpc.py +0 -89
  137. modal_proto/sandbox_router_pb2.py +0 -128
  138. modal_proto/sandbox_router_pb2_grpc.py +0 -169
  139. modal_proto/sandbox_router_pb2_grpc.pyi +0 -63
  140. {modal-1.1.5.dev66.dist-info → modal-1.3.1.dev8.dist-info}/WHEEL +0 -0
  141. {modal-1.1.5.dev66.dist-info → modal-1.3.1.dev8.dist-info}/entry_points.txt +0 -0
  142. {modal-1.1.5.dev66.dist-info → modal-1.3.1.dev8.dist-info}/licenses/LICENSE +0 -0
  143. {modal-1.1.5.dev66.dist-info → modal-1.3.1.dev8.dist-info}/top_level.txt +0 -0
@@ -8,7 +8,6 @@ from enum import Enum
8
8
  from pathlib import Path, PurePosixPath
9
9
  from typing import Any, Callable, Literal, Optional
10
10
 
11
- from grpclib import GRPCError
12
11
  from grpclib.exceptions import StreamTerminatedError
13
12
 
14
13
  import modal_proto
@@ -29,9 +28,11 @@ from ..exception import (
29
28
  DeserializationError,
30
29
  ExecutionError,
31
30
  FunctionTimeoutError,
31
+ InternalError,
32
32
  InternalFailure,
33
33
  InvalidError,
34
34
  RemoteError,
35
+ ServiceError,
35
36
  )
36
37
  from ..mount import ROOT_DIR, _is_modal_path, _Mount
37
38
  from .blob_utils import (
@@ -39,7 +40,6 @@ from .blob_utils import (
39
40
  blob_download,
40
41
  blob_upload_with_r2_failure_info,
41
42
  )
42
- from .grpc_utils import RETRYABLE_GRPC_STATUS_CODES
43
43
 
44
44
  if typing.TYPE_CHECKING:
45
45
  import modal._functions
@@ -75,6 +75,10 @@ def is_global_object(object_qual_name: str):
75
75
  return "<locals>" not in object_qual_name.split(".")
76
76
 
77
77
 
78
+ def is_flash_object(experimental_options: Optional[dict[str, Any]], http_config: Optional[api_pb2.HTTPConfig]) -> bool:
79
+ return bool(experimental_options and experimental_options.get("flash", False)) or http_config is not None
80
+
81
+
78
82
  def is_method_fn(object_qual_name: str):
79
83
  # methods have names like Cls.foo.
80
84
  if "<locals>" in object_qual_name:
@@ -125,6 +129,7 @@ class FunctionInfo:
125
129
 
126
130
  raw_f: Optional[Callable[..., Any]] # if None - this is a "class service function"
127
131
  function_name: str
132
+ implementation_name: str
128
133
  user_cls: Optional[type[Any]]
129
134
  module_name: Optional[str]
130
135
 
@@ -156,20 +161,16 @@ class FunctionInfo:
156
161
  self.raw_f = f
157
162
  self.user_cls = user_cls
158
163
 
159
- if name_override is not None:
160
- if not serialized:
161
- # We may relax this constraint in the future, but currently we don't track the distinction between
162
- # the Function's name inside modal and the name of the object that we need to import in a container.
163
- raise InvalidError("Setting a custom `name=` also requires setting `serialized=True`")
164
- self.function_name = name_override
165
- elif f is None and user_cls:
164
+ if f is None and user_cls:
166
165
  # "service function" for running all methods of a class
167
- self.function_name = f"{user_cls.__name__}.*"
166
+ self.implementation_name = f"{user_cls.__name__}.*"
168
167
  elif f and user_cls:
169
168
  # Method may be defined on superclass of the wrapped class
170
- self.function_name = f"{user_cls.__name__}.{f.__name__}"
169
+ self.implementation_name = f"{user_cls.__name__}.{f.__name__}"
171
170
  else:
172
- self.function_name = f.__qualname__
171
+ self.implementation_name = f.__qualname__
172
+
173
+ self.function_name = name_override or self.implementation_name
173
174
 
174
175
  # If it's a cls, the @method could be defined in a base class in a different file.
175
176
  if user_cls is not None:
@@ -436,15 +437,14 @@ async def _stream_function_call_data(
436
437
 
437
438
  last_index = chunk.index
438
439
  yield message
439
- except (GRPCError, StreamTerminatedError) as exc:
440
+ except (ServiceError, InternalError, StreamTerminatedError) as exc:
440
441
  if retries_remaining > 0:
441
442
  retries_remaining -= 1
442
- if isinstance(exc, GRPCError):
443
- if exc.status in RETRYABLE_GRPC_STATUS_CODES:
444
- logger.debug(f"{variant} stream retrying with delay {delay_ms}ms due to {exc}")
445
- await asyncio.sleep(delay_ms / 1000)
446
- delay_ms = min(1000, delay_ms * 10)
447
- continue
443
+ if isinstance(exc, (ServiceError, InternalError)):
444
+ logger.debug(f"{variant} stream retrying with delay {delay_ms}ms due to {exc}")
445
+ await asyncio.sleep(delay_ms / 1000)
446
+ delay_ms = min(1000, delay_ms * 10)
447
+ continue
448
448
  elif isinstance(exc, StreamTerminatedError):
449
449
  continue
450
450
  raise
@@ -641,14 +641,13 @@ class FunctionCreationStatus:
641
641
  if not self.response:
642
642
  self.status_row.finish(f"Unknown error when creating function {self.tag}")
643
643
 
644
- elif self.response.function.web_url:
644
+ elif web_url := self.response.handle_metadata.web_url:
645
645
  url_info = self.response.function.web_url_info
646
646
  requires_proxy_auth = self.response.function.webhook_config.requires_proxy_auth
647
647
  proxy_auth_suffix = " 🔑" if requires_proxy_auth else ""
648
648
  # Ensure terms used here match terms used in modal.com/docs/guide/webhook-urls doc.
649
649
  suffix = _get_suffix_from_web_url_info(url_info)
650
650
  # TODO: this is only printed when we're showing progress. Maybe move this somewhere else.
651
- web_url = self.response.handle_metadata.web_url
652
651
  for warning in self.response.server_warnings:
653
652
  self.status_row.warning(warning)
654
653
  self.status_row.finish(
@@ -45,8 +45,11 @@ def patch_mock_servicer(cls):
45
45
  Also patches all unimplemented abstract methods in a mock servicer with default error implementations.
46
46
  """
47
47
 
48
- async def fallback(self, stream) -> None:
49
- raise GRPCError(Status.UNIMPLEMENTED, "Not implemented in mock servicer " + repr(cls))
48
+ def fallback(name: str):
49
+ async def _fallback(self, stream) -> None:
50
+ raise GRPCError(Status.UNIMPLEMENTED, f"{name} not implemented in mock servicer " + repr(cls))
51
+
52
+ return _fallback
50
53
 
51
54
  @contextlib.contextmanager
52
55
  def intercept(servicer):
@@ -85,7 +88,7 @@ def patch_mock_servicer(cls):
85
88
  for name in dir(cls):
86
89
  method = getattr(cls, name)
87
90
  if getattr(method, "__isabstractmethod__", False):
88
- setattr(cls, name, patch_grpc_method(name, fallback))
91
+ setattr(cls, name, patch_grpc_method(name, fallback(name)))
89
92
  elif name[0].isupper() and inspect.isfunction(method):
90
93
  setattr(cls, name, patch_grpc_method(name, method))
91
94
 
@@ -1,6 +1,7 @@
1
1
  # Copyright Modal Labs 2022
2
2
  import asyncio
3
3
  import contextlib
4
+ import os
4
5
  import platform
5
6
  import socket
6
7
  import time
@@ -8,26 +9,26 @@ import typing
8
9
  import urllib.parse
9
10
  import uuid
10
11
  from collections.abc import AsyncIterator
11
- from dataclasses import dataclass
12
- from typing import (
13
- Any,
14
- Optional,
15
- TypeVar,
16
- )
12
+ from dataclasses import dataclass, field
13
+ from functools import cache
14
+ from typing import Any, Optional, Sequence, TypeVar
17
15
 
18
16
  import grpclib.client
19
17
  import grpclib.config
20
18
  import grpclib.events
21
- import grpclib.protocol
22
- import grpclib.stream
23
19
  from google.protobuf.message import Message
20
+ from google.protobuf.symbol_database import SymbolDatabase
24
21
  from grpclib import GRPCError, Status
22
+ from grpclib.encoding.base import StatusDetailsCodecBase
25
23
  from grpclib.exceptions import StreamTerminatedError
26
24
  from grpclib.protocol import H2Protocol
27
25
 
28
- from modal.exception import AuthError, ConnectionError
26
+ from modal.exception import ConnectionError
27
+ from modal_proto import api_pb2
29
28
  from modal_version import __version__
30
29
 
30
+ from .._traceback import suppress_tb_frame
31
+ from ..config import config
31
32
  from .async_utils import retry
32
33
  from .logger import logger
33
34
 
@@ -35,6 +36,7 @@ RequestType = TypeVar("RequestType", bound=Message)
35
36
  ResponseType = TypeVar("ResponseType", bound=Message)
36
37
 
37
38
  if typing.TYPE_CHECKING:
39
+ import modal._grpc_client
38
40
  import modal.client
39
41
 
40
42
  # Monkey patches grpclib to have a Modal User Agent header.
@@ -70,6 +72,7 @@ RETRYABLE_GRPC_STATUS_CODES = [
70
72
  Status.INTERNAL,
71
73
  Status.UNKNOWN,
72
74
  ]
75
+ SERVER_RETRY_WARNING_TIME_INTERVAL = 30.0
73
76
 
74
77
 
75
78
  @dataclass
@@ -109,6 +112,56 @@ class ConnectionManager:
109
112
  self._channels.clear()
110
113
 
111
114
 
115
+ @cache
116
+ def _sym_db() -> SymbolDatabase:
117
+ from google.protobuf.symbol_database import Default
118
+
119
+ return Default()
120
+
121
+
122
+ class CustomProtoStatusDetailsCodec(StatusDetailsCodecBase):
123
+ """grpclib compatible details codec.
124
+
125
+ The server can encode the details using `google.rpc.Status` using grpclib's default codec and this custom codec
126
+ can decode it into a `api_pb2.RPCStatus`.
127
+ """
128
+
129
+ def encode(
130
+ self,
131
+ status: Status,
132
+ message: Optional[str],
133
+ details: Optional[Sequence[Message]],
134
+ ) -> bytes:
135
+ details_proto = api_pb2.RPCStatus(code=status.value, message=message or "")
136
+ if details is not None:
137
+ for detail in details:
138
+ detail_container = details_proto.details.add()
139
+ detail_container.Pack(detail)
140
+ return details_proto.SerializeToString()
141
+
142
+ def decode(
143
+ self,
144
+ status: Status,
145
+ message: Optional[str],
146
+ data: bytes,
147
+ ) -> Any:
148
+ sym_db = _sym_db()
149
+ details_proto = api_pb2.RPCStatus.FromString(data)
150
+
151
+ details = []
152
+ for detail_container in details_proto.details:
153
+ # If we do not know how to decode an emssage, we'll ignore it.
154
+ with contextlib.suppress(Exception):
155
+ msg_type = sym_db.GetSymbol(detail_container.TypeName())
156
+ detail = msg_type()
157
+ detail_container.Unpack(detail)
158
+ details.append(detail)
159
+ return details
160
+
161
+
162
+ custom_detail_codec = CustomProtoStatusDetailsCodec()
163
+
164
+
112
165
  def create_channel(
113
166
  server_url: str,
114
167
  metadata: dict[str, str] = {},
@@ -127,7 +180,7 @@ def create_channel(
127
180
  )
128
181
 
129
182
  if o.scheme == "unix":
130
- channel = grpclib.client.Channel(path=o.path, config=config) # probably pointless to use a pool ever
183
+ channel = grpclib.client.Channel(path=o.path, config=config, status_details_codec=custom_detail_codec)
131
184
  elif o.scheme in ("http", "https"):
132
185
  target = o.netloc
133
186
  parts = target.split(":")
@@ -135,7 +188,7 @@ def create_channel(
135
188
  ssl = o.scheme.endswith("s")
136
189
  host = parts[0]
137
190
  port = int(parts[1]) if len(parts) == 2 else 443 if ssl else 80
138
- channel = grpclib.client.Channel(host, port, ssl=ssl, config=config)
191
+ channel = grpclib.client.Channel(host, port, ssl=ssl, config=config, status_details_codec=custom_detail_codec)
139
192
  else:
140
193
  raise Exception(f"Unknown scheme: {o.scheme}")
141
194
 
@@ -165,7 +218,7 @@ if typing.TYPE_CHECKING:
165
218
 
166
219
 
167
220
  async def unary_stream(
168
- method: "modal.client.UnaryStreamWrapper[RequestType, ResponseType]",
221
+ method: "modal._grpc_client.UnaryStreamWrapper[RequestType, ResponseType]",
169
222
  request: RequestType,
170
223
  metadata: Optional[Any] = None,
171
224
  ) -> AsyncIterator[ResponseType]:
@@ -174,102 +227,208 @@ async def unary_stream(
174
227
  yield item
175
228
 
176
229
 
230
+ @dataclass(frozen=True)
231
+ class Retry:
232
+ base_delay: float = 0.1
233
+ max_delay: float = 1
234
+ delay_factor: float = 2
235
+ max_retries: Optional[int] = 3
236
+ additional_status_codes: list = field(default_factory=list)
237
+ attempt_timeout: Optional[float] = None # timeout for each attempt
238
+ total_timeout: Optional[float] = None # timeout for the entire function call
239
+ attempt_timeout_floor: float = 2.0 # always have at least this much timeout (only for total_timeout)
240
+ warning_message: Optional[RetryWarningMessage] = None
241
+
242
+
177
243
  async def retry_transient_errors(
178
- fn: "modal.client.UnaryUnaryWrapper[RequestType, ResponseType]",
179
- *args,
180
- base_delay: float = 0.1,
181
- max_delay: float = 1,
182
- delay_factor: float = 2,
244
+ fn: "grpclib.client.UnaryUnaryMethod[RequestType, ResponseType]",
245
+ req: RequestType,
183
246
  max_retries: Optional[int] = 3,
184
- additional_status_codes: list = [],
185
- attempt_timeout: Optional[float] = None, # timeout for each attempt
186
- total_timeout: Optional[float] = None, # timeout for the entire function call
187
- attempt_timeout_floor=2.0, # always have at least this much timeout (only for total_timeout)
188
- retry_warning_message: Optional[RetryWarningMessage] = None,
189
- metadata: list[tuple[str, str]] = [],
247
+ ) -> ResponseType:
248
+ """Minimum API version of _retry_transient_errors that works with grpclib.client.UnaryUnaryMethod.
249
+
250
+ Used by modal server.
251
+ """
252
+ return await _retry_transient_errors(fn, req, retry=Retry(max_retries=max_retries))
253
+
254
+
255
+ def get_server_retry_policy(exc: Exception) -> Optional[api_pb2.RPCRetryPolicy]:
256
+ """Get server retry policy."""
257
+ if not isinstance(exc, GRPCError) or not exc.details:
258
+ return None
259
+
260
+ # Server should not set multiple retry instructions, but if there is more than one, pick the first one
261
+ for entry in exc.details:
262
+ if isinstance(entry, api_pb2.RPCRetryPolicy):
263
+ return entry
264
+ return None
265
+
266
+
267
+ def process_exception_before_retry(
268
+ exc: Exception,
269
+ final_attempt: bool,
270
+ fn_name: str,
271
+ n_retries: int,
272
+ delay: float,
273
+ idempotency_key: str,
274
+ ):
275
+ """Process exception before retry, used by `_retry_transient_errors`."""
276
+ with suppress_tb_frame():
277
+ if final_attempt:
278
+ logger.debug(
279
+ f"Final attempt failed with {repr(exc)} {n_retries=} {delay=} for {fn_name} ({idempotency_key[:8]})"
280
+ )
281
+ if isinstance(exc, OSError):
282
+ raise ConnectionError(str(exc))
283
+ elif isinstance(exc, asyncio.TimeoutError):
284
+ raise ConnectionError(str(exc))
285
+ else:
286
+ raise exc
287
+
288
+ if isinstance(exc, AttributeError) and "_write_appdata" not in str(exc):
289
+ # StreamTerminatedError are not properly raised in grpclib<=0.4.7
290
+ # fixed in https://github.com/vmagamedov/grpclib/issues/185
291
+ # TODO: update to newer version (>=0.4.8) once stable
292
+ # Also be sure to remove the AttributeError from the set of exceptions
293
+ # we handle in the retry logic once we drop this check!
294
+ raise exc
295
+
296
+ logger.debug(f"Retryable failure {repr(exc)} {n_retries=} {delay=} for {fn_name} ({idempotency_key[:8]})")
297
+
298
+
299
+ async def _retry_transient_errors(
300
+ fn: typing.Union[
301
+ "modal._grpc_client.UnaryUnaryWrapper[RequestType, ResponseType]",
302
+ "grpclib.client.UnaryUnaryMethod[RequestType, ResponseType]",
303
+ ],
304
+ req: RequestType,
305
+ retry: Retry,
306
+ metadata: Optional[list[tuple[str, str]]] = None,
190
307
  ) -> ResponseType:
191
308
  """Retry on transient gRPC failures with back-off until max_retries is reached.
192
309
  If max_retries is None, retry forever."""
310
+ import modal._grpc_client
193
311
 
194
- delay = base_delay
312
+ if isinstance(fn, modal._grpc_client.UnaryUnaryWrapper):
313
+ fn_callable = fn.direct
314
+ elif isinstance(fn, grpclib.client.UnaryUnaryMethod):
315
+ fn_callable = fn # type: ignore
316
+ else:
317
+ raise ValueError("Only modal._grpc_client.UnaryUnaryWrapper and grpclib.client.UnaryUnaryMethod are supported")
318
+
319
+ delay = retry.base_delay
195
320
  n_retries = 0
321
+ n_throttled_retries = 0
196
322
 
197
- status_codes = [*RETRYABLE_GRPC_STATUS_CODES, *additional_status_codes]
323
+ status_codes = [*RETRYABLE_GRPC_STATUS_CODES, *retry.additional_status_codes]
198
324
 
199
325
  idempotency_key = str(uuid.uuid4())
200
326
 
201
327
  t0 = time.time()
202
- if total_timeout is not None:
203
- total_deadline = t0 + total_timeout
328
+ last_server_retry_warning_time = None
329
+
330
+ if retry.total_timeout is not None:
331
+ total_deadline = t0 + retry.total_timeout
204
332
  else:
205
333
  total_deadline = None
206
334
 
207
- metadata = metadata + [("x-modal-timestamp", str(time.time()))]
335
+ metadata = (metadata or []) + [("x-modal-timestamp", str(time.time()))]
336
+
208
337
  while True:
209
338
  attempt_metadata = [
210
339
  ("x-idempotency-key", idempotency_key),
211
340
  ("x-retry-attempt", str(n_retries)),
341
+ ("x-throttle-retry-attempt", str(n_throttled_retries)),
212
342
  *metadata,
213
343
  ]
214
344
  if n_retries > 0:
215
345
  attempt_metadata.append(("x-retry-delay", str(time.time() - t0)))
346
+ if n_throttled_retries > 0:
347
+ attempt_metadata.append(("x-throttle-retry-delay", str(time.time() - t0)))
348
+
216
349
  timeouts = []
217
- if attempt_timeout is not None:
218
- timeouts.append(attempt_timeout)
219
- if total_timeout is not None:
220
- timeouts.append(max(total_deadline - time.time(), attempt_timeout_floor))
350
+ if retry.attempt_timeout is not None:
351
+ timeouts.append(retry.attempt_timeout)
352
+ if total_deadline is not None:
353
+ timeouts.append(max(total_deadline - time.time(), retry.attempt_timeout_floor))
221
354
  if timeouts:
222
355
  timeout = min(timeouts) # In case the function provided both types of timeouts
223
356
  else:
224
357
  timeout = None
358
+
225
359
  try:
226
- return await fn(*args, metadata=attempt_metadata, timeout=timeout)
360
+ with suppress_tb_frame():
361
+ return await fn_callable(req, metadata=attempt_metadata, timeout=timeout)
227
362
  except (StreamTerminatedError, GRPCError, OSError, asyncio.TimeoutError, AttributeError) as exc:
228
- if isinstance(exc, GRPCError) and exc.status not in status_codes:
229
- if exc.status == Status.UNAUTHENTICATED:
230
- raise AuthError(exc.message)
231
- else:
232
- raise exc
363
+ # Note that we only catch AttributeError to handle a specific case that works around a bug
364
+ # in grpclib<=0.4.7. See above (search for `write_appdata`).
233
365
 
234
- if max_retries is not None and n_retries >= max_retries:
366
+ # Server side instruction for retries
367
+ max_throttle_wait: Optional[int] = config.get("max_throttle_wait")
368
+ if (
369
+ max_throttle_wait != 0
370
+ and isinstance(exc, GRPCError)
371
+ and (server_retry_policy := get_server_retry_policy(exc))
372
+ ):
373
+ server_delay = server_retry_policy.retry_after_secs
374
+
375
+ now = time.time()
376
+
377
+ # We check if the timeout will be reached **after** the sleep, so we can raise an error early
378
+ # without needing to actually sleep.
379
+ total_timeout_will_be_reached = (
380
+ retry.total_timeout is not None and (now + server_delay - t0) >= retry.total_timeout
381
+ )
382
+ max_throttle_will_be_reached = (
383
+ max_throttle_wait is not None and (now + server_delay - t0) >= max_throttle_wait
384
+ )
385
+ final_attempt = total_timeout_will_be_reached or max_throttle_will_be_reached
386
+
387
+ with suppress_tb_frame():
388
+ process_exception_before_retry(
389
+ exc, final_attempt, fn.name, n_retries, server_delay, idempotency_key
390
+ )
391
+
392
+ now = time.time()
393
+ if last_server_retry_warning_time is None or (
394
+ now - last_server_retry_warning_time >= SERVER_RETRY_WARNING_TIME_INTERVAL
395
+ ):
396
+ last_server_retry_warning_time = now
397
+ logger.warning(
398
+ f"Warning: Received {exc.status}{os.linesep}"
399
+ f"{exc.message}{os.linesep}"
400
+ f"Will retry in {server_delay:0.2f} seconds."
401
+ )
402
+
403
+ n_throttled_retries += 1
404
+ await asyncio.sleep(server_delay)
405
+ continue
406
+
407
+ # Client handles retry
408
+ if isinstance(exc, GRPCError) and exc.status not in status_codes:
409
+ raise exc
410
+ if retry.max_retries is not None and n_retries >= retry.max_retries:
235
411
  final_attempt = True
236
- elif total_deadline is not None and time.time() + delay + attempt_timeout_floor >= total_deadline:
412
+ elif total_deadline is not None and time.time() + delay + retry.attempt_timeout_floor >= total_deadline:
237
413
  final_attempt = True
238
414
  else:
239
415
  final_attempt = False
240
416
 
241
- if final_attempt:
242
- logger.debug(
243
- f"Final attempt failed with {repr(exc)} {n_retries=} {delay=} "
244
- f"{total_deadline=} for {fn.name} ({idempotency_key[:8]})"
245
- )
246
- if isinstance(exc, OSError):
247
- raise ConnectionError(str(exc))
248
- elif isinstance(exc, asyncio.TimeoutError):
249
- raise ConnectionError(str(exc))
250
- else:
251
- raise exc
252
-
253
- if isinstance(exc, AttributeError) and "_write_appdata" not in str(exc):
254
- # StreamTerminatedError are not properly raised in grpclib<=0.4.7
255
- # fixed in https://github.com/vmagamedov/grpclib/issues/185
256
- # TODO: update to newer version (>=0.4.8) once stable
257
- raise exc
258
-
259
- logger.debug(f"Retryable failure {repr(exc)} {n_retries=} {delay=} for {fn.name} ({idempotency_key[:8]})")
417
+ with suppress_tb_frame():
418
+ process_exception_before_retry(exc, final_attempt, fn.name, n_retries, delay, idempotency_key)
260
419
 
261
420
  n_retries += 1
262
421
 
263
422
  if (
264
- retry_warning_message
265
- and n_retries % retry_warning_message.warning_interval == 0
423
+ retry.warning_message
424
+ and n_retries % retry.warning_message.warning_interval == 0
266
425
  and isinstance(exc, GRPCError)
267
- and exc.status in retry_warning_message.errors_to_warn_for
426
+ and exc.status in retry.warning_message.errors_to_warn_for
268
427
  ):
269
- logger.warning(retry_warning_message.message)
428
+ logger.warning(retry.warning_message.message)
270
429
 
271
430
  await asyncio.sleep(delay)
272
- delay = min(delay * delay_factor, max_delay)
431
+ delay = min(delay * retry.delay_factor, retry.max_delay)
273
432
 
274
433
 
275
434
  def find_free_port() -> int:
@@ -3,7 +3,9 @@ import posixpath
3
3
  import typing
4
4
  from collections.abc import Mapping, Sequence
5
5
  from pathlib import PurePath, PurePosixPath
6
- from typing import Union
6
+ from typing import Optional, Union
7
+
8
+ from typing_extensions import TypeGuard
7
9
 
8
10
  from ..cloud_bucket_mount import _CloudBucketMount
9
11
  from ..exception import InvalidError
@@ -76,3 +78,26 @@ def validate_volumes(
76
78
  )
77
79
 
78
80
  return validated_volumes
81
+
82
+
83
+ def validate_only_modal_volumes(
84
+ volumes: Optional[Optional[dict[Union[str, PurePosixPath], _Volume]]],
85
+ caller_name: str,
86
+ ) -> Sequence[tuple[str, _Volume]]:
87
+ """Validate all volumes are `modal.Volume`."""
88
+ if volumes is None:
89
+ return []
90
+
91
+ validated_volumes = validate_volumes(volumes)
92
+
93
+ # Although the typing forbids `_CloudBucketMount` for type checking, one can still pass a `_CloudBucketMount`
94
+ # during runtime, so we'll check the type here.
95
+ def all_modal_volumes(
96
+ vols: Sequence[tuple[str, Union[_Volume, _CloudBucketMount]]],
97
+ ) -> TypeGuard[Sequence[tuple[str, _Volume]]]:
98
+ return all(isinstance(v, _Volume) for _, v in vols)
99
+
100
+ if not all_modal_volumes(validated_volumes):
101
+ raise InvalidError(f"{caller_name} only supports volumes that are modal.Volume")
102
+
103
+ return validated_volumes
@@ -1,5 +1,6 @@
1
1
  # Copyright Modal Labs 2022
2
2
  import re
3
+ from collections.abc import Mapping
3
4
 
4
5
  from ..exception import InvalidError
5
6
 
@@ -37,7 +38,7 @@ def is_valid_tag(tag: str, max_length: int = 50) -> bool:
37
38
  return bool(re.match(pattern, tag))
38
39
 
39
40
 
40
- def check_tag_dict(tags: dict[str, str]) -> dict[str, str]:
41
+ def check_tag_dict(tags: Mapping[str, str]) -> None:
41
42
  rules = (
42
43
  "\n\nTags may contain only alphanumeric characters, dashes, periods, or underscores, "
43
44
  "and must be 63 characters or less."
@@ -49,8 +50,6 @@ def check_tag_dict(tags: dict[str, str]) -> dict[str, str]:
49
50
  if not is_valid_tag(value, max_length):
50
51
  raise InvalidError(f"Invalid tag value: {value!r}.{rules}")
51
52
 
52
- return tags
53
-
54
53
 
55
54
  def check_object_name(name: str, object_type: str) -> None:
56
55
  message = (
@@ -1,5 +1,4 @@
1
1
  # Copyright Modal Labs 2022
2
- import importlib
3
2
  import importlib.util
4
3
  import typing
5
4
  from importlib.metadata import PackageNotFoundError, files
@@ -45,7 +45,14 @@ def _fill(msg, desc: Descriptor, rand: Random) -> None:
45
45
  if field.containing_oneof is not None and field.name not in oneof_fields:
46
46
  continue
47
47
  is_message = field.type == FieldDescriptor.TYPE_MESSAGE
48
- is_repeated = field.label == FieldDescriptor.LABEL_REPEATED
48
+
49
+ # In the C implemenation of protobuf for Python 3.14, it raises a depreation
50
+ # warning when labels is accessed, but it does not clean up the exception state,
51
+ # causing an SystemError.
52
+ if hasattr(field, "is_repeated"):
53
+ is_repeated = field.is_repeated # type: ignore
54
+ else:
55
+ is_repeated = field.label == FieldDescriptor.LABEL_REPEATED
49
56
  if is_message:
50
57
  msg_field = getattr(msg, field.name)
51
58
  if is_repeated: