modal 1.2.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of modal might be problematic. Click here for more details.

Files changed (49) hide show
  1. modal/_container_entrypoint.py +4 -1
  2. modal/_partial_function.py +28 -3
  3. modal/_utils/function_utils.py +4 -0
  4. modal/_utils/task_command_router_client.py +537 -0
  5. modal/app.py +93 -54
  6. modal/app.pyi +48 -18
  7. modal/cli/_download.py +19 -3
  8. modal/cli/cluster.py +4 -2
  9. modal/cli/container.py +4 -2
  10. modal/cli/entry_point.py +1 -0
  11. modal/cli/launch.py +1 -2
  12. modal/cli/run.py +6 -0
  13. modal/cli/volume.py +7 -1
  14. modal/client.pyi +2 -2
  15. modal/cls.py +5 -12
  16. modal/config.py +14 -0
  17. modal/container_process.py +283 -3
  18. modal/container_process.pyi +95 -32
  19. modal/exception.py +4 -0
  20. modal/experimental/flash.py +21 -47
  21. modal/experimental/flash.pyi +6 -20
  22. modal/functions.pyi +6 -6
  23. modal/io_streams.py +455 -122
  24. modal/io_streams.pyi +220 -95
  25. modal/partial_function.pyi +4 -1
  26. modal/runner.py +39 -36
  27. modal/runner.pyi +40 -24
  28. modal/sandbox.py +130 -11
  29. modal/sandbox.pyi +145 -9
  30. modal/volume.py +23 -3
  31. modal/volume.pyi +30 -0
  32. {modal-1.2.0.dist-info → modal-1.2.1.dist-info}/METADATA +5 -5
  33. {modal-1.2.0.dist-info → modal-1.2.1.dist-info}/RECORD +49 -48
  34. modal_proto/api.proto +2 -26
  35. modal_proto/api_grpc.py +0 -32
  36. modal_proto/api_pb2.py +327 -367
  37. modal_proto/api_pb2.pyi +6 -69
  38. modal_proto/api_pb2_grpc.py +0 -67
  39. modal_proto/api_pb2_grpc.pyi +0 -22
  40. modal_proto/modal_api_grpc.py +0 -2
  41. modal_proto/sandbox_router.proto +0 -4
  42. modal_proto/sandbox_router_pb2.pyi +0 -4
  43. modal_proto/task_command_router.proto +1 -1
  44. modal_proto/task_command_router_pb2.py +2 -2
  45. modal_version/__init__.py +1 -1
  46. {modal-1.2.0.dist-info → modal-1.2.1.dist-info}/WHEEL +0 -0
  47. {modal-1.2.0.dist-info → modal-1.2.1.dist-info}/entry_points.txt +0 -0
  48. {modal-1.2.0.dist-info → modal-1.2.1.dist-info}/licenses/LICENSE +0 -0
  49. {modal-1.2.0.dist-info → modal-1.2.1.dist-info}/top_level.txt +0 -0
modal/config.py CHANGED
@@ -71,6 +71,10 @@ Other possible configuration options are:
71
71
  The log formatting pattern that will be used by the modal client itself.
72
72
  See https://docs.python.org/3/library/logging.html#logrecord-attributes for available
73
73
  log attributes.
74
+ * `dev_suffix` (in the .toml file) / `MODAL_DEV_SUFFIX` (as an env var).
75
+ Overrides the default `-dev` suffix added to URLs generated for web endpoints
76
+ when the App is ephemeral (i.e., created via `modal serve`). Must be a short
77
+ alphanumeric string.
74
78
 
75
79
  Meta-configuration
76
80
  ------------------
@@ -85,6 +89,7 @@ Some "meta-options" are set using environment variables only:
85
89
 
86
90
  import logging
87
91
  import os
92
+ import re
88
93
  import typing
89
94
  import warnings
90
95
  from typing import Any, Callable, Optional
@@ -206,6 +211,12 @@ def _check_value(options: list[str]) -> Callable[[str], str]:
206
211
  return checker
