modal 1.2.1.dev8__py3-none-any.whl → 1.2.2.dev19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. modal/_clustered_functions.py +1 -3
  2. modal/_container_entrypoint.py +4 -1
  3. modal/_functions.py +33 -49
  4. modal/_grpc_client.py +148 -0
  5. modal/_output.py +3 -4
  6. modal/_partial_function.py +22 -2
  7. modal/_runtime/container_io_manager.py +21 -22
  8. modal/_utils/async_utils.py +12 -3
  9. modal/_utils/auth_token_manager.py +1 -4
  10. modal/_utils/blob_utils.py +3 -4
  11. modal/_utils/function_utils.py +4 -0
  12. modal/_utils/grpc_utils.py +80 -51
  13. modal/_utils/mount_utils.py +26 -1
  14. modal/_utils/task_command_router_client.py +536 -0
  15. modal/app.py +7 -5
  16. modal/cli/cluster.py +4 -2
  17. modal/cli/config.py +3 -1
  18. modal/cli/container.py +5 -4
  19. modal/cli/entry_point.py +1 -0
  20. modal/cli/launch.py +1 -2
  21. modal/cli/network_file_system.py +1 -4
  22. modal/cli/queues.py +1 -2
  23. modal/cli/secret.py +1 -2
  24. modal/client.py +5 -115
  25. modal/client.pyi +2 -91
  26. modal/cls.py +1 -2
  27. modal/config.py +3 -1
  28. modal/container_process.py +287 -11
  29. modal/container_process.pyi +95 -32
  30. modal/dict.py +12 -12
  31. modal/environments.py +1 -2
  32. modal/exception.py +4 -0
  33. modal/experimental/__init__.py +2 -3
  34. modal/experimental/flash.py +27 -57
  35. modal/experimental/flash.pyi +6 -20
  36. modal/file_io.py +13 -27
  37. modal/functions.pyi +6 -6
  38. modal/image.py +24 -3
  39. modal/image.pyi +4 -0
  40. modal/io_streams.py +433 -127
  41. modal/io_streams.pyi +236 -171
  42. modal/mount.py +4 -4
  43. modal/network_file_system.py +5 -6
  44. modal/parallel_map.py +29 -31
  45. modal/parallel_map.pyi +3 -9
  46. modal/partial_function.pyi +4 -1
  47. modal/queue.py +17 -18
  48. modal/runner.py +12 -11
  49. modal/sandbox.py +148 -42
  50. modal/sandbox.pyi +139 -0
  51. modal/secret.py +4 -5
  52. modal/snapshot.py +1 -4
  53. modal/token_flow.py +1 -1
  54. modal/volume.py +22 -22
  55. {modal-1.2.1.dev8.dist-info → modal-1.2.2.dev19.dist-info}/METADATA +1 -1
  56. {modal-1.2.1.dev8.dist-info → modal-1.2.2.dev19.dist-info}/RECORD +70 -68
  57. modal_proto/api.proto +2 -24
  58. modal_proto/api_grpc.py +0 -32
  59. modal_proto/api_pb2.py +838 -878
  60. modal_proto/api_pb2.pyi +8 -70
  61. modal_proto/api_pb2_grpc.py +0 -67
  62. modal_proto/api_pb2_grpc.pyi +0 -22
  63. modal_proto/modal_api_grpc.py +175 -177
  64. modal_proto/sandbox_router.proto +0 -4
  65. modal_proto/sandbox_router_pb2.pyi +0 -4
  66. modal_version/__init__.py +1 -1
  67. {modal-1.2.1.dev8.dist-info → modal-1.2.2.dev19.dist-info}/WHEEL +0 -0
  68. {modal-1.2.1.dev8.dist-info → modal-1.2.2.dev19.dist-info}/entry_points.txt +0 -0
  69. {modal-1.2.1.dev8.dist-info → modal-1.2.2.dev19.dist-info}/licenses/LICENSE +0 -0
  70. {modal-1.2.1.dev8.dist-info → modal-1.2.2.dev19.dist-info}/top_level.txt +0 -0
@@ -75,6 +75,10 @@ def is_global_object(object_qual_name: str):
75
75
  return "<locals>" not in object_qual_name.split(".")
