modal 1.1.3.dev7__py3-none-any.whl → 1.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -14,6 +14,7 @@ from modal_proto import api_pb2
14
14
  @dataclass
15
15
  class ClusterInfo:
16
16
  rank: int
17
+ cluster_id: str
17
18
  container_ips: list[str]
18
19
  container_ipv4_ips: list[str]
19
20
 
@@ -69,12 +70,14 @@ async def _initialize_clustered_function(client: _Client, task_id: str, world_si
69
70
  )
70
71
  cluster_info = ClusterInfo(
71
72
  rank=resp.cluster_rank,
73
+ cluster_id=resp.cluster_id,
72
74
  container_ips=resp.container_ips,
73
75
  container_ipv4_ips=resp.container_ipv4_ips,
74
76
  )
75
77
  else:
76
78
  cluster_info = ClusterInfo(
77
79
  rank=0,
80
+ cluster_id="", # No cluster ID for single-node # TODO(irfansharif): Is this right?
78
81
  container_ips=[container_ip],
79
82
  container_ipv4_ips=[], # No IPv4 IPs for single-node
80
83
  )
@@ -3,13 +3,14 @@ import typing
3
3
  import typing_extensions
4
4
 
5
5
  class ClusterInfo:
6
- """ClusterInfo(rank: int, container_ips: list[str], container_ipv4_ips: list[str])"""
6
+ """ClusterInfo(rank: int, cluster_id: str, container_ips: list[str], container_ipv4_ips: list[str])"""
7
7
 
8
8
  rank: int
9
+ cluster_id: str
9
10
  container_ips: list[str]
10
11
  container_ipv4_ips: list[str]
11
12
 
12
- def __init__(self, rank: int, container_ips: list[str], container_ipv4_ips: list[str]) -> None:
13
+ def __init__(self, rank: int, cluster_id: str, container_ips: list[str], container_ipv4_ips: list[str]) -> None:
13
14
  """Initialize self. See help(type(self)) for accurate signature."""
14
15
  ...
15
16
 
modal/_functions.py CHANGED
@@ -674,6 +674,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
674
674
  proxy: Optional[_Proxy] = None,
675
675
  retries: Optional[Union[int, Retries]] = None,
676
676
  timeout: int = 300,
677
+ startup_timeout: Optional[int] = None,
677
678
  min_containers: Optional[int] = None,
678
679
  max_containers: Optional[int] = None,
679
680
  buffer_containers: Optional[int] = None,
@@ -966,6 +967,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
966
967
  proxy_id=(proxy.object_id if proxy else None),
967
968
  retry_policy=retry_policy,
968
969
  timeout_secs=timeout_secs or 0,
970
+ startup_timeout_secs=startup_timeout or timeout_secs,
969
971
  pty_info=pty_info,
970
972
  cloud_provider_str=cloud if cloud else "",
971
973
  runtime=config.get("function_runtime"),
@@ -1019,6 +1021,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
1019
1021
  autoscaler_settings=function_definition.autoscaler_settings,
1020
1022
  worker_id=function_definition.worker_id,
1021
1023
  timeout_secs=function_definition.timeout_secs,
1024
+ startup_timeout_secs=function_definition.startup_timeout_secs,
1022
1025
  web_url=function_definition.web_url,
1023
1026
  web_url_info=function_definition.web_url_info,
1024
1027
  webhook_config=function_definition.webhook_config,
@@ -1471,6 +1474,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
1471
1474
  self._info = None
1472
1475
  self._serve_mounts = frozenset()
1473
1476
  self._metadata = None
1477
+ self._experimental_flash_urls = None
1474
1478
 
1475
1479
  def _hydrate_metadata(self, metadata: Optional[Message]):
1476
1480
  # Overridden concrete implementation of base class method
@@ -1498,6 +1502,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
1498
1502
  self._max_object_size_bytes = (
1499
1503
  metadata.max_object_size_bytes if metadata.HasField("max_object_size_bytes") else MAX_OBJECT_SIZE_BYTES
1500
1504
  )
1505
+ self._experimental_flash_urls = metadata._experimental_flash_urls
1501
1506
 
1502
1507
  def _get_metadata(self):
1503
1508
  # Overridden concrete implementation of base class method