207
212
 
208
213
 
214
+ def _enforce_suffix_rules(x: str) -> str:
215
+ if x and not re.match(r"^[a-zA-Z0-9]{1,8}$", x):
216
+ raise ValueError("Suffix must be an alphanumeric string of no more than 8 characters.")
217
+ return x
218
+
219
+
209
220
  class _Setting(typing.NamedTuple):
210
221
  default: typing.Any = None
211
222
  transform: typing.Callable[[str], typing.Any] = lambda x: x # noqa: E731
@@ -236,6 +247,8 @@ _SETTINGS = {
236
247
  "traceback": _Setting(False, transform=_to_boolean),
237
248
  "image_builder_version": _Setting(),
238
249
  "strict_parameters": _Setting(False, transform=_to_boolean), # For internal/experimental use
250
+ # Allow insecure TLS for the task command router when running locally (testing/dev only)
251
+ "task_command_router_insecure": _Setting(False, transform=_to_boolean),
239
252
  "snapshot_debug": _Setting(False, transform=_to_boolean),
240
253
  "cuda_checkpoint_path": _Setting("/__modal/.bin/cuda-checkpoint"), # Used for snapshotting GPU memory.
241
254
  "build_validation": _Setting("error", transform=_check_value(["error", "warn", "ignore"])),
@@ -244,6 +257,7 @@ _SETTINGS = {
244
257
  "pickle",
245
258
  transform=lambda s: _check_value(["pickle", "cbor"])(s.lower()),
246
259
  ),
260
+ "dev_suffix": _Setting("", transform=_enforce_suffix_rules),
247
261
  }
248
262
 
249
263
 
@@ -9,16 +9,17 @@ from modal_proto import api_pb2
9
9
  from ._utils.async_utils import TaskContext, synchronize_api
10
10
  from ._utils.grpc_utils import retry_transient_errors
11
11
  from ._utils.shell_utils import stream_from_stdin, write_to_fd
12
+ from ._utils.task_command_router_client import TaskCommandRouterClient
12
13
  from .client import _Client
13
14
  from .config import logger
14
- from .exception import InteractiveTimeoutError, InvalidError
15
+ from .exception import ExecTimeoutError, InteractiveTimeoutError, InvalidError
15
16
  from .io_streams import _StreamReader, _StreamWriter
16
17
  from .stream_type import StreamType
17
18
 
18
19
  T = TypeVar("T", str, bytes)
19
20
 
20
21
 
21
- class _ContainerProcess(Generic[T]):
22
+ class _ContainerProcessThroughServer(Generic[T]):
22
23
  _process_id: Optional[str] = None
23
24
  _stdout: _StreamReader[T]
24
25
  _stderr: _StreamReader[T]
@@ -31,6 +32,7 @@ class _ContainerProcess(Generic[T]):
31
32
  def __init__(
32
33
  self,
33
34
  process_id: str,
35
+ task_id: str,
34
36
  client: _Client,
35
37
  stdout: StreamType = StreamType.PIPE,
36
38
  stderr: StreamType = StreamType.PIPE,
@@ -52,6 +54,7 @@ class _ContainerProcess(Generic[T]):
52
54
  text=text,
53
55
  by_line=by_line,
54
56
  deadline=exec_deadline,
57
+ task_id=task_id,
55
58
  )
56
59
  self._stderr = _StreamReader[T](
57
60
  api_pb2.FILE_DESCRIPTOR_STDERR,
@@ -62,6 +65,7 @@ class _ContainerProcess(Generic[T]):
62
65
  text=text,
63
66
  by_line=by_line,
64
67
  deadline=exec_deadline,
68
+ task_id=task_id,
65
69
  )
66
70
  self._stdin = _StreamWriter(process_id, "container_process", self._client)
67
71
 
@@ -155,8 +159,16 @@ class _ContainerProcess(Generic[T]):
155
159
  on_connect = asyncio.Event()
156
160
 
157
161
  async def _write_to_fd_loop(stream: _StreamReader):
