modal 1.1.5.dev66__py3-none-any.whl → 1.3.1.dev8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of modal might be problematic. Click here for more details.

Files changed (143) hide show
  1. modal/__init__.py +4 -4
  2. modal/__main__.py +4 -29
  3. modal/_billing.py +84 -0
  4. modal/_clustered_functions.py +1 -3
  5. modal/_container_entrypoint.py +33 -208
  6. modal/_functions.py +171 -138
  7. modal/_grpc_client.py +191 -0
  8. modal/_ipython.py +16 -6
  9. modal/_load_context.py +106 -0
  10. modal/_object.py +72 -21
  11. modal/_output.py +12 -14
  12. modal/_partial_function.py +31 -4
  13. modal/_resolver.py +44 -57
  14. modal/_runtime/container_io_manager.py +30 -28
  15. modal/_runtime/container_io_manager.pyi +42 -44
  16. modal/_runtime/gpu_memory_snapshot.py +9 -7
  17. modal/_runtime/user_code_event_loop.py +80 -0
  18. modal/_runtime/user_code_imports.py +236 -10
  19. modal/_serialization.py +2 -1
  20. modal/_traceback.py +4 -13
  21. modal/_tunnel.py +16 -11
  22. modal/_tunnel.pyi +25 -3
  23. modal/_utils/async_utils.py +337 -10
  24. modal/_utils/auth_token_manager.py +1 -4
  25. modal/_utils/blob_utils.py +29 -22
  26. modal/_utils/function_utils.py +20 -21
  27. modal/_utils/grpc_testing.py +6 -3
  28. modal/_utils/grpc_utils.py +223 -64
  29. modal/_utils/mount_utils.py +26 -1
  30. modal/_utils/name_utils.py +2 -3
  31. modal/_utils/package_utils.py +0 -1
  32. modal/_utils/rand_pb_testing.py +8 -1
  33. modal/_utils/task_command_router_client.py +524 -0
  34. modal/_vendor/cloudpickle.py +144 -48
  35. modal/app.py +285 -105
  36. modal/app.pyi +216 -53
  37. modal/billing.py +5 -0
  38. modal/builder/2025.06.txt +6 -3
  39. modal/builder/PREVIEW.txt +2 -1
  40. modal/builder/base-images.json +4 -2
  41. modal/cli/_download.py +19 -3
  42. modal/cli/cluster.py +4 -2
  43. modal/cli/config.py +3 -1
  44. modal/cli/container.py +5 -4
  45. modal/cli/dict.py +5 -2
  46. modal/cli/entry_point.py +26 -2
  47. modal/cli/environment.py +2 -16
  48. modal/cli/launch.py +1 -76
  49. modal/cli/network_file_system.py +5 -20
  50. modal/cli/programs/run_jupyter.py +1 -1
  51. modal/cli/programs/vscode.py +1 -1
  52. modal/cli/queues.py +5 -4
  53. modal/cli/run.py +24 -204
  54. modal/cli/secret.py +1 -2
  55. modal/cli/shell.py +375 -0
  56. modal/cli/utils.py +1 -13
  57. modal/cli/volume.py +11 -17
  58. modal/client.py +16 -125
  59. modal/client.pyi +94 -144
  60. modal/cloud_bucket_mount.py +3 -1
  61. modal/cloud_bucket_mount.pyi +4 -0
  62. modal/cls.py +101 -64
  63. modal/cls.pyi +9 -8
  64. modal/config.py +21 -1
  65. modal/container_process.py +288 -12
  66. modal/container_process.pyi +99 -38
  67. modal/dict.py +72 -33
  68. modal/dict.pyi +88 -57
  69. modal/environments.py +16 -8
  70. modal/environments.pyi +6 -2
  71. modal/exception.py +154 -16
  72. modal/experimental/__init__.py +24 -53
  73. modal/experimental/flash.py +161 -74
  74. modal/experimental/flash.pyi +97 -49
  75. modal/file_io.py +50 -92
  76. modal/file_io.pyi +117 -89
  77. modal/functions.pyi +70 -87
  78. modal/image.py +82 -47
  79. modal/image.pyi +51 -30
  80. modal/io_streams.py +500 -149
  81. modal/io_streams.pyi +279 -189
  82. modal/mount.py +60 -46
  83. modal/mount.pyi +41 -17
  84. modal/network_file_system.py +19 -11
  85. modal/network_file_system.pyi +72 -39
  86. modal/object.pyi +114 -22
  87. modal/parallel_map.py +42 -44
  88. modal/parallel_map.pyi +9 -17
  89. modal/partial_function.pyi +4 -2
  90. modal/proxy.py +14 -6
  91. modal/proxy.pyi +10 -2
  92. modal/queue.py +45 -38
  93. modal/queue.pyi +88 -52
  94. modal/runner.py +96 -96
  95. modal/runner.pyi +44 -27
  96. modal/sandbox.py +225 -107
  97. modal/sandbox.pyi +226 -60
  98. modal/secret.py +58 -56
  99. modal/secret.pyi +28 -13
  100. modal/serving.py +7 -11
  101. modal/serving.pyi +7 -8
  102. modal/snapshot.py +29 -15
  103. modal/snapshot.pyi +18 -10
  104. modal/token_flow.py +1 -1
  105. modal/token_flow.pyi +4 -6
  106. modal/volume.py +102 -55
  107. modal/volume.pyi +125 -66
  108. {modal-1.1.5.dev66.dist-info → modal-1.3.1.dev8.dist-info}/METADATA +10 -9
  109. modal-1.3.1.dev8.dist-info/RECORD +189 -0
  110. modal_proto/api.proto +141 -70
  111. modal_proto/api_grpc.py +42 -26
  112. modal_proto/api_pb2.py +1123 -1103
  113. modal_proto/api_pb2.pyi +331 -83
  114. modal_proto/api_pb2_grpc.py +80 -48
  115. modal_proto/api_pb2_grpc.pyi +26 -18
  116. modal_proto/modal_api_grpc.py +175 -174
  117. modal_proto/task_command_router.proto +164 -0
  118. modal_proto/task_command_router_grpc.py +138 -0
  119. modal_proto/task_command_router_pb2.py +180 -0
  120. modal_proto/{sandbox_router_pb2.pyi → task_command_router_pb2.pyi} +148 -57
  121. modal_proto/task_command_router_pb2_grpc.py +272 -0
  122. modal_proto/task_command_router_pb2_grpc.pyi +100 -0
  123. modal_version/__init__.py +1 -1
  124. modal_version/__main__.py +1 -1
  125. modal/cli/programs/launch_instance_ssh.py +0 -94
  126. modal/cli/programs/run_marimo.py +0 -95
  127. modal-1.1.5.dev66.dist-info/RECORD +0 -191
  128. modal_proto/modal_options_grpc.py +0 -3
  129. modal_proto/options.proto +0 -19
  130. modal_proto/options_grpc.py +0 -3
  131. modal_proto/options_pb2.py +0 -35
  132. modal_proto/options_pb2.pyi +0 -20
  133. modal_proto/options_pb2_grpc.py +0 -4
  134. modal_proto/options_pb2_grpc.pyi +0 -7
  135. modal_proto/sandbox_router.proto +0 -125
  136. modal_proto/sandbox_router_grpc.py +0 -89
  137. modal_proto/sandbox_router_pb2.py +0 -128
  138. modal_proto/sandbox_router_pb2_grpc.py +0 -169
  139. modal_proto/sandbox_router_pb2_grpc.pyi +0 -63
  140. {modal-1.1.5.dev66.dist-info → modal-1.3.1.dev8.dist-info}/WHEEL +0 -0
  141. {modal-1.1.5.dev66.dist-info → modal-1.3.1.dev8.dist-info}/entry_points.txt +0 -0
  142. {modal-1.1.5.dev66.dist-info → modal-1.3.1.dev8.dist-info}/licenses/LICENSE +0 -0
  143. {modal-1.1.5.dev66.dist-info → modal-1.3.1.dev8.dist-info}/top_level.txt +0 -0
