modal 1.1.5.dev83__py3-none-any.whl → 1.3.1.dev8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of modal might be problematic. Click here for more details.

Files changed (139) hide show
  1. modal/__init__.py +4 -4
  2. modal/__main__.py +4 -29
  3. modal/_billing.py +84 -0
  4. modal/_clustered_functions.py +1 -3
  5. modal/_container_entrypoint.py +33 -208
  6. modal/_functions.py +146 -121
  7. modal/_grpc_client.py +191 -0
  8. modal/_ipython.py +16 -6
  9. modal/_load_context.py +106 -0
  10. modal/_object.py +72 -21
  11. modal/_output.py +12 -14
  12. modal/_partial_function.py +31 -4
  13. modal/_resolver.py +44 -57
  14. modal/_runtime/container_io_manager.py +26 -28
  15. modal/_runtime/container_io_manager.pyi +42 -44
  16. modal/_runtime/gpu_memory_snapshot.py +9 -7
  17. modal/_runtime/user_code_event_loop.py +80 -0
  18. modal/_runtime/user_code_imports.py +236 -10
  19. modal/_serialization.py +2 -1
  20. modal/_traceback.py +4 -13
  21. modal/_tunnel.py +16 -11
  22. modal/_tunnel.pyi +25 -3
  23. modal/_utils/async_utils.py +337 -10
  24. modal/_utils/auth_token_manager.py +1 -4
  25. modal/_utils/blob_utils.py +29 -22
  26. modal/_utils/function_utils.py +20 -21
  27. modal/_utils/grpc_testing.py +6 -3
  28. modal/_utils/grpc_utils.py +223 -64
  29. modal/_utils/mount_utils.py +26 -1
  30. modal/_utils/package_utils.py +0 -1
  31. modal/_utils/rand_pb_testing.py +8 -1
  32. modal/_utils/task_command_router_client.py +524 -0
  33. modal/_vendor/cloudpickle.py +144 -48
  34. modal/app.py +215 -96
  35. modal/app.pyi +78 -37
  36. modal/billing.py +5 -0
  37. modal/builder/2025.06.txt +6 -3
  38. modal/builder/PREVIEW.txt +2 -1
  39. modal/builder/base-images.json +4 -2
  40. modal/cli/_download.py +19 -3
  41. modal/cli/cluster.py +4 -2
  42. modal/cli/config.py +3 -1
  43. modal/cli/container.py +5 -4
  44. modal/cli/dict.py +5 -2
  45. modal/cli/entry_point.py +26 -2
  46. modal/cli/environment.py +2 -16
  47. modal/cli/launch.py +1 -76
  48. modal/cli/network_file_system.py +5 -20
  49. modal/cli/queues.py +5 -4
  50. modal/cli/run.py +24 -204
  51. modal/cli/secret.py +1 -2
  52. modal/cli/shell.py +375 -0
  53. modal/cli/utils.py +1 -13
  54. modal/cli/volume.py +11 -17
  55. modal/client.py +16 -125
  56. modal/client.pyi +94 -144
  57. modal/cloud_bucket_mount.py +3 -1
  58. modal/cloud_bucket_mount.pyi +4 -0
  59. modal/cls.py +101 -64
  60. modal/cls.pyi +9 -8
  61. modal/config.py +21 -1
  62. modal/container_process.py +288 -12
  63. modal/container_process.pyi +99 -38
  64. modal/dict.py +72 -33
  65. modal/dict.pyi +88 -57
  66. modal/environments.py +16 -8
  67. modal/environments.pyi +6 -2
  68. modal/exception.py +154 -16
  69. modal/experimental/__init__.py +23 -5
  70. modal/experimental/flash.py +161 -74
  71. modal/experimental/flash.pyi +97 -49
  72. modal/file_io.py +50 -92
  73. modal/file_io.pyi +117 -89
  74. modal/functions.pyi +70 -87
  75. modal/image.py +73 -47
  76. modal/image.pyi +33 -30
  77. modal/io_streams.py +500 -149
  78. modal/io_streams.pyi +279 -189
  79. modal/mount.py +60 -45
  80. modal/mount.pyi +41 -17
  81. modal/network_file_system.py +19 -11
  82. modal/network_file_system.pyi +72 -39
  83. modal/object.pyi +114 -22
  84. modal/parallel_map.py +42 -44
  85. modal/parallel_map.pyi +9 -17
  86. modal/partial_function.pyi +4 -2
  87. modal/proxy.py +14 -6
  88. modal/proxy.pyi +10 -2
  89. modal/queue.py +45 -38
  90. modal/queue.pyi +88 -52
  91. modal/runner.py +96 -96
  92. modal/runner.pyi +44 -27
  93. modal/sandbox.py +225 -108
  94. modal/sandbox.pyi +226 -63
  95. modal/secret.py +58 -56
  96. modal/secret.pyi +28 -13
  97. modal/serving.py +7 -11
  98. modal/serving.pyi +7 -8
  99. modal/snapshot.py +29 -15
  100. modal/snapshot.pyi +18 -10
  101. modal/token_flow.py +1 -1
  102. modal/token_flow.pyi +4 -6
  103. modal/volume.py +102 -55
  104. modal/volume.pyi +125 -66
  105. {modal-1.1.5.dev83.dist-info → modal-1.3.1.dev8.dist-info}/METADATA +10 -9
  106. modal-1.3.1.dev8.dist-info/RECORD +189 -0
  107. modal_proto/api.proto +86 -30
  108. modal_proto/api_grpc.py +10 -25
  109. modal_proto/api_pb2.py +1080 -1047
  110. modal_proto/api_pb2.pyi +253 -79
  111. modal_proto/api_pb2_grpc.py +14 -48
  112. modal_proto/api_pb2_grpc.pyi +6 -18
  113. modal_proto/modal_api_grpc.py +175 -176
  114. modal_proto/{sandbox_router.proto → task_command_router.proto} +62 -45
  115. modal_proto/task_command_router_grpc.py +138 -0
  116. modal_proto/task_command_router_pb2.py +180 -0
  117. modal_proto/{sandbox_router_pb2.pyi → task_command_router_pb2.pyi} +110 -63
  118. modal_proto/task_command_router_pb2_grpc.py +272 -0
  119. modal_proto/task_command_router_pb2_grpc.pyi +100 -0
  120. modal_version/__init__.py +1 -1
  121. modal_version/__main__.py +1 -1
  122. modal/cli/programs/launch_instance_ssh.py +0 -94
  123. modal/cli/programs/run_marimo.py +0 -95
  124. modal-1.1.5.dev83.dist-info/RECORD +0 -191
  125. modal_proto/modal_options_grpc.py +0 -3
  126. modal_proto/options.proto +0 -19
  127. modal_proto/options_grpc.py +0 -3
  128. modal_proto/options_pb2.py +0 -35
  129. modal_proto/options_pb2.pyi +0 -20
  130. modal_proto/options_pb2_grpc.py +0 -4
  131. modal_proto/options_pb2_grpc.pyi +0 -7
  132. modal_proto/sandbox_router_grpc.py +0 -105
  133. modal_proto/sandbox_router_pb2.py +0 -148
  134. modal_proto/sandbox_router_pb2_grpc.py +0 -203
  135. modal_proto/sandbox_router_pb2_grpc.pyi +0 -75
  136. {modal-1.1.5.dev83.dist-info → modal-1.3.1.dev8.dist-info}/WHEEL +0 -0
  137. {modal-1.1.5.dev83.dist-info → modal-1.3.1.dev8.dist-info}/entry_points.txt +0 -0
  138. {modal-1.1.5.dev83.dist-info → modal-1.3.1.dev8.dist-info}/licenses/LICENSE +0 -0
  139. {modal-1.1.5.dev83.dist-info → modal-1.3.1.dev8.dist-info}/top_level.txt +0 -0
