modal 1.1.3.dev7__py3-none-any.whl → 1.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- modal/_clustered_functions.py +3 -0
- modal/_clustered_functions.pyi +3 -2
- modal/_functions.py +11 -0
- modal/_runtime/asgi.py +1 -1
- modal/_utils/grpc_utils.py +1 -0
- modal/app.py +6 -2
- modal/app.pyi +4 -0
- modal/builder/2025.06.txt +1 -0
- modal/builder/PREVIEW.txt +1 -0
- modal/client.pyi +2 -10
- modal/cls.py +6 -1
- modal/cls.pyi +16 -0
- modal/experimental/__init__.py +2 -1
- modal/experimental/flash.py +183 -23
- modal/experimental/flash.pyi +83 -9
- modal/functions.pyi +18 -6
- modal/image.py +8 -2
- modal/image.pyi +16 -4
- modal/mount.py +17 -11
- modal/mount.pyi +4 -0
- modal/parallel_map.py +26 -6
- modal/parallel_map.pyi +1 -0
- modal/sandbox.py +31 -4
- modal/sandbox.pyi +12 -3
- {modal-1.1.3.dev7.dist-info → modal-1.1.4.dist-info}/METADATA +1 -1
- {modal-1.1.3.dev7.dist-info → modal-1.1.4.dist-info}/RECORD +38 -38
- modal_proto/api.proto +30 -0
- modal_proto/api_grpc.py +32 -0
- modal_proto/api_pb2.py +893 -853
- modal_proto/api_pb2.pyi +94 -5
- modal_proto/api_pb2_grpc.py +68 -1
- modal_proto/api_pb2_grpc.pyi +25 -3
- modal_proto/modal_api_grpc.py +2 -0
- modal_version/__init__.py +1 -1
- {modal-1.1.3.dev7.dist-info → modal-1.1.4.dist-info}/WHEEL +0 -0
- {modal-1.1.3.dev7.dist-info → modal-1.1.4.dist-info}/entry_points.txt +0 -0
- {modal-1.1.3.dev7.dist-info → modal-1.1.4.dist-info}/licenses/LICENSE +0 -0
- {modal-1.1.3.dev7.dist-info → modal-1.1.4.dist-info}/top_level.txt +0 -0
modal/_clustered_functions.py
CHANGED
@@ -14,6 +14,7 @@ from modal_proto import api_pb2
|
|
14
14
|
@dataclass
|
15
15
|
class ClusterInfo:
|
16
16
|
rank: int
|
17
|
+
cluster_id: str
|
17
18
|
container_ips: list[str]
|
18
19
|
container_ipv4_ips: list[str]
|
19
20
|
|
@@ -69,12 +70,14 @@ async def _initialize_clustered_function(client: _Client, task_id: str, world_si
|
|
69
70
|
)
|
70
71
|
cluster_info = ClusterInfo(
|
71
72
|
rank=resp.cluster_rank,
|
73
|
+
cluster_id=resp.cluster_id,
|
72
74
|
container_ips=resp.container_ips,
|
73
75
|
container_ipv4_ips=resp.container_ipv4_ips,
|
74
76
|
)
|
75
77
|
else:
|
76
78
|
cluster_info = ClusterInfo(
|
77
79
|
rank=0,
|
80
|
+
cluster_id="", # No cluster ID for single-node # TODO(irfansharif): Is this right?
|
78
81
|
container_ips=[container_ip],
|
79
82
|
container_ipv4_ips=[], # No IPv4 IPs for single-node
|
80
83
|
)
|
modal/_clustered_functions.pyi
CHANGED
@@ -3,13 +3,14 @@ import typing
|
|
3
3
|
import typing_extensions
|
4
4
|
|
5
5
|
class ClusterInfo:
|
6
|
-
"""ClusterInfo(rank: int, container_ips: list[str], container_ipv4_ips: list[str])"""
|
6
|
+
"""ClusterInfo(rank: int, cluster_id: str, container_ips: list[str], container_ipv4_ips: list[str])"""
|
7
7
|
|
8
8
|
rank: int
|
9
|
+
cluster_id: str
|
9
10
|
container_ips: list[str]
|
10
11
|
container_ipv4_ips: list[str]
|
11
12
|
|
12
|
-
def __init__(self, rank: int, container_ips: list[str], container_ipv4_ips: list[str]) -> None:
|
13
|
+
def __init__(self, rank: int, cluster_id: str, container_ips: list[str], container_ipv4_ips: list[str]) -> None:
|
13
14
|
"""Initialize self. See help(type(self)) for accurate signature."""
|
14
15
|
...
|
15
16
|
|
modal/_functions.py
CHANGED
@@ -674,6 +674,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
|
|
674
674
|
proxy: Optional[_Proxy] = None,
|
675
675
|
retries: Optional[Union[int, Retries]] = None,
|
676
676
|
timeout: int = 300,
|
677
|
+
startup_timeout: Optional[int] = None,
|
677
678
|
min_containers: Optional[int] = None,
|
678
679
|
max_containers: Optional[int] = None,
|
679
680
|
buffer_containers: Optional[int] = None,
|
@@ -966,6 +967,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
|
|
966
967
|
proxy_id=(proxy.object_id if proxy else None),
|
967
968
|
retry_policy=retry_policy,
|
968
969
|
timeout_secs=timeout_secs or 0,
|
970
|
+
startup_timeout_secs=startup_timeout or timeout_secs,
|
969
971
|
pty_info=pty_info,
|
970
972
|
cloud_provider_str=cloud if cloud else "",
|
971
973
|
runtime=config.get("function_runtime"),
|
@@ -1019,6 +1021,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
|
|
1019
1021
|
autoscaler_settings=function_definition.autoscaler_settings,
|
1020
1022
|
worker_id=function_definition.worker_id,
|
1021
1023
|
timeout_secs=function_definition.timeout_secs,
|
1024
|
+
startup_timeout_secs=function_definition.startup_timeout_secs,
|
1022
1025
|
web_url=function_definition.web_url,
|
1023
1026
|
web_url_info=function_definition.web_url_info,
|
1024
1027
|
webhook_config=function_definition.webhook_config,
|
@@ -1471,6 +1474,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
|
|
1471
1474
|
self._info = None
|
1472
1475
|
self._serve_mounts = frozenset()
|
1473
1476
|
self._metadata = None
|
1477
|
+
self._experimental_flash_urls = None
|
1474
1478
|
|
1475
1479
|
def _hydrate_metadata(self, metadata: Optional[Message]):
|
1476
1480
|
# Overridden concrete implementation of base class method
|
@@ -1498,6 +1502,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
|
|
1498
1502
|
self._max_object_size_bytes = (
|
1499
1503
|
metadata.max_object_size_bytes if metadata.HasField("max_object_size_bytes") else MAX_OBJECT_SIZE_BYTES
|
1500
1504
|
)
|
1505
|
+
self._experimental_flash_urls = metadata._experimental_flash_urls
|
1501
1506
|
|
1502
1507
|
def _get_metadata(self):
|
1503
1508
|
# Overridden concrete implementation of base class method
|
@@ -1515,6 +1520,7 @@ class _Function(typing.Generic[P, ReturnType, OriginalReturnType], _Object, type
|
|
1515
1520
|
input_plane_url=self._input_plane_url,
|
1516
1521
|
input_plane_region=self._input_plane_region,
|
1517
1522
|
max_object_size_bytes=self._max_object_size_bytes,
|
1523
|
+
_experimental_flash_urls=self._experimental_flash_urls,
|
1518
1524
|
)
|
1519
1525
|
|
1520
1526
|
def _check_no_web_url(self, fn_name: str):
|
@@ -1545,6 +1551,11 @@ Use the `Function.get_web_url()` method instead.
|
|
1545
1551
|
"""URL of a Function running as a web endpoint."""
|
1546
1552
|
return self._web_url
|
1547
1553
|
|
1554
|
+
@live_method
|
1555
|
+
async def _experimental_get_flash_urls(self) -> Optional[list[str]]:
|
1556
|
+
"""URL of the flash service for the function."""
|
1557
|
+
return list(self._experimental_flash_urls) if self._experimental_flash_urls else None
|
1558
|
+
|
1548
1559
|
@property
|
1549
1560
|
async def is_generator(self) -> bool:
|
1550
1561
|
"""mdmd:hidden"""
|
modal/_runtime/asgi.py
CHANGED
@@ -120,7 +120,7 @@ def asgi_app_wrapper(asgi_app, container_io_manager) -> tuple[Callable[..., Asyn
|
|
120
120
|
|
121
121
|
async def handle_first_input_timeout():
|
122
122
|
if scope["type"] == "http":
|
123
|
-
await messages_from_app.put({"type": "http.response.start", "status":
|
123
|
+
await messages_from_app.put({"type": "http.response.start", "status": 408})
|
124
124
|
await messages_from_app.put(
|
125
125
|
{
|
126
126
|
"type": "http.response.body",
|
modal/_utils/grpc_utils.py
CHANGED
modal/app.py
CHANGED
@@ -641,7 +641,8 @@ class _App:
|
|
641
641
|
scaledown_window: Optional[int] = None, # Max time (in seconds) a container can remain idle while scaling down.
|
642
642
|
proxy: Optional[_Proxy] = None, # Reference to a Modal Proxy to use in front of this function.
|
643
643
|
retries: Optional[Union[int, Retries]] = None, # Number of times to retry each input in case of failure.
|
644
|
-
timeout: int = 300, # Maximum execution time in seconds.
|
644
|
+
timeout: int = 300, # Maximum execution time for inputs and startup time in seconds.
|
645
|
+
startup_timeout: Optional[int] = None, # Maximum startup time in seconds with higher precedence than `timeout`.
|
645
646
|
name: Optional[str] = None, # Sets the Modal name of the function within the app
|
646
647
|
is_generator: Optional[
|
647
648
|
bool
|
@@ -816,6 +817,7 @@ class _App:
|
|
816
817
|
batch_max_size=batch_max_size,
|
817
818
|
batch_wait_ms=batch_wait_ms,
|
818
819
|
timeout=timeout,
|
820
|
+
startup_timeout=startup_timeout or timeout,
|
819
821
|
cloud=cloud,
|
820
822
|
webhook_config=webhook_config,
|
821
823
|
enable_memory_snapshot=enable_memory_snapshot,
|
@@ -869,7 +871,8 @@ class _App:
|
|
869
871
|
scaledown_window: Optional[int] = None, # Max time (in seconds) a container can remain idle while scaling down.
|
870
872
|
proxy: Optional[_Proxy] = None, # Reference to a Modal Proxy to use in front of this function.
|
871
873
|
retries: Optional[Union[int, Retries]] = None, # Number of times to retry each input in case of failure.
|
872
|
-
timeout: int = 300, # Maximum execution time
|
874
|
+
timeout: int = 300, # Maximum execution time for inputs and startup time in seconds.
|
875
|
+
startup_timeout: Optional[int] = None, # Maximum startup time in seconds with higher precedence than `timeout`.
|
873
876
|
cloud: Optional[str] = None, # Cloud provider to run the function on. Possible values are aws, gcp, oci, auto.
|
874
877
|
region: Optional[Union[str, Sequence[str]]] = None, # Region or regions to run the function on.
|
875
878
|
enable_memory_snapshot: bool = False, # Enable memory checkpointing for faster cold starts.
|
@@ -1002,6 +1005,7 @@ class _App:
|
|
1002
1005
|
batch_max_size=batch_max_size,
|
1003
1006
|
batch_wait_ms=batch_wait_ms,
|
1004
1007
|
timeout=timeout,
|
1008
|
+
startup_timeout=startup_timeout or timeout,
|
1005
1009
|
cloud=cloud,
|
1006
1010
|
enable_memory_snapshot=enable_memory_snapshot,
|
1007
1011
|
block_network=block_network,
|
modal/app.pyi
CHANGED
@@ -411,6 +411,7 @@ class _App:
|
|
411
411
|
proxy: typing.Optional[modal.proxy._Proxy] = None,
|
412
412
|
retries: typing.Union[int, modal.retries.Retries, None] = None,
|
413
413
|
timeout: int = 300,
|
414
|
+
startup_timeout: typing.Optional[int] = None,
|
414
415
|
name: typing.Optional[str] = None,
|
415
416
|
is_generator: typing.Optional[bool] = None,
|
416
417
|
cloud: typing.Optional[str] = None,
|
@@ -464,6 +465,7 @@ class _App:
|
|
464
465
|
proxy: typing.Optional[modal.proxy._Proxy] = None,
|
465
466
|
retries: typing.Union[int, modal.retries.Retries, None] = None,
|
466
467
|
timeout: int = 300,
|
468
|
+
startup_timeout: typing.Optional[int] = None,
|
467
469
|
cloud: typing.Optional[str] = None,
|
468
470
|
region: typing.Union[str, collections.abc.Sequence[str], None] = None,
|
469
471
|
enable_memory_snapshot: bool = False,
|
@@ -1014,6 +1016,7 @@ class App:
|
|
1014
1016
|
proxy: typing.Optional[modal.proxy.Proxy] = None,
|
1015
1017
|
retries: typing.Union[int, modal.retries.Retries, None] = None,
|
1016
1018
|
timeout: int = 300,
|
1019
|
+
startup_timeout: typing.Optional[int] = None,
|
1017
1020
|
name: typing.Optional[str] = None,
|
1018
1021
|
is_generator: typing.Optional[bool] = None,
|
1019
1022
|
cloud: typing.Optional[str] = None,
|
@@ -1067,6 +1070,7 @@ class App:
|
|
1067
1070
|
proxy: typing.Optional[modal.proxy.Proxy] = None,
|
1068
1071
|
retries: typing.Union[int, modal.retries.Retries, None] = None,
|
1069
1072
|
timeout: int = 300,
|
1073
|
+
startup_timeout: typing.Optional[int] = None,
|
1070
1074
|
cloud: typing.Optional[str] = None,
|
1071
1075
|
region: typing.Union[str, collections.abc.Sequence[str], None] = None,
|
1072
1076
|
enable_memory_snapshot: bool = False,
|
modal/builder/2025.06.txt
CHANGED
modal/builder/PREVIEW.txt
CHANGED
modal/client.pyi
CHANGED
@@ -29,11 +29,7 @@ class _Client:
|
|
29
29
|
_snapshotted: bool
|
30
30
|
|
31
31
|
def __init__(
|
32
|
-
self,
|
33
|
-
server_url: str,
|
34
|
-
client_type: int,
|
35
|
-
credentials: typing.Optional[tuple[str, str]],
|
36
|
-
version: str = "1.1.3.dev7",
|
32
|
+
self, server_url: str, client_type: int, credentials: typing.Optional[tuple[str, str]], version: str = "1.1.4"
|
37
33
|
):
|
38
34
|
"""mdmd:hidden
|
39
35
|
The Modal client object is not intended to be instantiated directly by users.
|
@@ -160,11 +156,7 @@ class Client:
|
|
160
156
|
_snapshotted: bool
|
161
157
|
|
162
158
|
def __init__(
|
163
|
-
self,
|
164
|
-
server_url: str,
|
165
|
-
client_type: int,
|
166
|
-
credentials: typing.Optional[tuple[str, str]],
|
167
|
-
version: str = "1.1.3.dev7",
|
159
|
+
self, server_url: str, client_type: int, credentials: typing.Optional[tuple[str, str]], version: str = "1.1.4"
|
168
160
|
):
|
169
161
|
"""mdmd:hidden
|
170
162
|
The Modal client object is not intended to be instantiated directly by users.
|
modal/cls.py
CHANGED
@@ -12,7 +12,7 @@ from grpclib import GRPCError, Status
|
|
12
12
|
from modal_proto import api_pb2
|
13
13
|
|
14
14
|
from ._functions import _Function, _parse_retries
|
15
|
-
from ._object import _Object
|
15
|
+
from ._object import _Object, live_method
|
16
16
|
from ._partial_function import (
|
17
17
|
_find_callables_for_obj,
|
18
18
|
_find_partial_methods_for_user_cls,
|
@@ -510,6 +510,11 @@ class _Cls(_Object, type_prefix="cs"):
|
|
510
510
|
# returns method names for a *local* class only for now (used by cli)
|
511
511
|
return self._method_partials.keys()
|
512
512
|
|
513
|
+
@live_method
|
514
|
+
async def _experimental_get_flash_urls(self) -> Optional[list[str]]:
|
515
|
+
"""URL of the flash service for the class."""
|
516
|
+
return await self._get_class_service_function()._experimental_get_flash_urls()
|
517
|
+
|
513
518
|
def _hydrate_metadata(self, metadata: Message):
|
514
519
|
assert isinstance(metadata, api_pb2.ClassHandleMetadata)
|
515
520
|
class_service_function = self._get_class_service_function()
|
modal/cls.pyi
CHANGED
@@ -354,6 +354,10 @@ class _Cls(modal._object._Object):
|
|
354
354
|
def _get_name(self) -> str: ...
|
355
355
|
def _get_class_service_function(self) -> modal._functions._Function: ...
|
356
356
|
def _get_method_names(self) -> collections.abc.Collection[str]: ...
|
357
|
+
async def _experimental_get_flash_urls(self) -> typing.Optional[list[str]]:
|
358
|
+
"""URL of the flash service for the class."""
|
359
|
+
...
|
360
|
+
|
357
361
|
def _hydrate_metadata(self, metadata: google.protobuf.message.Message): ...
|
358
362
|
@staticmethod
|
359
363
|
def validate_construction_mechanism(user_cls):
|
@@ -520,6 +524,18 @@ class Cls(modal.object.Object):
|
|
520
524
|
def _get_name(self) -> str: ...
|
521
525
|
def _get_class_service_function(self) -> modal.functions.Function: ...
|
522
526
|
def _get_method_names(self) -> collections.abc.Collection[str]: ...
|
527
|
+
|
528
|
+
class ___experimental_get_flash_urls_spec(typing_extensions.Protocol[SUPERSELF]):
|
529
|
+
def __call__(self, /) -> typing.Optional[list[str]]:
|
530
|
+
"""URL of the flash service for the class."""
|
531
|
+
...
|
532
|
+
|
533
|
+
async def aio(self, /) -> typing.Optional[list[str]]:
|
534
|
+
"""URL of the flash service for the class."""
|
535
|
+
...
|
536
|
+
|
537
|
+
_experimental_get_flash_urls: ___experimental_get_flash_urls_spec[typing_extensions.Self]
|
538
|
+
|
523
539
|
def _hydrate_metadata(self, metadata: google.protobuf.message.Message): ...
|
524
540
|
@staticmethod
|
525
541
|
def validate_construction_mechanism(user_cls):
|
modal/experimental/__init__.py
CHANGED
@@ -311,7 +311,8 @@ async def notebook_base_image(*, python_version: Optional[str] = None, force_bui
|
|
311
311
|
|
312
312
|
commands: list[str] = [
|
313
313
|
"apt-get update",
|
314
|
-
"apt-get install -y
|
314
|
+
"apt-get install -y "
|
315
|
+
+ "libpq-dev pkg-config cmake git curl wget unzip zip libsqlite3-dev openssh-server vim ffmpeg",
|
315
316
|
_install_cuda_command(),
|
316
317
|
# Install uv since it's faster than pip for installing packages.
|
317
318
|
"pip install uv",
|
modal/experimental/flash.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1
1
|
# Copyright Modal Labs 2025
|
2
2
|
import asyncio
|
3
3
|
import math
|
4
|
+
import os
|
5
|
+
import subprocess
|
4
6
|
import sys
|
5
7
|
import time
|
6
8
|
import traceback
|
@@ -19,28 +21,87 @@ from ..client import _Client
|
|
19
21
|
from ..config import logger
|
20
22
|
from ..exception import InvalidError
|
21
23
|
|
24
|
+
MAX_FAILURES = 3
|
25
|
+
|
22
26
|
|
23
27
|
class _FlashManager:
|
24
|
-
def __init__(
|
28
|
+
def __init__(
|
29
|
+
self,
|
30
|
+
client: _Client,
|
31
|
+
port: int,
|
32
|
+
process: Optional[subprocess.Popen] = None,
|
33
|
+
health_check_url: Optional[str] = None,
|
34
|
+
):
|
25
35
|
self.client = client
|
26
36
|
self.port = port
|
37
|
+
# Health check is not currently being used
|
27
38
|
self.health_check_url = health_check_url
|
39
|
+
self.process = process
|
28
40
|
self.tunnel_manager = _forward_tunnel(port, client=client)
|
29
41
|
self.stopped = False
|
42
|
+
self.num_failures = 0
|
43
|
+
self.task_id = os.environ["MODAL_TASK_ID"]
|
44
|
+
|
45
|
+
async def check_port_connection(self, process: Optional[subprocess.Popen], timeout: int = 10):
|
46
|
+
import socket
|
47
|
+
|
48
|
+
start_time = time.monotonic()
|
49
|
+
|
50
|
+
while time.monotonic() - start_time < timeout:
|
51
|
+
try:
|
52
|
+
if process is not None and process.poll() is not None:
|
53
|
+
return Exception(f"Process {process.pid} exited with code {process.returncode}")
|
54
|
+
with socket.create_connection(("localhost", self.port), timeout=1):
|
55
|
+
return
|
56
|
+
except (ConnectionRefusedError, OSError):
|
57
|
+
await asyncio.sleep(0.1)
|
58
|
+
|
59
|
+
return Exception(f"Waited too long for port {self.port} to start accepting connections")
|
30
60
|
|
31
61
|
async def _start(self):
|
32
62
|
self.tunnel = await self.tunnel_manager.__aenter__()
|
33
|
-
|
34
63
|
parsed_url = urlparse(self.tunnel.url)
|
35
64
|
host = parsed_url.hostname
|
36
65
|
port = parsed_url.port or 443
|
37
66
|
|
38
67
|
self.heartbeat_task = asyncio.create_task(self._run_heartbeat(host, port))
|
68
|
+
self.drain_task = asyncio.create_task(self._drain_container())
|
69
|
+
|
70
|
+
async def _drain_container(self):
|
71
|
+
"""
|
72
|
+
Background task that checks if we've encountered too many failures and drains the container if so.
|
73
|
+
"""
|
74
|
+
while True:
|
75
|
+
try:
|
76
|
+
# Check if the container should be drained (e.g., too many failures)
|
77
|
+
if self.num_failures > MAX_FAILURES:
|
78
|
+
logger.warning(
|
79
|
+
f"[Modal Flash] Draining task {self.task_id} on {self.tunnel.url} due to too many failures."
|
80
|
+
)
|
81
|
+
await self.stop()
|
82
|
+
# handle close upon container exit
|
83
|
+
|
84
|
+
if self.task_id:
|
85
|
+
await self.client.stub.ContainerStop(api_pb2.ContainerStopRequest(task_id=self.task_id))
|
86
|
+
return
|
87
|
+
except asyncio.CancelledError:
|
88
|
+
logger.warning("[Modal Flash] Shutting down...")
|
89
|
+
return
|
90
|
+
except Exception as e:
|
91
|
+
logger.error(f"[Modal Flash] Error draining container: {e}")
|
92
|
+
await asyncio.sleep(1)
|
93
|
+
|
94
|
+
try:
|
95
|
+
await asyncio.sleep(1)
|
96
|
+
except asyncio.CancelledError:
|
97
|
+
logger.warning("[Modal Flash] Shutting down...")
|
98
|
+
return
|
39
99
|
|
40
100
|
async def _run_heartbeat(self, host: str, port: int):
|
41
101
|
first_registration = True
|
42
102
|
while True:
|
43
103
|
try:
|
104
|
+
await self.check_port_connection(process=self.process)
|
44
105
|
resp = await self.client.stub.FlashContainerRegister(
|
45
106
|
api_pb2.FlashContainerRegisterRequest(
|
46
107
|
priority=10,
|
@@ -50,14 +111,25 @@ class _FlashManager:
|
|
50
111
|
),
|
51
112
|
timeout=10,
|
52
113
|
)
|
114
|
+
self.num_failures = 0
|
53
115
|
if first_registration:
|
54
|
-
logger.warning(
|
116
|
+
logger.warning(
|
117
|
+
f"[Modal Flash] Listening at {resp.url} over {self.tunnel.url} for task_id {self.task_id}"
|
118
|
+
)
|
55
119
|
first_registration = False
|
56
120
|
except asyncio.CancelledError:
|
57
121
|
logger.warning("[Modal Flash] Shutting down...")
|
58
122
|
break
|
59
123
|
except Exception as e:
|
60
124
|
logger.error(f"[Modal Flash] Heartbeat failed: {e}")
|
125
|
+
self.num_failures += 1
|
126
|
+
logger.error(
|
127
|
+
f"[Modal Flash] Deregistering container {self.tunnel.url}, num_failures: {self.num_failures}"
|
128
|
+
)
|
129
|
+
await retry_transient_errors(
|
130
|
+
self.client.stub.FlashContainerDeregister,
|
131
|
+
api_pb2.FlashContainerDeregisterRequest(),
|
132
|
+
)
|
61
133
|
|
62
134
|
try:
|
63
135
|
await asyncio.sleep(1)
|
@@ -94,16 +166,17 @@ FlashManager = synchronize_api(_FlashManager)
|
|
94
166
|
|
95
167
|
|
96
168
|
@synchronizer.create_blocking
|
97
|
-
async def flash_forward(
|
169
|
+
async def flash_forward(
|
170
|
+
port: int, process: Optional[subprocess.Popen] = None, health_check_url: Optional[str] = None
|
171
|
+
) -> _FlashManager:
|
98
172
|
"""
|
99
173
|
Forward a port to the Modal Flash service, exposing that port as a stable web endpoint.
|
100
|
-
|
101
174
|
This is a highly experimental method that can break or be removed at any time without warning.
|
102
175
|
Do not use this method unless explicitly instructed to do so by Modal support.
|
103
176
|
"""
|
104
177
|
client = await _Client.from_env()
|
105
178
|
|
106
|
-
manager = _FlashManager(client, port, health_check_url)
|
179
|
+
manager = _FlashManager(client, port, process=process, health_check_url=health_check_url)
|
107
180
|
await manager._start()
|
108
181
|
return manager
|
109
182
|
|
@@ -127,6 +200,8 @@ class _FlashPrometheusAutoscaler:
|
|
127
200
|
scale_down_stabilization_window_seconds: int,
|
128
201
|
autoscaling_interval_seconds: int,
|
129
202
|
):
|
203
|
+
import aiohttp
|
204
|
+
|
130
205
|
if scale_up_stabilization_window_seconds > self._max_window_seconds:
|
131
206
|
raise InvalidError(
|
132
207
|
f"scale_up_stabilization_window_seconds must be less than or equal to {self._max_window_seconds}"
|
@@ -138,8 +213,6 @@ class _FlashPrometheusAutoscaler:
|
|
138
213
|
if target_metric_value <= 0:
|
139
214
|
raise InvalidError("target_metric_value must be greater than 0")
|
140
215
|
|
141
|
-
import aiohttp
|
142
|
-
|
143
216
|
self.client = client
|
144
217
|
self.app_name = app_name
|
145
218
|
self.cls_name = cls_name
|
@@ -200,7 +273,10 @@ class _FlashPrometheusAutoscaler:
|
|
200
273
|
if timestamp >= autoscaling_time - self._max_window_seconds
|
201
274
|
]
|
202
275
|
|
203
|
-
|
276
|
+
if self.metrics_endpoint == "internal":
|
277
|
+
current_target_containers = await self._compute_target_containers_internal(current_replicas)
|
278
|
+
else:
|
279
|
+
current_target_containers = await self._compute_target_containers_prometheus(current_replicas)
|
204
280
|
autoscaling_decisions.append((autoscaling_time, current_target_containers))
|
205
281
|
|
206
282
|
actual_target_containers = self._make_scaling_decision(
|
@@ -213,8 +289,8 @@ class _FlashPrometheusAutoscaler:
|
|
213
289
|
)
|
214
290
|
|
215
291
|
logger.warning(
|
216
|
-
f"[Modal Flash] Scaling to {actual_target_containers} containers.
|
217
|
-
f"made in {time.time() - autoscaling_time} seconds."
|
292
|
+
f"[Modal Flash] Scaling to {actual_target_containers=} containers. "
|
293
|
+
f" Autoscaling decision made in {time.time() - autoscaling_time} seconds."
|
218
294
|
)
|
219
295
|
|
220
296
|
await self.autoscaling_decisions_dict.put(
|
@@ -223,9 +299,7 @@ class _FlashPrometheusAutoscaler:
|
|
223
299
|
)
|
224
300
|
await self.autoscaling_decisions_dict.put("current_replicas", actual_target_containers)
|
225
301
|
|
226
|
-
await self.cls.update_autoscaler(
|
227
|
-
min_containers=actual_target_containers,
|
228
|
-
)
|
302
|
+
await self.cls.update_autoscaler(min_containers=actual_target_containers)
|
229
303
|
|
230
304
|
if time.time() - autoscaling_time < self.autoscaling_interval_seconds:
|
231
305
|
await asyncio.sleep(self.autoscaling_interval_seconds - (time.time() - autoscaling_time))
|
@@ -238,7 +312,55 @@ class _FlashPrometheusAutoscaler:
|
|
238
312
|
logger.error(traceback.format_exc())
|
239
313
|
await asyncio.sleep(self.autoscaling_interval_seconds)
|
240
314
|
|
241
|
-
async def
|
315
|
+
async def _compute_target_containers_internal(self, current_replicas: int) -> int:
|
316
|
+
"""
|
317
|
+
Gets internal metrics from container to autoscale up or down.
|
318
|
+
"""
|
319
|
+
containers = await self._get_all_containers()
|
320
|
+
if len(containers) > current_replicas:
|
321
|
+
logger.info(
|
322
|
+
f"[Modal Flash] Current replicas {current_replicas} is less than the number of containers "
|
323
|
+
f"{len(containers)}. Setting current_replicas = num_containers."
|
324
|
+
)
|
325
|
+
current_replicas = len(containers)
|
326
|
+
|
327
|
+
if current_replicas == 0:
|
328
|
+
return 1
|
329
|
+
|
330
|
+
internal_metrics_list = []
|
331
|
+
for container in containers:
|
332
|
+
internal_metric = await self._get_container_metrics(container.task_id)
|
333
|
+
if internal_metric is None:
|
334
|
+
continue
|
335
|
+
internal_metrics_list.append(getattr(internal_metric.metrics, self.target_metric))
|
336
|
+
|
337
|
+
if not internal_metrics_list:
|
338
|
+
return current_replicas
|
339
|
+
|
340
|
+
avg_internal_metric = sum(internal_metrics_list) / len(internal_metrics_list)
|
341
|
+
|
342
|
+
scale_factor = avg_internal_metric / self.target_metric_value
|
343
|
+
|
344
|
+
desired_replicas = current_replicas
|
345
|
+
if scale_factor > 1 + self.scale_up_tolerance:
|
346
|
+
desired_replicas = math.ceil(current_replicas * scale_factor)
|
347
|
+
elif scale_factor < 1 - self.scale_down_tolerance:
|
348
|
+
desired_replicas = math.ceil(current_replicas * scale_factor)
|
349
|
+
|
350
|
+
logger.warning(
|
351
|
+
f"[Modal Flash] Current replicas: {current_replicas}, "
|
352
|
+
f"avg internal metric `{self.target_metric}`: {avg_internal_metric}, "
|
353
|
+
f"target internal metric value: {self.target_metric_value}, "
|
354
|
+
f"scale factor: {scale_factor}, "
|
355
|
+
f"desired replicas: {desired_replicas}"
|
356
|
+
)
|
357
|
+
|
358
|
+
desired_replicas = max(1, min(desired_replicas, self.max_containers or 1000))
|
359
|
+
return desired_replicas
|
360
|
+
|
361
|
+
async def _compute_target_containers_prometheus(self, current_replicas: int) -> int:
|
362
|
+
# current_replicas is the number of live containers + cold starting containers (not yet live)
|
363
|
+
# containers is the number of live containers that are registered in flash dns
|
242
364
|
containers = await self._get_all_containers()
|
243
365
|
if len(containers) > current_replicas:
|
244
366
|
logger.info(
|
@@ -253,6 +375,7 @@ class _FlashPrometheusAutoscaler:
|
|
253
375
|
target_metric = self.target_metric
|
254
376
|
target_metric_value = float(self.target_metric_value)
|
255
377
|
|
378
|
+
# Gets metrics from prometheus
|
256
379
|
sum_metric = 0
|
257
380
|
containers_with_metrics = 0
|
258
381
|
container_metrics_list = await asyncio.gather(
|
@@ -271,11 +394,17 @@ class _FlashPrometheusAutoscaler:
|
|
271
394
|
sum_metric += container_metrics[target_metric][0].value
|
272
395
|
containers_with_metrics += 1
|
273
396
|
|
397
|
+
# n_containers_missing_metric is the number of unhealthy containers + number of cold starting containers
|
274
398
|
n_containers_missing_metric = current_replicas - containers_with_metrics
|
399
|
+
# n_containers_unhealthy is the number of live containers that are not emitting metrics i.e. unhealthy
|
400
|
+
n_containers_unhealthy = len(containers) - containers_with_metrics
|
401
|
+
|
402
|
+
# Scale up assuming that every unhealthy container is at 2x the target metric value.
|
403
|
+
scale_up_target_metric_value = (sum_metric + n_containers_unhealthy * target_metric_value) / (
|
404
|
+
(containers_with_metrics + n_containers_unhealthy) or 1
|
405
|
+
)
|
275
406
|
|
276
|
-
# Scale
|
277
|
-
# value of the metric when scaling up and the maximum value of the metric when scaling down.
|
278
|
-
scale_up_target_metric_value = sum_metric / current_replicas
|
407
|
+
# Scale down assuming that every container (including cold starting containers) are at the target metric value.
|
279
408
|
scale_down_target_metric_value = (
|
280
409
|
sum_metric + n_containers_missing_metric * target_metric_value
|
281
410
|
) / current_replicas
|
@@ -290,9 +419,14 @@ class _FlashPrometheusAutoscaler:
|
|
290
419
|
desired_replicas = math.ceil(current_replicas * scale_down_ratio)
|
291
420
|
|
292
421
|
logger.warning(
|
293
|
-
f"[Modal Flash] Current replicas: {current_replicas},
|
294
|
-
f"
|
295
|
-
f"
|
422
|
+
f"[Modal Flash] Current replicas: {current_replicas}, "
|
423
|
+
f"target metric value: {target_metric_value}, "
|
424
|
+
f"current sum of metric values: {sum_metric}, "
|
425
|
+
f"number of containers with metrics: {containers_with_metrics}, "
|
426
|
+
f"number of containers unhealthy: {n_containers_unhealthy}, "
|
427
|
+
f"number of containers missing metric (includes unhealthy): {n_containers_missing_metric}, "
|
428
|
+
f"scale up ratio: {scale_up_ratio}, "
|
429
|
+
f"scale down ratio: {scale_down_ratio}, "
|
296
430
|
f"desired replicas: {desired_replicas}"
|
297
431
|
)
|
298
432
|
|
@@ -303,20 +437,42 @@ class _FlashPrometheusAutoscaler:
|
|
303
437
|
|
304
438
|
# Fetch the metrics from the endpoint
|
305
439
|
try:
|
306
|
-
response = await self.http_client.get(url)
|
440
|
+
response = await self.http_client.get(url, timeout=3)
|
307
441
|
response.raise_for_status()
|
442
|
+
except asyncio.TimeoutError:
|
443
|
+
logger.warning(f"[Modal Flash] Timeout getting metrics from {url}")
|
444
|
+
return None
|
308
445
|
except Exception as e:
|
309
446
|
logger.warning(f"[Modal Flash] Error getting metrics from {url}: {e}")
|
310
447
|
return None
|
311
448
|
|
449
|
+
# Read body with timeout/error handling and parse Prometheus metrics
|
450
|
+
try:
|
451
|
+
text_body = await response.text()
|
452
|
+
except asyncio.TimeoutError:
|
453
|
+
logger.warning(f"[Modal Flash] Timeout reading metrics body from {url}")
|
454
|
+
return None
|
455
|
+
except Exception as e:
|
456
|
+
logger.warning(f"[Modal Flash] Error reading metrics body from {url}: {e}")
|
457
|
+
return None
|
458
|
+
|
312
459
|
# Parse the text-based Prometheus metrics format
|
313
460
|
metrics: dict[str, list[Sample]] = defaultdict(list)
|
314
|
-
for family in text_string_to_metric_families(
|
461
|
+
for family in text_string_to_metric_families(text_body):
|
315
462
|
for sample in family.samples:
|
316
463
|
metrics[sample.name] += [sample]
|
317
464
|
|
318
465
|
return metrics
|
319
466
|
|
467
|
+
async def _get_container_metrics(self, container_id: str) -> Optional[api_pb2.TaskGetAutoscalingMetricsResponse]:
|
468
|
+
req = api_pb2.TaskGetAutoscalingMetricsRequest(task_id=container_id)
|
469
|
+
try:
|
470
|
+
resp = await retry_transient_errors(self.client.stub.TaskGetAutoscalingMetrics, req)
|
471
|
+
return resp
|
472
|
+
except Exception as e:
|
473
|
+
logger.warning(f"[Modal Flash] Error getting metrics for container {container_id}: {e}")
|
474
|
+
return None
|
475
|
+
|
320
476
|
async def _get_all_containers(self):
|
321
477
|
req = api_pb2.FlashContainerListRequest(function_id=self.fn.object_id)
|
322
478
|
resp = await retry_transient_errors(self.client.stub.FlashContainerList, req)
|
@@ -395,10 +551,14 @@ async def flash_prometheus_autoscaler(
|
|
395
551
|
app_name: str,
|
396
552
|
cls_name: str,
|
397
553
|
# Endpoint to fetch metrics from. Must be in Prometheus format. Example: "/metrics"
|
554
|
+
# If metrics_endpoint is "internal", we will use containers' internal metrics to autoscale instead.
|
398
555
|
metrics_endpoint: str,
|
399
556
|
# Target metric to autoscale on. Example: "vllm:num_requests_running"
|
557
|
+
# If metrics_endpoint is "internal", target_metrics options are: [cpu_usage_percent, memory_usage_percent]
|
400
558
|
target_metric: str,
|
401
559
|
# Target metric value. Example: 25
|
560
|
+
# If metrics_endpoint is "internal", target_metric_value is a percentage value between 0.1 and 1.0 (inclusive),
|
561
|
+
# indicating container's usage of that metric.
|
402
562
|
target_metric_value: float,
|
403
563
|
min_containers: Optional[int] = None,
|
404
564
|
max_containers: Optional[int] = None,
|