@@ -19,7 +19,7 @@ from ._functions import _Function
19
19
  from ._utils.async_utils import synchronizer
20
20
  from ._utils.deprecation import deprecation_warning
21
21
  from ._utils.function_utils import callable_has_non_self_params
22
- from .config import logger
22
+ from .config import config, logger
23
23
  from .exception import InvalidError
24
24
 
25
25
  MAX_MAX_BATCH_SIZE = 1000
@@ -46,6 +46,7 @@ class _PartialFunctionFlags(enum.IntFlag):
46
46
  BATCHED = 64
47
47
  CONCURRENT = 128
48
48
  CLUSTERED = 256 # Experimental: Clustered functions
49
+ HTTP_WEB_INTERFACE = 512 # Experimental: HTTP server
49
50
 
50
51
  @staticmethod
51
52
  def all() -> int:
@@ -76,6 +77,7 @@ class _PartialFunctionParams:
76
77
  target_concurrent_inputs: Optional[int] = None
77
78
  build_timeout: Optional[int] = None
78
79
  rdma: Optional[bool] = None
80
+ http_config: Optional[api_pb2.HTTPConfig] = None
79
81
 
80
82
  def update(self, other: "_PartialFunctionParams") -> None:
81
83
  """Update self with params set in other."""
@@ -93,6 +95,26 @@ NullaryFuncOrMethod = Union[Callable[[], Any], Callable[[Any], Any]]
93
95
  NullaryMethod = Callable[[Any], Any]
94
96
 
95
97
 
98
+ def verify_concurrent_params(params: _PartialFunctionParams, is_flash: bool = False) -> None:
99
+ def _verify_concurrent_params_with_flash_settings(params: _PartialFunctionParams) -> None:
100
+ if params.max_concurrent_inputs is not None:
101
+ raise TypeError(
102
+ "@modal.concurrent(max_inputs=...) is not yet supported for Flash functions. "
103
+ "Use `@modal.concurrent(target_inputs=...)` instead."
104
+ )
105
+ if params.target_concurrent_inputs is None:
106
+ raise TypeError("`@modal.concurrent()` missing required argument: `target_inputs`.")
107
+
108
+ def _verify_concurrent_params(params: _PartialFunctionParams) -> None:
109
+ if params.max_concurrent_inputs is None:
110
+ raise TypeError("`@modal.concurrent()` missing required argument: `max_inputs`.")
111
+
112
+ if is_flash:
113
+ _verify_concurrent_params_with_flash_settings(params)
114
+ else:
115
+ _verify_concurrent_params(params)
116
+
117
+
96
118
  class _PartialFunction(typing.Generic[P, ReturnType, OriginalReturnType]):