modal/exception.py CHANGED
@@ -1,7 +1,45 @@
1
1
  # Copyright Modal Labs 2022
2
+ """
3
+ Modal-specific exception types.
4
+
5
+ ## Notes on `grpclib.GRPCError` migration
6
+
7
+ Historically, the Modal SDK could propagate `grpclib.GRPCError` exceptions out
8
+ to user code. As of v1.3, we are in the process of gracefully migrating to
9
+ always raising a Modal exception type in these cases. To avoid breaking user
10
+ code that relies on catching `grpclib.GRPCError`, a subset of Modal exception
11
+ types temporarily inherit from `grpclib.GRPCError`.
12
+
13
+ We encourage users to migrate any code that currently catches `grpclib.GRPCError`
14
+ to instead catch the appropriate Modal exception type. The following mapping
15
+ between GRPCError status codes and Modal exception types is currently in use:
16
+
17
+ ```
18
+ CANCELLED -> ServiceError
19
+ UNKNOWN -> ServiceError
20
+ INVALID_ARGUMENT -> InvalidError
21
+ DEADLINE_EXCEEDED -> ServiceError
22
+ NOT_FOUND -> NotFoundError
23
+ ALREADY_EXISTS -> AlreadyExistsError
24
+ PERMISSION_DENIED -> PermissionDeniedError
25
+ RESOURCE_EXHAUSTED -> ResourceExhaustedError
26
+ FAILED_PRECONDITION -> ConflictError
27
+ ABORTED -> ConflictError
28
+ OUT_OF_RANGE -> InvalidError
29
+ UNIMPLEMENTED -> UnimplementedError
30
+ INTERNAL -> InternalError
31
+ UNAVAILABLE -> ServiceError
32
+ DATA_LOSS -> DataLossError
33
+ UNAUTHENTICATED -> AuthError
34
+ ```
35
+
36
+ """
37
+
2
38
  import random
