modal 1.1.5.dev83__py3-none-any.whl → 1.3.1.dev8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of modal might be problematic. Click here for more details.

Files changed (139) hide show
  1. modal/__init__.py +4 -4
  2. modal/__main__.py +4 -29
  3. modal/_billing.py +84 -0
  4. modal/_clustered_functions.py +1 -3
  5. modal/_container_entrypoint.py +33 -208
  6. modal/_functions.py +146 -121
  7. modal/_grpc_client.py +191 -0
  8. modal/_ipython.py +16 -6
  9. modal/_load_context.py +106 -0
  10. modal/_object.py +72 -21
  11. modal/_output.py +12 -14
  12. modal/_partial_function.py +31 -4
  13. modal/_resolver.py +44 -57
  14. modal/_runtime/container_io_manager.py +26 -28
  15. modal/_runtime/container_io_manager.pyi +42 -44
  16. modal/_runtime/gpu_memory_snapshot.py +9 -7
  17. modal/_runtime/user_code_event_loop.py +80 -0
  18. modal/_runtime/user_code_imports.py +236 -10
  19. modal/_serialization.py +2 -1
  20. modal/_traceback.py +4 -13
  21. modal/_tunnel.py +16 -11
  22. modal/_tunnel.pyi +25 -3
  23. modal/_utils/async_utils.py +337 -10
  24. modal/_utils/auth_token_manager.py +1 -4
  25. modal/_utils/blob_utils.py +29 -22
  26. modal/_utils/function_utils.py +20 -21
  27. modal/_utils/grpc_testing.py +6 -3
  28. modal/_utils/grpc_utils.py +223 -64
  29. modal/_utils/mount_utils.py +26 -1
  30. modal/_utils/package_utils.py +0 -1
  31. modal/_utils/rand_pb_testing.py +8 -1
  32. modal/_utils/task_command_router_client.py +524 -0
  33. modal/_vendor/cloudpickle.py +144 -48
  34. modal/app.py +215 -96
  35. modal/app.pyi +78 -37
  36. modal/billing.py +5 -0
  37. modal/builder/2025.06.txt +6 -3
  38. modal/builder/PREVIEW.txt +2 -1
  39. modal/builder/base-images.json +4 -2
  40. modal/cli/_download.py +19 -3
  41. modal/cli/cluster.py +4 -2
  42. modal/cli/config.py +3 -1
  43. modal/cli/container.py +5 -4
  44. modal/cli/dict.py +5 -2
  45. modal/cli/entry_point.py +26 -2
  46. modal/cli/environment.py +2 -16
  47. modal/cli/launch.py +1 -76
  48. modal/cli/network_file_system.py +5 -20
  49. modal/cli/queues.py +5 -4
  50. modal/cli/run.py +24 -204
  51. modal/cli/secret.py +1 -2
  52. modal/cli/shell.py +375 -0
  53. modal/cli/utils.py +1 -13
  54. modal/cli/volume.py +11 -17
  55. modal/client.py +16 -125
  56. modal/client.pyi +94 -144
  57. modal/cloud_bucket_mount.py +3 -1
  58. modal/cloud_bucket_mount.pyi +4 -0
  59. modal/cls.py +101 -64
  60. modal/cls.pyi +9 -8
  61. modal/config.py +21 -1
  62. modal/container_process.py +288 -12
  63. modal/container_process.pyi +99 -38
  64. modal/dict.py +72 -33
  65. modal/dict.pyi +88 -57
  66. modal/environments.py +16 -8
  67. modal/environments.pyi +6 -2
  68. modal/exception.py +154 -16
  69. modal/experimental/__init__.py +23 -5
  70. modal/experimental/flash.py +161 -74
  71. modal/experimental/flash.pyi +97 -49
  72. modal/file_io.py +50 -92
  73. modal/file_io.pyi +117 -89
  74. modal/functions.pyi +70 -87
  75. modal/image.py +73 -47
  76. modal/image.pyi +33 -30
  77. modal/io_streams.py +500 -149
  78. modal/io_streams.pyi +279 -189
  79. modal/mount.py +60 -45
  80. modal/mount.pyi +41 -17
  81. modal/network_file_system.py +19 -11
  82. modal/network_file_system.pyi +72 -39
  83. modal/object.pyi +114 -22
  84. modal/parallel_map.py +42 -44
  85. modal/parallel_map.pyi +9 -17
  86. modal/partial_function.pyi +4 -2
  87. modal/proxy.py +14 -6
  88. modal/proxy.pyi +10 -2
  89. modal/queue.py +45 -38
  90. modal/queue.pyi +88 -52
  91. modal/runner.py +96 -96
  92. modal/runner.pyi +44 -27
  93. modal/sandbox.py +225 -108
  94. modal/sandbox.pyi +226 -63
  95. modal/secret.py +58 -56
  96. modal/secret.pyi +28 -13
  97. modal/serving.py +7 -11
  98. modal/serving.pyi +7 -8
  99. modal/snapshot.py +29 -15
  100. modal/snapshot.pyi +18 -10
  101. modal/token_flow.py +1 -1
  102. modal/token_flow.pyi +4 -6
  103. modal/volume.py +102 -55
  104. modal/volume.pyi +125 -66
  105. {modal-1.1.5.dev83.dist-info → modal-1.3.1.dev8.dist-info}/METADATA +10 -9
  106. modal-1.3.1.dev8.dist-info/RECORD +189 -0
  107. modal_proto/api.proto +86 -30
  108. modal_proto/api_grpc.py +10 -25
  109. modal_proto/api_pb2.py +1080 -1047
  110. modal_proto/api_pb2.pyi +253 -79
  111. modal_proto/api_pb2_grpc.py +14 -48
  112. modal_proto/api_pb2_grpc.pyi +6 -18
  113. modal_proto/modal_api_grpc.py +175 -176
  114. modal_proto/{sandbox_router.proto → task_command_router.proto} +62 -45
  115. modal_proto/task_command_router_grpc.py +138 -0
  116. modal_proto/task_command_router_pb2.py +180 -0
  117. modal_proto/{sandbox_router_pb2.pyi → task_command_router_pb2.pyi} +110 -63
  118. modal_proto/task_command_router_pb2_grpc.py +272 -0
  119. modal_proto/task_command_router_pb2_grpc.pyi +100 -0
  120. modal_version/__init__.py +1 -1
  121. modal_version/__main__.py +1 -1
  122. modal/cli/programs/launch_instance_ssh.py +0 -94
  123. modal/cli/programs/run_marimo.py +0 -95
  124. modal-1.1.5.dev83.dist-info/RECORD +0 -191
  125. modal_proto/modal_options_grpc.py +0 -3
  126. modal_proto/options.proto +0 -19
  127. modal_proto/options_grpc.py +0 -3
  128. modal_proto/options_pb2.py +0 -35
  129. modal_proto/options_pb2.pyi +0 -20
  130. modal_proto/options_pb2_grpc.py +0 -4
  131. modal_proto/options_pb2_grpc.pyi +0 -7
  132. modal_proto/sandbox_router_grpc.py +0 -105
  133. modal_proto/sandbox_router_pb2.py +0 -148
  134. modal_proto/sandbox_router_pb2_grpc.py +0 -203
  135. modal_proto/sandbox_router_pb2_grpc.pyi +0 -75
  136. {modal-1.1.5.dev83.dist-info → modal-1.3.1.dev8.dist-info}/WHEEL +0 -0
  137. {modal-1.1.5.dev83.dist-info → modal-1.3.1.dev8.dist-info}/entry_points.txt +0 -0
  138. {modal-1.1.5.dev83.dist-info → modal-1.3.1.dev8.dist-info}/licenses/LICENSE +0 -0
  139. {modal-1.1.5.dev83.dist-info → modal-1.3.1.dev8.dist-info}/top_level.txt +0 -0