97
119
  """Object produced by a decorator in the `modal` namespace
98
120
 
@@ -199,7 +221,7 @@ class _PartialFunction(typing.Generic[P, ReturnType, OriginalReturnType]):
199
221
  # of the type PartialFunction and this descriptor would be triggered when accessing it,
200
222
  #
201
223
  # However, modal classes are *actually* Cls instances (which isn't reflected in type checkers
202
- # due to Python's lack of type chekcing intersection types), so at runtime the Cls instance would
224
+ # due to Python's lack of type checking intersection types), so at runtime the Cls instance would
203
225
  # use its __getattr__ rather than this descriptor.
204
226
  assert self.raw_f is not None # Should only be relevant in a method context
205
227
  k = self.raw_f.__name__
@@ -378,6 +400,7 @@ def _fastapi_endpoint(
378
400
  method=method,
379
401
  web_endpoint_docs=docs,
380
402
  requested_suffix=label or "",
403
+ ephemeral_suffix=config.get("dev_suffix"),
381
404
  async_mode=api_pb2.WEBHOOK_ASYNC_MODE_AUTO,
382
405
  custom_domains=_parse_custom_domains(custom_domains),
383
406
  requires_proxy_auth=requires_proxy_auth,
@@ -446,6 +469,7 @@ def _web_endpoint(
446
469
  method=method,
447
470
  web_endpoint_docs=docs,
448
471
  requested_suffix=label or "",
472
+ ephemeral_suffix=config.get("dev_suffix"),
449
473
  async_mode=api_pb2.WEBHOOK_ASYNC_MODE_AUTO,
450
474
  custom_domains=_parse_custom_domains(custom_domains),
451
475
  requires_proxy_auth=requires_proxy_auth,
@@ -505,6 +529,7 @@ def _asgi_app(
505
529
  webhook_config = api_pb2.WebhookConfig(
506
530
  type=api_pb2.WEBHOOK_TYPE_ASGI_APP,
507
531
  requested_suffix=label or "",
532
+ ephemeral_suffix=config.get("dev_suffix"),
508
533
  async_mode=api_pb2.WEBHOOK_ASYNC_MODE_AUTO,
509
534
  custom_domains=_parse_custom_domains(custom_domains),
510
535
  requires_proxy_auth=requires_proxy_auth,
@@ -562,6 +587,7 @@ def _wsgi_app(
562
587
  webhook_config = api_pb2.WebhookConfig(
563
588
  type=api_pb2.WEBHOOK_TYPE_WSGI_APP,
564
589
  requested_suffix=label or "",
590
+ ephemeral_suffix=config.get("dev_suffix"),
565
591
  async_mode=api_pb2.WEBHOOK_ASYNC_MODE_AUTO,
566
592
  custom_domains=_parse_custom_domains(custom_domains),
567
593
  requires_proxy_auth=requires_proxy_auth,
@@ -623,6 +649,7 @@ def _web_server(
623
649
  webhook_config = api_pb2.WebhookConfig(
624
650
  type=api_pb2.WEBHOOK_TYPE_WEB_SERVER,
625
651
  requested_suffix=label or "",
652
+ ephemeral_suffix=config.get("dev_suffix"),
626
653
  async_mode=api_pb2.WEBHOOK_ASYNC_MODE_AUTO,
627
654
  custom_domains=_parse_custom_domains(custom_domains),
628
655
  web_server_port=port,
@@ -760,7 +787,7 @@ def _batched(
760
787
  def _concurrent(
761
788
  _warn_parentheses_missing=None, # mdmd:line-hidden
762
789
  *,
763
- max_inputs: int, # Hard limit on each container's input concurrency
790
+ max_inputs: Optional[int] = None, # Hard limit on each container's input concurrency
764
791
  target_inputs: Optional[int] = None, # Input concurrency that Modal's autoscaler should target
765
792
  ) -> Callable[
766
793
  [Union[Callable[P, ReturnType], _PartialFunction[P, ReturnType, ReturnType]]],
@@ -812,7 +839,7 @@ def _concurrent(
812
839
  "Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@modal.concurrent()`."
813
840
  )
814
841
 
815
- if target_inputs and target_inputs > max_inputs:
842
+ if max_inputs is not None and target_inputs is not None and target_inputs > max_inputs:
816
843
  raise InvalidError("`target_inputs` parameter cannot be greater than `max_inputs`.")
817
844
 
818
845
  flags = _PartialFunctionFlags.CONCURRENT
modal/_resolver.py CHANGED
@@ -8,17 +8,16 @@ from asyncio import Future
8
8
  from collections.abc import Hashable
9
9
  from typing import TYPE_CHECKING, Optional
10
10
 
11
- from modal._traceback import suppress_tb_frames
11
+ import modal._object
12
+ from modal._traceback import suppress_tb_frame
12
13
  from modal_proto import api_pb2
13
14
 
15
+ from ._load_context import LoadContext
14
16
  from ._utils.async_utils import TaskContext
15
- from .client import _Client
16
17
 
17
18
  if TYPE_CHECKING:
18
19
  from rich.tree import Tree
19
20
 
20
- import modal._object
21
-
22
21
 
23
22
  class StatusRow:
24
23
  def __init__(self, progress: "typing.Optional[Tree]"):
@@ -48,19 +47,10 @@ class StatusRow:
48
47
 
49
48
  class Resolver:
50
49
  _local_uuid_to_future: dict[str, Future]
51
- _environment_name: Optional[str]
52
- _app_id: Optional[str]
53
50
  _deduplication_cache: dict[Hashable, Future]
54
- _client: _Client
55
51
  _build_start: float
56
52
 
57
- def __init__(
58
- self,
59
- client: _Client,
60
- *,
61
- environment_name: Optional[str] = None,
62
- app_id: Optional[str] = None,
63
- ):
53
+ def __init__(self):
64
54
  try:
65
55
  # TODO(michael) If we don't clean this up more thoroughly, it would probably
66
56
  # be good to have a single source of truth for "rich is installed" rather than
@@ -75,9 +65,6 @@ class Resolver:
75
65
 
76
66
  self._local_uuid_to_future = {}
77
67
  self._tree = tree
78
- self._client = client
79
- self._app_id = app_id
80
- self._environment_name = environment_name
81
68
  self._deduplication_cache = {}
82
69
 
83
70
  with tempfile.TemporaryFile() as temp_file:
@@ -85,27 +72,24 @@ class Resolver:
85
72
  # to the mtime on mounted files, and want those measurements to have the same resolution.
86
73
  self._build_start = os.fstat(temp_file.fileno()).st_mtime
87
74
 
88
- @property
89
- def app_id(self) -> Optional[str]:
90
- return self._app_id
91
-
92
- @property
93
- def client(self):
94
- return self._client
95
-
96
- @property
97
- def environment_name(self):
98
- return self._environment_name
99
-
100
75
  @property
101
76
  def build_start(self) -> float:
102
77
  return self._build_start
103
78
 
104
- async def preload(self, obj, existing_object_id: Optional[str]):
79
+ async def preload(
80
+ self, obj: "modal._object._Object", parent_load_context: "LoadContext", existing_object_id: Optional[str]
81
+ ):
105
82
  if obj._preload is not None:
106
- await obj._preload(obj, self, existing_object_id)
83
+ load_context = obj._load_context_overrides.merged_with(parent_load_context)
84
+ await obj._preload(obj, self, load_context, existing_object_id)
107
85
 
108
- async def load(self, obj: "modal._object._Object", existing_object_id: Optional[str] = None):
86
+ async def load(
87
+ self,
88
+ obj: "modal._object._Object",
89
+ parent_load_context: "LoadContext",
90
+ *,
91
+ existing_object_id: Optional[str] = None,
92
+ ):
109
93
  if obj._is_hydrated and obj._is_another_app:
110
94
  # No need to reload this, it won't typically change
111
95
  if obj.local_uuid not in self._local_uuid_to_future:
@@ -129,42 +113,45 @@ class Resolver:
129
113
  cached_future = self._deduplication_cache.get(deduplication_key)
130
114
  if cached_future:
131
115
  hydrated_object = await cached_future
132
- obj._hydrate(hydrated_object.object_id, self._client, hydrated_object._get_metadata())
116
+ # Use the client from the already-hydrated object
117
+ obj._hydrate(hydrated_object.object_id, hydrated_object.client, hydrated_object._get_metadata())
133
118
  return obj
134
119
 
135
120
  if not cached_future:
136
121
  # don't run any awaits within this if-block to prevent race conditions
137
122
  async def loader():
138
- # Wait for all its dependencies
139
- # TODO(erikbern): do we need existing_object_id for those?
140
- await TaskContext.gather(*[self.load(dep) for dep in obj.deps()])
141
-
142
- # Load the object itself
143
- if not obj._load:
144
- raise Exception(f"Object {obj} has no loader function")
145
-
146
- await obj._load(obj, self, existing_object_id)
147
-
148
- # Check that the id of functions didn't change
149
- # Persisted refs are ignored because their life cycle is managed independently.
150
- if (
151
- not obj._is_another_app
152
- and existing_object_id is not None
153
- and existing_object_id.startswith("fu-")
154
- and obj.object_id != existing_object_id
155
- ):
156
- raise Exception(
157
- f"Tried creating an object using existing id {existing_object_id} but it has id {obj.object_id}"
158
- )
123
+ with suppress_tb_frame():
124
+ load_context = await obj._load_context_overrides.merged_with(parent_load_context).apply_defaults()
159
125
 
160
- return obj
126
+ # TODO(erikbern): do we need existing_object_id for those?
127
+ await TaskContext.gather(*[self.load(dep, load_context) for dep in obj.deps()])
128
+
129
+ # Load the object itself
130
+ if not obj._load:
131
+ raise Exception(f"Object {obj} has no loader function")
132
+
133
+ await obj._load(obj, self, load_context, existing_object_id)
134
+
135
+ # Check that the id of functions didn't change
136
+ # Persisted refs are ignored because their life cycle is managed independently.
137
+ if (
138
+ not obj._is_another_app
139
+ and existing_object_id is not None
140
+ and existing_object_id.startswith("fu-")
141
+ and obj.object_id != existing_object_id
142
+ ):
143
+ raise Exception(
144
+ f"Tried creating an object using existing id {existing_object_id} "
145
+ f"but it has id {obj.object_id}"
146
+ )
147
+
148
+ return obj
161
149
 
162
150
  cached_future = asyncio.create_task(loader())
163
151
  self._local_uuid_to_future[obj.local_uuid] = cached_future
164
152
  if deduplication_key is not None:
165
153
  self._deduplication_cache[deduplication_key] = cached_future
166
- with suppress_tb_frames(2):
167
- # skip current frame + `loader()` closure frame from above
154
+ with suppress_tb_frame():
168
155
  return await cached_future
169
156
 
170
157
  def objects(self) -> list["modal._object._Object"]:
@@ -36,7 +36,7 @@ from modal._traceback import print_exception
36
36
  from modal._utils.async_utils import TaskContext, aclosing, asyncify, synchronize_api, synchronizer
37
37
  from modal._utils.blob_utils import MAX_OBJECT_SIZE_BYTES, blob_download, blob_upload, format_blob_data
38
38
  from modal._utils.function_utils import _stream_function_call_data
39
- from modal._utils.grpc_utils import retry_transient_errors
39
+ from modal._utils.grpc_utils import Retry
40
40
  from modal._utils.package_utils import parse_major_minor_version
41
41
  from modal.client import HEARTBEAT_INTERVAL, HEARTBEAT_TIMEOUT, _Client
42
42
  from modal.config import config, logger
@@ -278,11 +278,13 @@ class IOContext:
278
278
  logger.debug(f"Finished generator input {self.input_ids}")
279
279
 
280
280
  async def output_items_cancellation(self, started_at: float):
281
+ output_created_at = time.time()
281
282
  # Create terminated outputs for these inputs to signal that the cancellations have been completed.
282
283
  return [
283
284
  api_pb2.FunctionPutOutputsItem(
284
285
  input_id=input_id,
285
286
  input_started_at=started_at,
287
+ output_created_at=output_created_at,
286
288
  result=api_pb2.GenericResult(status=api_pb2.GenericResult.GENERIC_STATUS_TERMINATED),
287
289
  retry_count=retry_count,
288
290
  )
@@ -354,10 +356,12 @@ class IOContext:
354
356
  }
355
357
 
356
358
  # all inputs in the batch get the same failure:
359
+ output_created_at = time.time()
357
360
  return [
358
361
  api_pb2.FunctionPutOutputsItem(
359
362
  input_id=input_id,
360
363
  input_started_at=started_at,
364
+ output_created_at=output_created_at,
361
365
  retry_count=retry_count,
362
366
  **data_format_specific_output(function_input.data_format),
363
367
  )
@@ -619,8 +623,8 @@ class _ContainerIOManager:
619
623
  await self.heartbeat_condition.wait()
620
624
 
621
625
  request = api_pb2.ContainerHeartbeatRequest(canceled_inputs_return_outputs_v2=True)
622
- response = await retry_transient_errors(
623
- self._client.stub.ContainerHeartbeat, request, attempt_timeout=HEARTBEAT_TIMEOUT
626
+ response = await self._client.stub.ContainerHeartbeat(
627
+ request, retry=Retry(attempt_timeout=HEARTBEAT_TIMEOUT)
624
628
  )
625
629
 
626
630
  if response.HasField("cancel_input_event"):
@@ -667,10 +671,9 @@ class _ContainerIOManager:
667
671
  target_concurrency=self._target_concurrency,
668
672
  max_concurrency=self._max_concurrency,
669
673
  )
670
- resp = await retry_transient_errors(
671
- self._client.stub.FunctionGetDynamicConcurrency,
674
+ resp = await self._client.stub.FunctionGetDynamicConcurrency(
672
675
  request,
673
- attempt_timeout=DYNAMIC_CONCURRENCY_TIMEOUT_SECS,
676
+ retry=Retry(attempt_timeout=DYNAMIC_CONCURRENCY_TIMEOUT_SECS),
674
677
  )
675
678
  if resp.concurrency != self._input_slots.value and not self._stop_concurrency_loop:
676
679
  logger.debug(f"Dynamic concurrency set from {self._input_slots.value} to {resp.concurrency}")
@@ -721,9 +724,9 @@ class _ContainerIOManager:
721
724
 
722
725
  if self.input_plane_server_url:
723
726
  stub = await self._client.get_stub(self.input_plane_server_url)
724
- await retry_transient_errors(stub.FunctionCallPutDataOut, req)
727
+ await stub.FunctionCallPutDataOut(req)
725
728
  else:
726
- await retry_transient_errors(self._client.stub.FunctionCallPutDataOut, req)
729
+ await self._client.stub.FunctionCallPutDataOut(req)
727
730
 
728
731
  @asynccontextmanager
729
732
  async def generator_output_sender(
@@ -811,9 +814,7 @@ class _ContainerIOManager:
811
814
  try:
812
815
  # If number of active inputs is at max queue size, this will block.
813
816
  iteration += 1
814
- response: api_pb2.FunctionGetInputsResponse = await retry_transient_errors(
815
- self._client.stub.FunctionGetInputs, request
816
- )
817
+ response: api_pb2.FunctionGetInputsResponse = await self._client.stub.FunctionGetInputs(request)
817
818
 
818
819
  if response.rate_limit_sleep_duration:
819
820
  logger.info(
@@ -844,8 +845,9 @@ class _ContainerIOManager:
844
845
  yield inputs
845
846
  yielded = True
846
847
 
847
- # We only support max_inputs = 1 at the moment
848
- if final_input_received or self.function_def.max_inputs == 1:
848
+ # TODO(michael): Remove use of max_inputs after worker rollover
849
+ single_use_container = self.function_def.single_use_containers or self.function_def.max_inputs == 1
850
+ if final_input_received or single_use_container:
849
851
  return
850
852
  finally:
851
853
  if not yielded:
@@ -883,11 +885,12 @@ class _ContainerIOManager:
883
885
  # Limit the batch size to 20 to stay within message size limits and buffer size limits.
884
886
  output_batch_size = 20
885
887
  for i in range(0, len(outputs), output_batch_size):
886
- await retry_transient_errors(
887
- self._client.stub.FunctionPutOutputs,
888
+ await self._client.stub.FunctionPutOutputs(
888
889
  api_pb2.FunctionPutOutputsRequest(outputs=outputs[i : i + output_batch_size]),
889
- additional_status_codes=[Status.RESOURCE_EXHAUSTED],
890
- max_retries=None, # Retry indefinitely, trying every 1s.
890
+ retry=Retry(
891
+ additional_status_codes=[Status.RESOURCE_EXHAUSTED],
892
+ max_retries=None, # Retry indefinitely, trying every 1s.
893
+ ),
891
894
  )
892
895
  input_ids = [output.input_id for output in outputs]
893
896
  self.exit_context(started_at, input_ids)
@@ -928,7 +931,7 @@ class _ContainerIOManager:
928
931
  )
929
932
 
930
933
  req = api_pb2.TaskResultRequest(result=result)
931
- await retry_transient_errors(self._client.stub.TaskResult, req)
934
+ await self._client.stub.TaskResult(req)
932
935
 
933
936
  # Shut down the task gracefully
934
937
  raise UserException()
@@ -989,12 +992,10 @@ class _ContainerIOManager:
989
992
  # Busy-wait for restore. `/__modal/restore-state.json` is created
990
993
  # by the worker process with updates to the container config.
991
994
  restored_path = Path(config.get("restore_state_path"))
992
- start = time.perf_counter()
995
+ logger.debug("Waiting for restore")
993
996
  while not restored_path.exists():
994
- logger.debug(f"Waiting for restore (elapsed={time.perf_counter() - start:.3f}s)")
995
997
  await asyncio.sleep(0.01)
996
998
  continue
997
-
998
999
  logger.debug("Container: restored")
999
1000
 
1000
1001
  # Look for state file and create new client with updated credentials.
@@ -1005,7 +1006,7 @@ class _ContainerIOManager:
1005
1006
  # Start a debugger if the worker tells us to
1006
1007
  if int(restored_state.get("snapshot_debug", 0)):
1007
1008
  logger.debug("Entering snapshot debugger")
1008
- breakpoint()
1009
+ breakpoint() # noqa: T100
1009
1010
 
1010
1011
  # Local ContainerIOManager state.
1011
1012
  for key in ["task_id", "function_id"]:
@@ -1078,13 +1079,14 @@ class _ContainerIOManager:
1078
1079
  await asyncify(os.sync)()
1079
1080
  results = await asyncio.gather(
1080
1081
  *[
1081
- retry_transient_errors(
1082
- self._client.stub.VolumeCommit,
1082
+ self._client.stub.VolumeCommit(
1083
1083
  api_pb2.VolumeCommitRequest(volume_id=v_id),
1084
- max_retries=9,
1085
- base_delay=0.25,
1086
- max_delay=256,
1087
- delay_factor=2,
1084
+ retry=Retry(
1085
+ max_retries=9,
1086
+ base_delay=0.25,
1087
+ max_delay=256,
1088
+ delay_factor=2,
1089
+ ),
1088
1090
  )
1089
1091
  for v_id in volume_ids
1090
1092
  ],
@@ -252,8 +252,6 @@ class _ContainerIOManager:
252
252
  @classmethod
253
253
  def stop_fetching_inputs(cls): ...
254
254
 
255
- SUPERSELF = typing.TypeVar("SUPERSELF", covariant=True)
256
-
257
255
  class ContainerIOManager:
258
256
  """Synchronizes all RPC calls and network operations for a running container.