3
39
  import signal
40
+ from typing import Any, Optional
4
41
 
42
+ import grpclib
5
43
  import synchronicity.exceptions
6
44
 
7
45
  UserCodeException = synchronicity.exceptions.UserCodeException # Deprecated type used for return_exception wrapping
@@ -26,10 +64,116 @@ class Error(Exception):
26
64
  """
27
65
 
28
66
 
29
- class AlreadyExistsError(Error):
67
+ class _GRPCErrorWrapper(grpclib.GRPCError):
68
+ """This transitional class helps us migrate away from propagating `grpclib.GRPCError` to users.
69
+
70
+ It serves two purposes:
71
+ - It avoids abruptly breaking user code that catches `grpclib.GRPCError`
72
+ - It actively warns when users access attributes defined by `grpclib.GRPCError`
73
+
74
+ This won't catch all cases (users might react indiscriminately to GRPCError without checking the status).
75
+
76
+ The mapping between GRPCError status codes and our error types is defined in `modal._grpc_client`.
77
+
78
+ """
79
+
80
+ # These will be set on the instance in our error handling middleware
81
+ _grpc_message: str
82
+ _grpc_status: grpclib.Status
83
+ _grpc_details: Any
84
+
85
+ def __init__(self, message: Optional[str] = None):
86
+ # Override GRPCError's init and repr to behave more like a regular Exception
87
+ # (We don't customize these anywhere in our custom error types currently).
88
+ self._message = message or ""
89
+
90
+ def __repr__(self) -> str:
91
+ return f"{type(self).__name__}({self._message!r})"
92
+
93
+ def _warn_on_grpc_error_attribute_access(self) -> None:
94
+ from ._utils.deprecation import deprecation_warning # Avoid circular import
95
+
96
+ exc_type = type(self).__name__
97
+ deprecation_warning(
98
+ (2025, 12, 9),
99
+ "Modal will stop propagating the `grpclib.GRPCError` type in the future. "
100
+ f"Update your code so that it catches `modal.exception.{exc_type}` directly "
101
+ "to avoid changes to error handling behavior in the future.",
102
+ pending=True,
103
+ )
104
+
105
+ @property
106
+ def message(self) -> str:
107
+ self._warn_on_grpc_error_attribute_access()
108
+ return self._grpc_message
109
+
110
+ @message.setter
111
+ def message(self, value: str) -> None:
112
+ self._grpc_message = value
113
+
114
+ @property
115
+ def status(self) -> grpclib.Status:
116
+ self._warn_on_grpc_error_attribute_access()
117
+ return self._grpc_status
118
+
119
+ @status.setter
120
+ def status(self, value: grpclib.Status) -> None:
121
+ self._grpc_status = value
122
+
123
+ @property
124
+ def details(self) -> Any:
125
+ self._warn_on_grpc_error_attribute_access()
126
+ return self._grpc_details
127
+
128
+ @details.setter
129
+ def details(self, value: Any) -> None:
130
+ self._grpc_details = value
131
+
132
+
133
+ class AlreadyExistsError(Error, _GRPCErrorWrapper):
30
134
  """Raised when a resource creation conflicts with an existing resource."""
31
135
 
32
136
 
137
+ class AuthError(Error, _GRPCErrorWrapper):
138
+ """Raised when a client has missing or invalid authentication."""
139
+
140
+
141
+ class InternalError(Error, _GRPCErrorWrapper):
142
+ """Raised when an internal error occurs in the Modal system."""
143
+
144
+
145
+ class InvalidError(Error, _GRPCErrorWrapper):
146
+ """Raised when user does something invalid."""
147
+
148
+
149
+ class ConflictError(InvalidError, _GRPCErrorWrapper):
150
+ """Raised when a resource conflict occurs between the request and current system state."""
151
+
152
+
153
+ class DataLossError(Error, _GRPCErrorWrapper):
154
+ """Raised when data is lost or corrupted."""
155
+
156
+
157
+ class NotFoundError(Error, _GRPCErrorWrapper):
158
+ """Raised when a requested resource was not found."""
159
+
160
+
161
+ class PermissionDeniedError(Error, _GRPCErrorWrapper):
162
+ """Raised when a user does not have permission to perform the requested operation."""
163
+
164
+
165
+ class ResourceExhaustedError(Error, _GRPCErrorWrapper):
166
+ """Raised when a server-side resource has been exhausted, e.g. a quota or rate limit."""
167
+
168
+
169
+ class ServiceError(Error, _GRPCErrorWrapper):
170
+ """Raised when an error occurs in basic client/server communication."""
171
+
172
+
173
+ class UnimplementedError(Error, _GRPCErrorWrapper):
174
+ """Raised when a requested operation is not implemented or not supported."""
175
+
176
+
33
177
  class RemoteError(Error):
34
178
  """Raised when an error occurs on the Modal server."""
35
179
 
@@ -42,6 +186,10 @@ class SandboxTimeoutError(TimeoutError):
42
186
  """Raised when a Sandbox exceeds its execution duration limit and times out."""
43
187
 
44
188
 
189
+ class ExecTimeoutError(TimeoutError):
190
+ """Raised when a container process exceeds its execution duration limit and times out."""
191
+
192
+
45
193
  class SandboxTerminatedError(Error):
46
194
  """Raised when a Sandbox is terminated for an internal reason."""
47
195
 
@@ -66,26 +214,14 @@ class OutputExpiredError(TimeoutError):
66
214
  """Raised when the Output exceeds expiration and times out."""
67
215
 
68
216
 
69
- class AuthError(Error):
70
- """Raised when a client has missing or invalid authentication."""
71
-
72
-
73
217
  class ConnectionError(Error):
74
218
  """Raised when an issue occurs while connecting to the Modal servers."""
75
219
 
76
220
 
77
- class InvalidError(Error):
78
- """Raised when user does something invalid."""
79
-
80
-
81
221
  class VersionError(Error):
82
222
  """Raised when the current client version of Modal is unsupported."""
83
223
 
84
224
 
85
- class NotFoundError(Error):
86
- """Raised when a requested resource was not found."""
87
-
88
-
89
225
  class ExecutionError(Error):
90
226
  """Raised when something unexpected happened during runtime."""
91
227
 
@@ -116,10 +252,12 @@ class ServerWarning(UserWarning):
116
252
  """Warning originating from the Modal server and re-issued in client code."""
117
253
 
118
254
 
255
+ class AsyncUsageWarning(UserWarning):
256
+ """Warning emitted when a blocking Modal interface is used in an async context."""
257
+
258
+
119
259
  class InternalFailure(Error):
120
- """
121
- Retriable internal error.
122
- """
260
+ """Retriable internal error."""
123
261
 
124
262
 
125
263
  class _CliUserExecutionError(Exception):
@@ -13,14 +13,18 @@ from .._object import _get_environment_name
13
13
  from .._partial_function import _clustered
14
14
  from .._runtime.container_io_manager import _ContainerIOManager
15
15
  from .._utils.async_utils import synchronize_api, synchronizer
16
- from .._utils.grpc_utils import retry_transient_errors
17
16
  from ..app import _App
18
17
  from ..client import _Client
19
18
  from ..cls import _Cls
20
19
  from ..exception import InvalidError
21
20
  from ..image import DockerfileSpec, ImageBuilderVersion, _Image, _ImageRegistryConfig
22
21
  from ..secret import _Secret
23
- from .flash import flash_forward, flash_get_containers, flash_prometheus_autoscaler # noqa: F401
22
+ from .flash import ( # noqa: F401
23
+ flash_forward,
24
+ flash_get_containers,
25
+ flash_prometheus_autoscaler,
26
+ http_server,
27
+ )
24
28
 
25
29
 
26
30
  def stop_fetching_inputs():
@@ -86,6 +90,19 @@ async def list_deployed_apps(environment_name: str = "", client: Optional[_Clien
86
90
  return app_infos
87
91
 
88
92
 
93
+ @synchronizer.create_blocking
94
+ async def stop_app(name: str, *, environment_name: Optional[str] = None, client: Optional[_Client] = None) -> None:
95
+ """Stop a deployed App.
96
+
97
+ This interface is experimental and may change in the future,
98
+ although the functionality will continue to be supported.
99
+ """
100
+ client_ = client or await _Client.from_env()
101
+ app = await _App.lookup(name, environment_name=environment_name, client=client_)
102
+ req = api_pb2.AppStopRequest(app_id=app.app_id, source=api_pb2.APP_STOP_SOURCE_PYTHON_CLIENT)
103
+ await client_.stub.AppStop(req)
104
+
105
+
89
106
  @synchronizer.create_blocking
90
107
  async def get_app_objects(
91
108
  app_name: str, *, environment_name: Optional[str] = None, client: Optional[_Client] = None
@@ -116,7 +133,7 @@ async def get_app_objects(
116
133
 
117
134
  app = await _App.lookup(app_name, environment_name=environment_name, client=client)
118
135
  req = api_pb2.AppGetLayoutRequest(app_id=app.app_id)
119
- app_layout_resp = await retry_transient_errors(client.stub.AppGetLayout, req)
136
+ app_layout_resp = await client.stub.AppGetLayout(req)
120
137
 
121
138
  app_objects: dict[str, Union[_Function, _Cls]] = {}
122
139
 
@@ -347,7 +364,8 @@ async def image_delete(
347
364
  ) -> None:
348
365
  """Delete an Image by its ID.
