modal 0.70.3__py3-none-any.whl → 0.71.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,7 +11,6 @@ if telemetry_socket:
11
11
  instrument_imports(telemetry_socket)
12
12
 
13
13
  import asyncio
14
- import base64
15
14
  import concurrent.futures
16
15
  import inspect
17
16
  import queue
@@ -337,14 +336,17 @@ def call_function(
337
336
  signal.signal(signal.SIGUSR1, usr1_handler) # reset signal handler
338
337
 
339
338
 
340
- def get_active_app_fallback(function_def: api_pb2.Function) -> Optional[_App]:
339
+ def get_active_app_fallback(function_def: api_pb2.Function) -> _App:
341
340
  # This branch is reached in the special case that the imported function/class is:
342
341
  # 1) not serialized, and
343
342
  # 2) isn't a FunctionHandle - i.e, not decorated at definition time
344
343
  # Look at all instantiated apps - if there is only one with the indicated name, use that one
345
344
  app_name: Optional[str] = function_def.app_name or None # coalesce protobuf field to None
346
345
  matching_apps = _App._all_apps.get(app_name, [])
347
- active_app = None
346
+ if len(matching_apps) == 1:
347
+ active_app: _App = matching_apps[0]
348
+ return active_app
349
+
348
350
  if len(matching_apps) > 1:
349
351
  if app_name is not None:
350
352
  warning_sub_message = f"app with the same name ('{app_name}')"
@@ -354,12 +356,10 @@ def get_active_app_fallback(function_def: api_pb2.Function) -> Optional[_App]:
354
356
  f"You have more than one {warning_sub_message}. "
355
357
  "It's recommended to name all your Apps uniquely when using multiple apps"
356
358
  )
357
- elif len(matching_apps) == 1:
358
- (active_app,) = matching_apps
359
- # there could also technically be zero found apps, but that should probably never be an
360
- # issue since that would mean user won't use is_inside or other function handles anyway
361
359
 
362
- return active_app
360
+ # If we don't have an active app, create one on the fly
361
+ # The app object is used to carry the app layout etc
362
+ return _App()
363
363
 
364
364
 
365
365
  def call_lifecycle_functions(
@@ -403,7 +403,7 @@ def main(container_args: api_pb2.ContainerArguments, client: Client):
403
403
  # This is a bit weird but we need both the blocking and async versions of ContainerIOManager.
404
404
  # At some point, we should fix that by having built-in support for running "user code"
405
405
  container_io_manager = ContainerIOManager(container_args, client)
406
- active_app: Optional[_App] = None
406
+ active_app: _App
407
407
  service: Service
408
408
  function_def = container_args.function_def
409
409
  is_auto_snapshot: bool = function_def.is_auto_snapshot
@@ -450,8 +450,9 @@ def main(container_args: api_pb2.ContainerArguments, client: Client):
450
450
  )
451
451
 
452
452
  # If the cls/function decorator was applied in local scope, but the app is global, we can look it up
453
- active_app = service.app
454
- if active_app is None:
453
+ if service.app is not None:
454
+ active_app = service.app
455
+ else:
455
456
  # if the app can't be inferred by the imported function, use name-based fallback
456
457
  active_app = get_active_app_fallback(function_def)
457
458
 
@@ -468,9 +469,8 @@ def main(container_args: api_pb2.ContainerArguments, client: Client):
468
469
 
469
470
  # Initialize objects on the app.
470
471
  # This is basically only functions and classes - anything else is deprecated and will be unsupported soon
471
- if active_app is not None:
472
- app: App = synchronizer._translate_out(active_app)
473
- app._init_container(client, container_app)
472
+ app: App = synchronizer._translate_out(active_app)
473
+ app._init_container(client, container_app)
474
474
 
475
475
  # Hydrate all function dependencies.
476
476
  # TODO(erikbern): we an remove this once we