259
257
 
@@ -298,47 +296,47 @@ class ContainerIOManager:
298
296
  """Only used for tests."""
299
297
  ...
300
298
 
301
- class __hello_spec(typing_extensions.Protocol[SUPERSELF]):
299
+ class __hello_spec(typing_extensions.Protocol):
302
300
  def __call__(self, /): ...
303
301
  async def aio(self, /): ...
304
302
 
305
- hello: __hello_spec[typing_extensions.Self]
303
+ hello: __hello_spec
306
304
 
307
- class ___run_heartbeat_loop_spec(typing_extensions.Protocol[SUPERSELF]):
305
+ class ___run_heartbeat_loop_spec(typing_extensions.Protocol):
308
306
  def __call__(self, /): ...
309
307
  async def aio(self, /): ...
310
308
 
311
- _run_heartbeat_loop: ___run_heartbeat_loop_spec[typing_extensions.Self]
309
+ _run_heartbeat_loop: ___run_heartbeat_loop_spec
312
310
 
313
- class ___heartbeat_handle_cancellations_spec(typing_extensions.Protocol[SUPERSELF]):
311
+ class ___heartbeat_handle_cancellations_spec(typing_extensions.Protocol):
314
312
  def __call__(self, /) -> bool: ...
315
313
  async def aio(self, /) -> bool: ...