76
76
 
77
77
 
78
+ def is_flash_object(experimental_options: Optional[dict[str, Any]]) -> bool:
79
+ return experimental_options.get("flash", False) if experimental_options else False
80
+
81
+
78
82
  def is_method_fn(object_qual_name: str):
79
83
  # methods have names like Cls.foo.
80
84
  if "<locals>" in object_qual_name:
@@ -8,12 +8,8 @@ import typing
8
8
  import urllib.parse
9
9
  import uuid
10
10
  from collections.abc import AsyncIterator
11
- from dataclasses import dataclass
12
- from typing import (
13
- Any,
14
- Optional,
15
- TypeVar,
16
- )
11
+ from dataclasses import dataclass, field
12
+ from typing import Any, Optional, TypeVar
17
13
 
18
14
  import grpclib.client
19
15
  import grpclib.config
@@ -28,6 +24,7 @@ from grpclib.protocol import H2Protocol
28
24
  from modal.exception import AuthError, ConnectionError
29
25
  from modal_version import __version__
30
26
 
27
+ from .._traceback import suppress_tb_frames
31
28
  from .async_utils import retry
32
29
  from .logger import logger
33
30
 
@@ -35,6 +32,7 @@ RequestType = TypeVar("RequestType", bound=Message)
35
32
  ResponseType = TypeVar("ResponseType", bound=Message)
36
33
 
37
34
  if typing.TYPE_CHECKING:
35
+ import modal._grpc_client
38
36
  import modal.client
39
37
 
40
38
  # Monkey patches grpclib to have a Modal User Agent header.
@@ -165,7 +163,7 @@ if typing.TYPE_CHECKING:
165
163
 
166
164
 
167
165
  async def unary_stream(
168
- method: "modal.client.UnaryStreamWrapper[RequestType, ResponseType]",
166
+ method: "modal._grpc_client.UnaryStreamWrapper[RequestType, ResponseType]",
169
167
  request: RequestType,
170
168
  metadata: Optional[Any] = None,
171
169
  ) -> AsyncIterator[ResponseType]:
@@ -174,37 +172,66 @@ async def unary_stream(
174
172
  yield item
175
173
 
176
174
 
175
+ @dataclass(frozen=True)
176
+ class Retry:
177
+ base_delay: float = 0.1
178
+ max_delay: float = 1
179
+ delay_factor: float = 2
180
+ max_retries: Optional[int] = 3
181
+ additional_status_codes: list = field(default_factory=list)
182
+ attempt_timeout: Optional[float] = None # timeout for each attempt
183
+ total_timeout: Optional[float] = None # timeout for the entire function call
184
+ attempt_timeout_floor: float = 2.0 # always have at least this much timeout (only for total_timeout)
185
+ warning_message: Optional[RetryWarningMessage] = None
186
+
187
+
177
188
  async def retry_transient_errors(
178
- fn: "modal.client.UnaryUnaryWrapper[RequestType, ResponseType]",
179
- *args,
180
- base_delay: float = 0.1,
181
- max_delay: float = 1,
182
- delay_factor: float = 2,
189
+ fn: "grpclib.client.UnaryUnaryMethod[RequestType, ResponseType]",
190
+ req: RequestType,
183
191
  max_retries: Optional[int] = 3,
184
- additional_status_codes: list = [],
185
- attempt_timeout: Optional[float] = None, # timeout for each attempt
186
- total_timeout: Optional[float] = None, # timeout for the entire function call
187
- attempt_timeout_floor=2.0, # always have at least this much timeout (only for total_timeout)
188
- retry_warning_message: Optional[RetryWarningMessage] = None,
189
- metadata: list[tuple[str, str]] = [],
192
+ ) -> ResponseType:
193
+ """Minimum API version of _retry_transient_errors that works with grpclib.client.UnaryUnaryMethod.
194
+
195
+ Used by modal server.
196
+ """
197
+ return await _retry_transient_errors(fn, req, retry=Retry(max_retries=max_retries))
198
+
199
+
200
+ async def _retry_transient_errors(
201
+ fn: typing.Union[
202
+ "modal._grpc_client.UnaryUnaryWrapper[RequestType, ResponseType]",
203
+ "grpclib.client.UnaryUnaryMethod[RequestType, ResponseType]",
204
+ ],
205
+ req: RequestType,
206
+ retry: Retry,
207
+ metadata: Optional[list[tuple[str, str]]] = None,
190
208
  ) -> ResponseType:
191
209
  """Retry on transient gRPC failures with back-off until max_retries is reached.
192
210
  If max_retries is None, retry forever."""
211
+ import modal._grpc_client
212
+
213
+ if isinstance(fn, modal._grpc_client.UnaryUnaryWrapper):
214
+ fn_callable = fn.direct
215
+ elif isinstance(fn, grpclib.client.UnaryUnaryMethod):
216
+ fn_callable = fn # type: ignore
217
+ else:
218
+ raise ValueError("Only modal._grpc_client.UnaryUnaryWrapper and grpclib.client.UnaryUnaryMethod are supported")
193
219
 
194
- delay = base_delay
220
+ delay = retry.base_delay
195
221
  n_retries = 0
196
222
 
197
- status_codes = [*RETRYABLE_GRPC_STATUS_CODES, *additional_status_codes]
223
+ status_codes = [*RETRYABLE_GRPC_STATUS_CODES, *retry.additional_status_codes]
198
224
 
199
225
  idempotency_key = str(uuid.uuid4())
200
226
 
201
227
  t0 = time.time()
202
- if total_timeout is not None:
203
- total_deadline = t0 + total_timeout
228
+ if retry.total_timeout is not None:
229
+ total_deadline = t0 + retry.total_timeout
204
230
  else:
205
231
  total_deadline = None
206
232
 
207
- metadata = metadata + [("x-modal-timestamp", str(time.time()))]
233
+ metadata = (metadata or []) + [("x-modal-timestamp", str(time.time()))]
234
+
208
235
  while True:
209
236
  attempt_metadata = [
210
237
  ("x-idempotency-key", idempotency_key),
@@ -214,16 +241,17 @@ async def retry_transient_errors(
214
241
  if n_retries > 0:
215
242
  attempt_metadata.append(("x-retry-delay", str(time.time() - t0)))
216
243
  timeouts = []
217
- if attempt_timeout is not None:
218
- timeouts.append(attempt_timeout)
219
- if total_timeout is not None:
220
- timeouts.append(max(total_deadline - time.time(), attempt_timeout_floor))
244
+ if retry.attempt_timeout is not None:
245
+ timeouts.append(retry.attempt_timeout)
246
+ if retry.total_timeout is not None and total_deadline is not None:
247
+ timeouts.append(max(total_deadline - time.time(), retry.attempt_timeout_floor))
221
248
  if timeouts:
222
249
  timeout = min(timeouts) # In case the function provided both types of timeouts
223
250
  else:
224
251
  timeout = None
225
252
  try:
226
- return await fn(*args, metadata=attempt_metadata, timeout=timeout)
253
+ with suppress_tb_frames(1):
254
+ return await fn_callable(req, metadata=attempt_metadata, timeout=timeout)
227
255
  except (StreamTerminatedError, GRPCError, OSError, asyncio.TimeoutError, AttributeError) as exc:
228
256
  if isinstance(exc, GRPCError) and exc.status not in status_codes:
229
257
  if exc.status == Status.UNAUTHENTICATED:
@@ -231,45 +259,46 @@ async def retry_transient_errors(
231
259
  else:
232
260
  raise exc
233
261
 
234
- if max_retries is not None and n_retries >= max_retries:
262
+ if retry.max_retries is not None and n_retries >= retry.max_retries:
235
263
  final_attempt = True
236
- elif total_deadline is not None and time.time() + delay + attempt_timeout_floor >= total_deadline:
264
+ elif total_deadline is not None and time.time() + delay + retry.attempt_timeout_floor >= total_deadline:
237
265
  final_attempt = True
238
266
  else:
239
267
  final_attempt = False
240
268
 
241
- if final_attempt:
242
- logger.debug(
243
- f"Final attempt failed with {repr(exc)} {n_retries=} {delay=} "
244
- f"{total_deadline=} for {fn.name} ({idempotency_key[:8]})"
245
- )
246
- if isinstance(exc, OSError):
247
- raise ConnectionError(str(exc))
248
- elif isinstance(exc, asyncio.TimeoutError):
249
- raise ConnectionError(str(exc))
250
- else:
269
+ with suppress_tb_frames(1):
270
+ if final_attempt:
271
+ logger.debug(
272
+ f"Final attempt failed with {repr(exc)} {n_retries=} {delay=} "
273
+ f"{total_deadline=} for {fn.name} ({idempotency_key[:8]})"
274
+ )
275
+ if isinstance(exc, OSError):
276
+ raise ConnectionError(str(exc))
277
+ elif isinstance(exc, asyncio.TimeoutError):
278
+ raise ConnectionError(str(exc))
279
+ else:
280
+ raise exc
281
+
282
+ if isinstance(exc, AttributeError) and "_write_appdata" not in str(exc):
283
+ # StreamTerminatedError are not properly raised in grpclib<=0.4.7
284
+ # fixed in https://github.com/vmagamedov/grpclib/issues/185
285
+ # TODO: update to newer version (>=0.4.8) once stable
251
286
  raise exc
252
287
 
253
- if isinstance(exc, AttributeError) and "_write_appdata" not in str(exc):
254
- # StreamTerminatedError are not properly raised in grpclib<=0.4.7
255
- # fixed in https://github.com/vmagamedov/grpclib/issues/185
256
- # TODO: update to newer version (>=0.4.8) once stable
257
- raise exc
258
-
259
288
  logger.debug(f"Retryable failure {repr(exc)} {n_retries=} {delay=} for {fn.name} ({idempotency_key[:8]})")
260
289
 
261
290
  n_retries += 1
262
291
 
263
292
  if (
264
- retry_warning_message
265
- and n_retries % retry_warning_message.warning_interval == 0
293
+ retry.warning_message
294
+ and n_retries % retry.warning_message.warning_interval == 0
266
295
  and isinstance(exc, GRPCError)
267
- and exc.status in retry_warning_message.errors_to_warn_for
296
+ and exc.status in retry.warning_message.errors_to_warn_for
268
297
  ):
269
- logger.warning(retry_warning_message.message)
298
+ logger.warning(retry.warning_message.message)
270
299
 
271
300
  await asyncio.sleep(delay)
272
- delay = min(delay * delay_factor, max_delay)
301
+ delay = min(delay * retry.delay_factor, retry.max_delay)
273
302
 
274
303
 
275
304
  def find_free_port() -> int:
@@ -3,7 +3,9 @@ import posixpath
3
3
  import typing
4
4
  from collections.abc import Mapping, Sequence
5
5
  from pathlib import PurePath, PurePosixPath
6
- from typing import Union
6
+ from typing import Optional, Union
7
+
8
+ from typing_extensions import TypeGuard
7
9
 
8
10
  from ..cloud_bucket_mount import _CloudBucketMount
9
11
  from ..exception import InvalidError
@@ -76,3 +78,26 @@ def validate_volumes(
76
78
  )
77
79
 
78
80
  return validated_volumes
81
+
82
+
83
+ def validate_only_modal_volumes(
84
+ volumes: Optional[Optional[dict[Union[str, PurePosixPath], _Volume]]],
85
+ caller_name: str,
86
+ ) -> Sequence[tuple[str, _Volume]]:
87
+ """Validate all volumes are `modal.Volume`."""
88
+ if volumes is None:
89
+ return []
90
+
91
+ validated_volumes = validate_volumes(volumes)
92
+
93
+ # Although the typing forbids `_CloudBucketMount` for type checking, one can still pass a `_CloudBucketMount`
94
+ # during runtime, so we'll check the type here.
95
+ def all_modal_volumes(
96
+ vols: Sequence[tuple[str, Union[_Volume, _CloudBucketMount]]],
97
+ ) -> TypeGuard[Sequence[tuple[str, _Volume]]]:
98
+ return all(isinstance(v, _Volume) for _, v in vols)
99
+
100
+ if not all_modal_volumes(validated_volumes):
101
+ raise InvalidError(f"{caller_name} only supports volumes that are modal.Volume")
102
+
103
+ return validated_volumes