162
+ # This is required to make modal shell to an existing task work,
163
+ # since that uses ContainerExec RPCs directly, but this is hacky.
164
+ #
165
+ # TODO(saltzm): Once we use the new exec path for that use case, this code can all be removed.
166
+ from .io_streams import _StreamReaderThroughServer
167
+
168
+ assert isinstance(stream._impl, _StreamReaderThroughServer)
169
+ stream_impl = stream._impl
158
170
  # Don't skip empty messages so we can detect when the process has booted.
159
- async for chunk in stream._get_logs(skip_empty_messages=False):
171
+ async for chunk in stream_impl._get_logs(skip_empty_messages=False):
160
172
  if chunk is None:
161
173
  break
162
174
 
@@ -193,4 +205,272 @@ class _ContainerProcess(Generic[T]):
193
205
  raise InteractiveTimeoutError("Failed to establish connection to container. Please try again.")
194
206
 
195
207
 
208
+ async def _iter_stream_as_bytes(stream: _StreamReader[T]):
209
+ """Yield raw bytes from a StreamReader regardless of text mode/backend."""
210
+ async for part in stream:
211
+ if isinstance(part, str):
212
+ yield part.encode("utf-8")
213
+ else:
214
+ yield part
215
+
216
+
217
+ class _ContainerProcessThroughCommandRouter(Generic[T]):
218
+ """
219
+ Container process implementation that works via direct communication with
220
+ the Modal worker where the container is running.
221
+ """
222
+
223
+ def __init__(
224
+ self,
225
+ process_id: str,
226
+ client: _Client,
227
+ command_router_client: TaskCommandRouterClient,
228
+ task_id: str,
229
+ *,
230
+ stdout: StreamType = StreamType.PIPE,
231
+ stderr: StreamType = StreamType.PIPE,
232
+ exec_deadline: Optional[float] = None,
233
+ text: bool = True,
234
+ by_line: bool = False,
235
+ ) -> None:
236
+ self._client = client
237
+ self._command_router_client = command_router_client
238
+ self._process_id = process_id
239
+ self._exec_deadline = exec_deadline
240
+ self._text = text
241
+ self._by_line = by_line
242
+ self._task_id = task_id
243
+ self._stdout = _StreamReader[T](
244
+ api_pb2.FILE_DESCRIPTOR_STDOUT,
245
+ process_id,
246
+ "container_process",
247
+ self._client,
248
+ stream_type=stdout,
249
+ text=text,
250
+ by_line=by_line,
251
+ deadline=exec_deadline,
252
+ command_router_client=self._command_router_client,
253
+ task_id=self._task_id,
254
+ )
255
+ self._stderr = _StreamReader[T](
256
+ api_pb2.FILE_DESCRIPTOR_STDERR,
257
+ process_id,
258
+ "container_process",
259
+ self._client,
260
+ stream_type=stderr,
261
+ text=text,
262
+ by_line=by_line,
263
+ deadline=exec_deadline,
264
+ command_router_client=self._command_router_client,
265
+ task_id=self._task_id,
266
+ )
267
+ self._stdin = _StreamWriter(
268
+ process_id,
269
+ "container_process",
270
+ self._client,
271
+ command_router_client=self._command_router_client,
272
+ task_id=self._task_id,
273
+ )
274
+ self._returncode = None
275
+
276
+ def __repr__(self) -> str:
277
+ return f"ContainerProcess(process_id={self._process_id!r})"
278
+
279
+ @property
280
+ def stdout(self) -> _StreamReader[T]:
281
+ return self._stdout
282
+
283
+ @property
284
+ def stderr(self) -> _StreamReader[T]:
285
+ return self._stderr
286
+
287
+ @property
288
+ def stdin(self) -> _StreamWriter:
289
+ return self._stdin
290
+
291
+ @property
292
+ def returncode(self) -> int:
293
+ if self._returncode is None:
294
+ raise InvalidError(
295
+ "You must call wait() before accessing the returncode. "
296
+ "To poll for the status of a running process, use poll() instead."
297
+ )
298
+ return self._returncode
299
+
300
+ async def poll(self) -> Optional[int]:
301
+ if self._returncode is not None:
302
+ return self._returncode
303
+ try:
304
+ resp = await self._command_router_client.exec_poll(self._task_id, self._process_id, self._exec_deadline)
305
+ which = resp.WhichOneof("exit_status")
306
+ if which is None:
307
+ return None
308
+
309
+ if which == "code":
310
+ self._returncode = int(resp.code)
311
+ return self._returncode
312
+ elif which == "signal":
313
+ self._returncode = 128 + int(resp.signal)
314
+ return self._returncode
315
+ else:
316
+ logger.debug(f"ContainerProcess {self._process_id} exited with unexpected status: {which}")
317
+ raise InvalidError("Unexpected exit status")
318
+ except ExecTimeoutError:
319
+ logger.debug(f"ContainerProcess poll for {self._process_id} did not complete within deadline")
320
+ # TODO(saltzm): This is a weird API, but customers currently may rely on it. This
321
+ # should probably raise an ExecTimeoutError instead.
322
+ self._returncode = -1
323
+ return self._returncode
324
+ except Exception as e:
325
+ # Re-raise non-transient errors or errors resulting from exceeding retries on transient errors.
326
+ logger.warning(f"ContainerProcess poll for {self._process_id} failed: {e}")
327
+ raise
328
+
329
+ async def wait(self) -> int:
330
+ if self._returncode is not None:
331
+ return self._returncode
332
+
333
+ try:
334
+ resp = await self._command_router_client.exec_wait(self._task_id, self._process_id, self._exec_deadline)
335
+ which = resp.WhichOneof("exit_status")
336
+ if which == "code":
337
+ self._returncode = int(resp.code)
338
+ elif which == "signal":
339
+ self._returncode = 128 + int(resp.signal)
340
+ else:
341
+ logger.debug(f"ContainerProcess {self._process_id} exited with unexpected status: {which}")
342
+ self._returncode = -1
343
+ raise InvalidError("Unexpected exit status")
344
+ except ExecTimeoutError:
345
+ logger.debug(f"ContainerProcess {self._process_id} did not complete within deadline")
346
+ # TODO(saltzm): This is a weird API, but customers currently may rely on it. This
347
+ # should be a ExecTimeoutError.
348
+ self._returncode = -1
349
+
350
+ return self._returncode
351
+
352
+ async def attach(self):
353
+ if platform.system() == "Windows":
354
+ print("interactive exec is not currently supported on Windows.")
355
+ return
356
+
357
+ from ._output import make_console
358
+
359
+ console = make_console()
360
+
361
+ connecting_status = console.status("Connecting...")
362
+ connecting_status.start()
363
+ on_connect = asyncio.Event()
364
+
365
+ async def _write_to_fd_loop(stream: _StreamReader[T]):
366
+ async for chunk in _iter_stream_as_bytes(stream):
367
+ if chunk is None:
368
+ break
369
+
370
+ if not on_connect.is_set():
371
+ connecting_status.stop()
372
+ on_connect.set()
373
+
374
+ await write_to_fd(stream.file_descriptor, chunk)
375
+
376
+ async def _handle_input(data: bytes, message_index: int):
377
+ self.stdin.write(data)
378
+ await self.stdin.drain()
379
+
380
+ async with TaskContext() as tc:
381
+ stdout_task = tc.create_task(_write_to_fd_loop(self.stdout))
382
+ stderr_task = tc.create_task(_write_to_fd_loop(self.stderr))
383
+
384
+ try:
385
+ # Time out if we can't connect fast enough.
386
+ await asyncio.wait_for(on_connect.wait(), timeout=60)
387
+
388
+ async with stream_from_stdin(_handle_input, use_raw_terminal=True):
389
+ await stdout_task
390
+ await stderr_task
391
+
392
+ except (asyncio.TimeoutError, TimeoutError):
393
+ connecting_status.stop()
394
+ stdout_task.cancel()
395
+ stderr_task.cancel()
396
+ raise InteractiveTimeoutError("Failed to establish connection to container. Please try again.")
397
+
398
+
399
+ class _ContainerProcess(Generic[T]):
400
+ """Represents a running process in a container."""
401
+
402
+ def __init__(
403
+ self,
404
+ process_id: str,
405
+ task_id: str,
406
+ client: _Client,
407
+ stdout: StreamType = StreamType.PIPE,
408
+ stderr: StreamType = StreamType.PIPE,
409
+ exec_deadline: Optional[float] = None,
410
+ text: bool = True,
411
+ by_line: bool = False,
412
+ command_router_client: Optional[TaskCommandRouterClient] = None,
413
+ ) -> None:
414
+ if command_router_client is None:
415
+ self._impl = _ContainerProcessThroughServer(
416
+ process_id,
417
+ task_id,
418
+ client,
419
+ stdout=stdout,
420
+ stderr=stderr,
421
+ exec_deadline=exec_deadline,
422
+ text=text,
423
+ by_line=by_line,
424
+ )
425
+ else:
426
+ self._impl = _ContainerProcessThroughCommandRouter(
427
+ process_id,
428
+ client,
429
+ command_router_client,
430
+ task_id,
431
+ stdout=stdout,
432
+ stderr=stderr,
433
+ exec_deadline=exec_deadline,
434
+ text=text,
435
+ by_line=by_line,
436
+ )
437
+
438
+ def __repr__(self) -> str:
439
+ return self._impl.__repr__()
440
+
441
+ @property
442
+ def stdout(self) -> _StreamReader[T]:
443
+ """StreamReader for the container process's stdout stream."""
444
+ return self._impl.stdout
445
+
446
+ @property
447
+ def stderr(self) -> _StreamReader[T]:
448
+ """StreamReader for the container process's stderr stream."""
449
+ return self._impl.stderr
450
+
451
+ @property
452
+ def stdin(self) -> _StreamWriter:
453
+ """StreamWriter for the container process's stdin stream."""
454
+ return self._impl.stdin
455
+
456
+ @property
457
+ def returncode(self) -> int:
458
+ return self._impl.returncode
459
+
460
+ async def poll(self) -> Optional[int]:
461
+ """Check if the container process has finished running.
462
+
463
+ Returns `None` if the process is still running, else returns the exit code.
464
+ """
465
+ return await self._impl.poll()
466
+
467
+ async def wait(self) -> int:
468
+ """Wait for the container process to finish running. Returns the exit code."""
469
+ return await self._impl.wait()
470
+
471
+ async def attach(self):
472
+ """mdmd:hidden"""
473
+ await self._impl.attach()
474
+
475
+
196
476
  ContainerProcess = synchronize_api(_ContainerProcess)