316
314
 
317
- _heartbeat_handle_cancellations: ___heartbeat_handle_cancellations_spec[typing_extensions.Self]
315
+ _heartbeat_handle_cancellations: ___heartbeat_handle_cancellations_spec
318
316
 
319
- class __heartbeats_spec(typing_extensions.Protocol[SUPERSELF]):
317
+ class __heartbeats_spec(typing_extensions.Protocol):
320
318
  def __call__(
321
319
  self, /, wait_for_mem_snap: bool
322
320
  ) -> synchronicity.combined_types.AsyncAndBlockingContextManager[None]: ...
323
321
  def aio(self, /, wait_for_mem_snap: bool) -> typing.AsyncContextManager[None]: ...
324
322
 
325
- heartbeats: __heartbeats_spec[typing_extensions.Self]
323
+ heartbeats: __heartbeats_spec
326
324
 
327
325
  def stop_heartbeat(self): ...
328
326
 
329
- class __dynamic_concurrency_manager_spec(typing_extensions.Protocol[SUPERSELF]):
327
+ class __dynamic_concurrency_manager_spec(typing_extensions.Protocol):
330
328
  def __call__(self, /) -> synchronicity.combined_types.AsyncAndBlockingContextManager[None]: ...
331
329
  def aio(self, /) -> typing.AsyncContextManager[None]: ...
