modal 1.2.1.dev13__py3-none-any.whl → 1.2.1.dev15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of modal might be problematic. Click here for more details.

@@ -292,7 +292,9 @@ class TaskCommandRouterClient:
292
292
  lambda: self._call_with_auth_retry(self._stub.TaskExecStdinWrite, request)
293
293
  )
294
294
 
295
- async def exec_poll(self, task_id: str, exec_id: str) -> sr_pb2.TaskExecPollResponse:
295
+ async def exec_poll(
296
+ self, task_id: str, exec_id: str, deadline: Optional[float] = None
297
+ ) -> sr_pb2.TaskExecPollResponse:
296
298
  """Poll for the exit status of an exec'd command, properly retrying on transient errors.
297
299
 
298
300
  Args:
@@ -302,13 +304,25 @@ class TaskCommandRouterClient:
302
304
  sr_pb2.TaskExecPollResponse: The exit status of the command if it has completed.
303
305
 
304
306
  Raises:
307
+ ExecTimeoutError: If the deadline is exceeded.
305
308
  Other errors: If retries are exhausted on transient errors or if there's an error
306
309
  from the RPC itself.
307
310
  """
308
311
  request = sr_pb2.TaskExecPollRequest(task_id=task_id, exec_id=exec_id)
309
- return await call_with_retries_on_transient_errors(
310
- lambda: self._call_with_auth_retry(self._stub.TaskExecPoll, request)
311
- )
312
+ # The timeout here is really a backstop in the event of a hang contacting
313
+ # the command router. Poll should usually be instantaneous.
314
+ timeout = deadline - time.monotonic() if deadline is not None else None
315
+ if timeout is not None and timeout <= 0:
316
+ raise ExecTimeoutError(f"Deadline exceeded while polling for exec {exec_id}")
317
+ try:
318
+ return await asyncio.wait_for(
319
+ call_with_retries_on_transient_errors(
320
+ lambda: self._call_with_auth_retry(self._stub.TaskExecPoll, request)
321
+ ),
322
+ timeout=timeout,
323
+ )
324
+ except asyncio.TimeoutError:
325
+ raise ExecTimeoutError(f"Deadline exceeded while polling for exec {exec_id}")
312
326
 