@@ -1,3 +1,4 @@
1
+ import modal._utils.task_command_router_client
1
2
  import modal.client
2
3
  import modal.io_streams
3
4
  import modal.stream_type
@@ -6,7 +7,7 @@ import typing_extensions
6
7
 
7
8
  T = typing.TypeVar("T")
8
9
 
9
- class _ContainerProcess(typing.Generic[T]):
10
+ class _ContainerProcessThroughServer(typing.Generic[T]):
10
11
  """Abstract base class for generic types.
11
12
 
12
13
  A generic type is typically declared by inheriting from
@@ -39,6 +40,7 @@ class _ContainerProcess(typing.Generic[T]):
39
40
  def __init__(
40
41
  self,
41
42
  process_id: str,
43
+ task_id: str,
42
44
  client: modal.client._Client,
43
45
  stdout: modal.stream_type.StreamType = modal.stream_type.StreamType.PIPE,
44
46
  stderr: modal.stream_type.StreamType = modal.stream_type.StreamType.PIPE,
@@ -86,47 +88,114 @@ class _ContainerProcess(typing.Generic[T]):
86
88
  """mdmd:hidden"""
87
89
  ...
88
90
 
89
- SUPERSELF = typing.TypeVar("SUPERSELF", covariant=True)
91
+ def _iter_stream_as_bytes(stream: modal.io_streams._StreamReader[T]):
92
+ """Yield raw bytes from a StreamReader regardless of text mode/backend."""
93
+ ...
90
94
 
91
- class ContainerProcess(typing.Generic[T]):
92
- """Abstract base class for generic types.
95
+ class _ContainerProcessThroughCommandRouter(typing.Generic[T]):
96
+ """Container process implementation that works via direct communication with
97
+ the Modal worker where the container is running.
98
+ """
99
+ def __init__(
100
+ self,
101
+ process_id: str,
102
+ client: modal.client._Client,
103
+ command_router_client: modal._utils.task_command_router_client.TaskCommandRouterClient,
104
+ task_id: str,
105
+ *,
106
+ stdout: modal.stream_type.StreamType = modal.stream_type.StreamType.PIPE,
107
+ stderr: modal.stream_type.StreamType = modal.stream_type.StreamType.PIPE,
108
+ exec_deadline: typing.Optional[float] = None,
109
+ text: bool = True,
110
+ by_line: bool = False,
111
+ ) -> None:
112
+ """Initialize self. See help(type(self)) for accurate signature."""
113
+ ...
93
114
 