332
330
 
333
- dynamic_concurrency_manager: __dynamic_concurrency_manager_spec[typing_extensions.Self]
331
+ dynamic_concurrency_manager: __dynamic_concurrency_manager_spec
334
332
 
335
- class ___dynamic_concurrency_loop_spec(typing_extensions.Protocol[SUPERSELF]):
333
+ class ___dynamic_concurrency_loop_spec(typing_extensions.Protocol):
336
334
  def __call__(self, /): ...
337
335
  async def aio(self, /): ...
338
336
 
339
- _dynamic_concurrency_loop: ___dynamic_concurrency_loop_spec[typing_extensions.Self]
337
+ _dynamic_concurrency_loop: ___dynamic_concurrency_loop_spec
340
338
 
341
- class __get_data_in_spec(typing_extensions.Protocol[SUPERSELF]):
339
+ class __get_data_in_spec(typing_extensions.Protocol):
342
340
  def __call__(
343
341
  self, /, function_call_id: str, attempt_token: typing.Optional[str]
344
342
  ) -> typing.Iterator[typing.Any]:
@@ -351,9 +349,9 @@ class ContainerIOManager:
351
349
  """Read from the `data_in` stream of a function call."""
352
350
  ...
353
351
 
354
- get_data_in: __get_data_in_spec[typing_extensions.Self]
352
+ get_data_in: __get_data_in_spec
355
353
 