@@ -1515,6 +1520,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
1515
1520
  input_plane_url=self._input_plane_url,
1516
1521
  input_plane_region=self._input_plane_region,
1517
1522
  max_object_size_bytes=self._max_object_size_bytes,
1523
+ _experimental_flash_urls=self._experimental_flash_urls,
1518
1524
  )
1519
1525
 
1520
1526
  def _check_no_web_url(self, fn_name: str):
@@ -1545,6 +1551,11 @@ Use the `Function.get_web_url()` method instead.
1545
1551
  """URL of a Function running as a web endpoint."""
1546
1552
  return self._web_url
1547
1553
 
1554
+ @live_method
1555
+ async def _experimental_get_flash_urls(self) -> Optional[list[str]]:
1556
+ """URL of the flash service for the function."""
1557
+ return list(self._experimental_flash_urls) if self._experimental_flash_urls else None
1558
+
1548
1559
  @property
1549
1560
  async def is_generator(self) -> bool:
1550
1561
  """mdmd:hidden"""
modal/_runtime/asgi.py CHANGED
@@ -120,7 +120,7 @@ def asgi_app_wrapper(asgi_app, container_io_manager) -> tuple[Callable[..., Asyn
120
120
 
121
121
  async def handle_first_input_timeout():
122
122
  if scope["type"] == "http":
123
- await messages_from_app.put({"type": "http.response.start", "status": 502})
123
+ await messages_from_app.put({"type": "http.response.start", "status": 408})
124
124
  await messages_from_app.put(
125
125
  {
126
126
  "type": "http.response.body",
@@ -204,6 +204,7 @@ async def retry_transient_errors(
204
204
  else:
205
205
  total_deadline = None
206
206
 
207
+ metadata = metadata + [("x-modal-timestamp", str(time.time()))]
207
208
  while True:
208
209
  attempt_metadata = [
209
210
  ("x-idempotency-key", idempotency_key),
modal/app.py CHANGED
@@ -641,7 +641,8 @@ class _App:
641
641
  scaledown_window: Optional[int] = None, # Max time (in seconds) a container can remain idle while scaling down.
642
642
  proxy: Optional[_Proxy] = None, # Reference to a Modal Proxy to use in front of this function.
643
643
  retries: Optional[Union[int, Retries]] = None, # Number of times to retry each input in case of failure.
644
- timeout: int = 300, # Maximum execution time in seconds.
644
+ timeout: int = 300, # Maximum execution time for inputs and startup time in seconds.
645
+ startup_timeout: Optional[int] = None, # Maximum startup time in seconds with higher precedence than `timeout`.
645
646
  name: Optional[str] = None, # Sets the Modal name of the function within the app
646
647
  is_generator: Optional[
647
648
  bool
@@ -816,6 +817,7 @@ class _App:
816
817
  batch_max_size=batch_max_size,
817
818
  batch_wait_ms=batch_wait_ms,
818
819
  timeout=timeout,
820
+ startup_timeout=startup_timeout or timeout,
819
821
  cloud=cloud,
820
822
  webhook_config=webhook_config,
821
823
  enable_memory_snapshot=enable_memory_snapshot,
@@ -869,7 +871,8 @@ class _App:
869
871
  scaledown_window: Optional[int] = None, # Max time (in seconds) a container can remain idle while scaling down.
870
872
  proxy: Optional[_Proxy] = None, # Reference to a Modal Proxy to use in front of this function.
871
873
  retries: Optional[Union[int, Retries]] = None, # Number of times to retry each input in case of failure.
872
- timeout: int = 300, # Maximum execution time in seconds; applies independently to startup and each input.
874
+ timeout: int = 300, # Maximum execution time for inputs and startup time in seconds.
875
+ startup_timeout: Optional[int] = None, # Maximum startup time in seconds with higher precedence than `timeout`.
873
876
  cloud: Optional[str] = None, # Cloud provider to run the function on. Possible values are aws, gcp, oci, auto.
874
877
  region: Optional[Union[str, Sequence[str]]] = None, # Region or regions to run the function on.
875
878
  enable_memory_snapshot: bool = False, # Enable memory checkpointing for faster cold starts.
@@ -1002,6 +1005,7 @@ class _App:
1002
1005
  batch_max_size=batch_max_size,
1003
1006
  batch_wait_ms=batch_wait_ms,
1004
1007
  timeout=timeout,
1008
+ startup_timeout=startup_timeout or timeout,
1005
1009
  cloud=cloud,
1006
1010
  enable_memory_snapshot=enable_memory_snapshot,
1007
1011
  block_network=block_network,
modal/app.pyi CHANGED
@@ -411,6 +411,7 @@ class _App:
411
411
  proxy: typing.Optional[modal.proxy._Proxy] = None,
412
412
  retries: typing.Union[int, modal.retries.Retries, None] = None,
413
413
  timeout: int = 300,
414
+ startup_timeout: typing.Optional[int] = None,
414
415
  name: typing.Optional[str] = None,
415
416
  is_generator: typing.Optional[bool] = None,
416
417
  cloud: typing.Optional[str] = None,
@@ -464,6 +465,7 @@ class _App:
464
465
  proxy: typing.Optional[modal.proxy._Proxy] = None,
465
466
  retries: typing.Union[int, modal.retries.Retries, None] = None,
466
467
  timeout: int = 300,
468
+ startup_timeout: typing.Optional[int] = None,
467
469
  cloud: typing.Optional[str] = None,
468
470
  region: typing.Union[str, collections.abc.Sequence[str], None] = None,
469
471
  enable_memory_snapshot: bool = False,
@@ -1014,6 +1016,7 @@ class App:
1014
1016
  proxy: typing.Optional[modal.proxy.Proxy] = None,
1015
1017
  retries: typing.Union[int, modal.retries.Retries, None] = None,
1016
1018
  timeout: int = 300,
1019
+ startup_timeout: typing.Optional[int] = None,
1017
1020
  name: typing.Optional[str] = None,
1018
1021
  is_generator: typing.Optional[bool] = None,
1019
1022
  cloud: typing.Optional[str] = None,
@@ -1067,6 +1070,7 @@ class App:
1067
1070
  proxy: typing.Optional[modal.proxy.Proxy] = None,
1068
1071
  retries: typing.Union[int, modal.retries.Retries, None] = None,
1069
1072
  timeout: int = 300,
1073
+ startup_timeout: typing.Optional[int] = None,
1070
1074
  cloud: typing.Optional[str] = None,
1071
1075
  region: typing.Union[str, collections.abc.Sequence[str], None] = None,
1072
1076
  enable_memory_snapshot: bool = False,
modal/builder/2025.06.txt CHANGED
@@ -3,6 +3,7 @@ aiohttp==3.12.7
3
3
  aiosignal==1.3.2
4
4
  async-timeout==5.0.1 ; python_version < "3.11"
5
5
  attrs==25.3.0
6
+ cbor2==5.7.0
6
7
  certifi==2025.4.26
7
8
  frozenlist==1.6.0
8
9
  grpclib==0.4.8
modal/builder/PREVIEW.txt CHANGED
@@ -3,6 +3,7 @@ aiohttp==3.12.7
3
3
  aiosignal==1.3.2
4
4
  async-timeout==5.0.1 ; python_version < "3.11"
5
5
  attrs==25.3.0
6
+ cbor2==5.7.0
6
7
  certifi==2025.4.26
7
8
  frozenlist==1.6.0
8
9
  grpclib==0.4.8
modal/client.pyi CHANGED
@@ -29,11 +29,7 @@ class _Client:
29
29
  _snapshotted: bool
30
30
 
31
31
  def __init__(
32
- self,
33
- server_url: str,
34
- client_type: int,
35
- credentials: typing.Optional[tuple[str, str]],
36
- version: str = "1.1.3.dev7",
32
+ self, server_url: str, client_type: int, credentials: typing.Optional[tuple[str, str]], version: str = "1.1.4"
37
33
  ):
38
34
  """mdmd:hidden
39
35
  The Modal client object is not intended to be instantiated directly by users.
@@ -160,11 +156,7 @@ class Client:
160
156
  _snapshotted: bool
161
157
 
162
158
  def __init__(
163
- self,
164
- server_url: str,
165
- client_type: int,
166
- credentials: typing.Optional[tuple[str, str]],
167
- version: str = "1.1.3.dev7",
159
+ self, server_url: str, client_type: int, credentials: typing.Optional[tuple[str, str]], version: str = "1.1.4"
168
160
  ):
169
161
  """mdmd:hidden
170
162
  The Modal client object is not intended to be instantiated directly by users.
modal/cls.py CHANGED
@@ -12,7 +12,7 @@ from grpclib import GRPCError, Status
12
12
  from modal_proto import api_pb2
13
13
 
14
14
  from ._functions import _Function, _parse_retries
15
- from ._object import _Object
15
+ from ._object import _Object, live_method
16
16
  from ._partial_function import (
17
17
  _find_callables_for_obj,
18
18
  _find_partial_methods_for_user_cls,
@@ -510,6 +510,11 @@ class _Cls(_Object, type_prefix="cs"):
510
510
  # returns method names for a *local* class only for now (used by cli)
511
511
  return self._method_partials.keys()
512
512
 
513
+ @live_method
514
+ async def _experimental_get_flash_urls(self) -> Optional[list[str]]:
515
+ """URL of the flash service for the class."""
516
+ return await self._get_class_service_function()._experimental_get_flash_urls()
517
+
513
518
  def _hydrate_metadata(self, metadata: Message):
514
519
  assert isinstance(metadata, api_pb2.ClassHandleMetadata)
515
520
  class_service_function = self._get_class_service_function()
modal/cls.pyi CHANGED
@@ -354,6 +354,10 @@ class _Cls(modal._object._Object):
354
354
  def _get_name(self) -> str: ...
355
355
  def _get_class_service_function(self) -> modal._functions._Function: ...
356
356
  def _get_method_names(self) -> collections.abc.Collection[str]: ...
357
+ async def _experimental_get_flash_urls(self) -> typing.Optional[list[str]]:
358
+ """URL of the flash service for the class."""
359
+ ...
360
+
357
361
  def _hydrate_metadata(self, metadata: google.protobuf.message.Message): ...
358
362
  @staticmethod
359
363
  def validate_construction_mechanism(user_cls):
@@ -520,6 +524,18 @@ class Cls(modal.object.Object):
520
524
  def _get_name(self) -> str: ...
521
525
  def _get_class_service_function(self) -> modal.functions.Function: ...
522
526
  def _get_method_names(self) -> collections.abc.Collection[str]: ...
527
+
528
+ class ___experimental_get_flash_urls_spec(typing_extensions.Protocol[SUPERSELF]):
529
+ def __call__(self, /) -> typing.Optional[list[str]]:
530
+ """URL of the flash service for the class."""
531
+ ...
532
+
533
+ async def aio(self, /) -> typing.Optional[list[str]]:
534
+ """URL of the flash service for the class."""
535
+ ...
536
+
537
+ _experimental_get_flash_urls: ___experimental_get_flash_urls_spec[typing_extensions.Self]
538
+
523
539
  def _hydrate_metadata(self, metadata: google.protobuf.message.Message): ...
524
540
  @staticmethod
525
541
  def validate_construction_mechanism(user_cls):
@@ -311,7 +311,8 @@ async def notebook_base_image(*, python_version: Optional[str] = None, force_bui
311
311
 
312
312
  commands: list[str] = [
313
313
  "apt-get update",
314
- "apt-get install -y libpq-dev pkg-config cmake git curl wget unzip zip libsqlite3-dev openssh-server vim",
314
+ "apt-get install -y "
315
+ + "libpq-dev pkg-config cmake git curl wget unzip zip libsqlite3-dev openssh-server vim ffmpeg",
315
316
  _install_cuda_command(),
316
317
  # Install uv since it's faster than pip for installing packages.
317
318
  "pip install uv",
@@ -1,6 +1,8 @@
1
1
  # Copyright Modal Labs 2025
2
2
  import asyncio
3
3
  import math
4
+ import os
5
+ import subprocess
4
6
  import sys
5
7
  import time
6
8
  import traceback
@@ -19,28 +21,87 @@ from ..client import _Client
19
21
  from ..config import logger
20
22
  from ..exception import InvalidError
21
23
 
24
+ MAX_FAILURES = 3
25
+
22
26
 
23
27
  class _FlashManager:
24
- def __init__(self, client: _Client, port: int, health_check_url: Optional[str] = None):
28
+ def __init__(
29
+ self,
30
+ client: _Client,
31
+ port: int,
32
+ process: Optional[subprocess.Popen] = None,
33
+ health_check_url: Optional[str] = None,
34
+ ):
25
35
  self.client = client
26
36
  self.port = port
37
+ # Health check is not currently being used
27
38
  self.health_check_url = health_check_url
39
+ self.process = process
28
40
  self.tunnel_manager = _forward_tunnel(port, client=client)
29
41
  self.stopped = False
42
+ self.num_failures = 0
43
+ self.task_id = os.environ["MODAL_TASK_ID"]
44
+
45
+ async def check_port_connection(self, process: Optional[subprocess.Popen], timeout: int = 10):
46
+ import socket
47
+
48
+ start_time = time.monotonic()
49
+
50
+ while time.monotonic() - start_time < timeout:
51
+ try:
52
+ if process is not None and process.poll() is not None:
53
+ return Exception(f"Process {process.pid} exited with code {process.returncode}")
54
+ with socket.create_connection(("localhost", self.port), timeout=1):
55
+ return
56
+ except (ConnectionRefusedError, OSError):
57
+ await asyncio.sleep(0.1)
58
+
59
+ return Exception(f"Waited too long for port {self.port} to start accepting connections")
30
60
 
31
61
  async def _start(self):
32
62
  self.tunnel = await self.tunnel_manager.__aenter__()
33
-
34
63
  parsed_url = urlparse(self.tunnel.url)
35
64
  host = parsed_url.hostname
36
65
  port = parsed_url.port or 443
37
66
 
38
67
  self.heartbeat_task = asyncio.create_task(self._run_heartbeat(host, port))
68
+ self.drain_task = asyncio.create_task(self._drain_container())
69
+
70
+ async def _drain_container(self):
71
+ """
72
+ Background task that checks if we've encountered too many failures and drains the container if so.
73
+ """
74
+ while True:
75
+ try:
76
+ # Check if the container should be drained (e.g., too many failures)
77
+ if self.num_failures > MAX_FAILURES:
78
+ logger.warning(
79
+ f"[Modal Flash] Draining task {self.task_id} on {self.tunnel.url} due to too many failures."
80
+ )
81
+ await self.stop()
82
+ # handle close upon container exit
83
+
84
+ if self.task_id:
85
+ await self.client.stub.ContainerStop(api_pb2.ContainerStopRequest(task_id=self.task_id))
86
+ return
87
+ except asyncio.CancelledError:
88
+ logger.warning("[Modal Flash] Shutting down...")
89
+ return
90
+ except Exception as e:
91
+ logger.error(f"[Modal Flash] Error draining container: {e}")
92
+ await asyncio.sleep(1)
93
+
94
+ try:
95
+ await asyncio.sleep(1)
96
+ except asyncio.CancelledError:
97
+ logger.warning("[Modal Flash] Shutting down...")
98
+ return
39
99
 
40
100
  async def _run_heartbeat(self, host: str, port: int):
41
101
  first_registration = True
42
102
  while True:
43
103
  try:
104
+ await self.check_port_connection(process=self.process)
44
105
  resp = await self.client.stub.FlashContainerRegister(
45
106
  api_pb2.FlashContainerRegisterRequest(
46
107
  priority=10,
@@ -50,14 +111,25 @@ class _FlashManager:
50
111
  ),
51
112
  timeout=10,
52
113
  )
114
+ self.num_failures = 0
53
115
  if first_registration:
54
- logger.warning(f"[Modal Flash] Listening at {resp.url} over {self.tunnel.url}")
116
+ logger.warning(
117
+ f"[Modal Flash] Listening at {resp.url} over {self.tunnel.url} for task_id {self.task_id}"
118
+ )
55
119
  first_registration = False
56
120
  except asyncio.CancelledError:
57
121
  logger.warning("[Modal Flash] Shutting down...")
58
122
  break
59
123
  except Exception as e:
60
124
  logger.error(f"[Modal Flash] Heartbeat failed: {e}")
125
+ self.num_failures += 1
126
+ logger.error(
127
+ f"[Modal Flash] Deregistering container {self.tunnel.url}, num_failures: {self.num_failures}"
128
+ )
129
+ await retry_transient_errors(
130
+ self.client.stub.FlashContainerDeregister,
131
+ api_pb2.FlashContainerDeregisterRequest(),
132
+ )
61
133
 
62
134
  try:
63
135
  await asyncio.sleep(1)
@@ -94,16 +166,17 @@ FlashManager = synchronize_api(_FlashManager)
94
166
 
95
167
 
96
168
  @synchronizer.create_blocking
97
- async def flash_forward(port: int, health_check_url: Optional[str] = None) -> _FlashManager:
169
+ async def flash_forward(
170
+ port: int, process: Optional[subprocess.Popen] = None, health_check_url: Optional[str] = None
171
+ ) -> _FlashManager:
98
172
  """
99
173
  Forward a port to the Modal Flash service, exposing that port as a stable web endpoint.
100
-
101
174
  This is a highly experimental method that can break or be removed at any time without warning.
102
175
  Do not use this method unless explicitly instructed to do so by Modal support.
103
176
  """
104
177
  client = await _Client.from_env()
105
178
 
106
- manager = _FlashManager(client, port, health_check_url)
179
+ manager = _FlashManager(client, port, process=process, health_check_url=health_check_url)
107
180
  await manager._start()
108
181
  return manager
109
182
 
@@ -127,6 +200,8 @@ class _FlashPrometheusAutoscaler:
127
200
  scale_down_stabilization_window_seconds: int,
128
201
  autoscaling_interval_seconds: int,
129
202
  ):
203
+ import aiohttp
204
+
130
205
  if scale_up_stabilization_window_seconds > self._max_window_seconds:
131
206
  raise InvalidError(
132
207
  f"scale_up_stabilization_window_seconds must be less than or equal to {self._max_window_seconds}"
@@ -138,8 +213,6 @@ class _FlashPrometheusAutoscaler:
138
213
  if target_metric_value <= 0:
139
214
  raise InvalidError("target_metric_value must be greater than 0")
140
215
 
141
- import aiohttp
142
-
143
216
  self.client = client
144
217
  self.app_name = app_name
145
218
  self.cls_name = cls_name
@@ -200,7 +273,10 @@ class _FlashPrometheusAutoscaler:
200
273
  if timestamp >= autoscaling_time - self._max_window_seconds
201
274
  ]
202
275
 
203
- current_target_containers = await self._compute_target_containers(current_replicas)
276
+ if self.metrics_endpoint == "internal":
277
+ current_target_containers = await self._compute_target_containers_internal(current_replicas)
278
+ else:
279
+ current_target_containers = await self._compute_target_containers_prometheus(current_replicas)
204
280
  autoscaling_decisions.append((autoscaling_time, current_target_containers))
205
281
 
206
282
  actual_target_containers = self._make_scaling_decision(
@@ -213,8 +289,8 @@ class _FlashPrometheusAutoscaler:
213
289
  )
214
290
 
215
291
  logger.warning(
216
- f"[Modal Flash] Scaling to {actual_target_containers} containers. Autoscaling decision "
217
- f"made in {time.time() - autoscaling_time} seconds."
292
+ f"[Modal Flash] Scaling to {actual_target_containers=} containers. "
293
+ f" Autoscaling decision made in {time.time() - autoscaling_time} seconds."
218
294
  )
219
295
 
220
296
  await self.autoscaling_decisions_dict.put(
@@ -223,9 +299,7 @@ class _FlashPrometheusAutoscaler:
223
299
  )
224
300
  await self.autoscaling_decisions_dict.put("current_replicas", actual_target_containers)
225
301
 
226
- await self.cls.update_autoscaler(
227
- min_containers=actual_target_containers,
228
- )
302
+ await self.cls.update_autoscaler(min_containers=actual_target_containers)
229
303
 
230
304
  if time.time() - autoscaling_time < self.autoscaling_interval_seconds:
231
305
  await asyncio.sleep(self.autoscaling_interval_seconds - (time.time() - autoscaling_time))
@@ -238,7 +312,55 @@ class _FlashPrometheusAutoscaler:
238
312
  logger.error(traceback.format_exc())
239
313
  await asyncio.sleep(self.autoscaling_interval_seconds)
240
314
 
241
- async def _compute_target_containers(self, current_replicas: int) -> int:
315
+ async def _compute_target_containers_internal(self, current_replicas: int) -> int:
316
+ """
317
+ Gets internal metrics from container to autoscale up or down.
318
+ """
319
+ containers = await self._get_all_containers()
320
+ if len(containers) > current_replicas:
321
+ logger.info(
322
+ f"[Modal Flash] Current replicas {current_replicas} is less than the number of containers "
323
+ f"{len(containers)}. Setting current_replicas = num_containers."
324
+ )
325
+ current_replicas = len(containers)
326
+
327
+ if current_replicas == 0:
328
+ return 1
329
+
330
+ internal_metrics_list = []
331
+ for container in containers:
332
+ internal_metric = await self._get_container_metrics(container.task_id)
333
+ if internal_metric is None:
334
+ continue
335
+ internal_metrics_list.append(getattr(internal_metric.metrics, self.target_metric))
336
+
337
+ if not internal_metrics_list:
338
+ return current_replicas
339
+
340
+ avg_internal_metric = sum(internal_metrics_list) / len(internal_metrics_list)
341
+
342
+ scale_factor = avg_internal_metric / self.target_metric_value
343
+
344
+ desired_replicas = current_replicas
345
+ if scale_factor > 1 + self.scale_up_tolerance:
346
+ desired_replicas = math.ceil(current_replicas * scale_factor)
347
+ elif scale_factor < 1 - self.scale_down_tolerance:
348
+ desired_replicas = math.ceil(current_replicas * scale_factor)
349
+
350
+ logger.warning(
351
+ f"[Modal Flash] Current replicas: {current_replicas}, "
352
+ f"avg internal metric `{self.target_metric}`: {avg_internal_metric}, "
353
+ f"target internal metric value: {self.target_metric_value}, "
354
+ f"scale factor: {scale_factor}, "
355
+ f"desired replicas: {desired_replicas}"
356
+ )
357
+
358
+ desired_replicas = max(1, min(desired_replicas, self.max_containers or 1000))
359
+ return desired_replicas
360
+
361
+ async def _compute_target_containers_prometheus(self, current_replicas: int) -> int:
362
+ # current_replicas is the number of live containers + cold starting containers (not yet live)
363
+ # containers is the number of live containers that are registered in flash dns
242
364
  containers = await self._get_all_containers()
243
365
  if len(containers) > current_replicas:
244
366
  logger.info(
@@ -253,6 +375,7 @@ class _FlashPrometheusAutoscaler:
253
375
  target_metric = self.target_metric
254
376
  target_metric_value = float(self.target_metric_value)
255
377
 
378
+ # Gets metrics from prometheus
256
379
  sum_metric = 0
257
380
  containers_with_metrics = 0
258
381
  container_metrics_list = await asyncio.gather(
@@ -271,11 +394,17 @@ class _FlashPrometheusAutoscaler:
271
394
  sum_metric += container_metrics[target_metric][0].value
272
395
  containers_with_metrics += 1
273
396
 
397
+ # n_containers_missing_metric is the number of unhealthy containers + number of cold starting containers
274
398
  n_containers_missing_metric = current_replicas - containers_with_metrics
399
+ # n_containers_unhealthy is the number of live containers that are not emitting metrics i.e. unhealthy
400
+ n_containers_unhealthy = len(containers) - containers_with_metrics
401
+
402
+ # Scale up assuming that every unhealthy container is at 2x the target metric value.
403
+ scale_up_target_metric_value = (sum_metric + n_containers_unhealthy * target_metric_value) / (
404
+ (containers_with_metrics + n_containers_unhealthy) or 1
405
+ )
275
406
 
276
- # Scale up / down conservatively: Any container that is missing the metric is assumed to be at the minimum
277
- # value of the metric when scaling up and the maximum value of the metric when scaling down.
278
- scale_up_target_metric_value = sum_metric / current_replicas
407
+ # Scale down assuming that every container (including cold starting containers) are at the target metric value.
279
408
  scale_down_target_metric_value = (
280
409
  sum_metric + n_containers_missing_metric * target_metric_value
281
410
  ) / current_replicas
@@ -290,9 +419,14 @@ class _FlashPrometheusAutoscaler:
290
419
  desired_replicas = math.ceil(current_replicas * scale_down_ratio)
291
420
 
292
421
  logger.warning(
293
- f"[Modal Flash] Current replicas: {current_replicas}, target metric value: {target_metric_value}, "
294
- f"current sum of metric values: {sum_metric}, number of containers missing metric: "
295
- f"{n_containers_missing_metric}, scale up ratio: {scale_up_ratio}, scale down ratio: {scale_down_ratio}, "
422
+ f"[Modal Flash] Current replicas: {current_replicas}, "
423
+ f"target metric value: {target_metric_value}, "
424
+ f"current sum of metric values: {sum_metric}, "
425
+ f"number of containers with metrics: {containers_with_metrics}, "
426
+ f"number of containers unhealthy: {n_containers_unhealthy}, "
427
+ f"number of containers missing metric (includes unhealthy): {n_containers_missing_metric}, "
428
+ f"scale up ratio: {scale_up_ratio}, "
429
+ f"scale down ratio: {scale_down_ratio}, "
296
430
  f"desired replicas: {desired_replicas}"
297
431
  )
298
432
 
@@ -303,20 +437,42 @@ class _FlashPrometheusAutoscaler:
303
437
 
304
438
  # Fetch the metrics from the endpoint
305
439
  try:
306
- response = await self.http_client.get(url)
440
+ response = await self.http_client.get(url, timeout=3)
307
441
  response.raise_for_status()
442
+ except asyncio.TimeoutError:
443
+ logger.warning(f"[Modal Flash] Timeout getting metrics from {url}")
444
+ return None
308
445
  except Exception as e:
309
446
  logger.warning(f"[Modal Flash] Error getting metrics from {url}: {e}")
310
447
  return None
311
448
 
449
+ # Read body with timeout/error handling and parse Prometheus metrics
450
+ try:
451
+ text_body = await response.text()
452
+ except asyncio.TimeoutError:
453
+ logger.warning(f"[Modal Flash] Timeout reading metrics body from {url}")
454
+ return None
455
+ except Exception as e:
456
+ logger.warning(f"[Modal Flash] Error reading metrics body from {url}: {e}")
457
+ return None
458
+
312
459
  # Parse the text-based Prometheus metrics format
313
460
  metrics: dict[str, list[Sample]] = defaultdict(list)
314
- for family in text_string_to_metric_families(await response.text()):
461
+ for family in text_string_to_metric_families(text_body):
315
462
  for sample in family.samples:
316
463
  metrics[sample.name] += [sample]
317
464
 
318
465
  return metrics
319
466
 
467
+ async def _get_container_metrics(self, container_id: str) -> Optional[api_pb2.TaskGetAutoscalingMetricsResponse]:
468
+ req = api_pb2.TaskGetAutoscalingMetricsRequest(task_id=container_id)
469
+ try:
470
+ resp = await retry_transient_errors(self.client.stub.TaskGetAutoscalingMetrics, req)
471
+ return resp
472
+ except Exception as e:
473
+ logger.warning(f"[Modal Flash] Error getting metrics for container {container_id}: {e}")
474
+ return None
475
+
320
476
  async def _get_all_containers(self):
321
477
  req = api_pb2.FlashContainerListRequest(function_id=self.fn.object_id)
322
478
  resp = await retry_transient_errors(self.client.stub.FlashContainerList, req)
@@ -395,10 +551,14 @@ async def flash_prometheus_autoscaler(
395
551
  app_name: str,
396
552
  cls_name: str,
397
553
  # Endpoint to fetch metrics from. Must be in Prometheus format. Example: "/metrics"
554
+ # If metrics_endpoint is "internal", we will use containers' internal metrics to autoscale instead.
398
555
  metrics_endpoint: str,
399
556
  # Target metric to autoscale on. Example: "vllm:num_requests_running"
557
+ # If metrics_endpoint is "internal", target_metrics options are: [cpu_usage_percent, memory_usage_percent]
400
558
  target_metric: str,
401
559
  # Target metric value. Example: 25
560
+ # If metrics_endpoint is "internal", target_metric_value is a percentage value between 0.1 and 1.0 (inclusive),
561
+ # indicating container's usage of that metric.
402
562
  target_metric_value: float,
403
563
  min_containers: Optional[int] = None,
404
564
  max_containers: Optional[int] = None,