94
- A generic type is typically declared by inheriting from
95
- this class parameterized with one or more type variables.
96
- For example, a generic mapping type might be defined as::
115
+ def __repr__(self) -> str:
116
+ """Return repr(self)."""
117
+ ...
97
118
 
98
- class Mapping(Generic[KT, VT]):
99
- def __getitem__(self, key: KT) -> VT:
100
- ...
101
- # Etc.
119
+ @property
120
+ def stdout(self) -> modal.io_streams._StreamReader[T]: ...
121
+ @property
122
+ def stderr(self) -> modal.io_streams._StreamReader[T]: ...
123
+ @property
124
+ def stdin(self) -> modal.io_streams._StreamWriter: ...
125
+ @property
126
+ def returncode(self) -> int: ...
127
+ async def poll(self) -> typing.Optional[int]: ...
128
+ async def wait(self) -> int: ...
129
+ async def attach(self): ...
102
130
 
103
- This class can then be used as follows::
131
+ class _ContainerProcess(typing.Generic[T]):
132
+ """Represents a running process in a container."""
133
+ def __init__(
134
+ self,
135
+ process_id: str,
136
+ task_id: str,
137
+ client: modal.client._Client,
138
+ stdout: modal.stream_type.StreamType = modal.stream_type.StreamType.PIPE,
139
+ stderr: modal.stream_type.StreamType = modal.stream_type.StreamType.PIPE,
140
+ exec_deadline: typing.Optional[float] = None,
141
+ text: bool = True,
142
+ by_line: bool = False,
143
+ command_router_client: typing.Optional[modal._utils.task_command_router_client.TaskCommandRouterClient] = None,
144
+ ) -> None:
145
+ """Initialize self. See help(type(self)) for accurate signature."""
146
+ ...
104
147
 