356
- class __put_data_out_spec(typing_extensions.Protocol[SUPERSELF]):
354
+ class __put_data_out_spec(typing_extensions.Protocol):
357
355
  def __call__(
358
356
  self,
359
357
  /,
@@ -388,9 +386,9 @@ class ContainerIOManager:
388
386
  """
389
387
  ...
390
388
 
391
- put_data_out: __put_data_out_spec[typing_extensions.Self]
389
+ put_data_out: __put_data_out_spec
392
390
 
393
- class __generator_output_sender_spec(typing_extensions.Protocol[SUPERSELF]):
391
+ class __generator_output_sender_spec(typing_extensions.Protocol):
394
392
  def __call__(
395
393
  self, /, function_call_id: str, attempt_token: str, data_format: int, message_rx: asyncio.queues.Queue
396
394
  ) -> synchronicity.combined_types.AsyncAndBlockingContextManager[None]:
@@ -403,9 +401,9 @@ class ContainerIOManager:
403
401
  """Runs background task that feeds generator outputs into a function call's `data_out` stream."""
404
402
  ...
405
403
 
406
- generator_output_sender: __generator_output_sender_spec[typing_extensions.Self]
404
+ generator_output_sender: __generator_output_sender_spec
407
405
 
408
- class ___queue_create_spec(typing_extensions.Protocol[SUPERSELF]):
406
+ class ___queue_create_spec(typing_extensions.Protocol):
409
407
  def __call__(self, /, size: int) -> asyncio.queues.Queue:
410
408
  """Create a queue, on the synchronicity event loop (needed on Python 3.8 and 3.9)."""
411
409
  ...
@@ -414,9 +412,9 @@ class ContainerIOManager:
414
412
  """Create a queue, on the synchronicity event loop (needed on Python 3.8 and 3.9)."""
415
413
  ...
416
414
 
417
- _queue_create: ___queue_create_spec[typing_extensions.Self]
415
+ _queue_create: ___queue_create_spec
418
416
 
419
- class ___queue_put_spec(typing_extensions.Protocol[SUPERSELF]):
417
+ class ___queue_put_spec(typing_extensions.Protocol):
420
418
  def __call__(self, /, queue: asyncio.queues.Queue, value: typing.Any) -> None:
421
419
  """Put a value onto a queue, using the synchronicity event loop."""
422
420
  ...
@@ -425,12 +423,12 @@ class ContainerIOManager:
425
423
  """Put a value onto a queue, using the synchronicity event loop."""
426
424
  ...
427
425
 
428
- _queue_put: ___queue_put_spec[typing_extensions.Self]
426
+ _queue_put: ___queue_put_spec
429
427
 
430
428
  def get_average_call_time(self) -> float: ...
431
429
  def get_max_inputs_to_fetch(self): ...
432
430
 
433
- class ___generate_inputs_spec(typing_extensions.Protocol[SUPERSELF]):
431
+ class ___generate_inputs_spec(typing_extensions.Protocol):
434
432
  def __call__(
435
433
  self, /, batch_max_size: int, batch_wait_ms: int
436
434
  ) -> typing.Iterator[list[tuple[str, int, str, str, modal_proto.api_pb2.FunctionInput]]]: ...
@@ -438,9 +436,9 @@ class ContainerIOManager:
438
436
  self, /, batch_max_size: int, batch_wait_ms: int
439
437
  ) -> collections.abc.AsyncIterator[list[tuple[str, int, str, str, modal_proto.api_pb2.FunctionInput]]]: ...
440
438
 
441
- _generate_inputs: ___generate_inputs_spec[typing_extensions.Self]
439
+ _generate_inputs: ___generate_inputs_spec
442
440
 