modal/io_streams.py CHANGED
@@ -1,27 +1,33 @@
1
1
  # Copyright Modal Labs 2022
2
2
  import asyncio
3
+ import codecs
4
+ import contextlib
5
+ import io
6
+ import sys
3
7
  import time
4
8
  from collections.abc import AsyncGenerator, AsyncIterator
9
+ from dataclasses import dataclass
5
10
  from typing import (
6
11
  TYPE_CHECKING,
7
12
  Generic,
8
13
  Literal,
9
14
  Optional,
15
+ TextIO,
10
16
  TypeVar,
11
17
  Union,
12
18
  cast,
13
19
  )
14
20
 
15
- from grpclib import Status
16
- from grpclib.exceptions import GRPCError, StreamTerminatedError
21
+ from grpclib.exceptions import StreamTerminatedError
17
22
 
18
- from modal.exception import ClientClosed, InvalidError
23
+ from modal.exception import ClientClosed, ExecTimeoutError, InvalidError
19
24
  from modal_proto import api_pb2
20
25
 
21
- from ._utils.async_utils import synchronize_api
22
- from ._utils.grpc_utils import RETRYABLE_GRPC_STATUS_CODES, retry_transient_errors
26
+ from ._utils.async_utils import aclosing, synchronize_api, synchronizer
27
+ from ._utils.task_command_router_client import TaskCommandRouterClient
23
28
  from .client import _Client
24
29
  from .config import logger
30
+ from .exception import ConflictError, InternalError, ServiceError
25
31
  from .stream_type import StreamType
26
32
 
27
33
  if TYPE_CHECKING:
@@ -61,7 +67,6 @@ async def _container_process_logs_iterator(
61
67
  get_raw_bytes=True,
62
68
  last_batch_index=last_index,
63
69
  )
64
-
65
70
  stream = client.stub.ContainerExecGetOutput.unary_stream(req)
66
71
  while True:
67
72
  # Check deadline before attempting to receive the next batch