@@ -581,7 +581,15 @@ if __name__ == "__main__":
581
581
  logger.debug("Container: starting")
582
582
 
583
583
  container_args = api_pb2.ContainerArguments()
584
- container_args.ParseFromString(base64.b64decode(sys.argv[1]))
584
+
585
+ container_arguments_path: Optional[str] = os.environ.get("MODAL_CONTAINER_ARGUMENTS_PATH")
586
+ if container_arguments_path is None:
587
+ # TODO(erikbern): this fallback is for old workers and we can remove it very soon (days)
588
+ import base64
589
+
590
+ container_args.ParseFromString(base64.b64decode(sys.argv[1]))
591
+ else:
592
+ container_args.ParseFromString(open(container_arguments_path, "rb").read())
585
593
 
586
594
  # Note that we're creating the client in a synchronous context, but it will be running in a separate thread.
587
595
  # This is good because if the function is long running then we the client can still send heartbeats
@@ -457,12 +457,7 @@ class _ContainerIOManager:
457
457
  resp = await retry_transient_errors(self._client.stub.AppGetLayout, req)
458
458
  app_layout = resp.app_layout
459
459
 
460
- return running_app_from_layout(
461
- self.app_id,
462
- app_layout,
463
- self._client,
464
- environment_name=self._environment_name,
465
- )
460
+ return running_app_from_layout(self.app_id, app_layout)
466
461
 
467
462
  async def get_serialized_function(self) -> tuple[Optional[Any], Optional[Callable[..., Any]]]:
468
463
  # Fetch the serialized function definition