105
- def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT:
106
- try:
107
- return mapping[key]
108
- except KeyError:
109
- return default
110
- """
148
+ def __repr__(self) -> str:
149
+ """Return repr(self)."""
150
+ ...
111
151
 
112
- _process_id: typing.Optional[str]
113
- _stdout: modal.io_streams.StreamReader[T]
114
- _stderr: modal.io_streams.StreamReader[T]
115
- _stdin: modal.io_streams.StreamWriter
116
- _exec_deadline: typing.Optional[float]
117
- _text: bool
118
- _by_line: bool
119
- _returncode: typing.Optional[int]
152
+ @property
153
+ def stdout(self) -> modal.io_streams._StreamReader[T]:
154
+ """StreamReader for the container process's stdout stream."""
155
+ ...
156
+
157
+ @property
158
+ def stderr(self) -> modal.io_streams._StreamReader[T]:
159
+ """StreamReader for the container process's stderr stream."""
160
+ ...
161
+
162
+ @property
163
+ def stdin(self) -> modal.io_streams._StreamWriter:
164
+ """StreamWriter for the container process's stdin stream."""
165
+ ...
166
+
167
+ @property
168
+ def returncode(self) -> int: ...
169
+ async def poll(self) -> typing.Optional[int]:
170
+ """Check if the container process has finished running.
171
+
172
+ Returns `None` if the process is still running, else returns the exit code.
173
+ """
174
+ ...
175
+
176
+ async def wait(self) -> int:
177
+ """Wait for the container process to finish running. Returns the exit code."""
178
+ ...
179
+
180
+ async def attach(self):
181
+ """mdmd:hidden"""
182
+ ...
183
+
184
+ SUPERSELF = typing.TypeVar("SUPERSELF", covariant=True)
120
185
 