443
- class __run_inputs_outputs_spec(typing_extensions.Protocol[SUPERSELF]):
441
+ class __run_inputs_outputs_spec(typing_extensions.Protocol):
444
442
  def __call__(
445
443
  self,
446
444
  /,
@@ -456,9 +454,9 @@ class ContainerIOManager:
456
454
  batch_wait_ms: int = 0,
457
455
  ) -> collections.abc.AsyncIterator[IOContext]: ...
458
456
 
459
- run_inputs_outputs: __run_inputs_outputs_spec[typing_extensions.Self]
457
+ run_inputs_outputs: __run_inputs_outputs_spec
460
458
 
461
- class ___send_outputs_spec(typing_extensions.Protocol[SUPERSELF]):
459
+ class ___send_outputs_spec(typing_extensions.Protocol):
462
460
  def __call__(self, /, started_at: float, outputs: list[modal_proto.api_pb2.FunctionPutOutputsItem]) -> None:
463
461
  """Send pre-built output items with retry and chunking."""
464
462
  ...
@@ -467,9 +465,9 @@ class ContainerIOManager:
467
465
  """Send pre-built output items with retry and chunking."""
468
466
  ...
469
467
 
470
- _send_outputs: ___send_outputs_spec[typing_extensions.Self]
468
+ _send_outputs: ___send_outputs_spec
471
469
 
472
- class __handle_user_exception_spec(typing_extensions.Protocol[SUPERSELF]):
470
+ class __handle_user_exception_spec(typing_extensions.Protocol):
473
471
  def __call__(self, /) -> synchronicity.combined_types.AsyncAndBlockingContextManager[None]:
474
472
  """Sets the task as failed in a way where it's not retried.
475
473
 
@@ -486,9 +484,9 @@ class ContainerIOManager:
486
484
  """
487
485
  ...
488
486
 
489
- handle_user_exception: __handle_user_exception_spec[typing_extensions.Self]
487
+ handle_user_exception: __handle_user_exception_spec
490
488
 
491
- class __handle_input_exception_spec(typing_extensions.Protocol[SUPERSELF]):
489
+ class __handle_input_exception_spec(typing_extensions.Protocol):
492
490
  def __call__(
493
491
  self, /, io_context: IOContext, started_at: float
494
492
  ) -> synchronicity.combined_types.AsyncAndBlockingContextManager[None]:
@@ -499,23 +497,23 @@ class ContainerIOManager:
499
497
  """Handle an exception while processing a function input."""
500
498
  ...
501
499
 
502
- handle_input_exception: __handle_input_exception_spec[typing_extensions.Self]
500
+ handle_input_exception: __handle_input_exception_spec
503
501
 
504
502
  def exit_context(self, started_at, input_ids: list[str]): ...
505
503
 
506
- class __push_outputs_spec(typing_extensions.Protocol[SUPERSELF]):
504
+ class __push_outputs_spec(typing_extensions.Protocol):
507
505
  def __call__(self, /, io_context: IOContext, started_at: float, output_data: list[typing.Any]) -> None: ...
508
506
  async def aio(self, /, io_context: IOContext, started_at: float, output_data: list[typing.Any]) -> None: ...
509
507
 
510
- push_outputs: __push_outputs_spec[typing_extensions.Self]
508
+ push_outputs: __push_outputs_spec
511
509
 
512
- class __memory_restore_spec(typing_extensions.Protocol[SUPERSELF]):
510
+ class __memory_restore_spec(typing_extensions.Protocol):
513
511
  def __call__(self, /) -> None: ...
514
512
  async def aio(self, /) -> None: ...
515
513
 
516
- memory_restore: __memory_restore_spec[typing_extensions.Self]
514
+ memory_restore: __memory_restore_spec
517
515
 
518
- class __memory_snapshot_spec(typing_extensions.Protocol[SUPERSELF]):
516
+ class __memory_snapshot_spec(typing_extensions.Protocol):
519
517
  def __call__(self, /) -> None:
520
518
  """Message server indicating that function is ready to be checkpointed."""
521
519
  ...
@@ -524,9 +522,9 @@ class ContainerIOManager:
524
522
  """Message server indicating that function is ready to be checkpointed."""
525
523
  ...
526
524
 
527
- memory_snapshot: __memory_snapshot_spec[typing_extensions.Self]
525
+ memory_snapshot: __memory_snapshot_spec
528
526
 
529
- class __volume_commit_spec(typing_extensions.Protocol[SUPERSELF]):
527
+ class __volume_commit_spec(typing_extensions.Protocol):
530
528
  def __call__(self, /, volume_ids: list[str]) -> None:
531
529
  """Perform volume commit for given `volume_ids`.
532
530
  Only used on container exit to persist uncommitted changes on behalf of user.
@@ -539,13 +537,13 @@ class ContainerIOManager:
539
537
  """
540
538
  ...
541
539
 
542
- volume_commit: __volume_commit_spec[typing_extensions.Self]
540
+ volume_commit: __volume_commit_spec
543
541
 
544
- class __interact_spec(typing_extensions.Protocol[SUPERSELF]):
542
+ class __interact_spec(typing_extensions.Protocol):
545
543
  def __call__(self, /, from_breakpoint: bool = False): ...
546
544
  async def aio(self, /, from_breakpoint: bool = False): ...
547
545
 
548
- interact: __interact_spec[typing_extensions.Self]
546
+ interact: __interact_spec
549
547
 
550
548
  @property
551
549
  def target_concurrency(self) -> int: ...