349
366
 
350
- Deletion is irreversible and will prevent Apps from using the Image.
367
+ Deletion is irreversible and will prevent Functions/Sandboxes from using
368
+ the Image.
351
369
 
352
370
  This is an experimental interface for a feature that we will be adding to
353
371
  the main Image class. The stable form of this interface may look different.
@@ -361,4 +379,4 @@ async def image_delete(
361
379
  client = await _Client.from_env()
362
380
 
363
381
  req = api_pb2.ImageDeleteRequest(image_id=image_id)
364
- await retry_transient_errors(client.stub.ImageDelete, req)
382
+ await client.stub.ImageDelete(req)
@@ -7,16 +7,16 @@ import sys
7
7
  import time
8
8
  import traceback
9
9
  from collections import defaultdict
10
- from typing import Any, Optional
10
+ from typing import Any, Callable, Optional, Union
11
11
  from urllib.parse import urlparse
12
12
 
13
+ from modal._partial_function import _PartialFunctionFlags
13
14
  from modal.cls import _Cls
14
15
  from modal.dict import _Dict
15
16
  from modal_proto import api_pb2
16
17
 
17
18
  from .._tunnel import _forward as _forward_tunnel
18
19
  from .._utils.async_utils import synchronize_api, synchronizer
19
- from .._utils.grpc_utils import retry_transient_errors
20
20
  from ..client import _Client
21
21
  from ..config import logger
22
22
  from ..exception import InvalidError
@@ -29,15 +29,20 @@ class _FlashManager:
29
29
  self,
30
30
  client: _Client,
31
31
  port: int,
32
- process: Optional[subprocess.Popen] = None,
32
+ process: Optional[subprocess.Popen] = None, # to be deprecated
33
33
  health_check_url: Optional[str] = None,
34
+ startup_timeout: int = 30,
35
+ exit_grace_period: int = 0,
36
+ h2_enabled: bool = False,
34
37
  ):
