modal 1.0.6.dev58__py3-none-any.whl → 1.2.3.dev7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of modal might be problematic. Click here for more details.

Files changed (147) hide show
  1. modal/__main__.py +3 -4
  2. modal/_billing.py +80 -0
  3. modal/_clustered_functions.py +7 -3
  4. modal/_clustered_functions.pyi +4 -2
  5. modal/_container_entrypoint.py +41 -49
  6. modal/_functions.py +424 -195
  7. modal/_grpc_client.py +171 -0
  8. modal/_load_context.py +105 -0
  9. modal/_object.py +68 -20
  10. modal/_output.py +58 -45
  11. modal/_partial_function.py +36 -11
  12. modal/_pty.py +7 -3
  13. modal/_resolver.py +21 -35
  14. modal/_runtime/asgi.py +4 -3
  15. modal/_runtime/container_io_manager.py +301 -186
  16. modal/_runtime/container_io_manager.pyi +70 -61
  17. modal/_runtime/execution_context.py +18 -2
  18. modal/_runtime/execution_context.pyi +4 -1
  19. modal/_runtime/gpu_memory_snapshot.py +170 -63
  20. modal/_runtime/user_code_imports.py +28 -58
  21. modal/_serialization.py +57 -1
  22. modal/_utils/async_utils.py +33 -12
  23. modal/_utils/auth_token_manager.py +2 -5
  24. modal/_utils/blob_utils.py +110 -53
  25. modal/_utils/function_utils.py +49 -42
  26. modal/_utils/grpc_utils.py +80 -50
  27. modal/_utils/mount_utils.py +26 -1
  28. modal/_utils/name_utils.py +17 -3
  29. modal/_utils/task_command_router_client.py +536 -0
  30. modal/_utils/time_utils.py +34 -6
  31. modal/app.py +219 -83
  32. modal/app.pyi +229 -56
  33. modal/billing.py +5 -0
  34. modal/{requirements → builder}/2025.06.txt +1 -0
  35. modal/{requirements → builder}/PREVIEW.txt +1 -0
  36. modal/cli/_download.py +19 -3
  37. modal/cli/_traceback.py +3 -2
  38. modal/cli/app.py +4 -4
  39. modal/cli/cluster.py +15 -7
  40. modal/cli/config.py +5 -3
  41. modal/cli/container.py +7 -6
  42. modal/cli/dict.py +22 -16
  43. modal/cli/entry_point.py +12 -5
  44. modal/cli/environment.py +5 -4
  45. modal/cli/import_refs.py +3 -3
  46. modal/cli/launch.py +102 -5
  47. modal/cli/network_file_system.py +9 -13
  48. modal/cli/profile.py +3 -2
  49. modal/cli/programs/launch_instance_ssh.py +94 -0
  50. modal/cli/programs/run_jupyter.py +1 -1
  51. modal/cli/programs/run_marimo.py +95 -0
  52. modal/cli/programs/vscode.py +1 -1
  53. modal/cli/queues.py +57 -26
  54. modal/cli/run.py +58 -16
  55. modal/cli/secret.py +48 -22
  56. modal/cli/utils.py +3 -4
  57. modal/cli/volume.py +28 -25
  58. modal/client.py +13 -116
  59. modal/client.pyi +9 -91
  60. modal/cloud_bucket_mount.py +5 -3
  61. modal/cloud_bucket_mount.pyi +5 -1
  62. modal/cls.py +130 -102
  63. modal/cls.pyi +45 -85
  64. modal/config.py +29 -10
  65. modal/container_process.py +291 -13
  66. modal/container_process.pyi +95 -32
  67. modal/dict.py +282 -63
  68. modal/dict.pyi +423 -73
  69. modal/environments.py +15 -27
  70. modal/environments.pyi +5 -15
  71. modal/exception.py +8 -0
  72. modal/experimental/__init__.py +143 -38
  73. modal/experimental/flash.py +247 -78
  74. modal/experimental/flash.pyi +137 -9
  75. modal/file_io.py +14 -28
  76. modal/file_io.pyi +2 -2
  77. modal/file_pattern_matcher.py +25 -16
  78. modal/functions.pyi +134 -61
  79. modal/image.py +255 -86
  80. modal/image.pyi +300 -62
  81. modal/io_streams.py +436 -126
  82. modal/io_streams.pyi +236 -171
  83. modal/mount.py +62 -157
  84. modal/mount.pyi +45 -172
  85. modal/network_file_system.py +30 -53
  86. modal/network_file_system.pyi +16 -76
  87. modal/object.pyi +42 -8
  88. modal/parallel_map.py +821 -113
  89. modal/parallel_map.pyi +134 -0
  90. modal/partial_function.pyi +4 -1
  91. modal/proxy.py +16 -7
  92. modal/proxy.pyi +10 -2
  93. modal/queue.py +263 -61
  94. modal/queue.pyi +409 -66
  95. modal/runner.py +112 -92
  96. modal/runner.pyi +45 -27
  97. modal/sandbox.py +451 -124
  98. modal/sandbox.pyi +513 -67
  99. modal/secret.py +291 -67
  100. modal/secret.pyi +425 -19
  101. modal/serving.py +7 -11
  102. modal/serving.pyi +7 -8
  103. modal/snapshot.py +11 -8
  104. modal/token_flow.py +4 -4
  105. modal/volume.py +344 -98
  106. modal/volume.pyi +464 -68
  107. {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/METADATA +9 -8
  108. modal-1.2.3.dev7.dist-info/RECORD +195 -0
  109. modal_docs/mdmd/mdmd.py +11 -1
  110. modal_proto/api.proto +399 -67
  111. modal_proto/api_grpc.py +241 -1
  112. modal_proto/api_pb2.py +1395 -1000
  113. modal_proto/api_pb2.pyi +1239 -79
  114. modal_proto/api_pb2_grpc.py +499 -4
  115. modal_proto/api_pb2_grpc.pyi +162 -14
  116. modal_proto/modal_api_grpc.py +175 -160
  117. modal_proto/sandbox_router.proto +145 -0
  118. modal_proto/sandbox_router_grpc.py +105 -0
  119. modal_proto/sandbox_router_pb2.py +149 -0
  120. modal_proto/sandbox_router_pb2.pyi +333 -0
  121. modal_proto/sandbox_router_pb2_grpc.py +203 -0
  122. modal_proto/sandbox_router_pb2_grpc.pyi +75 -0
  123. modal_proto/task_command_router.proto +144 -0
  124. modal_proto/task_command_router_grpc.py +105 -0
  125. modal_proto/task_command_router_pb2.py +149 -0
  126. modal_proto/task_command_router_pb2.pyi +333 -0
  127. modal_proto/task_command_router_pb2_grpc.py +203 -0
  128. modal_proto/task_command_router_pb2_grpc.pyi +75 -0
  129. modal_version/__init__.py +1 -1
  130. modal-1.0.6.dev58.dist-info/RECORD +0 -183
  131. modal_proto/modal_options_grpc.py +0 -3
  132. modal_proto/options.proto +0 -19
  133. modal_proto/options_grpc.py +0 -3
  134. modal_proto/options_pb2.py +0 -35
  135. modal_proto/options_pb2.pyi +0 -20
  136. modal_proto/options_pb2_grpc.py +0 -4
  137. modal_proto/options_pb2_grpc.pyi +0 -7
  138. /modal/{requirements → builder}/2023.12.312.txt +0 -0
  139. /modal/{requirements → builder}/2023.12.txt +0 -0
  140. /modal/{requirements → builder}/2024.04.txt +0 -0
  141. /modal/{requirements → builder}/2024.10.txt +0 -0
  142. /modal/{requirements → builder}/README.md +0 -0
  143. /modal/{requirements → builder}/base-images.json +0 -0
  144. {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/WHEEL +0 -0
  145. {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/entry_points.txt +0 -0
  146. {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/licenses/LICENSE +0 -0
  147. {modal-1.0.6.dev58.dist-info → modal-1.2.3.dev7.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,8 @@
1
1
  # Copyright Modal Labs 2025
2
2
  import asyncio
3
3
  import math
4
+ import os
5
+ import subprocess
4
6
  import sys
5
7
  import time
6
8
  import traceback
@@ -14,45 +16,118 @@ from modal_proto import api_pb2
14
16
 
15
17
  from .._tunnel import _forward as _forward_tunnel
16
18
  from .._utils.async_utils import synchronize_api, synchronizer
17
- from .._utils.grpc_utils import retry_transient_errors
18
19
  from ..client import _Client
19
20
  from ..config import logger
20
21
  from ..exception import InvalidError
21
22
 
23
+ _MAX_FAILURES = 10
24
+
22
25
 
23
26
  class _FlashManager:
24
- def __init__(self, client: _Client, port: int, health_check_url: Optional[str] = None):
27
+ def __init__(
28
+ self,
29
+ client: _Client,
30
+ port: int,
31
+ process: Optional[subprocess.Popen] = None,
32
+ health_check_url: Optional[str] = None,
33
+ ):
25
34
  self.client = client
26
35
  self.port = port
36
+ # Health check is not currently being used
27
37
  self.health_check_url = health_check_url
38
+ self.process = process
28
39
  self.tunnel_manager = _forward_tunnel(port, client=client)
29
40
  self.stopped = False
41
+ self.num_failures = 0
42
+ self.task_id = os.environ["MODAL_TASK_ID"]
43
+
44
+ async def is_port_connection_healthy(
45
+ self, process: Optional[subprocess.Popen], timeout: float = 0.5
46
+ ) -> tuple[bool, Optional[Exception]]:
47
+ import socket
48
+
49
+ start_time = time.monotonic()
50
+
51
+ while time.monotonic() - start_time < timeout:
52
+ try:
53
+ if process is not None and process.poll() is not None:
54
+ return False, Exception(f"Process {process.pid} exited with code {process.returncode}")
55
+ with socket.create_connection(("localhost", self.port), timeout=0.5):
56
+ return True, None
57
+ except (ConnectionRefusedError, OSError):
58
+ await asyncio.sleep(0.1)
59
+
60
+ return False, Exception(f"Waited too long for port {self.port} to start accepting connections")
30
61
 
31
62
  async def _start(self):
32
63
  self.tunnel = await self.tunnel_manager.__aenter__()
33
-
34
64
  parsed_url = urlparse(self.tunnel.url)
35
65
  host = parsed_url.hostname
36
66
  port = parsed_url.port or 443
37
67
 
38
68
  self.heartbeat_task = asyncio.create_task(self._run_heartbeat(host, port))
69
+ self.drain_task = asyncio.create_task(self._drain_container())
70
+
71
+ async def _drain_container(self):
72
+ """
73
+ Background task that checks if we've encountered too many failures and drains the container if so.
74
+ """
75
+ while True:
76
+ try:
77
+ # Check if the container should be drained (e.g., too many failures)
78
+ if self.num_failures > _MAX_FAILURES:
79
+ logger.warning(
80
+ f"[Modal Flash] Draining task {self.task_id} on {self.tunnel.url} due to too many failures."
81
+ )
82
+ await self.stop()
83
+ # handle close upon container exit
84
+
85
+ if self.task_id:
86
+ await self.client.stub.ContainerStop(api_pb2.ContainerStopRequest(task_id=self.task_id))
87
+ return
88
+ except asyncio.CancelledError:
89
+ logger.warning("[Modal Flash] Shutting down...")
90
+ return
91
+ except Exception as e:
92
+ logger.error(f"[Modal Flash] Error draining container: {e}")
93
+ await asyncio.sleep(1)
94
+
95
+ try:
96
+ await asyncio.sleep(1)
97
+ except asyncio.CancelledError:
98
+ logger.warning("[Modal Flash] Shutting down...")
99
+ return
39
100
 
40
101
  async def _run_heartbeat(self, host: str, port: int):
41
102
  first_registration = True
42
103
  while True:
43
104
  try:
44
- resp = await self.client.stub.FlashContainerRegister(
45
- api_pb2.FlashContainerRegisterRequest(
46
- priority=10,
47
- weight=5,
48
- host=host,
49
- port=port,
50
- ),
51
- timeout=10,
52
- )
53
- if first_registration:
54
- logger.warning(f"[Modal Flash] Listening at {resp.url}")
55
- first_registration = False
105
+ port_check_resp, port_check_error = await self.is_port_connection_healthy(process=self.process)
106
+ if port_check_resp:
107
+ resp = await self.client.stub.FlashContainerRegister(
108
+ api_pb2.FlashContainerRegisterRequest(
109
+ priority=10,
110
+ weight=5,
111
+ host=host,
112
+ port=port,
113
+ ),
114
+ timeout=10,
115
+ retry=None,
116
+ )
117
+ self.num_failures = 0
118
+ if first_registration:
119
+ logger.warning(
120
+ f"[Modal Flash] Listening at {resp.url} over {self.tunnel.url} for task_id {self.task_id}"
121
+ )
122
+ first_registration = False
123
+ else:
124
+ logger.error(
125
+ f"[Modal Flash] Deregistering container {self.task_id} on {self.tunnel.url} "
126
+ f"due to error: {port_check_error}, num_failures: {self.num_failures}"
127
+ )
128
+ self.num_failures += 1
129
+ await self.client.stub.FlashContainerDeregister(api_pb2.FlashContainerDeregisterRequest())
130
+
56
131
  except asyncio.CancelledError:
57
132
  logger.warning("[Modal Flash] Shutting down...")
58
133
  break
@@ -71,10 +146,7 @@ class _FlashManager:
71
146
 
72
147
  async def stop(self):
73
148
  self.heartbeat_task.cancel()
74
- await retry_transient_errors(
75
- self.client.stub.FlashContainerDeregister,
76
- api_pb2.FlashContainerDeregisterRequest(),
77
- )
149
+ await self.client.stub.FlashContainerDeregister(api_pb2.FlashContainerDeregisterRequest())
78
150
 
79
151
  self.stopped = True
80
152
  logger.warning(f"[Modal Flash] No longer accepting new requests on {self.tunnel.url}.")
@@ -94,16 +166,19 @@ FlashManager = synchronize_api(_FlashManager)
94
166
 
95
167
 
96
168
  @synchronizer.create_blocking
97
- async def flash_forward(port: int, health_check_url: Optional[str] = None) -> _FlashManager:
169
+ async def flash_forward(
170
+ port: int,
171
+ process: Optional[subprocess.Popen] = None,
172
+ health_check_url: Optional[str] = None,
173
+ ) -> _FlashManager:
98
174
  """
99
175
  Forward a port to the Modal Flash service, exposing that port as a stable web endpoint.
100
-
101
176
  This is a highly experimental method that can break or be removed at any time without warning.
102
177
  Do not use this method unless explicitly instructed to do so by Modal support.
103
178
  """
104
179
  client = await _Client.from_env()
105
180
 
106
- manager = _FlashManager(client, port, health_check_url)
181
+ manager = _FlashManager(client, port, process=process, health_check_url=health_check_url)
107
182
  await manager._start()
108
183
  return manager
109
184
 
@@ -121,12 +196,15 @@ class _FlashPrometheusAutoscaler:
121
196
  target_metric_value: float,
122
197
  min_containers: Optional[int],
123
198
  max_containers: Optional[int],
199
+ buffer_containers: Optional[int],
124
200
  scale_up_tolerance: float,
125
201
  scale_down_tolerance: float,
126
202
  scale_up_stabilization_window_seconds: int,
127
203
  scale_down_stabilization_window_seconds: int,
128
204
  autoscaling_interval_seconds: int,
129
205
  ):
206
+ import aiohttp
207
+
130
208
  if scale_up_stabilization_window_seconds > self._max_window_seconds:
131
209
  raise InvalidError(
132
210
  f"scale_up_stabilization_window_seconds must be less than or equal to {self._max_window_seconds}"
@@ -138,8 +216,6 @@ class _FlashPrometheusAutoscaler:
138
216
  if target_metric_value <= 0:
139
217
  raise InvalidError("target_metric_value must be greater than 0")
140
218
 
141
- import aiohttp
142
-
143
219
  self.client = client
144
220
  self.app_name = app_name
145
221
  self.cls_name = cls_name
@@ -148,6 +224,7 @@ class _FlashPrometheusAutoscaler:
148
224
  self.target_metric_value = target_metric_value
149
225
  self.min_containers = min_containers
150
226
  self.max_containers = max_containers
227
+ self.buffer_containers = buffer_containers
151
228
  self.scale_up_tolerance = scale_up_tolerance
152
229
  self.scale_down_tolerance = scale_down_tolerance
153
230
  self.scale_up_stabilization_window_seconds = scale_up_stabilization_window_seconds
@@ -200,7 +277,7 @@ class _FlashPrometheusAutoscaler:
200
277
  if timestamp >= autoscaling_time - self._max_window_seconds
201
278
  ]
202
279
 
203
- current_target_containers = await self._compute_target_containers(current_replicas)
280
+ current_target_containers = await self._compute_target_containers(current_replicas=current_replicas)
204
281
  autoscaling_decisions.append((autoscaling_time, current_target_containers))
205
282
 
206
283
  actual_target_containers = self._make_scaling_decision(
@@ -210,11 +287,12 @@ class _FlashPrometheusAutoscaler:
210
287
  scale_down_stabilization_window_seconds=self.scale_down_stabilization_window_seconds,
211
288
  min_containers=self.min_containers,
212
289
  max_containers=self.max_containers,
290
+ buffer_containers=self.buffer_containers,
213
291
  )
214
292
 
215
293
  logger.warning(
216
- f"[Modal Flash] Scaling to {actual_target_containers} containers. Autoscaling decision "
217
- f"made in {time.time() - autoscaling_time} seconds."
294
+ f"[Modal Flash] Scaling to {actual_target_containers=} containers. "
295
+ f" Autoscaling decision made in {time.time() - autoscaling_time} seconds."
218
296
  )
219
297
 
220
298
  await self.autoscaling_decisions_dict.put(
@@ -223,10 +301,7 @@ class _FlashPrometheusAutoscaler:
223
301
  )
224
302
  await self.autoscaling_decisions_dict.put("current_replicas", actual_target_containers)
225
303
 
226
- await self.cls.update_autoscaler(
227
- min_containers=actual_target_containers,
228
- max_containers=actual_target_containers,
229
- )
304
+ await self._set_target_slots(actual_target_containers)
230
305
 
231
306
  if time.time() - autoscaling_time < self.autoscaling_interval_seconds:
232
307
  await asyncio.sleep(self.autoscaling_interval_seconds - (time.time() - autoscaling_time))
@@ -240,6 +315,9 @@ class _FlashPrometheusAutoscaler:
240
315
  await asyncio.sleep(self.autoscaling_interval_seconds)
241
316
 
242
317
  async def _compute_target_containers(self, current_replicas: int) -> int:
318
+ """
319
+ Gets metrics from container to autoscale up or down.
320
+ """
243
321
  containers = await self._get_all_containers()
244
322
  if len(containers) > current_replicas:
245
323
  logger.info(
@@ -251,68 +329,128 @@ class _FlashPrometheusAutoscaler:
251
329
  if current_replicas == 0:
252
330
  return 1
253
331
 
254
- target_metric = self.target_metric
255
- target_metric_value = float(self.target_metric_value)
332
+ # Get metrics based on autoscaler type
333
+ sum_metric, n_containers_with_metrics = await self._get_scaling_info(containers)
256
334
 
257
- sum_metric = 0
258
- containers_with_metrics = 0
259
- container_metrics_list = await asyncio.gather(
260
- *[
261
- self._get_metrics(f"https://{container.host}:{container.port}/{self.metrics_endpoint}")
262
- for container in containers
263
- ]
335
+ desired_replicas = self._calculate_desired_replicas(
336
+ n_current_replicas=current_replicas,
337
+ sum_metric=sum_metric,
338
+ n_containers_with_metrics=n_containers_with_metrics,
339
+ n_total_containers=len(containers),
340
+ target_metric_value=self.target_metric_value,
264
341
  )
265
- for container_metrics in container_metrics_list:
266
- if (
267
- container_metrics is None
268
- or target_metric not in container_metrics
269
- or len(container_metrics[target_metric]) == 0
270
- ):
271
- continue
272
- sum_metric += container_metrics[target_metric][0].value
273
- containers_with_metrics += 1
274
342
 
275
- n_containers_missing_metric = current_replicas - containers_with_metrics
343
+ return max(1, desired_replicas)
276
344
 
277
- # Scale up / down conservatively: Any container that is missing the metric is assumed to be at the minimum
278
- # value of the metric when scaling up and the maximum value of the metric when scaling down.
279
- scale_up_target_metric_value = sum_metric / current_replicas
280
- scale_down_target_metric_value = (
281
- sum_metric + n_containers_missing_metric * target_metric_value
282
- ) / current_replicas
345
+ def _calculate_desired_replicas(
346
+ self,
347
+ n_current_replicas: int,
348
+ sum_metric: float,
349
+ n_containers_with_metrics: int,
350
+ n_total_containers: int,
351
+ target_metric_value: float,
352
+ ) -> int:
353
+ """
354
+ Calculate the desired number of replicas to autoscale to.
355
+ """
356
+ buffer_containers = self.buffer_containers or 0
357
+
358
+ # n_containers_missing = number of unhealthy containers + number of containers not registered in flash dns
359
+ n_containers_missing_metric = n_current_replicas - n_containers_with_metrics
360
+ # n_containers_unhealthy = number of dns registered containers that are not emitting metrics
361
+ n_containers_unhealthy = n_total_containers - n_containers_with_metrics
362
+
363
+ # Max is used to handle case when buffer_containers are first initialized.
364
+ num_provisioned_containers = max(n_current_replicas - buffer_containers, 1)
365
+
366
+ # Scale up assuming that every unhealthy container is at 1.5 x (1 + scale_up_tolerance) the target metric value.
367
+ # This way if all containers are unhealthy, we will increase our number of containers.
368
+ scale_up_target_metric_value = (
369
+ sum_metric + 1.5 * (1 + self.scale_up_tolerance) * n_containers_unhealthy * target_metric_value
370
+ ) / (num_provisioned_containers)
371
+
372
+ # Scale down assuming that every container (including cold starting containers) are at the target metric value.
373
+ # The denominator is just num_provisioned_containers because we don't want to account for the buffer containers.
374
+ scale_down_target_metric_value = (sum_metric + n_containers_missing_metric * target_metric_value) / (
375
+ num_provisioned_containers
376
+ )
283
377
 
284
378
  scale_up_ratio = scale_up_target_metric_value / target_metric_value
285
379
  scale_down_ratio = scale_down_target_metric_value / target_metric_value
286
380
 
287
- desired_replicas = current_replicas
381
+ desired_replicas = num_provisioned_containers
288
382
  if scale_up_ratio > 1 + self.scale_up_tolerance:
289
- desired_replicas = math.ceil(current_replicas * scale_up_ratio)
383
+ desired_replicas = math.ceil(desired_replicas * scale_up_ratio)
290
384
  elif scale_down_ratio < 1 - self.scale_down_tolerance:
291
- desired_replicas = math.ceil(current_replicas * scale_down_ratio)
385
+ desired_replicas = math.ceil(desired_replicas * scale_down_ratio)
292
386
 
293
387
  logger.warning(
294
- f"[Modal Flash] Current replicas: {current_replicas}, target metric value: {target_metric_value}, "
295
- f"current sum of metric values: {sum_metric}, number of containers missing metric: "
296
- f"{n_containers_missing_metric}, scale up ratio: {scale_up_ratio}, scale down ratio: {scale_down_ratio}, "
388
+ f"[Modal Flash] Current replicas: {n_current_replicas}, "
389
+ f"target metric: {self.target_metric}"
390
+ f"target metric value: {target_metric_value}, "
391
+ f"current sum of metric values: {sum_metric}, "
392
+ f"number of containers with metrics: {n_containers_with_metrics}, "
393
+ f"number of containers unhealthy: {n_containers_unhealthy}, "
394
+ f"number of containers missing metric (includes unhealthy): {n_containers_missing_metric}, "
395
+ f"number of provisioned containers: {num_provisioned_containers}, "
396
+ f"scale up ratio: {scale_up_ratio}, "
397
+ f"scale down ratio: {scale_down_ratio}, "
297
398
  f"desired replicas: {desired_replicas}"
298
399
  )
299
400
 
300
401
  return desired_replicas
301
402
 
403
+ async def _get_scaling_info(self, containers) -> tuple[float, int]:
404
+ """Get metrics using container exposed metrics endpoints."""
405
+ sum_metric = 0
406
+ n_containers_with_metrics = 0
407
+
408
+ container_metrics_list = await asyncio.gather(
409
+ *[
410
+ self._get_metrics(f"https://{container.host}:{container.port}/{self.metrics_endpoint}")
411
+ for container in containers
412
+ ]
413
+ )
414
+
415
+ for container_metrics in container_metrics_list:
416
+ if (
417
+ container_metrics is None
418
+ or self.target_metric not in container_metrics
419
+ or len(container_metrics[self.target_metric]) == 0
420
+ ):
421
+ continue
422
+ sum_metric += container_metrics[self.target_metric][0].value
423
+ n_containers_with_metrics += 1
424
+
425
+ return sum_metric, n_containers_with_metrics
426
+
302
427
  async def _get_metrics(self, url: str) -> Optional[dict[str, list[Any]]]: # technically any should be Sample
303
428
  from prometheus_client.parser import Sample, text_string_to_metric_families
304
429
 
305
430
  # Fetch the metrics from the endpoint
306
431
  try:
307
- response = await self.http_client.get(url)
432
+ response = await self.http_client.get(url, timeout=3)
308
433
  response.raise_for_status()
434
+ except asyncio.TimeoutError:
435
+ logger.warning(f"[Modal Flash] Timeout getting metrics from {url}")
436
+ return None
309
437
  except Exception as e:
310
438
  logger.warning(f"[Modal Flash] Error getting metrics from {url}: {e}")
311
439
  return None
312
440
 
441
+ # Read body with timeout/error handling and parse Prometheus metrics
442
+ try:
443
+ text_body = await response.text()
444
+ except asyncio.TimeoutError:
445
+ logger.warning(f"[Modal Flash] Timeout reading metrics body from {url}")
446
+ return None
447
+ except Exception as e:
448
+ logger.warning(f"[Modal Flash] Error reading metrics body from {url}: {e}")
449
+ return None
450
+
313
451
  # Parse the text-based Prometheus metrics format
314
452
  metrics: dict[str, list[Sample]] = defaultdict(list)
315
- for family in text_string_to_metric_families(await response.text()):
453
+ for family in text_string_to_metric_families(text_body):
316
454
  for sample in family.samples:
317
455
  metrics[sample.name] += [sample]
318
456
 
@@ -320,9 +458,14 @@ class _FlashPrometheusAutoscaler:
320
458
 
321
459
  async def _get_all_containers(self):
322
460
  req = api_pb2.FlashContainerListRequest(function_id=self.fn.object_id)
323
- resp = await retry_transient_errors(self.client.stub.FlashContainerList, req)
461
+ resp = await self.client.stub.FlashContainerList(req)
324
462
  return resp.containers
325
463
 
464
+ async def _set_target_slots(self, target_slots: int):
465
+ req = api_pb2.FlashSetTargetSlotsMetricsRequest(function_id=self.fn.object_id, target_slots=target_slots)
466
+ await self.client.stub.FlashSetTargetSlotsMetrics(req)
467
+ return
468
+
326
469
  def _make_scaling_decision(
327
470
  self,
328
471
  current_replicas: int,
@@ -331,6 +474,7 @@ class _FlashPrometheusAutoscaler:
331
474
  scale_down_stabilization_window_seconds: int = 60 * 5,
332
475
  min_containers: Optional[int] = None,
333
476
  max_containers: Optional[int] = None,
477
+ buffer_containers: Optional[int] = None,
334
478
  ) -> int:
335
479
  """
336
480
  Return the target number of containers following (simplified) Kubernetes HPA
@@ -349,6 +493,7 @@ class _FlashPrometheusAutoscaler:
349
493
  Returns:
350
494
  The target number of containers.
351
495
  """
496
+
352
497
  if not autoscaling_decisions:
353
498
  # Without data we can’t make a new decision – stay where we are.
354
499
  return current_replicas
@@ -381,6 +526,10 @@ class _FlashPrometheusAutoscaler:
381
526
  new_replicas = max(min_containers, new_replicas)
382
527
  if max_containers is not None:
383
528
  new_replicas = min(max_containers, new_replicas)
529
+
530
+ if buffer_containers is not None:
531
+ new_replicas += buffer_containers
532
+
384
533
  return new_replicas
385
534
 
386
535
  async def stop(self):
@@ -414,6 +563,8 @@ async def flash_prometheus_autoscaler(
414
563
  # How often to make autoscaling decisions.
415
564
  # Corresponds to --horizontal-pod-autoscaler-sync-period in Kubernetes.
416
565
  autoscaling_interval_seconds: int = 15,
566
+ # Whether to include overprovisioned containers in the scale up calculation.
567
+ buffer_containers: Optional[int] = None,
417
568
  ) -> _FlashPrometheusAutoscaler:
418
569
  """
419
570
  Autoscale a Flash service based on containers' Prometheus metrics.
@@ -431,19 +582,37 @@ async def flash_prometheus_autoscaler(
431
582
 
432
583
  client = await _Client.from_env()
433
584
  autoscaler = _FlashPrometheusAutoscaler(
434
- client,
435
- app_name,
436
- cls_name,
437
- metrics_endpoint,
438
- target_metric,
439
- target_metric_value,
440
- min_containers,
441
- max_containers,
442
- scale_up_tolerance,
443
- scale_down_tolerance,
444
- scale_up_stabilization_window_seconds,
445
- scale_down_stabilization_window_seconds,
446
- autoscaling_interval_seconds,
585
+ client=client,
586
+ app_name=app_name,
587
+ cls_name=cls_name,
588
+ metrics_endpoint=metrics_endpoint,
589
+ target_metric=target_metric,
590
+ target_metric_value=target_metric_value,
591
+ min_containers=min_containers,
592
+ max_containers=max_containers,
593
+ buffer_containers=buffer_containers,
594
+ scale_up_tolerance=scale_up_tolerance,
595
+ scale_down_tolerance=scale_down_tolerance,
596
+ scale_up_stabilization_window_seconds=scale_up_stabilization_window_seconds,
597
+ scale_down_stabilization_window_seconds=scale_down_stabilization_window_seconds,
598
+ autoscaling_interval_seconds=autoscaling_interval_seconds,
447
599
  )
448
600
  await autoscaler.start()
449
601
  return autoscaler
602
+
603
+
604
+ @synchronizer.create_blocking
605
+ async def flash_get_containers(app_name: str, cls_name: str) -> list[dict[str, Any]]:
606
+ """
607
+ Return a list of flash containers for a deployed Flash service.
608
+
609
+ This is a highly experimental method that can break or be removed at any time without warning.
610
+ Do not use this method unless explicitly instructed to do so by Modal support.
611
+ """
612
+ client = await _Client.from_env()
613
+ fn = _Cls.from_name(app_name, cls_name)._class_service_function
614
+ assert fn is not None
615
+ await fn.hydrate(client=client)
616
+ req = api_pb2.FlashContainerListRequest(function_id=fn.object_id)
617
+ resp = await client.stub.FlashContainerList(req)
618
+ return resp.containers