186
+ class ContainerProcess(typing.Generic[T]):
187
+ """Represents a running process in a container."""
121
188
  def __init__(
122
189
  self,
123
190
  process_id: str,
191
+ task_id: str,
124
192
  client: modal.client.Client,
125
193
  stdout: modal.stream_type.StreamType = modal.stream_type.StreamType.PIPE,
126
194
  stderr: modal.stream_type.StreamType = modal.stream_type.StreamType.PIPE,
127
195
  exec_deadline: typing.Optional[float] = None,
128
196
  text: bool = True,
129
197
  by_line: bool = False,
198
+ command_router_client: typing.Optional[modal._utils.task_command_router_client.TaskCommandRouterClient] = None,
130
199
  ) -> None: ...
131
200
  def __repr__(self) -> str: ...
132
201
  @property
@@ -164,12 +233,6 @@ class ContainerProcess(typing.Generic[T]):
164
233
 
165
234
  poll: __poll_spec[typing_extensions.Self]
166
235
 
167
- class ___wait_for_completion_spec(typing_extensions.Protocol[SUPERSELF]):
168
- def __call__(self, /) -> int: ...
169
- async def aio(self, /) -> int: ...
170
-
171
- _wait_for_completion: ___wait_for_completion_spec[typing_extensions.Self]
172
-
173
236
  class __wait_spec(typing_extensions.Protocol[SUPERSELF]):
174
237
  def __call__(self, /) -> int:
175
238
  """Wait for the container process to finish running. Returns the exit code."""
modal/exception.py CHANGED
@@ -42,6 +42,10 @@ class SandboxTimeoutError(TimeoutError):
42
42
  """Raised when a Sandbox exceeds its execution duration limit and times out."""
43
43
 
44
44
 
45
+ class ExecTimeoutError(TimeoutError):
46
+ """Raised when a container process exceeds its execution duration limit and times out."""
47
+
48
+
45
49
  class SandboxTerminatedError(Error):
46
50
  """Raised when a Sandbox is terminated for an internal reason."""
47
51
 
@@ -321,7 +321,7 @@ class _FlashPrometheusAutoscaler:
321
321
 
322
322
  async def _compute_target_containers(self, current_replicas: int) -> int:
323
323
  """
324
- Gets internal metrics from container to autoscale up or down.
324
+ Gets metrics from container to autoscale up or down.
325
325
  """
326
326
  containers = await self._get_all_containers()
327
327
  if len(containers) > current_replicas:
@@ -334,7 +334,7 @@ class _FlashPrometheusAutoscaler:
334
334
  if current_replicas == 0:
335
335
  return 1
336
336
 
337
- # Get metrics based on autoscaler type (prometheus or internal)
337
+ # Get metrics based on autoscaler type
338
338
  sum_metric, n_containers_with_metrics = await self._get_scaling_info(containers)
339
339
 