35
38
  self.client = client
36
39
  self.port = port
40
+ self.process = process
37
41
  # Health check is not currently being used
38
42
  self.health_check_url = health_check_url
39
- self.process = process
40
- self.tunnel_manager = _forward_tunnel(port, client=client)
43
+ self.startup_timeout = startup_timeout
44
+ self.exit_grace_period = exit_grace_period
45
+ self.tunnel_manager = _forward_tunnel(port, h2_enabled=h2_enabled, client=client)
41
46
  self.stopped = False
42
47
  self.num_failures = 0
43
48
  self.task_id = os.environ["MODAL_TASK_ID"]
@@ -49,10 +54,15 @@ class _FlashManager:
49
54
 
50
55
  start_time = time.monotonic()
51
56
 
57
+ def check_process_is_running() -> Optional[Exception]:
58
+ if process is not None and process.poll() is not None:
59
+ return Exception(f"Process {process.pid} exited with code {process.returncode}")
60
+ return None
61
+
52
62
  while time.monotonic() - start_time < timeout:
53
63
  try:
54
- if process is not None and process.poll() is not None:
55
- return False, Exception(f"Process {process.pid} exited with code {process.returncode}")
64
+ if error := check_process_is_running():
65
+ return False, error
56
66
  with socket.create_connection(("localhost", self.port), timeout=0.5):
57
67
  return True, None
58
68
  except (ConnectionRefusedError, OSError):
@@ -101,6 +111,7 @@ class _FlashManager:
101
111
 
102
112
  async def _run_heartbeat(self, host: str, port: int):
103
113
  first_registration = True
114
+ start_time = time.monotonic()
104
115
  while True:
105
116
  try:
106
117
  port_check_resp, port_check_error = await self.is_port_connection_healthy(process=self.process)
@@ -113,6 +124,7 @@ class _FlashManager:
113
124
  port=port,
114
125
  ),
115
126
  timeout=10,