@@ -73,39 +78,22 @@ async def _container_process_logs_iterator(
73
78
  break
74
79
  except StopAsyncIteration:
75
80
  break
81
+
82
+ for item in batch.items:
83
+ yield item.message_bytes, batch.batch_index
84
+
76
85
  if batch.HasField("exit_code"):
77
86
  yield None, batch.batch_index
78
87
  break
79
- for item in batch.items:
80
- yield item.message_bytes, batch.batch_index
81
88
 
82
89
 
83
90
  T = TypeVar("T", str, bytes)
84
91
 
85
92
 
86
- class _StreamReader(Generic[T]):
87
- """Retrieve logs from a stream (`stdout` or `stderr`).
88
-
89
- As an asynchronous iterable, the object supports the `for` and `async for`
90
- statements. Just loop over the object to read in chunks.
91
-
92
- **Usage**
93
+ class _StreamReaderThroughServer(Generic[T]):
94
+ """A StreamReader implementation that reads from the server."""
93
95
 
94
- ```python fixture:running_app
95
- from modal import Sandbox
96
-
97
- sandbox = Sandbox.create(
98
- "bash",
99
- "-c",
100
- "for i in $(seq 1 10); do echo foo; sleep 0.1; done",
101
- app=running_app,
102
- )
103
- for message in sandbox.stdout:
104
- print(f"Message: {message}")
105
- ```
106
- """
107
-
108
- _stream: Optional[AsyncGenerator[Optional[bytes], None]]
96
+ _stream: Optional[AsyncGenerator[T, None]]
109
97
 
110
98
  def __init__(
111
99
  self,
@@ -133,10 +121,6 @@ class _StreamReader(Generic[T]):
133
121
  if object_type == "sandbox" and not text:
134
122
  raise ValueError("Sandbox streams must have text mode enabled.")
135
123
 
136
- # line-buffering is only supported when text=True
137
- if by_line and not text:
138
- raise ValueError("line-buffering is only supported when text=True")
139
-
140
124
  self._text = text
141
125
  self._by_line = by_line
142
126
 
@@ -154,10 +138,9 @@ class _StreamReader(Generic[T]):
154
138
  self._stream_type = stream_type
155
139
 
156
140
  if self._object_type == "container_process":
157
- # Container process streams need to be consumed as they are produced,
158
- # otherwise the process will block. Use a buffer to store the stream
159
- # until the client consumes it.
160
- self._container_process_buffer: list[Optional[bytes]] = []
141
+ # TODO: we should not have this async code in constructors!
142
+ # it only works as long as all the construction happens inside of synchronicity code
143
+ self._container_process_buffer: list[Optional[bytes]] = [] # TODO: change this to an asyncio.Queue
161
144
  self._consume_container_process_task = asyncio.create_task(self._consume_container_process_stream())
162
145
 
163
146
  @property
@@ -166,35 +149,20 @@ class _StreamReader(Generic[T]):
166
149
  return self._file_descriptor
167
150
 
168
151
  async def read(self) -> T:
169
- """Fetch the entire contents of the stream until EOF.
170
-
171
- **Usage**
172
-
173
- ```python fixture:running_app
174
- from modal import Sandbox
175
-
176
- sandbox = Sandbox.create("echo", "hello", app=running_app)
177
- sandbox.wait()
178
-
179
- print(sandbox.stdout.read())
180
- ```
181
- """
182
- data_str = ""
183
- data_bytes = b""
152
+ """Fetch the entire contents of the stream until EOF."""
184
153
  logger.debug(f"{self._object_id} StreamReader fd={self._file_descriptor} read starting")
185
- async for message in self._get_logs():
186
- if message is None:
187
- break
188
- if self._text:
189
- data_str += message.decode("utf-8")
190
- else:
191
- data_bytes += message
192
-
193
- logger.debug(f"{self._object_id} StreamReader fd={self._file_descriptor} read completed after EOF")
194
154
  if self._text:
195
- return cast(T, data_str)
155
+ buffer = io.StringIO()
156
+ async for message in _decode_bytes_stream_to_str(self._get_logs()):
157
+ buffer.write(message)
158
+ logger.debug(f"{self._object_id} StreamReader fd={self._file_descriptor} read completed after EOF")
159
+ return cast(T, buffer.getvalue())
196
160
  else:
197
- return cast(T, data_bytes)
161
+ buffer = io.BytesIO()
162
+ async for message in self._get_logs():
163
+ buffer.write(message)
164
+ logger.debug(f"{self._object_id} StreamReader fd={self._file_descriptor} read completed after EOF")
165
+ return cast(T, buffer.getvalue())
198
166
 
199
167
  async def _consume_container_process_stream(self):
200
168
  """Consume the container process stream and store messages in the buffer."""
@@ -213,7 +181,8 @@ class _StreamReader(Generic[T]):
213
181
  )
214
182
  async for message, batch_index in iterator:
215
183
  if self._stream_type == StreamType.STDOUT and message:
216
- print(message.decode("utf-8"), end="")
184
+ # TODO: rearchitect this, since these bytes aren't necessarily decodable
185
+ print(message.decode("utf-8"), end="") # noqa: T201
217
186
  elif self._stream_type == StreamType.PIPE:
218
187
  self._container_process_buffer.append(message)
219
188
 
@@ -223,13 +192,12 @@ class _StreamReader(Generic[T]):
223
192
  else:
224
193
  last_index = batch_index
225
194
 
226
- except (GRPCError, StreamTerminatedError, ClientClosed) as exc:
195
+ except (ServiceError, InternalError, StreamTerminatedError, ClientClosed) as exc:
227
196
  if retries_remaining > 0:
228
197
  retries_remaining -= 1
229
- if isinstance(exc, GRPCError):
230
- if exc.status in RETRYABLE_GRPC_STATUS_CODES:
231
- await asyncio.sleep(1.0)
232
- continue
198
+ if isinstance(exc, (ServiceError, InternalError)):
199
+ await asyncio.sleep(1.0)
200
+ continue
233
201
  elif isinstance(exc, StreamTerminatedError):
234
202
  continue
235
203
  elif isinstance(exc, ClientClosed):
@@ -240,6 +208,9 @@ class _StreamReader(Generic[T]):
240
208
 
241
209
  async def _stream_container_process(self) -> AsyncGenerator[tuple[Optional[bytes], str], None]:
242
210
  """Streams the container process buffer to the reader."""
211
+ # Container process streams need to be consumed as they are produced,
212
+ # otherwise the process will block. Use a buffer to store the stream
213
+ # until the client consumes it.
243
214
  entry_id = 0
244
215
  if self._last_entry_id:
245
216
  entry_id = int(self._last_entry_id) + 1
@@ -257,7 +228,7 @@ class _StreamReader(Generic[T]):
257
228
 
258
229
  entry_id += 1
259
230
 
260
- async def _get_logs(self, skip_empty_messages: bool = True) -> AsyncGenerator[Optional[bytes], None]:
231
+ async def _get_logs(self, skip_empty_messages: bool = True) -> AsyncGenerator[bytes, None]:
261
232
  """Streams sandbox or process logs from the server to the reader.
262
233
 
263
234
  Logs returned by this method may contain partial or multiple lines at a time.
@@ -269,7 +240,6 @@ class _StreamReader(Generic[T]):
269
240
  raise InvalidError("Logs can only be retrieved using the PIPE stream type.")
270
241
 
271
242
  if self.eof:
272
- yield None
273
243
  return
274
244
 
275
245
  completed = False
@@ -294,72 +264,373 @@ class _StreamReader(Generic[T]):
294
264
  if message is None:
295
265
  completed = True
296
266
  self.eof = True
267
+ return
268
+
297
269
  yield message
298
270
 
299
- except (GRPCError, StreamTerminatedError) as exc:
271
+ except (ServiceError, InternalError, StreamTerminatedError) as exc:
300
272
  if retries_remaining > 0:
301
273
  retries_remaining -= 1
302
- if isinstance(exc, GRPCError):
303
- if exc.status in RETRYABLE_GRPC_STATUS_CODES:
304
- await asyncio.sleep(1.0)
305
- continue
274
+ if isinstance(exc, (ServiceError, InternalError)):
275
+ await asyncio.sleep(1.0)
276
+ continue
306
277
  elif isinstance(exc, StreamTerminatedError):
307
278
  continue
308
279
  raise
309
280
 
310
- async def _get_logs_by_line(self) -> AsyncGenerator[Optional[bytes], None]:
281
+ async def _get_logs_by_line(self) -> AsyncGenerator[bytes, None]:
311
282
  """Process logs from the server and yield complete lines only."""
312
283
  async for message in self._get_logs():
313
- if message is None:
314
- if self._line_buffer:
315
- yield self._line_buffer
316
- self._line_buffer = b""
317
- yield None
318
- else:
319
- assert isinstance(message, bytes)
320
- self._line_buffer += message
321
- while b"\n" in self._line_buffer:
322
- line, self._line_buffer = self._line_buffer.split(b"\n", 1)
323
- yield line + b"\n"
284
+ assert isinstance(message, bytes)
285
+ self._line_buffer += message
286
+ while b"\n" in self._line_buffer:
287
+ line, self._line_buffer = self._line_buffer.split(b"\n", 1)
288
+ yield line + b"\n"
289
+
290
+ if self._line_buffer:
291
+ yield self._line_buffer
292
+ self._line_buffer = b""
324
293
 
325
- def _ensure_stream(self) -> AsyncGenerator[Optional[bytes], None]:
294
+ def __aiter__(self) -> AsyncGenerator[T, None]:
326
295
  if not self._stream:
327
296
  if self._by_line:
328
- self._stream = self._get_logs_by_line()
297
+ # TODO: This is quite odd - it does line buffering in binary mode
298
+ # but we then always add the buffered text decoding on top of that.
299
+ # feels a bit upside down...
300
+ stream = self._get_logs_by_line()
329
301
  else:
330
- self._stream = self._get_logs()
302
+ stream = self._get_logs()
303
+ if self._text:
304
+ stream = _decode_bytes_stream_to_str(stream)
305
+ self._stream = cast(AsyncGenerator[T, None], stream)
331
306
  return self._stream
332
307
 
333
- def __aiter__(self) -> AsyncIterator[T]:
308
+ async def aclose(self):
334
309
  """mdmd:hidden"""
335
- self._ensure_stream()
336
- return self
310
+ if self._stream:
311
+ await self._stream.aclose()
312
+
313
+
314
+ async def _decode_bytes_stream_to_str(stream: AsyncGenerator[bytes, None]) -> AsyncGenerator[str, None]:
315
+ """Incrementally decode a bytes async generator as UTF-8 without breaking on chunk boundaries.
316
+
317
+ This function uses a streaming UTF-8 decoder so that multi-byte characters split across
318
+ chunks are handled correctly instead of raising ``UnicodeDecodeError``.
319
+ """
320
+ decoder = codecs.getincrementaldecoder("utf-8")(errors="strict")
321
+ async for item in stream:
322
+ text = decoder.decode(item, final=False)
323
+ if text:
324
+ yield text
325
+
326
+ # Flush any buffered partial character at end-of-stream
327
+ tail = decoder.decode(b"", final=True)
328
+ if tail:
329
+ yield tail
330
+
331
+
332
+ async def _stream_by_line(stream: AsyncGenerator[bytes, None]) -> AsyncGenerator[bytes, None]:
333
+ """Yield complete lines only (ending with \n), buffering partial lines until complete.
334
+
335
+ When this generator returns, the underlying generator is closed.
336
+ """
337
+ line_buffer = b""
338
+ try:
339
+ async for message in stream:
340
+ assert isinstance(message, bytes)
341
+ line_buffer += message
342
+ while b"\n" in line_buffer:
343
+ line, line_buffer = line_buffer.split(b"\n", 1)
344
+ yield line + b"\n"
345
+
346
+ if line_buffer:
347
+ yield line_buffer
348
+ finally:
349
+ await stream.aclose()
350
+
351
+
352
+ @dataclass
353
+ class _StreamReaderThroughCommandRouterParams:
354
+ file_descriptor: "api_pb2.FileDescriptor.ValueType"
355
+ task_id: str
356
+ object_id: str
357
+ command_router_client: TaskCommandRouterClient
358
+ deadline: Optional[float]
359
+
360
+
361
+ async def _stdio_stream_from_command_router(
362
+ params: _StreamReaderThroughCommandRouterParams,
363
+ ) -> AsyncGenerator[bytes, None]:
364
+ """Stream raw bytes from the router client."""
365
+ async with aclosing(
366
+ params.command_router_client.exec_stdio_read(
367
+ params.task_id, params.object_id, params.file_descriptor, params.deadline
368
+ )
369
+ ) as stream:
370
+ try:
371
+ async for item in stream:
372
+ if len(item.data) == 0:
373
+ # This is an error.
374
+ raise ValueError("Received empty message streaming stdio from sandbox.")
375
+
376
+ yield item.data
377
+ except ExecTimeoutError:
378
+ logger.debug(f"Deadline exceeded while streaming stdio for exec {params.object_id}")
379
+ # TODO(saltzm): This is a weird API, but customers currently may rely on it. We
380
+ # should probably raise this error rather than just ending the stream.
381
+ return
382
+
383
+
384
+ class _BytesStreamReaderThroughCommandRouter:
385
+ """
386
+ StreamReader implementation that will read directly from the worker that
387
+ hosts the sandbox.
388
+
389
+ This implementation is used for non-text streams.
390
+ """
391
+
392
+ def __init__(
393
+ self,
394
+ params: _StreamReaderThroughCommandRouterParams,
395
+ ) -> None:
396
+ self._params = params
397
+ self._stream = None
398
+
399
+ @property
400
+ def file_descriptor(self) -> int:
401
+ return self._params.file_descriptor
402
+
403
+ async def read(self) -> bytes:
404
+ buffer = io.BytesIO()
405
+ async for part in self:
406
+ buffer.write(part)
407
+ return buffer.getvalue()
408
+
409
+ def __aiter__(self) -> AsyncGenerator[bytes, None]:
410
+ return _stdio_stream_from_command_router(self._params)
411
+
412
+ async def _print_all(self, output_stream: TextIO) -> None:
413
+ async for part in self:
414
+ output_stream.buffer.write(part)
415
+ output_stream.buffer.flush()
416
+
417
+
418
+ class _TextStreamReaderThroughCommandRouter:
419
+ """
420
+ StreamReader implementation that will read directly from the worker
421
+ that hosts the sandbox.
422
+
423
+ This implementation is used for text streams.
424
+ """
425
+
426
+ def __init__(
427
+ self,
428
+ params: _StreamReaderThroughCommandRouterParams,
429
+ by_line: bool,
430
+ ) -> None:
431
+ self._params = params
432
+ self._by_line = by_line
433
+
434
+ @property
435
+ def file_descriptor(self) -> int:
436
+ return self._params.file_descriptor
437
+
438
+ async def read(self) -> str:
439
+ buffer = io.StringIO()
440
+ async for part in self:
441
+ buffer.write(part)
442
+ return buffer.getvalue()
443
+
444
+ async def __aiter__(self) -> AsyncGenerator[str, None]:
445
+ async with aclosing(_stdio_stream_from_command_router(self._params)) as bytes_stream:
446
+ if self._by_line:
447
+ stream = _decode_bytes_stream_to_str(_stream_by_line(bytes_stream))
448
+ else:
449
+ stream = _decode_bytes_stream_to_str(bytes_stream)
450
+
451
+ async with aclosing(stream):
452
+ async for part in stream:
453
+ yield part
454
+
455
+ async def _print_all(self, output_stream: TextIO) -> None:
456
+ async with aclosing(self.__aiter__()) as stream:
457
+ async for part in stream:
458
+ output_stream.write(part)
459
+
460
+
461
+ class _StdoutPrintingStreamReaderThroughCommandRouter(Generic[T]):
462
+ """
463
+ StreamReader implementation for StreamType.STDOUT when using the task command router.
464
+
465
+ This mirrors the behavior from the server-backed implementation: the stream is printed to
466
+ the local stdout immediately and is not readable via StreamReader methods.
467
+ """
468
+
469
+ _reader: Union[_TextStreamReaderThroughCommandRouter, _BytesStreamReaderThroughCommandRouter]
470
+
471
+ def __init__(
472
+ self,
473
+ reader: Union[_TextStreamReaderThroughCommandRouter, _BytesStreamReaderThroughCommandRouter],
474
+ ) -> None:
475
+ self._reader = reader
476
+ self._task: Optional[asyncio.Task[None]] = None
477
+ # Kick off a background task that reads from the underlying text stream and prints to stdout.
478
+ self._start_printing_task()
479
+
480
+ @property
481
+ def file_descriptor(self) -> int:
482
+ return self._reader.file_descriptor
483
+
484
+ def _start_printing_task(self) -> None:
485
+ async def _run():
486
+ try:
487
+ await self._reader._print_all(sys.stdout)
488
+ except Exception as e:
489
+ logger.exception(f"Error printing stream: {e}")
490
+
491
+ self._task = asyncio.create_task(_run())
492
+
493
+ async def read(self) -> T:
494
+ raise InvalidError("Output can only be retrieved using the PIPE stream type.")
495
+
496
+ def __aiter__(self) -> AsyncIterator[T]:
497
+ raise InvalidError("Output can only be retrieved using the PIPE stream type.")
337
498
 
338
499
  async def __anext__(self) -> T:
339
- """mdmd:hidden"""
340
- stream = self._ensure_stream()
500
+ raise InvalidError("Output can only be retrieved using the PIPE stream type.")
341
501
 
342
- value = await stream.__anext__()
502
+ async def aclose(self):
503
+ if self._task is not None:
504
+ self._task.cancel()
505
+ with contextlib.suppress(asyncio.CancelledError):
506
+ await self._task
507
+ self._task = None
343
508
 
344
- # The stream yields None if it receives an EOF batch.
345
- if value is None:
346
- raise StopAsyncIteration
347
509
 
348
- if self._text:
349
- return cast(T, value.decode("utf-8"))
510
+ class _DevnullStreamReader(Generic[T]):
511
+ """StreamReader implementation for a stream configured with
512
+ StreamType.DEVNULL. Throws an error if read or any other method is
513
+ called.
514
+ """
515
+
516
+ def __init__(self, file_descriptor: "api_pb2.FileDescriptor.ValueType") -> None:
517
+ self._file_descriptor = file_descriptor
518
+
519
+ @property
520
+ def file_descriptor(self) -> int:
521
+ return self._file_descriptor
522
+
523
+ async def read(self) -> T:
524
+ raise ValueError("read is not supported for a stream configured with StreamType.DEVNULL")
525
+
526
+ def __aiter__(self) -> AsyncIterator[T]:
527
+ raise ValueError("__aiter__ is not supported for a stream configured with StreamType.DEVNULL")
528
+
529
+ async def __anext__(self) -> T:
530
+ raise ValueError("__anext__ is not supported for a stream configured with StreamType.DEVNULL")
531
+
532
+ async def aclose(self):
533
+ raise ValueError("aclose is not supported for a stream configured with StreamType.DEVNULL")
534
+
535
+
536
+ class _StreamReader(Generic[T]):
537
+ """Retrieve logs from a stream (`stdout` or `stderr`).
538
+
539
+ As an asynchronous iterable, the object supports the `for` and `async for`
540
+ statements. Just loop over the object to read in chunks.
541
+ """
542
+
543
+ _impl: Union[
544
+ _StreamReaderThroughServer,
545
+ _DevnullStreamReader,
546
+ _TextStreamReaderThroughCommandRouter,
547
+ _BytesStreamReaderThroughCommandRouter,
548
+ _StdoutPrintingStreamReaderThroughCommandRouter,
549
+ ]
550
+ _read_gen: Optional[AsyncGenerator[T, None]] = None
551
+
552
+ def __init__(
553
+ self,
554
+ file_descriptor: "api_pb2.FileDescriptor.ValueType",
555
+ object_id: str,
556
+ object_type: Literal["sandbox", "container_process"],
557
+ client: _Client,
558
+ stream_type: StreamType = StreamType.PIPE,
559
+ text: bool = True,
560
+ by_line: bool = False,
561
+ deadline: Optional[float] = None,
562
+ command_router_client: Optional[TaskCommandRouterClient] = None,
563
+ task_id: Optional[str] = None,
564
+ ) -> None:
565
+ """mdmd:hidden"""
566
+ # we can remove this once we ensure no constructors use async code
567
+ assert asyncio.get_running_loop() == synchronizer._get_loop(start=False)
568
+
569
+ if by_line and not text:
570
+ raise ValueError("line-buffering is only supported when text=True")
571
+
572
+ if command_router_client is None:
573
+ self._impl = _StreamReaderThroughServer(
574
+ file_descriptor, object_id, object_type, client, stream_type, text, by_line, deadline
575
+ )
350
576
  else:
351
- return cast(T, value)
577
+ # The only reason task_id is optional is because StreamReader is also used for sandbox
578
+ # logs, which don't have a task ID available when the StreamReader is created.
579
+ assert task_id is not None
580
+ assert object_type == "container_process"
581
+ if stream_type == StreamType.DEVNULL:
582
+ self._impl = _DevnullStreamReader(file_descriptor)
583
+ else:
584
+ assert stream_type == StreamType.PIPE or stream_type == StreamType.STDOUT
585
+ params = _StreamReaderThroughCommandRouterParams(
586
+ file_descriptor, task_id, object_id, command_router_client, deadline
587
+ )
588
+ if text:
589
+ reader = _TextStreamReaderThroughCommandRouter(params, by_line)
590
+ else:
591
+ reader = _BytesStreamReaderThroughCommandRouter(params)
592
+
593
+ if stream_type == StreamType.STDOUT:
594
+ self._impl = _StdoutPrintingStreamReaderThroughCommandRouter(reader)
595
+ else:
596
+ self._impl = reader
597
+
598
+ @property
599
+ def file_descriptor(self) -> int:
600
+ """Possible values are `1` for stdout and `2` for stderr."""
601
+ return self._impl.file_descriptor
602
+
603
+ async def read(self) -> T:
604
+ """Fetch the entire contents of the stream until EOF."""
605
+ return cast(T, await self._impl.read())
606
+
607
+ def __aiter__(self) -> AsyncGenerator[T, None]:
608
+ if not self._read_gen:
609
+ self._read_gen = cast(AsyncGenerator[T, None], self._impl.__aiter__())
610
+ return self._read_gen
611
+
612
+ async def __anext__(self) -> T:
613
+ """Deprecated: This exists for backwards compatibility and will be removed in a future version of Modal
614
+
615
+ Only use next/anext on the return value of iter/aiter on the StreamReader object (treat streamreader as
616
+ an iterable, not an iterator).
617
+ """
618
+ if not self._read_gen:
619
+ self.__aiter__() # initialize the read generator
620
+ assert self._read_gen
621
+ return await self._read_gen.__anext__()
352
622
 
353
623
  async def aclose(self):
354
624
  """mdmd:hidden"""
355
- if self._stream:
356
- await self._stream.aclose()
625
+ if self._read_gen:
626
+ await self._read_gen.aclose()
627
+ self._read_gen = None
357
628
 
358
629
 
359
630
  MAX_BUFFER_SIZE = 2 * 1024 * 1024
360
631
 
361
632
 
362
- class _StreamWriter:
633
+ class _StreamWriterThroughServer:
363
634
  """Provides an interface to buffer and write logs to a sandbox or container process stream (`stdin`)."""
364
635
 
365
636
  def __init__(self, object_id: str, object_type: Literal["sandbox", "container_process"], client: _Client) -> None:
@@ -381,25 +652,6 @@ class _StreamWriter:
381
652
 
382
653
  This is non-blocking and queues the data to an internal buffer. Must be
383
654
  used along with the `drain()` method, which flushes the buffer.
384
-
385
- **Usage**
386
-
387
- ```python fixture:running_app
388
- from modal import Sandbox
389
-
390
- sandbox = Sandbox.create(
391
- "bash",
392
- "-c",
393
- "while read line; do echo $line; done",
394
- app=running_app,
395
- )
396
- sandbox.stdin.write(b"foo\\n")
397
- sandbox.stdin.write(b"bar\\n")
398
- sandbox.stdin.write_eof()
399
-
400
- sandbox.stdin.drain()
401
- sandbox.wait()
402
- ```
403
655
  """
404
656
  if self._is_closed:
405
657
  raise ValueError("Stdin is closed. Cannot write to it.")
@@ -407,7 +659,7 @@ class _StreamWriter:
407
659
  if isinstance(data, str):
408
660
  data = data.encode("utf-8")
409
661
  if len(self._buffer) + len(data) > MAX_BUFFER_SIZE:
410
- raise BufferError("Buffer size exceed limit. Call drain to clear the buffer.")
662
+ raise BufferError("Buffer size exceed limit. Call drain to flush the buffer.")
411
663
  self._buffer.extend(data)
412
664
  else:
413
665
  raise TypeError(f"data argument must be a bytes-like object, not {type(data).__name__}")
@@ -426,19 +678,6 @@ class _StreamWriter:
426
678
 
427
679
  This is a flow control method that blocks until data is sent. It returns
428
680
  when it is appropriate to continue writing data to the stream.
429
-
430
- **Usage**
431
-
432
- ```python notest
433
- writer.write(data)
434
- writer.drain()
435
- ```
436
-
437
- Async usage:
438
- ```python notest
439
- writer.write(data) # not a blocking operation
440
- await writer.drain.aio()
441
- ```
442
681
  """
443
682
  data = bytes(self._buffer)
444
683
  self._buffer.clear()
@@ -446,25 +685,137 @@ class _StreamWriter:
446
685
 
447
686
  try:
448
687
  if self._object_type == "sandbox":
449
- await retry_transient_errors(
450
- self._client.stub.SandboxStdinWrite,
688
+ await self._client.stub.SandboxStdinWrite(
451
689
  api_pb2.SandboxStdinWriteRequest(
452
690
  sandbox_id=self._object_id, index=index, eof=self._is_closed, input=data
453
691
  ),
454
692
  )
455
693
  else:
456
- await retry_transient_errors(
457
- self._client.stub.ContainerExecPutInput,
694
+ await self._client.stub.ContainerExecPutInput(
458
695
  api_pb2.ContainerExecPutInputRequest(
459
696
  exec_id=self._object_id,
460
697
  input=api_pb2.RuntimeInputMessage(message=data, message_index=index, eof=self._is_closed),
461
698
  ),
462
699
  )
463
- except GRPCError as exc:
464
- if exc.status == Status.FAILED_PRECONDITION:
465
- raise ValueError(exc.message)
466
- else:
467
- raise exc
700
+ except ConflictError as exc:
701
+ raise ValueError(str(exc))
702
+
703
+
704
+ class _StreamWriterThroughCommandRouter:
705
+ def __init__(
706
+ self,
707
+ object_id: str,
708
+ command_router_client: TaskCommandRouterClient,
709
+ task_id: str,
710
+ ) -> None:
711
+ self._object_id = object_id
712
+ self._command_router_client = command_router_client
713
+ self._task_id = task_id
714
+ self._is_closed = False
715
+ self._buffer = bytearray()
716
+ self._offset = 0
717
+
718
+ def write(self, data: Union[bytes, bytearray, memoryview, str]) -> None:
719
+ if self._is_closed:
720
+ raise ValueError("Stdin is closed. Cannot write to it.")
721
+ if isinstance(data, (bytes, bytearray, memoryview, str)):
722
+ if isinstance(data, str):
723
+ data = data.encode("utf-8")
724
+ if len(self._buffer) + len(data) > MAX_BUFFER_SIZE:
725
+ raise BufferError("Buffer size exceed limit. Call drain to flush the buffer.")
726
+ self._buffer.extend(data)
727
+ else:
728
+ raise TypeError(f"data argument must be a bytes-like object, not {type(data).__name__}")
729
+
730
+ def write_eof(self) -> None:
731
+ self._is_closed = True
732
+
733
+ async def drain(self) -> None:
734
+ eof = self._is_closed
735
+ # NB: There's no need to prevent writing eof twice, because the command router will ignore the second EOF.
736
+ if self._buffer or eof:
737
+ data = bytes(self._buffer)
738
+ await self._command_router_client.exec_stdin_write(
739
+ task_id=self._task_id, exec_id=self._object_id, offset=self._offset, data=data, eof=eof
740
+ )
741
+ # Only clear the buffer after writing the data to the command router is successful.
742
+ # This allows the client to retry drain() in the event of an exception (though
743
+ # exec_stdin_write already retries on transient errors, so most users will probably
744
+ # not do this).
745
+ self._buffer.clear()
746
+ self._offset += len(data)
747
+
748
+
749
+ class _StreamWriter:
750
+ """Provides an interface to buffer and write logs to a sandbox or container process stream (`stdin`)."""
751
+
752
+ def __init__(
753
+ self,
754
+ object_id: str,
755
+ object_type: Literal["sandbox", "container_process"],
756
+ client: _Client,
757
+ command_router_client: Optional[TaskCommandRouterClient] = None,
758
+ task_id: Optional[str] = None,
759
+ ) -> None:
760
+ """mdmd:hidden"""
761
+ if command_router_client is None:
762
+ self._impl = _StreamWriterThroughServer(object_id, object_type, client)
763
+ else:
764
+ assert task_id is not None
765
+ assert object_type == "container_process"
766
+ self._impl = _StreamWriterThroughCommandRouter(object_id, command_router_client, task_id=task_id)
767
+
768
+ def write(self, data: Union[bytes, bytearray, memoryview, str]) -> None:
769
+ """Write data to the stream but does not send it immediately.
770
+
771
+ This is non-blocking and queues the data to an internal buffer. Must be
772
+ used along with the `drain()` method, which flushes the buffer.
773
+
774
+ **Usage**
775
+
776
+ ```python fixture:sandbox
777
+ proc = sandbox.exec(
778
+ "bash",
779
+ "-c",
780
+ "while read line; do echo $line; done",
781
+ )
782
+ proc.stdin.write(b"foo\\n")
783
+ proc.stdin.write(b"bar\\n")
784
+ proc.stdin.write_eof()
785
+ proc.stdin.drain()
786
+ ```
787
+ """
788
+ self._impl.write(data)
789
+
790
+ def write_eof(self) -> None:
791
+ """Close the write end of the stream after the buffered data is drained.
792
+
793
+ If the process was blocked on input, it will become unblocked after
794
+ `write_eof()`. This method needs to be used along with the `drain()`
795
+ method, which flushes the EOF to the process.
796
+ """
797
+ self._impl.write_eof()
798
+
799
+ async def drain(self) -> None:
800
+ """Flush the write buffer and send data to the running process.
801
+
802
+ This is a flow control method that blocks until data is sent. It returns
803
+ when it is appropriate to continue writing data to the stream.
804
+
805
+ **Usage**
806
+
807
+ ```python notest
808
+ writer.write(data)
809
+ writer.drain()
810
+ ```
811
+
812
+ Async usage:
813
+ ```python notest
814
+ writer.write(data) # not a blocking operation
815
+ await writer.drain.aio()
816
+ ```
817
+ """
818
+ await self._impl.drain()
468
819
 
469
820
 
470
821
  StreamReader = synchronize_api(_StreamReader)