modal 1.2.0__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of modal might be problematic. Click here for more details.

Files changed (49) hide show
  1. modal/_container_entrypoint.py +4 -1
  2. modal/_partial_function.py +28 -3
  3. modal/_utils/function_utils.py +4 -0
  4. modal/_utils/task_command_router_client.py +537 -0
  5. modal/app.py +93 -54
  6. modal/app.pyi +48 -18
  7. modal/cli/_download.py +19 -3
  8. modal/cli/cluster.py +4 -2
  9. modal/cli/container.py +4 -2
  10. modal/cli/entry_point.py +1 -0
  11. modal/cli/launch.py +1 -2
  12. modal/cli/run.py +6 -0
  13. modal/cli/volume.py +7 -1
  14. modal/client.pyi +2 -2
  15. modal/cls.py +5 -12
  16. modal/config.py +14 -0
  17. modal/container_process.py +283 -3
  18. modal/container_process.pyi +95 -32
  19. modal/exception.py +4 -0
  20. modal/experimental/flash.py +21 -47
  21. modal/experimental/flash.pyi +6 -20
  22. modal/functions.pyi +6 -6
  23. modal/io_streams.py +455 -122
  24. modal/io_streams.pyi +220 -95
  25. modal/partial_function.pyi +4 -1
  26. modal/runner.py +39 -36
  27. modal/runner.pyi +40 -24
  28. modal/sandbox.py +130 -11
  29. modal/sandbox.pyi +145 -9
  30. modal/volume.py +23 -3
  31. modal/volume.pyi +30 -0
  32. {modal-1.2.0.dist-info → modal-1.2.1.dist-info}/METADATA +5 -5
  33. {modal-1.2.0.dist-info → modal-1.2.1.dist-info}/RECORD +49 -48
  34. modal_proto/api.proto +2 -26
  35. modal_proto/api_grpc.py +0 -32
  36. modal_proto/api_pb2.py +327 -367
  37. modal_proto/api_pb2.pyi +6 -69
  38. modal_proto/api_pb2_grpc.py +0 -67
  39. modal_proto/api_pb2_grpc.pyi +0 -22
  40. modal_proto/modal_api_grpc.py +0 -2
  41. modal_proto/sandbox_router.proto +0 -4
  42. modal_proto/sandbox_router_pb2.pyi +0 -4
  43. modal_proto/task_command_router.proto +1 -1
  44. modal_proto/task_command_router_pb2.py +2 -2
  45. modal_version/__init__.py +1 -1
  46. {modal-1.2.0.dist-info → modal-1.2.1.dist-info}/WHEEL +0 -0
  47. {modal-1.2.0.dist-info → modal-1.2.1.dist-info}/entry_points.txt +0 -0
  48. {modal-1.2.0.dist-info → modal-1.2.1.dist-info}/licenses/LICENSE +0 -0
  49. {modal-1.2.0.dist-info → modal-1.2.1.dist-info}/top_level.txt +0 -0
modal/io_streams.py CHANGED
@@ -1,7 +1,9 @@
1
1
  # Copyright Modal Labs 2022
2
2
  import asyncio
3
+ import codecs
3
4
  import time
4
5
  from collections.abc import AsyncGenerator, AsyncIterator
