modal 1.0.3.dev10__py3-none-any.whl → 1.2.3.dev7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of modal might be problematic. Click here for more details.

Files changed (160) hide show
  1. modal/__init__.py +0 -2
  2. modal/__main__.py +3 -4
  3. modal/_billing.py +80 -0
  4. modal/_clustered_functions.py +7 -3
  5. modal/_clustered_functions.pyi +15 -3
  6. modal/_container_entrypoint.py +51 -69
  7. modal/_functions.py +508 -240
  8. modal/_grpc_client.py +171 -0
  9. modal/_load_context.py +105 -0
  10. modal/_object.py +81 -21
  11. modal/_output.py +58 -45
  12. modal/_partial_function.py +48 -73
  13. modal/_pty.py +7 -3
  14. modal/_resolver.py +26 -46
  15. modal/_runtime/asgi.py +4 -3
  16. modal/_runtime/container_io_manager.py +358 -220
  17. modal/_runtime/container_io_manager.pyi +296 -101
  18. modal/_runtime/execution_context.py +18 -2
  19. modal/_runtime/execution_context.pyi +64 -7
  20. modal/_runtime/gpu_memory_snapshot.py +262 -57
  21. modal/_runtime/user_code_imports.py +28 -58
  22. modal/_serialization.py +90 -6
  23. modal/_traceback.py +42 -1
  24. modal/_tunnel.pyi +380 -12
  25. modal/_utils/async_utils.py +84 -29
  26. modal/_utils/auth_token_manager.py +111 -0
  27. modal/_utils/blob_utils.py +181 -58
  28. modal/_utils/deprecation.py +19 -0
  29. modal/_utils/function_utils.py +91 -47
  30. modal/_utils/grpc_utils.py +89 -66
  31. modal/_utils/mount_utils.py +26 -1
  32. modal/_utils/name_utils.py +17 -3
  33. modal/_utils/task_command_router_client.py +536 -0
  34. modal/_utils/time_utils.py +34 -6
  35. modal/app.py +256 -88
  36. modal/app.pyi +909 -92
  37. modal/billing.py +5 -0
  38. modal/builder/2025.06.txt +18 -0
  39. modal/builder/PREVIEW.txt +18 -0
  40. modal/builder/base-images.json +58 -0
  41. modal/cli/_download.py +19 -3
  42. modal/cli/_traceback.py +3 -2
  43. modal/cli/app.py +4 -4
  44. modal/cli/cluster.py +15 -7
  45. modal/cli/config.py +5 -3
  46. modal/cli/container.py +7 -6
  47. modal/cli/dict.py +22 -16
  48. modal/cli/entry_point.py +12 -5
  49. modal/cli/environment.py +5 -4
  50. modal/cli/import_refs.py +3 -3
  51. modal/cli/launch.py +102 -5
  52. modal/cli/network_file_system.py +11 -12
  53. modal/cli/profile.py +3 -2
  54. modal/cli/programs/launch_instance_ssh.py +94 -0
  55. modal/cli/programs/run_jupyter.py +1 -1
  56. modal/cli/programs/run_marimo.py +95 -0
  57. modal/cli/programs/vscode.py +1 -1
  58. modal/cli/queues.py +57 -26
  59. modal/cli/run.py +91 -23
  60. modal/cli/secret.py +48 -22
  61. modal/cli/token.py +7 -8
  62. modal/cli/utils.py +4 -7
  63. modal/cli/volume.py +31 -25
  64. modal/client.py +15 -85
  65. modal/client.pyi +183 -62
  66. modal/cloud_bucket_mount.py +5 -3
  67. modal/cloud_bucket_mount.pyi +197 -5
  68. modal/cls.py +200 -126
  69. modal/cls.pyi +446 -68
  70. modal/config.py +29 -11
  71. modal/container_process.py +319 -19
  72. modal/container_process.pyi +190 -20
  73. modal/dict.py +290 -71
  74. modal/dict.pyi +835 -83
  75. modal/environments.py +15 -27
  76. modal/environments.pyi +46 -24
  77. modal/exception.py +14 -2
  78. modal/experimental/__init__.py +194 -40
  79. modal/experimental/flash.py +618 -0
  80. modal/experimental/flash.pyi +380 -0
  81. modal/experimental/ipython.py +11 -7
  82. modal/file_io.py +29 -36
  83. modal/file_io.pyi +251 -53
  84. modal/file_pattern_matcher.py +56 -16
  85. modal/functions.pyi +673 -92
  86. modal/gpu.py +1 -1
  87. modal/image.py +528 -176
  88. modal/image.pyi +1572 -145
  89. modal/io_streams.py +458 -128
  90. modal/io_streams.pyi +433 -52
  91. modal/mount.py +216 -151
  92. modal/mount.pyi +225 -78
  93. modal/network_file_system.py +45 -62
  94. modal/network_file_system.pyi +277 -56
  95. modal/object.pyi +93 -17
  96. modal/parallel_map.py +942 -129
  97. modal/parallel_map.pyi +294 -15
  98. modal/partial_function.py +0 -2
  99. modal/partial_function.pyi +234 -19
  100. modal/proxy.py +17 -8
  101. modal/proxy.pyi +36 -3
  102. modal/queue.py +270 -65
  103. modal/queue.pyi +817 -57
  104. modal/runner.py +115 -101
  105. modal/runner.pyi +205 -49
  106. modal/sandbox.py +512 -136
  107. modal/sandbox.pyi +845 -111
  108. modal/schedule.py +1 -1
  109. modal/secret.py +300 -70
  110. modal/secret.pyi +589 -34
  111. modal/serving.py +7 -11
  112. modal/serving.pyi +7 -8
  113. modal/snapshot.py +11 -8
  114. modal/snapshot.pyi +25 -4
  115. modal/token_flow.py +4 -4
  116. modal/token_flow.pyi +28 -8
  117. modal/volume.py +416 -158
  118. modal/volume.pyi +1117 -121
  119. {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/METADATA +10 -9
  120. modal-1.2.3.dev7.dist-info/RECORD +195 -0
  121. modal_docs/mdmd/mdmd.py +17 -4
  122. modal_proto/api.proto +534 -79
  123. modal_proto/api_grpc.py +337 -1
  124. modal_proto/api_pb2.py +1522 -968
  125. modal_proto/api_pb2.pyi +1619 -134
  126. modal_proto/api_pb2_grpc.py +699 -4
  127. modal_proto/api_pb2_grpc.pyi +226 -14
  128. modal_proto/modal_api_grpc.py +175 -154
  129. modal_proto/sandbox_router.proto +145 -0
  130. modal_proto/sandbox_router_grpc.py +105 -0
  131. modal_proto/sandbox_router_pb2.py +149 -0
  132. modal_proto/sandbox_router_pb2.pyi +333 -0
  133. modal_proto/sandbox_router_pb2_grpc.py +203 -0
  134. modal_proto/sandbox_router_pb2_grpc.pyi +75 -0
  135. modal_proto/task_command_router.proto +144 -0
  136. modal_proto/task_command_router_grpc.py +105 -0
  137. modal_proto/task_command_router_pb2.py +149 -0
  138. modal_proto/task_command_router_pb2.pyi +333 -0
  139. modal_proto/task_command_router_pb2_grpc.py +203 -0
  140. modal_proto/task_command_router_pb2_grpc.pyi +75 -0
  141. modal_version/__init__.py +1 -1
  142. modal/requirements/PREVIEW.txt +0 -16
  143. modal/requirements/base-images.json +0 -26
  144. modal-1.0.3.dev10.dist-info/RECORD +0 -179
  145. modal_proto/modal_options_grpc.py +0 -3
  146. modal_proto/options.proto +0 -19
  147. modal_proto/options_grpc.py +0 -3
  148. modal_proto/options_pb2.py +0 -35
  149. modal_proto/options_pb2.pyi +0 -20
  150. modal_proto/options_pb2_grpc.py +0 -4
  151. modal_proto/options_pb2_grpc.pyi +0 -7
  152. /modal/{requirements → builder}/2023.12.312.txt +0 -0
  153. /modal/{requirements → builder}/2023.12.txt +0 -0
  154. /modal/{requirements → builder}/2024.04.txt +0 -0
  155. /modal/{requirements → builder}/2024.10.txt +0 -0
  156. /modal/{requirements → builder}/README.md +0 -0
  157. {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/WHEEL +0 -0
  158. {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/entry_points.txt +0 -0
  159. {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/licenses/LICENSE +0 -0
  160. {modal-1.0.3.dev10.dist-info → modal-1.2.3.dev7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,618 @@
1
+ # Copyright Modal Labs 2025
2
+ import asyncio
3
+ import math
4
+ import os
5
+ import subprocess
6
+ import sys
7
+ import time
8
+ import traceback
9
+ from collections import defaultdict
10
+ from typing import Any, Optional
11
+ from urllib.parse import urlparse
12
+
13
+ from modal.cls import _Cls
14
+ from modal.dict import _Dict
15
+ from modal_proto import api_pb2
16
+
17
+ from .._tunnel import _forward as _forward_tunnel
18
+ from .._utils.async_utils import synchronize_api, synchronizer
19
+ from ..client import _Client
20
+ from ..config import logger
21
+ from ..exception import InvalidError
22
+
23
+ _MAX_FAILURES = 10
24
+
25
+
26
+ class _FlashManager:
27
+ def __init__(
28
+ self,
29
+ client: _Client,
30
+ port: int,
31
+ process: Optional[subprocess.Popen] = None,
32
+ health_check_url: Optional[str] = None,
33
+ ):
34
+ self.client = client
35
+ self.port = port
36
+ # Health check is not currently being used
37
+ self.health_check_url = health_check_url
38
+ self.process = process
39
+ self.tunnel_manager = _forward_tunnel(port, client=client)
40
+ self.stopped = False
41
+ self.num_failures = 0
42
+ self.task_id = os.environ["MODAL_TASK_ID"]
43
+
44
+ async def is_port_connection_healthy(
45
+ self, process: Optional[subprocess.Popen], timeout: float = 0.5
46
+ ) -> tuple[bool, Optional[Exception]]:
47
+ import socket
48
+
49
+ start_time = time.monotonic()
50
+
51
+ while time.monotonic() - start_time < timeout:
52
+ try:
53
+ if process is not None and process.poll() is not None:
54
+ return False, Exception(f"Process {process.pid} exited with code {process.returncode}")
55
+ with socket.create_connection(("localhost", self.port), timeout=0.5):
56
+ return True, None
57
+ except (ConnectionRefusedError, OSError):
58
+ await asyncio.sleep(0.1)
59
+
60
+ return False, Exception(f"Waited too long for port {self.port} to start accepting connections")
61
+
62
+ async def _start(self):
63
+ self.tunnel = await self.tunnel_manager.__aenter__()
64
+ parsed_url = urlparse(self.tunnel.url)
65
+ host = parsed_url.hostname
66
+ port = parsed_url.port or 443
67
+
68
+ self.heartbeat_task = asyncio.create_task(self._run_heartbeat(host, port))
69
+ self.drain_task = asyncio.create_task(self._drain_container())
70
+
71
+ async def _drain_container(self):
72
+ """
73
+ Background task that checks if we've encountered too many failures and drains the container if so.
74
+ """
75
+ while True:
76
+ try:
77
+ # Check if the container should be drained (e.g., too many failures)
78
+ if self.num_failures > _MAX_FAILURES:
79
+ logger.warning(
80
+ f"[Modal Flash] Draining task {self.task_id} on {self.tunnel.url} due to too many failures."
81
+ )
82
+ await self.stop()
83
+ # handle close upon container exit
84
+
85
+ if self.task_id:
86
+ await self.client.stub.ContainerStop(api_pb2.ContainerStopRequest(task_id=self.task_id))
87
+ return
88
+ except asyncio.CancelledError:
89
+ logger.warning("[Modal Flash] Shutting down...")
90
+ return
91
+ except Exception as e:
92
+ logger.error(f"[Modal Flash] Error draining container: {e}")
93
+ await asyncio.sleep(1)
94
+
95
+ try:
96
+ await asyncio.sleep(1)
97
+ except asyncio.CancelledError:
98
+ logger.warning("[Modal Flash] Shutting down...")
99
+ return
100
+
101
+ async def _run_heartbeat(self, host: str, port: int):
102
+ first_registration = True
103
+ while True:
104
+ try:
105
+ port_check_resp, port_check_error = await self.is_port_connection_healthy(process=self.process)
106
+ if port_check_resp:
107
+ resp = await self.client.stub.FlashContainerRegister(
108
+ api_pb2.FlashContainerRegisterRequest(
109
+ priority=10,
110
+ weight=5,
111
+ host=host,
112
+ port=port,
113
+ ),
114
+ timeout=10,
115
+ retry=None,
116
+ )
117
+ self.num_failures = 0
118
+ if first_registration:
119
+ logger.warning(
120
+ f"[Modal Flash] Listening at {resp.url} over {self.tunnel.url} for task_id {self.task_id}"
121
+ )
122
+ first_registration = False
123
+ else:
124
+ logger.error(
125
+ f"[Modal Flash] Deregistering container {self.task_id} on {self.tunnel.url} "
126
+ f"due to error: {port_check_error}, num_failures: {self.num_failures}"
127
+ )
128
+ self.num_failures += 1
129
+ await self.client.stub.FlashContainerDeregister(api_pb2.FlashContainerDeregisterRequest())
130
+
131
+ except asyncio.CancelledError:
132
+ logger.warning("[Modal Flash] Shutting down...")
133
+ break
134
+ except Exception as e:
135
+ logger.error(f"[Modal Flash] Heartbeat failed: {e}")
136
+
137
+ try:
138
+ await asyncio.sleep(1)
139
+ except asyncio.CancelledError:
140
+ logger.warning("[Modal Flash] Shutting down...")
141
+ break
142
+
143
+ def get_container_url(self):
144
+ # WARNING: Try not to use this method; we aren't sure if we will keep it.
145
+ return self.tunnel.url
146
+
147
+ async def stop(self):
148
+ self.heartbeat_task.cancel()
149
+ await self.client.stub.FlashContainerDeregister(api_pb2.FlashContainerDeregisterRequest())
150
+
151
+ self.stopped = True
152
+ logger.warning(f"[Modal Flash] No longer accepting new requests on {self.tunnel.url}.")
153
+
154
+ # NOTE(gongy): We skip calling TunnelStop to avoid interrupting in-flight requests.
155
+ # It is up to the user to wait after calling .stop() to drain in-flight requests.
156
+
157
+ async def close(self):
158
+ if not self.stopped:
159
+ await self.stop()
160
+
161
+ logger.warning(f"[Modal Flash] Closing tunnel on {self.tunnel.url}.")
162
+ await self.tunnel_manager.__aexit__(*sys.exc_info())
163
+
164
+
165
+ FlashManager = synchronize_api(_FlashManager)
166
+
167
+
168
+ @synchronizer.create_blocking
169
+ async def flash_forward(
170
+ port: int,
171
+ process: Optional[subprocess.Popen] = None,
172
+ health_check_url: Optional[str] = None,
173
+ ) -> _FlashManager:
174
+ """
175
+ Forward a port to the Modal Flash service, exposing that port as a stable web endpoint.
176
+ This is a highly experimental method that can break or be removed at any time without warning.
177
+ Do not use this method unless explicitly instructed to do so by Modal support.
178
+ """
179
+ client = await _Client.from_env()
180
+
181
+ manager = _FlashManager(client, port, process=process, health_check_url=health_check_url)
182
+ await manager._start()
183
+ return manager
184
+
185
+
186
+ class _FlashPrometheusAutoscaler:
187
+ _max_window_seconds = 60 * 60
188
+
189
+ def __init__(
190
+ self,
191
+ client: _Client,
192
+ app_name: str,
193
+ cls_name: str,
194
+ metrics_endpoint: str,
195
+ target_metric: str,
196
+ target_metric_value: float,
197
+ min_containers: Optional[int],
198
+ max_containers: Optional[int],
199
+ buffer_containers: Optional[int],
200
+ scale_up_tolerance: float,
201
+ scale_down_tolerance: float,
202
+ scale_up_stabilization_window_seconds: int,
203
+ scale_down_stabilization_window_seconds: int,
204
+ autoscaling_interval_seconds: int,
205
+ ):
206
+ import aiohttp
207
+
208
+ if scale_up_stabilization_window_seconds > self._max_window_seconds:
209
+ raise InvalidError(
210
+ f"scale_up_stabilization_window_seconds must be less than or equal to {self._max_window_seconds}"
211
+ )
212
+ if scale_down_stabilization_window_seconds > self._max_window_seconds:
213
+ raise InvalidError(
214
+ f"scale_down_stabilization_window_seconds must be less than or equal to {self._max_window_seconds}"
215
+ )
216
+ if target_metric_value <= 0:
217
+ raise InvalidError("target_metric_value must be greater than 0")
218
+
219
+ self.client = client
220
+ self.app_name = app_name
221
+ self.cls_name = cls_name
222
+ self.metrics_endpoint = metrics_endpoint
223
+ self.target_metric = target_metric
224
+ self.target_metric_value = target_metric_value
225
+ self.min_containers = min_containers
226
+ self.max_containers = max_containers
227
+ self.buffer_containers = buffer_containers
228
+ self.scale_up_tolerance = scale_up_tolerance
229
+ self.scale_down_tolerance = scale_down_tolerance
230
+ self.scale_up_stabilization_window_seconds = scale_up_stabilization_window_seconds
231
+ self.scale_down_stabilization_window_seconds = scale_down_stabilization_window_seconds
232
+ self.autoscaling_interval_seconds = autoscaling_interval_seconds
233
+
234
+ FlashClass = _Cls.from_name(app_name, cls_name)
235
+ self.fn = FlashClass._class_service_function
236
+ self.cls = FlashClass()
237
+
238
+ self.http_client = aiohttp.ClientSession()
239
+ self.autoscaling_decisions_dict = _Dict.from_name(
240
+ f"{app_name}-{cls_name}-autoscaling-decisions",
241
+ create_if_missing=True,
242
+ )
243
+
244
+ self.autoscaler_thread = None
245
+
246
+ async def start(self):
247
+ await self.fn.hydrate(client=self.client)
248
+ self.autoscaler_thread = asyncio.create_task(self._run_autoscaler_loop())
249
+
250
+ async def _run_autoscaler_loop(self):
251
+ while True:
252
+ try:
253
+ autoscaling_time = time.time()
254
+
255
+ current_replicas = await self.autoscaling_decisions_dict.get("current_replicas", 0)
256
+ autoscaling_decisions = await self.autoscaling_decisions_dict.get("autoscaling_decisions", [])
257
+ if not isinstance(current_replicas, int):
258
+ logger.warning(f"[Modal Flash] Invalid item in autoscaling decisions: {current_replicas}")
259
+ current_replicas = 0
260
+ if not isinstance(autoscaling_decisions, list):
261
+ logger.warning(f"[Modal Flash] Invalid item in autoscaling decisions: {autoscaling_decisions}")
262
+ autoscaling_decisions = []
263
+ for item in autoscaling_decisions:
264
+ if (
265
+ not isinstance(item, tuple)
266
+ or len(item) != 2
267
+ or not isinstance(item[0], float)
268
+ or not isinstance(item[1], int)
269
+ ):
270
+ logger.warning(f"[Modal Flash] Invalid item in autoscaling decisions: {item}")
271
+ autoscaling_decisions = []
272
+ break
273
+
274
+ autoscaling_decisions = [
275
+ (timestamp, decision)
276
+ for timestamp, decision in autoscaling_decisions
277
+ if timestamp >= autoscaling_time - self._max_window_seconds
278
+ ]
279
+
280
+ current_target_containers = await self._compute_target_containers(current_replicas=current_replicas)
281
+ autoscaling_decisions.append((autoscaling_time, current_target_containers))
282
+
283
+ actual_target_containers = self._make_scaling_decision(
284
+ current_replicas,
285
+ autoscaling_decisions,
286
+ scale_up_stabilization_window_seconds=self.scale_up_stabilization_window_seconds,
287
+ scale_down_stabilization_window_seconds=self.scale_down_stabilization_window_seconds,
288
+ min_containers=self.min_containers,
289
+ max_containers=self.max_containers,
290
+ buffer_containers=self.buffer_containers,
291
+ )
292
+
293
+ logger.warning(
294
+ f"[Modal Flash] Scaling to {actual_target_containers=} containers. "
295
+ f" Autoscaling decision made in {time.time() - autoscaling_time} seconds."
296
+ )
297
+
298
+ await self.autoscaling_decisions_dict.put(
299
+ "autoscaling_decisions",
300
+ autoscaling_decisions,
301
+ )
302
+ await self.autoscaling_decisions_dict.put("current_replicas", actual_target_containers)
303
+
304
+ await self._set_target_slots(actual_target_containers)
305
+
306
+ if time.time() - autoscaling_time < self.autoscaling_interval_seconds:
307
+ await asyncio.sleep(self.autoscaling_interval_seconds - (time.time() - autoscaling_time))
308
+ except asyncio.CancelledError:
309
+ logger.warning("[Modal Flash] Shutting down autoscaler...")
310
+ await self.http_client.close()
311
+ break
312
+ except Exception as e:
313
+ logger.error(f"[Modal Flash] Error in autoscaler: {e}")
314
+ logger.error(traceback.format_exc())
315
+ await asyncio.sleep(self.autoscaling_interval_seconds)
316
+
317
+ async def _compute_target_containers(self, current_replicas: int) -> int:
318
+ """
319
+ Gets metrics from container to autoscale up or down.
320
+ """
321
+ containers = await self._get_all_containers()
322
+ if len(containers) > current_replicas:
323
+ logger.info(
324
+ f"[Modal Flash] Current replicas {current_replicas} is less than the number of containers "
325
+ f"{len(containers)}. Setting current_replicas = num_containers."
326
+ )
327
+ current_replicas = len(containers)
328
+
329
+ if current_replicas == 0:
330
+ return 1
331
+
332
+ # Get metrics based on autoscaler type
333
+ sum_metric, n_containers_with_metrics = await self._get_scaling_info(containers)
334
+
335
+ desired_replicas = self._calculate_desired_replicas(
336
+ n_current_replicas=current_replicas,
337
+ sum_metric=sum_metric,
338
+ n_containers_with_metrics=n_containers_with_metrics,
339
+ n_total_containers=len(containers),
340
+ target_metric_value=self.target_metric_value,
341
+ )
342
+
343
+ return max(1, desired_replicas)
344
+
345
+ def _calculate_desired_replicas(
346
+ self,
347
+ n_current_replicas: int,
348
+ sum_metric: float,
349
+ n_containers_with_metrics: int,
350
+ n_total_containers: int,
351
+ target_metric_value: float,
352
+ ) -> int:
353
+ """
354
+ Calculate the desired number of replicas to autoscale to.
355
+ """
356
+ buffer_containers = self.buffer_containers or 0
357
+
358
+ # n_containers_missing = number of unhealthy containers + number of containers not registered in flash dns
359
+ n_containers_missing_metric = n_current_replicas - n_containers_with_metrics
360
+ # n_containers_unhealthy = number of dns registered containers that are not emitting metrics
361
+ n_containers_unhealthy = n_total_containers - n_containers_with_metrics
362
+
363
+ # Max is used to handle case when buffer_containers are first initialized.
364
+ num_provisioned_containers = max(n_current_replicas - buffer_containers, 1)
365
+
366
+ # Scale up assuming that every unhealthy container is at 1.5 x (1 + scale_up_tolerance) the target metric value.
367
+ # This way if all containers are unhealthy, we will increase our number of containers.
368
+ scale_up_target_metric_value = (
369
+ sum_metric + 1.5 * (1 + self.scale_up_tolerance) * n_containers_unhealthy * target_metric_value
370
+ ) / (num_provisioned_containers)
371
+
372
+ # Scale down assuming that every container (including cold starting containers) are at the target metric value.
373
+ # The denominator is just num_provisioned_containers because we don't want to account for the buffer containers.
374
+ scale_down_target_metric_value = (sum_metric + n_containers_missing_metric * target_metric_value) / (
375
+ num_provisioned_containers
376
+ )
377
+
378
+ scale_up_ratio = scale_up_target_metric_value / target_metric_value
379
+ scale_down_ratio = scale_down_target_metric_value / target_metric_value
380
+
381
+ desired_replicas = num_provisioned_containers
382
+ if scale_up_ratio > 1 + self.scale_up_tolerance:
383
+ desired_replicas = math.ceil(desired_replicas * scale_up_ratio)
384
+ elif scale_down_ratio < 1 - self.scale_down_tolerance:
385
+ desired_replicas = math.ceil(desired_replicas * scale_down_ratio)
386
+
387
+ logger.warning(
388
+ f"[Modal Flash] Current replicas: {n_current_replicas}, "
389
+ f"target metric: {self.target_metric}"
390
+ f"target metric value: {target_metric_value}, "
391
+ f"current sum of metric values: {sum_metric}, "
392
+ f"number of containers with metrics: {n_containers_with_metrics}, "
393
+ f"number of containers unhealthy: {n_containers_unhealthy}, "
394
+ f"number of containers missing metric (includes unhealthy): {n_containers_missing_metric}, "
395
+ f"number of provisioned containers: {num_provisioned_containers}, "
396
+ f"scale up ratio: {scale_up_ratio}, "
397
+ f"scale down ratio: {scale_down_ratio}, "
398
+ f"desired replicas: {desired_replicas}"
399
+ )
400
+
401
+ return desired_replicas
402
+
403
+ async def _get_scaling_info(self, containers) -> tuple[float, int]:
404
+ """Get metrics using container exposed metrics endpoints."""
405
+ sum_metric = 0
406
+ n_containers_with_metrics = 0
407
+
408
+ container_metrics_list = await asyncio.gather(
409
+ *[
410
+ self._get_metrics(f"https://{container.host}:{container.port}/{self.metrics_endpoint}")
411
+ for container in containers
412
+ ]
413
+ )
414
+
415
+ for container_metrics in container_metrics_list:
416
+ if (
417
+ container_metrics is None
418
+ or self.target_metric not in container_metrics
419
+ or len(container_metrics[self.target_metric]) == 0
420
+ ):
421
+ continue
422
+ sum_metric += container_metrics[self.target_metric][0].value
423
+ n_containers_with_metrics += 1
424
+
425
+ return sum_metric, n_containers_with_metrics
426
+
427
+ async def _get_metrics(self, url: str) -> Optional[dict[str, list[Any]]]: # technically any should be Sample
428
+ from prometheus_client.parser import Sample, text_string_to_metric_families
429
+
430
+ # Fetch the metrics from the endpoint
431
+ try:
432
+ response = await self.http_client.get(url, timeout=3)
433
+ response.raise_for_status()
434
+ except asyncio.TimeoutError:
435
+ logger.warning(f"[Modal Flash] Timeout getting metrics from {url}")
436
+ return None
437
+ except Exception as e:
438
+ logger.warning(f"[Modal Flash] Error getting metrics from {url}: {e}")
439
+ return None
440
+
441
+ # Read body with timeout/error handling and parse Prometheus metrics
442
+ try:
443
+ text_body = await response.text()
444
+ except asyncio.TimeoutError:
445
+ logger.warning(f"[Modal Flash] Timeout reading metrics body from {url}")
446
+ return None
447
+ except Exception as e:
448
+ logger.warning(f"[Modal Flash] Error reading metrics body from {url}: {e}")
449
+ return None
450
+
451
+ # Parse the text-based Prometheus metrics format
452
+ metrics: dict[str, list[Sample]] = defaultdict(list)
453
+ for family in text_string_to_metric_families(text_body):
454
+ for sample in family.samples:
455
+ metrics[sample.name] += [sample]
456
+
457
+ return metrics
458
+
459
+ async def _get_all_containers(self):
460
+ req = api_pb2.FlashContainerListRequest(function_id=self.fn.object_id)
461
+ resp = await self.client.stub.FlashContainerList(req)
462
+ return resp.containers
463
+
464
+ async def _set_target_slots(self, target_slots: int):
465
+ req = api_pb2.FlashSetTargetSlotsMetricsRequest(function_id=self.fn.object_id, target_slots=target_slots)
466
+ await self.client.stub.FlashSetTargetSlotsMetrics(req)
467
+ return
468
+
469
+ def _make_scaling_decision(
470
+ self,
471
+ current_replicas: int,
472
+ autoscaling_decisions: list[tuple[float, int]],
473
+ scale_up_stabilization_window_seconds: int = 0,
474
+ scale_down_stabilization_window_seconds: int = 60 * 5,
475
+ min_containers: Optional[int] = None,
476
+ max_containers: Optional[int] = None,
477
+ buffer_containers: Optional[int] = None,
478
+ ) -> int:
479
+ """
480
+ Return the target number of containers following (simplified) Kubernetes HPA
481
+ stabilization-window semantics.
482
+
483
+ Args:
484
+ current_replicas: Current number of running Pods/containers.
485
+ autoscaling_decisions: List of (timestamp, desired_replicas) pairs, where
486
+ timestamp is a UNIX epoch float (seconds).
487
+ The list *must* contain at least one entry and should
488
+ already include the most-recent measurement.
489
+ scale_up_stabilization_window_seconds: 0 disables the up-window.
490
+ scale_down_stabilization_window_seconds: 0 disables the down-window.
491
+ min_containers / max_containers: Clamp the final decision to this range.
492
+
493
+ Returns:
494
+ The target number of containers.
495
+ """
496
+
497
+ if not autoscaling_decisions:
498
+ # Without data we can’t make a new decision – stay where we are.
499
+ return current_replicas
500
+
501
+ # Sort just once in case the caller didn’t: newest record is last.
502
+ autoscaling_decisions.sort(key=lambda rec: rec[0])
503
+ now_ts, latest_desired = autoscaling_decisions[-1]
504
+
505
+ if latest_desired > current_replicas:
506
+ # ---- SCALE-UP path ----
507
+ window_start = now_ts - scale_up_stabilization_window_seconds
508
+ # Consider only records *inside* the window.
509
+ desired_candidates = [desired for ts, desired in autoscaling_decisions if ts >= window_start]
510
+ # Use the *minimum* so that any temporary dip blocks the scale-up.
511
+ candidate = min(desired_candidates) if desired_candidates else latest_desired
512
+ new_replicas = max(current_replicas, candidate) # never scale *down* here
513
+ elif latest_desired < current_replicas:
514
+ # ---- SCALE-DOWN path ----
515
+ window_start = now_ts - scale_down_stabilization_window_seconds
516
+ desired_candidates = [desired for ts, desired in autoscaling_decisions if ts >= window_start]
517
+ # Use the *maximum* so that any temporary spike blocks the scale-down.
518
+ candidate = max(desired_candidates) if desired_candidates else latest_desired
519
+ new_replicas = min(current_replicas, candidate) # never scale *up* here
520
+ else:
521
+ # No change requested.
522
+ new_replicas = current_replicas
523
+
524
+ # Clamp to [min_containers, max_containers].
525
+ if min_containers is not None:
526
+ new_replicas = max(min_containers, new_replicas)
527
+ if max_containers is not None:
528
+ new_replicas = min(max_containers, new_replicas)
529
+
530
+ if buffer_containers is not None:
531
+ new_replicas += buffer_containers
532
+
533
+ return new_replicas
534
+
535
+ async def stop(self):
536
+ self.autoscaler_thread.cancel()
537
+ await self.autoscaler_thread
538
+
539
+
540
+ FlashPrometheusAutoscaler = synchronize_api(_FlashPrometheusAutoscaler)
541
+
542
+
543
+ @synchronizer.create_blocking
544
+ async def flash_prometheus_autoscaler(
545
+ app_name: str,
546
+ cls_name: str,
547
+ # Endpoint to fetch metrics from. Must be in Prometheus format. Example: "/metrics"
548
+ metrics_endpoint: str,
549
+ # Target metric to autoscale on. Example: "vllm:num_requests_running"
550
+ target_metric: str,
551
+ # Target metric value. Example: 25
552
+ target_metric_value: float,
553
+ min_containers: Optional[int] = None,
554
+ max_containers: Optional[int] = None,
555
+ # Corresponds to https://kubernetes.io/docs/tasks/run-application/horizontal-pod-autoscale/#tolerance
556
+ scale_up_tolerance: float = 0.1,
557
+ # Corresponds to https://kubernetes.io/docs/tasks/run-application/horizontal-pod-autoscale/#tolerance
558
+ scale_down_tolerance: float = 0.1,
559
+ # Corresponds to https://kubernetes.io/docs/tasks/run-application/horizontal-pod-autoscale/#stabilization-window
560
+ scale_up_stabilization_window_seconds: int = 0,
561
+ # Corresponds to https://kubernetes.io/docs/tasks/run-application/horizontal-pod-autoscale/#stabilization-window
562
+ scale_down_stabilization_window_seconds: int = 300,
563
+ # How often to make autoscaling decisions.
564
+ # Corresponds to --horizontal-pod-autoscaler-sync-period in Kubernetes.
565
+ autoscaling_interval_seconds: int = 15,
566
+ # Whether to include overprovisioned containers in the scale up calculation.
567
+ buffer_containers: Optional[int] = None,
568
+ ) -> _FlashPrometheusAutoscaler:
569
+ """
570
+ Autoscale a Flash service based on containers' Prometheus metrics.
571
+
572
+ The package `prometheus_client` is required to use this method.
573
+
574
+ This is a highly experimental method that can break or be removed at any time without warning.
575
+ Do not use this method unless explicitly instructed to do so by Modal support.
576
+ """
577
+
578
+ try:
579
+ import prometheus_client # noqa: F401
580
+ except ImportError:
581
+ raise ImportError("The package `prometheus_client` is required to use this method.")
582
+
583
+ client = await _Client.from_env()
584
+ autoscaler = _FlashPrometheusAutoscaler(
585
+ client=client,
586
+ app_name=app_name,
587
+ cls_name=cls_name,
588
+ metrics_endpoint=metrics_endpoint,
589
+ target_metric=target_metric,
590
+ target_metric_value=target_metric_value,
591
+ min_containers=min_containers,
592
+ max_containers=max_containers,
593
+ buffer_containers=buffer_containers,
594
+ scale_up_tolerance=scale_up_tolerance,
595
+ scale_down_tolerance=scale_down_tolerance,
596
+ scale_up_stabilization_window_seconds=scale_up_stabilization_window_seconds,
597
+ scale_down_stabilization_window_seconds=scale_down_stabilization_window_seconds,
598
+ autoscaling_interval_seconds=autoscaling_interval_seconds,
599
+ )
600
+ await autoscaler.start()
601
+ return autoscaler
602
+
603
+
604
+ @synchronizer.create_blocking
605
+ async def flash_get_containers(app_name: str, cls_name: str) -> list[dict[str, Any]]:
606
+ """
607
+ Return a list of flash containers for a deployed Flash service.
608
+
609
+ This is a highly experimental method that can break or be removed at any time without warning.
610
+ Do not use this method unless explicitly instructed to do so by Modal support.
611
+ """
612
+ client = await _Client.from_env()
613
+ fn = _Cls.from_name(app_name, cls_name)._class_service_function
614
+ assert fn is not None
615
+ await fn.hydrate(client=client)
616
+ req = api_pb2.FlashContainerListRequest(function_id=fn.object_id)
617
+ resp = await client.stub.FlashContainerList(req)
618
+ return resp.containers