313
327
  async def exec_wait(
314
328
  self,
modal/cli/cluster.py CHANGED
@@ -83,7 +83,9 @@ async def shell(
83
83
  )
84
84
  exec_res: api_pb2.ContainerExecResponse = await client.stub.ContainerExec(req)
85
85
  if pty:
86
- await _ContainerProcess(exec_res.exec_id, client).attach()
86
+ await _ContainerProcess(exec_res.exec_id, task_id, client).attach()
87
87
  else:
88
88
  # TODO: redirect stderr to its own stream?
89
- await _ContainerProcess(exec_res.exec_id, client, stdout=StreamType.STDOUT, stderr=StreamType.STDOUT).wait()
89
+ await _ContainerProcess(
90
+ exec_res.exec_id, task_id, client, stdout=StreamType.STDOUT, stderr=StreamType.STDOUT
91
+ ).wait()
modal/cli/container.py CHANGED
@@ -80,10 +80,12 @@ async def exec(
80
80
  res: api_pb2.ContainerExecResponse = await client.stub.ContainerExec(req)
81
81
 
82
82
  if pty:
83
- await _ContainerProcess(res.exec_id, client).attach()
83
+ await _ContainerProcess(res.exec_id, container_id, client).attach()
84
84
  else:
85
85
  # TODO: redirect stderr to its own stream?
86
- await _ContainerProcess(res.exec_id, client, stdout=StreamType.STDOUT, stderr=StreamType.STDOUT).wait()
86
+ await _ContainerProcess(
87
+ res.exec_id, container_id, client, stdout=StreamType.STDOUT, stderr=StreamType.STDOUT
88
+ ).wait()
87
89
 
88
90
 
89
91
  @container_cli.command("stop")
modal/client.pyi CHANGED
@@ -33,7 +33,7 @@ class _Client:
33
33
  server_url: str,
34
34
  client_type: int,
35
35
  credentials: typing.Optional[tuple[str, str]],
36
- version: str = "1.2.1.dev13",
36
+ version: str = "1.2.1.dev15",
37
37
  ):
38
38
  """mdmd:hidden
39
39
  The Modal client object is not intended to be instantiated directly by users.
@@ -164,7 +164,7 @@ class Client:
164
164
  server_url: str,
165
165
  client_type: int,
166
166
  credentials: typing.Optional[tuple[str, str]],
167
- version: str = "1.2.1.dev13",
167
+ version: str = "1.2.1.dev15",
168
168
  ):
169
169
  """mdmd:hidden
170
170
  The Modal client object is not intended to be instantiated directly by users.
@@ -9,16 +9,17 @@ from modal_proto import api_pb2
9
9
  from ._utils.async_utils import TaskContext, synchronize_api
10
10
  from ._utils.grpc_utils import retry_transient_errors
11
11
  from ._utils.shell_utils import stream_from_stdin, write_to_fd
12
+ from ._utils.task_command_router_client import TaskCommandRouterClient
12
13
  from .client import _Client
13
14
  from .config import logger
14
- from .exception import InteractiveTimeoutError, InvalidError
15
+ from .exception import ExecTimeoutError, InteractiveTimeoutError, InvalidError
15
16
  from .io_streams import _StreamReader, _StreamWriter
16
17
  from .stream_type import StreamType
17
18
 
18
19
  T = TypeVar("T", str, bytes)
19
20
 
20
21
 
21
- class _ContainerProcess(Generic[T]):
22
+ class _ContainerProcessThroughServer(Generic[T]):
22
23
  _process_id: Optional[str] = None
23
24
  _stdout: _StreamReader[T]
24
25
  _stderr: _StreamReader[T]
@@ -31,6 +32,7 @@ class _ContainerProcess(Generic[T]):
31
32
  def __init__(
32
33
  self,
33
34
  process_id: str,
35
+ task_id: str,
34
36
  client: _Client,
35
37
  stdout: StreamType = StreamType.PIPE,
36
38
  stderr: StreamType = StreamType.PIPE,
@@ -52,6 +54,7 @@ class _ContainerProcess(Generic[T]):
52
54
  text=text,
53
55
  by_line=by_line,
54
56
  deadline=exec_deadline,
57
+ task_id=task_id,
55
58
  )
56
59
  self._stderr = _StreamReader[T](
57
60
  api_pb2.FILE_DESCRIPTOR_STDERR,
@@ -62,6 +65,7 @@ class _ContainerProcess(Generic[T]):
62
65
  text=text,
63
66
  by_line=by_line,
64
67
  deadline=exec_deadline,
68
+ task_id=task_id,
65
69
  )
66
70
  self._stdin = _StreamWriter(process_id, "container_process", self._client)
67
71
 
@@ -201,4 +205,266 @@ class _ContainerProcess(Generic[T]):
201
205
  raise InteractiveTimeoutError("Failed to establish connection to container. Please try again.")
202
206
 
203
207
 
208
+ async def _iter_stream_as_bytes(stream: _StreamReader[T]):
209
+ """Yield raw bytes from a StreamReader regardless of text mode/backend."""
210
+ async for part in stream:
211
+ if isinstance(part, str):
212
+ yield part.encode("utf-8")
213
+ else:
214
+ yield part
215
+
216
+
217
+ class _ContainerProcessThroughCommandRouter(Generic[T]):
218
+ """
219
+ Container process implementation that works via direct communication with
220
+ the Modal worker where the container is running.
221
+ """
222
+
223
+ def __init__(
224
+ self,
225
+ process_id: str,
226
+ client: _Client,
227
+ command_router_client: TaskCommandRouterClient,
228
+ task_id: str,
229
+ *,
230
+ stdout: StreamType = StreamType.PIPE,
231
+ stderr: StreamType = StreamType.PIPE,
232
+ exec_deadline: Optional[float] = None,
233
+ text: bool = True,
234
+ by_line: bool = False,
235
+ ) -> None:
236
+ self._client = client
237
+ self._command_router_client = command_router_client
238
+ self._process_id = process_id
239
+ self._exec_deadline = exec_deadline
240
+ self._text = text
241
+ self._by_line = by_line
242
+ self._task_id = task_id
243
+ self._stdout = _StreamReader[T](
244
+ api_pb2.FILE_DESCRIPTOR_STDOUT,
245
+ process_id,
246
+ "container_process",
247
+ self._client,
248
+ stream_type=stdout,
249
+ text=text,
250
+ by_line=by_line,
251
+ deadline=exec_deadline,
252
+ command_router_client=self._command_router_client,
253
+ task_id=self._task_id,
254
+ )
255
+ self._stderr = _StreamReader[T](
256
+ api_pb2.FILE_DESCRIPTOR_STDERR,
257
+ process_id,
258
+ "container_process",
259
+ self._client,
260
+ stream_type=stderr,
261
+ text=text,
262
+ by_line=by_line,
263
+ deadline=exec_deadline,
264
+ command_router_client=self._command_router_client,
265
+ task_id=self._task_id,
266
+ )
267
+ self._stdin = _StreamWriter(
268
+ process_id,
269
+ "container_process",
270
+ self._client,
271
+ command_router_client=self._command_router_client,
272
+ task_id=self._task_id,
273
+ )
274
+ self._returncode = None
275
+
276
+ @property
277
+ def stdout(self) -> _StreamReader[T]:
278
+ return self._stdout
279
+
280
+ @property
281
+ def stderr(self) -> _StreamReader[T]:
282
+ return self._stderr
283
+
284
+ @property
285
+ def stdin(self) -> _StreamWriter:
286
+ return self._stdin
287
+
288
+ @property
289
+ def returncode(self) -> int:
290
+ if self._returncode is None:
291
+ raise InvalidError(
292
+ "You must call wait() before accessing the returncode. "
293
+ "To poll for the status of a running process, use poll() instead."
294
+ )
295
+ return self._returncode
296
+
297
+ async def poll(self) -> Optional[int]:
298
+ if self._returncode is not None:
299
+ return self._returncode
300
+ try:
301
+ resp = await self._command_router_client.exec_poll(self._task_id, self._process_id, self._exec_deadline)
302
+ which = resp.WhichOneof("exit_status")
303
+ if which is None:
304
+ return None
305
+
306
+ if which == "code":
307
+ self._returncode = int(resp.code)
308
+ return self._returncode
309
+ elif which == "signal":
310
+ self._returncode = 128 + int(resp.signal)
311
+ return self._returncode
312
+ else:
313
+ logger.debug(f"ContainerProcess {self._process_id} exited with unexpected status: {which}")
314
+ raise InvalidError("Unexpected exit status")
315
+ except ExecTimeoutError:
316
+ logger.debug(f"ContainerProcess poll for {self._process_id} did not complete within deadline")
317
+ return None
318
+ except Exception as e:
319
+ # Re-raise non-transient errors or errors resulting from exceeding retries on transient errors.
320
+ logger.warning(f"ContainerProcess poll for {self._process_id} failed: {e}")
321
+ raise
322
+
323
+ async def wait(self) -> int:
324
+ if self._returncode is not None:
325
+ return self._returncode
326
+
327
+ try:
328
+ resp = await self._command_router_client.exec_wait(self._task_id, self._process_id, self._exec_deadline)
329
+ which = resp.WhichOneof("exit_status")
330
+ if which == "code":
331
+ self._returncode = int(resp.code)
332
+ elif which == "signal":
333
+ self._returncode = 128 + int(resp.signal)
334
+ else:
335
+ logger.debug(f"ContainerProcess {self._process_id} exited with unexpected status: {which}")
336
+ self._returncode = -1
337
+ raise InvalidError("Unexpected exit status")
338
+ except ExecTimeoutError:
339
+ logger.debug(f"ContainerProcess {self._process_id} did not complete within deadline")
340
+ # TODO(saltzm): This is a weird API, but customers currently may rely on it. This
341
+ # should be a ExecTimeoutError.
342
+ self._returncode = -1
343
+
344
+ return self._returncode
345
+
346
+ async def attach(self):
347
+ if platform.system() == "Windows":
348
+ print("interactive exec is not currently supported on Windows.")
349
+ return
350
+
351
+ from ._output import make_console
352
+
353
+ console = make_console()
354
+
355
+ connecting_status = console.status("Connecting...")
356
+ connecting_status.start()
357
+ on_connect = asyncio.Event()
358
+
359
+ async def _write_to_fd_loop(stream: _StreamReader[T]):
360
+ async for chunk in _iter_stream_as_bytes(stream):
361
+ if chunk is None:
362
+ break
363
+
364
+ if not on_connect.is_set():
365
+ connecting_status.stop()
366
+ on_connect.set()
367
+
368
+ await write_to_fd(stream.file_descriptor, chunk)
369
+
370
+ async def _handle_input(data: bytes, message_index: int):
371
+ self.stdin.write(data)
372
+ await self.stdin.drain()
373
+
374
+ async with TaskContext() as tc:
375
+ stdout_task = tc.create_task(_write_to_fd_loop(self.stdout))
376
+ stderr_task = tc.create_task(_write_to_fd_loop(self.stderr))
377
+
378
+ try:
379
+ # Time out if we can't connect fast enough.
380
+ await asyncio.wait_for(on_connect.wait(), timeout=60)
381
+
382
+ async with stream_from_stdin(_handle_input, use_raw_terminal=True):
383
+ await stdout_task
384
+ await stderr_task
385
+
386
+ except (asyncio.TimeoutError, TimeoutError):
387
+ connecting_status.stop()
388
+ stdout_task.cancel()
389
+ stderr_task.cancel()
390
+ raise InteractiveTimeoutError("Failed to establish connection to container. Please try again.")
391
+
392
+
393
+ class _ContainerProcess(Generic[T]):
394
+ """Represents a running process in a container."""
395
+
396
+ def __init__(
397
+ self,
398
+ process_id: str,
399
+ task_id: str,
400
+ client: _Client,
401
+ stdout: StreamType = StreamType.PIPE,
402
+ stderr: StreamType = StreamType.PIPE,
403
+ exec_deadline: Optional[float] = None,
404
+ text: bool = True,
405
+ by_line: bool = False,
406
+ command_router_client: Optional[TaskCommandRouterClient] = None,
407
+ ) -> None:
408
+ if command_router_client is None:
409
+ self._impl = _ContainerProcessThroughServer(
410
+ process_id,
411
+ task_id,
412
+ client,
413
+ stdout=stdout,
414
+ stderr=stderr,
415
+ exec_deadline=exec_deadline,
416
+ text=text,
417
+ by_line=by_line,
418
+ )
419
+ else:
420
+ self._impl = _ContainerProcessThroughCommandRouter(
421
+ process_id,
422
+ client,
423
+ command_router_client,
424
+ task_id,
425
+ stdout=stdout,
426
+ stderr=stderr,
427
+ exec_deadline=exec_deadline,
428
+ text=text,
429
+ by_line=by_line,
430
+ )
431
+
432
+ def __repr__(self) -> str:
433
+ return self._impl.__repr__()
434
+
435
+ @property
436
+ def stdout(self) -> _StreamReader[T]:
437
+ """StreamReader for the container process's stdout stream."""
438
+ return self._impl.stdout
439
+
440
+ @property
441
+ def stderr(self) -> _StreamReader[T]:
442
+ """StreamReader for the container process's stderr stream."""
443
+ return self._impl.stderr
444
+
445
+ @property
446
+ def stdin(self) -> _StreamWriter:
447
+ """StreamWriter for the container process's stdin stream."""
448
+ return self._impl.stdin
449
+
450
+ @property
451
+ def returncode(self) -> int:
452
+ return self._impl.returncode
453
+
454
+ async def poll(self) -> Optional[int]:
455
+ """Check if the container process has finished running.
456
+
457
+ Returns `None` if the process is still running, else returns the exit code.
458
+ """
459
+ return await self._impl.poll()
460
+
461
+ async def wait(self) -> int:
462
+ """Wait for the container process to finish running. Returns the exit code."""
463
+ return await self._impl.wait()
464
+
465
+ async def attach(self):
466
+ """mdmd:hidden"""
467
+ await self._impl.attach()
468
+
469
+
204
470
  ContainerProcess = synchronize_api(_ContainerProcess)
@@ -1,3 +1,4 @@
1
+ import modal._utils.task_command_router_client
1
2
  import modal.client
2
3
  import modal.io_streams
3
4
  import modal.stream_type
@@ -6,7 +7,7 @@ import typing_extensions
6
7
 
7
8
  T = typing.TypeVar("T")
8
9
 
9
- class _ContainerProcess(typing.Generic[T]):
10
+ class _ContainerProcessThroughServer(typing.Generic[T]):
10
11
  """Abstract base class for generic types.
11
12
 
12
13
  A generic type is typically declared by inheriting from
@@ -39,6 +40,7 @@ class _ContainerProcess(typing.Generic[T]):
39
40
  def __init__(
40
41
  self,
41
42
  process_id: str,
43
+ task_id: str,
42
44
  client: modal.client._Client,
43
45
  stdout: modal.stream_type.StreamType = modal.stream_type.StreamType.PIPE,
44
46
  stderr: modal.stream_type.StreamType = modal.stream_type.StreamType.PIPE,
@@ -86,47 +88,110 @@ class _ContainerProcess(typing.Generic[T]):
86
88
  """mdmd:hidden"""
87
89
  ...
88
90
 
89
- SUPERSELF = typing.TypeVar("SUPERSELF", covariant=True)
91
+ def _iter_stream_as_bytes(stream: modal.io_streams._StreamReader[T]):
92
+ """Yield raw bytes from a StreamReader regardless of text mode/backend."""
93
+ ...
90
94
 
91
- class ContainerProcess(typing.Generic[T]):
92
- """Abstract base class for generic types.
95
+ class _ContainerProcessThroughCommandRouter(typing.Generic[T]):
96
+ """Container process implementation that works via direct communication with
97
+ the Modal worker where the container is running.
98
+ """
99
+ def __init__(
100
+ self,
101
+ process_id: str,
102
+ client: modal.client._Client,
103
+ command_router_client: modal._utils.task_command_router_client.TaskCommandRouterClient,
104
+ task_id: str,
105
+ *,
106
+ stdout: modal.stream_type.StreamType = modal.stream_type.StreamType.PIPE,
107
+ stderr: modal.stream_type.StreamType = modal.stream_type.StreamType.PIPE,
108
+ exec_deadline: typing.Optional[float] = None,
109
+ text: bool = True,
110
+ by_line: bool = False,
111
+ ) -> None:
112
+ """Initialize self. See help(type(self)) for accurate signature."""
113
+ ...
93
114
 
94
- A generic type is typically declared by inheriting from
95
- this class parameterized with one or more type variables.
96
- For example, a generic mapping type might be defined as::
115
+ @property
116
+ def stdout(self) -> modal.io_streams._StreamReader[T]: ...
117
+ @property
118
+ def stderr(self) -> modal.io_streams._StreamReader[T]: ...
119
+ @property
120
+ def stdin(self) -> modal.io_streams._StreamWriter: ...
121
+ @property
122
+ def returncode(self) -> int: ...
123
+ async def poll(self) -> typing.Optional[int]: ...
124
+ async def wait(self) -> int: ...
125
+ async def attach(self): ...
97
126
 
98
- class Mapping(Generic[KT, VT]):
99
- def __getitem__(self, key: KT) -> VT:
100
- ...
101
- # Etc.
127
+ class _ContainerProcess(typing.Generic[T]):
128
+ """Represents a running process in a container."""
129
+ def __init__(
130
+ self,
131
+ process_id: str,
132
+ task_id: str,
133
+ client: modal.client._Client,
134
+ stdout: modal.stream_type.StreamType = modal.stream_type.StreamType.PIPE,
135
+ stderr: modal.stream_type.StreamType = modal.stream_type.StreamType.PIPE,
136
+ exec_deadline: typing.Optional[float] = None,
137
+ text: bool = True,
138
+ by_line: bool = False,
139
+ command_router_client: typing.Optional[modal._utils.task_command_router_client.TaskCommandRouterClient] = None,
140
+ ) -> None:
141
+ """Initialize self. See help(type(self)) for accurate signature."""
142
+ ...
102
143
 
103
- This class can then be used as follows::
144
+ def __repr__(self) -> str:
145
+ """Return repr(self)."""
146
+ ...
104
147
 
105
- def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT:
106
- try:
107
- return mapping[key]
108
- except KeyError:
109
- return default
110
- """
148
+ @property
149
+ def stdout(self) -> modal.io_streams._StreamReader[T]:
150
+ """StreamReader for the container process's stdout stream."""
151
+ ...
111
152
 
112
- _process_id: typing.Optional[str]
113
- _stdout: modal.io_streams.StreamReader[T]
114
- _stderr: modal.io_streams.StreamReader[T]
115
- _stdin: modal.io_streams.StreamWriter
116
- _exec_deadline: typing.Optional[float]
117
- _text: bool
118
- _by_line: bool
119
- _returncode: typing.Optional[int]
153
+ @property
154
+ def stderr(self) -> modal.io_streams._StreamReader[T]:
155
+ """StreamReader for the container process's stderr stream."""
156
+ ...
157
+
158
+ @property
159
+ def stdin(self) -> modal.io_streams._StreamWriter:
160
+ """StreamWriter for the container process's stdin stream."""
161
+ ...
162
+
163
+ @property
164
+ def returncode(self) -> int: ...
165
+ async def poll(self) -> typing.Optional[int]:
166
+ """Check if the container process has finished running.
120
167
 
168
+ Returns `None` if the process is still running, else returns the exit code.
169
+ """
170
+ ...
171
+
172
+ async def wait(self) -> int:
173
+ """Wait for the container process to finish running. Returns the exit code."""
174
+ ...
175
+
176
+ async def attach(self):
177
+ """mdmd:hidden"""
178
+ ...
179
+
180
+ SUPERSELF = typing.TypeVar("SUPERSELF", covariant=True)
181
+
182
+ class ContainerProcess(typing.Generic[T]):
183
+ """Represents a running process in a container."""
121
184
  def __init__(
122
185
  self,
123
186
  process_id: str,
187
+ task_id: str,
124
188
  client: modal.client.Client,
125
189
  stdout: modal.stream_type.StreamType = modal.stream_type.StreamType.PIPE,
126
190
  stderr: modal.stream_type.StreamType = modal.stream_type.StreamType.PIPE,
127
191
  exec_deadline: typing.Optional[float] = None,
128
192
  text: bool = True,
129
193
  by_line: bool = False,
194
+ command_router_client: typing.Optional[modal._utils.task_command_router_client.TaskCommandRouterClient] = None,
130
195
  ) -> None: ...
131
196
  def __repr__(self) -> str: ...
132
197
  @property
@@ -164,12 +229,6 @@ class ContainerProcess(typing.Generic[T]):
164
229
 
165
230
  poll: __poll_spec[typing_extensions.Self]
166
231
 
167
- class ___wait_for_completion_spec(typing_extensions.Protocol[SUPERSELF]):
168
- def __call__(self, /) -> int: ...
169
- async def aio(self, /) -> int: ...
170
-
171
- _wait_for_completion: ___wait_for_completion_spec[typing_extensions.Self]
172
-
173
232
  class __wait_spec(typing_extensions.Protocol[SUPERSELF]):
174
233
  def __call__(self, /) -> int:
175
234
  """Wait for the container process to finish running. Returns the exit code."""
@@ -321,7 +321,7 @@ class _FlashPrometheusAutoscaler:
321
321
 
322
322
  async def _compute_target_containers(self, current_replicas: int) -> int:
323
323
  """
324
- Gets internal metrics from container to autoscale up or down.
324
+ Gets metrics from container to autoscale up or down.
325
325
  """
326
326
  containers = await self._get_all_containers()
327
327
  if len(containers) > current_replicas:
@@ -334,7 +334,7 @@ class _FlashPrometheusAutoscaler:
334
334
  if current_replicas == 0:
335
335
  return 1
336
336
 
337
- # Get metrics based on autoscaler type (prometheus or internal)
337
+ # Get metrics based on autoscaler type
338
338
  sum_metric, n_containers_with_metrics = await self._get_scaling_info(containers)
339
339
 
340
340
  desired_replicas = self._calculate_desired_replicas(
@@ -406,39 +406,26 @@ class _FlashPrometheusAutoscaler:
406
406
  return desired_replicas
407
407
 
408
408
  async def _get_scaling_info(self, containers) -> tuple[float, int]:
409
- """Get metrics using either internal container metrics API or prometheus HTTP endpoints."""
410
- if self.metrics_endpoint == "internal":
411
- container_metrics_results = await asyncio.gather(
412
- *[self._get_container_metrics(container.task_id) for container in containers]
413
- )
414
- container_metrics_list = []
415
- for container_metric in container_metrics_results:
416
- if container_metric is None:
417
- continue
418
- container_metrics_list.append(getattr(container_metric.metrics, self.target_metric))
419
-
420
- sum_metric = sum(container_metrics_list)
421
- n_containers_with_metrics = len(container_metrics_list)
422
- else:
423
- sum_metric = 0
424
- n_containers_with_metrics = 0
425
-
426
- container_metrics_list = await asyncio.gather(
427
- *[
428
- self._get_metrics(f"https://{container.host}:{container.port}/{self.metrics_endpoint}")
429
- for container in containers
430
- ]
431
- )
409
+ """Get metrics using container exposed metrics endpoints."""
410
+ sum_metric = 0
411
+ n_containers_with_metrics = 0
412
+
413
+ container_metrics_list = await asyncio.gather(
414
+ *[
415
+ self._get_metrics(f"https://{container.host}:{container.port}/{self.metrics_endpoint}")
416
+ for container in containers
417
+ ]
418
+ )
432
419
 
433
- for container_metrics in container_metrics_list:
434
- if (
435
- container_metrics is None
436
- or self.target_metric not in container_metrics
437
- or len(container_metrics[self.target_metric]) == 0
438
- ):
439
- continue
440
- sum_metric += container_metrics[self.target_metric][0].value
441
- n_containers_with_metrics += 1
420
+ for container_metrics in container_metrics_list:
421
+ if (
422
+ container_metrics is None
423
+ or self.target_metric not in container_metrics
424
+ or len(container_metrics[self.target_metric]) == 0
425
+ ):
426
+ continue
427
+ sum_metric += container_metrics[self.target_metric][0].value
428
+ n_containers_with_metrics += 1
442
429
 
443
430
  return sum_metric, n_containers_with_metrics
444
431
 
@@ -474,15 +461,6 @@ class _FlashPrometheusAutoscaler:
474
461
 
475
462
  return metrics
476
463
 
477
- async def _get_container_metrics(self, container_id: str) -> Optional[api_pb2.TaskGetAutoscalingMetricsResponse]:
478
- req = api_pb2.TaskGetAutoscalingMetricsRequest(task_id=container_id)
479
- try:
480
- resp = await retry_transient_errors(self.client.stub.TaskGetAutoscalingMetrics, req)
481
- return resp
482
- except Exception as e:
483
- logger.warning(f"[Modal Flash] Error getting metrics for container {container_id}: {e}")
484
- return None
485
-
486
464
  async def _get_all_containers(self):
487
465
  req = api_pb2.FlashContainerListRequest(function_id=self.fn.object_id)
488
466
  resp = await retry_transient_errors(self.client.stub.FlashContainerList, req)
@@ -572,14 +550,10 @@ async def flash_prometheus_autoscaler(
572
550
  app_name: str,
573
551
  cls_name: str,
574
552
  # Endpoint to fetch metrics from. Must be in Prometheus format. Example: "/metrics"
575
- # If metrics_endpoint is "internal", we will use containers' internal metrics to autoscale instead.
576
553
  metrics_endpoint: str,
577
554
  # Target metric to autoscale on. Example: "vllm:num_requests_running"
578
- # If metrics_endpoint is "internal", target_metrics options are: [cpu_usage_percent, memory_usage_percent]
579
555
  target_metric: str,
580
556
  # Target metric value. Example: 25
581
- # If metrics_endpoint is "internal", target_metric_value is a percentage value between 0.1 and 1.0 (inclusive),
582
- # indicating container's usage of that metric.
583
557
  target_metric_value: float,
584
558
  min_containers: Optional[int] = None,
585
559
  max_containers: Optional[int] = None,