127
+ retry=None,
116
128
  )
117
129
  self.num_failures = 0
118
130
  if first_registration:
@@ -121,15 +133,16 @@ class _FlashManager:
121
133
  )
122
134
  first_registration = False
123
135
  else:
124
- logger.error(
125
- f"[Modal Flash] Deregistering container {self.task_id} on {self.tunnel.url} "
126
- f"due to error: {port_check_error}, num_failures: {self.num_failures}"
127
- )
128
- self.num_failures += 1
129
- await retry_transient_errors(
130
- self.client.stub.FlashContainerDeregister,
131
- api_pb2.FlashContainerDeregisterRequest(),
132
- )
136
+ if first_registration and (time.monotonic() - start_time < self.startup_timeout):
137
+ continue
138
+ else:
139
+ logger.error(
140
+ f"[Modal Flash] Deregistering container {self.task_id} on {self.tunnel.url} "
141
+ f"due to error: {port_check_error}, num_failures: {self.num_failures}"
142
+ )
143
+ self.num_failures += 1
144
+ await self.client.stub.FlashContainerDeregister(api_pb2.FlashContainerDeregisterRequest())
145
+
133
146
  except asyncio.CancelledError:
134
147
  logger.warning("[Modal Flash] Shutting down...")
135
148
  break
@@ -147,12 +160,12 @@ class _FlashManager:
147
160
  return self.tunnel.url
148
161
 
149
162
  async def stop(self):
150
- self.heartbeat_task.cancel()
151
- await retry_transient_errors(
152
- self.client.stub.FlashContainerDeregister,
153
- api_pb2.FlashContainerDeregisterRequest(),
154
- )
163
+ try:
164
+ self.heartbeat_task.cancel()
165
+ except Exception as e:
166
+ logger.error(f"[Modal Flash] Error stopping: {e}")
155
167
 
168
+ await self.client.stub.FlashContainerDeregister(api_pb2.FlashContainerDeregisterRequest())
156
169
  self.stopped = True
157
170
  logger.warning(f"[Modal Flash] No longer accepting new requests on {self.tunnel.url}.")
158
171
 
@@ -163,18 +176,23 @@ class _FlashManager:
163
176
  if not self.stopped:
164
177
  await self.stop()
165
178
 
179
+ await asyncio.sleep(self.exit_grace_period)
180
+
166
181
  logger.warning(f"[Modal Flash] Closing tunnel on {self.tunnel.url}.")
167
182
  await self.tunnel_manager.__aexit__(*sys.exc_info())
168
183
 
169
184
 
170
- FlashManager = synchronize_api(_FlashManager)
185
+ FlashManager = synchronize_api(_FlashManager, target_module=__name__)
171
186
 
172
187
 
173
188
  @synchronizer.create_blocking
174
189
  async def flash_forward(
175
190
  port: int,
176
- process: Optional[subprocess.Popen] = None,
191
+ process: Optional[subprocess.Popen] = None, # to be deprecated
177
192
  health_check_url: Optional[str] = None,
193
+ startup_timeout: int = 30,
194
+ exit_grace_period: int = 0,
195
+ h2_enabled: bool = False,
178
196
  ) -> _FlashManager:
179
197
  """
180
198
  Forward a port to the Modal Flash service, exposing that port as a stable web endpoint.
@@ -183,7 +201,15 @@ async def flash_forward(
183
201
  """
184
202
  client = await _Client.from_env()
185
203
 
186
- manager = _FlashManager(client, port, process=process, health_check_url=health_check_url)
204
+ manager = _FlashManager(
205
+ client,
206
+ port,
207
+ process=process,
208
+ health_check_url=health_check_url,
209
+ startup_timeout=startup_timeout,
210
+ exit_grace_period=exit_grace_period,
211
+ h2_enabled=h2_enabled,
212
+ )
187
213
  await manager._start()
188
214
  return manager
189
215
 
@@ -321,7 +347,7 @@ class _FlashPrometheusAutoscaler:
321
347
 
322
348
  async def _compute_target_containers(self, current_replicas: int) -> int:
323
349
  """
324
- Gets internal metrics from container to autoscale up or down.
350
+ Gets metrics from container to autoscale up or down.
325
351
  """
326
352
  containers = await self._get_all_containers()
327
353
  if len(containers) > current_replicas:
@@ -334,7 +360,7 @@ class _FlashPrometheusAutoscaler:
334
360
  if current_replicas == 0:
335
361
  return 1
336
362
 
337
- # Get metrics based on autoscaler type (prometheus or internal)
363
+ # Get metrics based on autoscaler type
338
364
  sum_metric, n_containers_with_metrics = await self._get_scaling_info(containers)