6
+ from dataclasses import dataclass
5
7
  from typing import (
6
8
  TYPE_CHECKING,
7
9
  Generic,
@@ -15,11 +17,12 @@ from typing import (
15
17
  from grpclib import Status
16
18
  from grpclib.exceptions import GRPCError, StreamTerminatedError
17
19
 
18
- from modal.exception import ClientClosed, InvalidError
20
+ from modal.exception import ClientClosed, ExecTimeoutError, InvalidError
19
21
  from modal_proto import api_pb2
20
22
 
21
23
  from ._utils.async_utils import synchronize_api
22
24
  from ._utils.grpc_utils import RETRYABLE_GRPC_STATUS_CODES, retry_transient_errors
25
+ from ._utils.task_command_router_client import TaskCommandRouterClient
23
26
  from .client import _Client
24
27
  from .config import logger
25
28
  from .stream_type import StreamType
@@ -61,7 +64,6 @@ async def _container_process_logs_iterator(
61
64
  get_raw_bytes=True,
62
65
  last_batch_index=last_index,
63
66
  )
64
-
65
67
  stream = client.stub.ContainerExecGetOutput.unary_stream(req)
66
68
  while True:
67
69
  # Check deadline before attempting to receive the next batch
@@ -73,39 +75,22 @@ async def _container_process_logs_iterator(
73
75
  break
74
76
  except StopAsyncIteration:
75
77
  break
78
+
79
+ for item in batch.items:
80
+ yield item.message_bytes, batch.batch_index
81
+
76
82
  if batch.HasField("exit_code"):
77
83
  yield None, batch.batch_index
78
84
  break
79
- for item in batch.items:
80
- yield item.message_bytes, batch.batch_index
81
85
 
82
86
 
83
87
  T = TypeVar("T", str, bytes)
84
88
 
85
89
 
86
- class _StreamReader(Generic[T]):
87
- """Retrieve logs from a stream (`stdout` or `stderr`).
88
-
89
- As an asynchronous iterable, the object supports the `for` and `async for`
90
- statements. Just loop over the object to read in chunks.
91
-
92
- **Usage**
93
-
94
- ```python fixture:running_app
95
- from modal import Sandbox
96
-
97
- sandbox = Sandbox.create(
98
- "bash",
99
- "-c",
100
- "for i in $(seq 1 10); do echo foo; sleep 0.1; done",
101
- app=running_app,
102
- )
103
- for message in sandbox.stdout:
104
- print(f"Message: {message}")
105
- ```
106
- """
90
+ class _StreamReaderThroughServer(Generic[T]):
91
+ """A StreamReader implementation that reads from the server."""
107
92
 
108
- _stream: Optional[AsyncGenerator[Optional[bytes], None]]
93
+ _stream: Optional[AsyncGenerator[T, None]]
109
94
 
110
95
  def __init__(
111
96
  self,
@@ -133,10 +118,6 @@ class _StreamReader(Generic[T]):
133
118
  if object_type == "sandbox" and not text:
134
119
  raise ValueError("Sandbox streams must have text mode enabled.")
135
120
 
136
- # line-buffering is only supported when text=True
137
- if by_line and not text:
138
- raise ValueError("line-buffering is only supported when text=True")
139
-
140
121
  self._text = text
141
122
  self._by_line = by_line
142
123
 
@@ -154,10 +135,9 @@ class _StreamReader(Generic[T]):
154
135
  self._stream_type = stream_type
155
136
 
156
137
  if self._object_type == "container_process":
157
- # Container process streams need to be consumed as they are produced,
158
- # otherwise the process will block. Use a buffer to store the stream
159
- # until the client consumes it.
160
- self._container_process_buffer: list[Optional[bytes]] = []
138
+ # TODO: we should not have this async code in constructors!
139
+ # it only works as long as all the construction happens inside of synchronicity code
140
+ self._container_process_buffer: list[Optional[bytes]] = [] # TODO: change this to an asyncio.Queue
161
141
  self._consume_container_process_task = asyncio.create_task(self._consume_container_process_stream())
162
142
 
163
143
  @property
@@ -166,34 +146,19 @@ class _StreamReader(Generic[T]):
166
146
  return self._file_descriptor
167
147
 
168
148
  async def read(self) -> T:
169
- """Fetch the entire contents of the stream until EOF.
170
-
171
- **Usage**
172
-
173
- ```python fixture:running_app
174
- from modal import Sandbox
175
-
176
- sandbox = Sandbox.create("echo", "hello", app=running_app)
177
- sandbox.wait()
178
-
179
- print(sandbox.stdout.read())
180
- ```
181
- """
182
- data_str = ""
183
- data_bytes = b""
149
+ """Fetch the entire contents of the stream until EOF."""
184
150
  logger.debug(f"{self._object_id} StreamReader fd={self._file_descriptor} read starting")
185
- async for message in self._get_logs():
186
- if message is None:
187
- break
188
- if self._text:
189
- data_str += message.decode("utf-8")
190
- else:
191
- data_bytes += message
192
-
193
- logger.debug(f"{self._object_id} StreamReader fd={self._file_descriptor} read completed after EOF")
194
151
  if self._text:
152
+ data_str = ""
153
+ async for message in _decode_bytes_stream_to_str(self._get_logs()):
154
+ data_str += message
155
+ logger.debug(f"{self._object_id} StreamReader fd={self._file_descriptor} read completed after EOF")
195
156
  return cast(T, data_str)
196
157
  else:
158
+ data_bytes = b""
159
+ async for message in self._get_logs():
160
+ data_bytes += message
161
+ logger.debug(f"{self._object_id} StreamReader fd={self._file_descriptor} read completed after EOF")
197
162
  return cast(T, data_bytes)
198
163
 
199
164
  async def _consume_container_process_stream(self):
@@ -213,6 +178,7 @@ class _StreamReader(Generic[T]):
213
178
  )
214
179
  async for message, batch_index in iterator:
215
180
  if self._stream_type == StreamType.STDOUT and message:
181
+ # TODO: rearchitect this, since these bytes aren't necessarily decodable
216
182
  print(message.decode("utf-8"), end="")
217
183
  elif self._stream_type == StreamType.PIPE:
218
184
  self._container_process_buffer.append(message)
@@ -240,6 +206,9 @@ class _StreamReader(Generic[T]):
240
206
 
241
207
  async def _stream_container_process(self) -> AsyncGenerator[tuple[Optional[bytes], str], None]:
242
208
  """Streams the container process buffer to the reader."""
209
+ # Container process streams need to be consumed as they are produced,
210
+ # otherwise the process will block. Use a buffer to store the stream
211
+ # until the client consumes it.
243
212
  entry_id = 0
244
213
  if self._last_entry_id:
245
214
  entry_id = int(self._last_entry_id) + 1
@@ -257,7 +226,7 @@ class _StreamReader(Generic[T]):
257
226
 
258
227
  entry_id += 1
259
228
 
260
- async def _get_logs(self, skip_empty_messages: bool = True) -> AsyncGenerator[Optional[bytes], None]:
229
+ async def _get_logs(self, skip_empty_messages: bool = True) -> AsyncGenerator[bytes, None]:
261
230
  """Streams sandbox or process logs from the server to the reader.
262
231
 
263
232
  Logs returned by this method may contain partial or multiple lines at a time.
@@ -269,7 +238,6 @@ class _StreamReader(Generic[T]):
269
238
  raise InvalidError("Logs can only be retrieved using the PIPE stream type.")
270
239
 
271
240
  if self.eof:
272
- yield None
273
241
  return
274
242
 
275
243
  completed = False
@@ -294,6 +262,8 @@ class _StreamReader(Generic[T]):
294
262
  if message is None:
295
263
  completed = True
296
264
  self.eof = True
265
+ return
266
+
297
267
  yield message
298
268
 
299
269
  except (GRPCError, StreamTerminatedError) as exc:
@@ -307,59 +277,332 @@ class _StreamReader(Generic[T]):
307
277
  continue
308
278
  raise
309
279
 
310
- async def _get_logs_by_line(self) -> AsyncGenerator[Optional[bytes], None]:
280
+ async def _get_logs_by_line(self) -> AsyncGenerator[bytes, None]:
311
281
  """Process logs from the server and yield complete lines only."""
312
282
  async for message in self._get_logs():
313
- if message is None:
314
- if self._line_buffer:
315
- yield self._line_buffer
316
- self._line_buffer = b""
317
- yield None
318
- else:
319
- assert isinstance(message, bytes)
320
- self._line_buffer += message
321
- while b"\n" in self._line_buffer:
322
- line, self._line_buffer = self._line_buffer.split(b"\n", 1)
323
- yield line + b"\n"
283
+ assert isinstance(message, bytes)
284
+ self._line_buffer += message
285
+ while b"\n" in self._line_buffer:
286
+ line, self._line_buffer = self._line_buffer.split(b"\n", 1)
287
+ yield line + b"\n"
288
+
289
+ if self._line_buffer:
290
+ yield self._line_buffer
291
+ self._line_buffer = b""
324
292
 
325
- def _ensure_stream(self) -> AsyncGenerator[Optional[bytes], None]:
293
+ def _ensure_stream(self) -> AsyncGenerator[T, None]:
326
294
  if not self._stream:
327
295
  if self._by_line:
328
- self._stream = self._get_logs_by_line()
296
+ # TODO: This is quite odd - it does line buffering in binary mode
297
+ # but we then always add the buffered text decoding on top of that.
298
+ # feels a bit upside down...
299
+ stream = self._get_logs_by_line()
329
300
  else:
330
- self._stream = self._get_logs()
301
+ stream = self._get_logs()
302
+ if self._text:
303
+ stream = _decode_bytes_stream_to_str(stream)
304
+ self._stream = cast(AsyncGenerator[T, None], stream)
331
305
  return self._stream
332
306
 
333
- def __aiter__(self) -> AsyncIterator[T]:
307
+ async def __anext__(self) -> T:
308
+ """mdmd:hidden"""
309
+ stream = self._ensure_stream()
310
+ return cast(T, await stream.__anext__())
311
+
312
+ async def aclose(self):
334
313
  """mdmd:hidden"""
335
- self._ensure_stream()
314
+ if self._stream:
315
+ await self._stream.aclose()
316
+
317
+
318
+ async def _decode_bytes_stream_to_str(stream: AsyncGenerator[bytes, None]) -> AsyncGenerator[str, None]:
319
+ """Incrementally decode a bytes async generator as UTF-8 without breaking on chunk boundaries.
320
+
321
+ This function uses a streaming UTF-8 decoder so that multi-byte characters split across
322
+ chunks are handled correctly instead of raising ``UnicodeDecodeError``.
323
+ """
324
+ decoder = codecs.getincrementaldecoder("utf-8")(errors="strict")
325
+ async for item in stream:
326
+ text = decoder.decode(item, final=False)
327
+ if text:
328
+ yield text
329
+
330
+ # Flush any buffered partial character at end-of-stream
331
+ tail = decoder.decode(b"", final=True)
332
+ if tail:
333
+ yield tail
334
+
335
+
336
+ async def _stream_by_line(stream: AsyncGenerator[bytes, None]) -> AsyncGenerator[bytes, None]:
337
+ """Yield complete lines only (ending with \n), buffering partial lines until complete."""
338
+ line_buffer = b""
339
+ async for message in stream:
340
+ assert isinstance(message, bytes)
341
+ line_buffer += message
342
+ while b"\n" in line_buffer:
343
+ line, line_buffer = line_buffer.split(b"\n", 1)
344
+ yield line + b"\n"
345
+
346
+ if line_buffer:
347
+ yield line_buffer
348
+
349
+
350
+ @dataclass
351
+ class _StreamReaderThroughCommandRouterParams:
352
+ file_descriptor: "api_pb2.FileDescriptor.ValueType"
353
+ task_id: str
354
+ object_id: str
355
+ command_router_client: TaskCommandRouterClient
356
+ deadline: Optional[float]
357
+
358
+
359
+ async def _stdio_stream_from_command_router(
360
+ params: _StreamReaderThroughCommandRouterParams,
361
+ ) -> AsyncGenerator[bytes, None]:
362
+ """Stream raw bytes from the router client."""
363
+ stream = params.command_router_client.exec_stdio_read(
364
+ params.task_id, params.object_id, params.file_descriptor, params.deadline
365
+ )
366
+ try:
367
+ async for item in stream:
368
+ if len(item.data) == 0:
369
+ # This is an error.
370
+ raise ValueError("Received empty message streaming stdio from sandbox.")
371
+
372
+ yield item.data
373
+ except ExecTimeoutError:
374
+ logger.debug(f"Deadline exceeded while streaming stdio for exec {params.object_id}")
375
+ # TODO(saltzm): This is a weird API, but customers currently may rely on it. We
376
+ # should probably raise this error rather than just ending the stream.
377
+ return
378
+
379
+
380
+ class _BytesStreamReaderThroughCommandRouter(Generic[T]):
381
+ """
382
+ StreamReader implementation that will read directly from the worker that
383
+ hosts the sandbox.
384
+
385
+ This implementation is used for non-text streams.
386
+ """
387
+
388
+ def __init__(
389
+ self,
390
+ params: _StreamReaderThroughCommandRouterParams,
391
+ ) -> None:
392
+ self._params = params
393
+ self._stream = None
394
+
395
+ @property
396
+ def file_descriptor(self) -> int:
397
+ return self._params.file_descriptor
398
+
399
+ async def read(self) -> T:
400
+ data_bytes = b""
401
+ async for part in self:
402
+ data_bytes += cast(bytes, part)
403
+ return cast(T, data_bytes)
404
+
405
+ def __aiter__(self) -> AsyncIterator[T]:
336
406
  return self
337
407
 
338
408
  async def __anext__(self) -> T:
339
- """mdmd:hidden"""
340
- stream = self._ensure_stream()
409
+ if self._stream is None:
410
+ self._stream = _stdio_stream_from_command_router(self._params)
411
+ # This raises StopAsyncIteration if the stream is at EOF.
412
+ return cast(T, await self._stream.__anext__())
341
413
 
342
- value = await stream.__anext__()
414
+ async def aclose(self):
415
+ if self._stream:
416
+ await self._stream.aclose()
343
417
 
344
- # The stream yields None if it receives an EOF batch.
345
- if value is None:
346
- raise StopAsyncIteration
347
418
 
348
- if self._text:
349
- return cast(T, value.decode("utf-8"))
350
- else:
351
- return cast(T, value)
419
+ class _TextStreamReaderThroughCommandRouter(Generic[T]):
420
+ """
421
+ StreamReader implementation that will read directly from the worker
422
+ that hosts the sandbox.
423
+
424
+ This implementation is used for text streams.
425
+ """
426
+
427
+ def __init__(
428
+ self,
429
+ params: _StreamReaderThroughCommandRouterParams,
430
+ by_line: bool,
431
+ ) -> None:
432
+ self._params = params
433
+ self._by_line = by_line
434
+ self._stream = None
435
+
436
+ @property
437
+ def file_descriptor(self) -> int:
438
+ return self._params.file_descriptor
439
+
440
+ async def read(self) -> T:
441
+ data_str = ""
442
+ async for part in self:
443
+ data_str += cast(str, part)
444
+ return cast(T, data_str)
445
+
446
+ def __aiter__(self) -> AsyncIterator[T]:
447
+ return self
448
+
449
+ async def __anext__(self) -> T:
450
+ if self._stream is None:
451
+ bytes_stream = _stdio_stream_from_command_router(self._params)
452
+ if self._by_line:
453
+ self._stream = _decode_bytes_stream_to_str(_stream_by_line(bytes_stream))
454
+ else:
455
+ self._stream = _decode_bytes_stream_to_str(bytes_stream)
456
+ # This raises StopAsyncIteration if the stream is at EOF.
457
+ return cast(T, await self._stream.__anext__())
352
458
 
353
459
  async def aclose(self):
354
- """mdmd:hidden"""
355
460
  if self._stream:
356
461
  await self._stream.aclose()
357
462
 
358
463
 
464
+ class _DevnullStreamReader(Generic[T]):
465
+ """StreamReader implementation for a stream configured with
466
+ StreamType.DEVNULL. Throws an error if read or any other method is
467
+ called.
468
+ """
469
+
470
+ def __init__(self, file_descriptor: "api_pb2.FileDescriptor.ValueType") -> None:
471
+ self._file_descriptor = file_descriptor
472
+
473
+ @property
474
+ def file_descriptor(self) -> int:
475
+ return self._file_descriptor
476
+
477
+ async def read(self) -> T:
478
+ raise ValueError("read is not supported for a stream configured with StreamType.DEVNULL")
479
+
480
+ def __aiter__(self) -> AsyncIterator[T]:
481
+ raise ValueError("__aiter__ is not supported for a stream configured with StreamType.DEVNULL")
482
+
483
+ async def __anext__(self) -> T:
484
+ raise ValueError("__anext__ is not supported for a stream configured with StreamType.DEVNULL")
485
+
486
+ async def aclose(self):
487
+ raise ValueError("aclose is not supported for a stream configured with StreamType.DEVNULL")
488
+
489
+
490
+ class _StreamReader(Generic[T]):
491
+ """Retrieve logs from a stream (`stdout` or `stderr`).
492
+
493
+ As an asynchronous iterable, the object supports the `for` and `async for`
494
+ statements. Just loop over the object to read in chunks.
495
+
496
+ **Usage**
497
+
498
+ ```python fixture:running_app
499
+ from modal import Sandbox
500
+
501
+ sandbox = Sandbox.create(
502
+ "bash",
503
+ "-c",
504
+ "for i in $(seq 1 10); do echo foo; sleep 0.1; done",
505
+ app=running_app,
506
+ )
507
+ for message in sandbox.stdout:
508
+ print(f"Message: {message}")
509
+ ```
510
+ """
511
+
512
+ def __init__(
513
+ self,
514
+ file_descriptor: "api_pb2.FileDescriptor.ValueType",
515
+ object_id: str,
516
+ object_type: Literal["sandbox", "container_process"],
517
+ client: _Client,
518
+ stream_type: StreamType = StreamType.PIPE,
519
+ text: bool = True,
520
+ by_line: bool = False,
521
+ deadline: Optional[float] = None,
522
+ command_router_client: Optional[TaskCommandRouterClient] = None,
523
+ task_id: Optional[str] = None,
524
+ ) -> None:
525
+ """mdmd:hidden"""
526
+ if by_line and not text:
527
+ raise ValueError("line-buffering is only supported when text=True")
528
+
529
+ if command_router_client is None:
530
+ self._impl = _StreamReaderThroughServer(
531
+ file_descriptor, object_id, object_type, client, stream_type, text, by_line, deadline
532
+ )
533
+ else:
534
+ # The only reason task_id is optional is because StreamReader is
535
+ # also used for sandbox logs, which don't have a task ID available
536
+ # when the StreamReader is created.
537
+ assert task_id is not None
538
+ assert object_type == "container_process"
539
+ if stream_type == StreamType.DEVNULL:
540
+ self._impl = _DevnullStreamReader(file_descriptor)
541
+ else:
542
+ assert stream_type == StreamType.PIPE or stream_type == StreamType.STDOUT
543
+ # TODO(saltzm): The original implementation of STDOUT StreamType in
544
+ # _StreamReaderThroughServer prints to stdout immediately. This doesn't match
545
+ # python subprocess.run, which uses None to print to stdout immediately, and uses
546
+ # STDOUT as an argument to stderr to redirect stderr to the stdout stream. We should
547
+ # implement the old behavior here before moving out of beta, but after that
548
+ # we should consider changing the API to match python subprocess.run. I don't expect
549
+ # many customers are using this in any case, so I think it's fine to leave this
550
+ # unimplemented for now.
551
+ if stream_type == StreamType.STDOUT:
552
+ raise NotImplementedError(
553
+ "Currently the STDOUT stream type is not supported when using exec "
554
+ "through a task command router, which is currently in beta."
555
+ )
556
+ params = _StreamReaderThroughCommandRouterParams(
557
+ file_descriptor, task_id, object_id, command_router_client, deadline
558
+ )
559
+ if text:
560
+ self._impl = _TextStreamReaderThroughCommandRouter(params, by_line)
561
+ else:
562
+ self._impl = _BytesStreamReaderThroughCommandRouter(params)
563
+
564
+ @property
565
+ def file_descriptor(self) -> int:
566
+ """Possible values are `1` for stdout and `2` for stderr."""
567
+ return self._impl.file_descriptor
568
+
569
+ async def read(self) -> T:
570
+ """Fetch the entire contents of the stream until EOF.
571
+
572
+ **Usage**
573
+
574
+ ```python fixture:running_app
575
+ from modal import Sandbox
576
+
577
+ sandbox = Sandbox.create("echo", "hello", app=running_app)
578
+ sandbox.wait()
579
+
580
+ print(sandbox.stdout.read())
581
+ ```
582
+ """
583
+ return await self._impl.read()
584
+
585
+ # TODO(saltzm): I'd prefer to have the implementation classes only implement __aiter__
586
+ # and have them return generator functions directly, but synchronicity doesn't let us
587
+ # return self._impl.__aiter__() here because it won't properly wrap the implementation
588
+ # classes.
589
+ def __aiter__(self) -> AsyncIterator[T]:
590
+ """mdmd:hidden"""
591
+ return self
592
+
593
+ async def __anext__(self) -> T:
594
+ """mdmd:hidden"""
595
+ return await self._impl.__anext__()
596
+
597
+ async def aclose(self):
598
+ """mdmd:hidden"""
599
+ await self._impl.aclose()
600
+
601
+
359
602
  MAX_BUFFER_SIZE = 2 * 1024 * 1024
360
603
 
361
604
 
362
- class _StreamWriter:
605
+ class _StreamWriterThroughServer:
363
606
  """Provides an interface to buffer and write logs to a sandbox or container process stream (`stdin`)."""
364
607
 
365
608
  def __init__(self, object_id: str, object_type: Literal["sandbox", "container_process"], client: _Client) -> None:
@@ -381,25 +624,6 @@ class _StreamWriter:
381
624
 
382
625
  This is non-blocking and queues the data to an internal buffer. Must be
383
626
  used along with the `drain()` method, which flushes the buffer.
384
-
385
- **Usage**
386
-
387
- ```python fixture:running_app
388
- from modal import Sandbox
389
-
390
- sandbox = Sandbox.create(
391
- "bash",
392
- "-c",
393
- "while read line; do echo $line; done",
394
- app=running_app,
395
- )
396
- sandbox.stdin.write(b"foo\\n")
397
- sandbox.stdin.write(b"bar\\n")
398
- sandbox.stdin.write_eof()
399
-
400
- sandbox.stdin.drain()
401
- sandbox.wait()
402
- ```
403
627
  """
404
628
  if self._is_closed:
405
629
  raise ValueError("Stdin is closed. Cannot write to it.")
@@ -407,7 +631,7 @@ class _StreamWriter:
407
631
  if isinstance(data, str):
408
632
  data = data.encode("utf-8")
409
633
  if len(self._buffer) + len(data) > MAX_BUFFER_SIZE:
410
- raise BufferError("Buffer size exceed limit. Call drain to clear the buffer.")
634
+ raise BufferError("Buffer size exceed limit. Call drain to flush the buffer.")
411
635
  self._buffer.extend(data)
412
636
  else:
413
637
  raise TypeError(f"data argument must be a bytes-like object, not {type(data).__name__}")
@@ -426,19 +650,6 @@ class _StreamWriter:
426
650
 
427
651
  This is a flow control method that blocks until data is sent. It returns
428
652
  when it is appropriate to continue writing data to the stream.
429
-
430
- **Usage**
431
-
432
- ```python notest
433
- writer.write(data)
434
- writer.drain()
435
- ```
436
-
437
- Async usage:
438
- ```python notest
439
- writer.write(data) # not a blocking operation
440
- await writer.drain.aio()
441
- ```
442
653
  """
443
654
  data = bytes(self._buffer)
444
655
  self._buffer.clear()
@@ -467,5 +678,127 @@ class _StreamWriter:
467
678
  raise exc
468
679
 
469
680
 
681
+ class _StreamWriterThroughCommandRouter:
682
+ def __init__(
683
+ self,
684
+ object_id: str,
685
+ command_router_client: TaskCommandRouterClient,
686
+ task_id: str,
687
+ ) -> None:
688
+ self._object_id = object_id
689
+ self._command_router_client = command_router_client
690
+ self._task_id = task_id
691
+ self._is_closed = False
692
+ self._buffer = bytearray()
693
+ self._offset = 0
694
+
695
+ def write(self, data: Union[bytes, bytearray, memoryview, str]) -> None:
696
+ if self._is_closed:
697
+ raise ValueError("Stdin is closed. Cannot write to it.")
698
+ if isinstance(data, (bytes, bytearray, memoryview, str)):
699
+ if isinstance(data, str):
700
+ data = data.encode("utf-8")
701
+ if len(self._buffer) + len(data) > MAX_BUFFER_SIZE:
702
+ raise BufferError("Buffer size exceed limit. Call drain to flush the buffer.")
703
+ self._buffer.extend(data)
704
+ else:
705
+ raise TypeError(f"data argument must be a bytes-like object, not {type(data).__name__}")
706
+
707
+ def write_eof(self) -> None:
708
+ self._is_closed = True
709
+
710
+ async def drain(self) -> None:
711
+ eof = self._is_closed
712
+ # NB: There's no need to prevent writing eof twice, because the command router will ignore the second EOF.
713
+ if self._buffer or eof:
714
+ data = bytes(self._buffer)
715
+ await self._command_router_client.exec_stdin_write(
716
+ task_id=self._task_id, exec_id=self._object_id, offset=self._offset, data=data, eof=eof
717
+ )
718
+ # Only clear the buffer after writing the data to the command router is successful.
719
+ # This allows the client to retry drain() in the event of an exception (though
720
+ # exec_stdin_write already retries on transient errors, so most users will probably
721
+ # not do this).
722
+ self._buffer.clear()
723
+ self._offset += len(data)
724
+
725
+
726
+ class _StreamWriter:
727
+ """Provides an interface to buffer and write logs to a sandbox or container process stream (`stdin`)."""
728
+
729
+ def __init__(
730
+ self,
731
+ object_id: str,
732
+ object_type: Literal["sandbox", "container_process"],
733
+ client: _Client,
734
+ command_router_client: Optional[TaskCommandRouterClient] = None,
735
+ task_id: Optional[str] = None,
736
+ ) -> None:
737
+ """mdmd:hidden"""
738
+ if command_router_client is None:
739
+ self._impl = _StreamWriterThroughServer(object_id, object_type, client)
740
+ else:
741
+ assert task_id is not None
742
+ assert object_type == "container_process"
743
+ self._impl = _StreamWriterThroughCommandRouter(object_id, command_router_client, task_id=task_id)
744
+
745
+ def write(self, data: Union[bytes, bytearray, memoryview, str]) -> None:
746
+ """Write data to the stream but does not send it immediately.
747
+
748
+ This is non-blocking and queues the data to an internal buffer. Must be
749
+ used along with the `drain()` method, which flushes the buffer.
750
+
751
+ **Usage**
752
+
753
+ ```python fixture:running_app
754
+ from modal import Sandbox
755
+
756
+ sandbox = Sandbox.create(
757
+ "bash",
758
+ "-c",
759
+ "while read line; do echo $line; done",
760
+ app=running_app,
761
+ )
762
+ sandbox.stdin.write(b"foo\\n")
763
+ sandbox.stdin.write(b"bar\\n")
764
+ sandbox.stdin.write_eof()
765
+
766
+ sandbox.stdin.drain()
767
+ sandbox.wait()
768
+ ```
769
+ """
770
+ self._impl.write(data)
771
+
772
+ def write_eof(self) -> None:
773
+ """Close the write end of the stream after the buffered data is drained.
774
+
775
+ If the process was blocked on input, it will become unblocked after
776
+ `write_eof()`. This method needs to be used along with the `drain()`
777
+ method, which flushes the EOF to the process.
778
+ """
779
+ self._impl.write_eof()
780
+
781
+ async def drain(self) -> None:
782
+ """Flush the write buffer and send data to the running process.
783
+
784
+ This is a flow control method that blocks until data is sent. It returns
785
+ when it is appropriate to continue writing data to the stream.
786
+
787
+ **Usage**
788
+
789
+ ```python notest
790
+ writer.write(data)
791
+ writer.drain()
792
+ ```
793
+
794
+ Async usage:
795
+ ```python notest
796
+ writer.write(data) # not a blocking operation
797
+ await writer.drain.aio()
798
+ ```
799
+ """
800
+ await self._impl.drain()
801
+
802
+
470
803
  StreamReader = synchronize_api(_StreamReader)
471
804
  StreamWriter = synchronize_api(_StreamWriter)