modal 1.0.3.dev10__py3-none-any.whl → 1.2.3.dev7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of modal might be problematic. Click here for more details.

Files changed (160) hide show
  1. modal/__init__.py +0 -2
  2. modal/__main__.py +3 -4
  3. modal/_billing.py +80 -0
  4. modal/_clustered_functions.py +7 -3
  5. modal/_clustered_functions.pyi +15 -3
  6. modal/_container_entrypoint.py +51 -69
  7. modal/_functions.py +508 -240
  8. modal/_grpc_client.py +171 -0
  9. modal/_load_context.py +105 -0
  10. modal/_object.py +81 -21
  11. modal/_output.py +58 -45
  12. modal/_partial_function.py +48 -73
  13. modal/_pty.py +7 -3
  14. modal/_resolver.py +26 -46
  15. modal/_runtime/asgi.py +4 -3
  16. modal/_runtime/container_io_manager.py +358 -220
  17. modal/_runtime/container_io_manager.pyi +296 -101
  18. modal/_runtime/execution_context.py +18 -2
  19. modal/_runtime/execution_context.pyi +64 -7
  20. modal/_runtime/gpu_memory_snapshot.py +262 -57
  21. modal/_runtime/user_code_imports.py +28 -58
  22. modal/_serialization.py +90 -6
  23. modal/_traceback.py +42 -1
  24. modal/_tunnel.pyi +380 -12
  25. modal/_utils/async_utils.py +84 -29
  26. modal/_utils/auth_token_manager.py +111 -0
  27. modal/_utils/blob_utils.py +181 -58
  28. modal/_utils/deprecation.py +19 -0
  29. modal/_utils/function_utils.py +91 -47
  30. modal/_utils/grpc_utils.py +89 -66
  31. modal/_utils/mount_utils.py +26 -1
  32. modal/_utils/name_utils.py +17 -3
  33. modal/_utils/task_command_router_client.py +536 -0
  34. modal/_utils/time_utils.py +34 -6
  35. modal/app.py +256 -88
  36. modal/app.pyi +909 -92
  37. modal/billing.py +5 -0
  38. modal/builder/2025.06.txt +18 -0
  39. modal/builder/PREVIEW.txt +18 -0
  40. modal/builder/base-images.json +58 -0
  41. modal/cli/_download.py +19 -3
  42. modal/cli/_traceback.py +3 -2
  43. modal/cli/app.py +4 -4
  44. modal/cli/cluster.py +15 -7
  45. modal/cli/config.py +5 -3
  46. modal/cli/container.py +7 -6
  47. modal/cli/dict.py +22 -16
  48. modal/cli/entry_point.py +12 -5
  49. modal/cli/environment.py +5 -4
  50. modal/cli/import_refs.py +3 -3
  51. modal/cli/launch.py +102 -5
  52. modal/cli/network_file_system.py +11 -12
  53. modal/cli/profile.py +3 -2
  54. modal/cli/programs/launch_instance_ssh.py +94 -0
  55. modal/cli/programs/run_jupyter.py +1 -1
  56. modal/cli/programs/run_marimo.py +95 -0
  57. modal/cli/programs/vscode.py +1 -1
  58. modal/cli/queues.py +57 -26
  59. modal/cli/run.py +91 -23
  60. modal/cli/secret.py +48 -22
  61. modal/cli/token.py +7 -8
  62. modal/cli/utils.py +4 -7
  63. modal/cli/volume.py +31 -25
  64. modal/client.py +15 -85
  65. modal/client.pyi +183 -62
  66. modal/cloud_bucket_mount.py +5 -3
  67. modal/cloud_bucket_mount.pyi +197 -5
  68. modal/cls.py +200 -126
  69. modal/cls.pyi +446 -68
  70. modal/config.py +29 -11
  71. modal/container_process.py +319 -19
  72. modal/container_process.pyi +190 -20
  73. modal/dict.py +290 -71
  74. modal/dict.pyi +835 -83
  75. modal/environments.py +15 -27
  76. modal/environments.pyi +46 -24
  77. modal/exception.py +14 -2
  78. modal/experimental/__init__.py +194 -40
  79. modal/experimental/flash.py +618 -0
  80. modal/experimental/flash.pyi +380 -0
  81. modal/experimental/ipython.py +11 -7
  82. modal/file_io.py +29 -36
  83. modal/file_io.pyi +251 -53
  84. modal/file_pattern_matcher.py +56 -16
  85. modal/functions.pyi +673 -92
  86. modal/gpu.py +1 -1
  87. modal/image.py +528 -176
  88. modal/image.pyi +1572 -145
  89. modal/io_streams.py +458 -128
  90. modal/io_streams.pyi +433 -52
  91. modal/mount.py +216 -151
  92. modal/mount.pyi +225 -78
  93. modal/network_file_system.py +45 -62
  94. modal/network_file_system.pyi +277 -56
  95. modal/object.pyi +93 -17
  96. modal/parallel_map.py +942 -129
  97. modal/parallel_map.pyi +294 -15
  98. modal/partial_function.py +0 -2
  99. modal/partial_function.pyi +234 -19
  100. modal/proxy.py +17 -8
  101. modal/proxy.pyi +36 -3
  102. modal/queue.py +270 -65
  103. modal/queue.pyi +817 -57
  104. modal/runner.py +115 -101
  105. modal/runner.pyi +205 -49
  106. modal/sandbox.py +512 -136
  107. modal/sandbox.pyi +845 -111
  108. modal/schedule.py +1 -1
  109. modal/secret.py +300 -70
  110. modal/secret.pyi +589 -34
  111. modal/serving.py +7 -11
  112. modal/serving.pyi +7 -8
  113. modal/snapshot.py +11 -8
  114. modal/snapshot.pyi +25 -4
  115. modal/token_flow.py +4 -4
  116. modal/token_flow.pyi +28 -8
  117. modal/volume.py +416 -158
  118. modal/volume.pyi +1117 -121
  119. {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/METADATA +10 -9
  120. modal-1.2.3.dev7.dist-info/RECORD +195 -0
  121. modal_docs/mdmd/mdmd.py +17 -4
  122. modal_proto/api.proto +534 -79
  123. modal_proto/api_grpc.py +337 -1
  124. modal_proto/api_pb2.py +1522 -968
  125. modal_proto/api_pb2.pyi +1619 -134
  126. modal_proto/api_pb2_grpc.py +699 -4
  127. modal_proto/api_pb2_grpc.pyi +226 -14
  128. modal_proto/modal_api_grpc.py +175 -154
  129. modal_proto/sandbox_router.proto +145 -0
  130. modal_proto/sandbox_router_grpc.py +105 -0
  131. modal_proto/sandbox_router_pb2.py +149 -0
  132. modal_proto/sandbox_router_pb2.pyi +333 -0
  133. modal_proto/sandbox_router_pb2_grpc.py +203 -0
  134. modal_proto/sandbox_router_pb2_grpc.pyi +75 -0
  135. modal_proto/task_command_router.proto +144 -0
  136. modal_proto/task_command_router_grpc.py +105 -0
  137. modal_proto/task_command_router_pb2.py +149 -0
  138. modal_proto/task_command_router_pb2.pyi +333 -0
  139. modal_proto/task_command_router_pb2_grpc.py +203 -0
  140. modal_proto/task_command_router_pb2_grpc.pyi +75 -0
  141. modal_version/__init__.py +1 -1
  142. modal/requirements/PREVIEW.txt +0 -16
  143. modal/requirements/base-images.json +0 -26
  144. modal-1.0.3.dev10.dist-info/RECORD +0 -179
  145. modal_proto/modal_options_grpc.py +0 -3
  146. modal_proto/options.proto +0 -19
  147. modal_proto/options_grpc.py +0 -3
  148. modal_proto/options_pb2.py +0 -35
  149. modal_proto/options_pb2.pyi +0 -20
  150. modal_proto/options_pb2_grpc.py +0 -4
  151. modal_proto/options_pb2_grpc.pyi +0 -7
  152. /modal/{requirements → builder}/2023.12.312.txt +0 -0
  153. /modal/{requirements → builder}/2023.12.txt +0 -0
  154. /modal/{requirements → builder}/2024.04.txt +0 -0
  155. /modal/{requirements → builder}/2024.10.txt +0 -0
  156. /modal/{requirements → builder}/README.md +0 -0
  157. {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/WHEEL +0 -0
  158. {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/entry_points.txt +0 -0
  159. {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/licenses/LICENSE +0 -0
  160. {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,8 @@
1
1
  # Copyright Modal Labs 2022
2
2
  import asyncio
3
- import enum
4
3
  import inspect
5
4
  import os
5
+ import typing
6
6
  from collections.abc import AsyncGenerator
7
7
  from enum import Enum
8
8
  from pathlib import Path, PurePosixPath
@@ -10,7 +10,6 @@ from typing import Any, Callable, Literal, Optional
10
10
 
11
11
  from grpclib import GRPCError
12
12
  from grpclib.exceptions import StreamTerminatedError
13
- from synchronicity.exceptions import UserCodeException
14
13
 
15
14
  import modal_proto
16
15
  from modal_proto import api_pb2
@@ -19,7 +18,9 @@ from modal_proto.modal_api_grpc import ModalClientModal
19
18
  from .._serialization import (
20
19
  deserialize,
21
20
  deserialize_data_format,
21
+ get_preferred_payload_format,
22
22
  serialize,
23
+ serialize_data_format as _serialize_data_format,
23
24
  signature_to_parameter_specs,
24
25
  )
25
26
  from .._traceback import append_modal_tb
@@ -33,9 +34,16 @@ from ..exception import (
33
34
  RemoteError,
34
35
  )
35
36
  from ..mount import ROOT_DIR, _is_modal_path, _Mount
36
- from .blob_utils import MAX_OBJECT_SIZE_BYTES, blob_download, blob_upload
37
+ from .blob_utils import (
38
+ MAX_ASYNC_OBJECT_SIZE_BYTES,
39
+ blob_download,
40
+ blob_upload_with_r2_failure_info,
41
+ )
37
42
  from .grpc_utils import RETRYABLE_GRPC_STATUS_CODES
38
43
 
44
+ if typing.TYPE_CHECKING:
45
+ import modal._functions
46
+
39
47
 
40
48
  class FunctionInfoType(Enum):
41
49
  PACKAGE = "package"
@@ -67,6 +75,10 @@ def is_global_object(object_qual_name: str):
67
75
  return "<locals>" not in object_qual_name.split(".")
68
76
 
69
77
 
78
+ def is_flash_object(experimental_options: Optional[dict[str, Any]]) -> bool:
79
+ return experimental_options.get("flash", False) if experimental_options else False
80
+
81
+
70
82
  def is_method_fn(object_qual_name: str):
71
83
  # methods have names like Cls.foo.
72
84
  if "<locals>" in object_qual_name:
@@ -383,9 +395,19 @@ def callable_has_non_self_non_default_params(f: Callable[..., Any]) -> bool:
383
395
 
384
396
 
385
397
  async def _stream_function_call_data(
386
- client, function_call_id: str, variant: Literal["data_in", "data_out"]
398
+ client,
399
+ stub,
400
+ function_call_id: Optional[str],
401
+ variant: Literal["data_in", "data_out"],
402
+ attempt_token: Optional[str] = None,
387
403
  ) -> AsyncGenerator[Any, None]:
388
404
  """Read from the `data_in` or `data_out` stream of a function call."""
405
+ if not function_call_id and not attempt_token:
406
+ raise ValueError("function_call_id or attempt_token is required to read from a data stream")
407
+
408
+ if stub is None:
409
+ stub = client.stub
410
+
389
411
  last_index = 0
390
412
 
391
413
  # TODO(gongy): generalize this logic as util for unary streams
@@ -393,14 +415,19 @@ async def _stream_function_call_data(
393
415
  delay_ms = 1
394
416
 
395
417
  if variant == "data_in":
396
- stub_fn = client.stub.FunctionCallGetDataIn
418
+ stub_fn = stub.FunctionCallGetDataIn
397
419
  elif variant == "data_out":
398
- stub_fn = client.stub.FunctionCallGetDataOut
420
+ stub_fn = stub.FunctionCallGetDataOut
399
421
  else:
400
422
  raise ValueError(f"Invalid variant {variant}")
401
423
 
402
424
  while True:
403
- req = api_pb2.FunctionCallGetDataRequest(function_call_id=function_call_id, last_index=last_index)
425
+ req = api_pb2.FunctionCallGetDataRequest(
426
+ function_call_id=function_call_id,
427
+ last_index=last_index,
428
+ )
429
+ if attempt_token:
430
+ req.attempt_token = attempt_token # oneof clears function_call_id.
404
431
  try:
405
432
  async for chunk in stub_fn.unary_stream(req):
406
433
  if chunk.index <= last_index:
@@ -469,7 +496,12 @@ async def _process_result(result: api_pb2.GenericResult, data_format: int, stub,
469
496
  elif result.status == api_pb2.GenericResult.GENERIC_STATUS_INTERNAL_FAILURE:
470
497
  raise InternalFailure(result.exception)
471
498
  elif result.status != api_pb2.GenericResult.GENERIC_STATUS_SUCCESS:
472
- if data:
499
+ if data and data_format in (api_pb2.DATA_FORMAT_PICKLE, api_pb2.DATA_FORMAT_UNSPECIFIED):
500
+ # *Unspecified data format here but data present usually means that the exception
501
+ # was created by the server representing an exception that occurred during container
502
+ # startup (crash looping) that eventually got escalated to input failures.
503
+ # TaskResult doesn't specify data format, so these results don't have that metadata
504
+ # the moment.
473
505
  try:
474
506
  exc = deserialize(data, client)
475
507
  except DeserializationError as deser_exc:
@@ -497,8 +529,9 @@ async def _process_result(result: api_pb2.GenericResult, data_format: int, stub,
497
529
  append_modal_tb(exc, tb_dict, line_cache)
498
530
  except Exception:
499
531
  pass
500
- uc_exc = UserCodeException(exc_with_hints(exc))
501
- raise uc_exc
532
+
533
+ raise exc_with_hints(exc)
534
+
502
535
  raise RemoteError(result.exception)
503
536
 
504
537
  try:
@@ -511,35 +544,66 @@ async def _process_result(result: api_pb2.GenericResult, data_format: int, stub,
511
544
  ) from deser_exc
512
545
 
513
546
 
547
+ def should_upload(
548
+ num_bytes: int,
549
+ max_object_size_bytes: int,
550
+ function_call_invocation_type: Optional["api_pb2.FunctionCallInvocationType.ValueType"],
551
+ ) -> bool:
552
+ """
553
+ Determine if the input should be uploaded to blob storage.
554
+ """
555
+ return num_bytes > max_object_size_bytes or (
556
+ function_call_invocation_type == api_pb2.FUNCTION_CALL_INVOCATION_TYPE_ASYNC
557
+ and num_bytes > MAX_ASYNC_OBJECT_SIZE_BYTES
558
+ )
559
+
560
+
561
+ # This must be called against the client stub, not the input-plane stub.
514
562
  async def _create_input(
515
- args, kwargs, stub: ModalClientModal, *, idx: Optional[int] = None, method_name: Optional[str] = None
563
+ args,
564
+ kwargs,
565
+ stub: ModalClientModal,
566
+ *,
567
+ function: "modal._functions._Function",
568
+ idx: Optional[int] = None,
569
+ function_call_invocation_type: Optional["api_pb2.FunctionCallInvocationType.ValueType"] = None,
516
570
  ) -> api_pb2.FunctionPutInputsItem:
517
571
  """Serialize function arguments and create a FunctionInput protobuf,
518
572
  uploading to blob storage if needed.
519
573
  """
574
+ method_name = function._use_method_name
575
+ max_object_size_bytes = function._max_object_size_bytes
576
+
520
577
  if idx is None:
521
578
  idx = 0
522
- if method_name is None:
523
- method_name = "" # proto compatible
524
579
 
525
- args_serialized = serialize((args, kwargs))
580
+ data_format = get_preferred_payload_format()
581
+ if not function._metadata:
582
+ raise ExecutionError("Attempted to call function that has not been hydrated with metadata")
526
583
 
527
- if len(args_serialized) > MAX_OBJECT_SIZE_BYTES:
528
- args_blob_id = await blob_upload(args_serialized, stub)
584
+ supported_input_formats = function._metadata.supported_input_formats or [api_pb2.DATA_FORMAT_PICKLE]
585
+ if data_format not in supported_input_formats:
586
+ data_format = supported_input_formats[0]
529
587
 
588
+ args_serialized = _serialize_data_format((args, kwargs), data_format)
589
+
590
+ if should_upload(len(args_serialized), max_object_size_bytes, function_call_invocation_type):
591
+ args_blob_id, r2_failed, r2_throughput_bytes_s = await blob_upload_with_r2_failure_info(args_serialized, stub)
530
592
  return api_pb2.FunctionPutInputsItem(
531
593
  input=api_pb2.FunctionInput(
532
594
  args_blob_id=args_blob_id,
533
- data_format=api_pb2.DATA_FORMAT_PICKLE,
595
+ data_format=data_format,
534
596
  method_name=method_name,
535
597
  ),
536
598
  idx=idx,
599
+ r2_failed=r2_failed,
600
+ r2_throughput_bytes_s=r2_throughput_bytes_s,
537
601
  )
538
602
  else:
539
603
  return api_pb2.FunctionPutInputsItem(
540
604
  input=api_pb2.FunctionInput(
541
605
  args=args_serialized,
542
- data_format=api_pb2.DATA_FORMAT_PICKLE,
606
+ data_format=data_format,
543
607
  method_name=method_name,
544
608
  ),
545
609
  idx=idx,
@@ -581,14 +645,13 @@ class FunctionCreationStatus:
581
645
  if not self.response:
582
646
  self.status_row.finish(f"Unknown error when creating function {self.tag}")
583
647
 
584
- elif self.response.function.web_url:
648
+ elif web_url := self.response.handle_metadata.web_url:
585
649
  url_info = self.response.function.web_url_info
586
650
  requires_proxy_auth = self.response.function.webhook_config.requires_proxy_auth
587
651
  proxy_auth_suffix = " 🔑" if requires_proxy_auth else ""
588
652
  # Ensure terms used here match terms used in modal.com/docs/guide/webhook-urls doc.
589
653
  suffix = _get_suffix_from_web_url_info(url_info)
590
654
  # TODO: this is only printed when we're showing progress. Maybe move this somewhere else.
591
- web_url = self.response.handle_metadata.web_url
592
655
  for warning in self.response.server_warnings:
593
656
  self.status_row.warning(warning)
594
657
  self.status_row.finish(
@@ -603,6 +666,14 @@ class FunctionCreationStatus:
603
666
  f"Custom domain for {self.tag} => [magenta underline]{custom_domain.url}[/magenta underline]"
604
667
  )
605
668
 
669
+ elif self.response.function.flash_service_urls:
670
+ for flash_service_url in self.response.function.flash_service_urls:
671
+ flash_service_url_status_row = self.resolver.add_status_row()
672
+ flash_service_url_status_row.finish(
673
+ f"Created flash service endpoint for {self.tag} => "
674
+ f"[magenta underline]{flash_service_url}[/magenta underline]"
675
+ )
676
+
606
677
  else:
607
678
  for warning in self.response.server_warnings:
608
679
  self.status_row.warning(warning)
@@ -623,30 +694,3 @@ class FunctionCreationStatus:
623
694
  f"Custom domain for {method_definition.function_name} => [magenta underline]"
624
695
  f"{custom_domain.url}[/magenta underline]"
625
696
  )
626
-
627
-
628
- class IncludeSourceMode(enum.Enum):
629
- INCLUDE_NOTHING = False # can only be set in source, can't be set in config
630
- INCLUDE_MAIN_PACKAGE = True # Default behavior
631
-
632
-
633
- def get_include_source_mode(function_or_app_specific) -> IncludeSourceMode:
634
- """Which "automount" behavior should a function use
635
-
636
- function_or_app_specific: explicit value given in the @function or @cls decorator, in an App constructor, or None
637
-
638
- If function_or_app_specific is specified, validate and return the IncludeSourceMode
639
- If function_or_app_specific is None, infer it from config
640
- """
641
- if function_or_app_specific is not None:
642
- if not isinstance(function_or_app_specific, bool):
643
- raise ValueError(
644
- f"Invalid `include_source` value: {function_or_app_specific}. Use one of:\n"
645
- f"True - include function's package source\n"
646
- f"False - include no Python source (module expected to be present in Image)\n"
647
- )
648
-
649
- # explicitly set in app/function
650
- return IncludeSourceMode(function_or_app_specific)
651
-
652
- return IncludeSourceMode.INCLUDE_MAIN_PACKAGE
@@ -8,12 +8,8 @@ import typing
8
8
  import urllib.parse
9
9
  import uuid
10
10
  from collections.abc import AsyncIterator
11
- from dataclasses import dataclass
12
- from typing import (
13
- Any,
14
- Optional,
15
- TypeVar,
16
- )
11
+ from dataclasses import dataclass, field
12
+ from typing import Any, Optional, TypeVar
17
13
 
18
14
  import grpclib.client
19
15
  import grpclib.config
@@ -28,12 +24,15 @@ from grpclib.protocol import H2Protocol
28
24
  from modal.exception import AuthError, ConnectionError
29
25
  from modal_version import __version__
30
26
 
27
+ from .._traceback import suppress_tb_frames
28
+ from .async_utils import retry
31
29
  from .logger import logger
32
30
 
33
31
  RequestType = TypeVar("RequestType", bound=Message)
34
32
  ResponseType = TypeVar("ResponseType", bound=Message)
35
33
 
36
34
  if typing.TYPE_CHECKING:
35
+ import modal._grpc_client
37
36
  import modal.client
38
37
 
39
38
  # Monkey patches grpclib to have a Modal User Agent header.
@@ -148,27 +147,14 @@ def create_channel(
148
147
 
149
148
  logger.debug(f"Sending request to {event.method_name}")
150
149
 
151
- async def recv_initial_metadata(initial_metadata: grpclib.events.RecvInitialMetadata) -> None:
152
- # If we receive an auth token from the server, include it in all future requests.
153
- # TODO(nathan): This isn't perfect because the metadata isn't propagated when the
154
- # process is forked and a new channel is created. This is OK for now since this
155
- # token is only used by the experimental input plane
156
- if token := initial_metadata.metadata.get("x-modal-auth-token"):
157
- metadata["x-modal-auth-token"] = str(token)
158
-
159
- async def recv_trailing_metadata(trailing_metadata: grpclib.events.RecvTrailingMetadata) -> None:
160
- if token := trailing_metadata.metadata.get("x-modal-auth-token"):
161
- metadata["x-modal-auth-token"] = str(token)
162
-
163
150
  grpclib.events.listen(channel, grpclib.events.SendRequest, send_request)
164
- grpclib.events.listen(channel, grpclib.events.RecvInitialMetadata, recv_initial_metadata)
165
- grpclib.events.listen(channel, grpclib.events.RecvTrailingMetadata, recv_trailing_metadata)
166
151
 
167
152
  return channel
168
153
 
169
154
 
155
+ @retry(n_attempts=5, base_delay=0.1)
170
156
  async def connect_channel(channel: grpclib.client.Channel):
171
- """Connects socket (potentially raising errors raising to connectivity."""
157
+ """Connect to socket and raise exceptions when there is a connection issue."""
172
158
  await channel.__connect__()
173
159
 
174
160
 
@@ -177,7 +163,7 @@ if typing.TYPE_CHECKING:
177
163
 
178
164
 
179
165
  async def unary_stream(
180
- method: "modal.client.UnaryStreamWrapper[RequestType, ResponseType]",
166
+ method: "modal._grpc_client.UnaryStreamWrapper[RequestType, ResponseType]",
181
167
  request: RequestType,
182
168
  metadata: Optional[Any] = None,
183
169
  ) -> AsyncIterator[ResponseType]:
@@ -186,50 +172,86 @@ async def unary_stream(
186
172
  yield item
187
173
 
188
174
 
175
+ @dataclass(frozen=True)
176
+ class Retry:
177
+ base_delay: float = 0.1
178
+ max_delay: float = 1
179
+ delay_factor: float = 2
180
+ max_retries: Optional[int] = 3
181
+ additional_status_codes: list = field(default_factory=list)
182
+ attempt_timeout: Optional[float] = None # timeout for each attempt
183
+ total_timeout: Optional[float] = None # timeout for the entire function call
184
+ attempt_timeout_floor: float = 2.0 # always have at least this much timeout (only for total_timeout)
185
+ warning_message: Optional[RetryWarningMessage] = None
186
+
187
+
189
188
  async def retry_transient_errors(
190
- fn: "modal.client.UnaryUnaryWrapper[RequestType, ResponseType]",
191
- *args,
192
- base_delay: float = 0.1,
193
- max_delay: float = 1,
194
- delay_factor: float = 2,
189
+ fn: "grpclib.client.UnaryUnaryMethod[RequestType, ResponseType]",
190
+ req: RequestType,
195
191
  max_retries: Optional[int] = 3,
196
- additional_status_codes: list = [],
197
- attempt_timeout: Optional[float] = None, # timeout for each attempt
198
- total_timeout: Optional[float] = None, # timeout for the entire function call
199
- attempt_timeout_floor=2.0, # always have at least this much timeout (only for total_timeout)
200
- retry_warning_message: Optional[RetryWarningMessage] = None,
192
+ ) -> ResponseType:
193
+ """Minimum API version of _retry_transient_errors that works with grpclib.client.UnaryUnaryMethod.
194
+
195
+ Used by modal server.
196
+ """
197
+ return await _retry_transient_errors(fn, req, retry=Retry(max_retries=max_retries))
198
+
199
+
200
+ async def _retry_transient_errors(
201
+ fn: typing.Union[
202
+ "modal._grpc_client.UnaryUnaryWrapper[RequestType, ResponseType]",
203
+ "grpclib.client.UnaryUnaryMethod[RequestType, ResponseType]",
204
+ ],
205
+ req: RequestType,
206
+ retry: Retry,
207
+ metadata: Optional[list[tuple[str, str]]] = None,
201
208
  ) -> ResponseType:
202
209
  """Retry on transient gRPC failures with back-off until max_retries is reached.
203
210
  If max_retries is None, retry forever."""
211
+ import modal._grpc_client
212
+
213
+ if isinstance(fn, modal._grpc_client.UnaryUnaryWrapper):
214
+ fn_callable = fn.direct
215
+ elif isinstance(fn, grpclib.client.UnaryUnaryMethod):
216
+ fn_callable = fn # type: ignore
217
+ else:
218
+ raise ValueError("Only modal._grpc_client.UnaryUnaryWrapper and grpclib.client.UnaryUnaryMethod are supported")
204
219
 
205
- delay = base_delay
220
+ delay = retry.base_delay
206
221
  n_retries = 0
207
222
 
208
- status_codes = [*RETRYABLE_GRPC_STATUS_CODES, *additional_status_codes]
223
+ status_codes = [*RETRYABLE_GRPC_STATUS_CODES, *retry.additional_status_codes]
209
224
 
210
225
  idempotency_key = str(uuid.uuid4())
211
226
 
212
227
  t0 = time.time()
213
- if total_timeout is not None:
214
- total_deadline = t0 + total_timeout
228
+ if retry.total_timeout is not None:
229
+ total_deadline = t0 + retry.total_timeout
215
230
  else:
216
231
  total_deadline = None
217
232
 
233
+ metadata = (metadata or []) + [("x-modal-timestamp", str(time.time()))]
234
+
218
235
  while True:
219
- metadata = [("x-idempotency-key", idempotency_key), ("x-retry-attempt", str(n_retries))]
236
+ attempt_metadata = [
237
+ ("x-idempotency-key", idempotency_key),
238
+ ("x-retry-attempt", str(n_retries)),
239
+ *metadata,
240
+ ]
220
241
  if n_retries > 0:
221
- metadata.append(("x-retry-delay", str(time.time() - t0)))
242
+ attempt_metadata.append(("x-retry-delay", str(time.time() - t0)))
222
243
  timeouts = []
223
- if attempt_timeout is not None:
224
- timeouts.append(attempt_timeout)
225
- if total_timeout is not None:
226
- timeouts.append(max(total_deadline - time.time(), attempt_timeout_floor))
244
+ if retry.attempt_timeout is not None:
245
+ timeouts.append(retry.attempt_timeout)
246
+ if retry.total_timeout is not None and total_deadline is not None:
247
+ timeouts.append(max(total_deadline - time.time(), retry.attempt_timeout_floor))
227
248
  if timeouts:
228
249
  timeout = min(timeouts) # In case the function provided both types of timeouts
229
250
  else:
230
251
  timeout = None
231
252
  try:
232
- return await fn(*args, metadata=metadata, timeout=timeout)
253
+ with suppress_tb_frames(1):
254
+ return await fn_callable(req, metadata=attempt_metadata, timeout=timeout)
233
255
  except (StreamTerminatedError, GRPCError, OSError, asyncio.TimeoutError, AttributeError) as exc:
234
256
  if isinstance(exc, GRPCError) and exc.status not in status_codes:
235
257
  if exc.status == Status.UNAUTHENTICATED:
@@ -237,45 +259,46 @@ async def retry_transient_errors(
237
259
  else:
238
260
  raise exc
239
261
 
240
- if max_retries is not None and n_retries >= max_retries:
262
+ if retry.max_retries is not None and n_retries >= retry.max_retries:
241
263
  final_attempt = True
242
- elif total_deadline is not None and time.time() + delay + attempt_timeout_floor >= total_deadline:
264
+ elif total_deadline is not None and time.time() + delay + retry.attempt_timeout_floor >= total_deadline:
243
265
  final_attempt = True
244
266
  else:
245
267
  final_attempt = False
246
268
 
247
- if final_attempt:
248
- logger.debug(
249
- f"Final attempt failed with {repr(exc)} {n_retries=} {delay=} "
250
- f"{total_deadline=} for {fn.name} ({idempotency_key[:8]})"
251
- )
252
- if isinstance(exc, OSError):
253
- raise ConnectionError(str(exc))
254
- elif isinstance(exc, asyncio.TimeoutError):
255
- raise ConnectionError(str(exc))
256
- else:
269
+ with suppress_tb_frames(1):
270
+ if final_attempt:
271
+ logger.debug(
272
+ f"Final attempt failed with {repr(exc)} {n_retries=} {delay=} "
273
+ f"{total_deadline=} for {fn.name} ({idempotency_key[:8]})"
274
+ )
275
+ if isinstance(exc, OSError):
276
+ raise ConnectionError(str(exc))
277
+ elif isinstance(exc, asyncio.TimeoutError):
278
+ raise ConnectionError(str(exc))
279
+ else:
280
+ raise exc
281
+
282
+ if isinstance(exc, AttributeError) and "_write_appdata" not in str(exc):
283
+ # StreamTerminatedError are not properly raised in grpclib<=0.4.7
284
+ # fixed in https://github.com/vmagamedov/grpclib/issues/185
285
+ # TODO: update to newer version (>=0.4.8) once stable
257
286
  raise exc
258
287
 
259
- if isinstance(exc, AttributeError) and "_write_appdata" not in str(exc):
260
- # StreamTerminatedError are not properly raised in grpclib<=0.4.7
261
- # fixed in https://github.com/vmagamedov/grpclib/issues/185
262
- # TODO: update to newer version (>=0.4.8) once stable
263
- raise exc
264
-
265
288
  logger.debug(f"Retryable failure {repr(exc)} {n_retries=} {delay=} for {fn.name} ({idempotency_key[:8]})")
266
289
 
267
290
  n_retries += 1
268
291
 
269
292
  if (
270
- retry_warning_message
271
- and n_retries % retry_warning_message.warning_interval == 0
293
+ retry.warning_message
294
+ and n_retries % retry.warning_message.warning_interval == 0
272
295
  and isinstance(exc, GRPCError)
273
- and exc.status in retry_warning_message.errors_to_warn_for
296
+ and exc.status in retry.warning_message.errors_to_warn_for
274
297
  ):
275
- logger.warning(retry_warning_message.message)
298
+ logger.warning(retry.warning_message.message)
276
299
 
277
300
  await asyncio.sleep(delay)
278
- delay = min(delay * delay_factor, max_delay)
301
+ delay = min(delay * retry.delay_factor, retry.max_delay)
279
302
 
280
303
 
281
304
  def find_free_port() -> int:
@@ -3,7 +3,9 @@ import posixpath
3
3
  import typing
4
4
  from collections.abc import Mapping, Sequence
5
5
  from pathlib import PurePath, PurePosixPath
6
- from typing import Union
6
+ from typing import Optional, Union
7
+
8
+ from typing_extensions import TypeGuard
7
9
 
8
10
  from ..cloud_bucket_mount import _CloudBucketMount
9
11
  from ..exception import InvalidError
@@ -76,3 +78,26 @@ def validate_volumes(
76
78
  )
77
79
 
78
80
  return validated_volumes
81
+
82
+
83
+ def validate_only_modal_volumes(
84
+ volumes: Optional[Optional[dict[Union[str, PurePosixPath], _Volume]]],
85
+ caller_name: str,
86
+ ) -> Sequence[tuple[str, _Volume]]:
87
+ """Validate all volumes are `modal.Volume`."""
88
+ if volumes is None:
89
+ return []
90
+
91
+ validated_volumes = validate_volumes(volumes)
92
+
93
+ # Although the typing forbids `_CloudBucketMount` for type checking, one can still pass a `_CloudBucketMount`
94
+ # during runtime, so we'll check the type here.
95
+ def all_modal_volumes(
96
+ vols: Sequence[tuple[str, Union[_Volume, _CloudBucketMount]]],
97
+ ) -> TypeGuard[Sequence[tuple[str, _Volume]]]:
98
+ return all(isinstance(v, _Volume) for _, v in vols)
99
+
100
+ if not all_modal_volumes(validated_volumes):
101
+ raise InvalidError(f"{caller_name} only supports volumes that are modal.Volume")
102
+
103
+ return validated_volumes
@@ -1,5 +1,6 @@
1
1
  # Copyright Modal Labs 2022
2
2
  import re
3
+ from collections.abc import Mapping
3
4
 
4
5
  from ..exception import InvalidError
5
6
 
@@ -31,12 +32,25 @@ def is_valid_environment_name(name: str) -> bool:
31
32
  return len(name) <= 64 and re.match(r"^[a-zA-Z0-9][a-zA-Z0-9-_.]+$", name) is not None
32
33
 
33
34
 
34
- def is_valid_tag(tag: str) -> bool:
35
- """Tags are alphanumeric, dashes, periods, and underscores, and must be 50 characters or less"""
36
- pattern = r"^[a-zA-Z0-9._-]{1,50}$"
35
+ def is_valid_tag(tag: str, max_length: int = 50) -> bool:
36
+ """Tags are alphanumeric, dashes, periods, and underscores, and not longer than the max_length."""
37
+ pattern = rf"^[a-zA-Z0-9._-]{{1,{max_length}}}$"
37
38
  return bool(re.match(pattern, tag))
38
39
 
39
40
 
41
+ def check_tag_dict(tags: Mapping[str, str]) -> None:
42
+ rules = (
43
+ "\n\nTags may contain only alphanumeric characters, dashes, periods, or underscores, "
44
+ "and must be 63 characters or less."
45
+ )
46
+ max_length = 63
47
+ for key, value in tags.items():
48
+ if not is_valid_tag(key, max_length):
49
+ raise InvalidError(f"Invalid tag key: {key!r}.{rules}")
50
+ if not is_valid_tag(value, max_length):
51
+ raise InvalidError(f"Invalid tag value: {value!r}.{rules}")
52
+
53
+
40
54
  def check_object_name(name: str, object_type: str) -> None:
41
55
  message = (
42
56
  f"Invalid {object_type} name: '{name}'."