340
340
  desired_replicas = self._calculate_desired_replicas(
@@ -406,39 +406,26 @@ class _FlashPrometheusAutoscaler:
406
406
  return desired_replicas
407
407
 
408
408
  async def _get_scaling_info(self, containers) -> tuple[float, int]:
409
- """Get metrics using either internal container metrics API or prometheus HTTP endpoints."""
410
- if self.metrics_endpoint == "internal":
411
- container_metrics_results = await asyncio.gather(
412
- *[self._get_container_metrics(container.task_id) for container in containers]
413
- )
414
- container_metrics_list = []
415
- for container_metric in container_metrics_results:
416
- if container_metric is None:
417
- continue
418
- container_metrics_list.append(getattr(container_metric.metrics, self.target_metric))
419
-
420
- sum_metric = sum(container_metrics_list)
421
- n_containers_with_metrics = len(container_metrics_list)
422
- else:
423
- sum_metric = 0
424
- n_containers_with_metrics = 0
425
-
426
- container_metrics_list = await asyncio.gather(
427
- *[
428
- self._get_metrics(f"https://{container.host}:{container.port}/{self.metrics_endpoint}")
429
- for container in containers
430
- ]
431
- )
409
+ """Get metrics using container exposed metrics endpoints."""
410
+ sum_metric = 0
411
+ n_containers_with_metrics = 0
412
+
413
+ container_metrics_list = await asyncio.gather(
414
+ *[
415
+ self._get_metrics(f"https://{container.host}:{container.port}/{self.metrics_endpoint}")
416
+ for container in containers
417
+ ]
418
+ )
432
419
 
433
- for container_metrics in container_metrics_list:
434
- if (
435
- container_metrics is None
436
- or self.target_metric not in container_metrics
437
- or len(container_metrics[self.target_metric]) == 0
438
- ):
439
- continue
440
- sum_metric += container_metrics[self.target_metric][0].value
441
- n_containers_with_metrics += 1
420
+ for container_metrics in container_metrics_list:
421
+ if (
422
+ container_metrics is None
423
+ or self.target_metric not in container_metrics
424
+ or len(container_metrics[self.target_metric]) == 0
425
+ ):
426
+ continue
427
+ sum_metric += container_metrics[self.target_metric][0].value
428
+ n_containers_with_metrics += 1
442
429
 
443
430
  return sum_metric, n_containers_with_metrics
444
431
 
@@ -474,15 +461,6 @@ class _FlashPrometheusAutoscaler:
474
461
 
475
462
  return metrics
476
463
 
477
- async def _get_container_metrics(self, container_id: str) -> Optional[api_pb2.TaskGetAutoscalingMetricsResponse]:
478
- req = api_pb2.TaskGetAutoscalingMetricsRequest(task_id=container_id)
479
- try:
480
- resp = await retry_transient_errors(self.client.stub.TaskGetAutoscalingMetrics, req)
481
- return resp
482
- except Exception as e:
483
- logger.warning(f"[Modal Flash] Error getting metrics for container {container_id}: {e}")
484
- return None
485
-
486
464
  async def _get_all_containers(self):
487
465
  req = api_pb2.FlashContainerListRequest(function_id=self.fn.object_id)
488
466
  resp = await retry_transient_errors(self.client.stub.FlashContainerList, req)
@@ -572,14 +550,10 @@ async def flash_prometheus_autoscaler(
572
550
  app_name: str,
573
551
  cls_name: str,
574
552
  # Endpoint to fetch metrics from. Must be in Prometheus format. Example: "/metrics"
575
- # If metrics_endpoint is "internal", we will use containers' internal metrics to autoscale instead.
576
553
  metrics_endpoint: str,
577
554
  # Target metric to autoscale on. Example: "vllm:num_requests_running"
578
- # If metrics_endpoint is "internal", target_metrics options are: [cpu_usage_percent, memory_usage_percent]
579
555
  target_metric: str,
580
556
  # Target metric value. Example: 25
581
- # If metrics_endpoint is "internal", target_metric_value is a percentage value between 0.1 and 1.0 (inclusive),
582
- # indicating container's usage of that metric.
583
557
  target_metric_value: float,
584
558
  min_containers: Optional[int] = None,
585
559
  max_containers: Optional[int] = None,