modal/app.py CHANGED
@@ -168,7 +168,7 @@ class _App:
168
168
  """
169
169
 
170
170
  _all_apps: ClassVar[dict[Optional[str], list["_App"]]] = {}
171
- _container_app: ClassVar[Optional[RunningApp]] = None
171
+ _container_app: ClassVar[Optional["_App"]] = None
172
172
 
173
173
  _name: Optional[str]
174
174
  _description: Optional[str]
@@ -294,12 +294,7 @@ class _App:
294
294
  app = _App(name)
295
295
  app._app_id = response.app_id
296
296
  app._client = client
297
- app._running_app = RunningApp(
298
- response.app_id,
299
- client=client,
300
- environment_name=environment_name,
301
- interactive=False,
302
- )
297
+ app._running_app = RunningApp(response.app_id, interactive=False)
303
298
  return app
304
299
 
305
300
  def set_description(self, description: str):
@@ -488,7 +483,7 @@ class _App:
488
483
  self._running_app = running_app
489
484
  self._client = client
490
485
 
491
- _App._container_app = running_app
486
+ _App._container_app = self
492
487
 
493
488
  # Hydrate function objects
494
489
  for tag, object_id in running_app.function_ids.items():
@@ -1047,6 +1042,13 @@ class _App:
1047
1042
  if log.data:
1048
1043
  yield log.data
1049
1044
 
1045
+ @classmethod
1046
+ def _get_container_app(cls) -> Optional["_App"]:
1047
+ """Returns the `App` running inside a container.
1048
+
1049
+ This will return `None` outside of a Modal container."""
1050
+ return cls._container_app
1051
+
1050
1052
  @classmethod
1051
1053
  def _reset_container_app(cls):
1052
1054
  """Only used for tests."""
modal/app.pyi CHANGED
@@ -73,7 +73,7 @@ class _FunctionDecoratorType:
73
73
 
74
74
  class _App:
75
75
  _all_apps: typing.ClassVar[dict[typing.Optional[str], list[_App]]]
76
- _container_app: typing.ClassVar[typing.Optional[modal.running_app.RunningApp]]
76
+ _container_app: typing.ClassVar[typing.Optional[_App]]
77
77
  _name: typing.Optional[str]
78
78
  _description: typing.Optional[str]
79
79
  _functions: dict[str, modal.functions._Function]
@@ -266,11 +266,13 @@ class _App:
266
266
  self, client: typing.Optional[modal.client._Client] = None
267
267
  ) -> collections.abc.AsyncGenerator[str, None]: ...
268
268
  @classmethod
269
+ def _get_container_app(cls) -> typing.Optional[_App]: ...
270
+ @classmethod
269
271
  def _reset_container_app(cls): ...
270
272
 
271
273
  class App:
272
274
  _all_apps: typing.ClassVar[dict[typing.Optional[str], list[App]]]
273
- _container_app: typing.ClassVar[typing.Optional[modal.running_app.RunningApp]]
275
+ _container_app: typing.ClassVar[typing.Optional[App]]
274
276
  _name: typing.Optional[str]
275
277
  _description: typing.Optional[str]
276
278
  _functions: dict[str, modal.functions.Function]
@@ -530,6 +532,8 @@ class App:
530
532
 
531
533
  _logs: ___logs_spec
532
534
 
535
+ @classmethod
536
+ def _get_container_app(cls) -> typing.Optional[App]: ...
533
537
  @classmethod
534
538
  def _reset_container_app(cls): ...
535
539
 
modal/client.pyi CHANGED
@@ -26,7 +26,7 @@ class _Client:
26
26
  _stub: typing.Optional[modal_proto.api_grpc.ModalClientStub]
27
27
 
28
28
  def __init__(
29
- self, server_url: str, client_type: int, credentials: typing.Optional[tuple[str, str]], version: str = "0.70.3"
29
+ self, server_url: str, client_type: int, credentials: typing.Optional[tuple[str, str]], version: str = "0.71.2"
30
30
  ): ...
31
31
  def is_closed(self) -> bool: ...
32
32
  @property
@@ -81,7 +81,7 @@ class Client:
81
81
  _stub: typing.Optional[modal_proto.api_grpc.ModalClientStub]
82
82
 
83
83
  def __init__(
84
- self, server_url: str, client_type: int, credentials: typing.Optional[tuple[str, str]], version: str = "0.70.3"
84
+ self, server_url: str, client_type: int, credentials: typing.Optional[tuple[str, str]], version: str = "0.71.2"
85
85
  ): ...
86
86
  def is_closed(self) -> bool: ...
87
87
  @property
modal/experimental.py CHANGED
@@ -48,6 +48,9 @@ def clustered(size: int, broadcast: bool = True):
48
48
 
49
49
  assert broadcast, "broadcast=False has not been implemented yet!"
50
50
 
51
+ if size <= 0:
52
+ raise ValueError("cluster size must be greater than 0")
53
+
51
54
  def wrapper(raw_f: Callable[..., Any]) -> _PartialFunction:
52
55
  if isinstance(raw_f, _Function):
53
56
  raw_f = raw_f.get_raw_f()
modal/file_io.py CHANGED
@@ -1,6 +1,8 @@
1
1
  # Copyright Modal Labs 2024
2
2
  import asyncio
3
+ import enum
3
4
  import io
5
+ from dataclasses import dataclass
4
6
  from typing import TYPE_CHECKING, AsyncIterator, Generic, Optional, Sequence, TypeVar, Union, cast
5
7
 
6
8
  if TYPE_CHECKING:
@@ -11,6 +13,7 @@ import json
11
13
  from grpclib.exceptions import GRPCError, StreamTerminatedError
12
14
 
13
15
  from modal._utils.grpc_utils import retry_transient_errors
16
+ from modal.io_streams_helper import consume_stream_with_retries
14
17
  from modal_proto import api_pb2
15
18
 
16
19
  from ._utils.async_utils import synchronize_api
@@ -94,6 +97,20 @@ async def _replace_bytes(file: "_FileIO", data: bytes, start: Optional[int] = No
94
97
  await file._wait(resp.exec_id)
95
98
 
96
99
 
100
+ class FileWatchEventType(enum.Enum):
101
+ Unknown = "Unknown"
102
+ Access = "Access"
103
+ Create = "Create"
104
+ Modify = "Modify"
105
+ Remove = "Remove"
106
+
107
+
108
+ @dataclass
109
+ class FileWatchEvent:
110
+ paths: list[str]
111
+ type: FileWatchEventType
112
+
113
+
97
114
  # The FileIO class is designed to mimic Python's io.FileIO
98
115
  # See https://github.com/python/cpython/blob/main/Lib/_pyio.py#L1459
99
116
  class _FileIO(Generic[T]):
@@ -124,6 +141,7 @@ class _FileIO(Generic[T]):
124
141
  _task_id: str = ""
125
142
  _file_descriptor: str = ""
126
143
  _client: Optional[_Client] = None
144
+ _watch_output_buffer: list[Optional[bytes]] = []
127
145
 
128
146
  def _validate_mode(self, mode: str) -> None:
129
147
  if not any(char in mode for char in "rwax"):
@@ -167,6 +185,44 @@ class _FileIO(Generic[T]):
167
185
  for message in batch.output:
168
186
  yield message
169
187
 
188
+ async def _consume_watch_output(self, exec_id: str) -> None:
189
+ def item_handler(item: Optional[bytes]):
190
+ self._watch_output_buffer.append(item)
191
+
192
+ def completion_check(item: Optional[bytes]):
193
+ return item is None
194
+
195
+ await consume_stream_with_retries(
196
+ self._consume_output(exec_id),
197
+ item_handler,
198
+ completion_check,
199
+ )
200
+
201
+ async def _parse_watch_output(self, event: bytes) -> Optional[FileWatchEvent]:
202
+ try:
203
+ event_json = json.loads(event.decode())
204
+ return FileWatchEvent(type=FileWatchEventType(event_json["event_type"]), paths=event_json["paths"])
205
+ except (json.JSONDecodeError, KeyError, ValueError):
206
+ # skip invalid events
207
+ return None
208
+
209
+ async def _stream_watch_output(self) -> AsyncIterator[FileWatchEvent]:
210
+ buffer = b""
211
+ while True:
212
+ if len(self._watch_output_buffer) > 0:
213
+ item = self._watch_output_buffer.pop(0)
214
+ if item is None:
215
+ break
216
+ buffer += item
217
+ # a single event may be split across multiple messages, the end of an event is marked by two newlines
218
+ if buffer.endswith(b"\n\n"):
219
+ event = await self._parse_watch_output(buffer.strip())
220
+ if event is not None:
221
+ yield event
222
+ buffer = b""
223
+ else:
224
+ await asyncio.sleep(0.1)
225
+
170
226
  async def _wait(self, exec_id: str) -> bytes:
171
227
  # The logic here is similar to how output is read from `exec`
172
228
  output = b""
@@ -391,6 +447,36 @@ class _FileIO(Generic[T]):
391
447
  )
392
448
  await self._wait(resp.exec_id)
393
449
 
450
+ @classmethod
451
+ async def watch(
452
+ cls,
453
+ path: str,
454
+ client: _Client,
455
+ task_id: str,
456
+ filter: Optional[list[FileWatchEventType]] = None,
457
+ recursive: bool = False,
458
+ timeout: Optional[int] = None,
459
+ ) -> AsyncIterator[FileWatchEvent]:
460
+ self = cls.__new__(cls)
461
+ self._client = client
462
+ self._task_id = task_id
463
+ resp = await self._make_request(
464
+ api_pb2.ContainerFilesystemExecRequest(
465
+ file_watch_request=api_pb2.ContainerFileWatchRequest(
466
+ path=path,
467
+ recursive=recursive,
468
+ timeout_secs=timeout,
469
+ ),
470
+ task_id=self._task_id,
471
+ )
472
+ )
473
+ task = asyncio.create_task(self._consume_watch_output(resp.exec_id))
474
+ async for event in self._stream_watch_output():
475
+ if filter and event.type not in filter:
476
+ continue
477
+ yield event
478
+ task.cancel()
479
+
394
480
  async def _close(self) -> None:
395
481
  # Buffer is flushed by the runner on close
396
482
  resp = await self._make_request(
modal/file_io.pyi CHANGED
@@ -1,4 +1,5 @@
1
1
  import _typeshed
2
+ import enum
2
3
  import modal.client
3
4
  import modal_proto.api_pb2
4
5
  import typing
@@ -13,14 +14,33 @@ async def _replace_bytes(
13
14
  file: _FileIO, data: bytes, start: typing.Optional[int] = None, end: typing.Optional[int] = None
14
15
  ) -> None: ...
15
16
 
17
+ class FileWatchEventType(enum.Enum):
18
+ Unknown = "Unknown"
19
+ Access = "Access"
20
+ Create = "Create"
21
+ Modify = "Modify"
22
+ Remove = "Remove"
23
+
24
+ class FileWatchEvent:
25
+ paths: list[str]
26
+ type: FileWatchEventType
27
+
28
+ def __init__(self, paths: list[str], type: FileWatchEventType) -> None: ...
29
+ def __repr__(self): ...
30
+ def __eq__(self, other): ...
31
+
16
32
  class _FileIO(typing.Generic[T]):
17
33
  _task_id: str
18
34
  _file_descriptor: str
19
35
  _client: typing.Optional[modal.client._Client]
36
+ _watch_output_buffer: list[typing.Optional[bytes]]
20
37
 
21
38
  def _validate_mode(self, mode: str) -> None: ...
22
39
  def _handle_error(self, error: modal_proto.api_pb2.SystemErrorMessage) -> None: ...
23
40
  def _consume_output(self, exec_id: str) -> typing.AsyncIterator[typing.Optional[bytes]]: ...
41
+ async def _consume_watch_output(self, exec_id: str) -> None: ...
42
+ async def _parse_watch_output(self, event: bytes) -> typing.Optional[FileWatchEvent]: ...
43
+ def _stream_watch_output(self) -> typing.AsyncIterator[FileWatchEvent]: ...
24
44
  async def _wait(self, exec_id: str) -> bytes: ...
25
45
  def _validate_type(self, data: typing.Union[bytes, str]) -> None: ...
26
46
  async def _open_file(self, path: str, mode: str) -> None: ...
@@ -49,6 +69,16 @@ class _FileIO(typing.Generic[T]):
49
69
  async def mkdir(cls, path: str, client: modal.client._Client, task_id: str, parents: bool = False) -> None: ...
50
70
  @classmethod
51
71
  async def rm(cls, path: str, client: modal.client._Client, task_id: str, recursive: bool = False) -> None: ...
72
+ @classmethod
73
+ def watch(
74
+ cls,
75
+ path: str,
76
+ client: modal.client._Client,
77
+ task_id: str,
78
+ filter: typing.Optional[list[FileWatchEventType]] = None,
79
+ recursive: bool = False,
80
+ timeout: typing.Optional[int] = None,
81
+ ) -> typing.AsyncIterator[FileWatchEvent]: ...
52
82
  async def _close(self) -> None: ...
53
83
  async def close(self) -> None: ...
54
84
  def _check_writable(self) -> None: ...
@@ -79,6 +109,7 @@ class FileIO(typing.Generic[T]):
79
109
  _task_id: str
80
110
  _file_descriptor: str
81
111
  _client: typing.Optional[modal.client.Client]
112
+ _watch_output_buffer: list[typing.Optional[bytes]]
82
113
 
83
114
  def __init__(self, /, *args, **kwargs): ...
84
115
  def _validate_mode(self, mode: str) -> None: ...
@@ -90,6 +121,24 @@ class FileIO(typing.Generic[T]):
90
121
 
91
122
  _consume_output: ___consume_output_spec
92
123
 
124
+ class ___consume_watch_output_spec(typing_extensions.Protocol):
125
+ def __call__(self, exec_id: str) -> None: ...
126
+ async def aio(self, exec_id: str) -> None: ...
127
+
128
+ _consume_watch_output: ___consume_watch_output_spec
129
+
130
+ class ___parse_watch_output_spec(typing_extensions.Protocol):
131
+ def __call__(self, event: bytes) -> typing.Optional[FileWatchEvent]: ...
132
+ async def aio(self, event: bytes) -> typing.Optional[FileWatchEvent]: ...
133
+
134
+ _parse_watch_output: ___parse_watch_output_spec
135
+
136
+ class ___stream_watch_output_spec(typing_extensions.Protocol):
137
+ def __call__(self) -> typing.Iterator[FileWatchEvent]: ...
138
+ def aio(self) -> typing.AsyncIterator[FileWatchEvent]: ...
139
+
140
+ _stream_watch_output: ___stream_watch_output_spec
141
+
93
142
  class ___wait_spec(typing_extensions.Protocol):
94
143
  def __call__(self, exec_id: str) -> bytes: ...
95
144
  async def aio(self, exec_id: str) -> bytes: ...
@@ -173,6 +222,16 @@ class FileIO(typing.Generic[T]):
173
222
  def mkdir(cls, path: str, client: modal.client.Client, task_id: str, parents: bool = False) -> None: ...
174
223
  @classmethod
175
224
  def rm(cls, path: str, client: modal.client.Client, task_id: str, recursive: bool = False) -> None: ...
225
+ @classmethod
226
+ def watch(
227
+ cls,
228
+ path: str,
229
+ client: modal.client.Client,
230
+ task_id: str,
231
+ filter: typing.Optional[list[FileWatchEventType]] = None,
232
+ recursive: bool = False,
233
+ timeout: typing.Optional[int] = None,
234
+ ) -> typing.Iterator[FileWatchEvent]: ...
176
235
 
177
236
  class ___close_spec(typing_extensions.Protocol):
178
237
  def __call__(self) -> None: ...
modal/functions.pyi CHANGED
@@ -462,11 +462,11 @@ class Function(typing.Generic[P, ReturnType, OriginalReturnType], modal.object.O
462
462
 
463
463
  _call_generator_nowait: ___call_generator_nowait_spec
464
464
 
465
- class __remote_spec(typing_extensions.Protocol[ReturnType_INNER, P_INNER]):
465
+ class __remote_spec(typing_extensions.Protocol[P_INNER, ReturnType_INNER]):
466
466
  def __call__(self, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> ReturnType_INNER: ...
467
467
  async def aio(self, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> ReturnType_INNER: ...
468
468
 
469
- remote: __remote_spec[ReturnType, P]
469
+ remote: __remote_spec[P, ReturnType]
470
470
 
471
471
  class __remote_gen_spec(typing_extensions.Protocol):
472
472
  def __call__(self, *args, **kwargs) -> typing.Generator[typing.Any, None, None]: ...
@@ -479,17 +479,17 @@ class Function(typing.Generic[P, ReturnType, OriginalReturnType], modal.object.O
479
479
  def _get_obj(self) -> typing.Optional[modal.cls.Obj]: ...
480
480
  def local(self, *args: P.args, **kwargs: P.kwargs) -> OriginalReturnType: ...
481
481
 
482
- class ___experimental_spawn_spec(typing_extensions.Protocol[ReturnType_INNER, P_INNER]):
482
+ class ___experimental_spawn_spec(typing_extensions.Protocol[P_INNER, ReturnType_INNER]):
483
483
  def __call__(self, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> FunctionCall[ReturnType_INNER]: ...
484
484
  async def aio(self, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> FunctionCall[ReturnType_INNER]: ...
485
485
 
486
- _experimental_spawn: ___experimental_spawn_spec[ReturnType, P]
486
+ _experimental_spawn: ___experimental_spawn_spec[P, ReturnType]
487
487
 
488
- class __spawn_spec(typing_extensions.Protocol[ReturnType_INNER, P_INNER]):
488
+ class __spawn_spec(typing_extensions.Protocol[P_INNER, ReturnType_INNER]):
489
489
  def __call__(self, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> FunctionCall[ReturnType_INNER]: ...
490
490
  async def aio(self, *args: P_INNER.args, **kwargs: P_INNER.kwargs) -> FunctionCall[ReturnType_INNER]: ...
491
491
 
492
- spawn: __spawn_spec[ReturnType, P]
492
+ spawn: __spawn_spec[P, ReturnType]
493
493
 
494
494
  def get_raw_f(self) -> typing.Callable[..., typing.Any]: ...
495
495
 
modal/io_streams.py CHANGED
@@ -14,7 +14,8 @@ from typing import (
14
14
  from grpclib import Status
15
15
  from grpclib.exceptions import GRPCError, StreamTerminatedError
16
16
 
17
- from modal.exception import ClientClosed, InvalidError
17
+ from modal.exception import InvalidError
18
+ from modal.io_streams_helper import consume_stream_with_retries
18
19
  from modal_proto import api_pb2
19
20
 
20
21
  from ._utils.async_utils import synchronize_api
@@ -176,34 +177,21 @@ class _StreamReader(Generic[T]):
176
177
  if self._stream_type == StreamType.DEVNULL:
177
178
  return
178
179
 
179
- completed = False
180
- retries_remaining = 10
181
- while not completed:
182
- try:
183
- iterator = _container_process_logs_iterator(self._object_id, self._file_descriptor, self._client)
180
+ def item_handler(item: Optional[bytes]):
181
+ if self._stream_type == StreamType.STDOUT and item is not None:
182
+ print(item.decode("utf-8"), end="")
183
+ elif self._stream_type == StreamType.PIPE:
184
+ self._container_process_buffer.append(item)
184
185
 
185
- async for message in iterator:
186
- if self._stream_type == StreamType.STDOUT and message:
187
- print(message.decode("utf-8"), end="")
188
- elif self._stream_type == StreamType.PIPE:
189
- self._container_process_buffer.append(message)
190
- if message is None:
191
- completed = True
192
- break
186
+ def completion_check(item: Optional[bytes]):
187
+ return item is None
193
188
 
194
- except (GRPCError, StreamTerminatedError, ClientClosed) as exc:
195
- if retries_remaining > 0:
196
- retries_remaining -= 1
197
- if isinstance(exc, GRPCError):
198
- if exc.status in RETRYABLE_GRPC_STATUS_CODES:
199
- await asyncio.sleep(1.0)
200
- continue
201
- elif isinstance(exc, StreamTerminatedError):
202
- continue
203
- elif isinstance(exc, ClientClosed):
204
- # If the client was closed, the user has triggered a cleanup.
205
- break
206
- raise exc
189
+ iterator = _container_process_logs_iterator(self._object_id, self._file_descriptor, self._client)
190
+ await consume_stream_with_retries(
191
+ iterator,
192
+ item_handler,
193
+ completion_check,
194
+ )
207
195
 
208
196
  async def _stream_container_process(self) -> AsyncGenerator[tuple[Optional[bytes], str], None]:
209
197
  """Streams the container process buffer to the reader."""
@@ -0,0 +1,53 @@
1
+ # Copyright Modal Labs 2024
2
+ import asyncio
3
+ from typing import AsyncIterator, Callable, TypeVar
4
+
5
+ from grpclib.exceptions import GRPCError, StreamTerminatedError
6
+
7
+ from modal.exception import ClientClosed
8
+
9
+ from ._utils.grpc_utils import RETRYABLE_GRPC_STATUS_CODES
10
+
11
+ T = TypeVar("T")
12
+
13
+
14
+ async def consume_stream_with_retries(
15
+ stream: AsyncIterator[T],
16
+ item_handler: Callable[[T], None],
17
+ completion_check: Callable[[T], bool],
18
+ max_retries: int = 10,
19
+ retry_delay: float = 1.0,
20
+ ) -> None:
21
+ """mdmd:hidden
22
+ Helper function to consume a stream with retry logic for transient errors.
23
+
24
+ Args:
25
+ stream_generator: Function that returns an AsyncIterator to consume
26
+ item_handler: Callback function to handle each item from the stream
27
+ completion_check: Callback function to check if the stream is complete
28
+ max_retries: Maximum number of retry attempts
29
+ retry_delay: Delay in seconds between retries
30
+ """
31
+ completed = False
32
+ retries_remaining = max_retries
33
+
34
+ while not completed:
35
+ try:
36
+ async for item in stream:
37
+ item_handler(item)
38
+ if completion_check(item):
39
+ completed = True
40
+ break
41
+
42
+ except (GRPCError, StreamTerminatedError, ClientClosed) as exc:
43
+ if retries_remaining > 0:
44
+ retries_remaining -= 1
45
+ if isinstance(exc, GRPCError):
46
+ if exc.status in RETRYABLE_GRPC_STATUS_CODES:
47
+ await asyncio.sleep(retry_delay)
48
+ continue
49
+ elif isinstance(exc, StreamTerminatedError):
50
+ continue
51
+ elif isinstance(exc, ClientClosed):
52
+ break
53
+ raise
modal/runner.py CHANGED
@@ -64,7 +64,6 @@ async def _init_local_app_existing(client: _Client, existing_app_id: str, enviro
64
64
  return running_app_from_layout(
65
65
  existing_app_id,
66
66
  obj_resp.app_layout,
67
- client,
68
67
  app_page_url=app_page_url,
69
68
  )
70
69
 
@@ -89,10 +88,8 @@ async def _init_local_app_new(
89
88
  logger.debug(f"Created new app with id {app_resp.app_id}")
90
89
  return RunningApp(
91
90
  app_resp.app_id,
92
- client=client,
93
91
  app_page_url=app_resp.app_page_url,
94
92
  app_logs_url=app_resp.app_logs_url,
95
- environment_name=environment_name,
96
93
  interactive=interactive,
97
94
  )
98
95
 
modal/running_app.py CHANGED
@@ -7,14 +7,10 @@ from google.protobuf.message import Message
7
7
  from modal._utils.grpc_utils import get_proto_oneof
8
8
  from modal_proto import api_pb2
9
9
 
10
- from .client import _Client
11
-
12
10
 
13
11
  @dataclass
14
12
  class RunningApp:
15
13
  app_id: str
16
- client: _Client
17
- environment_name: Optional[str] = None
18
14
  app_page_url: Optional[str] = None
19
15
  app_logs_url: Optional[str] = None
20
16
  function_ids: dict[str, str] = field(default_factory=dict)
@@ -26,8 +22,6 @@ class RunningApp:
26
22
  def running_app_from_layout(
27
23
  app_id: str,
28
24
  app_layout: api_pb2.AppLayout,
29
- client: _Client,
30
- environment_name: Optional[str] = None,
31
25
  app_page_url: Optional[str] = None,
32
26
  ) -> RunningApp:
33
27
  object_handle_metadata = {}
@@ -37,8 +31,6 @@ def running_app_from_layout(
37
31
 
38
32
  return RunningApp(
39
33
  app_id,
40
- client,
41
- environment_name=environment_name,
42
34
  function_ids=dict(app_layout.function_ids),
43
35
  class_ids=dict(app_layout.class_ids),
44
36
  object_handle_metadata=object_handle_metadata,