339
365
 
340
366
  desired_replicas = self._calculate_desired_replicas(
@@ -406,39 +432,26 @@ class _FlashPrometheusAutoscaler:
406
432
  return desired_replicas
407
433
 
408
434
  async def _get_scaling_info(self, containers) -> tuple[float, int]:
409
- """Get metrics using either internal container metrics API or prometheus HTTP endpoints."""
410
- if self.metrics_endpoint == "internal":
411
- container_metrics_results = await asyncio.gather(
412
- *[self._get_container_metrics(container.task_id) for container in containers]
413
- )
414
- container_metrics_list = []
415
- for container_metric in container_metrics_results:
416
- if container_metric is None:
417
- continue
418
- container_metrics_list.append(getattr(container_metric.metrics, self.target_metric))
419
-
420
- sum_metric = sum(container_metrics_list)
421
- n_containers_with_metrics = len(container_metrics_list)
422
- else:
423
- sum_metric = 0
424
- n_containers_with_metrics = 0
425
-
426
- container_metrics_list = await asyncio.gather(
427
- *[
428
- self._get_metrics(f"https://{container.host}:{container.port}/{self.metrics_endpoint}")
429
- for container in containers
430
- ]
431
- )
435
+ """Get metrics using container exposed metrics endpoints."""
436
+ sum_metric = 0
437
+ n_containers_with_metrics = 0
438
+
439
+ container_metrics_list = await asyncio.gather(
440
+ *[
441
+ self._get_metrics(f"https://{container.host}:{container.port}/{self.metrics_endpoint}")
442
+ for container in containers
443
+ ]
444
+ )
432
445
 
433
- for container_metrics in container_metrics_list:
434
- if (
435
- container_metrics is None
436
- or self.target_metric not in container_metrics
437
- or len(container_metrics[self.target_metric]) == 0
438
- ):
439
- continue
440
- sum_metric += container_metrics[self.target_metric][0].value
441
- n_containers_with_metrics += 1
446
+ for container_metrics in container_metrics_list:
447
+ if (
448
+ container_metrics is None
449
+ or self.target_metric not in container_metrics
450
+ or len(container_metrics[self.target_metric]) == 0
451
+ ):
452
+ continue
453
+ sum_metric += container_metrics[self.target_metric][0].value
454
+ n_containers_with_metrics += 1
442
455
 
443
456
  return sum_metric, n_containers_with_metrics
444
457
 
@@ -474,23 +487,14 @@ class _FlashPrometheusAutoscaler:
474
487
 
475
488
  return metrics
476
489
 
477
- async def _get_container_metrics(self, container_id: str) -> Optional[api_pb2.TaskGetAutoscalingMetricsResponse]:
478
- req = api_pb2.TaskGetAutoscalingMetricsRequest(task_id=container_id)
479
- try:
480
- resp = await retry_transient_errors(self.client.stub.TaskGetAutoscalingMetrics, req)
481
- return resp
482
- except Exception as e:
483
- logger.warning(f"[Modal Flash] Error getting metrics for container {container_id}: {e}")
484
- return None
485
-
486
490
  async def _get_all_containers(self):
487
491
  req = api_pb2.FlashContainerListRequest(function_id=self.fn.object_id)
488
- resp = await retry_transient_errors(self.client.stub.FlashContainerList, req)
492
+ resp = await self.client.stub.FlashContainerList(req)
489
493
  return resp.containers
490
494
 
491
495
  async def _set_target_slots(self, target_slots: int):
492
496
  req = api_pb2.FlashSetTargetSlotsMetricsRequest(function_id=self.fn.object_id, target_slots=target_slots)
493
- await retry_transient_errors(self.client.stub.FlashSetTargetSlotsMetrics, req)
497
+ await self.client.stub.FlashSetTargetSlotsMetrics(req)
494
498
  return
495
499
 
496
500
  def _make_scaling_decision(
@@ -572,14 +576,10 @@ async def flash_prometheus_autoscaler(
572
576
  app_name: str,
573
577
  cls_name: str,
574
578
  # Endpoint to fetch metrics from. Must be in Prometheus format. Example: "/metrics"
575
- # If metrics_endpoint is "internal", we will use containers' internal metrics to autoscale instead.
576
579
  metrics_endpoint: str,
577
580
  # Target metric to autoscale on. Example: "vllm:num_requests_running"
578
- # If metrics_endpoint is "internal", target_metrics options are: [cpu_usage_percent, memory_usage_percent]
579
581
  target_metric: str,
580
582
  # Target metric value. Example: 25
581
- # If metrics_endpoint is "internal", target_metric_value is a percentage value between 0.1 and 1.0 (inclusive),
582
- # indicating container's usage of that metric.
583
583
  target_metric_value: float,
584
584
  min_containers: Optional[int] = None,
585
585
  max_containers: Optional[int] = None,
@@ -645,5 +645,92 @@ async def flash_get_containers(app_name: str, cls_name: str) -> list[dict[str, A
645
645
  assert fn is not None
646
646
  await fn.hydrate(client=client)
647
647
  req = api_pb2.FlashContainerListRequest(function_id=fn.object_id)
648
- resp = await retry_transient_errors(client.stub.FlashContainerList, req)
648
+ resp = await client.stub.FlashContainerList(req)
649
649
  return resp.containers
650
+
651
+
652
+ def _http_server(
653
+ port: Optional[int] = None,
654
+ *,
655
+ proxy_regions: list[str] = [], # The regions to proxy the HTTP server to.
656
+ startup_timeout: int = 30, # Maximum number of seconds to wait for the HTTP server to start.
657
+ exit_grace_period: Optional[int] = None, # The time to wait for the HTTP server to exit gracefully.
658
+ h2_enabled: bool = False, # Whether to enable HTTP/2 support.
659
+ ):
660
+ """Decorator for Flash-enabled HTTP servers on Modal classes.
661
+
662
+ Args:
663
+ port: The local port to forward to the HTTP server.
664
+ proxy_regions: The regions to proxy the HTTP server to.
665
+ startup_timeout: The maximum time to wait for the HTTP server to start.
666
+ exit_grace_period: The time to wait for the HTTP server to exit gracefully.
667
+
668
+ """
669
+ if port is None:
670
+ raise InvalidError(
671
+ "Positional arguments are not allowed. Did you forget parentheses? Suggestion: `@modal.http_server()`."
672
+ )
673
+ if not isinstance(port, int) or port < 1 or port > 65535:
674
+ raise InvalidError("First argument of `@http_server` must be a local port, such as `@http_server(8000)`.")
675
+ if startup_timeout <= 0:
676
+ raise InvalidError("The `startup_timeout` argument of `@http_server` must be positive.")
677
+ if exit_grace_period is not None and exit_grace_period < 0:
678
+ raise InvalidError("The `exit_grace_period` argument of `@http_server` must be non-negative.")
679
+
680
+ from modal._partial_function import _PartialFunction, _PartialFunctionParams
681
+
682
+ params = _PartialFunctionParams(
683
+ http_config=api_pb2.HTTPConfig(
684
+ port=port,
685
+ proxy_regions=proxy_regions,
686
+ startup_timeout=startup_timeout or 0,
687
+ exit_grace_period=exit_grace_period or 0,
688
+ h2_enabled=h2_enabled,
689
+ )
690
+ )
691
+
692
+ def wrapper(obj: Union[Callable[..., Any], _PartialFunction]) -> _PartialFunction:
693
+ flags = _PartialFunctionFlags.HTTP_WEB_INTERFACE
694
+
695
+ if isinstance(obj, _PartialFunction):
696
+ pf = obj.stack(flags, params)
697
+ else:
698
+ pf = _PartialFunction(obj, flags, params)
699
+ pf.validate_obj_compatibility("`http_server`")
700
+ return pf
701
+
702
+ return wrapper
703
+
704
+
705
+ http_server = synchronize_api(_http_server, target_module=__name__)
706
+
707
+
708
+ class _FlashContainerEntry:
709
+ """
710
+ A class that manages the lifecycle of Flash manager for Flash containers.
711
+
712
+ It is intentional that stop() runs before exit handlers and close().
713
+ This ensures the container is deregistered first, preventing new requests from being routed to it
714
+ while exit handlers execute and the exit grace period elapses, before finally closing the tunnel.
715
+ """
716
+
717
+ def __init__(self, http_config: api_pb2.HTTPConfig):
718
+ self.http_config: api_pb2.HTTPConfig = http_config
719
+ self.flash_manager: Optional[FlashManager] = None # type: ignore
720
+
721
+ def enter(self):
722
+ if self.http_config != api_pb2.HTTPConfig():
723
+ self.flash_manager = flash_forward(
724
+ self.http_config.port,
725
+ startup_timeout=self.http_config.startup_timeout,
726
+ exit_grace_period=self.http_config.exit_grace_period,
727
+ h2_enabled=self.http_config.h2_enabled,
728
+ )
729
+
730
+ def stop(self):
731
+ if self.flash_manager:
732
+ self.flash_manager.stop()
733
+
734
+ def close(self):
735
+ if self.flash_manager:
736
+ self.flash_manager.close()